Django

Code

Ticket #7270: reverse_select_related.diff

File reverse_select_related.diff, 11.5 kB (added by Alex, 2 months ago)
  • a/django/db/models/fields/related.py

    old new  
    188188    # SingleRelatedObjectDescriptor instance. 
    189189    def __init__(self, related): 
    190190        self.related = related 
    191         self.cache_name = '_%s_cache' % related.get_accessor_name() 
     191        self.cache_name = related.get_accessor_cache() 
    192192 
    193193    def __get__(self, instance, instance_type=None): 
    194194        if instance is None: 
     
    307307            # cache. This cache also might not exist if the related object 
    308308            # hasn't been accessed yet. 
    309309            if related: 
    310                 cache_name = '_%s_cache' % self.field.related.get_accessor_name() 
     310                cache_name = self.field.related.get_accessor_cache() 
    311311                try: 
    312312                    delattr(related, cache_name) 
    313313                except AttributeError: 
  • a/django/db/models/query.py

    old new  
    11471147            rel_obj, index_end = cached_row 
    11481148            if obj is not None: 
    11491149                setattr(obj, f.get_cache_name(), rel_obj) 
     1150            if f.unique: 
     1151                setattr(rel_obj, f.related.get_accessor_cache(), obj) 
     1152 
     1153    if restricted: 
     1154        related_fields = [(o.field, o.model) for o in klass._meta.get_all_related_objects() 
     1155            if o.field.unique and o.field.related_query_name() in requested] 
     1156        for f, model in related_fields: 
     1157            next = requested.get(f.related_query_name(), {}) 
     1158            cached_row = get_cached_row(model, row, index_end, max_depth, 
     1159                cur_depth+1, next) 
     1160            if cached_row: 
     1161                rel_obj, index_end = cached_row 
     1162                if obj is not None: 
     1163                    setattr(obj, f.related.get_accessor_cache(), rel_obj) 
     1164                if rel_obj is not None: 
     1165                    setattr(rel_obj, f.get_cache_name(), obj) 
     1166 
     1167 
    11501168    return obj, index_end 
    11511169 
    11521170def delete_objects(seen_objs, using): 
  • a/django/db/models/related.py

    old new  
    4545            return self.field.rel.related_name or (self.opts.object_name.lower() + '_set') 
    4646        else: 
    4747            return self.field.rel.related_name or (self.opts.object_name.lower()) 
     48 
     49    def get_accessor_cache(self): 
     50        return "_%s_cache" % self.get_accessor_name() 
  • a/django/db/models/sql/compiler.py

    old new  
    520520 
    521521        # Setup for the case when only particular related fields should be 
    522522        # included in the related selection. 
    523         if requested is None and restricted is not False
     523        if requested is None
    524524            if isinstance(self.query.select_related, dict): 
    525525                requested = self.query.select_related 
    526526                restricted = True 
     
    600600            self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, 
    601601                    used, next, restricted, new_nullable, dupe_set, avoid) 
    602602 
     603        if restricted and requested is not None: 
     604            related_fields = [(o.field, o.model) for o in opts.get_all_related_objects() 
     605                if o.field.unique and o.field.related_query_name() in requested 
     606            ] 
     607            for f, model in related_fields: 
     608                table = model._meta.db_table 
     609                int_opts = opts 
     610                alias = root_alias 
     611                alias_chain = [] 
     612                chain = opts.get_base_chain(f.rel.to) 
     613                avoid = avoid_set.copy() 
     614                if chain is not None: 
     615                    for int_model in chain: 
     616                        if not int_opts.parents[int_model]: 
     617                            int_opts = int_model._meta 
     618                            continue 
     619                        lhs_col = int_opts.parents[int_model].column 
     620                        dedupe = lhs_col in opts.duplicate_targets 
     621                        if dedupe: 
     622                            avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col), ()) 
     623                            dupe_set.add((opts, lhs_col)) 
     624                        int_opts = int_model._meta 
     625                        alias = self.query.join( 
     626                            (alias, int_opts.db_table, lhs_col, int_opts.pk.column), 
     627                            exclusions=used, promote=True, reuse=used 
     628                        ) 
     629                        alias_chain.append(alias) 
     630                        for dupe_opts, dupe_col in dupe_set: 
     631                            self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias) 
     632                    dedupe = f.column in opts.duplicate_targets 
     633                    if dupe_set or dedupe: 
     634                        avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ())) 
     635                        if dedupe: 
     636                            dupe_set.add((opts, f.column)) 
     637                alias = self.query.join( 
     638                    (alias, table, f.rel.get_related_field().column, f.column), 
     639                    exclusions=used.union(avoid), 
     640                    promote=True 
     641                ) 
     642                used.add(alias) 
     643                columns, aliases = self.get_default_columns(start_alias=alias, 
     644                    opts=model._meta, as_pairs=True) 
     645                self.query.related_select_cols.extend(columns) 
     646                self.query.related_select_fields.extend(model._meta.fields) 
     647 
     648                next = requested.get(f.related_query_name(), {}) 
     649                new_nullable = f.null or None 
     650 
     651                self.fill_related_selections(model._meta, table, cur_depth+1, 
     652                    used, next, restricted, new_nullable) 
     653 
    603654    def deferred_to_columns(self): 
    604655        """ 
    605656        Converts the self.deferred_loading data structure to mapping of table 
  • /dev/null

    old new  
     1from django.db import models 
     2 
     3 
     4class User(models.Model): 
     5    username = models.CharField(max_length=100) 
     6    email = models.EmailField() 
     7 
     8    def __unicode__(self): 
     9        return self.username 
     10 
     11 
     12class UserProfile(models.Model): 
     13    user = models.OneToOneField(User) 
     14    city = models.CharField(max_length=100) 
     15    state = models.CharField(max_length=2) 
     16 
     17    def __unicode__(self): 
     18        return "%s, %s" % (self.city, self.state) 
     19 
     20 
     21class UserStatResult(models.Model): 
     22    results = models.CharField(max_length=50) 
     23 
     24    def __unicode__(self): 
     25        return 'UserStatResults, results = %s' % (self.results,) 
     26     
     27 
     28class UserStat(models.Model): 
     29    user = models.OneToOneField(User, primary_key=True) 
     30    posts = models.IntegerField() 
     31    results = models.ForeignKey(UserStatResult) 
     32 
     33    def __unicode__(self): 
     34        return 'UserStat, posts = %s' % (self.posts,) 
     35 
     36class StatDetails(models.Model): 
     37    base_stats = models.OneToOneField(UserStat) 
     38    comments = models.IntegerField() 
     39 
     40    def __unicode__(self): 
     41        return 'StatDetails, comments = %s' % (self.comments,) 
     42 
     43class AdvancedUserStat(UserStat): 
     44    pass 
  • /dev/null

    old new  
     1from django import db 
     2from django.conf import settings 
     3from django.test import TestCase 
     4 
     5from models import User, UserProfile, UserStat, UserStatResult, StatDetails, AdvancedUserStat 
     6 
     7class ReverseSelectRelatedTestCase(TestCase): 
     8    def setUp(self): 
     9        self.old_debug = settings.DEBUG 
     10        settings.DEBUG = True 
     11 
     12        user = User.objects.create(username="test") 
     13        userprofile = UserProfile.objects.create(user=user, state="KS", 
     14                                                 city="Lawrence") 
     15        results = UserStatResult.objects.create(results='first results') 
     16        userstat = UserStat.objects.create(user=user, posts=150, 
     17                                           results=results) 
     18        details = StatDetails.objects.create(base_stats=userstat, comments=259) 
     19 
     20        user2 = User.objects.create(username="bob") 
     21        results2 = UserStatResult.objects.create(results='moar results') 
     22        advstat = AdvancedUserStat.objects.create(user=user2, posts=200, 
     23                                                  results=results2) 
     24        StatDetails.objects.create(base_stats=advstat, comments=250) 
     25 
     26        db.reset_queries() 
     27 
     28    def assertQueries(self, queries): 
     29        self.assertEqual(len(db.connection.queries), queries) 
     30 
     31    def tearDown(self): 
     32        settings.DEBUG = self.old_debug 
     33 
     34    def test_basic(self): 
     35        u = User.objects.select_related("userprofile").get(username="test") 
     36        self.assertEqual(u.userprofile.state, "KS") 
     37        self.assertQueries(1) 
     38 
     39    def test_follow_next_level(self): 
     40        u = User.objects.select_related("userstat__results").get(username="test") 
     41        self.assertEqual(u.userstat.posts, 150) 
     42        self.assertEqual(u.userstat.results.results, 'first results') 
     43        self.assertQueries(1) 
     44 
     45    def test_follow_two(self): 
     46        u = User.objects.select_related("userprofile", "userstat").get(username="test") 
     47        self.assertEqual(u.userprofile.state, "KS") 
     48        self.assertEqual(u.userstat.posts, 150) 
     49        self.assertQueries(1) 
     50 
     51    def test_follow_two_next_level(self): 
     52        u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") 
     53        self.assertEqual(u.userstat.results.results, 'first results') 
     54        self.assertEqual(u.userstat.statdetails.comments, 259) 
     55        self.assertQueries(1) 
     56 
     57    def test_forward_and_back(self): 
     58        stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") 
     59        self.assertEqual(stat.user.userprofile.state, 'KS') 
     60        self.assertEqual(stat.user.userstat.posts, 150) 
     61        self.assertQueries(1) 
     62 
     63    def test_back_and_forward(self): 
     64        u = User.objects.select_related("userstat").get(username="test") 
     65        self.assertEqual(u.userstat.user.username, 'test') 
     66        self.assertQueries(1) 
     67 
     68    def test_not_followed_by_default(self): 
     69        u = User.objects.select_related().get(username="test") 
     70        self.assertEqual(u.userstat.posts, 150) 
     71        self.assertQueries(2) 
     72 
     73    def test_follow_from_child_class(self): 
     74        stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200) 
     75        self.assertEqual(stat.statdetails.comments, 250) 
     76        self.assertQueries(1) 
     77 
     78    def test_follow_inheritance(self): 
     79        stat = UserStat.objects.select_related('advanceduserstat').get(posts=200) 
     80        self.assertEqual(stat.advanceduserstat.posts, 200) 
     81        self.assertQueries(1)