From 9282d1d27823e780ccad7ddb006182e27e66262e Mon Sep 17 00:00:00 2001
From: David Bennett <david@dbinit.com>
Date: Mon, 30 Jan 2012 12:42:57 -0600
Subject: [PATCH] Fixed #13781 -- select_related and multiple inheritance
---
 django/db/models/query.py                          |   16 +++++++++---
 django/db/models/sql/compiler.py                   |   12 ++++++---
 .../select_related_onetoone/models.py              |   26 ++++++++++++++++++++
 .../select_related_onetoone/tests.py               |   23 ++++++++++++++++-
 4 files changed, 68 insertions(+), 9 deletions(-)
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 324554e..ede2581 100644
      
        
          
        
        
          
            | a | b | class EmptyQuerySet(QuerySet): | 
        
        
          
            | 1126 | 1126 |  | 
          
            | 1127 | 1127 |  | 
          
            | 1128 | 1128 | def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | 
        
        
          
            | 1129 |  | requested=None, offset=0, only_load=None, local_only=False): | 
          
            |  | 1129 | requested=None, offset=0, only_load=None, local_only=False, | 
          
            |  | 1130 | last_klass=None): | 
        
        
          
            | 1130 | 1131 | """ | 
          
            | 1131 | 1132 | Helper function that recursively returns an object with the specified | 
          
            | 1132 | 1133 | related attributes already populated. | 
        
        
          
            | … | … | def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | 
        
        
          
            | 1156 | 1157 | the full field list for `klass` can be assumed. | 
          
            | 1157 | 1158 | * local_only - Only populate local fields. This is used when building | 
          
            | 1158 | 1159 | following reverse select-related relations | 
        
        
          
            |  | 1160 | * last_klass - the last class seen when following reverse | 
          
            |  | 1161 | select-related relations | 
        
        
          
            | 1159 | 1162 | """ | 
          
            | 1160 | 1163 | if max_depth and requested is None and cur_depth > max_depth: | 
          
            | 1161 | 1164 | # We've recursed deeply enough; stop now. | 
        
        
          
            | … | … | def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | 
        
        
          
            | 1202 | 1205 | else: | 
          
            | 1203 | 1206 | # Load all fields on klass | 
          
            | 1204 | 1207 | if local_only: | 
        
        
          
            | 1205 |  | field_names = [f.attname for f in klass._meta.local_fields] | 
          
            |  | 1208 | parents = [p for p in klass._meta.get_parent_list() | 
          
            |  | 1209 | if p is not last_klass] | 
          
            |  | 1210 | field_names = [f.attname for f in klass._meta.fields | 
          
            |  | 1211 | if f in klass._meta.local_fields | 
          
            |  | 1212 | or f.model in parents] | 
        
        
          
            | 1206 | 1213 | else: | 
          
            | 1207 | 1214 | field_names = [f.attname for f in klass._meta.fields] | 
          
            | 1208 | 1215 | field_count = len(field_names) | 
        
        
          
            | … | … | def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | 
        
        
          
            | 1261 | 1268 | next = requested[f.related_query_name()] | 
          
            | 1262 | 1269 | # Recursively retrieve the data for the related object | 
          
            | 1263 | 1270 | cached_row = get_cached_row(model, row, index_end, using, | 
        
        
          
            | 1264 |  | max_depth, cur_depth+1, next, only_load=only_load, local_only=True) | 
          
            |  | 1271 | max_depth, cur_depth+1, next, only_load=only_load, local_only=True, | 
          
            |  | 1272 | last_klass=klass) | 
        
        
          
            | 1265 | 1273 | # If the recursive descent found an object, populate the | 
          
            | 1266 | 1274 | # descriptor caches relevant to the object | 
          
            | 1267 | 1275 | if cached_row: | 
        
        
          
            | … | … | def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | 
        
        
          
            | 1277 | 1285 | # Now populate all the non-local field values | 
          
            | 1278 | 1286 | # on the related object | 
          
            | 1279 | 1287 | for rel_field,rel_model in rel_obj._meta.get_fields_with_model(): | 
        
        
          
            | 1280 |  | if rel_model is not None : | 
          
            |  | 1288 | if rel_model is not None and isinstance(obj, rel_model): | 
        
        
          
            | 1281 | 1289 | setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) | 
          
            | 1282 | 1290 | # populate the field cache for any related object | 
          
            | 1283 | 1291 | # that has already been retrieved | 
        
      
    
    
      
      diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index d425c8b..7c0ff75 100644
      
        
          
        
        
          
            | a | b | class SQLCompiler(object): | 
        
        
          
            | 216 | 216 | return result | 
          
            | 217 | 217 |  | 
          
            | 218 | 218 | def get_default_columns(self, with_aliases=False, col_aliases=None, | 
        
        
          
            | 219 |  | start_alias=None, opts=None, as_pairs=False, local_only=False): | 
          
            |  | 219 | start_alias=None, opts=None, as_pairs=False, local_only=False, | 
          
            |  | 220 | last_opts=None): | 
        
        
          
            | 220 | 221 | """ | 
          
            | 221 | 222 | Computes the default columns for selecting every field in the base | 
          
            | 222 | 223 | model. Will sometimes be called to pull in related models (e.g. via | 
        
        
          
            | … | … | class SQLCompiler(object): | 
        
        
          
            | 240 | 241 |  | 
          
            | 241 | 242 | if start_alias: | 
          
            | 242 | 243 | seen = {None: start_alias} | 
        
        
          
            |  | 244 | parents = [p for p in opts.get_parent_list() if p._meta is not last_opts] | 
        
        
          
            | 243 | 245 | for field, model in opts.get_fields_with_model(): | 
        
        
          
            | 244 |  | if local_only and model is not None : | 
          
            |  | 246 | if local_only and model is not None and model not in parents: | 
        
        
          
            | 245 | 247 | continue | 
          
            | 246 | 248 | if start_alias: | 
          
            | 247 | 249 | try: | 
        
        
          
            | … | … | class SQLCompiler(object): | 
        
        
          
            | 252 | 254 | else: | 
          
            | 253 | 255 | link_field = opts.get_ancestor_link(model) | 
          
            | 254 | 256 | alias = self.query.join((start_alias, model._meta.db_table, | 
        
        
          
            | 255 |  | link_field.column, model._meta.pk.column)) | 
          
            |  | 257 | link_field.column, model._meta.pk.column), | 
          
            |  | 258 | promote=(model in parents)) | 
        
        
          
            | 256 | 259 | seen[model] = alias | 
          
            | 257 | 260 | else: | 
          
            | 258 | 261 | # If we're starting from the base model of the queryset, the | 
        
        
          
            | … | … | class SQLCompiler(object): | 
        
        
          
            | 650 | 653 | ) | 
          
            | 651 | 654 | used.add(alias) | 
          
            | 652 | 655 | columns, aliases = self.get_default_columns(start_alias=alias, | 
        
        
          
            | 653 |  | opts=model._meta, as_pairs=True, local_only=True) | 
          
            |  | 656 | opts=model._meta, as_pairs=True, local_only=True, | 
          
            |  | 657 | last_opts=opts) | 
        
        
          
            | 654 | 658 | self.query.related_select_cols.extend(columns) | 
          
            | 655 | 659 | self.query.related_select_fields.extend(model._meta.fields) | 
          
            | 656 | 660 |  | 
        
      
    
    
      
      diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
index 3d6da9b..4bfad1d 100644
      
        
          
        
        
          
            | a | b | class StatDetails(models.Model): | 
        
        
          
            | 45 | 45 | class AdvancedUserStat(UserStat): | 
          
            | 46 | 46 | karma = models.IntegerField() | 
          
            | 47 | 47 |  | 
        
        
          
            |  | 48 |  | 
        
        
          
            | 48 | 49 | class Image(models.Model): | 
          
            | 49 | 50 | name = models.CharField(max_length=100) | 
          
            | 50 | 51 |  | 
        
        
          
            | … | … | class Image(models.Model): | 
        
        
          
            | 52 | 53 | class Product(models.Model): | 
          
            | 53 | 54 | name = models.CharField(max_length=100) | 
          
            | 54 | 55 | image = models.OneToOneField(Image, null=True) | 
        
        
          
            |  | 56 |  | 
          
            |  | 57 |  | 
          
            |  | 58 | class Parent1(models.Model): | 
          
            |  | 59 | name1 = models.CharField(max_length=50) | 
          
            |  | 60 | def __unicode__(self): | 
          
            |  | 61 | return self.name1 | 
          
            |  | 62 |  | 
          
            |  | 63 |  | 
          
            |  | 64 | class Parent2(models.Model): | 
          
            |  | 65 | name2 = models.CharField(max_length=50) | 
          
            |  | 66 | def __unicode__(self): | 
          
            |  | 67 | return self.name2 | 
          
            |  | 68 |  | 
          
            |  | 69 |  | 
          
            |  | 70 | class Child1(Parent1, Parent2): | 
          
            |  | 71 | other = models.CharField(max_length=50) | 
          
            |  | 72 | def __unicode__(self): | 
          
            |  | 73 | return self.name1 | 
          
            |  | 74 |  | 
          
            |  | 75 |  | 
          
            |  | 76 | class Child2(Parent1): | 
          
            |  | 77 | parent2 = models.OneToOneField(Parent2) | 
          
            |  | 78 | other = models.CharField(max_length=50) | 
          
            |  | 79 | def __unicode__(self): | 
          
            |  | 80 | return self.name1 | 
        
      
    
    
      
      diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
index ab35fec..407cdc9 100644
      
        
          
        
        
          
            | a | b | from django.conf import settings | 
        
        
          
            | 3 | 3 | from django.test import TestCase | 
          
            | 4 | 4 |  | 
          
            | 5 | 5 | from models import (User, UserProfile, UserStat, UserStatResult, StatDetails, | 
        
        
          
            | 6 |  | AdvancedUserStat, Image, Product ) | 
          
            |  | 6 | AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2) | 
        
        
          
            | 7 | 7 |  | 
          
            | 8 | 8 | class ReverseSelectRelatedTestCase(TestCase): | 
          
            | 9 | 9 | def setUp(self): | 
        
        
          
            | … | … | class ReverseSelectRelatedTestCase(TestCase): | 
        
        
          
            | 20 | 20 | advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, | 
          
            | 21 | 21 | results=results2) | 
          
            | 22 | 22 | StatDetails.objects.create(base_stats=advstat, comments=250) | 
        
        
          
            |  | 23 | p1 = Parent1(name1="Only Parent1") | 
          
            |  | 24 | p1.save() | 
          
            |  | 25 | c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2") | 
          
            |  | 26 | c1.save() | 
          
            |  | 27 | p2 = Parent2(name2="Child2 Parent2") | 
          
            |  | 28 | p2.save() | 
          
            |  | 29 | c2 = Child2(name1="Child2 Parent1", parent2=p2) | 
          
            |  | 30 | c2.save() | 
        
        
          
            | 23 | 31 |  | 
          
            | 24 | 32 | def test_basic(self): | 
          
            | 25 | 33 | def test(): | 
        
        
          
            | … | … | class ReverseSelectRelatedTestCase(TestCase): | 
        
        
          
            | 88 | 96 | p2 = Product.objects.create(name="Talking Django Plushie") | 
          
            | 89 | 97 |  | 
          
            | 90 | 98 | self.assertEqual(len(Product.objects.select_related("image")), 2) | 
        
        
          
            |  | 99 |  | 
          
            |  | 100 | def test_parent_only(self): | 
          
            |  | 101 | Parent1.objects.select_related('child1').get(name1="Only Parent1") | 
          
            |  | 102 |  | 
          
            |  | 103 | def test_multiple_subclass(self): | 
          
            |  | 104 | with self.assertNumQueries(1): | 
          
            |  | 105 | p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1") | 
          
            |  | 106 | self.assertEqual(p.child1.name2, u"Child1 Parent2") | 
          
            |  | 107 |  | 
          
            |  | 108 | def test_onetoone_with_subclass(self): | 
          
            |  | 109 | with self.assertNumQueries(1): | 
          
            |  | 110 | p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2") | 
          
            |  | 111 | self.assertEqual(p.child2.name1, u"Child2 Parent1") |