Code

Ticket #13781: select_related_subclass_patch_1.2.X.diff

File select_related_subclass_patch_1.2.X.diff, 8.9 KB (added by ungenio, 2 years ago)

Tests and patch (1.2.X)

Line 
1From bbd8e472f24db8a3134687d4deae9cf39faa437c Mon Sep 17 00:00:00 2001
2From: David Bennett <david@dbinit.com>
3Date: Mon, 30 Jan 2012 13:02:13 -0600
4Subject: [PATCH] Fixed #13781 -- select_related and multiple inheritance
5
6---
7 django/db/models/query.py                          |   16 +++++++++---
8 django/db/models/sql/compiler.py                   |   12 ++++++---
9 .../select_related_onetoone/models.py              |   26 ++++++++++++++++++++
10 .../select_related_onetoone/tests.py               |   23 ++++++++++++++++-
11 4 files changed, 68 insertions(+), 9 deletions(-)
12
13diff --git a/django/db/models/query.py b/django/db/models/query.py
14index a2d7ffb..cb50647 100644
15--- a/django/db/models/query.py
16+++ b/django/db/models/query.py
17@@ -1142,7 +1142,8 @@ class EmptyQuerySet(QuerySet):
18 
19 
20 def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
21-                   requested=None, offset=0, only_load=None, local_only=False):
22+                   requested=None, offset=0, only_load=None, local_only=False,
23+                   last_klass=None):
24     """
25     Helper function that recursively returns an object with the specified
26     related attributes already populated.
27@@ -1172,6 +1173,8 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
28        the full field list for `klass` can be assumed.
29      * local_only - Only populate local fields. This is used when building
30        following reverse select-related relations
31+     * last_klass - the last class seen when following reverse
32+       select-related relations
33     """
34     if max_depth and requested is None and cur_depth > max_depth:
35         # We've recursed deeply enough; stop now.
36@@ -1218,7 +1221,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
37     else:
38         # Load all fields on klass
39         if local_only:
40-            field_names = [f.attname for f in klass._meta.local_fields]
41+            parents = [p for p in klass._meta.get_parent_list()
42+                       if p is not last_klass]
43+            field_names = [f.attname for f in klass._meta.fields
44+                           if f in klass._meta.local_fields
45+                           or f.model in parents]
46         else:
47             field_names = [f.attname for f in klass._meta.fields]
48         field_count = len(field_names)
49@@ -1277,7 +1284,8 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
50             next = requested[f.related_query_name()]
51             # Recursively retrieve the data for the related object
52             cached_row = get_cached_row(model, row, index_end, using,
53-                max_depth, cur_depth+1, next, only_load=only_load, local_only=True)
54+                max_depth, cur_depth+1, next, only_load=only_load, local_only=True,
55+                last_klass=klass)
56             # If the recursive descent found an object, populate the
57             # descriptor caches relevant to the object
58             if cached_row:
59@@ -1293,7 +1301,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
60                     # Now populate all the non-local field values
61                     # on the related object
62                     for rel_field,rel_model in rel_obj._meta.get_fields_with_model():
63-                        if rel_model is not None:
64+                        if rel_model is not None and isinstance(obj, rel_model):
65                             setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
66                             # populate the field cache for any related object
67                             # that has already been retrieved
68diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
69index fb9674c..bcfd014 100644
70--- a/django/db/models/sql/compiler.py
71+++ b/django/db/models/sql/compiler.py
72@@ -213,7 +213,8 @@ class SQLCompiler(object):
73         return result
74 
75     def get_default_columns(self, with_aliases=False, col_aliases=None,
76-            start_alias=None, opts=None, as_pairs=False, local_only=False):
77+            start_alias=None, opts=None, as_pairs=False, local_only=False,
78+            last_opts=None):
79         """
80         Computes the default columns for selecting every field in the base
81         model. Will sometimes be called to pull in related models (e.g. via
82@@ -237,8 +238,9 @@ class SQLCompiler(object):
83 
84         if start_alias:
85             seen = {None: start_alias}
86+        parents = [p for p in opts.get_parent_list() if p._meta is not last_opts]
87         for field, model in opts.get_fields_with_model():
88-            if local_only and model is not None:
89+            if local_only and model is not None and model not in parents:
90                 continue
91             if start_alias:
92                 try:
93@@ -249,7 +251,8 @@ class SQLCompiler(object):
94                     else:
95                         link_field = opts.get_ancestor_link(model)
96                         alias = self.query.join((start_alias, model._meta.db_table,
97-                                link_field.column, model._meta.pk.column))
98+                                link_field.column, model._meta.pk.column),
99+                                promote=(model in parents))
100                     seen[model] = alias
101             else:
102                 # If we're starting from the base model of the queryset, the
103@@ -647,7 +650,8 @@ class SQLCompiler(object):
104                 )
105                 used.add(alias)
106                 columns, aliases = self.get_default_columns(start_alias=alias,
107-                    opts=model._meta, as_pairs=True, local_only=True)
108+                    opts=model._meta, as_pairs=True, local_only=True,
109+                    last_opts=opts)
110                 self.query.related_select_cols.extend(columns)
111                 self.query.related_select_fields.extend(model._meta.fields)
112 
113diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
114index 3d6da9b..4bfad1d 100644
115--- a/tests/regressiontests/select_related_onetoone/models.py
116+++ b/tests/regressiontests/select_related_onetoone/models.py
117@@ -45,6 +45,7 @@ class StatDetails(models.Model):
118 class AdvancedUserStat(UserStat):
119     karma = models.IntegerField()
120 
121+
122 class Image(models.Model):
123     name = models.CharField(max_length=100)
124 
125@@ -52,3 +53,28 @@ class Image(models.Model):
126 class Product(models.Model):
127     name = models.CharField(max_length=100)
128     image = models.OneToOneField(Image, null=True)
129+
130+
131+class Parent1(models.Model):
132+    name1 = models.CharField(max_length=50)
133+    def __unicode__(self):
134+        return self.name1
135+
136+
137+class Parent2(models.Model):
138+    name2 = models.CharField(max_length=50)
139+    def __unicode__(self):
140+        return self.name2
141+
142+
143+class Child1(Parent1, Parent2):
144+    other = models.CharField(max_length=50)
145+    def __unicode__(self):
146+        return self.name1
147+
148+
149+class Child2(Parent1):
150+    parent2 = models.OneToOneField(Parent2)
151+    other = models.CharField(max_length=50)
152+    def __unicode__(self):
153+        return self.name1
154diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
155index 4ccb584..b2f8549 100644
156--- a/tests/regressiontests/select_related_onetoone/tests.py
157+++ b/tests/regressiontests/select_related_onetoone/tests.py
158@@ -3,7 +3,7 @@ from django.conf import settings
159 from django.test import TestCase
160 
161 from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
162-    AdvancedUserStat, Image, Product)
163+    AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2)
164 
165 class ReverseSelectRelatedTestCase(TestCase):
166     def setUp(self):
167@@ -25,6 +25,14 @@ class ReverseSelectRelatedTestCase(TestCase):
168         advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
169                                                   results=results2)
170         StatDetails.objects.create(base_stats=advstat, comments=250)
171+        p1 = Parent1(name1="Only Parent1")
172+        p1.save()
173+        c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2")
174+        c1.save()
175+        p2 = Parent2(name2="Child2 Parent2")
176+        p2.save()
177+        c2 = Child2(name1="Child2 Parent1", parent2=p2)
178+        c2.save()
179 
180         db.reset_queries()
181 
182@@ -92,3 +100,16 @@ class ReverseSelectRelatedTestCase(TestCase):
183         p2 = Product.objects.create(name="Talking Django Plushie")
184 
185         self.assertEqual(len(Product.objects.select_related("image")), 2)
186+
187+    def test_parent_only(self):
188+        Parent1.objects.select_related('child1').get(name1="Only Parent1")
189+
190+    def test_multiple_subclass(self):
191+        p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1")
192+        self.assertEqual(p.child1.name2, u'Child1 Parent2')
193+        self.assertQueries(1)
194+
195+    def test_onetoone_with_subclass(self):
196+        p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2")
197+        self.assertEqual(p.child2.name1, u'Child2 Parent1')
198+        self.assertQueries(1)
199--
2001.7.7.3
201