Code

Ticket #3566: aggregate.2.diff

File aggregate.2.diff, 8.6 KB (added by nicolas, 6 years ago)
Line 
1Index: django/db/models/sql/query.py
2===================================================================
3--- django/db/models/sql/query.py       (revision 7436)
4+++ django/db/models/sql/query.py       (working copy)
5@@ -56,6 +56,7 @@
6         self.start_meta = None
7 
8         # SQL-related attributes
9+        self.aggregates = []
10         self.select = []
11         self.tables = []    # Aliases in the order they are created.
12         self.where = where()
13@@ -141,6 +142,7 @@
14         obj.standard_ordering = self.standard_ordering
15         obj.start_meta = self.start_meta
16         obj.select = self.select[:]
17+        obj.aggregates = self.aggregates[:]
18         obj.tables = self.tables[:]
19         obj.where = deepcopy(self.where)
20         obj.where_class = self.where_class
21@@ -174,6 +176,30 @@
22                     row = self.resolve_columns(row, fields)
23                 yield row
24 
25+    def get_aggregation(self):
26+        for field in self.select:
27+            self.group_by.append(field)
28+        self.select.extend(self.aggregates)
29+        self.aggregates = []
30+        #print self.as_sql()
31+        #print 'after', self.select
32+
33+        get_name = lambda x : isinstance(x, tuple) and x[1] or x.aliased_name
34+
35+        print 'final query', self.as_sql()
36+
37+        if self.group_by:
38+            data = self.execute_sql(MULTI)
39+            result = []
40+            for rs in data.next():
41+                result.append(dict(zip([get_name(i) for i in self.select], rs)))
42+        else:
43+            data = self.execute_sql(SINGLE)
44+            result = dict(zip([get_name(i) for i in self.select], data))
45+
46+        self.select = []
47+        return result
48+
49     def get_count(self):
50         """
51         Performs a COUNT() query using the current filter constraints.
52@@ -811,6 +837,81 @@
53             self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
54                     used, next, restricted)
55 
56+    def annotate(self, aggregate_expr, aliased_name, model):
57+        field_list = aggregate_expr.split(LOOKUP_SEP)
58+        opts = model._meta
59+
60+        aggregate_func = field_list.pop()
61+       
62+        if len(field_list) > 1:
63+            field, target, opts, join_list, last = self.setup_joins(
64+                field_list, opts, self.get_initial_alias(), False)
65+            final = len(join_list)
66+            penultimate = last.pop()
67+            if penultimate == final:
68+                penultimate = last.pop()
69+            if len(join_list) > 1:
70+                extra = join_list[penultimate:]
71+                final = penultimate
72+                col = self.alias_map[extra[0]][LHS_JOIN_COL]
73+            else:
74+                col = target.column
75+               
76+            field_name = field_list.pop()
77+            alias = join_list[-1]
78+            alias = extra[final]
79+        else:
80+            field_name = field_list[0]
81+            alias = opts.db_table
82+         
83+
84+    def add_aggregate(self, aggregate_expr, aliased_name, model):
85+        """
86+        Adds a single aggregate expression to the Query
87+        """
88+       
89+        field_list = aggregate_expr.split(LOOKUP_SEP)
90+        opts = model._meta
91+
92+        aggregate_func = field_list.pop()
93+       
94+        if len(field_list) > 1:
95+            field, target, opts, join_list, last = self.setup_joins(
96+                field_list, opts, self.get_initial_alias(), False)
97+            final = len(join_list)
98+            penultimate = last.pop()
99+            if penultimate == final:
100+                penultimate = last.pop()
101+            if len(join_list) > 1:
102+                extra = join_list[penultimate:]
103+                final = penultimate
104+                col = self.alias_map[extra[0]][LHS_JOIN_COL]
105+            else:
106+                col = target.column
107+               
108+            field_name = field_list.pop()
109+            alias = join_list[-1]
110+            alias = extra[final]
111+        else:
112+            field_name = field_list[0]
113+            alias = opts.db_table
114+
115+        class AggregateNode:
116+            def __init__(self, field_name, aggregate_func, aliased_name, alias):
117+                self.field_name = field_name
118+                self.aggregate_func = aggregate_func
119+                self.aliased_name = aliased_name
120+                self.alias = alias
121+               
122+            def as_sql(self, quote_func=None):
123+                if not quote_func:
124+                    quote_func = lambda x: x
125+                return '%s(%s.%s)' % (self.aggregate_func.upper(),
126+                                      quote_func(self.alias),
127+                                      quote_func(self.field_name))
128+
129+        self.aggregates.append(AggregateNode(field_name, aggregate_func, aliased_name, alias))
130+       
131     def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
132             single_filter=False):
133         """
134@@ -829,6 +930,10 @@
135         if not parts:
136             raise FieldError("Cannot parse keyword query %r" % arg)
137 
138+        # if arg in (x.aliased_name for x in self.aggregates):
139+        #     self.having.append(arg)
140+        #     return
141+
142         # Work out the lookup type and remove it from 'parts', if necessary.
143         if len(parts) == 1 or parts[-1] not in self.query_terms:
144             lookup_type = 'exact'
145Index: django/db/models/query.py
146===================================================================
147--- django/db/models/query.py   (revision 7436)
148+++ django/db/models/query.py   (working copy)
149@@ -158,6 +158,34 @@
150                 setattr(obj, k, row[i])
151             yield obj
152 
153+    def aggregate(self, *args, **kwargs):
154+        """
155+        Returns the aggregation over the current model as a
156+        dictionary.
157+
158+        When applied to a ValuesQuerySet the results are GROUP BY-ed
159+        by the fields specified in the values queryset.
160+
161+        The kwargs are parsed as expression='alias'.
162+
163+        If args is present the expression is passed as a kwarg with
164+        itself as an alias.
165+        """
166+        #Bug (or is it?): when doing both an aggregation on a related
167+        #field and one on a 'local' field the local one goes
168+        #wrong. something similar to: SELECT SUM(a.f1) FROM a INNER JOIN b;
169+        #the value gets aggregated more than one time.
170+
171+        if args:
172+            newargs = {}
173+            for arg in args:
174+                newargs[arg] = arg
175+            kwargs.update(newargs)
176+           
177+        for (aggregate_expr, alias) in kwargs.items():
178+            self.query.add_aggregate(aggregate_expr, alias, self.model)
179+        return self.query.get_aggregation()
180+
181     def count(self):
182         """
183         Performs a SELECT COUNT() and returns the number of records as an
184@@ -326,6 +354,32 @@
185         """
186         return self._clone(klass=EmptyQuerySet)
187 
188+    def annotate(self, *args, **kwargs):
189+        # Fix: Values is not working propperly
190+        # Suffers from the same bug as aggrgate
191+        # To-Do: HAVING
192+
193+        if args:
194+            newargs = {}
195+            for arg in args:
196+                newargs[arg] = arg
197+            kwargs.update(newargs)
198+
199+        opts = self.model._meta
200+        fields = []
201+       
202+        if isinstance(self, ValuesQuerySet):
203+            obj = self._clone()
204+        else:
205+            fields.extend([f.name for f in opts.fields])
206+            obj = self._clone(klass=ValuesQuerySet, setup=True, _fields=fields)
207+           
208+
209+        for (aggregate_expr, alias) in kwargs.items():
210+            obj.query.add_aggregate(aggregate_expr, alias, self.model)
211+
212+        return obj
213+
214     ##################################################################
215     # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
216     ##################################################################
217@@ -335,7 +389,7 @@
218         Returns a new QuerySet that is a copy of the current one. This allows a
219         QuerySet to proxy for a model manager in some cases.
220         """
221-        return self._clone()
222+        return self._clone()       
223 
224     def filter(self, *args, **kwargs):
225         """
226@@ -488,14 +542,21 @@
227         # names of the model fields to select.
228 
229     def __iter__(self):
230+        if self.query.aggregates:
231+            return self.aggregate_iterator()
232         return self.iterator()
233 
234+    def aggregate_iterator(self):
235+        #Not lazy.. review
236+        for i in self.query.get_aggregation():
237+            yield i
238+
239     def iterator(self):
240         self.query.trim_extra_select(self.extra_names)
241         names = self.query.extra_select.keys() + self.field_names
242         for row in self.query.results_iter():
243             yield dict(zip(names, row))
244-
245+
246     def _setup_query(self):
247         """
248         Constructs the field_names list that the values query will be