Code

Ticket #10695: defer.6.diff

File defer.6.diff, 9.2 KB (added by Alex, 5 years ago)
Line 
1diff --git a/django/contrib/gis/tests/relatedapp/tests.py b/django/contrib/gis/tests/relatedapp/tests.py
2index 3d162b0..acc616c 100644
3--- a/django/contrib/gis/tests/relatedapp/tests.py
4+++ b/django/contrib/gis/tests/relatedapp/tests.py
5@@ -184,12 +184,12 @@ class RelatedGeoModelTest(unittest.TestCase):
6             self.assertEqual(m.point, t[1])
7 
8     # Test disabled until #10572 is resolved.
9-    #def test08_defer_only(self):
10-    #    "Testing defer() and only() on Geographic models."
11-    #    qs = Location.objects.all()
12-    #    def_qs = Location.objects.defer('point')
13-    #    for loc, def_loc in zip(qs, def_qs):
14-    #        self.assertEqual(loc.point, def_loc.point)
15+    def test08_defer_only(self):
16+        "Testing defer() and only() on Geographic models."
17+        qs = Location.objects.all()
18+        def_qs = Location.objects.defer('point')
19+        for loc, def_loc in zip(qs, def_qs):
20+            self.assertEqual(loc.point, def_loc.point)
21 
22     # TODO: Related tests for KML, GML, and distance lookups.
23 
24diff --git a/django/db/models/base.py b/django/db/models/base.py
25index 01e2ca7..05cd0d9 100644
26--- a/django/db/models/base.py
27+++ b/django/db/models/base.py
28@@ -362,9 +362,8 @@ class Model(object):
29                     # DeferredAttribute classes, so we only need to do this
30                     # once.
31                     obj = self.__class__.__dict__[field.attname]
32-                    pk_val = obj.pk_value
33                     model = obj.model_ref()
34-        return (model_unpickle, (model, pk_val, defers), data)
35+        return (model_unpickle, (model, defers), data)
36 
37     def _get_pk_val(self, meta=None):
38         if not meta:
39@@ -635,12 +634,12 @@ def get_absolute_url(opts, func, self, *args, **kwargs):
40 class Empty(object):
41     pass
42 
43-def model_unpickle(model, pk_val, attrs):
44+def model_unpickle(model, attrs):
45     """
46     Used to unpickle Model subclasses with deferred fields.
47     """
48     from django.db.models.query_utils import deferred_class_factory
49-    cls = deferred_class_factory(model, pk_val, attrs)
50+    cls = deferred_class_factory(model, attrs)
51     return cls.__new__(cls)
52 model_unpickle.__safe_for_unpickle__ = True
53 
54diff --git a/django/db/models/query.py b/django/db/models/query.py
55index ea7129b..9dcc031 100644
56--- a/django/db/models/query.py
57+++ b/django/db/models/query.py
58@@ -190,6 +190,20 @@ class QuerySet(object):
59         index_start = len(extra_select)
60         aggregate_start = index_start + len(self.model._meta.fields)
61 
62+        load_fields = only_load.get(self.model)
63+        skip = None
64+        if load_fields and not fill_cache:
65+            # Some fields have been deferred, so we have to initialise
66+            # via keyword arguments.
67+            skip = set()
68+            init_list = []
69+            for field in fields:
70+                if field.name not in load_fields:
71+                    skip.add(field.attname)
72+                else:
73+                    init_list.append(field.attname)
74+            model_cls = deferred_class_factory(self.model, skip)
75+
76         for row in self.query.results_iter():
77             if fill_cache:
78                 obj, _ = get_cached_row(self.model, row,
79@@ -197,25 +211,10 @@ class QuerySet(object):
80                             requested=requested, offset=len(aggregate_select),
81                             only_load=only_load)
82             else:
83-                load_fields = only_load.get(self.model)
84-                if load_fields:
85-                    # Some fields have been deferred, so we have to initialise
86-                    # via keyword arguments.
87+                if skip:
88                     row_data = row[index_start:aggregate_start]
89                     pk_val = row_data[pk_idx]
90-                    skip = set()
91-                    init_list = []
92-                    for field in fields:
93-                        if field.name not in load_fields:
94-                            skip.add(field.attname)
95-                        else:
96-                            init_list.append(field.attname)
97-                    if skip:
98-                        model_cls = deferred_class_factory(self.model, pk_val,
99-                                skip)
100-                        obj = model_cls(**dict(zip(init_list, row_data)))
101-                    else:
102-                        obj = self.model(*row[index_start:aggregate_start])
103+                    obj = model_cls(**dict(zip(init_list, row_data)))
104                 else:
105                     # Omit aggregates in object creation.
106                     obj = self.model(*row[index_start:aggregate_start])
107@@ -927,7 +926,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
108                 else:
109                     init_list.append(field.attname)
110             if skip:
111-                klass = deferred_class_factory(klass, pk_val, skip)
112+                klass = deferred_class_factory(klass, skip)
113                 obj = klass(**dict(zip(init_list, fields)))
114             else:
115                 obj = klass(*fields)
116diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py
117index 8baa654..2a84c35 100644
118--- a/django/db/models/query_utils.py
119+++ b/django/db/models/query_utils.py
120@@ -158,9 +158,8 @@ class DeferredAttribute(object):
121     A wrapper for a deferred-loading field. When the value is read from this
122     object the first time, the query is executed.
123     """
124-    def __init__(self, field_name, pk_value, model):
125+    def __init__(self, field_name, model):
126         self.field_name = field_name
127-        self.pk_value = pk_value
128         self.model_ref = weakref.ref(model)
129         self.loaded = False
130 
131@@ -170,21 +169,17 @@ class DeferredAttribute(object):
132         Returns the cached value.
133         """
134         assert instance is not None
135-        if not self.loaded:
136-            obj = self.model_ref()
137-            if obj is None:
138-                return
139-            self.value = list(obj._base_manager.filter(pk=self.pk_value).values_list(self.field_name, flat=True))[0]
140-            self.loaded = True
141-        return self.value
142-
143-    def __set__(self, name, value):
144+        cls = self.model_ref()
145+        if self.field_name not in instance.__dict__:
146+            instance.__dict__[self.field_name] =  cls._base_manager.filter(pk=instance.pk).values_list(self.field_name, flat=True).get()
147+        return instance.__dict__[self.field_name]
148+
149+    def __set__(self, instance, value):
150         """
151         Deferred loading attributes can be set normally (which means there will
152         never be a database lookup involved.
153         """
154-        self.value = value
155-        self.loaded = True
156+        instance.__dict__[self.field_name] = value
157 
158 def select_related_descend(field, restricted, requested):
159     """
160@@ -206,7 +201,7 @@ def select_related_descend(field, restricted, requested):
161 # This function is needed because data descriptors must be defined on a class
162 # object, not an instance, to have any effect.
163 
164-def deferred_class_factory(model, pk_value, attrs):
165+def deferred_class_factory(model, attrs):
166     """
167     Returns a class object that is a copy of "model" with the specified "attrs"
168     being replaced with DeferredAttribute objects. The "pk_value" ties the
169@@ -223,7 +218,7 @@ def deferred_class_factory(model, pk_value, attrs):
170     # are identical.
171     name = "%s_Deferred_%s" % (model.__name__, '_'.join(sorted(list(attrs))))
172 
173-    overrides = dict([(attr, DeferredAttribute(attr, pk_value, model))
174+    overrides = dict([(attr, DeferredAttribute(attr, model))
175             for attr in attrs])
176     overrides["Meta"] = Meta
177     overrides["__module__"] = model.__module__
178@@ -233,4 +228,3 @@ def deferred_class_factory(model, pk_value, attrs):
179 # The above function is also used to unpickle model instances with deferred
180 # fields.
181 deferred_class_factory.__safe_for_unpickling__ = True
182-
183diff --git a/tests/regressiontests/defer_regress/models.py b/tests/regressiontests/defer_regress/models.py
184index c46d7ce..5f51513 100644
185--- a/tests/regressiontests/defer_regress/models.py
186+++ b/tests/regressiontests/defer_regress/models.py
187@@ -6,7 +6,7 @@ from django.conf import settings
188 from django.db import connection, models
189 
190 class Item(models.Model):
191-    name = models.CharField(max_length=10)
192+    name = models.CharField(max_length=15)
193     text = models.TextField(default="xyzzy")
194     value = models.IntegerField()
195     other_value = models.IntegerField(default=0)
196@@ -14,6 +14,9 @@ class Item(models.Model):
197     def __unicode__(self):
198         return self.name
199 
200+class RelatedItem(models.Model):
201+    item = models.ForeignKey(Item)
202+
203 __test__ = {"regression_tests": """
204 Deferred fields should really be deferred and not accidentally use the field's
205 default value just because they aren't passed to __init__.
206@@ -39,9 +42,26 @@ True
207 u"xyzzy"
208 >>> len(connection.queries) == num + 2      # Effect of text lookup.
209 True
210+>>> obj.text
211+u"xyzzy"
212+>>> len(connection.queries) == num + 2
213+True
214+
215+>>> i = Item.objects.create(name="no I'm first", value=37)
216+>>> items = Item.objects.only('value').order_by('-value')
217+>>> items[0].name
218+u'first'
219+>>> items[1].name
220+u"no I'm first"
221+
222+>>> _ = RelatedItem.objects.create(item=i)
223+>>> r = RelatedItem.objects.defer('item').get()
224+>>> r.item_id == i.id
225+True
226+>>> r.item == i
227+True
228 
229 >>> settings.DEBUG = False
230 
231 """
232 }
233-