Ticket #18177: prefetch_related_cache_with_tests.patch

File prefetch_related_cache_with_tests.patch, 13.4 KB (added by kaiser.yann@…, 12 years ago)

Patch against r17916 (Forgot to add the tests first time around.)

  • tests/modeltests/existing_related_instances/tests.py

     
     1from __future__ import absolute_import
     2
     3from django.test import TestCase
     4
     5from .models import Tournament, Pool, PoolStyle
     6
     7class EriTests(TestCase):
     8    fixtures = ['tournament.json',]
     9
     10    def test_fk(self):
     11        with self.assertNumQueries(2):
     12            tournament = Tournament.objects.get(pk=1)
     13            pool = tournament.pool_set.all()[0]
     14            self.assertIs(tournament, pool.tournament)
     15
     16    def test_fk_prefetch_related(self):
     17        with self.assertNumQueries(2):
     18            tournament = (
     19                Tournament.objects.prefetch_related('pool_set')
     20                .get(pk=1)
     21                )
     22            pool = tournament.pool_set.all()[0]
     23            self.assertIs(tournament, pool.tournament)
     24
     25    def test_fk_multiple_prefetch(self):
     26        with self.assertNumQueries(2):
     27            tournaments = list(
     28                Tournament.objects.prefetch_related('pool_set')
     29                )
     30            pool1 = tournaments[0].pool_set.all()[0]
     31            self.assertIs(tournaments[0], pool1.tournament)
     32            pool2 = tournaments[1].pool_set.all()[0]
     33            self.assertIs(tournaments[1], pool2.tournament)
     34
     35    def test_1t1(self):
     36        with self.assertNumQueries(2):
     37            style = PoolStyle.objects.get(pk=1)
     38            pool = style.pool
     39            self.assertIs(style, pool.poolstyle)
     40
     41    def test_1t1_select_related(self):
     42        with self.assertNumQueries(1):
     43            style = PoolStyle.objects.select_related('pool').get(pk=1)
     44            pool = style.pool
     45            self.assertIs(style, pool.poolstyle)
     46
     47    def test_1t1_prefetch_related(self):
     48        with self.assertNumQueries(2):
     49            style = PoolStyle.objects.prefetch_related('pool').get(pk=1)
     50            pool = style.pool
     51            self.assertIs(style, pool.poolstyle)
     52
     53    def test_1t1r(self):
     54        with self.assertNumQueries(2):
     55            pool = Pool.objects.get(pk=2)
     56            style = pool.poolstyle
     57            self.assertIs(pool, style.pool)
     58
     59    def test_1t1r_select_related(self):
     60        with self.assertNumQueries(1):
     61            pool = Pool.objects.select_related('poolstyle').get(pk=2)
     62            style = pool.poolstyle
     63            self.assertIs(pool, style.pool)
     64
     65    def test_1t1r_prefetch_related(self):
     66        with self.assertNumQueries(2):
     67            pool = Pool.objects.prefetch_related('poolstyle').get(pk=2)
     68            style = pool.poolstyle
     69            self.assertIs(pool, style.pool)
     70
     71    def test_1t1_multi_related(self):
     72        with self.assertNumQueries(1):
     73            poolstyles = list(PoolStyle.objects.select_related('pool'))
     74            self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle)
     75            self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle)
     76
     77    def test_1t1_multi_prefetch(self):
     78        with self.assertNumQueries(2):
     79            poolstyles = list(PoolStyle.objects.prefetch_related('pool'))
     80            self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle)
     81            self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle)
     82
     83    def test_1t1r_multi_related(self):
     84        with self.assertNumQueries(1):
     85            pools = list(Pool.objects.select_related('poolstyle'))
     86            self.assertIs(pools[1], pools[1].poolstyle.pool)
     87            self.assertIs(pools[2], pools[2].poolstyle.pool)
     88
     89    def test_1t1r_multi_prefetch(self):
     90        with self.assertNumQueries(2):
     91            pools = list(Pool.objects.prefetch_related('poolstyle'))
     92            self.assertIs(pools[1], pools[1].poolstyle.pool)
     93            self.assertIs(pools[2], pools[2].poolstyle.pool)
  • tests/modeltests/existing_related_instances/fixtures/tournament.json

     
     1[
     2    {
     3        "pk": 1,
     4        "model": "existing_related_instances.tournament",
     5        "fields": {
     6            "name": "Tourney 1"
     7            }
     8        },
     9    {
     10        "pk": 2,
     11        "model": "existing_related_instances.tournament",
     12        "fields": {
     13            "name": "Tourney 2"
     14            }
     15        },
     16    {
     17        "pk": 1,
     18        "model": "existing_related_instances.pool",
     19        "fields": {
     20            "tournament": 1,
     21            "name": "T1 Pool 1"
     22            }
     23        },
     24    {
     25        "pk": 2,
     26        "model": "existing_related_instances.pool",
     27        "fields": {
     28            "tournament": 1,
     29            "name": "T1 Pool 2"
     30            }
     31        },
     32    {
     33        "pk": 3,
     34        "model": "existing_related_instances.pool",
     35        "fields": {
     36            "tournament": 2,
     37            "name": "T2 Pool 1"
     38            }
     39        },
     40    {
     41        "pk": 4,
     42        "model": "existing_related_instances.pool",
     43        "fields": {
     44            "tournament": 2,
     45            "name": "T2 Pool 2"
     46            }
     47        },
     48    {
     49        "pk": 1,
     50        "model": "existing_related_instances.poolstyle",
     51        "fields": {
     52            "name": "T1 Pool 2 Style",
     53            "pool": 2
     54            }
     55        },
     56    {
     57        "pk": 2,
     58        "model": "existing_related_instances.poolstyle",
     59        "fields": {
     60            "name": "T2 Pool 1 Style",
     61            "pool": 3
     62            }
     63        }
     64]
     65
  • tests/modeltests/existing_related_instances/models.py

     
     1"""
     21. Existing related object instance caching
     3
     4These test that queries are not redone when going back through already
     5explored relations.
     6"""
     7
     8from django.db import models
     9
     10class Tournament(models.Model):
     11    name = models.CharField(max_length=30)
     12
     13class Pool(models.Model):
     14    name = models.CharField(max_length=30)
     15    tournament = models.ForeignKey(Tournament)
     16
     17class PoolStyle(models.Model):
     18    name = models.CharField(max_length=30)
     19    pool = models.OneToOneField(Pool)
     20
  • django/db/models/fields/related.py

     
    8484    for cls, field, operation in pending_lookups.pop(key, []):
    8585        operation(field, sender, cls)
    8686
     87def set_eri(qs, field_name, rel_field_name, instances):
     88    qs._existing_related_instances = {
     89        field_name: dict(
     90            (getattr(instance, rel_field_name), instance)
     91            for instance in instances),
     92        }
     93
    8794signals.class_prepared.connect(do_pending_lookups)
    8895
    8996#HACK
     
    232239    def is_cached(self, instance):
    233240        return hasattr(instance, self.cache_name)
    234241
     242    def set_eri(self, qs, instances):
     243        set_eri(qs,
     244                self.related.field.name,
     245                self.related.field.rel.field_name,
     246                instances)
     247
    235248    def get_query_set(self, **db_hints):
    236249        db = router.db_for_read(self.related.model, **db_hints)
    237250        return self.related.model._base_manager.using(db)
     
    239252    def get_prefetch_query_set(self, instances):
    240253        vals = set(instance._get_pk_val() for instance in instances)
    241254        params = {'%s__pk__in' % self.related.field.name: vals}
    242         return (self.get_query_set(instance=instances[0]).filter(**params),
     255        qs = self.get_query_set(instance=instances[0]).filter(**params)
     256        self.set_eri(qs, instances)
     257        return (qs,
    243258                attrgetter(self.related.field.attname),
    244259                lambda obj: obj._get_pk_val(),
    245260                True,
     
    313328    def is_cached(self, instance):
    314329        return hasattr(instance, self.cache_name)
    315330
     331    def set_eri(self, qs, instances):
     332        set_eri(qs,
     333                self.field.related.get_accessor_name(),
     334                self.field.attname,
     335                instances
     336                )
     337
    316338    def get_query_set(self, **db_hints):
    317339        db = router.db_for_read(self.field.rel.to, **db_hints)
    318340        rel_mgr = self.field.rel.to._default_manager
     
    330352            params = {'%s__pk__in' % self.field.rel.field_name: vals}
    331353        else:
    332354            params = {'%s__in' % self.field.rel.field_name: vals}
    333         return (self.get_query_set(instance=instances[0]).filter(**params),
     355        qs = self.get_query_set(instance=instances[0]).filter(**params)
     356        self.set_eri(qs, instances)
     357        return (qs,
    334358                attrgetter(self.field.rel.field_name),
    335359                attrgetter(self.field.attname),
    336360                True,
     
    462486                }
    463487                self.model = rel_model
    464488
     489            def set_eri(self, qs, instances):
     490                set_eri(qs,
     491                        rel_field.name,
     492                        rel_field.rel.field_name,
     493                        instances)
     494
    465495            def get_query_set(self):
    466496                try:
    467497                    return self.instance._prefetched_objects_cache[rel_field.related_query_name()]
    468498                except (AttributeError, KeyError):
    469499                    db = self._db or router.db_for_read(self.model, instance=self.instance)
    470                     return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
     500                    qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
     501                    self.set_eri(qs, (self.instance,))
     502                    return qs
    471503
    472504            def get_prefetch_query_set(self, instances):
    473505                db = self._db or router.db_for_read(self.model, instance=instances[0])
    474506                query = {'%s__%s__in' % (rel_field.name, attname):
    475507                             set(getattr(obj, attname) for obj in instances)}
    476508                qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
     509                self.set_eri(qs, instances)
     510
    477511                return (qs,
    478512                        attrgetter(rel_field.get_attname()),
    479513                        attrgetter(attname),
  • django/db/models/query.py

     
    4141        self._for_write = False
    4242        self._prefetch_related_lookups = []
    4343        self._prefetch_done = False
     44        self._existing_related_instances = {}
    4445
    4546    ########################
    4647    # PYTHON MAGIC METHODS #
     
    294295                obj, _ = get_cached_row(row, index_start, db, klass_info,
    295296                                        offset=len(aggregate_select))
    296297            else:
     298                # Omit aggregates in object creation.
     299                row_data = row[index_start:aggregate_start]
    297300                if skip:
    298                     row_data = row[index_start:aggregate_start]
    299301                    obj = model_cls(**dict(zip(init_list, row_data)))
    300302                else:
    301                     # Omit aggregates in object creation.
    302                     obj = model(*row[index_start:aggregate_start])
     303                    obj = model(*row_data)
    303304
    304305                # Store the source database of the object
    305306                obj._state.db = db
     
    315316                for i, aggregate in enumerate(aggregate_select):
    316317                    setattr(obj, aggregate, row[i+aggregate_start])
    317318
     319            if self._existing_related_instances:
     320                for attr, values in self._existing_related_instances.iteritems():
     321                    descriptor = getattr(obj.__class__, attr)
     322                    if len(values) > 1:
     323                        try:
     324                            attname = descriptor.field.attname
     325                        except AttributeError:
     326                            # If descriptor has no field attribute,
     327                            # this is usually a reverse one-to-one
     328                            # relation in which case obj PKs are used
     329                            val = values[getattr(
     330                                obj,
     331                                descriptor.related.field.rel.field_name
     332                                )]
     333                        else:
     334                            val = values[
     335                                getattr(obj, attname)]
     336                    else:
     337                        val = values.values()[0]
     338
     339                    setattr(
     340                        obj,
     341                        descriptor.cache_name,
     342                        val
     343                        )
     344
    318345            yield obj
    319346
    320347    def aggregate(self, *args, **kwargs):
     
    860887        c = klass(model=self.model, query=query, using=self._db)
    861888        c._for_write = self._for_write
    862889        c._prefetch_related_lookups = self._prefetch_related_lookups[:]
     890        c._existing_related_instances = dict(self._existing_related_instances)
    863891        c.__dict__.update(kwargs)
    864892        if setup and hasattr(c, '_setup_query'):
    865893            c._setup_query()
Back to Top