Ticket #7270: reverse_select_related.diff

File reverse_select_related.diff, 11.5 KB (added by Alex Gaynor, 14 years ago)
  • django/db/models/fields/related.py

    diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
    index 8fec836..8e27244 100644
    a b class SingleRelatedObjectDescriptor(object):  
    188188    # SingleRelatedObjectDescriptor instance.
    189189    def __init__(self, related):
    190190        self.related = related
    191         self.cache_name = '_%s_cache' % related.get_accessor_name()
     191        self.cache_name = related.get_accessor_cache()
    192192
    193193    def __get__(self, instance, instance_type=None):
    194194        if instance is None:
    class ReverseSingleRelatedObjectDescriptor(object):  
    307307            # cache. This cache also might not exist if the related object
    308308            # hasn't been accessed yet.
    309309            if related:
    310                 cache_name = '_%s_cache' % self.field.related.get_accessor_name()
     310                cache_name = self.field.related.get_accessor_cache()
    311311                try:
    312312                    delattr(related, cache_name)
    313313                except AttributeError:
  • django/db/models/query.py

    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index 4e3326a..afbcc24 100644
    a b def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,  
    11471147            rel_obj, index_end = cached_row
    11481148            if obj is not None:
    11491149                setattr(obj, f.get_cache_name(), rel_obj)
     1150            if f.unique:
     1151                setattr(rel_obj, f.related.get_accessor_cache(), obj)
     1152
     1153    if restricted:
     1154        related_fields = [(o.field, o.model) for o in klass._meta.get_all_related_objects()
     1155            if o.field.unique and o.field.related_query_name() in requested]
     1156        for f, model in related_fields:
     1157            next = requested.get(f.related_query_name(), {})
     1158            cached_row = get_cached_row(model, row, index_end, max_depth,
     1159                cur_depth+1, next)
     1160            if cached_row:
     1161                rel_obj, index_end = cached_row
     1162                if obj is not None:
     1163                    setattr(obj, f.related.get_accessor_cache(), rel_obj)
     1164                if rel_obj is not None:
     1165                    setattr(rel_obj, f.get_cache_name(), obj)
     1166
     1167
    11501168    return obj, index_end
    11511169
    11521170def delete_objects(seen_objs, using):
  • django/db/models/related.py

    diff --git a/django/db/models/related.py b/django/db/models/related.py
    index afdf3f7..54258ca 100644
    a b class RelatedObject(object):  
    4545            return self.field.rel.related_name or (self.opts.object_name.lower() + '_set')
    4646        else:
    4747            return self.field.rel.related_name or (self.opts.object_name.lower())
     48
     49    def get_accessor_cache(self):
     50        return "_%s_cache" % self.get_accessor_name()
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 6a95d32..99d4c7e 100644
    a b class SQLCompiler(object):  
    520520
    521521        # Setup for the case when only particular related fields should be
    522522        # included in the related selection.
    523         if requested is None and restricted is not False:
     523        if requested is None:
    524524            if isinstance(self.query.select_related, dict):
    525525                requested = self.query.select_related
    526526                restricted = True
    class SQLCompiler(object):  
    600600            self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
    601601                    used, next, restricted, new_nullable, dupe_set, avoid)
    602602
     603        if restricted and requested is not None:
     604            related_fields = [(o.field, o.model) for o in opts.get_all_related_objects()
     605                if o.field.unique and o.field.related_query_name() in requested
     606            ]
     607            for f, model in related_fields:
     608                table = model._meta.db_table
     609                int_opts = opts
     610                alias = root_alias
     611                alias_chain = []
     612                chain = opts.get_base_chain(f.rel.to)
     613                avoid = avoid_set.copy()
     614                if chain is not None:
     615                    for int_model in chain:
     616                        if not int_opts.parents[int_model]:
     617                            int_opts = int_model._meta
     618                            continue
     619                        lhs_col = int_opts.parents[int_model].column
     620                        dedupe = lhs_col in opts.duplicate_targets
     621                        if dedupe:
     622                            avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col), ())
     623                            dupe_set.add((opts, lhs_col))
     624                        int_opts = int_model._meta
     625                        alias = self.query.join(
     626                            (alias, int_opts.db_table, lhs_col, int_opts.pk.column),
     627                            exclusions=used, promote=True, reuse=used
     628                        )
     629                        alias_chain.append(alias)
     630                        for dupe_opts, dupe_col in dupe_set:
     631                            self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias)
     632                    dedupe = f.column in opts.duplicate_targets
     633                    if dupe_set or dedupe:
     634                        avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ()))
     635                        if dedupe:
     636                            dupe_set.add((opts, f.column))
     637                alias = self.query.join(
     638                    (alias, table, f.rel.get_related_field().column, f.column),
     639                    exclusions=used.union(avoid),
     640                    promote=True
     641                )
     642                used.add(alias)
     643                columns, aliases = self.get_default_columns(start_alias=alias,
     644                    opts=model._meta, as_pairs=True)
     645                self.query.related_select_cols.extend(columns)
     646                self.query.related_select_fields.extend(model._meta.fields)
     647
     648                next = requested.get(f.related_query_name(), {})
     649                new_nullable = f.null or None
     650
     651                self.fill_related_selections(model._meta, table, cur_depth+1,
     652                    used, next, restricted, new_nullable)
     653
    603654    def deferred_to_columns(self):
    604655        """
    605656        Converts the self.deferred_loading data structure to mapping of table
  • new file tests/regressiontests/select_related_onetoone/models.py

    diff --git a/tests/regressiontests/select_related_onetoone/__init__.py b/tests/regressiontests/select_related_onetoone/__init__.py
    new file mode 100644
    index 0000000..e69de29
    diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
    new file mode 100644
    index 0000000..05efadf
    - +  
     1from django.db import models
     2
     3
     4class User(models.Model):
     5    username = models.CharField(max_length=100)
     6    email = models.EmailField()
     7
     8    def __unicode__(self):
     9        return self.username
     10
     11
     12class UserProfile(models.Model):
     13    user = models.OneToOneField(User)
     14    city = models.CharField(max_length=100)
     15    state = models.CharField(max_length=2)
     16
     17    def __unicode__(self):
     18        return "%s, %s" % (self.city, self.state)
     19
     20
     21class UserStatResult(models.Model):
     22    results = models.CharField(max_length=50)
     23
     24    def __unicode__(self):
     25        return 'UserStatResults, results = %s' % (self.results,)
     26   
     27
     28class UserStat(models.Model):
     29    user = models.OneToOneField(User, primary_key=True)
     30    posts = models.IntegerField()
     31    results = models.ForeignKey(UserStatResult)
     32
     33    def __unicode__(self):
     34        return 'UserStat, posts = %s' % (self.posts,)
     35
     36class StatDetails(models.Model):
     37    base_stats = models.OneToOneField(UserStat)
     38    comments = models.IntegerField()
     39
     40    def __unicode__(self):
     41        return 'StatDetails, comments = %s' % (self.comments,)
     42
     43class AdvancedUserStat(UserStat):
     44    pass
  • new file tests/regressiontests/select_related_onetoone/tests.py

    diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
    new file mode 100644
    index 0000000..08e798b
    - +  
     1from django import db
     2from django.conf import settings
     3from django.test import TestCase
     4
     5from models import User, UserProfile, UserStat, UserStatResult, StatDetails, AdvancedUserStat
     6
     7class ReverseSelectRelatedTestCase(TestCase):
     8    def setUp(self):
     9        self.old_debug = settings.DEBUG
     10        settings.DEBUG = True
     11
     12        user = User.objects.create(username="test")
     13        userprofile = UserProfile.objects.create(user=user, state="KS",
     14                                                 city="Lawrence")
     15        results = UserStatResult.objects.create(results='first results')
     16        userstat = UserStat.objects.create(user=user, posts=150,
     17                                           results=results)
     18        details = StatDetails.objects.create(base_stats=userstat, comments=259)
     19
     20        user2 = User.objects.create(username="bob")
     21        results2 = UserStatResult.objects.create(results='moar results')
     22        advstat = AdvancedUserStat.objects.create(user=user2, posts=200,
     23                                                  results=results2)
     24        StatDetails.objects.create(base_stats=advstat, comments=250)
     25
     26        db.reset_queries()
     27
     28    def assertQueries(self, queries):
     29        self.assertEqual(len(db.connection.queries), queries)
     30
     31    def tearDown(self):
     32        settings.DEBUG = self.old_debug
     33
     34    def test_basic(self):
     35        u = User.objects.select_related("userprofile").get(username="test")
     36        self.assertEqual(u.userprofile.state, "KS")
     37        self.assertQueries(1)
     38
     39    def test_follow_next_level(self):
     40        u = User.objects.select_related("userstat__results").get(username="test")
     41        self.assertEqual(u.userstat.posts, 150)
     42        self.assertEqual(u.userstat.results.results, 'first results')
     43        self.assertQueries(1)
     44
     45    def test_follow_two(self):
     46        u = User.objects.select_related("userprofile", "userstat").get(username="test")
     47        self.assertEqual(u.userprofile.state, "KS")
     48        self.assertEqual(u.userstat.posts, 150)
     49        self.assertQueries(1)
     50
     51    def test_follow_two_next_level(self):
     52        u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
     53        self.assertEqual(u.userstat.results.results, 'first results')
     54        self.assertEqual(u.userstat.statdetails.comments, 259)
     55        self.assertQueries(1)
     56
     57    def test_forward_and_back(self):
     58        stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
     59        self.assertEqual(stat.user.userprofile.state, 'KS')
     60        self.assertEqual(stat.user.userstat.posts, 150)
     61        self.assertQueries(1)
     62
     63    def test_back_and_forward(self):
     64        u = User.objects.select_related("userstat").get(username="test")
     65        self.assertEqual(u.userstat.user.username, 'test')
     66        self.assertQueries(1)
     67
     68    def test_not_followed_by_default(self):
     69        u = User.objects.select_related().get(username="test")
     70        self.assertEqual(u.userstat.posts, 150)
     71        self.assertQueries(2)
     72
     73    def test_follow_from_child_class(self):
     74        stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200)
     75        self.assertEqual(stat.statdetails.comments, 250)
     76        self.assertQueries(1)
     77
     78    def test_follow_inheritance(self):
     79        stat = UserStat.objects.select_related('advanceduserstat').get(posts=200)
     80        self.assertEqual(stat.advanceduserstat.posts, 200)
     81        self.assertQueries(1)
Back to Top