Ticket #11305: conditional_agg.patch
File conditional_agg.patch, 29.0 KB (added by , 12 years ago) |
---|
-
django/db/models/aggregates.py
From e3654e9da731fedca85bb31e4fa1afe252f88268 Mon Sep 17 00:00:00 2001 From: Andre Terra <andreterra@gmail.com> Date: Mon, 28 May 2012 18:22:27 -0300 Subject: [PATCH] Added conditional aggregates. --- django/db/models/aggregates.py | 9 ++- django/db/models/expressions.py | 8 +- django/db/models/sql/aggregates.py | 43 +++++++++++--- django/db/models/sql/compiler.py | 39 ++++++++----- django/db/models/sql/expressions.py | 16 ++++-- django/db/models/sql/query.py | 103 ++++++++++++++++++++------------ django/db/models/sql/where.py | 10 +++ tests/modeltests/aggregation/tests.py | 93 +++++++++++++++++++++++++++++ 8 files changed, 247 insertions(+), 74 deletions(-) mode change 100644 => 100755 django/db/models/sql/aggregates.py mode change 100644 => 100755 django/db/models/sql/where.py diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index a2349cf..d816aa7 100644
a b class Aggregate(object): 6 6 """ 7 7 Default Aggregate definition. 8 8 """ 9 def __init__(self, lookup, **extra):9 def __init__(self, lookup, only=None, **extra): 10 10 """Instantiate a new aggregate. 11 11 12 12 * lookup is the field on which the aggregate operates. 13 * only is a Q-object used in conditional aggregation. 13 14 * extra is a dictionary of additional data to provide for the 14 15 aggregate definition 15 16 … … class Aggregate(object): 18 19 """ 19 20 self.lookup = lookup 20 21 self.extra = extra 22 self.only = only 23 self.condition = None 21 24 22 25 def _default_alias(self): 26 if hasattr(self.lookup, 'evaluate'): 27 raise ValueError('When aggregating over an expression, you need to give an alias.') 23 28 return '%s__%s' % (self.lookup, self.name.lower()) 24 29 default_alias = property(_default_alias) 25 30 … … class Aggregate(object): 42 47 summary value rather than an annotation. 43 48 """ 44 49 klass = getattr(query.aggregates_module, self.name) 45 aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)50 aggregate = klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra) 46 51 query.aggregates[alias] = aggregate 47 52 48 53 class Avg(Aggregate): -
django/db/models/expressions.py
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index a71f4a3..390f475 100644
a b class ExpressionNode(tree.Node): 39 39 # VISITOR METHODS # 40 40 ################### 41 41 42 def prepare(self, evaluator, query, allow_joins ):43 return evaluator.prepare_node(self, query, allow_joins )42 def prepare(self, evaluator, query, allow_joins, promote_joins=False): 43 return evaluator.prepare_node(self, query, allow_joins, promote_joins) 44 44 45 45 def evaluate(self, evaluator, qn, connection): 46 46 return evaluator.evaluate_node(self, qn, connection) … … class F(ExpressionNode): 107 107 obj.name = self.name 108 108 return obj 109 109 110 def prepare(self, evaluator, query, allow_joins ):111 return evaluator.prepare_leaf(self, query, allow_joins )110 def prepare(self, evaluator, query, allow_joins, promote_joins=False): 111 return evaluator.prepare_leaf(self, query, allow_joins, promote_joins) 112 112 113 113 def evaluate(self, evaluator, qn, connection): 114 114 return evaluator.evaluate_leaf(self, qn, connection) -
django/db/models/sql/aggregates.py
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py old mode 100644 new mode 100755 index b41314a..5fe2215
a b Classes to represent the default SQL aggregate functions 3 3 """ 4 4 5 5 from django.db.models.fields import IntegerField, FloatField 6 from django.db.models.sql.expressions import SQLEvaluator 6 7 7 8 # Fake fields used to identify aggregate types in data-conversion operations. 8 9 ordinal_aggregate_field = IntegerField() … … class Aggregate(object): 15 16 is_ordinal = False 16 17 is_computed = False 17 18 sql_template = '%(function)s(%(field)s)' 19 conditional_template = "CASE WHEN %(condition)s THEN %(field_name)s ELSE null END" 18 20 19 def __init__(self, col, source=None, is_summary=False, **extra):21 def __init__(self, col, source=None, is_summary=False, condition=None, **extra): 20 22 """Instantiate an SQL aggregate 21 23 22 24 * col is a column reference describing the subject field … … class Aggregate(object): 26 28 the column reference. If the aggregate is not an ordinal or 27 29 computed type, this reference is used to determine the coerced 28 30 output type of the aggregate. 31 * condition is used in conditional aggregation. 29 32 * extra is a dictionary of additional data to provide for the 30 aggregate definition 33 aggregate definition. 31 34 32 35 Also utilizes the class variables: 33 36 * sql_function, the name of the SQL function that implements the … … class Aggregate(object): 35 38 * sql_template, a template string that is used to render the 36 39 aggregate into SQL. 37 40 * is_ordinal, a boolean indicating if the output of this aggregate 38 is an integer (e.g., a count) 41 is an integer (e.g., a count). 39 42 * is_computed, a boolean indicating if this output of this aggregate 40 43 is a computed float (e.g., an average), regardless of the input 41 44 type. … … class Aggregate(object): 45 48 self.source = source 46 49 self.is_summary = is_summary 47 50 self.extra = extra 51 self.condition = condition 48 52 49 53 # Follow the chain of aggregate sources back until you find an 50 54 # actual field, or an aggregate that forces a particular output … … class Aggregate(object): 65 69 def relabel_aliases(self, change_map): 66 70 if isinstance(self.col, (list, tuple)): 67 71 self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) 72 else: 73 self.col.relabel_aliases(change_map) 74 if self.condition: 75 self.condition.relabel_aliases(change_map) 68 76 69 77 def as_sql(self, qn, connection): 70 78 "Return the aggregate, rendered as SQL." 71 79 80 condition_params = [] 81 col_params = [] 72 82 if hasattr(self.col, 'as_sql'): 73 field_name = self.col.as_sql(qn, connection) 83 if isinstance(self.col, SQLEvaluator): 84 field_name, col_params = self.col.as_sql(qn, connection) 85 else: 86 field_name = self.col.as_sql(qn, connection) 74 87 elif isinstance(self.col, (list, tuple)): 75 88 field_name = '.'.join([qn(c) for c in self.col]) 76 89 else: 77 90 field_name = self.col 78 91 79 params = { 80 'function': self.sql_function, 81 'field': field_name 82 } 92 if self.condition: 93 condition, condition_params = self.condition.as_sql(qn, connection) 94 conditional_field = self.conditional_template % { 95 'condition': condition, 96 'field_name': field_name 97 } 98 params = { 99 'function': self.sql_function, 100 'field': conditional_field, 101 } 102 else: 103 params = { 104 'function': self.sql_function, 105 'field': field_name 106 } 83 107 params.update(self.extra) 84 108 85 return self.sql_template % params 109 condition_params.extend(col_params) 110 return (self.sql_template % params, condition_params) 86 111 87 112 88 113 class Avg(Aggregate): -
django/db/models/sql/compiler.py
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 5801b2f..fa1b15a 100644
a b class SQLCompiler(object): 68 68 # as the pre_sql_setup will modify query state in a way that forbids 69 69 # another run of it. 70 70 self.refcounts_before = self.query.alias_refcount.copy() 71 out_cols = self.get_columns(with_col_aliases)71 out_cols, c_params = self.get_columns(with_col_aliases) 72 72 ordering, ordering_group_by = self.get_ordering() 73 73 74 74 distinct_fields = self.get_distinct() … … class SQLCompiler(object): 84 84 params = [] 85 85 for val in self.query.extra_select.itervalues(): 86 86 params.extend(val[1]) 87 # Extra-select comes before aggregation in the select list 88 params.extend(c_params) 87 89 88 90 result = ['SELECT'] 89 91 … … class SQLCompiler(object): 178 180 qn = self.quote_name_unless_alias 179 181 qn2 = self.connection.ops.quote_name 180 182 result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()] 183 query_params = [] 181 184 aliases = set(self.query.extra_select.keys()) 182 185 if with_aliases: 183 186 col_aliases = aliases.copy() … … class SQLCompiler(object): 220 223 aliases.update(new_aliases) 221 224 222 225 max_name_length = self.connection.ops.max_name_length() 223 result.extend([ 224 '%s%s' % ( 225 aggregate.as_sql(qn, self.connection), 226 alias is not None 227 and ' AS %s' % qn(truncate_name(alias, max_name_length)) 228 or '' 226 for alias, aggregate in self.query.aggregate_select.items(): 227 sql, params = aggregate.as_sql(qn, self.connection) 228 result.append( 229 '%s%s' % ( 230 sql, 231 alias is not None 232 and ' AS %s' % qn(truncate_name(alias, max_name_length)) 233 or '' 234 ) 229 235 ) 230 for alias, aggregate in self.query.aggregate_select.items() 231 ]) 236 query_params.extend(params) 232 237 233 238 for table, col in self.query.related_select_cols: 234 239 r = '%s.%s' % (qn(table), qn(col)) … … class SQLCompiler(object): 243 248 col_aliases.add(col) 244 249 245 250 self._select_aliases = aliases 246 return result 251 return result, query_params 247 252 248 253 def get_default_columns(self, with_aliases=False, col_aliases=None, 249 254 start_alias=None, opts=None, as_pairs=False, local_only=False): … … class SQLAggregateCompiler(SQLCompiler): 1053 1058 """ 1054 1059 if qn is None: 1055 1060 qn = self.quote_name_unless_alias 1061 buf = [] 1062 a_params = [] 1063 for aggregate in self.query.aggregate_select.values(): 1064 sql, query_params = aggregate.as_sql(qn, self.connection) 1065 buf.append(sql) 1066 a_params.extend(query_params) 1067 aggregate_sql = ', '.join(buf) 1056 1068 1057 1069 sql = ('SELECT %s FROM (%s) subquery' % ( 1058 ', '.join([ 1059 aggregate.as_sql(qn, self.connection) 1060 for aggregate in self.query.aggregate_select.values() 1061 ]), 1070 aggregate_sql, 1062 1071 self.query.subquery) 1063 1072 ) 1064 params = self.query.sub_params1073 params = tuple(a_params) + (self.query.sub_params) 1065 1074 return (sql, params) 1066 1075 1067 1076 class SQLDateCompiler(SQLCompiler): -
django/db/models/sql/expressions.py
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 1bbf742..3df33f0 100644
a b from django.db.models.fields import FieldDoesNotExist 3 3 from django.db.models.sql.constants import LOOKUP_SEP 4 4 5 5 class SQLEvaluator(object): 6 def __init__(self, expression, query, allow_joins=True ):6 def __init__(self, expression, query, allow_joins=True, promote_joins=False): 7 7 self.expression = expression 8 8 self.opts = query.get_meta() 9 9 self.cols = {} 10 10 11 11 self.contains_aggregate = False 12 self.expression.prepare(self, query, allow_joins )12 self.expression.prepare(self, query, allow_joins, promote_joins) 13 13 14 14 def prepare(self): 15 15 return self … … class SQLEvaluator(object): 28 28 # Vistor methods for initial expression preparation # 29 29 ##################################################### 30 30 31 def prepare_node(self, node, query, allow_joins ):31 def prepare_node(self, node, query, allow_joins, promote_joins): 32 32 for child in node.children: 33 33 if hasattr(child, 'prepare'): 34 child.prepare(self, query, allow_joins )34 child.prepare(self, query, allow_joins, promote_joins) 35 35 36 def prepare_leaf(self, node, query, allow_joins ):36 def prepare_leaf(self, node, query, allow_joins, promote_joins): 37 37 if not allow_joins and LOOKUP_SEP in node.name: 38 38 raise FieldError("Joined field references are not permitted in this query") 39 39 … … class SQLEvaluator(object): 48 48 field_list, query.get_meta(), 49 49 query.get_initial_alias(), False) 50 50 col, _, join_list = query.trim_joins(source, join_list, last, False) 51 if promote_joins: 52 for column_alias in join_list: 53 query.promote_alias(column_alias, unconditional=True) 51 54 52 55 self.cols[node] = (join_list[-1], col) 53 56 except FieldDoesNotExist: … … class SQLEvaluator(object): 65 68 for child in node.children: 66 69 if hasattr(child, 'evaluate'): 67 70 sql, params = child.evaluate(self, qn, connection) 71 if isinstance(sql, tuple): 72 expression_params.extend(sql[1]) 73 sql = sql[0] 68 74 else: 69 75 sql, params = '%s', (child,) 70 76 -
django/db/models/sql/query.py
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 7f331bf..a13c3d4 100644
a b class Query(object): 974 974 Adds a single aggregate expression to the Query 975 975 """ 976 976 opts = model._meta 977 field_list = aggregate.lookup.split(LOOKUP_SEP) 978 if len(field_list) == 1 and aggregate.lookup in self.aggregates: 979 # Aggregate is over an annotation 980 field_name = field_list[0] 981 col = field_name 982 source = self.aggregates[field_name] 983 if not is_summary: 984 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 985 aggregate.name, field_name, field_name)) 986 elif ((len(field_list) > 1) or 987 (field_list[0] not in [i.name for i in opts.fields]) or 988 self.group_by is None or 989 not is_summary): 990 # If: 991 # - the field descriptor has more than one part (foo__bar), or 992 # - the field descriptor is referencing an m2m/m2o field, or 993 # - this is a reference to a model field (possibly inherited), or 994 # - this is an annotation over a model field 995 # then we need to explore the joins that are required. 996 997 field, source, opts, join_list, last, _ = self.setup_joins( 998 field_list, opts, self.get_initial_alias(), False) 999 1000 # Process the join chain to see if it can be trimmed 1001 col, _, join_list = self.trim_joins(source, join_list, last, False) 1002 1003 # If the aggregate references a model or field that requires a join, 1004 # those joins must be LEFT OUTER - empty join rows must be returned 1005 # in order for zeros to be returned for those aggregates. 1006 for column_alias in join_list: 1007 self.promote_alias(column_alias, unconditional=True) 1008 1009 col = (join_list[-1], col) 977 only = aggregate.only 978 if hasattr(aggregate.lookup, 'evaluate'): 979 # If lookup is a query expression, evaluate it 980 col = SQLEvaluator(aggregate.lookup, self, promote_joins=True) 981 # TODO: find out the real source of this field. If any field has 982 # is_computed, then source can be set to is_computed. 983 source = None 1010 984 else: 1011 # The simplest cases. No joins required - 1012 # just reference the provided column alias. 1013 field_name = field_list[0] 1014 source = opts.get_field(field_name) 1015 col = field_name 1016 985 field_list = aggregate.lookup.split(LOOKUP_SEP) 986 join_list = [] 987 if len(field_list) == 1 and aggregate.lookup in self.aggregates: 988 # Aggregate is over an annotation 989 field_name = field_list[0] 990 col = field_name 991 source = self.aggregates[field_name] 992 if not is_summary: 993 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 994 aggregate.name, field_name, field_name)) 995 if only: 996 raise FieldError("Cannot use aggregated fields in conditional aggregates") 997 elif ((len(field_list) > 1) or 998 (field_list[0] not in [i.name for i in opts.fields]) or 999 self.group_by is None or 1000 not is_summary): 1001 # If: 1002 # - the field descriptor has more than one part (foo__bar), or 1003 # - the field descriptor is referencing an m2m/m2o field, or 1004 # - this is a reference to a model field (possibly inherited), or 1005 # - this is an annotation over a model field 1006 # then we need to explore the joins that are required. 1007 1008 field, source, opts, join_list, last, _ = self.setup_joins( 1009 field_list, opts, self.get_initial_alias(), False) 1010 1011 # Process the join chain to see if it can be trimmed 1012 col, _, join_list = self.trim_joins(source, join_list, last, False) 1013 1014 # If the aggregate references a model or field that requires a join, 1015 # those joins must be LEFT OUTER - empty join rows must be returned 1016 # in order for zeros to be returned for those aggregates. 1017 for column_alias in join_list: 1018 self.promote_alias(column_alias, unconditional=True) 1019 1020 col = (join_list[-1], col) 1021 else: 1022 # The simplest cases. No joins required - 1023 # just reference the provided column alias. 1024 field_name = field_list[0] 1025 source = opts.get_field(field_name) 1026 col = field_name 1027 1028 if only: 1029 original_where = self.where 1030 original_having = self.having 1031 aggregate.condition = self.where_class() 1032 self.where = aggregate.condition 1033 self.having = self.where_class() 1034 original_alias_map = self.alias_map.keys()[:] 1035 self.add_q(only, used_aliases=set(original_alias_map)) 1036 if original_alias_map != self.alias_map.keys(): 1037 raise FieldError("Aggregate's only condition can not require additional joins, Original joins: %s, joins after: %s" % (original_alias_map, self.alias_map.keys())) 1038 if self.having.children: 1039 raise FieldError("Aggregate's only condition can not reference annotated fields") 1040 self.having = original_having 1041 self.where = original_where 1017 1042 # Add the aggregate to the query 1018 1043 aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) 1019 1044 -
django/db/models/sql/where.py
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py old mode 100644 new mode 100755 index 5515bc4..90f96c2
a b class WhereNode(tree.Node): 139 139 it. 140 140 """ 141 141 lvalue, lookup_type, value_annotation, params_or_value = child 142 additional_params = [] 142 143 if isinstance(lvalue, Constraint): 143 144 try: 144 145 lvalue, params = lvalue.process(lookup_type, params_or_value, connection) … … class WhereNode(tree.Node): 156 157 else: 157 158 # A smart object with an as_sql() method. 158 159 field_sql = lvalue.as_sql(qn, connection) 160 if isinstance(field_sql, tuple): 161 # It also returned params 162 additional_params.extend(field_sql[1]) 163 field_sql = field_sql[0] 159 164 160 165 if value_annotation is datetime.datetime: 161 166 cast_sql = connection.ops.datetime_cast_sql() … … class WhereNode(tree.Node): 164 169 165 170 if hasattr(params, 'as_sql'): 166 171 extra, params = params.as_sql(qn, connection) 172 if isinstance(extra, tuple): 173 params = params + tuple(extra[1]) 174 extra = extra[0] 167 175 cast_sql = '' 168 176 else: 169 177 extra = '' … … class WhereNode(tree.Node): 173 181 lookup_type = 'isnull' 174 182 value_annotation = True 175 183 184 additional_params.extend(params) 185 params = additional_params 176 186 if lookup_type in connection.operators: 177 187 format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) 178 188 return (format % (field_sql, -
tests/modeltests/aggregation/tests.py
diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py index a35dbb3..1e94784 100644
a b import datetime 4 4 from decimal import Decimal 5 5 6 6 from django.db.models import Avg, Sum, Count, Max, Min 7 from django.db.models import Q, F 8 from django.core.exceptions import FieldError 7 9 from django.test import TestCase, Approximate 8 10 9 11 from .models import Author, Publisher, Book, Store … … class BaseAggregateTestCase(TestCase): 18 20 def test_single_aggregate(self): 19 21 vals = Author.objects.aggregate(Avg("age")) 20 22 self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)}) 23 vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29))) 24 self.assertEqual(vals, {"age__sum": 254}) 25 vals = Author.objects.extra(select={'testparams':'age < %s'}, select_params=[0])\ 26 .aggregate(Sum("age", only=Q(age__gt=29))) 27 self.assertEqual(vals, {"age__sum": 254}) 28 vals = Author.objects.aggregate(Sum("age", only=Q(name__icontains='jaco')|Q(name__icontains='adrian'))) 29 self.assertEqual(vals, {"age__sum": 69}) 21 30 22 31 def test_multiple_aggregates(self): 23 32 vals = Author.objects.aggregate(Sum("age"), Avg("age")) 24 33 self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)}) 34 vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)), Avg("age")) 35 self.assertEqual(vals, {"age__sum": 254, "age__avg": Approximate(37.4, places=1)}) 25 36 26 37 def test_filter_aggregate(self): 27 38 vals = Author.objects.filter(age__gt=29).aggregate(Sum("age")) 28 39 self.assertEqual(len(vals), 1) 29 40 self.assertEqual(vals["age__sum"], 254) 41 vals = Author.objects.filter(age__gt=29).aggregate(Sum("age", only=Q(age__lt=29))) 42 # If there are no matching aggregates, then None, not 0 is the answer. 43 self.assertEqual(vals["age__sum"], None) 30 44 31 45 def test_related_aggregate(self): 32 46 vals = Author.objects.aggregate(Avg("friends__age")) 33 47 self.assertEqual(len(vals), 1) 34 48 self.assertAlmostEqual(vals["friends__age__avg"], 34.07, places=2) 35 49 50 vals = Author.objects.aggregate(Avg("friends__age", only=Q(age__lt=29))) 51 self.assertEqual(len(vals), 1) 52 self.assertAlmostEqual(vals["friends__age__avg"], 33.67, places=2) 53 vals2 = Author.objects.filter(age__lt=29).aggregate(Avg("friends__age")) 54 self.assertEqual(vals, vals2) 55 56 vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=35))) 57 self.assertEqual(len(vals), 1) 58 self.assertAlmostEqual(vals["friends__age__avg"], 28.75, places=2) 59 60 # The average age of author's friends, whose age is lower than the authors age. 61 vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=F('age')))) 62 self.assertEqual(len(vals), 1) 63 self.assertAlmostEqual(vals["friends__age__avg"], 30.43, places=2) 64 36 65 vals = Book.objects.filter(rating__lt=4.5).aggregate(Avg("authors__age")) 37 66 self.assertEqual(len(vals), 1) 38 67 self.assertAlmostEqual(vals["authors__age__avg"], 38.2857, places=2) … … class BaseAggregateTestCase(TestCase): 54 83 self.assertEqual(len(vals), 1) 55 84 self.assertEqual(vals["books__authors__age__max"], 57) 56 85 86 vals = Store.objects.aggregate(Max("books__authors__age", only=Q(books__authors__age__lt=56))) 87 self.assertEqual(len(vals), 1) 88 self.assertEqual(vals["books__authors__age__max"], 46) 89 57 90 vals = Author.objects.aggregate(Min("book__publisher__num_awards")) 58 91 self.assertEqual(len(vals), 1) 59 92 self.assertEqual(vals["book__publisher__num_awards__min"], 1) … … class BaseAggregateTestCase(TestCase): 84 117 ) 85 118 self.assertEqual(b.mean_age, 34.5) 86 119 120 # Test extra-select 121 books = Book.objects.annotate(mean_age=Avg("authors__age")) 122 books = books.annotate(mean_age2=Avg('authors__age', only=Q(authors__age__gte=0))) 123 books = books.extra(select={'testparams': 'publisher_id = %s'}, select_params=[1]) 124 b = books.get(pk=1) 125 self.assertEqual(b.mean_age, 34.5) 126 self.assertEqual(b.mean_age2, 34.5) 127 self.assertEqual(b.testparams, True) 128 129 # Test relabel_aliases 130 excluded_authors = Author.objects.annotate(book_rating=Min(F('book__rating') + 5, only=Q(pk__gte=1))) 131 excluded_authors = excluded_authors.filter(book_rating__lt=0) 132 books = books.exclude(authors__in=excluded_authors) 133 b = books.get(pk=1) 134 self.assertEqual(b.mean_age, 34.5) 135 136 # Test joins in F-based annotation 137 books = Book.objects.annotate(oldest=Max(F('authors__age'))) 138 books = books.values_list('rating', 'oldest').order_by('rating', 'oldest') 139 self.assertEqual( 140 list(books), 141 [(3.0, 45), (4.0, 29), (4.0, 37), (4.0, 57), (4.5, 35), (5.0, 57)] 142 ) 143 144 publishers = Publisher.objects.annotate(avg_rating=Avg(F('book__rating') - 0)) 145 publishers = publishers.values_list('id', 'avg_rating').order_by('id') 146 self.assertEqual(list(publishers), [(1, 4.25), (2, 3.0), (3, 4.0), (4, 5.0), (5, None)]) 147 87 148 def test_annotate_m2m(self): 88 149 books = Book.objects.filter(rating__lt=4.5).annotate(Avg("authors__age")).order_by("name") 89 150 self.assertQuerysetEqual( … … class BaseAggregateTestCase(TestCase): 109 170 lambda b: (b.name, b.num_authors) 110 171 ) 111 172 173 def raises_exception(): 174 list(Book.objects.annotate(num_authors=Count("authors")).annotate(num_authors2=Count("authors", only=Q(num_authors__gt=1))).order_by("name")) 175 176 self.assertRaises(FieldError, raises_exception) 177 112 178 def test_backwards_m2m_annotate(self): 113 179 authors = Author.objects.filter(name__contains="a").annotate(Avg("book__rating")).order_by("name") 114 180 self.assertQuerysetEqual( … … class BaseAggregateTestCase(TestCase): 194 260 } 195 261 ] 196 262 ) 263 books = Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age', only=Q(authors__age__lt=35))).values('pk', 'isbn', 'mean_age') 264 self.assertEqual( 265 list(books), [ 266 { 267 "pk": 1, 268 "isbn": "159059725", 269 "mean_age": 34.0, 270 } 271 ] 272 ) 197 273 198 274 books = Book.objects.filter(pk=1).annotate(mean_age=Avg("authors__age")).values("name") 199 275 self.assertEqual( … … class BaseAggregateTestCase(TestCase): 271 347 272 348 vals = Book.objects.aggregate(Count("rating", distinct=True)) 273 349 self.assertEqual(vals, {"rating__count": 4}) 350 vals = Book.objects.aggregate( 351 low_count=Count("rating", only=Q(rating__lt=4)), 352 high_count=Count("rating", only=Q(rating__gte=4)) 353 ) 354 self.assertEqual(vals, {"low_count": 1, 'high_count': 5}) 355 vals = Book.objects.aggregate( 356 low_count=Count("rating", distinct=True, only=Q(rating__lt=4)), 357 high_count=Count("rating", distinct=True, only=Q(rating__gte=4)) 358 ) 359 self.assertEqual(vals, {"low_count": 1, 'high_count': 3}) 274 360 275 361 def test_fkey_aggregate(self): 276 362 explicit = list(Author.objects.annotate(Count('book__id'))) … … class BaseAggregateTestCase(TestCase): 390 476 ], 391 477 lambda p: p.name, 392 478 ) 479 publishers = Publisher.objects.annotate(num_books=Count("book__id", only=Q(book__id__gt=5))).filter(num_books__gt=1, book__price__lt=Decimal("40.0")).order_by("pk") 480 self.assertQuerysetEqual( 481 publishers, [ 482 "Expensive Publisher", 483 ], 484 lambda p: p.name, 485 ) 393 486 394 487 publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk") 395 488 self.assertQuerysetEqual(