Code

Ticket #3566: aggregate.diff

File aggregate.diff, 4.9 KB (added by Nicolas Lara <nicolaslara@…>, 6 years ago)

Patch that adds aggregate functionality (no annotate yet)

Line 
1Index: django/db/models/sql/query.py
2===================================================================
3--- django/db/models/sql/query.py       (revision 7350)
4+++ django/db/models/sql/query.py       (working copy)
5@@ -55,6 +55,7 @@
6         self.start_meta = None
7 
8         # SQL-related attributes
9+        self.aggregates = []
10         self.select = []
11         self.tables = []    # Aliases in the order they are created.
12         self.where = where()
13@@ -140,6 +141,7 @@
14         obj.standard_ordering = self.standard_ordering
15         obj.start_meta = self.start_meta
16         obj.select = self.select[:]
17+        obj.aggregates = self.aggregates[:]
18         obj.tables = self.tables[:]
19         obj.where = deepcopy(self.where)
20         obj.where_class = self.where_class
21@@ -173,6 +175,30 @@
22                     row = self.resolve_columns(row, fields)
23                 yield row
24 
25+    def get_aggregation(self):
26+        for field in self.select:
27+            self.group_by.append(field)
28+        self.select.extend(self.aggregates)
29+        self.aggregates = []
30+        #print self.as_sql()
31+        #print 'after', self.select
32+
33+        get_name = lambda x : isinstance(x, tuple) and x[1] or x.aliased_name
34+
35+        print 'final query', self.as_sql()
36+
37+        if self.group_by:
38+            data = self.execute_sql(MULTI)
39+            result = []
40+            for rs in data.next():
41+                result.append(dict(zip([get_name(i) for i in self.select], rs)))
42+        else:
43+            data = self.execute_sql(SINGLE)
44+            result = dict(zip([get_name(i) for i in self.select], data))
45+
46+        self.select = []
47+        return result
48+
49     def get_count(self):
50         """
51         Performs a COUNT() query using the current filter constraints.
52@@ -808,6 +834,53 @@
53             self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
54                     used, next, restricted)
55 
56+    def add_aggregate(self, aggregate_expr, aliased_name, model):
57+        """
58+        Adds a single aggregate expression to the Query
59+        """
60+       
61+        field_list = aggregate_expr.split(LOOKUP_SEP)
62+        opts = model._meta
63+
64+        aggregate_func = field_list.pop()
65+       
66+        if len(field_list) > 1:
67+            field, target, opts, join_list, last = self.setup_joins(
68+                field_list, opts, self.get_initial_alias(), False)
69+            final = len(join_list)
70+            penultimate = last.pop()
71+            if penultimate == final:
72+                penultimate = last.pop()
73+            if len(join_list) > 1:
74+                extra = join_list[penultimate:]
75+                final = penultimate
76+                col = self.alias_map[extra[0]][LHS_JOIN_COL]
77+            else:
78+                col = target.column
79+               
80+            field_name = field_list.pop()
81+            alias = join_list[-1]
82+            alias = extra[final]
83+        else:
84+            field_name = field_list[0]
85+            alias = opts.db_table
86+
87+        class AggregateNode:
88+            def __init__(self, field_name, aggregate_func, aliased_name, alias):
89+                self.field_name = field_name
90+                self.aggregate_func = aggregate_func
91+                self.aliased_name = aliased_name
92+                self.alias = alias
93+               
94+            def as_sql(self, quote_func=None):
95+                if not quote_func:
96+                    quote_func = lambda x: x
97+                return '%s(%s.%s)' % (self.aggregate_func.upper(),
98+                                      quote_func(self.alias),
99+                                      quote_func(self.field_name))
100+
101+        self.aggregates.append(AggregateNode(field_name, aggregate_func, aliased_name, alias))
102+       
103     def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
104             single_filter=False):
105         """
106Index: django/db/models/query.py
107===================================================================
108--- django/db/models/query.py   (revision 7350)
109+++ django/db/models/query.py   (working copy)
110@@ -165,6 +165,17 @@
111                 setattr(obj, k, row[i])
112             yield obj
113 
114+    def aggregate(self, *args, **kwargs):
115+        """
116+        Returns the aggregation over the current model as values (or so it should).
117+        """
118+        if args:
119+            TypeError('Unexpected positional arguments')
120+           
121+        for (aggregate_expr, alias) in kwargs.items():
122+            self.query.add_aggregate(aggregate_expr, alias, self.model)
123+        return self.query.get_aggregation()
124+
125     def count(self):
126         """
127         Performs a SELECT COUNT() and returns the number of records as an
128@@ -342,7 +353,7 @@
129         Returns a new QuerySet that is a copy of the current one. This allows a
130         QuerySet to proxy for a model manager in some cases.
131         """
132-        return self._clone()
133+        return self._clone()       
134 
135     def filter(self, *args, **kwargs):
136         """