Code

Ticket #3275: query.py.diff

File query.py.diff, 5.9 KB (added by David Cramer <dcramer@…>, 7 years ago)

diffs for django/db/models/query.py

Line 
1Index: query.py
2===================================================================
3--- query.py    (revision 4300)
4+++ query.py    (working copy)
5@@ -80,6 +80,8 @@
6         self._filters = Q()
7         self._order_by = None        # Ordering, e.g. ('date', '-name'). If None, use model's ordering.
8         self._select_related = False # Whether to fill cache for related objects.
9+        self._recurse_depth = 0      # Used to track how deep we are following for select_related()
10+        self._recurse_fields = []    # Fields to recurse through for select_related()
11         self._distinct = False       # Whether the query should use SELECT DISTINCT.
12         self._select = {}            # Dictionary of attname -> SQL.
13         self._where = []             # List of extra WHERE clauses to use.
14@@ -178,7 +180,7 @@
15                 raise StopIteration
16             for row in rows:
17                 if fill_cache:
18-                    obj, index_end = get_cached_row(self.model, row, 0)
19+                    obj, index_end = get_cached_row(self.model, row, 0, self._recurse_fields, self._recurse_depth)
20                 else:
21                     obj = self.model(*row[:index_end])
22                 for i, k in enumerate(extra_select):
23@@ -194,12 +196,12 @@
24         counter._select_related = False
25         select, sql, params = counter._get_sql_clause()
26         cursor = connection.cursor()
27+        id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table),
28+                backend.quote_name(self.model._meta.pk.column))
29         if self._distinct:
30-            id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table),
31-                    backend.quote_name(self.model._meta.pk.column))
32             cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params)
33         else:
34-            cursor.execute("SELECT COUNT(*)" + sql, params)
35+            cursor.execute("SELECT COUNT(%s)" % id_col + sql, params)
36         return cursor.fetchone()[0]
37 
38     def get(self, *args, **kwargs):
39@@ -359,9 +361,13 @@
40         else:
41             return self._filter_or_exclude(None, **filter_obj)
42 
43-    def select_related(self, true_or_false=True):
44+    # fields should be a list of field names in the root table, if specified, it modifies depth to 1
45+    # depth is the maximum number of children to recurse through, defaults to infinite
46+    def select_related(self, true_or_false=True, depth=0, fields=[]):
47         "Returns a new QuerySet instance with '_select_related' modified."
48-        return self._clone(_select_related=true_or_false)
49+        if fields != []:
50+            depth = 1
51+        return self._clone(_select_related=true_or_false, _recurse_depth=depth, _recurse_fields=fields)
52 
53     def order_by(self, *field_names):
54         "Returns a new QuerySet instance with the ordering changed."
55@@ -395,6 +401,8 @@
56         c._filters = self._filters
57         c._order_by = self._order_by
58         c._select_related = self._select_related
59+        c._recurse_fields = self._recurse_fields
60+        c._recurse_depth = self._recurse_depth
61         c._distinct = self._distinct
62         c._select = self._select.copy()
63         c._where = self._where[:]
64@@ -448,7 +456,7 @@
65 
66         # Add additional tables and WHERE clauses based on select_related.
67         if self._select_related:
68-            fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
69+            fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table], self._recurse_depth, self._recurse_fields)
70 
71         # Add any additional SELECTs.
72         if self._select:
73@@ -660,24 +668,30 @@
74         return backend.get_fulltext_search_sql(table_prefix + field_name)
75     raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type)
76 
77-def get_cached_row(klass, row, index_start):
78+def get_cached_row(klass, row, index_start, fields=[], max_depth=0, cur_depth=0):
79     "Helper function that recursively returns an object with cache filled"
80+    if max_depth and cur_depth > max_depth:
81+        return None
82     index_end = index_start + len(klass._meta.fields)
83     obj = klass(*row[index_start:index_end])
84     for f in klass._meta.fields:
85-        if f.rel and not f.null:
86-            rel_obj, index_end = get_cached_row(f.rel.to, row, index_end)
87-            setattr(obj, f.get_cache_name(), rel_obj)
88+        if f.rel and not f.null and (fields == [] or (cur_depth == 0 and f.name in fields)):
89+            cached_row = get_cached_row(f.rel.to, row, index_end, fields, max_depth, cur_depth+1)
90+            if cached_row:
91+                    rel_obj, index_end = cached_row
92+                    setattr(obj, f.get_cache_name(), rel_obj)
93     return obj, index_end
94 
95-def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen):
96+def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, fields=[], cur_depth=0):
97     """
98     Helper function that recursively populates the select, tables and where (in
99     place) for select_related queries.
100     """
101     qn = backend.quote_name
102+    if max_depth and cur_depth > max_depth:
103+        return
104     for f in opts.fields:
105-        if f.rel and not f.null:
106+        if f.rel and not f.null and (fields == [] or (cur_depth == 0 and f.name in fields)):
107             db_table = f.rel.to._meta.db_table
108             if db_table not in cache_tables_seen:
109                 tables.append(qn(db_table))
110@@ -689,7 +703,7 @@
111             where.append('%s.%s = %s.%s' % \
112                 (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column)))
113             select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields])
114-            fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen)
115+            fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, fields, cur_depth+1)
116 
117 def parse_lookup(kwarg_items, opts):
118     # Helper function that handles converting API kwargs