Code

Ticket #9961: 9961_r17068.diff

File 9961_r17068.diff, 4.3 KB (added by koenb, 3 years ago)

updated the patch to r17068

Line 
1diff --git a/django/db/models/query.py b/django/db/models/query.py
2index be42d02..7875eae 100644
3--- a/django/db/models/query.py
4+++ b/django/db/models/query.py
5@@ -657,7 +657,7 @@ class QuerySet(object):
6         If fields are specified, they must be ForeignKey fields and only those
7         related objects are included in the selection.
8         """
9-        depth = kwargs.pop('depth', 0)
10+        depth = kwargs.pop('depth', None)
11         if kwargs:
12             raise TypeError('Unexpected keyword arguments to select_related: %s'
13                     % (kwargs.keys(),))
14@@ -668,7 +668,7 @@ class QuerySet(object):
15             obj.query.add_select_related(fields)
16         else:
17             obj.query.select_related = True
18-        if depth:
19+        if depth is not None:
20             obj.query.max_depth = depth
21         return obj
22 
23@@ -1217,7 +1217,7 @@ class EmptyQuerySet(QuerySet):
24     # situations).
25     value_annotation = False
26 
27-def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
28+def get_klass_info(klass, max_depth=None, cur_depth=0, requested=None,
29                    only_load=None, local_only=False):
30     """
31     Helper function that recursively returns an information for a klass, to be
32@@ -1241,7 +1241,7 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
33      * local_only - Only populate local fields. This is used when
34        following reverse select-related relations
35     """
36-    if max_depth and requested is None and cur_depth > max_depth:
37+    if max_depth is not None and requested is None and cur_depth > max_depth:
38         # We've recursed deeply enough; stop now.
39         return None
40 
41diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
42index cebd77f..9c63062 100644
43--- a/django/db/models/sql/compiler.py
44+++ b/django/db/models/sql/compiler.py
45@@ -517,7 +517,7 @@ class SQLCompiler(object):
46         (for example, cur_depth=1 means we are looking at models with direct
47         connections to the root model).
48         """
49-        if not restricted and self.query.max_depth and cur_depth > self.query.max_depth:
50+        if not restricted and self.query.max_depth is not None and cur_depth > self.query.max_depth:
51             # We've recursed far enough; bail out.
52             return
53 
54diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
55index 68da9c7..9cae243 100644
56--- a/docs/ref/models/querysets.txt
57+++ b/docs/ref/models/querysets.txt
58@@ -642,6 +642,14 @@ follow::
59     p = b.author         # Doesn't hit the database.
60     c = p.hometown       # Requires a database call.
61 
62+You can also use the ``depth`` argument to cancel an existing
63+``select_related()`` on the query by setting ``depth`` to "0". For example
64+these three queries are equivalent::
65+
66+    b = Book.objects.get(id=4)
67+    b = Book.objects.select_related(depth=0).get(id=4)
68+    b = Book.objects.select_related('author').select_related(depth=0).get(id=4)
69+
70 Sometimes you only want to access specific models that are related to your root
71 model, not all of the related models. In these cases, you can pass the related
72 field names to ``select_related()`` and it will only follow those relations.
73diff --git a/tests/modeltests/select_related/tests.py b/tests/modeltests/select_related/tests.py
74index 1b3715a..a6641db 100644
75--- a/tests/modeltests/select_related/tests.py
76+++ b/tests/modeltests/select_related/tests.py
77@@ -160,3 +160,26 @@ class SelectRelatedTests(TestCase):
78             Species.objects.select_related,
79             'genus__family__order', depth=4
80         )
81+
82+    def test_depth_zero(self):
83+        with self.assertNumQueries(9):
84+            world = Species.objects.all().select_related(depth=0)
85+            families = [o.genus.family.name for o in world]
86+            self.assertEqual(sorted(families), [
87+                'Amanitacae',
88+                'Drosophilidae',
89+                'Fabaceae',
90+                'Hominidae',
91+            ])
92+
93+    def test_reset_depth(self):
94+        with self.assertNumQueries(9):
95+            world = Species.objects.all().select_related('genus').select_related(depth=0)
96+            families = [o.genus.family.name for o in world]
97+            self.assertEqual(sorted(families), [
98+                'Amanitacae',
99+                'Drosophilidae',
100+                'Fabaceae',
101+                'Hominidae',
102+            ])
103+