Code

Ticket #7270: reverse_select_related.diff

File reverse_select_related.diff, 11.5 KB (added by Alex, 5 years ago)
Line 
1diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
2index 8fec836..8e27244 100644
3--- a/django/db/models/fields/related.py
4+++ b/django/db/models/fields/related.py
5@@ -188,7 +188,7 @@ class SingleRelatedObjectDescriptor(object):
6     # SingleRelatedObjectDescriptor instance.
7     def __init__(self, related):
8         self.related = related
9-        self.cache_name = '_%s_cache' % related.get_accessor_name()
10+        self.cache_name = related.get_accessor_cache()
11 
12     def __get__(self, instance, instance_type=None):
13         if instance is None:
14@@ -307,7 +307,7 @@ class ReverseSingleRelatedObjectDescriptor(object):
15             # cache. This cache also might not exist if the related object
16             # hasn't been accessed yet.
17             if related:
18-                cache_name = '_%s_cache' % self.field.related.get_accessor_name()
19+                cache_name = self.field.related.get_accessor_cache()
20                 try:
21                     delattr(related, cache_name)
22                 except AttributeError:
23diff --git a/django/db/models/query.py b/django/db/models/query.py
24index 4e3326a..afbcc24 100644
25--- a/django/db/models/query.py
26+++ b/django/db/models/query.py
27@@ -1147,6 +1147,24 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
28             rel_obj, index_end = cached_row
29             if obj is not None:
30                 setattr(obj, f.get_cache_name(), rel_obj)
31+            if f.unique:
32+                setattr(rel_obj, f.related.get_accessor_cache(), obj)
33+
34+    if restricted:
35+        related_fields = [(o.field, o.model) for o in klass._meta.get_all_related_objects()
36+            if o.field.unique and o.field.related_query_name() in requested]
37+        for f, model in related_fields:
38+            next = requested.get(f.related_query_name(), {})
39+            cached_row = get_cached_row(model, row, index_end, max_depth,
40+                cur_depth+1, next)
41+            if cached_row:
42+                rel_obj, index_end = cached_row
43+                if obj is not None:
44+                    setattr(obj, f.related.get_accessor_cache(), rel_obj)
45+                if rel_obj is not None:
46+                    setattr(rel_obj, f.get_cache_name(), obj)
47+
48+
49     return obj, index_end
50 
51 def delete_objects(seen_objs, using):
52diff --git a/django/db/models/related.py b/django/db/models/related.py
53index afdf3f7..54258ca 100644
54--- a/django/db/models/related.py
55+++ b/django/db/models/related.py
56@@ -45,3 +45,6 @@ class RelatedObject(object):
57             return self.field.rel.related_name or (self.opts.object_name.lower() + '_set')
58         else:
59             return self.field.rel.related_name or (self.opts.object_name.lower())
60+
61+    def get_accessor_cache(self):
62+        return "_%s_cache" % self.get_accessor_name()
63diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
64index 6a95d32..99d4c7e 100644
65--- a/django/db/models/sql/compiler.py
66+++ b/django/db/models/sql/compiler.py
67@@ -520,7 +520,7 @@ class SQLCompiler(object):
68 
69         # Setup for the case when only particular related fields should be
70         # included in the related selection.
71-        if requested is None and restricted is not False:
72+        if requested is None:
73             if isinstance(self.query.select_related, dict):
74                 requested = self.query.select_related
75                 restricted = True
76@@ -600,6 +600,57 @@ class SQLCompiler(object):
77             self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
78                     used, next, restricted, new_nullable, dupe_set, avoid)
79 
80+        if restricted and requested is not None:
81+            related_fields = [(o.field, o.model) for o in opts.get_all_related_objects()
82+                if o.field.unique and o.field.related_query_name() in requested
83+            ]
84+            for f, model in related_fields:
85+                table = model._meta.db_table
86+                int_opts = opts
87+                alias = root_alias
88+                alias_chain = []
89+                chain = opts.get_base_chain(f.rel.to)
90+                avoid = avoid_set.copy()
91+                if chain is not None:
92+                    for int_model in chain:
93+                        if not int_opts.parents[int_model]:
94+                            int_opts = int_model._meta
95+                            continue
96+                        lhs_col = int_opts.parents[int_model].column
97+                        dedupe = lhs_col in opts.duplicate_targets
98+                        if dedupe:
99+                            avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col), ())
100+                            dupe_set.add((opts, lhs_col))
101+                        int_opts = int_model._meta
102+                        alias = self.query.join(
103+                            (alias, int_opts.db_table, lhs_col, int_opts.pk.column),
104+                            exclusions=used, promote=True, reuse=used
105+                        )
106+                        alias_chain.append(alias)
107+                        for dupe_opts, dupe_col in dupe_set:
108+                            self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias)
109+                    dedupe = f.column in opts.duplicate_targets
110+                    if dupe_set or dedupe:
111+                        avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ()))
112+                        if dedupe:
113+                            dupe_set.add((opts, f.column))
114+                alias = self.query.join(
115+                    (alias, table, f.rel.get_related_field().column, f.column),
116+                    exclusions=used.union(avoid),
117+                    promote=True
118+                )
119+                used.add(alias)
120+                columns, aliases = self.get_default_columns(start_alias=alias,
121+                    opts=model._meta, as_pairs=True)
122+                self.query.related_select_cols.extend(columns)
123+                self.query.related_select_fields.extend(model._meta.fields)
124+
125+                next = requested.get(f.related_query_name(), {})
126+                new_nullable = f.null or None
127+
128+                self.fill_related_selections(model._meta, table, cur_depth+1,
129+                    used, next, restricted, new_nullable)
130+
131     def deferred_to_columns(self):
132         """
133         Converts the self.deferred_loading data structure to mapping of table
134diff --git a/tests/regressiontests/select_related_onetoone/__init__.py b/tests/regressiontests/select_related_onetoone/__init__.py
135new file mode 100644
136index 0000000..e69de29
137diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
138new file mode 100644
139index 0000000..05efadf
140--- /dev/null
141+++ b/tests/regressiontests/select_related_onetoone/models.py
142@@ -0,0 +1,44 @@
143+from django.db import models
144+
145+
146+class User(models.Model):
147+    username = models.CharField(max_length=100)
148+    email = models.EmailField()
149+
150+    def __unicode__(self):
151+        return self.username
152+
153+
154+class UserProfile(models.Model):
155+    user = models.OneToOneField(User)
156+    city = models.CharField(max_length=100)
157+    state = models.CharField(max_length=2)
158+
159+    def __unicode__(self):
160+        return "%s, %s" % (self.city, self.state)
161+
162+
163+class UserStatResult(models.Model):
164+    results = models.CharField(max_length=50)
165+
166+    def __unicode__(self):
167+        return 'UserStatResults, results = %s' % (self.results,)
168+   
169+
170+class UserStat(models.Model):
171+    user = models.OneToOneField(User, primary_key=True)
172+    posts = models.IntegerField()
173+    results = models.ForeignKey(UserStatResult)
174+
175+    def __unicode__(self):
176+        return 'UserStat, posts = %s' % (self.posts,)
177+
178+class StatDetails(models.Model):
179+    base_stats = models.OneToOneField(UserStat)
180+    comments = models.IntegerField()
181+
182+    def __unicode__(self):
183+        return 'StatDetails, comments = %s' % (self.comments,)
184+
185+class AdvancedUserStat(UserStat):
186+    pass
187diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
188new file mode 100644
189index 0000000..08e798b
190--- /dev/null
191+++ b/tests/regressiontests/select_related_onetoone/tests.py
192@@ -0,0 +1,81 @@
193+from django import db
194+from django.conf import settings
195+from django.test import TestCase
196+
197+from models import User, UserProfile, UserStat, UserStatResult, StatDetails, AdvancedUserStat
198+
199+class ReverseSelectRelatedTestCase(TestCase):
200+    def setUp(self):
201+        self.old_debug = settings.DEBUG
202+        settings.DEBUG = True
203+
204+        user = User.objects.create(username="test")
205+        userprofile = UserProfile.objects.create(user=user, state="KS",
206+                                                 city="Lawrence")
207+        results = UserStatResult.objects.create(results='first results')
208+        userstat = UserStat.objects.create(user=user, posts=150,
209+                                           results=results)
210+        details = StatDetails.objects.create(base_stats=userstat, comments=259)
211+
212+        user2 = User.objects.create(username="bob")
213+        results2 = UserStatResult.objects.create(results='moar results')
214+        advstat = AdvancedUserStat.objects.create(user=user2, posts=200,
215+                                                  results=results2)
216+        StatDetails.objects.create(base_stats=advstat, comments=250)
217+
218+        db.reset_queries()
219+
220+    def assertQueries(self, queries):
221+        self.assertEqual(len(db.connection.queries), queries)
222+
223+    def tearDown(self):
224+        settings.DEBUG = self.old_debug
225+
226+    def test_basic(self):
227+        u = User.objects.select_related("userprofile").get(username="test")
228+        self.assertEqual(u.userprofile.state, "KS")
229+        self.assertQueries(1)
230+
231+    def test_follow_next_level(self):
232+        u = User.objects.select_related("userstat__results").get(username="test")
233+        self.assertEqual(u.userstat.posts, 150)
234+        self.assertEqual(u.userstat.results.results, 'first results')
235+        self.assertQueries(1)
236+
237+    def test_follow_two(self):
238+        u = User.objects.select_related("userprofile", "userstat").get(username="test")
239+        self.assertEqual(u.userprofile.state, "KS")
240+        self.assertEqual(u.userstat.posts, 150)
241+        self.assertQueries(1)
242+
243+    def test_follow_two_next_level(self):
244+        u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
245+        self.assertEqual(u.userstat.results.results, 'first results')
246+        self.assertEqual(u.userstat.statdetails.comments, 259)
247+        self.assertQueries(1)
248+
249+    def test_forward_and_back(self):
250+        stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
251+        self.assertEqual(stat.user.userprofile.state, 'KS')
252+        self.assertEqual(stat.user.userstat.posts, 150)
253+        self.assertQueries(1)
254+
255+    def test_back_and_forward(self):
256+        u = User.objects.select_related("userstat").get(username="test")
257+        self.assertEqual(u.userstat.user.username, 'test')
258+        self.assertQueries(1)
259+
260+    def test_not_followed_by_default(self):
261+        u = User.objects.select_related().get(username="test")
262+        self.assertEqual(u.userstat.posts, 150)
263+        self.assertQueries(2)
264+
265+    def test_follow_from_child_class(self):
266+        stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200)
267+        self.assertEqual(stat.statdetails.comments, 250)
268+        self.assertQueries(1)
269+
270+    def test_follow_inheritance(self):
271+        stat = UserStat.objects.select_related('advanceduserstat').get(posts=200)
272+        self.assertEqual(stat.advanceduserstat.posts, 200)
273+        self.assertQueries(1)