Code

Ticket #18177: prefetch_related_cache_with_tests.patch

File prefetch_related_cache_with_tests.patch, 13.4 KB (added by kaiser.yann@…, 2 years ago)

Patch against r17916 (Forgot to add the tests first time around.)

  • tests/modeltests/existing_related_instances/tests.py

     
     1from __future__ import absolute_import 
     2 
     3from django.test import TestCase 
     4 
     5from .models import Tournament, Pool, PoolStyle 
     6 
     7class EriTests(TestCase): 
     8    fixtures = ['tournament.json',] 
     9 
     10    def test_fk(self): 
     11        with self.assertNumQueries(2): 
     12            tournament = Tournament.objects.get(pk=1) 
     13            pool = tournament.pool_set.all()[0] 
     14            self.assertIs(tournament, pool.tournament) 
     15 
     16    def test_fk_prefetch_related(self): 
     17        with self.assertNumQueries(2): 
     18            tournament = ( 
     19                Tournament.objects.prefetch_related('pool_set') 
     20                .get(pk=1) 
     21                ) 
     22            pool = tournament.pool_set.all()[0] 
     23            self.assertIs(tournament, pool.tournament) 
     24 
     25    def test_fk_multiple_prefetch(self): 
     26        with self.assertNumQueries(2): 
     27            tournaments = list( 
     28                Tournament.objects.prefetch_related('pool_set') 
     29                ) 
     30            pool1 = tournaments[0].pool_set.all()[0] 
     31            self.assertIs(tournaments[0], pool1.tournament) 
     32            pool2 = tournaments[1].pool_set.all()[0] 
     33            self.assertIs(tournaments[1], pool2.tournament) 
     34 
     35    def test_1t1(self): 
     36        with self.assertNumQueries(2): 
     37            style = PoolStyle.objects.get(pk=1) 
     38            pool = style.pool 
     39            self.assertIs(style, pool.poolstyle) 
     40 
     41    def test_1t1_select_related(self): 
     42        with self.assertNumQueries(1): 
     43            style = PoolStyle.objects.select_related('pool').get(pk=1) 
     44            pool = style.pool 
     45            self.assertIs(style, pool.poolstyle) 
     46 
     47    def test_1t1_prefetch_related(self): 
     48        with self.assertNumQueries(2): 
     49            style = PoolStyle.objects.prefetch_related('pool').get(pk=1) 
     50            pool = style.pool 
     51            self.assertIs(style, pool.poolstyle) 
     52 
     53    def test_1t1r(self): 
     54        with self.assertNumQueries(2): 
     55            pool = Pool.objects.get(pk=2) 
     56            style = pool.poolstyle 
     57            self.assertIs(pool, style.pool) 
     58 
     59    def test_1t1r_select_related(self): 
     60        with self.assertNumQueries(1): 
     61            pool = Pool.objects.select_related('poolstyle').get(pk=2) 
     62            style = pool.poolstyle 
     63            self.assertIs(pool, style.pool) 
     64 
     65    def test_1t1r_prefetch_related(self): 
     66        with self.assertNumQueries(2): 
     67            pool = Pool.objects.prefetch_related('poolstyle').get(pk=2) 
     68            style = pool.poolstyle 
     69            self.assertIs(pool, style.pool) 
     70 
     71    def test_1t1_multi_related(self): 
     72        with self.assertNumQueries(1): 
     73            poolstyles = list(PoolStyle.objects.select_related('pool')) 
     74            self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle) 
     75            self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle) 
     76 
     77    def test_1t1_multi_prefetch(self): 
     78        with self.assertNumQueries(2): 
     79            poolstyles = list(PoolStyle.objects.prefetch_related('pool')) 
     80            self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle) 
     81            self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle) 
     82 
     83    def test_1t1r_multi_related(self): 
     84        with self.assertNumQueries(1): 
     85            pools = list(Pool.objects.select_related('poolstyle')) 
     86            self.assertIs(pools[1], pools[1].poolstyle.pool) 
     87            self.assertIs(pools[2], pools[2].poolstyle.pool) 
     88 
     89    def test_1t1r_multi_prefetch(self): 
     90        with self.assertNumQueries(2): 
     91            pools = list(Pool.objects.prefetch_related('poolstyle')) 
     92            self.assertIs(pools[1], pools[1].poolstyle.pool) 
     93            self.assertIs(pools[2], pools[2].poolstyle.pool) 
  • tests/modeltests/existing_related_instances/fixtures/tournament.json

     
     1[ 
     2    { 
     3        "pk": 1, 
     4        "model": "existing_related_instances.tournament", 
     5        "fields": { 
     6            "name": "Tourney 1" 
     7            } 
     8        }, 
     9    { 
     10        "pk": 2, 
     11        "model": "existing_related_instances.tournament", 
     12        "fields": { 
     13            "name": "Tourney 2" 
     14            } 
     15        }, 
     16    { 
     17        "pk": 1, 
     18        "model": "existing_related_instances.pool", 
     19        "fields": { 
     20            "tournament": 1, 
     21            "name": "T1 Pool 1" 
     22            } 
     23        }, 
     24    { 
     25        "pk": 2, 
     26        "model": "existing_related_instances.pool", 
     27        "fields": { 
     28            "tournament": 1, 
     29            "name": "T1 Pool 2" 
     30            } 
     31        }, 
     32    { 
     33        "pk": 3, 
     34        "model": "existing_related_instances.pool", 
     35        "fields": { 
     36            "tournament": 2, 
     37            "name": "T2 Pool 1" 
     38            } 
     39        }, 
     40    { 
     41        "pk": 4, 
     42        "model": "existing_related_instances.pool", 
     43        "fields": { 
     44            "tournament": 2, 
     45            "name": "T2 Pool 2" 
     46            } 
     47        }, 
     48    { 
     49        "pk": 1, 
     50        "model": "existing_related_instances.poolstyle", 
     51        "fields": { 
     52            "name": "T1 Pool 2 Style", 
     53            "pool": 2 
     54            } 
     55        }, 
     56    { 
     57        "pk": 2, 
     58        "model": "existing_related_instances.poolstyle", 
     59        "fields": { 
     60            "name": "T2 Pool 1 Style", 
     61            "pool": 3 
     62            } 
     63        } 
     64] 
     65 
  • tests/modeltests/existing_related_instances/models.py

     
     1""" 
     21. Existing related object instance caching 
     3 
     4These test that queries are not redone when going back through already 
     5explored relations. 
     6""" 
     7 
     8from django.db import models 
     9 
     10class Tournament(models.Model): 
     11    name = models.CharField(max_length=30) 
     12 
     13class Pool(models.Model): 
     14    name = models.CharField(max_length=30) 
     15    tournament = models.ForeignKey(Tournament) 
     16 
     17class PoolStyle(models.Model): 
     18    name = models.CharField(max_length=30) 
     19    pool = models.OneToOneField(Pool) 
     20 
  • django/db/models/fields/related.py

     
    8484    for cls, field, operation in pending_lookups.pop(key, []): 
    8585        operation(field, sender, cls) 
    8686 
     87def set_eri(qs, field_name, rel_field_name, instances): 
     88    qs._existing_related_instances = { 
     89        field_name: dict( 
     90            (getattr(instance, rel_field_name), instance) 
     91            for instance in instances), 
     92        } 
     93 
    8794signals.class_prepared.connect(do_pending_lookups) 
    8895 
    8996#HACK 
     
    232239    def is_cached(self, instance): 
    233240        return hasattr(instance, self.cache_name) 
    234241 
     242    def set_eri(self, qs, instances): 
     243        set_eri(qs, 
     244                self.related.field.name, 
     245                self.related.field.rel.field_name, 
     246                instances) 
     247 
    235248    def get_query_set(self, **db_hints): 
    236249        db = router.db_for_read(self.related.model, **db_hints) 
    237250        return self.related.model._base_manager.using(db) 
     
    239252    def get_prefetch_query_set(self, instances): 
    240253        vals = set(instance._get_pk_val() for instance in instances) 
    241254        params = {'%s__pk__in' % self.related.field.name: vals} 
    242         return (self.get_query_set(instance=instances[0]).filter(**params), 
     255        qs = self.get_query_set(instance=instances[0]).filter(**params) 
     256        self.set_eri(qs, instances) 
     257        return (qs, 
    243258                attrgetter(self.related.field.attname), 
    244259                lambda obj: obj._get_pk_val(), 
    245260                True, 
     
    313328    def is_cached(self, instance): 
    314329        return hasattr(instance, self.cache_name) 
    315330 
     331    def set_eri(self, qs, instances): 
     332        set_eri(qs, 
     333                self.field.related.get_accessor_name(), 
     334                self.field.attname, 
     335                instances 
     336                ) 
     337 
    316338    def get_query_set(self, **db_hints): 
    317339        db = router.db_for_read(self.field.rel.to, **db_hints) 
    318340        rel_mgr = self.field.rel.to._default_manager 
     
    330352            params = {'%s__pk__in' % self.field.rel.field_name: vals} 
    331353        else: 
    332354            params = {'%s__in' % self.field.rel.field_name: vals} 
    333         return (self.get_query_set(instance=instances[0]).filter(**params), 
     355        qs = self.get_query_set(instance=instances[0]).filter(**params) 
     356        self.set_eri(qs, instances) 
     357        return (qs, 
    334358                attrgetter(self.field.rel.field_name), 
    335359                attrgetter(self.field.attname), 
    336360                True, 
     
    462486                } 
    463487                self.model = rel_model 
    464488 
     489            def set_eri(self, qs, instances): 
     490                set_eri(qs, 
     491                        rel_field.name, 
     492                        rel_field.rel.field_name, 
     493                        instances) 
     494 
    465495            def get_query_set(self): 
    466496                try: 
    467497                    return self.instance._prefetched_objects_cache[rel_field.related_query_name()] 
    468498                except (AttributeError, KeyError): 
    469499                    db = self._db or router.db_for_read(self.model, instance=self.instance) 
    470                     return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) 
     500                    qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) 
     501                    self.set_eri(qs, (self.instance,)) 
     502                    return qs 
    471503 
    472504            def get_prefetch_query_set(self, instances): 
    473505                db = self._db or router.db_for_read(self.model, instance=instances[0]) 
    474506                query = {'%s__%s__in' % (rel_field.name, attname): 
    475507                             set(getattr(obj, attname) for obj in instances)} 
    476508                qs = super(RelatedManager, self).get_query_set().using(db).filter(**query) 
     509                self.set_eri(qs, instances) 
     510 
    477511                return (qs, 
    478512                        attrgetter(rel_field.get_attname()), 
    479513                        attrgetter(attname), 
  • django/db/models/query.py

     
    4141        self._for_write = False 
    4242        self._prefetch_related_lookups = [] 
    4343        self._prefetch_done = False 
     44        self._existing_related_instances = {} 
    4445 
    4546    ######################## 
    4647    # PYTHON MAGIC METHODS # 
     
    294295                obj, _ = get_cached_row(row, index_start, db, klass_info, 
    295296                                        offset=len(aggregate_select)) 
    296297            else: 
     298                # Omit aggregates in object creation. 
     299                row_data = row[index_start:aggregate_start] 
    297300                if skip: 
    298                     row_data = row[index_start:aggregate_start] 
    299301                    obj = model_cls(**dict(zip(init_list, row_data))) 
    300302                else: 
    301                     # Omit aggregates in object creation. 
    302                     obj = model(*row[index_start:aggregate_start]) 
     303                    obj = model(*row_data) 
    303304 
    304305                # Store the source database of the object 
    305306                obj._state.db = db 
     
    315316                for i, aggregate in enumerate(aggregate_select): 
    316317                    setattr(obj, aggregate, row[i+aggregate_start]) 
    317318 
     319            if self._existing_related_instances: 
     320                for attr, values in self._existing_related_instances.iteritems(): 
     321                    descriptor = getattr(obj.__class__, attr) 
     322                    if len(values) > 1: 
     323                        try: 
     324                            attname = descriptor.field.attname 
     325                        except AttributeError: 
     326                            # If descriptor has no field attribute, 
     327                            # this is usually a reverse one-to-one 
     328                            # relation in which case obj PKs are used 
     329                            val = values[getattr( 
     330                                obj, 
     331                                descriptor.related.field.rel.field_name 
     332                                )] 
     333                        else: 
     334                            val = values[ 
     335                                getattr(obj, attname)] 
     336                    else: 
     337                        val = values.values()[0] 
     338 
     339                    setattr( 
     340                        obj, 
     341                        descriptor.cache_name, 
     342                        val 
     343                        ) 
     344 
    318345            yield obj 
    319346 
    320347    def aggregate(self, *args, **kwargs): 
     
    860887        c = klass(model=self.model, query=query, using=self._db) 
    861888        c._for_write = self._for_write 
    862889        c._prefetch_related_lookups = self._prefetch_related_lookups[:] 
     890        c._existing_related_instances = dict(self._existing_related_instances) 
    863891        c.__dict__.update(kwargs) 
    864892        if setup and hasattr(c, '_setup_query'): 
    865893            c._setup_query()