Code

Ticket #17001: prefetch_extensions.diff

File prefetch_extensions.diff, 35.3 KB (added by akaariai, 3 years ago)

Now with settings.DEBUG removed

Line 
1diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py
2index c513787..1cdc572 100644
3--- a/django/contrib/contenttypes/generic.py
4+++ b/django/contrib/contenttypes/generic.py
5@@ -62,7 +62,9 @@ class GenericForeignKey(object):
6             # This should never happen. I love comments like this, don't you?
7             raise Exception("Impossible arguments to GFK.get_content_type!")
8 
9-    def get_prefetch_query_set(self, instances):
10+    def get_prefetch_query_set(self, instances, custom_qs=None):
11+        if custom_qs is not None:
12+            raise ValueError("Custom queryset can't be used for this lookup")
13         # For efficiency, group the instances by content type and then do one
14         # query per model
15         fk_dict = defaultdict(set)
16@@ -320,20 +322,34 @@ def create_generic_related_manager(superclass):
17                 db = self._db or router.db_for_read(self.model, instance=self.instance)
18                 return super(GenericRelatedObjectManager, self).get_query_set().using(db).filter(**self.core_filters)
19 
20-        def get_prefetch_query_set(self, instances):
21-            db = self._db or router.db_for_read(self.model)
22+        def get_prefetch_query_set(self, instances, custom_qs=None):
23+            if not instances:
24+                return self.model._default_manager.none()
25             query = {
26                 '%s__pk' % self.content_type_field_name: self.content_type.id,
27                 '%s__in' % self.object_id_field_name:
28                     set(obj._get_pk_val() for obj in instances)
29-                }
30-            qs = super(GenericRelatedObjectManager, self).get_query_set().using(db).filter(**query)
31+            }
32+            if custom_qs is not None:
33+                qs = custom_qs.filter(**query)
34+            else:
35+                db = self._db or router.db_for_read(self.model, instance=instances[0])
36+                qs = super(GenericRelatedObjectManager, self).get_query_set()\
37+                         .using(db).filter(**query)
38             return (qs,
39                     attrgetter(self.object_id_field_name),
40                     lambda obj: obj._get_pk_val(),
41                     False,
42                     self.prefetch_cache_name)
43 
44+
45+        def all(self):
46+            try:
47+                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
48+            except (AttributeError, KeyError):
49+                return super(GenericRelatedObjectManager, self).all()
50+
51+
52         def add(self, *objs):
53             for obj in objs:
54                 if not isinstance(obj, self.model):
55diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py
56index 3582720..fd07ca6 100644
57--- a/django/db/models/__init__.py
58+++ b/django/db/models/__init__.py
59@@ -4,14 +4,17 @@ from django.db import connection
60 from django.db.models.loading import get_apps, get_app, get_models, get_model, register_models
61 from django.db.models.query import Q
62 from django.db.models.expressions import F
63+from django.db.models.related import R
64 from django.db.models.manager import Manager
65 from django.db.models.base import Model
66 from django.db.models.aggregates import *
67 from django.db.models.fields import *
68 from django.db.models.fields.subclassing import SubfieldBase
69 from django.db.models.fields.files import FileField, ImageField
70-from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
71-from django.db.models.deletion import CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError
72+from django.db.models.fields.related import (ForeignKey, OneToOneField,
73+        ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel)
74+from django.db.models.deletion import (CASCADE, PROTECT, SET, SET_NULL,
75+        SET_DEFAULT, DO_NOTHING, ProtectedError)
76 from django.db.models import signals
77 from django.utils.decorators import wraps
78 
79diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
80index 8c054e7..a0748a1 100644
81--- a/django/db/models/fields/related.py
82+++ b/django/db/models/fields/related.py
83@@ -236,7 +236,13 @@ class SingleRelatedObjectDescriptor(object):
84         db = router.db_for_read(self.related.model, **db_hints)
85         return self.related.model._base_manager.using(db)
86 
87-    def get_prefetch_query_set(self, instances):
88+    def get_prefetch_query_set(self, instances, custom_qs=None):
89+        if custom_qs is not None:
90+            # TODO: This error message is too SQLish, and might be downright
91+            # wrong.
92+            raise ValueError(
93+                "Custom querysets can't be used for one-to-one relations")
94+
95         vals = set(instance._get_pk_val() for instance in instances)
96         params = {'%s__pk__in' % self.related.field.name: vals}
97         return (self.get_query_set(),
98@@ -315,7 +321,13 @@ class ReverseSingleRelatedObjectDescriptor(object):
99         else:
100             return QuerySet(self.field.rel.to).using(db)
101 
102-    def get_prefetch_query_set(self, instances):
103+    def get_prefetch_query_set(self, instances, custom_qs=None):
104+        if custom_qs is not None:
105+            # TODO: This error message is too SQLish, and I am not even sure
106+            # this desriptor is used for m2o...
107+            raise ValueError(
108+                "Custom querysets can't be used for many-to-one relations")
109+
110         vals = set(getattr(instance, self.field.attname) for instance in instances)
111         other_field = self.field.rel.get_related_field()
112         if other_field.rel:
113@@ -460,17 +472,31 @@ class ForeignRelatedObjectsDescriptor(object):
114                     db = self._db or router.db_for_read(self.model, instance=self.instance)
115                     return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
116 
117-            def get_prefetch_query_set(self, instances):
118-                db = self._db or router.db_for_read(self.model)
119+            def get_prefetch_query_set(self, instances, custom_qs=None):
120+                """
121+                Return a queryset that does the bulk lookup needed
122+                by prefetch_related functionality.
123+                """
124                 query = {'%s__%s__in' % (rel_field.name, attname):
125-                             set(getattr(obj, attname) for obj in instances)}
126-                qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
127+                            set(getattr(obj, attname) for obj in instances)}
128+                if custom_qs is not None:
129+                    qs = custom_qs.filter(**query)
130+                else:
131+                    db = self._db or router.db_for_read(self.model)
132+                    qs = super(RelatedManager, self).get_query_set().\
133+                                    using(db).filter(**query)
134                 return (qs,
135                         attrgetter(rel_field.get_attname()),
136                         attrgetter(attname),
137                         False,
138                         rel_field.related_query_name())
139 
140+            def all(self):
141+                try:
142+                    return self.instance._prefetched_objects_cache[rel_field.related_query_name()]
143+                except (AttributeError, KeyError):
144+                    return super(RelatedManager, self).all()
145+
146             def add(self, *objs):
147                 for obj in objs:
148                     if not isinstance(obj, self.model):
149@@ -542,24 +568,39 @@ def create_many_related_manager(superclass, rel):
150                 db = self._db or router.db_for_read(self.instance.__class__, instance=self.instance)
151                 return super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**self.core_filters)
152 
153-        def get_prefetch_query_set(self, instances):
154+        def get_prefetch_query_set(self, instances, custom_qs=None):
155             from django.db import connections
156             db = self._db or router.db_for_read(self.model)
157             query = {'%s__pk__in' % self.query_field_name:
158-                         set(obj._get_pk_val() for obj in instances)}
159-            qs = super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**query)
160+                                  set(obj._get_pk_val() for obj in instances)}
161+
162+            if custom_qs is not None:
163+                qs = custom_qs._next_is_sticky().filter(**query)
164+            else:
165+                qs = (super(ManyRelatedManager, self).get_query_set().using(db)
166+                      ._next_is_sticky().filter(**query))
167 
168             # M2M: need to annotate the query in order to get the primary model
169-            # that the secondary model was actually related to. We know that
170-            # there will already be a join on the join table, so we can just add
171-            # the select.
172+            # that the secondary model was actually related to.
173+
174+            # We know that there will already be a join on the join table, so we
175+            # can just add the select.
176 
177             # For non-autocreated 'through' models, can't assume we are
178             # dealing with PK values.
179+
180+            # TODO: This is at the wrong level of abstraction. We should not
181+            # be generating SQL here, but instead maybe pass this information
182+            # to the connection. NoSQL camp will have problems with this, for
183+            # example.
184             fk = self.through._meta.get_field(self.source_field_name)
185             source_col = fk.column
186             join_table = self.through._meta.db_table
187-            connection = connections[db]
188+            if custom_qs is not None:
189+                connection = connections[custom_qs.db]
190+            else:
191+                connection = connections[db]
192+
193             qn = connection.ops.quote_name
194             qs = qs.extra(select={'_prefetch_related_val':
195                                       '%s.%s' % (qn(join_table), qn(source_col))})
196@@ -570,6 +611,12 @@ def create_many_related_manager(superclass, rel):
197                     False,
198                     self.prefetch_cache_name)
199 
200+        def all(self):
201+            try:
202+                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
203+            except (AttributeError, KeyError):
204+                return super(ManyRelatedManager, self).all()
205+
206         # If the ManyToMany relation has an intermediary model,
207         # the add and remove methods do not exist.
208         if rel.through._meta.auto_created:
209diff --git a/django/db/models/query.py b/django/db/models/query.py
210index be42d02..f5cf2ae 100644
211--- a/django/db/models/query.py
212+++ b/django/db/models/query.py
213@@ -11,6 +11,7 @@ from django.db.models.query_utils import (Q, select_related_descend,
214     deferred_class_factory, InvalidQuery)
215 from django.db.models.deletion import Collector
216 from django.db.models import sql
217+from django.db.models.related import R
218 from django.utils.functional import partition
219 
220 # Used to control how many objects are worked with at once in some cases (e.g.
221@@ -1555,6 +1556,22 @@ def insert_query(model, objs, fields, return_id=False, raw=False, using=None):
222     query.insert_values(fields, objs, raw=raw)
223     return query.get_compiler(using=using).execute_sql(return_id)
224 
225+def prl_to_r_objs(lookups, prefix=None):
226+    """
227+    This little helper function will convert a list containing R objects or
228+    normal lookups into all R objects list.
229+    """
230+    from django.db.models.sql.constants import LOOKUP_SEP
231+    if prefix is None:
232+        return [isinstance(lup, R) and lup or R(lup) for lup in lookups]
233+    ret = []
234+    for lup in lookups:
235+        if isinstance(lup, R):
236+            r_obj = lup._new_prefixed(prefix)
237+        else:
238+            r_obj = R(prefix + LOOKUP_SEP + lup)
239+        ret.append(r_obj)
240+    return ret
241 
242 def prefetch_related_objects(result_cache, related_lookups):
243     """
244@@ -1567,31 +1584,38 @@ def prefetch_related_objects(result_cache, related_lookups):
245 
246     if len(result_cache) == 0:
247         return # nothing to do
248-
249+    r_objs = prl_to_r_objs(related_lookups)
250     model = result_cache[0].__class__
251 
252     # We need to be able to dynamically add to the list of prefetch_related
253     # lookups that we look up (see below).  So we need some book keeping to
254     # ensure we don't do duplicate work.
255-    done_lookups = set() # list of lookups like foo__bar__baz
256+    seen_lookups = set() # list of lookups like foo__bar__baz
257     done_queries = {}    # dictionary of things like 'foo__bar': [results]
258 
259-    manual_lookups = list(related_lookups)
260+    manual_lookups = list(r_objs)
261     auto_lookups = [] # we add to this as we go through.
262     followed_descriptors = set() # recursion protection
263 
264-    related_lookups = itertools.chain(manual_lookups, auto_lookups)
265-    for lookup in related_lookups:
266-        if lookup in done_lookups:
267+    # For R-objects, we have two different lookups:
268+    #   - lookup: This is the related object attribute name
269+    #   - lookup_refpath: This is to be used when this R-object is referenced
270+    #     in chained prefetches.
271+    # One way to explain these would be to say lookup is how we go forward,
272+    # lookup_refpath is what happened in the past.
273+
274+    r_objs = itertools.chain(manual_lookups, auto_lookups)
275+    for r_obj in r_objs:
276+        if r_obj.lookup_refpath in seen_lookups:
277             # We've done exactly this already, skip the whole thing
278             continue
279-        done_lookups.add(lookup)
280+        seen_lookups.add(r_obj.lookup_refpath)
281 
282         # Top level, the list of objects to decorate is the the result cache
283         # from the primary QuerySet. It won't be for deeper levels.
284         obj_list = result_cache
285 
286-        attrs = lookup.split(LOOKUP_SEP)
287+        attrs = r_obj.lookup.split(LOOKUP_SEP)
288         for level, attr in enumerate(attrs):
289             # Prepare main instances
290             if len(obj_list) == 0:
291@@ -1619,46 +1643,63 @@ def prefetch_related_objects(result_cache, related_lookups):
292 
293             # We assume that objects retrieved are homogenous (which is the premise
294             # of prefetch_related), so what applies to first object applies to all.
295+            # TODO: Make sure this is really true for objects coming from generic
296+            # relations.
297             first_obj = obj_list[0]
298-            prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(first_obj, attr)
299+            prefetcher, descriptor, attr_found, is_fetched = \
300+                get_prefetcher(first_obj, attr)
301 
302             if not attr_found:
303                 raise AttributeError("Cannot find '%s' on %s object, '%s' is an invalid "
304                                      "parameter to prefetch_related()" %
305-                                     (attr, first_obj.__class__.__name__, lookup))
306+                                     (attr, first_obj.__class__.__name__,
307+                                      r_obj.lookup))
308 
309             if level == len(attrs) - 1 and prefetcher is None:
310                 # Last one, this *must* resolve to something that supports
311                 # prefetching, otherwise there is no point adding it and the
312                 # developer asking for it has made a mistake.
313-                raise ValueError("'%s' does not resolve to a item that supports "
314+                raise ValueError("'%s' does not resolve to an item that supports "
315                                  "prefetching - this is an invalid parameter to "
316-                                 "prefetch_related()." % lookup)
317+                                 "prefetch_related()." % r_obj.lookup)
318 
319             if prefetcher is not None and not is_fetched:
320-                # Check we didn't do this already
321-                current_lookup = LOOKUP_SEP.join(attrs[0:level+1])
322+                current_lookup = r_obj.get_current_lookup(level)
323                 if current_lookup in done_queries:
324                     obj_list = done_queries[current_lookup]
325                 else:
326-                    obj_list, additional_prl = prefetch_one_level(obj_list, prefetcher, attr)
327+                    obj_list, additional_prl = prefetch_one_level(
328+                            obj_list, prefetcher, r_obj, level)
329                     # We need to ensure we don't keep adding lookups from the
330                     # same relationships to stop infinite recursion. So, if we
331                     # are already on an automatically added lookup, don't add
332                     # the new lookups from relationships we've seen already.
333-                    if not (lookup in auto_lookups and
334+                    if not (r_obj in auto_lookups and
335                             descriptor in followed_descriptors):
336-                        for f in additional_prl:
337-                            new_prl = LOOKUP_SEP.join([current_lookup, f])
338-                            auto_lookups.append(new_prl)
339                         done_queries[current_lookup] = obj_list
340+                        additional_prl = prl_to_r_objs(additional_prl,
341+                                                       current_lookup)
342+                        auto_lookups.extend(additional_prl)
343                     followed_descriptors.add(descriptor)
344+
345+            elif isinstance(getattr(obj_list[0], attr), list):
346+                # The current part of the lookup relates to a r_obj.to_attr
347+                # defined previous fetch. This means that obj.attr is a list
348+                # of related objects, and thus we must turn the obj.attr lists
349+                # into a single related object list.
350+                new_list = []
351+                for obj in obj_list:
352+                    new_list.extend(getattr(obj, attr))
353+                obj_list = new_list
354             else:
355                 # Either a singly related object that has already been fetched
356                 # (e.g. via select_related), or hopefully some other property
357                 # that doesn't support prefetching but needs to be traversed.
358 
359                 # We replace the current list of parent objects with that list.
360+                # TODO: Check what happens if attr resolves to local field?
361+                # User typoing rel_attr_id instead of rel_attr? AND there are
362+                # multiple parts in the path left.
363                 obj_list = [getattr(obj, attr) for obj in obj_list]
364 
365                 # Filter out 'None' so that we can continue with nullable
366@@ -1688,6 +1729,12 @@ def get_prefetcher(instance, attr):
367         try:
368             rel_obj = getattr(instance, attr)
369             attr_found = True
370+            # If we are following a r_obj lookup path which leads us through
371+            # a previous fetch with to_attr, then we might end up into a list
372+            # instead of related qs. This means the objects are already
373+            # fetched.
374+            if isinstance(rel_obj, list):
375+                is_fetched = True
376         except AttributeError:
377             pass
378     else:
379@@ -1709,7 +1756,7 @@ def get_prefetcher(instance, attr):
380     return prefetcher, rel_obj_descriptor, attr_found, is_fetched
381 
382 
383-def prefetch_one_level(instances, prefetcher, attname):
384+def prefetch_one_level(instances, prefetcher, r_obj, level):
385     """
386     Helper function for prefetch_related_objects
387 
388@@ -1733,7 +1780,8 @@ def prefetch_one_level(instances, prefetcher, attname):
389     # in a dictionary.
390 
391     rel_qs, rel_obj_attr, instance_attr, single, cache_name =\
392-        prefetcher.get_prefetch_query_set(instances)
393+        prefetcher.get_prefetch_query_set(instances, custom_qs=r_obj.qs)
394+
395     # We have to handle the possibility that the default manager itself added
396     # prefetch_related lookups to the QuerySet we just got back. We don't want to
397     # trigger the prefetch_related functionality by evaluating the query.
398@@ -1754,7 +1802,16 @@ def prefetch_one_level(instances, prefetcher, attname):
399             rel_obj_cache[rel_attr_val] = []
400         rel_obj_cache[rel_attr_val].append(rel_obj)
401 
402+
403+    # to_attr is the name of the attribute we will be fetching into, to_list
404+    # is False if to_attr refers to related manager. If it refers to related
405+    # manager, we will be caching in rel_manager.all(), otherwise in a list.
406+    to_attr, to_list = r_obj.get_to_attr(level)
407     for obj in instances:
408+        # TODO: in this case we could set the reverse attribute if the relation
409+        # is o2o. Both this and the TODO below are handled by select_related
410+        # in the get_cached_row iterator construction. Maybe that code could
411+        # be generalized and shared.
412         instance_attr_val = instance_attr(obj)
413         vals = rel_obj_cache.get(instance_attr_val, [])
414         if single:
415@@ -1764,10 +1821,17 @@ def prefetch_one_level(instances, prefetcher, attname):
416         else:
417             # Multi, attribute represents a manager with an .all() method that
418             # returns a QuerySet
419-            qs = getattr(obj, attname).all()
420-            qs._result_cache = vals
421-            # We don't want the individual qs doing prefetch_related now, since we
422-            # have merged this into the current work.
423-            qs._prefetch_done = True
424-            obj._prefetched_objects_cache[cache_name] = qs
425+            # TODO: we could set the reverse relation, so that if user does
426+            # access the just fetched relation in the reverse order, we would
427+            # not need to do a query. We can't do this for m2m, of course.
428+            if to_list:
429+                setattr(obj, to_attr, vals)
430+            else:
431+                # Cache in the QuerySet.all().
432+                qs = getattr(obj, to_attr).all()
433+                qs._result_cache = vals
434+                # We don't want the individual qs doing prefetch_related now,
435+                # since we have merged this into the current work.
436+                qs._prefetch_done = True
437+                obj._prefetched_objects_cache[cache_name] = qs
438     return all_related_objects, additional_prl
439diff --git a/django/db/models/related.py b/django/db/models/related.py
440index 90995d7..2c1e954 100644
441--- a/django/db/models/related.py
442+++ b/django/db/models/related.py
443@@ -1,5 +1,6 @@
444 from django.utils.encoding import smart_unicode
445 from django.db.models.fields import BLANK_CHOICE_DASH
446+from django.db.models.sql.constants import LOOKUP_SEP
447 
448 class BoundRelatedObject(object):
449     def __init__(self, related_object, field_mapping, original):
450@@ -36,7 +37,7 @@ class RelatedObject(object):
451                 {'%s__isnull' % self.parent_model._meta.module_name: False})
452         lst = [(x._get_pk_val(), smart_unicode(x)) for x in queryset]
453         return first_choice + lst
454-       
455+
456     def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
457         # Defer to the actual field definition for db prep
458         return self.field.get_db_prep_lookup(lookup_type, value,
459@@ -67,3 +68,86 @@ class RelatedObject(object):
460 
461     def get_cache_name(self):
462         return "_%s_cache" % self.get_accessor_name()
463+
464+# Not knowing a better place for this, I just planted R here.
465+# Feel free to move this to a better place or remove this comment.
466+class R(object):
467+    """
468+    A class used for passing options to .prefetch_related. Note that instances
469+    of this class should be considered immutable.
470+    """
471+
472+    # For R-objects, we have two different internal lookup paths:
473+    #   - lookup: This is the related object attribute name
474+    #   - lookup_refpath: This is to be used when this R-object is referenced
475+    #     in chained prefetches.
476+    # Check out the source of R-objects to see what is happening there.
477+    #
478+    # The difference is needed, because when we chain R-objects with to_attr
479+    # defined, the lookup_path (how we got here) and lookup_refpath (how to
480+    # get forward from here) will be different. For example:
481+    # R('foo', to_attr='foolst') -> lookup_path = foo, that is we are going
482+    # to prefetch through relation foo.
483+    #
484+    # If there would be another qs produced by R, the lookup_refpath would
485+    # need to be 'foolst__nextpart'. Otherwise we can't distinguish between
486+    # two different prefetch_related lookups to 'foo' (perhaps with custom
487+    # querysets).
488+    #
489+    # Luckily the user does not need to know anything about this.
490+
491+    def __init__(self, lookup, to_attr=None, qs=None):
492+        if qs is not None and not to_attr:
493+            raise ValueError('When custom qs is defined, to_attr '
494+                             'must also be defined')
495+        self.lookup = lookup
496+        self.to_attr = to_attr
497+        self.qs = qs
498+
499+    def _new_prefixed(self, prefix):
500+        """
501+        _new_internal is to be used when prefetches are chained internally.
502+        The returned R-object is identical to self, except lookup_path
503+        is prefixed with prefix.
504+        """
505+        new_lookup = prefix + LOOKUP_SEP + self.lookup
506+        return R(new_lookup, to_attr=self.to_attr, qs=self.qs)
507+
508+    def __unicode__(self):
509+        return ("lookup: %s, to_attr: %s, qs: %s" %
510+            (self.lookup, self.to_attr or None, self.qs))
511+
512+    def __repr__(self):
513+        return '<%s: %s>' % (self.__class__.__name__, unicode(self))
514+
515+    def __eq__(self, other):
516+        if isinstance(other, R):
517+            return self.lookup_refpath == other.lookup_refpath
518+        return False
519+
520+    def _lookup_refpath(self):
521+        if self.to_attr is None:
522+            return self.lookup
523+        else:
524+            path, sep, last_part = self.lookup.rpartition(LOOKUP_SEP)
525+            return path + sep + self.to_attr
526+    lookup_refpath = property(_lookup_refpath)
527+
528+    def get_current_lookup(self, level):
529+        """
530+        Returns the first level + 1 parts of the self.lookup_refpath
531+        """
532+        parts = self.lookup_refpath.split(LOOKUP_SEP)
533+        return LOOKUP_SEP.join(parts[0:level + 1])
534+
535+    def get_to_attr(self, level):
536+        """
537+        Returns information about into what attribute should the results be
538+        fetched, and if that attribute is related object manager, or will the
539+        objects be fetched into a list.
540+        """
541+        parts = self.lookup_refpath.split(LOOKUP_SEP)
542+        if self.to_attr is None or level < len(parts) - 1:
543+            return parts[level], False
544+        else:
545+            return self.to_attr, True
546diff --git a/tests/modeltests/prefetch_related/models.py b/tests/modeltests/prefetch_related/models.py
547index 1c14c88..3a996db 100644
548--- a/tests/modeltests/prefetch_related/models.py
549+++ b/tests/modeltests/prefetch_related/models.py
550@@ -58,6 +58,32 @@ class BookWithYear(Book):
551         AuthorWithAge, related_name='books_with_year')
552 
553 
554+class AuthorDefManager(models.Manager):
555+    # Default manager with possibly recursive results.
556+    def get_query_set(self):
557+        qs = super(AuthorDefManager, self).get_query_set()
558+        return qs.prefetch_related('best_friend_reverse', 'books')
559+
560+class AuthorWithDefPrefetch(models.Model):
561+    name = models.TextField()
562+    best_friend = models.ForeignKey(
563+         'self', related_name='best_friend_reverse', null=True)
564+    objects = AuthorDefManager()
565+
566+class BookDefManager(models.Manager):
567+    # No need for guard here, author's manager will take care of that.
568+    def get_query_set(self):
569+        return (super(BookDefManager, self).get_query_set()
570+                .prefetch_related('authors'))
571+
572+class BookWithDefPrefetch(models.Model):
573+    name = models.TextField()
574+    authors = models.ManyToManyField(AuthorWithDefPrefetch,
575+                                     related_name='books')
576+
577+    objects = BookDefManager()
578+
579+
580 class Reader(models.Model):
581     name = models.CharField(max_length=50)
582     books_read = models.ManyToManyField(Book, related_name='read_by')
583@@ -155,7 +181,7 @@ class Person(models.Model):
584         ordering = ['id']
585 
586 
587-## Models for nullable FK tests
588+## Models for nullable FK tests and recursive prefetch_related tests.
589 
590 class Employee(models.Model):
591     name = models.CharField(max_length=50)
592diff --git a/tests/modeltests/prefetch_related/tests.py b/tests/modeltests/prefetch_related/tests.py
593index bdbb056..f424fc1 100644
594--- a/tests/modeltests/prefetch_related/tests.py
595+++ b/tests/modeltests/prefetch_related/tests.py
596@@ -1,13 +1,32 @@
597 from __future__ import with_statement
598 
599 from django.contrib.contenttypes.models import ContentType
600+from django.db.models import R
601 from django.test import TestCase
602 from django.utils import unittest
603 
604 from models import (Author, Book, Reader, Qualification, Teacher, Department,
605                     TaggedItem, Bookmark, AuthorAddress, FavoriteAuthors,
606                     AuthorWithAge, BookWithYear, Person, House, Room,
607-                    Employee)
608+                    Employee, AuthorWithDefPrefetch, BookWithDefPrefetch)
609+
610+def traverse_qs(obj_iter, path):
611+    """
612+    Helper method that returns a list containing a list of the objects in the
613+    obj_iter. Then for each object in the obj_iter, the path will be
614+    recursively travelled and the found objects are added to the return value.
615+    """
616+    ret_val = []
617+    if hasattr(obj_iter, 'all'):
618+        obj_iter = obj_iter.all()
619+    for obj in obj_iter:
620+        rel_objs = []
621+        for part in path:
622+            if not part:
623+                continue
624+            rel_objs.extend(traverse_qs(getattr(obj, part[0]), [part[1:]]))
625+        ret_val.append((obj, rel_objs))
626+    return ret_val
627 
628 
629 class PrefetchRelatedTests(TestCase):
630@@ -39,6 +58,15 @@ class PrefetchRelatedTests(TestCase):
631         self.reader1.books_read.add(self.book1, self.book4)
632         self.reader2.books_read.add(self.book2, self.book4)
633 
634+    def test_metatest_traverse_qs(self):
635+        qs = Book.objects.prefetch_related('authors')
636+        related_objs_normal = [list(b.authors.all()) for b in qs],
637+        related_objs_from_traverse = [[inner[0] for inner in o[1]]
638+                                      for o in traverse_qs(qs, [['authors']])]
639+        self.assertEquals(related_objs_normal, (related_objs_from_traverse,))
640+        self.assertFalse(related_objs_from_traverse == traverse_qs(qs.filter(pk=1),
641+                         [['authors']]))
642+
643     def test_m2m_forward(self):
644         with self.assertNumQueries(2):
645             lists = [list(b.authors.all()) for b in Book.objects.prefetch_related('authors')]
646@@ -472,3 +500,117 @@ class NullableTest(TestCase):
647                         for e in qs2]
648 
649         self.assertEqual(co_serfs, co_serfs2)
650+
651+
652+class RObjectTest(TestCase):
653+    def setUp(self):
654+        self.person1 = Person.objects.create(name="Joe")
655+        self.person2 = Person.objects.create(name="Mary")
656+
657+        self.house1 = House.objects.create(address="123 Main St")
658+        self.house2 = House.objects.create(address="45 Side St")
659+        self.house3 = House.objects.create(address="6 Downing St")
660+        self.house4 = House.objects.create(address="7 Regents St")
661+
662+        self.room1_1 = Room.objects.create(name="Dining room", house=self.house1)
663+        self.room1_2 = Room.objects.create(name="Lounge", house=self.house1)
664+        self.room1_3 = Room.objects.create(name="Kitchen", house=self.house1)
665+
666+        self.room2_1 = Room.objects.create(name="Dining room", house=self.house2)
667+        self.room2_2 = Room.objects.create(name="Lounge", house=self.house2)
668+
669+        self.room3_1 = Room.objects.create(name="Dining room", house=self.house3)
670+        self.room3_2 = Room.objects.create(name="Lounge", house=self.house3)
671+        self.room3_3 = Room.objects.create(name="Kitchen", house=self.house3)
672+
673+        self.room4_1 = Room.objects.create(name="Dining room", house=self.house4)
674+        self.room4_2 = Room.objects.create(name="Lounge", house=self.house4)
675+
676+        self.person1.houses.add(self.house1, self.house2)
677+        self.person2.houses.add(self.house3, self.house4)
678+
679+    def test_robj_basics(self):
680+        # Test different combinations of R and non-R lookups
681+        with self.assertNumQueries(2):
682+            lst1 = traverse_qs(Person.objects.prefetch_related('houses'),
683+                               [['houses']])
684+        with self.assertNumQueries(2):
685+            lst2 = traverse_qs(Person.objects.prefetch_related(R('houses')),
686+                               [['houses']])
687+        self.assertEquals(lst1, lst2)
688+        with self.assertNumQueries(3):
689+            lst1 = traverse_qs(Person.objects.prefetch_related('houses', 'houses__rooms'),
690+                               [['houses', 'rooms']])
691+        with self.assertNumQueries(3):
692+            lst2 = traverse_qs(Person.objects.prefetch_related(R('houses'), R('houses__rooms')),
693+                               [['houses', 'rooms']])
694+        self.assertEquals(lst1, lst2)
695+        with self.assertNumQueries(3):
696+            lst1 = traverse_qs(Person.objects.prefetch_related('houses', 'houses__rooms'),
697+                               [['houses', 'rooms']])
698+        with self.assertNumQueries(3):
699+            lst2 = traverse_qs(Person.objects.prefetch_related(R('houses'), 'houses__rooms'),
700+                               [['houses', 'rooms']])
701+        self.assertEquals(lst1, lst2)
702+        # Test to_attr
703+        with self.assertNumQueries(3):
704+            lst1 = traverse_qs(Person.objects.prefetch_related('houses', 'houses__rooms'),
705+                               [['houses', 'rooms']])
706+        with self.assertNumQueries(3):
707+            lst2 = traverse_qs(Person.objects.prefetch_related(
708+                                  R('houses', to_attr='houses_lst'),
709+                                  'houses_lst__rooms'),
710+                               [['houses_lst', 'rooms']])
711+        self.assertEquals(lst1, lst2)
712+
713+        with self.assertNumQueries(4):
714+            qs = list(Person.objects.prefetch_related(
715+                    R('houses', to_attr='houses_lst'),
716+                    R('houses__rooms', to_attr='rooms_lst')
717+            ))
718+            with self.assertRaises(AttributeError):
719+                qs[0].houses_lst2[0].rooms_lst
720+            qs[0].houses.all()[0].rooms_lst
721+            lst2 = traverse_qs(
722+                qs, [['houses', 'rooms_lst']]
723+            )
724+            self.assertEquals(lst1, lst2)
725+            self.assertEquals(
726+                traverse_qs(qs, [['houses']]),
727+                traverse_qs(qs, [['houses_lst']])
728+            )
729+
730+    def test_custom_qs(self):
731+        person_qs = Person.objects.all()
732+        houses_qs = House.objects.all()
733+        with self.assertNumQueries(2):
734+             lst1 = list(person_qs.prefetch_related('houses'))
735+        with self.assertNumQueries(2):
736+             lst2 = list(person_qs.prefetch_related(
737+                 R('houses', qs=houses_qs, to_attr='houses_lst')
738+             ))
739+        self.assertEquals(
740+            traverse_qs(lst1, [['houses']]),
741+            traverse_qs(lst2, [['houses_lst']])
742+        )
743+        with self.assertNumQueries(2):
744+            lst2 = list(person_qs.prefetch_related(
745+                R('houses', qs=houses_qs.filter(pk__in=[self.house1.pk, self.house3.pk]),
746+                  to_attr='hlst')
747+            ))
748+        self.assertEquals(len(lst2[0].hlst), 1)
749+        self.assertEquals(lst2[0].hlst[0], self.house1)
750+        self.assertEquals(len(lst2[1].hlst), 1)
751+        self.assertEquals(lst2[1].hlst[0], self.house3)
752+
753+        inner_rooms_qs = Room.objects.filter(pk__in=[self.room1_1.pk, self.room1_2.pk])
754+        houses_qs_prf = houses_qs.prefetch_related(
755+            R('rooms', qs=inner_rooms_qs, to_attr='rooms_lst'))
756+        with self.assertNumQueries(3):
757+            lst2 = list(person_qs.prefetch_related(
758+                       R('houses', qs=houses_qs_prf.filter(pk=self.house1.pk), to_attr='hlst'),
759+                   ))
760+        self.assertEquals(len(lst2[0].hlst[0].rooms_lst), 2)
761+        self.assertEquals(lst2[0].hlst[0].rooms_lst[0], self.room1_1)
762+        self.assertEquals(lst2[0].hlst[0].rooms_lst[1], self.room1_2)
763+        self.assertEquals(len(lst2[1].hlst), 0)