Ticket #13781: select_related_subclass_patch_1.2.X.diff

File select_related_subclass_patch_1.2.X.diff, 8.9 KB (added by ungenio, 3 years ago)

Tests and patch (1.2.X)

  • django/db/models/query.py

    From bbd8e472f24db8a3134687d4deae9cf39faa437c Mon Sep 17 00:00:00 2001
    From: David Bennett <david@dbinit.com>
    Date: Mon, 30 Jan 2012 13:02:13 -0600
    Subject: [PATCH] Fixed #13781 -- select_related and multiple inheritance
    
    ---
     django/db/models/query.py                          |   16 +++++++++---
     django/db/models/sql/compiler.py                   |   12 ++++++---
     .../select_related_onetoone/models.py              |   26 ++++++++++++++++++++
     .../select_related_onetoone/tests.py               |   23 ++++++++++++++++-
     4 files changed, 68 insertions(+), 9 deletions(-)
    
    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index a2d7ffb..cb50647 100644
    a b class EmptyQuerySet(QuerySet): 
    11421142
    11431143
    11441144def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
    1145                    requested=None, offset=0, only_load=None, local_only=False):
     1145                   requested=None, offset=0, only_load=None, local_only=False,
     1146                   last_klass=None):
    11461147    """
    11471148    Helper function that recursively returns an object with the specified
    11481149    related attributes already populated.
    def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, 
    11721173       the full field list for `klass` can be assumed.
    11731174     * local_only - Only populate local fields. This is used when building
    11741175       following reverse select-related relations
     1176     * last_klass - the last class seen when following reverse
     1177       select-related relations
    11751178    """
    11761179    if max_depth and requested is None and cur_depth > max_depth:
    11771180        # We've recursed deeply enough; stop now.
    def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, 
    12181221    else:
    12191222        # Load all fields on klass
    12201223        if local_only:
    1221             field_names = [f.attname for f in klass._meta.local_fields]
     1224            parents = [p for p in klass._meta.get_parent_list()
     1225                       if p is not last_klass]
     1226            field_names = [f.attname for f in klass._meta.fields
     1227                           if f in klass._meta.local_fields
     1228                           or f.model in parents]
    12221229        else:
    12231230            field_names = [f.attname for f in klass._meta.fields]
    12241231        field_count = len(field_names)
    def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, 
    12771284            next = requested[f.related_query_name()]
    12781285            # Recursively retrieve the data for the related object
    12791286            cached_row = get_cached_row(model, row, index_end, using,
    1280                 max_depth, cur_depth+1, next, only_load=only_load, local_only=True)
     1287                max_depth, cur_depth+1, next, only_load=only_load, local_only=True,
     1288                last_klass=klass)
    12811289            # If the recursive descent found an object, populate the
    12821290            # descriptor caches relevant to the object
    12831291            if cached_row:
    def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, 
    12931301                    # Now populate all the non-local field values
    12941302                    # on the related object
    12951303                    for rel_field,rel_model in rel_obj._meta.get_fields_with_model():
    1296                         if rel_model is not None:
     1304                        if rel_model is not None and isinstance(obj, rel_model):
    12971305                            setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
    12981306                            # populate the field cache for any related object
    12991307                            # that has already been retrieved
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index fb9674c..bcfd014 100644
    a b class SQLCompiler(object): 
    213213        return result
    214214
    215215    def get_default_columns(self, with_aliases=False, col_aliases=None,
    216             start_alias=None, opts=None, as_pairs=False, local_only=False):
     216            start_alias=None, opts=None, as_pairs=False, local_only=False,
     217            last_opts=None):
    217218        """
    218219        Computes the default columns for selecting every field in the base
    219220        model. Will sometimes be called to pull in related models (e.g. via
    class SQLCompiler(object): 
    237238
    238239        if start_alias:
    239240            seen = {None: start_alias}
     241        parents = [p for p in opts.get_parent_list() if p._meta is not last_opts]
    240242        for field, model in opts.get_fields_with_model():
    241             if local_only and model is not None:
     243            if local_only and model is not None and model not in parents:
    242244                continue
    243245            if start_alias:
    244246                try:
    class SQLCompiler(object): 
    249251                    else:
    250252                        link_field = opts.get_ancestor_link(model)
    251253                        alias = self.query.join((start_alias, model._meta.db_table,
    252                                 link_field.column, model._meta.pk.column))
     254                                link_field.column, model._meta.pk.column),
     255                                promote=(model in parents))
    253256                    seen[model] = alias
    254257            else:
    255258                # If we're starting from the base model of the queryset, the
    class SQLCompiler(object): 
    647650                )
    648651                used.add(alias)
    649652                columns, aliases = self.get_default_columns(start_alias=alias,
    650                     opts=model._meta, as_pairs=True, local_only=True)
     653                    opts=model._meta, as_pairs=True, local_only=True,
     654                    last_opts=opts)
    651655                self.query.related_select_cols.extend(columns)
    652656                self.query.related_select_fields.extend(model._meta.fields)
    653657
  • tests/regressiontests/select_related_onetoone/models.py

    diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
    index 3d6da9b..4bfad1d 100644
    a b class StatDetails(models.Model): 
    4545class AdvancedUserStat(UserStat):
    4646    karma = models.IntegerField()
    4747
     48
    4849class Image(models.Model):
    4950    name = models.CharField(max_length=100)
    5051
    class Image(models.Model): 
    5253class Product(models.Model):
    5354    name = models.CharField(max_length=100)
    5455    image = models.OneToOneField(Image, null=True)
     56
     57
     58class Parent1(models.Model):
     59    name1 = models.CharField(max_length=50)
     60    def __unicode__(self):
     61        return self.name1
     62
     63
     64class Parent2(models.Model):
     65    name2 = models.CharField(max_length=50)
     66    def __unicode__(self):
     67        return self.name2
     68
     69
     70class Child1(Parent1, Parent2):
     71    other = models.CharField(max_length=50)
     72    def __unicode__(self):
     73        return self.name1
     74
     75
     76class Child2(Parent1):
     77    parent2 = models.OneToOneField(Parent2)
     78    other = models.CharField(max_length=50)
     79    def __unicode__(self):
     80        return self.name1
  • tests/regressiontests/select_related_onetoone/tests.py

    diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
    index 4ccb584..b2f8549 100644
    a b from django.conf import settings 
    33from django.test import TestCase
    44
    55from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
    6     AdvancedUserStat, Image, Product)
     6    AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2)
    77
    88class ReverseSelectRelatedTestCase(TestCase):
    99    def setUp(self):
    class ReverseSelectRelatedTestCase(TestCase): 
    2525        advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
    2626                                                  results=results2)
    2727        StatDetails.objects.create(base_stats=advstat, comments=250)
     28        p1 = Parent1(name1="Only Parent1")
     29        p1.save()
     30        c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2")
     31        c1.save()
     32        p2 = Parent2(name2="Child2 Parent2")
     33        p2.save()
     34        c2 = Child2(name1="Child2 Parent1", parent2=p2)
     35        c2.save()
    2836
    2937        db.reset_queries()
    3038
    class ReverseSelectRelatedTestCase(TestCase): 
    92100        p2 = Product.objects.create(name="Talking Django Plushie")
    93101
    94102        self.assertEqual(len(Product.objects.select_related("image")), 2)
     103
     104    def test_parent_only(self):
     105        Parent1.objects.select_related('child1').get(name1="Only Parent1")
     106
     107    def test_multiple_subclass(self):
     108        p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1")
     109        self.assertEqual(p.child1.name2, u'Child1 Parent2')
     110        self.assertQueries(1)
     111
     112    def test_onetoone_with_subclass(self):
     113        p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2")
     114        self.assertEqual(p.child2.name1, u'Child2 Parent1')
     115        self.assertQueries(1)
Back to Top