Ticket #17001: prefetch_extensions.diff

File prefetch_extensions.diff, 35.3 KB (added by Anssi Kääriäinen, 13 years ago)

Now with settings.DEBUG removed

  • django/contrib/contenttypes/generic.py

    diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py
    index c513787..1cdc572 100644
    a b class GenericForeignKey(object):  
    6262            # This should never happen. I love comments like this, don't you?
    6363            raise Exception("Impossible arguments to GFK.get_content_type!")
    6464
    65     def get_prefetch_query_set(self, instances):
     65    def get_prefetch_query_set(self, instances, custom_qs=None):
     66        if custom_qs is not None:
     67            raise ValueError("Custom queryset can't be used for this lookup")
    6668        # For efficiency, group the instances by content type and then do one
    6769        # query per model
    6870        fk_dict = defaultdict(set)
    def create_generic_related_manager(superclass):  
    320322                db = self._db or router.db_for_read(self.model, instance=self.instance)
    321323                return super(GenericRelatedObjectManager, self).get_query_set().using(db).filter(**self.core_filters)
    322324
    323         def get_prefetch_query_set(self, instances):
    324             db = self._db or router.db_for_read(self.model)
     325        def get_prefetch_query_set(self, instances, custom_qs=None):
     326            if not instances:
     327                return self.model._default_manager.none()
    325328            query = {
    326329                '%s__pk' % self.content_type_field_name: self.content_type.id,
    327330                '%s__in' % self.object_id_field_name:
    328331                    set(obj._get_pk_val() for obj in instances)
    329                 }
    330             qs = super(GenericRelatedObjectManager, self).get_query_set().using(db).filter(**query)
     332            }
     333            if custom_qs is not None:
     334                qs = custom_qs.filter(**query)
     335            else:
     336                db = self._db or router.db_for_read(self.model, instance=instances[0])
     337                qs = super(GenericRelatedObjectManager, self).get_query_set()\
     338                         .using(db).filter(**query)
    331339            return (qs,
    332340                    attrgetter(self.object_id_field_name),
    333341                    lambda obj: obj._get_pk_val(),
    334342                    False,
    335343                    self.prefetch_cache_name)
    336344
     345
     346        def all(self):
     347            try:
     348                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
     349            except (AttributeError, KeyError):
     350                return super(GenericRelatedObjectManager, self).all()
     351
     352
    337353        def add(self, *objs):
    338354            for obj in objs:
    339355                if not isinstance(obj, self.model):
  • django/db/models/__init__.py

    diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py
    index 3582720..fd07ca6 100644
    a b from django.db import connection  
    44from django.db.models.loading import get_apps, get_app, get_models, get_model, register_models
    55from django.db.models.query import Q
    66from django.db.models.expressions import F
     7from django.db.models.related import R
    78from django.db.models.manager import Manager
    89from django.db.models.base import Model
    910from django.db.models.aggregates import *
    1011from django.db.models.fields import *
    1112from django.db.models.fields.subclassing import SubfieldBase
    1213from django.db.models.fields.files import FileField, ImageField
    13 from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
    14 from django.db.models.deletion import CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError
     14from django.db.models.fields.related import (ForeignKey, OneToOneField,
     15        ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel)
     16from django.db.models.deletion import (CASCADE, PROTECT, SET, SET_NULL,
     17        SET_DEFAULT, DO_NOTHING, ProtectedError)
    1518from django.db.models import signals
    1619from django.utils.decorators import wraps
    1720
  • django/db/models/fields/related.py

    diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
    index 8c054e7..a0748a1 100644
    a b class SingleRelatedObjectDescriptor(object):  
    236236        db = router.db_for_read(self.related.model, **db_hints)
    237237        return self.related.model._base_manager.using(db)
    238238
    239     def get_prefetch_query_set(self, instances):
     239    def get_prefetch_query_set(self, instances, custom_qs=None):
     240        if custom_qs is not None:
     241            # TODO: This error message is too SQLish, and might be downright
     242            # wrong.
     243            raise ValueError(
     244                "Custom querysets can't be used for one-to-one relations")
     245
    240246        vals = set(instance._get_pk_val() for instance in instances)
    241247        params = {'%s__pk__in' % self.related.field.name: vals}
    242248        return (self.get_query_set(),
    class ReverseSingleRelatedObjectDescriptor(object):  
    315321        else:
    316322            return QuerySet(self.field.rel.to).using(db)
    317323
    318     def get_prefetch_query_set(self, instances):
     324    def get_prefetch_query_set(self, instances, custom_qs=None):
     325        if custom_qs is not None:
     326            # TODO: This error message is too SQLish, and I am not even sure
     327            # this desriptor is used for m2o...
     328            raise ValueError(
     329                "Custom querysets can't be used for many-to-one relations")
     330
    319331        vals = set(getattr(instance, self.field.attname) for instance in instances)
    320332        other_field = self.field.rel.get_related_field()
    321333        if other_field.rel:
    class ForeignRelatedObjectsDescriptor(object):  
    460472                    db = self._db or router.db_for_read(self.model, instance=self.instance)
    461473                    return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
    462474
    463             def get_prefetch_query_set(self, instances):
    464                 db = self._db or router.db_for_read(self.model)
     475            def get_prefetch_query_set(self, instances, custom_qs=None):
     476                """
     477                Return a queryset that does the bulk lookup needed
     478                by prefetch_related functionality.
     479                """
    465480                query = {'%s__%s__in' % (rel_field.name, attname):
    466                              set(getattr(obj, attname) for obj in instances)}
    467                 qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
     481                            set(getattr(obj, attname) for obj in instances)}
     482                if custom_qs is not None:
     483                    qs = custom_qs.filter(**query)
     484                else:
     485                    db = self._db or router.db_for_read(self.model)
     486                    qs = super(RelatedManager, self).get_query_set().\
     487                                    using(db).filter(**query)
    468488                return (qs,
    469489                        attrgetter(rel_field.get_attname()),
    470490                        attrgetter(attname),
    471491                        False,
    472492                        rel_field.related_query_name())
    473493
     494            def all(self):
     495                try:
     496                    return self.instance._prefetched_objects_cache[rel_field.related_query_name()]
     497                except (AttributeError, KeyError):
     498                    return super(RelatedManager, self).all()
     499
    474500            def add(self, *objs):
    475501                for obj in objs:
    476502                    if not isinstance(obj, self.model):
    def create_many_related_manager(superclass, rel):  
    542568                db = self._db or router.db_for_read(self.instance.__class__, instance=self.instance)
    543569                return super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**self.core_filters)
    544570
    545         def get_prefetch_query_set(self, instances):
     571        def get_prefetch_query_set(self, instances, custom_qs=None):
    546572            from django.db import connections
    547573            db = self._db or router.db_for_read(self.model)
    548574            query = {'%s__pk__in' % self.query_field_name:
    549                          set(obj._get_pk_val() for obj in instances)}
    550             qs = super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**query)
     575                                  set(obj._get_pk_val() for obj in instances)}
     576
     577            if custom_qs is not None:
     578                qs = custom_qs._next_is_sticky().filter(**query)
     579            else:
     580                qs = (super(ManyRelatedManager, self).get_query_set().using(db)
     581                      ._next_is_sticky().filter(**query))
    551582
    552583            # M2M: need to annotate the query in order to get the primary model
    553             # that the secondary model was actually related to. We know that
    554             # there will already be a join on the join table, so we can just add
    555             # the select.
     584            # that the secondary model was actually related to.
     585
     586            # We know that there will already be a join on the join table, so we
     587            # can just add the select.
    556588
    557589            # For non-autocreated 'through' models, can't assume we are
    558590            # dealing with PK values.
     591
     592            # TODO: This is at the wrong level of abstraction. We should not
     593            # be generating SQL here, but instead maybe pass this information
     594            # to the connection. NoSQL camp will have problems with this, for
     595            # example.
    559596            fk = self.through._meta.get_field(self.source_field_name)
    560597            source_col = fk.column
    561598            join_table = self.through._meta.db_table
    562             connection = connections[db]
     599            if custom_qs is not None:
     600                connection = connections[custom_qs.db]
     601            else:
     602                connection = connections[db]
     603
    563604            qn = connection.ops.quote_name
    564605            qs = qs.extra(select={'_prefetch_related_val':
    565606                                      '%s.%s' % (qn(join_table), qn(source_col))})
    def create_many_related_manager(superclass, rel):  
    570611                    False,
    571612                    self.prefetch_cache_name)
    572613
     614        def all(self):
     615            try:
     616                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
     617            except (AttributeError, KeyError):
     618                return super(ManyRelatedManager, self).all()
     619
    573620        # If the ManyToMany relation has an intermediary model,
    574621        # the add and remove methods do not exist.
    575622        if rel.through._meta.auto_created:
  • django/db/models/query.py

    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index be42d02..f5cf2ae 100644
    a b from django.db.models.query_utils import (Q, select_related_descend,  
    1111    deferred_class_factory, InvalidQuery)
    1212from django.db.models.deletion import Collector
    1313from django.db.models import sql
     14from django.db.models.related import R
    1415from django.utils.functional import partition
    1516
    1617# Used to control how many objects are worked with at once in some cases (e.g.
    def insert_query(model, objs, fields, return_id=False, raw=False, using=None):  
    15551556    query.insert_values(fields, objs, raw=raw)
    15561557    return query.get_compiler(using=using).execute_sql(return_id)
    15571558
     1559def prl_to_r_objs(lookups, prefix=None):
     1560    """
     1561    This little helper function will convert a list containing R objects or
     1562    normal lookups into all R objects list.
     1563    """
     1564    from django.db.models.sql.constants import LOOKUP_SEP
     1565    if prefix is None:
     1566        return [isinstance(lup, R) and lup or R(lup) for lup in lookups]
     1567    ret = []
     1568    for lup in lookups:
     1569        if isinstance(lup, R):
     1570            r_obj = lup._new_prefixed(prefix)
     1571        else:
     1572            r_obj = R(prefix + LOOKUP_SEP + lup)
     1573        ret.append(r_obj)
     1574    return ret
    15581575
    15591576def prefetch_related_objects(result_cache, related_lookups):
    15601577    """
    def prefetch_related_objects(result_cache, related_lookups):  
    15671584
    15681585    if len(result_cache) == 0:
    15691586        return # nothing to do
    1570 
     1587    r_objs = prl_to_r_objs(related_lookups)
    15711588    model = result_cache[0].__class__
    15721589
    15731590    # We need to be able to dynamically add to the list of prefetch_related
    15741591    # lookups that we look up (see below).  So we need some book keeping to
    15751592    # ensure we don't do duplicate work.
    1576     done_lookups = set() # list of lookups like foo__bar__baz
     1593    seen_lookups = set() # list of lookups like foo__bar__baz
    15771594    done_queries = {}    # dictionary of things like 'foo__bar': [results]
    15781595
    1579     manual_lookups = list(related_lookups)
     1596    manual_lookups = list(r_objs)
    15801597    auto_lookups = [] # we add to this as we go through.
    15811598    followed_descriptors = set() # recursion protection
    15821599
    1583     related_lookups = itertools.chain(manual_lookups, auto_lookups)
    1584     for lookup in related_lookups:
    1585         if lookup in done_lookups:
     1600    # For R-objects, we have two different lookups:
     1601    #   - lookup: This is the related object attribute name
     1602    #   - lookup_refpath: This is to be used when this R-object is referenced
     1603    #     in chained prefetches.
     1604    # One way to explain these would be to say lookup is how we go forward,
     1605    # lookup_refpath is what happened in the past.
     1606
     1607    r_objs = itertools.chain(manual_lookups, auto_lookups)
     1608    for r_obj in r_objs:
     1609        if r_obj.lookup_refpath in seen_lookups:
    15861610            # We've done exactly this already, skip the whole thing
    15871611            continue
    1588         done_lookups.add(lookup)
     1612        seen_lookups.add(r_obj.lookup_refpath)
    15891613
    15901614        # Top level, the list of objects to decorate is the the result cache
    15911615        # from the primary QuerySet. It won't be for deeper levels.
    15921616        obj_list = result_cache
    15931617
    1594         attrs = lookup.split(LOOKUP_SEP)
     1618        attrs = r_obj.lookup.split(LOOKUP_SEP)
    15951619        for level, attr in enumerate(attrs):
    15961620            # Prepare main instances
    15971621            if len(obj_list) == 0:
    def prefetch_related_objects(result_cache, related_lookups):  
    16191643
    16201644            # We assume that objects retrieved are homogenous (which is the premise
    16211645            # of prefetch_related), so what applies to first object applies to all.
     1646            # TODO: Make sure this is really true for objects coming from generic
     1647            # relations.
    16221648            first_obj = obj_list[0]
    1623             prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(first_obj, attr)
     1649            prefetcher, descriptor, attr_found, is_fetched = \
     1650                get_prefetcher(first_obj, attr)
    16241651
    16251652            if not attr_found:
    16261653                raise AttributeError("Cannot find '%s' on %s object, '%s' is an invalid "
    16271654                                     "parameter to prefetch_related()" %
    1628                                      (attr, first_obj.__class__.__name__, lookup))
     1655                                     (attr, first_obj.__class__.__name__,
     1656                                      r_obj.lookup))
    16291657
    16301658            if level == len(attrs) - 1 and prefetcher is None:
    16311659                # Last one, this *must* resolve to something that supports
    16321660                # prefetching, otherwise there is no point adding it and the
    16331661                # developer asking for it has made a mistake.
    1634                 raise ValueError("'%s' does not resolve to a item that supports "
     1662                raise ValueError("'%s' does not resolve to an item that supports "
    16351663                                 "prefetching - this is an invalid parameter to "
    1636                                  "prefetch_related()." % lookup)
     1664                                 "prefetch_related()." % r_obj.lookup)
    16371665
    16381666            if prefetcher is not None and not is_fetched:
    1639                 # Check we didn't do this already
    1640                 current_lookup = LOOKUP_SEP.join(attrs[0:level+1])
     1667                current_lookup = r_obj.get_current_lookup(level)
    16411668                if current_lookup in done_queries:
    16421669                    obj_list = done_queries[current_lookup]
    16431670                else:
    1644                     obj_list, additional_prl = prefetch_one_level(obj_list, prefetcher, attr)
     1671                    obj_list, additional_prl = prefetch_one_level(
     1672                            obj_list, prefetcher, r_obj, level)
    16451673                    # We need to ensure we don't keep adding lookups from the
    16461674                    # same relationships to stop infinite recursion. So, if we
    16471675                    # are already on an automatically added lookup, don't add
    16481676                    # the new lookups from relationships we've seen already.
    1649                     if not (lookup in auto_lookups and
     1677                    if not (r_obj in auto_lookups and
    16501678                            descriptor in followed_descriptors):
    1651                         for f in additional_prl:
    1652                             new_prl = LOOKUP_SEP.join([current_lookup, f])
    1653                             auto_lookups.append(new_prl)
    16541679                        done_queries[current_lookup] = obj_list
     1680                        additional_prl = prl_to_r_objs(additional_prl,
     1681                                                       current_lookup)
     1682                        auto_lookups.extend(additional_prl)
    16551683                    followed_descriptors.add(descriptor)
     1684
     1685            elif isinstance(getattr(obj_list[0], attr), list):
     1686                # The current part of the lookup relates to a r_obj.to_attr
     1687                # defined previous fetch. This means that obj.attr is a list
     1688                # of related objects, and thus we must turn the obj.attr lists
     1689                # into a single related object list.
     1690                new_list = []
     1691                for obj in obj_list:
     1692                    new_list.extend(getattr(obj, attr))
     1693                obj_list = new_list
    16561694            else:
    16571695                # Either a singly related object that has already been fetched
    16581696                # (e.g. via select_related), or hopefully some other property
    16591697                # that doesn't support prefetching but needs to be traversed.
    16601698
    16611699                # We replace the current list of parent objects with that list.
     1700                # TODO: Check what happens if attr resolves to local field?
     1701                # User typoing rel_attr_id instead of rel_attr? AND there are
     1702                # multiple parts in the path left.
    16621703                obj_list = [getattr(obj, attr) for obj in obj_list]
    16631704
    16641705                # Filter out 'None' so that we can continue with nullable
    def get_prefetcher(instance, attr):  
    16881729        try:
    16891730            rel_obj = getattr(instance, attr)
    16901731            attr_found = True
     1732            # If we are following a r_obj lookup path which leads us through
     1733            # a previous fetch with to_attr, then we might end up into a list
     1734            # instead of related qs. This means the objects are already
     1735            # fetched.
     1736            if isinstance(rel_obj, list):
     1737                is_fetched = True
    16911738        except AttributeError:
    16921739            pass
    16931740    else:
    def get_prefetcher(instance, attr):  
    17091756    return prefetcher, rel_obj_descriptor, attr_found, is_fetched
    17101757
    17111758
    1712 def prefetch_one_level(instances, prefetcher, attname):
     1759def prefetch_one_level(instances, prefetcher, r_obj, level):
    17131760    """
    17141761    Helper function for prefetch_related_objects
    17151762
    def prefetch_one_level(instances, prefetcher, attname):  
    17331780    # in a dictionary.
    17341781
    17351782    rel_qs, rel_obj_attr, instance_attr, single, cache_name =\
    1736         prefetcher.get_prefetch_query_set(instances)
     1783        prefetcher.get_prefetch_query_set(instances, custom_qs=r_obj.qs)
     1784
    17371785    # We have to handle the possibility that the default manager itself added
    17381786    # prefetch_related lookups to the QuerySet we just got back. We don't want to
    17391787    # trigger the prefetch_related functionality by evaluating the query.
    def prefetch_one_level(instances, prefetcher, attname):  
    17541802            rel_obj_cache[rel_attr_val] = []
    17551803        rel_obj_cache[rel_attr_val].append(rel_obj)
    17561804
     1805
     1806    # to_attr is the name of the attribute we will be fetching into, to_list
     1807    # is False if to_attr refers to related manager. If it refers to related
     1808    # manager, we will be caching in rel_manager.all(), otherwise in a list.
     1809    to_attr, to_list = r_obj.get_to_attr(level)
    17571810    for obj in instances:
     1811        # TODO: in this case we could set the reverse attribute if the relation
     1812        # is o2o. Both this and the TODO below are handled by select_related
     1813        # in the get_cached_row iterator construction. Maybe that code could
     1814        # be generalized and shared.
    17581815        instance_attr_val = instance_attr(obj)
    17591816        vals = rel_obj_cache.get(instance_attr_val, [])
    17601817        if single:
    def prefetch_one_level(instances, prefetcher, attname):  
    17641821        else:
    17651822            # Multi, attribute represents a manager with an .all() method that
    17661823            # returns a QuerySet
    1767             qs = getattr(obj, attname).all()
    1768             qs._result_cache = vals
    1769             # We don't want the individual qs doing prefetch_related now, since we
    1770             # have merged this into the current work.
    1771             qs._prefetch_done = True
    1772             obj._prefetched_objects_cache[cache_name] = qs
     1824            # TODO: we could set the reverse relation, so that if user does
     1825            # access the just fetched relation in the reverse order, we would
     1826            # not need to do a query. We can't do this for m2m, of course.
     1827            if to_list:
     1828                setattr(obj, to_attr, vals)
     1829            else:
     1830                # Cache in the QuerySet.all().
     1831                qs = getattr(obj, to_attr).all()
     1832                qs._result_cache = vals
     1833                # We don't want the individual qs doing prefetch_related now,
     1834                # since we have merged this into the current work.
     1835                qs._prefetch_done = True
     1836                obj._prefetched_objects_cache[cache_name] = qs
    17731837    return all_related_objects, additional_prl
  • django/db/models/related.py

    diff --git a/django/db/models/related.py b/django/db/models/related.py
    index 90995d7..2c1e954 100644
    a b  
    11from django.utils.encoding import smart_unicode
    22from django.db.models.fields import BLANK_CHOICE_DASH
     3from django.db.models.sql.constants import LOOKUP_SEP
    34
    45class BoundRelatedObject(object):
    56    def __init__(self, related_object, field_mapping, original):
    class RelatedObject(object):  
    3637                {'%s__isnull' % self.parent_model._meta.module_name: False})
    3738        lst = [(x._get_pk_val(), smart_unicode(x)) for x in queryset]
    3839        return first_choice + lst
    39        
     40
    4041    def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
    4142        # Defer to the actual field definition for db prep
    4243        return self.field.get_db_prep_lookup(lookup_type, value,
    class RelatedObject(object):  
    6768
    6869    def get_cache_name(self):
    6970        return "_%s_cache" % self.get_accessor_name()
     71
     72# Not knowing a better place for this, I just planted R here.
     73# Feel free to move this to a better place or remove this comment.
     74class R(object):
     75    """
     76    A class used for passing options to .prefetch_related. Note that instances
     77    of this class should be considered immutable.
     78    """
     79
     80    # For R-objects, we have two different internal lookup paths:
     81    #   - lookup: This is the related object attribute name
     82    #   - lookup_refpath: This is to be used when this R-object is referenced
     83    #     in chained prefetches.
     84    # Check out the source of R-objects to see what is happening there.
     85    #
     86    # The difference is needed, because when we chain R-objects with to_attr
     87    # defined, the lookup_path (how we got here) and lookup_refpath (how to
     88    # get forward from here) will be different. For example:
     89    # R('foo', to_attr='foolst') -> lookup_path = foo, that is we are going
     90    # to prefetch through relation foo.
     91    #
     92    # If there would be another qs produced by R, the lookup_refpath would
     93    # need to be 'foolst__nextpart'. Otherwise we can't distinguish between
     94    # two different prefetch_related lookups to 'foo' (perhaps with custom
     95    # querysets).
     96    #
     97    # Luckily the user does not need to know anything about this.
     98
     99    def __init__(self, lookup, to_attr=None, qs=None):
     100        if qs is not None and not to_attr:
     101            raise ValueError('When custom qs is defined, to_attr '
     102                             'must also be defined')
     103        self.lookup = lookup
     104        self.to_attr = to_attr
     105        self.qs = qs
     106
     107    def _new_prefixed(self, prefix):
     108        """
     109        _new_internal is to be used when prefetches are chained internally.
     110        The returned R-object is identical to self, except lookup_path
     111        is prefixed with prefix.
     112        """
     113        new_lookup = prefix + LOOKUP_SEP + self.lookup
     114        return R(new_lookup, to_attr=self.to_attr, qs=self.qs)
     115
     116    def __unicode__(self):
     117        return ("lookup: %s, to_attr: %s, qs: %s" %
     118            (self.lookup, self.to_attr or None, self.qs))
     119
     120    def __repr__(self):
     121        return '<%s: %s>' % (self.__class__.__name__, unicode(self))
     122
     123    def __eq__(self, other):
     124        if isinstance(other, R):
     125            return self.lookup_refpath == other.lookup_refpath
     126        return False
     127
     128    def _lookup_refpath(self):
     129        if self.to_attr is None:
     130            return self.lookup
     131        else:
     132            path, sep, last_part = self.lookup.rpartition(LOOKUP_SEP)
     133            return path + sep + self.to_attr
     134    lookup_refpath = property(_lookup_refpath)
     135
     136    def get_current_lookup(self, level):
     137        """
     138        Returns the first level + 1 parts of the self.lookup_refpath
     139        """
     140        parts = self.lookup_refpath.split(LOOKUP_SEP)
     141        return LOOKUP_SEP.join(parts[0:level + 1])
     142
     143    def get_to_attr(self, level):
     144        """
     145        Returns information about into what attribute should the results be
     146        fetched, and if that attribute is related object manager, or will the
     147        objects be fetched into a list.
     148        """
     149        parts = self.lookup_refpath.split(LOOKUP_SEP)
     150        if self.to_attr is None or level < len(parts) - 1:
     151            return parts[level], False
     152        else:
     153            return self.to_attr, True
  • tests/modeltests/prefetch_related/models.py

    diff --git a/tests/modeltests/prefetch_related/models.py b/tests/modeltests/prefetch_related/models.py
    index 1c14c88..3a996db 100644
    a b class BookWithYear(Book):  
    5858        AuthorWithAge, related_name='books_with_year')
    5959
    6060
     61class AuthorDefManager(models.Manager):
     62    # Default manager with possibly recursive results.
     63    def get_query_set(self):
     64        qs = super(AuthorDefManager, self).get_query_set()
     65        return qs.prefetch_related('best_friend_reverse', 'books')
     66
     67class AuthorWithDefPrefetch(models.Model):
     68    name = models.TextField()
     69    best_friend = models.ForeignKey(
     70         'self', related_name='best_friend_reverse', null=True)
     71    objects = AuthorDefManager()
     72
     73class BookDefManager(models.Manager):
     74    # No need for guard here, author's manager will take care of that.
     75    def get_query_set(self):
     76        return (super(BookDefManager, self).get_query_set()
     77                .prefetch_related('authors'))
     78
     79class BookWithDefPrefetch(models.Model):
     80    name = models.TextField()
     81    authors = models.ManyToManyField(AuthorWithDefPrefetch,
     82                                     related_name='books')
     83
     84    objects = BookDefManager()
     85 
     86
    6187class Reader(models.Model):
    6288    name = models.CharField(max_length=50)
    6389    books_read = models.ManyToManyField(Book, related_name='read_by')
    class Person(models.Model):  
    155181        ordering = ['id']
    156182
    157183
    158 ## Models for nullable FK tests
     184## Models for nullable FK tests and recursive prefetch_related tests.
    159185
    160186class Employee(models.Model):
    161187    name = models.CharField(max_length=50)
  • tests/modeltests/prefetch_related/tests.py

    diff --git a/tests/modeltests/prefetch_related/tests.py b/tests/modeltests/prefetch_related/tests.py
    index bdbb056..f424fc1 100644
    a b  
    11from __future__ import with_statement
    22
    33from django.contrib.contenttypes.models import ContentType
     4from django.db.models import R
    45from django.test import TestCase
    56from django.utils import unittest
    67
    78from models import (Author, Book, Reader, Qualification, Teacher, Department,
    89                    TaggedItem, Bookmark, AuthorAddress, FavoriteAuthors,
    910                    AuthorWithAge, BookWithYear, Person, House, Room,
    10                     Employee)
     11                    Employee, AuthorWithDefPrefetch, BookWithDefPrefetch)
     12
     13def traverse_qs(obj_iter, path):
     14    """
     15    Helper method that returns a list containing a list of the objects in the
     16    obj_iter. Then for each object in the obj_iter, the path will be
     17    recursively travelled and the found objects are added to the return value.
     18    """
     19    ret_val = []
     20    if hasattr(obj_iter, 'all'):
     21        obj_iter = obj_iter.all()
     22    for obj in obj_iter:
     23        rel_objs = []
     24        for part in path:
     25            if not part:
     26                continue
     27            rel_objs.extend(traverse_qs(getattr(obj, part[0]), [part[1:]]))
     28        ret_val.append((obj, rel_objs))
     29    return ret_val
    1130
    1231
    1332class PrefetchRelatedTests(TestCase):
    class PrefetchRelatedTests(TestCase):  
    3958        self.reader1.books_read.add(self.book1, self.book4)
    4059        self.reader2.books_read.add(self.book2, self.book4)
    4160
     61    def test_metatest_traverse_qs(self):
     62        qs = Book.objects.prefetch_related('authors')
     63        related_objs_normal = [list(b.authors.all()) for b in qs],
     64        related_objs_from_traverse = [[inner[0] for inner in o[1]]
     65                                      for o in traverse_qs(qs, [['authors']])]
     66        self.assertEquals(related_objs_normal, (related_objs_from_traverse,))
     67        self.assertFalse(related_objs_from_traverse == traverse_qs(qs.filter(pk=1),
     68                         [['authors']]))
     69
    4270    def test_m2m_forward(self):
    4371        with self.assertNumQueries(2):
    4472            lists = [list(b.authors.all()) for b in Book.objects.prefetch_related('authors')]
    class NullableTest(TestCase):  
    472500                        for e in qs2]
    473501
    474502        self.assertEqual(co_serfs, co_serfs2)
     503
     504
     505class RObjectTest(TestCase):
     506    def setUp(self):
     507        self.person1 = Person.objects.create(name="Joe")
     508        self.person2 = Person.objects.create(name="Mary")
     509
     510        self.house1 = House.objects.create(address="123 Main St")
     511        self.house2 = House.objects.create(address="45 Side St")
     512        self.house3 = House.objects.create(address="6 Downing St")
     513        self.house4 = House.objects.create(address="7 Regents St")
     514
     515        self.room1_1 = Room.objects.create(name="Dining room", house=self.house1)
     516        self.room1_2 = Room.objects.create(name="Lounge", house=self.house1)
     517        self.room1_3 = Room.objects.create(name="Kitchen", house=self.house1)
     518
     519        self.room2_1 = Room.objects.create(name="Dining room", house=self.house2)
     520        self.room2_2 = Room.objects.create(name="Lounge", house=self.house2)
     521
     522        self.room3_1 = Room.objects.create(name="Dining room", house=self.house3)
     523        self.room3_2 = Room.objects.create(name="Lounge", house=self.house3)
     524        self.room3_3 = Room.objects.create(name="Kitchen", house=self.house3)
     525
     526        self.room4_1 = Room.objects.create(name="Dining room", house=self.house4)
     527        self.room4_2 = Room.objects.create(name="Lounge", house=self.house4)
     528
     529        self.person1.houses.add(self.house1, self.house2)
     530        self.person2.houses.add(self.house3, self.house4)
     531
     532    def test_robj_basics(self):
     533        # Test different combinations of R and non-R lookups
     534        with self.assertNumQueries(2):
     535            lst1 = traverse_qs(Person.objects.prefetch_related('houses'),
     536                               [['houses']])
     537        with self.assertNumQueries(2):
     538            lst2 = traverse_qs(Person.objects.prefetch_related(R('houses')),
     539                               [['houses']])
     540        self.assertEquals(lst1, lst2)
     541        with self.assertNumQueries(3):
     542            lst1 = traverse_qs(Person.objects.prefetch_related('houses', 'houses__rooms'),
     543                               [['houses', 'rooms']])
     544        with self.assertNumQueries(3):
     545            lst2 = traverse_qs(Person.objects.prefetch_related(R('houses'), R('houses__rooms')),
     546                               [['houses', 'rooms']])
     547        self.assertEquals(lst1, lst2)
     548        with self.assertNumQueries(3):
     549            lst1 = traverse_qs(Person.objects.prefetch_related('houses', 'houses__rooms'),
     550                               [['houses', 'rooms']])
     551        with self.assertNumQueries(3):
     552            lst2 = traverse_qs(Person.objects.prefetch_related(R('houses'), 'houses__rooms'),
     553                               [['houses', 'rooms']])
     554        self.assertEquals(lst1, lst2)
     555        # Test to_attr
     556        with self.assertNumQueries(3):
     557            lst1 = traverse_qs(Person.objects.prefetch_related('houses', 'houses__rooms'),
     558                               [['houses', 'rooms']])
     559        with self.assertNumQueries(3):
     560            lst2 = traverse_qs(Person.objects.prefetch_related(
     561                                  R('houses', to_attr='houses_lst'),
     562                                  'houses_lst__rooms'),
     563                               [['houses_lst', 'rooms']])
     564        self.assertEquals(lst1, lst2)
     565
     566        with self.assertNumQueries(4):
     567            qs = list(Person.objects.prefetch_related(
     568                    R('houses', to_attr='houses_lst'),
     569                    R('houses__rooms', to_attr='rooms_lst')
     570            ))
     571            with self.assertRaises(AttributeError):
     572                qs[0].houses_lst2[0].rooms_lst
     573            qs[0].houses.all()[0].rooms_lst
     574            lst2 = traverse_qs(
     575                qs, [['houses', 'rooms_lst']]
     576            )
     577            self.assertEquals(lst1, lst2)
     578            self.assertEquals(
     579                traverse_qs(qs, [['houses']]),
     580                traverse_qs(qs, [['houses_lst']])
     581            )
     582
     583    def test_custom_qs(self):
     584        person_qs = Person.objects.all()
     585        houses_qs = House.objects.all()
     586        with self.assertNumQueries(2):
     587             lst1 = list(person_qs.prefetch_related('houses'))
     588        with self.assertNumQueries(2):
     589             lst2 = list(person_qs.prefetch_related(
     590                 R('houses', qs=houses_qs, to_attr='houses_lst')
     591             ))
     592        self.assertEquals(
     593            traverse_qs(lst1, [['houses']]),
     594            traverse_qs(lst2, [['houses_lst']])
     595        )
     596        with self.assertNumQueries(2):
     597            lst2 = list(person_qs.prefetch_related(
     598                R('houses', qs=houses_qs.filter(pk__in=[self.house1.pk, self.house3.pk]),
     599                  to_attr='hlst')
     600            ))
     601        self.assertEquals(len(lst2[0].hlst), 1)
     602        self.assertEquals(lst2[0].hlst[0], self.house1)
     603        self.assertEquals(len(lst2[1].hlst), 1)
     604        self.assertEquals(lst2[1].hlst[0], self.house3)
     605
     606        inner_rooms_qs = Room.objects.filter(pk__in=[self.room1_1.pk, self.room1_2.pk])
     607        houses_qs_prf = houses_qs.prefetch_related(
     608            R('rooms', qs=inner_rooms_qs, to_attr='rooms_lst'))
     609        with self.assertNumQueries(3):
     610            lst2 = list(person_qs.prefetch_related(
     611                       R('houses', qs=houses_qs_prf.filter(pk=self.house1.pk), to_attr='hlst'),
     612                   ))
     613        self.assertEquals(len(lst2[0].hlst[0].rooms_lst), 2)
     614        self.assertEquals(lst2[0].hlst[0].rooms_lst[0], self.room1_1)
     615        self.assertEquals(lst2[0].hlst[0].rooms_lst[1], self.room1_2)
     616        self.assertEquals(len(lst2[1].hlst), 0)
Back to Top