From 2b446904bfeb8bb79762426eb5ee0b97f57e6844 Mon Sep 17 00:00:00 2001
From: David Bennett <david@dbinit.com>
Date: Mon, 30 Jan 2012 12:54:44 -0600
Subject: [PATCH] Fixed #13781 -- select_related and multiple inheritance
---
django/db/models/query.py | 19 ++++++++++----
django/db/models/sql/compiler.py | 12 ++++++---
.../select_related_onetoone/models.py | 26 ++++++++++++++++++++
.../select_related_onetoone/tests.py | 23 ++++++++++++++++-
4 files changed, 70 insertions(+), 10 deletions(-)
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 41c24c7..c76d6d0 100644
|
a
|
b
|
class EmptyQuerySet(QuerySet):
|
| 1238 | 1238 | value_annotation = False |
| 1239 | 1239 | |
| 1240 | 1240 | def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, |
| 1241 | | only_load=None, local_only=False): |
| | 1241 | only_load=None, local_only=False, last_klass=None): |
| 1242 | 1242 | """ |
| 1243 | 1243 | Helper function that recursively returns an information for a klass, to be |
| 1244 | 1244 | used in get_cached_row. It exists just to compute this information only |
| … |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
| 1260 | 1260 | the full field list for `klass` can be assumed. |
| 1261 | 1261 | * local_only - Only populate local fields. This is used when |
| 1262 | 1262 | following reverse select-related relations |
| | 1263 | * last_klass - the last class seen when following reverse |
| | 1264 | select-related relations |
| 1263 | 1265 | """ |
| 1264 | 1266 | if max_depth and requested is None and cur_depth > max_depth: |
| 1265 | 1267 | # We've recursed deeply enough; stop now. |
| … |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
| 1305 | 1307 | # But kwargs version of Model.__init__ is slower, so we should avoid using |
| 1306 | 1308 | # it when it is not really neccesary. |
| 1307 | 1309 | if local_only and len(klass._meta.local_fields) != len(klass._meta.fields): |
| 1308 | | field_count = len(klass._meta.local_fields) |
| 1309 | | field_names = [f.attname for f in klass._meta.local_fields] |
| | 1310 | parents = [p for p in klass._meta.get_parent_list() |
| | 1311 | if p is not last_klass] |
| | 1312 | field_names = [f.attname for f in klass._meta.fields |
| | 1313 | if f in klass._meta.local_fields |
| | 1314 | or f.model in parents] |
| | 1315 | field_count = len(field_names) |
| | 1316 | if field_count == len(klass._meta.fields): |
| | 1317 | field_names = () |
| 1310 | 1318 | else: |
| 1311 | 1319 | field_count = len(klass._meta.fields) |
| 1312 | 1320 | field_names = () |
| … |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
| 1330 | 1338 | if o.field.unique and select_related_descend(o.field, restricted, requested, reverse=True): |
| 1331 | 1339 | next = requested[o.field.related_query_name()] |
| 1332 | 1340 | klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1, |
| 1333 | | requested=next, only_load=only_load, local_only=True) |
| | 1341 | requested=next, only_load=only_load, local_only=True, |
| | 1342 | last_klass=klass) |
| 1334 | 1343 | reverse_related_fields.append((o.field, klass_info)) |
| 1335 | 1344 | |
| 1336 | 1345 | return klass, field_names, field_count, related_fields, reverse_related_fields |
| … |
… |
def get_cached_row(row, index_start, using, klass_info, offset=0):
|
| 1416 | 1425 | # Now populate all the non-local field values |
| 1417 | 1426 | # on the related object |
| 1418 | 1427 | for rel_field, rel_model in rel_obj._meta.get_fields_with_model(): |
| 1419 | | if rel_model is not None: |
| | 1428 | if rel_model is not None and isinstance(obj, rel_model): |
| 1420 | 1429 | setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) |
| 1421 | 1430 | # populate the field cache for any related object |
| 1422 | 1431 | # that has already been retrieved |
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 72948f9..c773a3b 100644
|
a
|
b
|
class SQLCompiler(object):
|
| 246 | 246 | return result |
| 247 | 247 | |
| 248 | 248 | def get_default_columns(self, with_aliases=False, col_aliases=None, |
| 249 | | start_alias=None, opts=None, as_pairs=False, local_only=False): |
| | 249 | start_alias=None, opts=None, as_pairs=False, local_only=False, |
| | 250 | last_opts=None): |
| 250 | 251 | """ |
| 251 | 252 | Computes the default columns for selecting every field in the base |
| 252 | 253 | model. Will sometimes be called to pull in related models (e.g. via |
| … |
… |
class SQLCompiler(object):
|
| 270 | 271 | |
| 271 | 272 | if start_alias: |
| 272 | 273 | seen = {None: start_alias} |
| | 274 | parents = [p for p in opts.get_parent_list() if p._meta is not last_opts] |
| 273 | 275 | for field, model in opts.get_fields_with_model(): |
| 274 | | if local_only and model is not None: |
| | 276 | if local_only and model is not None and model not in parents: |
| 275 | 277 | continue |
| 276 | 278 | if start_alias: |
| 277 | 279 | try: |
| … |
… |
class SQLCompiler(object):
|
| 282 | 284 | else: |
| 283 | 285 | link_field = opts.get_ancestor_link(model) |
| 284 | 286 | alias = self.query.join((start_alias, model._meta.db_table, |
| 285 | | link_field.column, model._meta.pk.column)) |
| | 287 | link_field.column, model._meta.pk.column), |
| | 288 | promote=(model in parents)) |
| 286 | 289 | seen[model] = alias |
| 287 | 290 | else: |
| 288 | 291 | # If we're starting from the base model of the queryset, the |
| … |
… |
class SQLCompiler(object):
|
| 728 | 731 | ) |
| 729 | 732 | used.add(alias) |
| 730 | 733 | columns, aliases = self.get_default_columns(start_alias=alias, |
| 731 | | opts=model._meta, as_pairs=True, local_only=True) |
| | 734 | opts=model._meta, as_pairs=True, local_only=True, |
| | 735 | last_opts=opts) |
| 732 | 736 | self.query.related_select_cols.extend(columns) |
| 733 | 737 | self.query.related_select_fields.extend(model._meta.fields) |
| 734 | 738 | |
diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
index 3d6da9b..4bfad1d 100644
|
a
|
b
|
class StatDetails(models.Model):
|
| 45 | 45 | class AdvancedUserStat(UserStat): |
| 46 | 46 | karma = models.IntegerField() |
| 47 | 47 | |
| | 48 | |
| 48 | 49 | class Image(models.Model): |
| 49 | 50 | name = models.CharField(max_length=100) |
| 50 | 51 | |
| … |
… |
class Image(models.Model):
|
| 52 | 53 | class Product(models.Model): |
| 53 | 54 | name = models.CharField(max_length=100) |
| 54 | 55 | image = models.OneToOneField(Image, null=True) |
| | 56 | |
| | 57 | |
| | 58 | class Parent1(models.Model): |
| | 59 | name1 = models.CharField(max_length=50) |
| | 60 | def __unicode__(self): |
| | 61 | return self.name1 |
| | 62 | |
| | 63 | |
| | 64 | class Parent2(models.Model): |
| | 65 | name2 = models.CharField(max_length=50) |
| | 66 | def __unicode__(self): |
| | 67 | return self.name2 |
| | 68 | |
| | 69 | |
| | 70 | class Child1(Parent1, Parent2): |
| | 71 | other = models.CharField(max_length=50) |
| | 72 | def __unicode__(self): |
| | 73 | return self.name1 |
| | 74 | |
| | 75 | |
| | 76 | class Child2(Parent1): |
| | 77 | parent2 = models.OneToOneField(Parent2) |
| | 78 | other = models.CharField(max_length=50) |
| | 79 | def __unicode__(self): |
| | 80 | return self.name1 |
diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
index 643a0ff..a57142c 100644
|
a
|
b
|
from __future__ import with_statement, absolute_import
|
| 3 | 3 | from django.test import TestCase |
| 4 | 4 | |
| 5 | 5 | from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails, |
| 6 | | AdvancedUserStat, Image, Product) |
| | 6 | AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2) |
| 7 | 7 | |
| 8 | 8 | |
| 9 | 9 | class ReverseSelectRelatedTestCase(TestCase): |
| … |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
| 21 | 21 | advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, |
| 22 | 22 | results=results2) |
| 23 | 23 | StatDetails.objects.create(base_stats=advstat, comments=250) |
| | 24 | p1 = Parent1(name1="Only Parent1") |
| | 25 | p1.save() |
| | 26 | c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2") |
| | 27 | c1.save() |
| | 28 | p2 = Parent2(name2="Child2 Parent2") |
| | 29 | p2.save() |
| | 30 | c2 = Child2(name1="Child2 Parent1", parent2=p2) |
| | 31 | c2.save() |
| 24 | 32 | |
| 25 | 33 | def test_basic(self): |
| 26 | 34 | with self.assertNumQueries(1): |
| … |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
| 80 | 88 | p2 = Product.objects.create(name="Talking Django Plushie") |
| 81 | 89 | |
| 82 | 90 | self.assertEqual(len(Product.objects.select_related("image")), 2) |
| | 91 | |
| | 92 | def test_parent_only(self): |
| | 93 | Parent1.objects.select_related('child1').get(name1="Only Parent1") |
| | 94 | |
| | 95 | def test_multiple_subclass(self): |
| | 96 | with self.assertNumQueries(1): |
| | 97 | p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1") |
| | 98 | self.assertEqual(p.child1.name2, u'Child1 Parent2') |
| | 99 | |
| | 100 | def test_onetoone_with_subclass(self): |
| | 101 | with self.assertNumQueries(1): |
| | 102 | p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2") |
| | 103 | self.assertEqual(p.child2.name1, u'Child2 Parent1') |