Code

Ticket #14930: extra_order_by_values_list_with_tests_3.diff

File extra_order_by_values_list_with_tests_3.diff, 5.5 KB (added by fhahn, 15 months ago)

updated version of the patch, with tests

Line 
1diff --git a/django/db/models/query.py b/django/db/models/query.py
2index eda71d2..01e3b9f 100644
3--- a/django/db/models/query.py
4+++ b/django/db/models/query.py
5@@ -1017,8 +1017,17 @@ class ValuesQuerySet(QuerySet):
6 
7         names = extra_names + field_names + aggregate_names
8 
9+        # If a field list has been specified, use it. Otherwise, use the
10+        # full list of fields, including extras and aggregates.
11+        if self._fields:
12+            fields = list(self._fields) + [f for f in aggregate_names if f not in self._fields]
13+        else:
14+            fields = names
15+
16         for row in self.query.get_compiler(self.db).results_iter():
17-            yield dict(zip(names, row))
18+            # removes the non-necessary fields in the intersection
19+            # between names and fields
20+            yield {key: elem for key, elem in zip(names, row) if key in fields}
21 
22     def delete(self):
23         # values().delete() doesn't work currently - make sure it raises an
24@@ -1064,8 +1073,13 @@ class ValuesQuerySet(QuerySet):
25             self.aggregate_names = None
26 
27         self.query.select = []
28+        # add fields that are required by order_by or extra_order_by and
29+        # present in extra_select to extra_mask
30+        order_by_fields = list(self.query.order_by or []) + list(self.query.extra_order_by or [])
31+        extra_mask_names = list(self.extra_names or []) +\
32+            [n for n in order_by_fields if n in self.query.extra_select]
33         if self.extra_names is not None:
34-            self.query.set_extra_mask(self.extra_names)
35+            self.query.set_extra_mask(extra_mask_names)
36         self.query.add_fields(self.field_names, True)
37         if self.aggregate_names is not None:
38             self.query.set_aggregate_mask(self.aggregate_names)
39@@ -1139,8 +1153,12 @@ class ValuesQuerySet(QuerySet):
40 class ValuesListQuerySet(ValuesQuerySet):
41     def iterator(self):
42         if self.flat and len(self._fields) == 1:
43+            # get fields that were added to extra_select_mask but are not
44+            # expected in the value_list (e.g: a field was added to
45+            # extra_select_mask because it was needed be an order by clause)
46+            extra_fields = [f for f in self.query.extra_select_mask if f not in self._fields]
47             for row in self.query.get_compiler(self.db).results_iter():
48-                yield row[0]
49+                yield row[len(extra_fields)]
50         elif not self.query.extra_select and not self.query.aggregate_select:
51             for row in self.query.get_compiler(self.db).results_iter():
52                 yield tuple(row)
53diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py
54index 4adf076..d8bb7fc 100644
55--- a/tests/regressiontests/queries/tests.py
56+++ b/tests/regressiontests/queries/tests.py
57@@ -1531,7 +1531,8 @@ class Queries5Tests(TestCase):
58         # An empty values() call includes all aliases, including those from an
59         # extra()
60         qs = Ranking.objects.extra(select={'good': 'case when rank > 2 then 1 else 0 end'})
61-        dicts = qs.values().order_by('id')
62+        dicts = qs.values()
63+        dicts = dicts.order_by('id')
64         for d in dicts: del d['id']; del d['author_id']
65         self.assertEqual(
66             [sorted(d.items()) for d in dicts],
67@@ -1995,14 +1996,58 @@ class EmptyQuerySetTests(TestCase):
68 
69 
70 class ValuesQuerysetTests(BaseQuerysetTest):
71-    def test_flat_values_lits(self):
72+    def setUp(self):
73         Number.objects.create(num=72)
74+
75+    def test_flat_values_list(self):
76         qs = Number.objects.values_list("num")
77         qs = qs.values_list("num", flat=True)
78         self.assertValueQuerysetEqual(
79             qs, [72]
80         )
81 
82+    def test_extra_values(self):
83+        # testing for ticket 14930 issues
84+        qs = Number.objects.extra(select={'value_plus_one': 'num+1', 'value_minus_one': 'num-1' })
85+        qs = qs.order_by('value_minus_one')
86+        qs = qs.values('num')
87+        identity = lambda x:x
88+        self.assertQuerysetEqual(qs, [{'num': 72}], identity)
89+
90+    def test_extra_values_order_twice(self):
91+        # testing for ticket 14930 issues
92+        qs = Number.objects.extra(select={'value_plus_one': 'num+1', 'value_minus_one': 'num-1' })
93+        qs = qs.order_by('value_minus_one').order_by('value_plus_one')
94+        qs = qs.values('num')
95+        identity = lambda x:x
96+        self.assertQuerysetEqual(qs, [{'num': 72}], identity)
97+
98+    def test_extra_values_order_in_extra(self):
99+        # testing for ticket 14930 issues
100+        qs = Number.objects.extra(
101+                select={'value_plus_one': 'num+1', 'value_minus_one': 'num-1' },
102+                order_by=['value_minus_one'])
103+        qs = qs.values('num')
104+        identity = lambda x:x
105+        self.assertQuerysetEqual(qs, [{'num': 72}], identity)
106+
107+    def test_extra_values_list(self):
108+        # testing for ticket 14930 issues
109+        qs = Number.objects.extra(select={'value_plus_one': 'num+1'})
110+        qs = qs.order_by('value_plus_one')
111+        qs = qs.values_list('num')
112+        identity = lambda x:x
113+        self.assertQuerysetEqual(qs, [(72,)], identity)
114+
115+    def test_flat_extra_values_list(self):
116+        # testing for ticket 14930 issues
117+        qs = Number.objects.extra(select={'value_plus_one': 'num+1'})
118+        qs = qs.order_by('value_plus_one')
119+        qs = qs.values_list('num', flat=True)
120+        identity = lambda x:x
121+        self.assertQuerysetEqual(qs, [72], identity)
122+
123+
124 
125 class WeirdQuerysetSlicingTests(BaseQuerysetTest):
126     def setUp(self):