Ticket #18177: prefetch_related_cache_with_tests.patch
File prefetch_related_cache_with_tests.patch, 13.4 KB (added by , 13 years ago) |
---|
-
tests/modeltests/existing_related_instances/tests.py
1 from __future__ import absolute_import 2 3 from django.test import TestCase 4 5 from .models import Tournament, Pool, PoolStyle 6 7 class EriTests(TestCase): 8 fixtures = ['tournament.json',] 9 10 def test_fk(self): 11 with self.assertNumQueries(2): 12 tournament = Tournament.objects.get(pk=1) 13 pool = tournament.pool_set.all()[0] 14 self.assertIs(tournament, pool.tournament) 15 16 def test_fk_prefetch_related(self): 17 with self.assertNumQueries(2): 18 tournament = ( 19 Tournament.objects.prefetch_related('pool_set') 20 .get(pk=1) 21 ) 22 pool = tournament.pool_set.all()[0] 23 self.assertIs(tournament, pool.tournament) 24 25 def test_fk_multiple_prefetch(self): 26 with self.assertNumQueries(2): 27 tournaments = list( 28 Tournament.objects.prefetch_related('pool_set') 29 ) 30 pool1 = tournaments[0].pool_set.all()[0] 31 self.assertIs(tournaments[0], pool1.tournament) 32 pool2 = tournaments[1].pool_set.all()[0] 33 self.assertIs(tournaments[1], pool2.tournament) 34 35 def test_1t1(self): 36 with self.assertNumQueries(2): 37 style = PoolStyle.objects.get(pk=1) 38 pool = style.pool 39 self.assertIs(style, pool.poolstyle) 40 41 def test_1t1_select_related(self): 42 with self.assertNumQueries(1): 43 style = PoolStyle.objects.select_related('pool').get(pk=1) 44 pool = style.pool 45 self.assertIs(style, pool.poolstyle) 46 47 def test_1t1_prefetch_related(self): 48 with self.assertNumQueries(2): 49 style = PoolStyle.objects.prefetch_related('pool').get(pk=1) 50 pool = style.pool 51 self.assertIs(style, pool.poolstyle) 52 53 def test_1t1r(self): 54 with self.assertNumQueries(2): 55 pool = Pool.objects.get(pk=2) 56 style = pool.poolstyle 57 self.assertIs(pool, style.pool) 58 59 def test_1t1r_select_related(self): 60 with self.assertNumQueries(1): 61 pool = Pool.objects.select_related('poolstyle').get(pk=2) 62 style = pool.poolstyle 63 self.assertIs(pool, style.pool) 64 65 def test_1t1r_prefetch_related(self): 66 with self.assertNumQueries(2): 67 pool = Pool.objects.prefetch_related('poolstyle').get(pk=2) 68 style = pool.poolstyle 69 self.assertIs(pool, style.pool) 70 71 def test_1t1_multi_related(self): 72 with self.assertNumQueries(1): 73 poolstyles = list(PoolStyle.objects.select_related('pool')) 74 self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle) 75 self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle) 76 77 def test_1t1_multi_prefetch(self): 78 with self.assertNumQueries(2): 79 poolstyles = list(PoolStyle.objects.prefetch_related('pool')) 80 self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle) 81 self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle) 82 83 def test_1t1r_multi_related(self): 84 with self.assertNumQueries(1): 85 pools = list(Pool.objects.select_related('poolstyle')) 86 self.assertIs(pools[1], pools[1].poolstyle.pool) 87 self.assertIs(pools[2], pools[2].poolstyle.pool) 88 89 def test_1t1r_multi_prefetch(self): 90 with self.assertNumQueries(2): 91 pools = list(Pool.objects.prefetch_related('poolstyle')) 92 self.assertIs(pools[1], pools[1].poolstyle.pool) 93 self.assertIs(pools[2], pools[2].poolstyle.pool) -
tests/modeltests/existing_related_instances/fixtures/tournament.json
1 [ 2 { 3 "pk": 1, 4 "model": "existing_related_instances.tournament", 5 "fields": { 6 "name": "Tourney 1" 7 } 8 }, 9 { 10 "pk": 2, 11 "model": "existing_related_instances.tournament", 12 "fields": { 13 "name": "Tourney 2" 14 } 15 }, 16 { 17 "pk": 1, 18 "model": "existing_related_instances.pool", 19 "fields": { 20 "tournament": 1, 21 "name": "T1 Pool 1" 22 } 23 }, 24 { 25 "pk": 2, 26 "model": "existing_related_instances.pool", 27 "fields": { 28 "tournament": 1, 29 "name": "T1 Pool 2" 30 } 31 }, 32 { 33 "pk": 3, 34 "model": "existing_related_instances.pool", 35 "fields": { 36 "tournament": 2, 37 "name": "T2 Pool 1" 38 } 39 }, 40 { 41 "pk": 4, 42 "model": "existing_related_instances.pool", 43 "fields": { 44 "tournament": 2, 45 "name": "T2 Pool 2" 46 } 47 }, 48 { 49 "pk": 1, 50 "model": "existing_related_instances.poolstyle", 51 "fields": { 52 "name": "T1 Pool 2 Style", 53 "pool": 2 54 } 55 }, 56 { 57 "pk": 2, 58 "model": "existing_related_instances.poolstyle", 59 "fields": { 60 "name": "T2 Pool 1 Style", 61 "pool": 3 62 } 63 } 64 ] 65 -
tests/modeltests/existing_related_instances/models.py
1 """ 2 1. Existing related object instance caching 3 4 These test that queries are not redone when going back through already 5 explored relations. 6 """ 7 8 from django.db import models 9 10 class Tournament(models.Model): 11 name = models.CharField(max_length=30) 12 13 class Pool(models.Model): 14 name = models.CharField(max_length=30) 15 tournament = models.ForeignKey(Tournament) 16 17 class PoolStyle(models.Model): 18 name = models.CharField(max_length=30) 19 pool = models.OneToOneField(Pool) 20 -
django/db/models/fields/related.py
84 84 for cls, field, operation in pending_lookups.pop(key, []): 85 85 operation(field, sender, cls) 86 86 87 def set_eri(qs, field_name, rel_field_name, instances): 88 qs._existing_related_instances = { 89 field_name: dict( 90 (getattr(instance, rel_field_name), instance) 91 for instance in instances), 92 } 93 87 94 signals.class_prepared.connect(do_pending_lookups) 88 95 89 96 #HACK … … 232 239 def is_cached(self, instance): 233 240 return hasattr(instance, self.cache_name) 234 241 242 def set_eri(self, qs, instances): 243 set_eri(qs, 244 self.related.field.name, 245 self.related.field.rel.field_name, 246 instances) 247 235 248 def get_query_set(self, **db_hints): 236 249 db = router.db_for_read(self.related.model, **db_hints) 237 250 return self.related.model._base_manager.using(db) … … 239 252 def get_prefetch_query_set(self, instances): 240 253 vals = set(instance._get_pk_val() for instance in instances) 241 254 params = {'%s__pk__in' % self.related.field.name: vals} 242 return (self.get_query_set(instance=instances[0]).filter(**params), 255 qs = self.get_query_set(instance=instances[0]).filter(**params) 256 self.set_eri(qs, instances) 257 return (qs, 243 258 attrgetter(self.related.field.attname), 244 259 lambda obj: obj._get_pk_val(), 245 260 True, … … 313 328 def is_cached(self, instance): 314 329 return hasattr(instance, self.cache_name) 315 330 331 def set_eri(self, qs, instances): 332 set_eri(qs, 333 self.field.related.get_accessor_name(), 334 self.field.attname, 335 instances 336 ) 337 316 338 def get_query_set(self, **db_hints): 317 339 db = router.db_for_read(self.field.rel.to, **db_hints) 318 340 rel_mgr = self.field.rel.to._default_manager … … 330 352 params = {'%s__pk__in' % self.field.rel.field_name: vals} 331 353 else: 332 354 params = {'%s__in' % self.field.rel.field_name: vals} 333 return (self.get_query_set(instance=instances[0]).filter(**params), 355 qs = self.get_query_set(instance=instances[0]).filter(**params) 356 self.set_eri(qs, instances) 357 return (qs, 334 358 attrgetter(self.field.rel.field_name), 335 359 attrgetter(self.field.attname), 336 360 True, … … 462 486 } 463 487 self.model = rel_model 464 488 489 def set_eri(self, qs, instances): 490 set_eri(qs, 491 rel_field.name, 492 rel_field.rel.field_name, 493 instances) 494 465 495 def get_query_set(self): 466 496 try: 467 497 return self.instance._prefetched_objects_cache[rel_field.related_query_name()] 468 498 except (AttributeError, KeyError): 469 499 db = self._db or router.db_for_read(self.model, instance=self.instance) 470 return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) 500 qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) 501 self.set_eri(qs, (self.instance,)) 502 return qs 471 503 472 504 def get_prefetch_query_set(self, instances): 473 505 db = self._db or router.db_for_read(self.model, instance=instances[0]) 474 506 query = {'%s__%s__in' % (rel_field.name, attname): 475 507 set(getattr(obj, attname) for obj in instances)} 476 508 qs = super(RelatedManager, self).get_query_set().using(db).filter(**query) 509 self.set_eri(qs, instances) 510 477 511 return (qs, 478 512 attrgetter(rel_field.get_attname()), 479 513 attrgetter(attname), -
django/db/models/query.py
41 41 self._for_write = False 42 42 self._prefetch_related_lookups = [] 43 43 self._prefetch_done = False 44 self._existing_related_instances = {} 44 45 45 46 ######################## 46 47 # PYTHON MAGIC METHODS # … … 294 295 obj, _ = get_cached_row(row, index_start, db, klass_info, 295 296 offset=len(aggregate_select)) 296 297 else: 298 # Omit aggregates in object creation. 299 row_data = row[index_start:aggregate_start] 297 300 if skip: 298 row_data = row[index_start:aggregate_start]299 301 obj = model_cls(**dict(zip(init_list, row_data))) 300 302 else: 301 # Omit aggregates in object creation. 302 obj = model(*row[index_start:aggregate_start]) 303 obj = model(*row_data) 303 304 304 305 # Store the source database of the object 305 306 obj._state.db = db … … 315 316 for i, aggregate in enumerate(aggregate_select): 316 317 setattr(obj, aggregate, row[i+aggregate_start]) 317 318 319 if self._existing_related_instances: 320 for attr, values in self._existing_related_instances.iteritems(): 321 descriptor = getattr(obj.__class__, attr) 322 if len(values) > 1: 323 try: 324 attname = descriptor.field.attname 325 except AttributeError: 326 # If descriptor has no field attribute, 327 # this is usually a reverse one-to-one 328 # relation in which case obj PKs are used 329 val = values[getattr( 330 obj, 331 descriptor.related.field.rel.field_name 332 )] 333 else: 334 val = values[ 335 getattr(obj, attname)] 336 else: 337 val = values.values()[0] 338 339 setattr( 340 obj, 341 descriptor.cache_name, 342 val 343 ) 344 318 345 yield obj 319 346 320 347 def aggregate(self, *args, **kwargs): … … 860 887 c = klass(model=self.model, query=query, using=self._db) 861 888 c._for_write = self._for_write 862 889 c._prefetch_related_lookups = self._prefetch_related_lookups[:] 890 c._existing_related_instances = dict(self._existing_related_instances) 863 891 c.__dict__.update(kwargs) 864 892 if setup and hasattr(c, '_setup_query'): 865 893 c._setup_query()