Code

Ticket #18177: prefetch_related_cache.patch

File prefetch_related_cache.patch, 6.9 KB (added by kaiser.yann@…, 2 years ago)

Patch

  • django/db/models/fields/related.py

     
    8484    for cls, field, operation in pending_lookups.pop(key, []): 
    8585        operation(field, sender, cls) 
    8686 
     87def set_eri(qs, field_name, rel_field_name, instances): 
     88    qs._existing_related_instances = { 
     89        field_name: dict( 
     90            (getattr(instance, rel_field_name), instance) 
     91            for instance in instances), 
     92        } 
     93 
    8794signals.class_prepared.connect(do_pending_lookups) 
    8895 
    8996#HACK 
     
    232239    def is_cached(self, instance): 
    233240        return hasattr(instance, self.cache_name) 
    234241 
     242    def set_eri(self, qs, instances): 
     243        set_eri(qs, 
     244                self.related.field.name, 
     245                self.related.field.rel.field_name, 
     246                instances) 
     247 
    235248    def get_query_set(self, **db_hints): 
    236249        db = router.db_for_read(self.related.model, **db_hints) 
    237250        return self.related.model._base_manager.using(db) 
     
    239252    def get_prefetch_query_set(self, instances): 
    240253        vals = set(instance._get_pk_val() for instance in instances) 
    241254        params = {'%s__pk__in' % self.related.field.name: vals} 
    242         return (self.get_query_set(instance=instances[0]).filter(**params), 
     255        qs = self.get_query_set(instance=instances[0]).filter(**params) 
     256        self.set_eri(qs, instances) 
     257        return (qs, 
    243258                attrgetter(self.related.field.attname), 
    244259                lambda obj: obj._get_pk_val(), 
    245260                True, 
     
    313328    def is_cached(self, instance): 
    314329        return hasattr(instance, self.cache_name) 
    315330 
     331    def set_eri(self, qs, instances): 
     332        set_eri(qs, 
     333                self.field.related.get_accessor_name(), 
     334                self.field.attname, 
     335                instances 
     336                ) 
     337 
    316338    def get_query_set(self, **db_hints): 
    317339        db = router.db_for_read(self.field.rel.to, **db_hints) 
    318340        rel_mgr = self.field.rel.to._default_manager 
     
    330352            params = {'%s__pk__in' % self.field.rel.field_name: vals} 
    331353        else: 
    332354            params = {'%s__in' % self.field.rel.field_name: vals} 
    333         return (self.get_query_set(instance=instances[0]).filter(**params), 
     355        qs = self.get_query_set(instance=instances[0]).filter(**params) 
     356        self.set_eri(qs, instances) 
     357        return (qs, 
    334358                attrgetter(self.field.rel.field_name), 
    335359                attrgetter(self.field.attname), 
    336360                True, 
     
    462486                } 
    463487                self.model = rel_model 
    464488 
     489            def set_eri(self, qs, instances): 
     490                set_eri(qs, 
     491                        rel_field.name, 
     492                        rel_field.rel.field_name, 
     493                        instances) 
     494 
    465495            def get_query_set(self): 
    466496                try: 
    467497                    return self.instance._prefetched_objects_cache[rel_field.related_query_name()] 
    468498                except (AttributeError, KeyError): 
    469499                    db = self._db or router.db_for_read(self.model, instance=self.instance) 
    470                     return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) 
     500                    qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) 
     501                    self.set_eri(qs, (self.instance,)) 
     502                    return qs 
    471503 
    472504            def get_prefetch_query_set(self, instances): 
    473505                db = self._db or router.db_for_read(self.model, instance=instances[0]) 
    474506                query = {'%s__%s__in' % (rel_field.name, attname): 
    475507                             set(getattr(obj, attname) for obj in instances)} 
    476508                qs = super(RelatedManager, self).get_query_set().using(db).filter(**query) 
     509                self.set_eri(qs, instances) 
     510 
    477511                return (qs, 
    478512                        attrgetter(rel_field.get_attname()), 
    479513                        attrgetter(attname), 
  • django/db/models/query.py

     
    4141        self._for_write = False 
    4242        self._prefetch_related_lookups = [] 
    4343        self._prefetch_done = False 
     44        self._existing_related_instances = {} 
    4445 
    4546    ######################## 
    4647    # PYTHON MAGIC METHODS # 
     
    294295                obj, _ = get_cached_row(row, index_start, db, klass_info, 
    295296                                        offset=len(aggregate_select)) 
    296297            else: 
     298                # Omit aggregates in object creation. 
     299                row_data = row[index_start:aggregate_start] 
    297300                if skip: 
    298                     row_data = row[index_start:aggregate_start] 
    299301                    obj = model_cls(**dict(zip(init_list, row_data))) 
    300302                else: 
    301                     # Omit aggregates in object creation. 
    302                     obj = model(*row[index_start:aggregate_start]) 
     303                    obj = model(*row_data) 
    303304 
    304305                # Store the source database of the object 
    305306                obj._state.db = db 
     
    315316                for i, aggregate in enumerate(aggregate_select): 
    316317                    setattr(obj, aggregate, row[i+aggregate_start]) 
    317318 
     319            if self._existing_related_instances: 
     320                for attr, values in self._existing_related_instances.iteritems(): 
     321                    descriptor = getattr(obj.__class__, attr) 
     322                    if len(values) > 1: 
     323                        try: 
     324                            attname = descriptor.field.attname 
     325                        except AttributeError: 
     326                            # If descriptor has no field attribute, 
     327                            # this is usually a reverse one-to-one 
     328                            # relation in which case obj PKs are used 
     329                            val = values[getattr( 
     330                                obj, 
     331                                descriptor.related.field.rel.field_name 
     332                                )] 
     333                        else: 
     334                            val = values[ 
     335                                getattr(obj, attname)] 
     336                    else: 
     337                        val = values.values()[0] 
     338 
     339                    setattr( 
     340                        obj, 
     341                        descriptor.cache_name, 
     342                        val 
     343                        ) 
     344 
    318345            yield obj 
    319346 
    320347    def aggregate(self, *args, **kwargs): 
     
    860887        c = klass(model=self.model, query=query, using=self._db) 
    861888        c._for_write = self._for_write 
    862889        c._prefetch_related_lookups = self._prefetch_related_lookups[:] 
     890        c._existing_related_instances = dict(self._existing_related_instances) 
    863891        c.__dict__.update(kwargs) 
    864892        if setup and hasattr(c, '_setup_query'): 
    865893            c._setup_query()