diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
index a2349cf..3d8b63a5 100644
a
|
b
|
class Aggregate(object):
|
6 | 6 | """ |
7 | 7 | Default Aggregate definition. |
8 | 8 | """ |
9 | | def __init__(self, lookup, **extra): |
| 9 | def __init__(self, lookup, only=None, **extra): |
10 | 10 | """Instantiate a new aggregate. |
11 | 11 | |
12 | 12 | * lookup is the field on which the aggregate operates. |
| 13 | * only is a Q-object used in conditional aggregation. |
13 | 14 | * extra is a dictionary of additional data to provide for the |
14 | 15 | aggregate definition |
15 | 16 | |
… |
… |
class Aggregate(object):
|
18 | 19 | """ |
19 | 20 | self.lookup = lookup |
20 | 21 | self.extra = extra |
| 22 | self.only = only |
| 23 | self.condition = None |
21 | 24 | |
22 | 25 | def _default_alias(self): |
23 | 26 | return '%s__%s' % (self.lookup, self.name.lower()) |
… |
… |
class Aggregate(object):
|
42 | 45 | summary value rather than an annotation. |
43 | 46 | """ |
44 | 47 | klass = getattr(query.aggregates_module, self.name) |
45 | | aggregate = klass(col, source=source, is_summary=is_summary, **self.extra) |
| 48 | aggregate = klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra) |
46 | 49 | query.aggregates[alias] = aggregate |
47 | 50 | |
48 | 51 | class Avg(Aggregate): |
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
index 207bc0c..0de19a3 100644
a
|
b
|
class Aggregate(object):
|
22 | 22 | is_ordinal = False |
23 | 23 | is_computed = False |
24 | 24 | sql_template = '%(function)s(%(field)s)' |
| 25 | conditional_template = "CASE WHEN %(condition)s THEN %(field_name)s ELSE null END" |
25 | 26 | |
26 | | def __init__(self, col, source=None, is_summary=False, **extra): |
| 27 | def __init__(self, col, source=None, is_summary=False, condition=None, **extra): |
27 | 28 | """Instantiate an SQL aggregate |
28 | 29 | |
29 | 30 | * col is a column reference describing the subject field |
… |
… |
class Aggregate(object):
|
52 | 53 | self.source = source |
53 | 54 | self.is_summary = is_summary |
54 | 55 | self.extra = extra |
| 56 | self.condition = condition |
55 | 57 | |
56 | 58 | # Follow the chain of aggregate sources back until you find an |
57 | 59 | # actual field, or an aggregate that forces a particular output |
… |
… |
class Aggregate(object):
|
76 | 78 | def as_sql(self, qn, connection): |
77 | 79 | "Return the aggregate, rendered as SQL." |
78 | 80 | |
| 81 | query_params = [] |
79 | 82 | if hasattr(self.col, 'as_sql'): |
80 | 83 | field_name = self.col.as_sql(qn, connection) |
81 | 84 | elif isinstance(self.col, (list, tuple)): |
82 | 85 | field_name = '.'.join([qn(c) for c in self.col]) |
83 | 86 | else: |
84 | 87 | field_name = self.col |
85 | | |
86 | | params = { |
87 | | 'function': self.sql_function, |
88 | | 'field': field_name |
89 | | } |
| 88 | if self.condition: |
| 89 | condition = self.condition.as_sql(qn, connection) |
| 90 | query_params = condition[1] |
| 91 | conditional_field = self.conditional_template % { |
| 92 | 'condition': condition[0], |
| 93 | 'field_name': field_name |
| 94 | } |
| 95 | params = { |
| 96 | 'function': self.sql_function, |
| 97 | 'field': conditional_field, |
| 98 | } |
| 99 | else: |
| 100 | params = { |
| 101 | 'function': self.sql_function, |
| 102 | 'field': field_name |
| 103 | } |
90 | 104 | params.update(self.extra) |
91 | 105 | |
92 | | return self.sql_template % params |
| 106 | return (self.sql_template % params, query_params) |
93 | 107 | |
94 | 108 | |
95 | 109 | class Avg(Aggregate): |
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 841ec12..a0d2f35 100644
a
|
b
|
class SQLCompiler(object):
|
58 | 58 | return '', () |
59 | 59 | |
60 | 60 | self.pre_sql_setup() |
61 | | out_cols = self.get_columns(with_col_aliases) |
| 61 | out_cols, c_params = self.get_columns(with_col_aliases) |
62 | 62 | ordering, ordering_group_by = self.get_ordering() |
| 63 | params = [] |
| 64 | params.extend(c_params) |
63 | 65 | |
64 | 66 | # This must come after 'select' and 'ordering' -- see docstring of |
65 | 67 | # get_from_clause() for details. |
… |
… |
class SQLCompiler(object):
|
69 | 71 | |
70 | 72 | where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection) |
71 | 73 | having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection) |
72 | | params = [] |
73 | 74 | for val in self.query.extra_select.itervalues(): |
74 | 75 | params.extend(val[1]) |
75 | 76 | |
… |
… |
class SQLCompiler(object):
|
126 | 127 | if nowait and not self.connection.features.has_select_for_update_nowait: |
127 | 128 | raise DatabaseError('NOWAIT is not supported on this database backend.') |
128 | 129 | result.append(self.connection.ops.for_update_sql(nowait=nowait)) |
129 | | |
130 | 130 | return ' '.join(result), tuple(params) |
131 | 131 | |
132 | 132 | def as_nested_sql(self): |
… |
… |
class SQLCompiler(object):
|
158 | 158 | qn = self.quote_name_unless_alias |
159 | 159 | qn2 = self.connection.ops.quote_name |
160 | 160 | result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()] |
| 161 | query_params = [] |
161 | 162 | aliases = set(self.query.extra_select.keys()) |
162 | 163 | if with_aliases: |
163 | 164 | col_aliases = aliases.copy() |
… |
… |
class SQLCompiler(object):
|
200 | 201 | aliases.update(new_aliases) |
201 | 202 | |
202 | 203 | max_name_length = self.connection.ops.max_name_length() |
203 | | result.extend([ |
204 | | '%s%s' % ( |
205 | | aggregate.as_sql(qn, self.connection), |
206 | | alias is not None |
207 | | and ' AS %s' % qn(truncate_name(alias, max_name_length)) |
208 | | or '' |
| 204 | for alias, aggregate in self.query.aggregate_select.items(): |
| 205 | sql, params = aggregate.as_sql(qn, self.connection) |
| 206 | result.append( |
| 207 | '%s%s' % ( |
| 208 | sql, |
| 209 | alias is not None |
| 210 | and ' AS %s' % qn(truncate_name(alias, max_name_length)) |
| 211 | or '' |
| 212 | ) |
209 | 213 | ) |
210 | | for alias, aggregate in self.query.aggregate_select.items() |
211 | | ]) |
| 214 | query_params.extend(params) |
212 | 215 | |
213 | 216 | for table, col in self.query.related_select_cols: |
214 | 217 | r = '%s.%s' % (qn(table), qn(col)) |
… |
… |
class SQLCompiler(object):
|
223 | 226 | col_aliases.add(col) |
224 | 227 | |
225 | 228 | self._select_aliases = aliases |
226 | | return result |
| 229 | return result, query_params |
227 | 230 | |
228 | 231 | def get_default_columns(self, with_aliases=False, col_aliases=None, |
229 | 232 | start_alias=None, opts=None, as_pairs=False, local_only=False): |
… |
… |
class SQLAggregateCompiler(SQLCompiler):
|
948 | 951 | """ |
949 | 952 | if qn is None: |
950 | 953 | qn = self.quote_name_unless_alias |
| 954 | buf = [] |
| 955 | a_params = [] |
| 956 | for aggregate in self.query.aggregate_select.values(): |
| 957 | sql, query_params = aggregate.as_sql(qn, self.connection) |
| 958 | buf.append(sql) |
| 959 | a_params.extend(query_params) |
| 960 | aggregate_sql = ', '.join(buf) |
951 | 961 | sql = ('SELECT %s FROM (%s) subquery' % ( |
952 | | ', '.join([ |
953 | | aggregate.as_sql(qn, self.connection) |
954 | | for aggregate in self.query.aggregate_select.values() |
955 | | ]), |
| 962 | aggregate_sql, |
956 | 963 | self.query.subquery) |
957 | 964 | ) |
958 | | params = self.query.sub_params |
| 965 | params = tuple(a_params) + (self.query.sub_params) |
959 | 966 | return (sql, params) |
960 | 967 | |
961 | 968 | class SQLDateCompiler(SQLCompiler): |
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
index 1bbf742..f9c23a9 100644
a
|
b
|
class SQLEvaluator(object):
|
65 | 65 | for child in node.children: |
66 | 66 | if hasattr(child, 'evaluate'): |
67 | 67 | sql, params = child.evaluate(self, qn, connection) |
| 68 | if isinstance(sql, tuple): |
| 69 | expression_params.extend(sql[1]) |
| 70 | sql = sql[0] |
68 | 71 | else: |
69 | 72 | sql, params = '%s', (child,) |
70 | 73 | |
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 110e317..7735261 100644
a
|
b
|
class Query(object):
|
956 | 956 | """ |
957 | 957 | opts = model._meta |
958 | 958 | field_list = aggregate.lookup.split(LOOKUP_SEP) |
| 959 | only = aggregate.only |
| 960 | join_list = [] |
959 | 961 | if len(field_list) == 1 and aggregate.lookup in self.aggregates: |
960 | 962 | # Aggregate is over an annotation |
961 | 963 | field_name = field_list[0] |
… |
… |
class Query(object):
|
964 | 966 | if not is_summary: |
965 | 967 | raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( |
966 | 968 | aggregate.name, field_name, field_name)) |
| 969 | if only: |
| 970 | raise FieldError("Cannot use aggregated fields in conditional aggregates") |
967 | 971 | elif ((len(field_list) > 1) or |
968 | 972 | (field_list[0] not in [i.name for i in opts.fields]) or |
969 | 973 | self.group_by is None or |
… |
… |
class Query(object):
|
995 | 999 | source = opts.get_field(field_name) |
996 | 1000 | col = field_name |
997 | 1001 | |
| 1002 | if only: |
| 1003 | original_where = self.where |
| 1004 | original_having = self.having |
| 1005 | aggregate.condition = self.where_class() |
| 1006 | self.where = aggregate.condition |
| 1007 | self.having = self.where_class() |
| 1008 | original_alias_map = self.alias_map.keys()[:] |
| 1009 | self.add_q(only, used_aliases=set(original_alias_map)) |
| 1010 | if original_alias_map != self.alias_map.keys(): |
| 1011 | raise FieldError("Aggregate's only condition can not require additional joins, Original joins: %s, joins after: %s" % (original_alias_map, self.alias_map.keys())) |
| 1012 | if self.having.children: |
| 1013 | raise FieldError("Aggregate's only condition can not reference annotated fields") |
| 1014 | self.having = original_having |
| 1015 | self.where = original_where |
998 | 1016 | # Add the aggregate to the query |
999 | 1017 | aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) |
1000 | 1018 | |
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index 2427a52..7ac0287 100644
a
|
b
|
class WhereNode(tree.Node):
|
137 | 137 | it. |
138 | 138 | """ |
139 | 139 | lvalue, lookup_type, value_annot, params_or_value = child |
| 140 | additional_params = [] |
140 | 141 | if hasattr(lvalue, 'process'): |
141 | 142 | try: |
142 | 143 | lvalue, params = lvalue.process(lookup_type, params_or_value, connection) |
… |
… |
class WhereNode(tree.Node):
|
151 | 152 | else: |
152 | 153 | # A smart object with an as_sql() method. |
153 | 154 | field_sql = lvalue.as_sql(qn, connection) |
| 155 | if isinstance(field_sql, tuple): |
| 156 | # It returned also params |
| 157 | additional_params.extend(field_sql[1]) |
| 158 | field_sql = field_sql[0] |
154 | 159 | |
155 | 160 | if value_annot is datetime.datetime: |
156 | 161 | cast_sql = connection.ops.datetime_cast_sql() |
… |
… |
class WhereNode(tree.Node):
|
159 | 164 | |
160 | 165 | if hasattr(params, 'as_sql'): |
161 | 166 | extra, params = params.as_sql(qn, connection) |
| 167 | if isinstance(extra, tuple): |
| 168 | params = params + tuple(extra[1]) |
| 169 | extra = extra[0] |
162 | 170 | cast_sql = '' |
163 | 171 | else: |
164 | 172 | extra = '' |
… |
… |
class WhereNode(tree.Node):
|
168 | 176 | lookup_type = 'isnull' |
169 | 177 | value_annot = True |
170 | 178 | |
| 179 | additional_params.extend(params) |
| 180 | params = additional_params |
171 | 181 | if lookup_type in connection.operators: |
172 | 182 | format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) |
173 | 183 | return (format % (field_sql, |
diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py
index 6f68800..3a04fa9 100644
a
|
b
|
import datetime
|
2 | 2 | from decimal import Decimal |
3 | 3 | |
4 | 4 | from django.db.models import Avg, Sum, Count, Max, Min |
| 5 | from django.db.models import Q, F |
| 6 | from django.core.exceptions import FieldError |
5 | 7 | from django.test import TestCase, Approximate |
6 | 8 | |
7 | 9 | from models import Author, Publisher, Book, Store |
… |
… |
class BaseAggregateTestCase(TestCase):
|
16 | 18 | def test_single_aggregate(self): |
17 | 19 | vals = Author.objects.aggregate(Avg("age")) |
18 | 20 | self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)}) |
| 21 | vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29))) |
| 22 | self.assertEqual(vals, {"age__sum": 254}) |
| 23 | vals = Author.objects.aggregate(Sum("age", only=Q(name__icontains='jaco')|Q(name__icontains='adrian'))) |
| 24 | self.assertEqual(vals, {"age__sum": 69}) |
19 | 25 | |
20 | 26 | def test_multiple_aggregates(self): |
21 | 27 | vals = Author.objects.aggregate(Sum("age"), Avg("age")) |
22 | 28 | self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)}) |
| 29 | vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)), Avg("age")) |
| 30 | self.assertEqual(vals, {"age__sum": 254, "age__avg": Approximate(37.4, places=1)}) |
23 | 31 | |
24 | 32 | def test_filter_aggregate(self): |
25 | 33 | vals = Author.objects.filter(age__gt=29).aggregate(Sum("age")) |
26 | 34 | self.assertEqual(len(vals), 1) |
27 | 35 | self.assertEqual(vals["age__sum"], 254) |
| 36 | vals = Author.objects.filter(age__gt=29).aggregate(Sum("age", only=Q(age__lt=29))) |
| 37 | # If there are no matching aggregates, then None, not 0 is the answer. |
| 38 | self.assertEqual(vals["age__sum"], None) |
28 | 39 | |
29 | 40 | def test_related_aggregate(self): |
30 | 41 | vals = Author.objects.aggregate(Avg("friends__age")) |
31 | 42 | self.assertEqual(len(vals), 1) |
32 | 43 | self.assertAlmostEqual(vals["friends__age__avg"], 34.07, places=2) |
33 | 44 | |
| 45 | vals = Author.objects.aggregate(Avg("friends__age", only=Q(age__lt=29))) |
| 46 | self.assertEqual(len(vals), 1) |
| 47 | self.assertAlmostEqual(vals["friends__age__avg"], 33.67, places=2) |
| 48 | vals2 = Author.objects.filter(age__lt=29).aggregate(Avg("friends__age")) |
| 49 | self.assertEqual(vals, vals2) |
| 50 | |
| 51 | vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=35))) |
| 52 | self.assertEqual(len(vals), 1) |
| 53 | self.assertAlmostEqual(vals["friends__age__avg"], 28.75, places=2) |
| 54 | |
| 55 | # The average age of author's friends, whose age is lower than the authors age. |
| 56 | vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=F('age')))) |
| 57 | self.assertEqual(len(vals), 1) |
| 58 | self.assertAlmostEqual(vals["friends__age__avg"], 30.43, places=2) |
| 59 | |
34 | 60 | vals = Book.objects.filter(rating__lt=4.5).aggregate(Avg("authors__age")) |
35 | 61 | self.assertEqual(len(vals), 1) |
36 | 62 | self.assertAlmostEqual(vals["authors__age__avg"], 38.2857, places=2) |
… |
… |
class BaseAggregateTestCase(TestCase):
|
51 | 77 | vals = Store.objects.aggregate(Max("books__authors__age")) |
52 | 78 | self.assertEqual(len(vals), 1) |
53 | 79 | self.assertEqual(vals["books__authors__age__max"], 57) |
| 80 | |
| 81 | vals = Store.objects.aggregate(Max("books__authors__age", only=Q(books__authors__age__lt=56))) |
| 82 | self.assertEqual(len(vals), 1) |
| 83 | self.assertEqual(vals["books__authors__age__max"], 46) |
54 | 84 | |
55 | 85 | vals = Author.objects.aggregate(Min("book__publisher__num_awards")) |
56 | 86 | self.assertEqual(len(vals), 1) |
… |
… |
class BaseAggregateTestCase(TestCase):
|
106 | 136 | ], |
107 | 137 | lambda b: (b.name, b.num_authors) |
108 | 138 | ) |
| 139 | |
| 140 | def raises_exception(): |
| 141 | list(Book.objects.annotate(num_authors=Count("authors")).annotate(num_authors2=Count("authors", only=Q(num_authors__gt=1))).order_by("name")) |
| 142 | |
| 143 | self.assertRaises(FieldError, raises_exception) |
109 | 144 | |
110 | 145 | def test_backwards_m2m_annotate(self): |
111 | 146 | authors = Author.objects.filter(name__contains="a").annotate(Avg("book__rating")).order_by("name") |
… |
… |
class BaseAggregateTestCase(TestCase):
|
192 | 227 | } |
193 | 228 | ] |
194 | 229 | ) |
| 230 | books = Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age', only=Q(authors__age__lt=35))).values('pk', 'isbn', 'mean_age') |
| 231 | self.assertEqual( |
| 232 | list(books), [ |
| 233 | { |
| 234 | "pk": 1, |
| 235 | "isbn": "159059725", |
| 236 | "mean_age": 34.0, |
| 237 | } |
| 238 | ] |
| 239 | ) |
195 | 240 | |
196 | 241 | books = Book.objects.filter(pk=1).annotate(mean_age=Avg("authors__age")).values("name") |
197 | 242 | self.assertEqual( |
… |
… |
class BaseAggregateTestCase(TestCase):
|
269 | 314 | |
270 | 315 | vals = Book.objects.aggregate(Count("rating", distinct=True)) |
271 | 316 | self.assertEqual(vals, {"rating__count": 4}) |
| 317 | vals = Book.objects.aggregate( |
| 318 | low_count=Count("rating", only=Q(rating__lt=4)), |
| 319 | high_count=Count("rating", only=Q(rating__gte=4)) |
| 320 | ) |
| 321 | self.assertEqual(vals, {"low_count": 1, 'high_count': 5}) |
| 322 | vals = Book.objects.aggregate( |
| 323 | low_count=Count("rating", distinct=True, only=Q(rating__lt=4)), |
| 324 | high_count=Count("rating", distinct=True, only=Q(rating__gte=4)) |
| 325 | ) |
| 326 | self.assertEqual(vals, {"low_count": 1, 'high_count': 3}) |
272 | 327 | |
273 | 328 | def test_fkey_aggregate(self): |
274 | 329 | explicit = list(Author.objects.annotate(Count('book__id'))) |
… |
… |
class BaseAggregateTestCase(TestCase):
|
388 | 443 | ], |
389 | 444 | lambda p: p.name, |
390 | 445 | ) |
| 446 | publishers = Publisher.objects.annotate(num_books=Count("book__id", only=Q(book__id__gt=5))).filter(num_books__gt=1, book__price__lt=Decimal("40.0")).order_by("pk") |
| 447 | self.assertQuerysetEqual( |
| 448 | publishers, [ |
| 449 | "Expensive Publisher", |
| 450 | ], |
| 451 | lambda p: p.name, |
| 452 | ) |
391 | 453 | |
392 | 454 | publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk") |
393 | 455 | self.assertQuerysetEqual( |