Ticket #18177: prefetch_related_cache.patch
File prefetch_related_cache.patch, 6.9 KB (added by , 13 years ago) |
---|
-
django/db/models/fields/related.py
84 84 for cls, field, operation in pending_lookups.pop(key, []): 85 85 operation(field, sender, cls) 86 86 87 def set_eri(qs, field_name, rel_field_name, instances): 88 qs._existing_related_instances = { 89 field_name: dict( 90 (getattr(instance, rel_field_name), instance) 91 for instance in instances), 92 } 93 87 94 signals.class_prepared.connect(do_pending_lookups) 88 95 89 96 #HACK … … 232 239 def is_cached(self, instance): 233 240 return hasattr(instance, self.cache_name) 234 241 242 def set_eri(self, qs, instances): 243 set_eri(qs, 244 self.related.field.name, 245 self.related.field.rel.field_name, 246 instances) 247 235 248 def get_query_set(self, **db_hints): 236 249 db = router.db_for_read(self.related.model, **db_hints) 237 250 return self.related.model._base_manager.using(db) … … 239 252 def get_prefetch_query_set(self, instances): 240 253 vals = set(instance._get_pk_val() for instance in instances) 241 254 params = {'%s__pk__in' % self.related.field.name: vals} 242 return (self.get_query_set(instance=instances[0]).filter(**params), 255 qs = self.get_query_set(instance=instances[0]).filter(**params) 256 self.set_eri(qs, instances) 257 return (qs, 243 258 attrgetter(self.related.field.attname), 244 259 lambda obj: obj._get_pk_val(), 245 260 True, … … 313 328 def is_cached(self, instance): 314 329 return hasattr(instance, self.cache_name) 315 330 331 def set_eri(self, qs, instances): 332 set_eri(qs, 333 self.field.related.get_accessor_name(), 334 self.field.attname, 335 instances 336 ) 337 316 338 def get_query_set(self, **db_hints): 317 339 db = router.db_for_read(self.field.rel.to, **db_hints) 318 340 rel_mgr = self.field.rel.to._default_manager … … 330 352 params = {'%s__pk__in' % self.field.rel.field_name: vals} 331 353 else: 332 354 params = {'%s__in' % self.field.rel.field_name: vals} 333 return (self.get_query_set(instance=instances[0]).filter(**params), 355 qs = self.get_query_set(instance=instances[0]).filter(**params) 356 self.set_eri(qs, instances) 357 return (qs, 334 358 attrgetter(self.field.rel.field_name), 335 359 attrgetter(self.field.attname), 336 360 True, … … 462 486 } 463 487 self.model = rel_model 464 488 489 def set_eri(self, qs, instances): 490 set_eri(qs, 491 rel_field.name, 492 rel_field.rel.field_name, 493 instances) 494 465 495 def get_query_set(self): 466 496 try: 467 497 return self.instance._prefetched_objects_cache[rel_field.related_query_name()] 468 498 except (AttributeError, KeyError): 469 499 db = self._db or router.db_for_read(self.model, instance=self.instance) 470 return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) 500 qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) 501 self.set_eri(qs, (self.instance,)) 502 return qs 471 503 472 504 def get_prefetch_query_set(self, instances): 473 505 db = self._db or router.db_for_read(self.model, instance=instances[0]) 474 506 query = {'%s__%s__in' % (rel_field.name, attname): 475 507 set(getattr(obj, attname) for obj in instances)} 476 508 qs = super(RelatedManager, self).get_query_set().using(db).filter(**query) 509 self.set_eri(qs, instances) 510 477 511 return (qs, 478 512 attrgetter(rel_field.get_attname()), 479 513 attrgetter(attname), -
django/db/models/query.py
41 41 self._for_write = False 42 42 self._prefetch_related_lookups = [] 43 43 self._prefetch_done = False 44 self._existing_related_instances = {} 44 45 45 46 ######################## 46 47 # PYTHON MAGIC METHODS # … … 294 295 obj, _ = get_cached_row(row, index_start, db, klass_info, 295 296 offset=len(aggregate_select)) 296 297 else: 298 # Omit aggregates in object creation. 299 row_data = row[index_start:aggregate_start] 297 300 if skip: 298 row_data = row[index_start:aggregate_start]299 301 obj = model_cls(**dict(zip(init_list, row_data))) 300 302 else: 301 # Omit aggregates in object creation. 302 obj = model(*row[index_start:aggregate_start]) 303 obj = model(*row_data) 303 304 304 305 # Store the source database of the object 305 306 obj._state.db = db … … 315 316 for i, aggregate in enumerate(aggregate_select): 316 317 setattr(obj, aggregate, row[i+aggregate_start]) 317 318 319 if self._existing_related_instances: 320 for attr, values in self._existing_related_instances.iteritems(): 321 descriptor = getattr(obj.__class__, attr) 322 if len(values) > 1: 323 try: 324 attname = descriptor.field.attname 325 except AttributeError: 326 # If descriptor has no field attribute, 327 # this is usually a reverse one-to-one 328 # relation in which case obj PKs are used 329 val = values[getattr( 330 obj, 331 descriptor.related.field.rel.field_name 332 )] 333 else: 334 val = values[ 335 getattr(obj, attname)] 336 else: 337 val = values.values()[0] 338 339 setattr( 340 obj, 341 descriptor.cache_name, 342 val 343 ) 344 318 345 yield obj 319 346 320 347 def aggregate(self, *args, **kwargs): … … 860 887 c = klass(model=self.model, query=query, using=self._db) 861 888 c._for_write = self._for_write 862 889 c._prefetch_related_lookups = self._prefetch_related_lookups[:] 890 c._existing_related_instances = dict(self._existing_related_instances) 863 891 c.__dict__.update(kwargs) 864 892 if setup and hasattr(c, '_setup_query'): 865 893 c._setup_query()