Code

Ticket #14694: 14694-defer-reverse-relations-r17462.diff

File 14694-defer-reverse-relations-r17462.diff, 6.0 KB (added by mrmachine, 2 years ago)
Line 
1Index: tests/regressiontests/defer_regress/tests.py
2===================================================================
3--- tests/regressiontests/defer_regress/tests.py        (revision 17462)
4+++ tests/regressiontests/defer_regress/tests.py        (working copy)
5@@ -9,7 +9,7 @@
6 from django.test import TestCase
7 
8 from .models import (ResolveThis, Item, RelatedItem, Child, Leaf, Proxy,
9-    SimpleItem, Feature)
10+    SimpleItem, Feature, OneToOneItem)
11 
12 
13 class DeferRegressionTest(TestCase):
14@@ -110,6 +110,7 @@
15                 Feature,
16                 Item,
17                 Leaf,
18+                OneToOneItem,
19                 Proxy,
20                 RelatedItem,
21                 ResolveThis,
22@@ -141,6 +142,7 @@
23                 "Leaf_Deferred_name_value",
24                 "Leaf_Deferred_second_child_value",
25                 "Leaf_Deferred_value",
26+                "OneToOneItem",
27                 "Proxy",
28                 "RelatedItem",
29                 "RelatedItem_Deferred_",
30@@ -174,3 +176,14 @@
31         qs = ResolveThis.objects.defer('num')
32         self.assertEqual(1, qs.count())
33         self.assertEqual('Foobar', qs[0].name)
34+
35+    def test_reverse_one_to_one_relations(self):
36+        item = Item.objects.create(name="first", value=42)
37+        OneToOneItem.objects.create(item=item, name="second")
38+        self.assertEqual(len(Item.objects.all()), 1)
39+        self.assertEqual(len(Item.objects.defer('one_to_one_item__name')), 1)
40+        self.assertEqual(len(Item.objects.select_related('one_to_one_item')), 1)
41+        self.assertEqual(len(Item.objects.select_related('one_to_one_item').defer('one_to_one_item__name')), 1)
42+        self.assertEqual(len(Item.objects.select_related('one_to_one_item').defer('value')), 1)
43+        # make sure that `only()` doesn't break when we pass in a reverse relation, rather than a field on the relation.
44+        self.assertEqual(len(Item.objects.only('one_to_one_item')), 1)
45Index: tests/regressiontests/defer_regress/models.py
46===================================================================
47--- tests/regressiontests/defer_regress/models.py       (revision 17462)
48+++ tests/regressiontests/defer_regress/models.py       (working copy)
49@@ -47,3 +47,7 @@
50 
51 class Feature(models.Model):
52     item = models.ForeignKey(SimpleItem)
53+
54+class OneToOneItem(models.Model):
55+    item = models.OneToOneField(Item, related_name="one_to_one_item")
56+    name = models.CharField(max_length=15)
57Index: django/db/models/sql/query.py
58===================================================================
59--- django/db/models/sql/query.py       (revision 17462)
60+++ django/db/models/sql/query.py       (working copy)
61@@ -17,6 +17,7 @@
62 from django.db.models.expressions import ExpressionNode
63 from django.db.models.fields import FieldDoesNotExist
64 from django.db.models.query_utils import InvalidQuery
65+from django.db.models.related import RelatedObject
66 from django.db.models.sql import aggregates as base_aggregates_module
67 from django.db.models.sql.constants import *
68 from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
69@@ -586,17 +587,22 @@
70             for name in parts[:-1]:
71                 old_model = cur_model
72                 source = opts.get_field_by_name(name)[0]
73-                cur_model = opts.get_field_by_name(name)[0].rel.to
74+                if isinstance(source, RelatedObject):
75+                    cur_model = source.model
76+                else:
77+                    cur_model = source.rel.to
78                 opts = cur_model._meta
79                 # Even if we're "just passing through" this model, we must add
80                 # both the current model's pk and the related reference field
81-                # to the things we select.
82-                must_include[old_model].add(source)
83+                # (if it's not a reverse relation) to the things we select.
84+                if not isinstance(source, RelatedObject):
85+                    must_include[old_model].add(source)
86                 add_to_dict(must_include, cur_model, opts.pk)
87             field, model, _, _ = opts.get_field_by_name(parts[-1])
88             if model is None:
89                 model = cur_model
90-            add_to_dict(seen, model, field)
91+            if not isinstance(field, RelatedObject):
92+                add_to_dict(seen, model, field)
93 
94         if defer:
95             # We need to load all fields for each model, except those that
96@@ -636,7 +642,6 @@
97             for model, values in seen.iteritems():
98                 callback(target, model, values)
99 
100-
101     def deferred_to_columns_cb(self, target, model, fields):
102         """
103         Callback used by deferred_to_columns(). The "target" parameter should
104@@ -648,7 +653,6 @@
105         for field in fields:
106             target[table].add(field.column)
107 
108-
109     def table_alias(self, table_name, create=False):
110         """
111         Returns a table alias for the given table_name and whether this is a
112Index: django/db/models/query.py
113===================================================================
114--- django/db/models/query.py   (revision 17462)
115+++ django/db/models/query.py   (working copy)
116@@ -1413,9 +1413,14 @@
117                 # If the related object exists, populate
118                 # the descriptor cache.
119                 setattr(rel_obj, f.get_cache_name(), obj)
120-                # Now populate all the non-local field values
121-                # on the related object
122-                for rel_field, rel_model in rel_obj._meta.get_fields_with_model():
123+                # Now populate all the non-local field values on the related
124+                # object. If this object has deferred fields, we need to use
125+                # the opts from the original model to get non-local fields
126+                # correctly.
127+                opts = rel_obj._meta
128+                if getattr(rel_obj, '_deferred'):
129+                    opts = opts.proxy_for_model._meta
130+                for rel_field, rel_model in opts.get_fields_with_model():
131                     if rel_model is not None:
132                         setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
133                         # populate the field cache for any related object