Ticket #18177: prefetch_related_cache.patch

File prefetch_related_cache.patch, 6.9 KB (added by kaiser.yann@…, 12 years ago)

Patch

  • django/db/models/fields/related.py

     
    8484    for cls, field, operation in pending_lookups.pop(key, []):
    8585        operation(field, sender, cls)
    8686
     87def set_eri(qs, field_name, rel_field_name, instances):
     88    qs._existing_related_instances = {
     89        field_name: dict(
     90            (getattr(instance, rel_field_name), instance)
     91            for instance in instances),
     92        }
     93
    8794signals.class_prepared.connect(do_pending_lookups)
    8895
    8996#HACK
     
    232239    def is_cached(self, instance):
    233240        return hasattr(instance, self.cache_name)
    234241
     242    def set_eri(self, qs, instances):
     243        set_eri(qs,
     244                self.related.field.name,
     245                self.related.field.rel.field_name,
     246                instances)
     247
    235248    def get_query_set(self, **db_hints):
    236249        db = router.db_for_read(self.related.model, **db_hints)
    237250        return self.related.model._base_manager.using(db)
     
    239252    def get_prefetch_query_set(self, instances):
    240253        vals = set(instance._get_pk_val() for instance in instances)
    241254        params = {'%s__pk__in' % self.related.field.name: vals}
    242         return (self.get_query_set(instance=instances[0]).filter(**params),
     255        qs = self.get_query_set(instance=instances[0]).filter(**params)
     256        self.set_eri(qs, instances)
     257        return (qs,
    243258                attrgetter(self.related.field.attname),
    244259                lambda obj: obj._get_pk_val(),
    245260                True,
     
    313328    def is_cached(self, instance):
    314329        return hasattr(instance, self.cache_name)
    315330
     331    def set_eri(self, qs, instances):
     332        set_eri(qs,
     333                self.field.related.get_accessor_name(),
     334                self.field.attname,
     335                instances
     336                )
     337
    316338    def get_query_set(self, **db_hints):
    317339        db = router.db_for_read(self.field.rel.to, **db_hints)
    318340        rel_mgr = self.field.rel.to._default_manager
     
    330352            params = {'%s__pk__in' % self.field.rel.field_name: vals}
    331353        else:
    332354            params = {'%s__in' % self.field.rel.field_name: vals}
    333         return (self.get_query_set(instance=instances[0]).filter(**params),
     355        qs = self.get_query_set(instance=instances[0]).filter(**params)
     356        self.set_eri(qs, instances)
     357        return (qs,
    334358                attrgetter(self.field.rel.field_name),
    335359                attrgetter(self.field.attname),
    336360                True,
     
    462486                }
    463487                self.model = rel_model
    464488
     489            def set_eri(self, qs, instances):
     490                set_eri(qs,
     491                        rel_field.name,
     492                        rel_field.rel.field_name,
     493                        instances)
     494
    465495            def get_query_set(self):
    466496                try:
    467497                    return self.instance._prefetched_objects_cache[rel_field.related_query_name()]
    468498                except (AttributeError, KeyError):
    469499                    db = self._db or router.db_for_read(self.model, instance=self.instance)
    470                     return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
     500                    qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
     501                    self.set_eri(qs, (self.instance,))
     502                    return qs
    471503
    472504            def get_prefetch_query_set(self, instances):
    473505                db = self._db or router.db_for_read(self.model, instance=instances[0])
    474506                query = {'%s__%s__in' % (rel_field.name, attname):
    475507                             set(getattr(obj, attname) for obj in instances)}
    476508                qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
     509                self.set_eri(qs, instances)
     510
    477511                return (qs,
    478512                        attrgetter(rel_field.get_attname()),
    479513                        attrgetter(attname),
  • django/db/models/query.py

     
    4141        self._for_write = False
    4242        self._prefetch_related_lookups = []
    4343        self._prefetch_done = False
     44        self._existing_related_instances = {}
    4445
    4546    ########################
    4647    # PYTHON MAGIC METHODS #
     
    294295                obj, _ = get_cached_row(row, index_start, db, klass_info,
    295296                                        offset=len(aggregate_select))
    296297            else:
     298                # Omit aggregates in object creation.
     299                row_data = row[index_start:aggregate_start]
    297300                if skip:
    298                     row_data = row[index_start:aggregate_start]
    299301                    obj = model_cls(**dict(zip(init_list, row_data)))
    300302                else:
    301                     # Omit aggregates in object creation.
    302                     obj = model(*row[index_start:aggregate_start])
     303                    obj = model(*row_data)
    303304
    304305                # Store the source database of the object
    305306                obj._state.db = db
     
    315316                for i, aggregate in enumerate(aggregate_select):
    316317                    setattr(obj, aggregate, row[i+aggregate_start])
    317318
     319            if self._existing_related_instances:
     320                for attr, values in self._existing_related_instances.iteritems():
     321                    descriptor = getattr(obj.__class__, attr)
     322                    if len(values) > 1:
     323                        try:
     324                            attname = descriptor.field.attname
     325                        except AttributeError:
     326                            # If descriptor has no field attribute,
     327                            # this is usually a reverse one-to-one
     328                            # relation in which case obj PKs are used
     329                            val = values[getattr(
     330                                obj,
     331                                descriptor.related.field.rel.field_name
     332                                )]
     333                        else:
     334                            val = values[
     335                                getattr(obj, attname)]
     336                    else:
     337                        val = values.values()[0]
     338
     339                    setattr(
     340                        obj,
     341                        descriptor.cache_name,
     342                        val
     343                        )
     344
    318345            yield obj
    319346
    320347    def aggregate(self, *args, **kwargs):
     
    860887        c = klass(model=self.model, query=query, using=self._db)
    861888        c._for_write = self._for_write
    862889        c._prefetch_related_lookups = self._prefetch_related_lookups[:]
     890        c._existing_related_instances = dict(self._existing_related_instances)
    863891        c.__dict__.update(kwargs)
    864892        if setup and hasattr(c, '_setup_query'):
    865893            c._setup_query()
Back to Top