diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
index a2349cf..3d8b63a5 100644
|
a
|
b
|
class Aggregate(object):
|
| 6 | 6 | """ |
| 7 | 7 | Default Aggregate definition. |
| 8 | 8 | """ |
| 9 | | def __init__(self, lookup, **extra): |
| | 9 | def __init__(self, lookup, only=None, **extra): |
| 10 | 10 | """Instantiate a new aggregate. |
| 11 | 11 | |
| 12 | 12 | * lookup is the field on which the aggregate operates. |
| | 13 | * only is a Q-object used in conditional aggregation. |
| 13 | 14 | * extra is a dictionary of additional data to provide for the |
| 14 | 15 | aggregate definition |
| 15 | 16 | |
| … |
… |
class Aggregate(object):
|
| 18 | 19 | """ |
| 19 | 20 | self.lookup = lookup |
| 20 | 21 | self.extra = extra |
| | 22 | self.only = only |
| | 23 | self.condition = None |
| 21 | 24 | |
| 22 | 25 | def _default_alias(self): |
| 23 | 26 | return '%s__%s' % (self.lookup, self.name.lower()) |
| … |
… |
class Aggregate(object):
|
| 42 | 45 | summary value rather than an annotation. |
| 43 | 46 | """ |
| 44 | 47 | klass = getattr(query.aggregates_module, self.name) |
| 45 | | aggregate = klass(col, source=source, is_summary=is_summary, **self.extra) |
| | 48 | aggregate = klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra) |
| 46 | 49 | query.aggregates[alias] = aggregate |
| 47 | 50 | |
| 48 | 51 | class Avg(Aggregate): |
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
index 207bc0c..0de19a3 100644
|
a
|
b
|
class Aggregate(object):
|
| 22 | 22 | is_ordinal = False |
| 23 | 23 | is_computed = False |
| 24 | 24 | sql_template = '%(function)s(%(field)s)' |
| | 25 | conditional_template = "CASE WHEN %(condition)s THEN %(field_name)s ELSE null END" |
| 25 | 26 | |
| 26 | | def __init__(self, col, source=None, is_summary=False, **extra): |
| | 27 | def __init__(self, col, source=None, is_summary=False, condition=None, **extra): |
| 27 | 28 | """Instantiate an SQL aggregate |
| 28 | 29 | |
| 29 | 30 | * col is a column reference describing the subject field |
| … |
… |
class Aggregate(object):
|
| 52 | 53 | self.source = source |
| 53 | 54 | self.is_summary = is_summary |
| 54 | 55 | self.extra = extra |
| | 56 | self.condition = condition |
| 55 | 57 | |
| 56 | 58 | # Follow the chain of aggregate sources back until you find an |
| 57 | 59 | # actual field, or an aggregate that forces a particular output |
| … |
… |
class Aggregate(object):
|
| 76 | 78 | def as_sql(self, qn, connection): |
| 77 | 79 | "Return the aggregate, rendered as SQL." |
| 78 | 80 | |
| | 81 | query_params = [] |
| 79 | 82 | if hasattr(self.col, 'as_sql'): |
| 80 | 83 | field_name = self.col.as_sql(qn, connection) |
| 81 | 84 | elif isinstance(self.col, (list, tuple)): |
| 82 | 85 | field_name = '.'.join([qn(c) for c in self.col]) |
| 83 | 86 | else: |
| 84 | 87 | field_name = self.col |
| 85 | | |
| 86 | | params = { |
| 87 | | 'function': self.sql_function, |
| 88 | | 'field': field_name |
| 89 | | } |
| | 88 | if self.condition: |
| | 89 | condition = self.condition.as_sql(qn, connection) |
| | 90 | query_params = condition[1] |
| | 91 | conditional_field = self.conditional_template % { |
| | 92 | 'condition': condition[0], |
| | 93 | 'field_name': field_name |
| | 94 | } |
| | 95 | params = { |
| | 96 | 'function': self.sql_function, |
| | 97 | 'field': conditional_field, |
| | 98 | } |
| | 99 | else: |
| | 100 | params = { |
| | 101 | 'function': self.sql_function, |
| | 102 | 'field': field_name |
| | 103 | } |
| 90 | 104 | params.update(self.extra) |
| 91 | 105 | |
| 92 | | return self.sql_template % params |
| | 106 | return (self.sql_template % params, query_params) |
| 93 | 107 | |
| 94 | 108 | |
| 95 | 109 | class Avg(Aggregate): |
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 841ec12..a0d2f35 100644
|
a
|
b
|
class SQLCompiler(object):
|
| 58 | 58 | return '', () |
| 59 | 59 | |
| 60 | 60 | self.pre_sql_setup() |
| 61 | | out_cols = self.get_columns(with_col_aliases) |
| | 61 | out_cols, c_params = self.get_columns(with_col_aliases) |
| 62 | 62 | ordering, ordering_group_by = self.get_ordering() |
| | 63 | params = [] |
| | 64 | params.extend(c_params) |
| 63 | 65 | |
| 64 | 66 | # This must come after 'select' and 'ordering' -- see docstring of |
| 65 | 67 | # get_from_clause() for details. |
| … |
… |
class SQLCompiler(object):
|
| 69 | 71 | |
| 70 | 72 | where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection) |
| 71 | 73 | having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection) |
| 72 | | params = [] |
| 73 | 74 | for val in self.query.extra_select.itervalues(): |
| 74 | 75 | params.extend(val[1]) |
| 75 | 76 | |
| … |
… |
class SQLCompiler(object):
|
| 126 | 127 | if nowait and not self.connection.features.has_select_for_update_nowait: |
| 127 | 128 | raise DatabaseError('NOWAIT is not supported on this database backend.') |
| 128 | 129 | result.append(self.connection.ops.for_update_sql(nowait=nowait)) |
| 129 | | |
| 130 | 130 | return ' '.join(result), tuple(params) |
| 131 | 131 | |
| 132 | 132 | def as_nested_sql(self): |
| … |
… |
class SQLCompiler(object):
|
| 158 | 158 | qn = self.quote_name_unless_alias |
| 159 | 159 | qn2 = self.connection.ops.quote_name |
| 160 | 160 | result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()] |
| | 161 | query_params = [] |
| 161 | 162 | aliases = set(self.query.extra_select.keys()) |
| 162 | 163 | if with_aliases: |
| 163 | 164 | col_aliases = aliases.copy() |
| … |
… |
class SQLCompiler(object):
|
| 200 | 201 | aliases.update(new_aliases) |
| 201 | 202 | |
| 202 | 203 | max_name_length = self.connection.ops.max_name_length() |
| 203 | | result.extend([ |
| 204 | | '%s%s' % ( |
| 205 | | aggregate.as_sql(qn, self.connection), |
| 206 | | alias is not None |
| 207 | | and ' AS %s' % qn(truncate_name(alias, max_name_length)) |
| 208 | | or '' |
| | 204 | for alias, aggregate in self.query.aggregate_select.items(): |
| | 205 | sql, params = aggregate.as_sql(qn, self.connection) |
| | 206 | result.append( |
| | 207 | '%s%s' % ( |
| | 208 | sql, |
| | 209 | alias is not None |
| | 210 | and ' AS %s' % qn(truncate_name(alias, max_name_length)) |
| | 211 | or '' |
| | 212 | ) |
| 209 | 213 | ) |
| 210 | | for alias, aggregate in self.query.aggregate_select.items() |
| 211 | | ]) |
| | 214 | query_params.extend(params) |
| 212 | 215 | |
| 213 | 216 | for table, col in self.query.related_select_cols: |
| 214 | 217 | r = '%s.%s' % (qn(table), qn(col)) |
| … |
… |
class SQLCompiler(object):
|
| 223 | 226 | col_aliases.add(col) |
| 224 | 227 | |
| 225 | 228 | self._select_aliases = aliases |
| 226 | | return result |
| | 229 | return result, query_params |
| 227 | 230 | |
| 228 | 231 | def get_default_columns(self, with_aliases=False, col_aliases=None, |
| 229 | 232 | start_alias=None, opts=None, as_pairs=False, local_only=False): |
| … |
… |
class SQLAggregateCompiler(SQLCompiler):
|
| 948 | 951 | """ |
| 949 | 952 | if qn is None: |
| 950 | 953 | qn = self.quote_name_unless_alias |
| | 954 | buf = [] |
| | 955 | a_params = [] |
| | 956 | for aggregate in self.query.aggregate_select.values(): |
| | 957 | sql, query_params = aggregate.as_sql(qn, self.connection) |
| | 958 | buf.append(sql) |
| | 959 | a_params.extend(query_params) |
| | 960 | aggregate_sql = ', '.join(buf) |
| 951 | 961 | sql = ('SELECT %s FROM (%s) subquery' % ( |
| 952 | | ', '.join([ |
| 953 | | aggregate.as_sql(qn, self.connection) |
| 954 | | for aggregate in self.query.aggregate_select.values() |
| 955 | | ]), |
| | 962 | aggregate_sql, |
| 956 | 963 | self.query.subquery) |
| 957 | 964 | ) |
| 958 | | params = self.query.sub_params |
| | 965 | params = tuple(a_params) + (self.query.sub_params) |
| 959 | 966 | return (sql, params) |
| 960 | 967 | |
| 961 | 968 | class SQLDateCompiler(SQLCompiler): |
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
index 1bbf742..f9c23a9 100644
|
a
|
b
|
class SQLEvaluator(object):
|
| 65 | 65 | for child in node.children: |
| 66 | 66 | if hasattr(child, 'evaluate'): |
| 67 | 67 | sql, params = child.evaluate(self, qn, connection) |
| | 68 | if isinstance(sql, tuple): |
| | 69 | expression_params.extend(sql[1]) |
| | 70 | sql = sql[0] |
| 68 | 71 | else: |
| 69 | 72 | sql, params = '%s', (child,) |
| 70 | 73 | |
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 110e317..7735261 100644
|
a
|
b
|
class Query(object):
|
| 956 | 956 | """ |
| 957 | 957 | opts = model._meta |
| 958 | 958 | field_list = aggregate.lookup.split(LOOKUP_SEP) |
| | 959 | only = aggregate.only |
| | 960 | join_list = [] |
| 959 | 961 | if len(field_list) == 1 and aggregate.lookup in self.aggregates: |
| 960 | 962 | # Aggregate is over an annotation |
| 961 | 963 | field_name = field_list[0] |
| … |
… |
class Query(object):
|
| 964 | 966 | if not is_summary: |
| 965 | 967 | raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( |
| 966 | 968 | aggregate.name, field_name, field_name)) |
| | 969 | if only: |
| | 970 | raise FieldError("Cannot use aggregated fields in conditional aggregates") |
| 967 | 971 | elif ((len(field_list) > 1) or |
| 968 | 972 | (field_list[0] not in [i.name for i in opts.fields]) or |
| 969 | 973 | self.group_by is None or |
| … |
… |
class Query(object):
|
| 995 | 999 | source = opts.get_field(field_name) |
| 996 | 1000 | col = field_name |
| 997 | 1001 | |
| | 1002 | if only: |
| | 1003 | original_where = self.where |
| | 1004 | original_having = self.having |
| | 1005 | aggregate.condition = self.where_class() |
| | 1006 | self.where = aggregate.condition |
| | 1007 | self.having = self.where_class() |
| | 1008 | original_alias_map = self.alias_map.keys()[:] |
| | 1009 | self.add_q(only, used_aliases=set(original_alias_map)) |
| | 1010 | if original_alias_map != self.alias_map.keys(): |
| | 1011 | raise FieldError("Aggregate's only condition can not require additional joins, Original joins: %s, joins after: %s" % (original_alias_map, self.alias_map.keys())) |
| | 1012 | if self.having.children: |
| | 1013 | raise FieldError("Aggregate's only condition can not reference annotated fields") |
| | 1014 | self.having = original_having |
| | 1015 | self.where = original_where |
| 998 | 1016 | # Add the aggregate to the query |
| 999 | 1017 | aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) |
| 1000 | 1018 | |
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index 2427a52..7ac0287 100644
|
a
|
b
|
class WhereNode(tree.Node):
|
| 137 | 137 | it. |
| 138 | 138 | """ |
| 139 | 139 | lvalue, lookup_type, value_annot, params_or_value = child |
| | 140 | additional_params = [] |
| 140 | 141 | if hasattr(lvalue, 'process'): |
| 141 | 142 | try: |
| 142 | 143 | lvalue, params = lvalue.process(lookup_type, params_or_value, connection) |
| … |
… |
class WhereNode(tree.Node):
|
| 151 | 152 | else: |
| 152 | 153 | # A smart object with an as_sql() method. |
| 153 | 154 | field_sql = lvalue.as_sql(qn, connection) |
| | 155 | if isinstance(field_sql, tuple): |
| | 156 | # It returned also params |
| | 157 | additional_params.extend(field_sql[1]) |
| | 158 | field_sql = field_sql[0] |
| 154 | 159 | |
| 155 | 160 | if value_annot is datetime.datetime: |
| 156 | 161 | cast_sql = connection.ops.datetime_cast_sql() |
| … |
… |
class WhereNode(tree.Node):
|
| 159 | 164 | |
| 160 | 165 | if hasattr(params, 'as_sql'): |
| 161 | 166 | extra, params = params.as_sql(qn, connection) |
| | 167 | if isinstance(extra, tuple): |
| | 168 | params = params + tuple(extra[1]) |
| | 169 | extra = extra[0] |
| 162 | 170 | cast_sql = '' |
| 163 | 171 | else: |
| 164 | 172 | extra = '' |
| … |
… |
class WhereNode(tree.Node):
|
| 168 | 176 | lookup_type = 'isnull' |
| 169 | 177 | value_annot = True |
| 170 | 178 | |
| | 179 | additional_params.extend(params) |
| | 180 | params = additional_params |
| 171 | 181 | if lookup_type in connection.operators: |
| 172 | 182 | format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) |
| 173 | 183 | return (format % (field_sql, |
diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py
index 6f68800..3a04fa9 100644
|
a
|
b
|
import datetime
|
| 2 | 2 | from decimal import Decimal |
| 3 | 3 | |
| 4 | 4 | from django.db.models import Avg, Sum, Count, Max, Min |
| | 5 | from django.db.models import Q, F |
| | 6 | from django.core.exceptions import FieldError |
| 5 | 7 | from django.test import TestCase, Approximate |
| 6 | 8 | |
| 7 | 9 | from models import Author, Publisher, Book, Store |
| … |
… |
class BaseAggregateTestCase(TestCase):
|
| 16 | 18 | def test_single_aggregate(self): |
| 17 | 19 | vals = Author.objects.aggregate(Avg("age")) |
| 18 | 20 | self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)}) |
| | 21 | vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29))) |
| | 22 | self.assertEqual(vals, {"age__sum": 254}) |
| | 23 | vals = Author.objects.aggregate(Sum("age", only=Q(name__icontains='jaco')|Q(name__icontains='adrian'))) |
| | 24 | self.assertEqual(vals, {"age__sum": 69}) |
| 19 | 25 | |
| 20 | 26 | def test_multiple_aggregates(self): |
| 21 | 27 | vals = Author.objects.aggregate(Sum("age"), Avg("age")) |
| 22 | 28 | self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)}) |
| | 29 | vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)), Avg("age")) |
| | 30 | self.assertEqual(vals, {"age__sum": 254, "age__avg": Approximate(37.4, places=1)}) |
| 23 | 31 | |
| 24 | 32 | def test_filter_aggregate(self): |
| 25 | 33 | vals = Author.objects.filter(age__gt=29).aggregate(Sum("age")) |
| 26 | 34 | self.assertEqual(len(vals), 1) |
| 27 | 35 | self.assertEqual(vals["age__sum"], 254) |
| | 36 | vals = Author.objects.filter(age__gt=29).aggregate(Sum("age", only=Q(age__lt=29))) |
| | 37 | # If there are no matching aggregates, then None, not 0 is the answer. |
| | 38 | self.assertEqual(vals["age__sum"], None) |
| 28 | 39 | |
| 29 | 40 | def test_related_aggregate(self): |
| 30 | 41 | vals = Author.objects.aggregate(Avg("friends__age")) |
| 31 | 42 | self.assertEqual(len(vals), 1) |
| 32 | 43 | self.assertAlmostEqual(vals["friends__age__avg"], 34.07, places=2) |
| 33 | 44 | |
| | 45 | vals = Author.objects.aggregate(Avg("friends__age", only=Q(age__lt=29))) |
| | 46 | self.assertEqual(len(vals), 1) |
| | 47 | self.assertAlmostEqual(vals["friends__age__avg"], 33.67, places=2) |
| | 48 | vals2 = Author.objects.filter(age__lt=29).aggregate(Avg("friends__age")) |
| | 49 | self.assertEqual(vals, vals2) |
| | 50 | |
| | 51 | vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=35))) |
| | 52 | self.assertEqual(len(vals), 1) |
| | 53 | self.assertAlmostEqual(vals["friends__age__avg"], 28.75, places=2) |
| | 54 | |
| | 55 | # The average age of author's friends, whose age is lower than the authors age. |
| | 56 | vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=F('age')))) |
| | 57 | self.assertEqual(len(vals), 1) |
| | 58 | self.assertAlmostEqual(vals["friends__age__avg"], 30.43, places=2) |
| | 59 | |
| 34 | 60 | vals = Book.objects.filter(rating__lt=4.5).aggregate(Avg("authors__age")) |
| 35 | 61 | self.assertEqual(len(vals), 1) |
| 36 | 62 | self.assertAlmostEqual(vals["authors__age__avg"], 38.2857, places=2) |
| … |
… |
class BaseAggregateTestCase(TestCase):
|
| 51 | 77 | vals = Store.objects.aggregate(Max("books__authors__age")) |
| 52 | 78 | self.assertEqual(len(vals), 1) |
| 53 | 79 | self.assertEqual(vals["books__authors__age__max"], 57) |
| | 80 | |
| | 81 | vals = Store.objects.aggregate(Max("books__authors__age", only=Q(books__authors__age__lt=56))) |
| | 82 | self.assertEqual(len(vals), 1) |
| | 83 | self.assertEqual(vals["books__authors__age__max"], 46) |
| 54 | 84 | |
| 55 | 85 | vals = Author.objects.aggregate(Min("book__publisher__num_awards")) |
| 56 | 86 | self.assertEqual(len(vals), 1) |
| … |
… |
class BaseAggregateTestCase(TestCase):
|
| 106 | 136 | ], |
| 107 | 137 | lambda b: (b.name, b.num_authors) |
| 108 | 138 | ) |
| | 139 | |
| | 140 | def raises_exception(): |
| | 141 | list(Book.objects.annotate(num_authors=Count("authors")).annotate(num_authors2=Count("authors", only=Q(num_authors__gt=1))).order_by("name")) |
| | 142 | |
| | 143 | self.assertRaises(FieldError, raises_exception) |
| 109 | 144 | |
| 110 | 145 | def test_backwards_m2m_annotate(self): |
| 111 | 146 | authors = Author.objects.filter(name__contains="a").annotate(Avg("book__rating")).order_by("name") |
| … |
… |
class BaseAggregateTestCase(TestCase):
|
| 192 | 227 | } |
| 193 | 228 | ] |
| 194 | 229 | ) |
| | 230 | books = Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age', only=Q(authors__age__lt=35))).values('pk', 'isbn', 'mean_age') |
| | 231 | self.assertEqual( |
| | 232 | list(books), [ |
| | 233 | { |
| | 234 | "pk": 1, |
| | 235 | "isbn": "159059725", |
| | 236 | "mean_age": 34.0, |
| | 237 | } |
| | 238 | ] |
| | 239 | ) |
| 195 | 240 | |
| 196 | 241 | books = Book.objects.filter(pk=1).annotate(mean_age=Avg("authors__age")).values("name") |
| 197 | 242 | self.assertEqual( |
| … |
… |
class BaseAggregateTestCase(TestCase):
|
| 269 | 314 | |
| 270 | 315 | vals = Book.objects.aggregate(Count("rating", distinct=True)) |
| 271 | 316 | self.assertEqual(vals, {"rating__count": 4}) |
| | 317 | vals = Book.objects.aggregate( |
| | 318 | low_count=Count("rating", only=Q(rating__lt=4)), |
| | 319 | high_count=Count("rating", only=Q(rating__gte=4)) |
| | 320 | ) |
| | 321 | self.assertEqual(vals, {"low_count": 1, 'high_count': 5}) |
| | 322 | vals = Book.objects.aggregate( |
| | 323 | low_count=Count("rating", distinct=True, only=Q(rating__lt=4)), |
| | 324 | high_count=Count("rating", distinct=True, only=Q(rating__gte=4)) |
| | 325 | ) |
| | 326 | self.assertEqual(vals, {"low_count": 1, 'high_count': 3}) |
| 272 | 327 | |
| 273 | 328 | def test_fkey_aggregate(self): |
| 274 | 329 | explicit = list(Author.objects.annotate(Count('book__id'))) |
| … |
… |
class BaseAggregateTestCase(TestCase):
|
| 388 | 443 | ], |
| 389 | 444 | lambda p: p.name, |
| 390 | 445 | ) |
| | 446 | publishers = Publisher.objects.annotate(num_books=Count("book__id", only=Q(book__id__gt=5))).filter(num_books__gt=1, book__price__lt=Decimal("40.0")).order_by("pk") |
| | 447 | self.assertQuerysetEqual( |
| | 448 | publishers, [ |
| | 449 | "Expensive Publisher", |
| | 450 | ], |
| | 451 | lambda p: p.name, |
| | 452 | ) |
| 391 | 453 | |
| 392 | 454 | publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk") |
| 393 | 455 | self.assertQuerysetEqual( |