diff --git a/django/db/models/query.py b/django/db/models/query.py
index 0210a79..e144956 100644
|
a
|
b
|
class EmptyQuerySet(QuerySet):
|
| 1276 | 1276 | value_annotation = False |
| 1277 | 1277 | |
| 1278 | 1278 | def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, |
| 1279 | | only_load=None, local_only=False): |
| | 1279 | only_load=None, local_only=False, last_klass=None): |
| 1280 | 1280 | """ |
| 1281 | 1281 | Helper function that recursively returns an information for a klass, to be |
| 1282 | 1282 | used in get_cached_row. It exists just to compute this information only |
| … |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
| 1298 | 1298 | the full field list for `klass` can be assumed. |
| 1299 | 1299 | * local_only - Only populate local fields. This is used when |
| 1300 | 1300 | following reverse select-related relations |
| | 1301 | * last_klass - the last class seen when following reverse |
| | 1302 | select-related relations |
| 1301 | 1303 | """ |
| 1302 | 1304 | if max_depth and requested is None and cur_depth > max_depth: |
| 1303 | 1305 | # We've recursed deeply enough; stop now. |
| … |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
| 1343 | 1345 | # But kwargs version of Model.__init__ is slower, so we should avoid using |
| 1344 | 1346 | # it when it is not really neccesary. |
| 1345 | 1347 | if local_only and len(klass._meta.local_fields) != len(klass._meta.fields): |
| 1346 | | field_count = len(klass._meta.local_fields) |
| 1347 | | field_names = [f.attname for f in klass._meta.local_fields] |
| | 1348 | parents = [p for p in klass._meta.get_parent_list() |
| | 1349 | if p is not last_klass] |
| | 1350 | field_names = [f.attname for f in klass._meta.fields |
| | 1351 | if f in klass._meta.local_fields |
| | 1352 | or f.model in parents] |
| | 1353 | field_count = len(field_names) |
| | 1354 | if field_count == len(klass._meta.fields): |
| | 1355 | field_names = () |
| 1348 | 1356 | else: |
| 1349 | 1357 | field_count = len(klass._meta.fields) |
| 1350 | 1358 | field_names = () |
| … |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
| 1369 | 1377 | only_load.get(o.model), reverse=True): |
| 1370 | 1378 | next = requested[o.field.related_query_name()] |
| 1371 | 1379 | klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1, |
| 1372 | | requested=next, only_load=only_load, local_only=True) |
| | 1380 | requested=next, only_load=only_load, local_only=True, |
| | 1381 | last_klass=klass) |
| 1373 | 1382 | reverse_related_fields.append((o.field, klass_info)) |
| 1374 | 1383 | |
| 1375 | 1384 | return klass, field_names, field_count, related_fields, reverse_related_fields |
| … |
… |
def get_cached_row(row, index_start, using, klass_info, offset=0):
|
| 1455 | 1464 | # Now populate all the non-local field values |
| 1456 | 1465 | # on the related object |
| 1457 | 1466 | for rel_field, rel_model in rel_obj._meta.get_fields_with_model(): |
| 1458 | | if rel_model is not None: |
| | 1467 | if rel_model is not None and isinstance(obj, rel_model): |
| 1459 | 1468 | setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) |
| 1460 | 1469 | # populate the field cache for any related object |
| 1461 | 1470 | # that has already been retrieved |
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index a68f6e0..b8cf74d 100644
|
a
|
b
|
class SQLCompiler(object):
|
| 249 | 249 | return result |
| 250 | 250 | |
| 251 | 251 | def get_default_columns(self, with_aliases=False, col_aliases=None, |
| 252 | | start_alias=None, opts=None, as_pairs=False, local_only=False): |
| | 252 | start_alias=None, opts=None, as_pairs=False, local_only=False, |
| | 253 | last_opts=None): |
| 253 | 254 | """ |
| 254 | 255 | Computes the default columns for selecting every field in the base |
| 255 | 256 | model. Will sometimes be called to pull in related models (e.g. via |
| … |
… |
class SQLCompiler(object):
|
| 273 | 274 | |
| 274 | 275 | if start_alias: |
| 275 | 276 | seen = {None: start_alias} |
| | 277 | parents = [p for p in opts.get_parent_list() if p._meta is not last_opts] |
| 276 | 278 | for field, model in opts.get_fields_with_model(): |
| 277 | | if local_only and model is not None: |
| | 279 | if local_only and model is not None and model not in parents: |
| 278 | 280 | continue |
| 279 | 281 | if start_alias: |
| 280 | 282 | try: |
| … |
… |
class SQLCompiler(object):
|
| 282 | 284 | except KeyError: |
| 283 | 285 | link_field = opts.get_ancestor_link(model) |
| 284 | 286 | alias = self.query.join((start_alias, model._meta.db_table, |
| 285 | | link_field.column, model._meta.pk.column)) |
| | 287 | link_field.column, model._meta.pk.column), |
| | 288 | promote=(model in parents)) |
| 286 | 289 | seen[model] = alias |
| 287 | 290 | else: |
| 288 | 291 | # If we're starting from the base model of the queryset, the |
| … |
… |
class SQLCompiler(object):
|
| 733 | 736 | ) |
| 734 | 737 | used.add(alias) |
| 735 | 738 | columns, aliases = self.get_default_columns(start_alias=alias, |
| 736 | | opts=model._meta, as_pairs=True, local_only=True) |
| | 739 | opts=model._meta, as_pairs=True, local_only=True, |
| | 740 | last_opts=opts) |
| 737 | 741 | self.query.related_select_cols.extend(columns) |
| 738 | 742 | self.query.related_select_fields.extend(model._meta.fields) |
| 739 | 743 | |
diff --git a/tests/django b/tests/django
new file mode 120000
index 0000000..8016dee
|
-
|
+
|
|
| | 1 | ../django |
| | 2 | No newline at end of file |
diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
index 3284def..6216e0c 100644
|
a
|
b
|
class StatDetails(models.Model):
|
| 51 | 51 | class AdvancedUserStat(UserStat): |
| 52 | 52 | karma = models.IntegerField() |
| 53 | 53 | |
| | 54 | |
| 54 | 55 | class Image(models.Model): |
| 55 | 56 | name = models.CharField(max_length=100) |
| 56 | 57 | |
| … |
… |
class Image(models.Model):
|
| 58 | 59 | class Product(models.Model): |
| 59 | 60 | name = models.CharField(max_length=100) |
| 60 | 61 | image = models.OneToOneField(Image, null=True) |
| | 62 | |
| | 63 | |
| | 64 | class Parent1(models.Model): |
| | 65 | name1 = models.CharField(max_length=50) |
| | 66 | def __unicode__(self): |
| | 67 | return self.name1 |
| | 68 | |
| | 69 | |
| | 70 | class Parent2(models.Model): |
| | 71 | name2 = models.CharField(max_length=50) |
| | 72 | def __unicode__(self): |
| | 73 | return self.name2 |
| | 74 | |
| | 75 | |
| | 76 | class Child1(Parent1, Parent2): |
| | 77 | other = models.CharField(max_length=50) |
| | 78 | def __unicode__(self): |
| | 79 | return self.name1 |
| | 80 | |
| | 81 | |
| | 82 | class Child2(Parent1): |
| | 83 | parent2 = models.OneToOneField(Parent2) |
| | 84 | other = models.CharField(max_length=50) |
| | 85 | def __unicode__(self): |
| | 86 | return self.name1 |
diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
index 1373f04..3c8623f 100644
|
a
|
b
|
from __future__ import absolute_import
|
| 3 | 3 | from django.test import TestCase |
| 4 | 4 | |
| 5 | 5 | from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails, |
| 6 | | AdvancedUserStat, Image, Product) |
| | 6 | AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2) |
| 7 | 7 | |
| 8 | 8 | |
| 9 | 9 | class ReverseSelectRelatedTestCase(TestCase): |
| … |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
| 21 | 21 | advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, |
| 22 | 22 | results=results2) |
| 23 | 23 | StatDetails.objects.create(base_stats=advstat, comments=250) |
| | 24 | p1 = Parent1(name1="Only Parent1") |
| | 25 | p1.save() |
| | 26 | c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2") |
| | 27 | c1.save() |
| | 28 | p2 = Parent2(name2="Child2 Parent2") |
| | 29 | p2.save() |
| | 30 | c2 = Child2(name1="Child2 Parent1", parent2=p2) |
| | 31 | c2.save() |
| 24 | 32 | |
| 25 | 33 | def test_basic(self): |
| 26 | 34 | with self.assertNumQueries(1): |
| … |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
| 79 | 87 | p1 = Product.objects.create(name="Django Plushie", image=im) |
| 80 | 88 | p2 = Product.objects.create(name="Talking Django Plushie") |
| 81 | 89 | |
| | 90 | self.assertEqual(len(Product.objects.select_related("image")), 2) |
| | 91 | |
| 82 | 92 | with self.assertNumQueries(1): |
| 83 | 93 | result = sorted(Product.objects.select_related("image"), key=lambda x: x.name) |
| 84 | 94 | self.assertEqual([p.name for p in result], ["Django Plushie", "Talking Django Plushie"]) |
| … |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
| 108 | 118 | image = Image.objects.select_related('product').get() |
| 109 | 119 | with self.assertRaises(Product.DoesNotExist): |
| 110 | 120 | image.product |
| | 121 | |
| | 122 | def test_parent_only(self): |
| | 123 | Parent1.objects.select_related('child1').get(name1="Only Parent1") |
| | 124 | |
| | 125 | def test_multiple_subclass(self): |
| | 126 | with self.assertNumQueries(1): |
| | 127 | p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1") |
| | 128 | self.assertEqual(p.child1.name2, u'Child1 Parent2') |
| | 129 | |
| | 130 | def test_onetoone_with_subclass(self): |
| | 131 | with self.assertNumQueries(1): |
| | 132 | p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2") |
| | 133 | self.assertEqual(p.child2.name1, u'Child2 Parent1') |