Ticket #13781: 13781-master.patch

File 13781-master.patch, 9.4 KB (added by Tomáš Ehrlich, 12 years ago)

Diff against latest master

  • django/db/models/query.py

    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index 0210a79..e144956 100644
    a b class EmptyQuerySet(QuerySet):  
    12761276    value_annotation = False
    12771277
    12781278def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
    1279                    only_load=None, local_only=False):
     1279                   only_load=None, local_only=False, last_klass=None):
    12801280    """
    12811281    Helper function that recursively returns an information for a klass, to be
    12821282    used in get_cached_row.  It exists just to compute this information only
    def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,  
    12981298       the full field list for `klass` can be assumed.
    12991299     * local_only - Only populate local fields. This is used when
    13001300       following reverse select-related relations
     1301     * last_klass - the last class seen when following reverse
     1302       select-related relations
    13011303    """
    13021304    if max_depth and requested is None and cur_depth > max_depth:
    13031305        # We've recursed deeply enough; stop now.
    def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,  
    13431345        # But kwargs version of Model.__init__ is slower, so we should avoid using
    13441346        # it when it is not really neccesary.
    13451347        if local_only and len(klass._meta.local_fields) != len(klass._meta.fields):
    1346             field_count = len(klass._meta.local_fields)
    1347             field_names = [f.attname for f in klass._meta.local_fields]
     1348            parents = [p for p in klass._meta.get_parent_list()
     1349                       if p is not last_klass]
     1350            field_names = [f.attname for f in klass._meta.fields
     1351                           if f in klass._meta.local_fields
     1352                           or f.model in parents]
     1353            field_count = len(field_names)
     1354            if field_count == len(klass._meta.fields):
     1355                field_names = ()
    13481356        else:
    13491357            field_count = len(klass._meta.fields)
    13501358            field_names = ()
    def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,  
    13691377                                                         only_load.get(o.model), reverse=True):
    13701378                next = requested[o.field.related_query_name()]
    13711379                klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1,
    1372                                             requested=next, only_load=only_load, local_only=True)
     1380                                            requested=next, only_load=only_load, local_only=True,
     1381                                            last_klass=klass)
    13731382                reverse_related_fields.append((o.field, klass_info))
    13741383
    13751384    return klass, field_names, field_count, related_fields, reverse_related_fields
    def get_cached_row(row, index_start, using, klass_info, offset=0):  
    14551464                # Now populate all the non-local field values
    14561465                # on the related object
    14571466                for rel_field, rel_model in rel_obj._meta.get_fields_with_model():
    1458                     if rel_model is not None:
     1467                    if rel_model is not None and isinstance(obj, rel_model):
    14591468                        setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
    14601469                        # populate the field cache for any related object
    14611470                        # that has already been retrieved
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index a68f6e0..b8cf74d 100644
    a b class SQLCompiler(object):  
    249249        return result
    250250
    251251    def get_default_columns(self, with_aliases=False, col_aliases=None,
    252             start_alias=None, opts=None, as_pairs=False, local_only=False):
     252            start_alias=None, opts=None, as_pairs=False, local_only=False,
     253            last_opts=None):
    253254        """
    254255        Computes the default columns for selecting every field in the base
    255256        model. Will sometimes be called to pull in related models (e.g. via
    class SQLCompiler(object):  
    273274
    274275        if start_alias:
    275276            seen = {None: start_alias}
     277        parents = [p for p in opts.get_parent_list() if p._meta is not last_opts]
    276278        for field, model in opts.get_fields_with_model():
    277             if local_only and model is not None:
     279            if local_only and model is not None and model not in parents:
    278280                continue
    279281            if start_alias:
    280282                try:
    class SQLCompiler(object):  
    282284                except KeyError:
    283285                    link_field = opts.get_ancestor_link(model)
    284286                    alias = self.query.join((start_alias, model._meta.db_table,
    285                             link_field.column, model._meta.pk.column))
     287                            link_field.column, model._meta.pk.column),
     288                            promote=(model in parents))
    286289                    seen[model] = alias
    287290            else:
    288291                # If we're starting from the base model of the queryset, the
    class SQLCompiler(object):  
    733736                )
    734737                used.add(alias)
    735738                columns, aliases = self.get_default_columns(start_alias=alias,
    736                     opts=model._meta, as_pairs=True, local_only=True)
     739                    opts=model._meta, as_pairs=True, local_only=True,
     740                    last_opts=opts)
    737741                self.query.related_select_cols.extend(columns)
    738742                self.query.related_select_fields.extend(model._meta.fields)
    739743
  • new file tests/django

    diff --git a/tests/django b/tests/django
    new file mode 120000
    index 0000000..8016dee
    - +  
     1../django
     2 No newline at end of file
  • tests/regressiontests/select_related_onetoone/models.py

    diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
    index 3284def..6216e0c 100644
    a b class StatDetails(models.Model):  
    5151class AdvancedUserStat(UserStat):
    5252    karma = models.IntegerField()
    5353
     54
    5455class Image(models.Model):
    5556    name = models.CharField(max_length=100)
    5657
    class Image(models.Model):  
    5859class Product(models.Model):
    5960    name = models.CharField(max_length=100)
    6061    image = models.OneToOneField(Image, null=True)
     62
     63
     64class Parent1(models.Model):
     65    name1 = models.CharField(max_length=50)
     66    def __unicode__(self):
     67        return self.name1
     68
     69
     70class Parent2(models.Model):
     71    name2 = models.CharField(max_length=50)
     72    def __unicode__(self):
     73        return self.name2
     74
     75
     76class Child1(Parent1, Parent2):
     77    other = models.CharField(max_length=50)
     78    def __unicode__(self):
     79        return self.name1
     80
     81
     82class Child2(Parent1):
     83    parent2 = models.OneToOneField(Parent2)
     84    other = models.CharField(max_length=50)
     85    def __unicode__(self):
     86        return self.name1
  • tests/regressiontests/select_related_onetoone/tests.py

    diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
    index 1373f04..3c8623f 100644
    a b from __future__ import absolute_import  
    33from django.test import TestCase
    44
    55from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
    6     AdvancedUserStat, Image, Product)
     6    AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2)
    77
    88
    99class ReverseSelectRelatedTestCase(TestCase):
    class ReverseSelectRelatedTestCase(TestCase):  
    2121        advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
    2222                                                  results=results2)
    2323        StatDetails.objects.create(base_stats=advstat, comments=250)
     24        p1 = Parent1(name1="Only Parent1")
     25        p1.save()
     26        c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2")
     27        c1.save()
     28        p2 = Parent2(name2="Child2 Parent2")
     29        p2.save()
     30        c2 = Child2(name1="Child2 Parent1", parent2=p2)
     31        c2.save()
    2432
    2533    def test_basic(self):
    2634        with self.assertNumQueries(1):
    class ReverseSelectRelatedTestCase(TestCase):  
    7987        p1 = Product.objects.create(name="Django Plushie", image=im)
    8088        p2 = Product.objects.create(name="Talking Django Plushie")
    8189
     90        self.assertEqual(len(Product.objects.select_related("image")), 2)
     91
    8292        with self.assertNumQueries(1):
    8393            result = sorted(Product.objects.select_related("image"), key=lambda x: x.name)
    8494            self.assertEqual([p.name for p in result], ["Django Plushie", "Talking Django Plushie"])
    class ReverseSelectRelatedTestCase(TestCase):  
    108118            image = Image.objects.select_related('product').get()
    109119            with self.assertRaises(Product.DoesNotExist):
    110120                image.product
     121
     122    def test_parent_only(self):
     123        Parent1.objects.select_related('child1').get(name1="Only Parent1")
     124
     125    def test_multiple_subclass(self):
     126        with self.assertNumQueries(1):
     127            p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1")
     128            self.assertEqual(p.child1.name2, u'Child1 Parent2')
     129
     130    def test_onetoone_with_subclass(self):
     131        with self.assertNumQueries(1):
     132            p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2")
     133            self.assertEqual(p.child2.name1, u'Child2 Parent1')
Back to Top