Code

Ticket #16759: #16759-remove_deepcopy_in_qs.diff

File #16759-remove_deepcopy_in_qs.diff, 9.4 KB (added by Kronuz, 14 months ago)

Fixes #19964

Line 
1diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
2index 1bbf742..51061ff 100644
3--- django/db/models/sql/expressions.py
4+++ django/db/models/sql/expressions.py
5@@ -1,6 +1,7 @@
6 from django.core.exceptions import FieldError
7 from django.db.models.fields import FieldDoesNotExist
8 from django.db.models.sql.constants import LOOKUP_SEP
9+import copy
10 
11 class SQLEvaluator(object):
12     def __init__(self, expression, query, allow_joins=True):
13@@ -12,6 +13,17 @@ def __init__(self, expression, query, allow_joins=True, reuse=None):
14         self.contains_aggregate = False
15         self.expression.prepare(self, query, allow_joins)
16 
17+    def clone(self):
18+        clone = copy.copy(self)
19+        clone.cols = {}
20+        for key, col in self.cols.items():
21+            if hasattr(col, 'clone'):
22+                clone.cols[key] = col.clone()
23+            else:
24+                clone.cols[key] = col
25+        return clone
26+
27+
28     def prepare(self):
29         return self
30 
31diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
32index b41314a..75a330f 100644
33--- django/db/models/sql/aggregates.py
34+++ django/db/models/sql/aggregates.py
35@@ -1,6 +1,7 @@
36 """
37 Classes to represent the default SQL aggregate functions
38 """
39+import copy
40 
41 from django.db.models.fields import IntegerField, FloatField
42 
43@@ -62,6 +63,11 @@ def __init__(self, col, source=None, is_summary=False, **extra):
44 
45         self.field = tmp
46 
47+    def clone(self):
48+        # Different aggregates have different init methods, so use copy here
49+        # deepcopy is not needed, as self.col is only changing variable.
50+        return copy.copy(self)
51+
52     def relabel_aliases(self, change_map):
53         if isinstance(self.col, (list, tuple)):
54             self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
55diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
56index 7f331bf..449404a 100644
57--- django/db/models/sql/query.py
58+++ django/db/models/sql/query.py
59@@ -256,13 +256,13 @@ def clone(self, klass=None, memo=None, **kwargs):
60         obj.dupe_avoidance = self.dupe_avoidance.copy()
61         obj.select = self.select[:]
62         obj.tables = self.tables[:]
63-        obj.where = copy.deepcopy(self.where, memo=memo)
64+        obj.where = self.where.clone()
65         obj.where_class = self.where_class
66         if self.group_by is None:
67             obj.group_by = None
68         else:
69             obj.group_by = self.group_by[:]
70-        obj.having = copy.deepcopy(self.having, memo=memo)
71+        obj.having = self.having.clone()
72         obj.order_by = self.order_by[:]
73         obj.low_mark, obj.high_mark = self.low_mark, self.high_mark
74         obj.distinct = self.distinct
75@@ -271,7 +271,8 @@ def clone(self, klass=None, memo=None, **kwargs):
76         obj.select_for_update_nowait = self.select_for_update_nowait
77         obj.select_related = self.select_related
78         obj.related_select_cols = []
79-        obj.aggregates = copy.deepcopy(self.aggregates, memo=memo)
80+        obj.aggregates = SortedDict((k, v.clone())
81+                                    for k, v in self.aggregates.items())
82         if self.aggregate_select_mask is None:
83             obj.aggregate_select_mask = None
84         else:
85@@ -294,7 +295,7 @@ def clone(self, klass=None, memo=None, **kwargs):
86             obj._extra_select_cache = self._extra_select_cache.copy()
87         obj.extra_tables = self.extra_tables
88         obj.extra_order_by = self.extra_order_by
89-        obj.deferred_loading = copy.deepcopy(self.deferred_loading, memo=memo)
90+        obj.deferred_loading = copy.copy(self.deferred_loading[0]), self.deferred_loading[1]
91         if self.filter_is_sticky and self.used_aliases:
92             obj.used_aliases = self.used_aliases.copy()
93         else:
94@@ -509,7 +510,7 @@ def combine(self, rhs, connector):
95         # Now relabel a copy of the rhs where-clause and add it to the current
96         # one.
97         if rhs.where:
98-            w = copy.deepcopy(rhs.where)
99+            w = rhs.where.clone()
100             w.relabel_aliases(change_map)
101             if not self.where:
102                 # Since 'self' matches everything, add an explicit "include
103@@ -530,7 +531,7 @@ def combine(self, rhs, connector):
104             if isinstance(col, (list, tuple)):
105                 self.select.append((change_map.get(col[0], col[0]), col[1]))
106             else:
107-                item = copy.deepcopy(col)
108+                item = col.clone()
109                 item.relabel_aliases(change_map)
110                 self.select.append(item)
111         self.select_fields = rhs.select_fields[:]
112diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
113index 5515bc4..f9475a6 100644
114--- django/db/models/sql/where.py
115+++ django/db/models/sql/where.py
116@@ -10,7 +10,7 @@
117 
118 from django.utils import tree
119 from django.db.models.fields import Field
120-from django.db.models.sql.datastructures import EmptyResultSet, FullResultSet
121+from django.db.models.sql.datastructures import EmptyResultSet, FullResultSet, Empty
122 from django.db.models.sql.aggregates import Aggregate
123 
124 # Connection types
125@@ -254,6 +254,23 @@ def relabel_aliases(self, change_map, node=None):
126                 # Check if the query value also requires relabelling
127                 if hasattr(child[3], 'relabel_aliases'):
128                     child[3].relabel_aliases(change_map)
129+
130+    def clone(self):
131+        """
132+        Creates a clone of the tree. Must only be called on root nodes (nodes
133+        with empty subtree_parents). Childs must be either Contraint, lookup,
134+        value tuples, or objects supporting .clone().
135+        """
136+        assert not self.subtree_parents, '.clone() can only be called on root nodes'
137+        clone = self.__class__._new_instance(
138+            children=[], connector=self.connector, negated=self.negated)
139+        for child in self.children:
140+            if isinstance(child, tuple):
141+                clone.children.append(
142+                    tuple(map(lambda o: o.clone() if hasattr(o, 'clone') else o, child)))
143+            else:
144+                clone.children.append(child.clone())
145+        return clone
146 
147 class EverythingNode(object):
148     """
149@@ -266,6 +283,9 @@ def as_sql(self, qn=None, connection=None):
150     def relabel_aliases(self, change_map, node=None):
151         return
152 
153+    def clone(self):
154+        return self
155+
156 class NothingNode(object):
157     """
158     A node that matches nothing.
159@@ -276,6 +296,9 @@ def as_sql(self, qn=None, connection=None):
160     def relabel_aliases(self, change_map, node=None):
161         return
162 
163+    def clone(self):
164+        return self
165+
166 class ExtraWhere(object):
167     def __init__(self, sqls, params):
168         self.sqls = sqls
169@@ -285,6 +308,9 @@ def as_sql(self, qn=None, connection=None):
170         sqls = ["(%s)" % sql for sql in self.sqls]
171         return " AND ".join(sqls), tuple(self.params or ())
172 
173+    def clone(self):
174+        return self
175+
176 class Constraint(object):
177     """
178     An object that can be passed to WhereNode.add() and knows how to
179@@ -349,3 +375,9 @@ def process(self, lookup_type, value, connection):
180     def relabel_aliases(self, change_map):
181         if self.alias in change_map:
182             self.alias = change_map[self.alias]
183+
184+    def clone(self):
185+        new = Empty()
186+        new.__class__ = self.__class__
187+        new.alias, new.col, new.field = self.alias, self.col, self.field
188+        return new
189diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py
190index 1035a38..9ef06ca 100644
191--- tests/regressiontests/queries/tests.py
192+++ tests/regressiontests/queries/tests.py
193@@ -1649,6 +1649,7 @@ def test_sliced_delete(self):
194 
195 
196 class CloneTests(TestCase):
197+
198     def test_evaluated_queryset_as_argument(self):
199         "#13227 -- If a queryset is already evaluated, it can still be used as a query arg"
200         n = Note(note='Test1', misc='misc')
201@@ -1666,6 +1667,39 @@ def test_evaluated_queryset_as_argument(self):
202         except:
203             self.fail('Query should be clonable')
204 
205+    def test_no_model_options_cloning(self):
206+        """
207+        Test that cloning a queryset does not get out of hand. While complete
208+        testing is impossible, this is a sanity check against invalid use of
209+        deepcopy. refs #16759.
210+        """
211+        opts_class = type(Note._meta)
212+        note_deepcopy = getattr(opts_class, "__deepcopy__", None)
213+        opts_class.__deepcopy__ = lambda obj, memo: self.fail("Model options shouldn't be cloned.")
214+        try:
215+            Note.objects.filter(pk__lte=F('pk') + 1).all()
216+        finally:
217+            if note_deepcopy is None:
218+                delattr(opts_class, "__deepcopy__")
219+            else:
220+                opts_class.__deepcopy__ = note_deepcopy
221+
222+    def test_no_fields_cloning(self):
223+        """
224+        Test that cloning a queryset does not get out of hand. While complete
225+        testing is impossible, this is a sanity check against invalid use of
226+        deepcopy. refs #16759.
227+        """
228+        opts_class = type(Note._meta.get_field_by_name("misc")[0])
229+        note_deepcopy = getattr(opts_class, "__deepcopy__", None)
230+        opts_class.__deepcopy__ = lambda obj, memo: self.fail("Model fields shouldn't be cloned")
231+        try:
232+            Note.objects.filter(note=F('misc')).all()
233+        finally:
234+            if note_deepcopy is None:
235+                delattr(opts_class, "__deepcopy__")
236+            else:
237+                opts_class.__deepcopy__ = note_deepcopy
238 
239 class EmptyQuerySetTests(TestCase):
240     def test_emptyqueryset_values(self):