Code

Ticket #10790: 10790v4.diff

File 10790v4.diff, 9.3 KB (added by PhiR, 3 years ago)

updated for current trunk

Line 
1Index: tests/modeltests/null_trimjoin/__init__.py
2===================================================================
3--- tests/modeltests/null_trimjoin/__init__.py  (revision 0)
4+++ tests/modeltests/null_trimjoin/__init__.py  (revision 0)
5@@ -0,0 +1 @@
6+# dummy text for patch
7Index: tests/modeltests/null_trimjoin/tests.py
8===================================================================
9--- tests/modeltests/null_trimjoin/tests.py     (revision 0)
10+++ tests/modeltests/null_trimjoin/tests.py     (revision 0)
11@@ -0,0 +1,44 @@
12+from django.test import TestCase
13+from models import Article, Reporter
14+
15+class OneToOneTests(TestCase):
16+
17+    def setUp(self):
18+        self.r = Reporter(name='John Smith')
19+        self.r.save()
20+        self.a = Article(headline="First", reporter=self.r)
21+        self.a.save()
22+        self.a2 = Article(headline="Second")
23+        self.a2.save()
24+
25+    def test_query_with_isnull(self):
26+        """Querying with isnull should not join Reporter table."""
27+        q = Article.objects.filter(reporter=None)
28+        # check that reporter is not in the query's used_aliases
29+        self.assertFalse('null_trimjoin_reporter' in q.query.used_aliases)
30+        self.assertTrue('null_trimjoin_article' in q.query.used_aliases)
31+        # but it should still be in query.tables
32+        self.assertTrue('null_trimjoin_article' in q.query.tables)
33+
34+    def test_query_across_tables(self):
35+        """Querying across several tables should strip only the last join, while
36+        preserving the preceding left outer joins."""
37+        q = Article.objects.filter(reporter__type=None)
38+        self.assertEquals(len(q), 2)
39+        self.assertTrue('null_trimjoin_article' in q.query.used_aliases)
40+        self.assertTrue('null_trimjoin_reporter' in q.query.used_aliases)
41+        self.assertFalse('null_trimjoin_reportertype' in q.query.used_aliases)
42+
43+    def test_m2m_query(self):
44+        """Querying across m2m field should not strip the m2m table from join."""
45+        q = Article.objects.filter(reporter__category__isnull=True)
46+        self.assertTrue('null_trimjoin_article' in q.query.used_aliases)
47+        self.assertTrue('null_trimjoin_reporter' in q.query.used_aliases)
48+        self.assertTrue('null_trimjoin_category' in q.query.used_aliases)
49+
50+    def test_reverse_query(self):
51+        """Reverse querying with isnull should not strip the join."""
52+        q = Reporter.objects.filter(article__isnull=True)
53+        self.assertTrue('null_trimjoin_reporter' in q.query.used_aliases)
54+
55+
56Index: tests/modeltests/null_trimjoin/models.py
57===================================================================
58--- tests/modeltests/null_trimjoin/models.py    (revision 0)
59+++ tests/modeltests/null_trimjoin/models.py    (revision 0)
60@@ -0,0 +1,31 @@
61+"""
62+Do not join table when querying on isnull
63+
64+"""
65+
66+from django.db import models
67+
68+class Category(models.Model):
69+    name = models.CharField(max_length=30)
70+
71+class ReporterType(models.Model):
72+    name = models.CharField(max_length=30)
73+
74+class Reporter(models.Model):
75+    name = models.CharField(max_length=30)
76+    type = models.ForeignKey(ReporterType, null=True)
77+    category = models.ManyToManyField(Category, null=True)
78+
79+    def __unicode__(self):
80+        return self.name
81+
82+class Article(models.Model):
83+    headline = models.CharField(max_length=100)
84+    reporter = models.ForeignKey(Reporter, null=True)
85+
86+    class Meta:
87+        ordering = ('headline',)
88+
89+    def __unicode__(self):
90+        return self.headline
91+
92Index: django/db/models/sql/compiler.py
93===================================================================
94--- django/db/models/sql/compiler.py    (revision 16730)
95+++ django/db/models/sql/compiler.py    (working copy)
96@@ -384,8 +384,8 @@
97         pieces = name.split(LOOKUP_SEP)
98         if not alias:
99             alias = self.query.get_initial_alias()
100-        field, target, opts, joins, last, extra = self.query.setup_joins(pieces,
101-                opts, alias, False)
102+        field, target, opts, joins, last, extra, allow_trim_join = self.query.setup_joins(
103+                pieces, opts, alias, False)
104         alias = joins[-1]
105         col = target.column
106         if not field.rel:
107Index: django/db/models/sql/expressions.py
108===================================================================
109--- django/db/models/sql/expressions.py (revision 16730)
110+++ django/db/models/sql/expressions.py (working copy)
111@@ -44,7 +44,7 @@
112             self.cols[node] = query.aggregate_select[node.name]
113         else:
114             try:
115-                field, source, opts, join_list, last, _ = query.setup_joins(
116+                field, source, opts, join_list, last, _, allow_trim_join = query.setup_joins(
117                     field_list, query.get_meta(),
118                     query.get_initial_alias(), False)
119                 col, _, join_list = query.trim_joins(source, join_list, last, False)
120Index: django/db/models/sql/query.py
121===================================================================
122--- django/db/models/sql/query.py       (revision 16730)
123+++ django/db/models/sql/query.py       (working copy)
124@@ -694,6 +694,14 @@
125             return True
126         return False
127 
128+    def demote_alias(self, alias):
129+        """
130+        Demotes the join type of an alias to an inner join.
131+        """
132+        data = list(self.alias_map[alias])
133+        data[JOIN_TYPE] = self.INNER
134+        self.alias_map[alias] = tuple(data)
135+
136     def promote_alias_chain(self, chain, must_promote=False):
137         """
138         Walks along a chain of aliases, promoting the first nullable join and
139@@ -991,7 +999,7 @@
140             #   - this is an annotation over a model field
141             # then we need to explore the joins that are required.
142 
143-            field, source, opts, join_list, last, _ = self.setup_joins(
144+            field, source, opts, join_list, last, _, allow_trim_join = self.setup_joins(
145                 field_list, opts, self.get_initial_alias(), False)
146 
147             # Process the join chain to see if it can be trimmed
148@@ -1083,7 +1091,7 @@
149         allow_many = trim or not negate
150 
151         try:
152-            field, target, opts, join_list, last, extra_filters = self.setup_joins(
153+            field, target, opts, join_list, last, extra_filters, allow_trim_join = self.setup_joins(
154                     parts, opts, alias, True, allow_many, allow_explicit_fk=True,
155                     can_reuse=can_reuse, negate=negate,
156                     process_extras=process_extras)
157@@ -1103,6 +1111,13 @@
158             self.promote_alias_chain(join_list)
159             join_promote = True
160 
161+        # If we have a one2one or many2one field, we can trim the left outer
162+        # join from the end of a list of joins.
163+        # In order to do this, we convert alias join type back to INNER and
164+        # trim_joins later will do the strip for us.
165+        if allow_trim_join and field.rel:
166+            self.demote_alias(join_list[-1])
167+
168         # Process the join list to see if we can remove any inner joins from
169         # the far end (fewer tables in a query is better).
170         nonnull_comparison = (lookup_type == 'isnull' and value is False)
171@@ -1259,6 +1274,7 @@
172         dupe_set = set()
173         exclusions = set()
174         extra_filters = []
175+        allow_trim_join = True
176         int_alias = None
177         for pos, name in enumerate(names):
178             if int_alias is not None:
179@@ -1282,6 +1298,11 @@
180                     raise FieldError("Cannot resolve keyword %r into field. "
181                             "Choices are: %s" % (name, ", ".join(names)))
182 
183+            # presence of indirect field in the filter requires
184+            # left outer join for isnull
185+            if not direct and allow_trim_join:
186+                allow_trim_join = False
187+
188             if not allow_many and (m2m or not direct):
189                 for alias in joins:
190                     self.unref_alias(alias)
191@@ -1323,6 +1344,8 @@
192                 extra_filters.extend(field.extra_filters(names, pos, negate))
193             if direct:
194                 if m2m:
195+                    # null query on m2mfield requires outer join
196+                    allow_trim_join = False
197                     # Many-to-many field defined on the current model.
198                     if cached_data:
199                         (table1, from_col1, to_col1, table2, from_col2,
200@@ -1443,7 +1466,7 @@
201             else:
202                 raise FieldError("Join on field %r not permitted." % name)
203 
204-        return field, target, opts, joins, last, extra_filters
205+        return field, target, opts, joins, last, extra_filters, allow_trim_join
206 
207     def trim_joins(self, target, join_list, last, trim, nonnull_check=False):
208         """
209@@ -1605,7 +1628,7 @@
210 
211         try:
212             for name in field_names:
213-                field, target, u2, joins, u3, u4 = self.setup_joins(
214+                field, target, u2, joins, u3, u4, allow_trim_join = self.setup_joins(
215                         name.split(LOOKUP_SEP), opts, alias, False, allow_m2m,
216                         True)
217                 final_alias = joins[-1]
218@@ -1887,7 +1910,7 @@
219         """
220         opts = self.model._meta
221         alias = self.get_initial_alias()
222-        field, col, opts, joins, last, extra = self.setup_joins(
223+        field, col, opts, joins, last, extra, allow_trim_join = self.setup_joins(
224                 start.split(LOOKUP_SEP), opts, alias, False)
225         select_col = self.alias_map[joins[1]][LHS_JOIN_COL]
226         select_alias = alias