Ticket #11305: conditional_aggregates.1.6.patch
File conditional_aggregates.1.6.patch, 29.3 KB (added by , 12 years ago) |
---|
-
django/db/models/aggregates.py
From 55f415fa2fe31b5191f9a67d17364f9590be401d Mon Sep 17 00:00:00 2001 From: Andre Terra <loftarasa@gmail.com> Date: Tue, 11 Dec 2012 18:32:25 -0200 Subject: [PATCH] Added support for conditional aggregates (ticket #11305) --- django/db/models/aggregates.py | 9 ++- django/db/models/expressions.py | 8 +-- django/db/models/sql/aggregates.py | 43 +++++++++++--- django/db/models/sql/compiler.py | 39 ++++++++----- django/db/models/sql/expressions.py | 15 +++-- django/db/models/sql/query.py | 103 +++++++++++++++++++++------------- django/db/models/sql/where.py | 10 ++++ tests/modeltests/aggregation/tests.py | 93 ++++++++++++++++++++++++++++++ 8 files changed, 246 insertions(+), 74 deletions(-) diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index a2349cf..d816aa7 100644
a b class Aggregate(object): 6 6 """ 7 7 Default Aggregate definition. 8 8 """ 9 def __init__(self, lookup, **extra):9 def __init__(self, lookup, only=None, **extra): 10 10 """Instantiate a new aggregate. 11 11 12 12 * lookup is the field on which the aggregate operates. 13 * only is a Q-object used in conditional aggregation. 13 14 * extra is a dictionary of additional data to provide for the 14 15 aggregate definition 15 16 … … class Aggregate(object): 18 19 """ 19 20 self.lookup = lookup 20 21 self.extra = extra 22 self.only = only 23 self.condition = None 21 24 22 25 def _default_alias(self): 26 if hasattr(self.lookup, 'evaluate'): 27 raise ValueError('When aggregating over an expression, you need to give an alias.') 23 28 return '%s__%s' % (self.lookup, self.name.lower()) 24 29 default_alias = property(_default_alias) 25 30 … … class Aggregate(object): 42 47 summary value rather than an annotation. 43 48 """ 44 49 klass = getattr(query.aggregates_module, self.name) 45 aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)50 aggregate = klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra) 46 51 query.aggregates[alias] = aggregate 47 52 48 53 class Avg(Aggregate): -
django/db/models/expressions.py
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 3566d77..2b782ea 100644
a b class ExpressionNode(tree.Node): 41 41 # VISITOR METHODS # 42 42 ################### 43 43 44 def prepare(self, evaluator, query, allow_joins ):45 return evaluator.prepare_node(self, query, allow_joins )44 def prepare(self, evaluator, query, allow_joins, promote_joins=False): 45 return evaluator.prepare_node(self, query, allow_joins, promote_joins) 46 46 47 47 def evaluate(self, evaluator, qn, connection): 48 48 return evaluator.evaluate_node(self, qn, connection) … … class F(ExpressionNode): 129 129 obj.name = self.name 130 130 return obj 131 131 132 def prepare(self, evaluator, query, allow_joins ):133 return evaluator.prepare_leaf(self, query, allow_joins )132 def prepare(self, evaluator, query, allow_joins, promote_joins=False): 133 return evaluator.prepare_leaf(self, query, allow_joins, promote_joins) 134 134 135 135 def evaluate(self, evaluator, qn, connection): 136 136 return evaluator.evaluate_leaf(self, qn, connection) -
django/db/models/sql/aggregates.py
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index b41314a..5fe2215 100644
a b Classes to represent the default SQL aggregate functions 3 3 """ 4 4 5 5 from django.db.models.fields import IntegerField, FloatField 6 from django.db.models.sql.expressions import SQLEvaluator 6 7 7 8 # Fake fields used to identify aggregate types in data-conversion operations. 8 9 ordinal_aggregate_field = IntegerField() … … class Aggregate(object): 15 16 is_ordinal = False 16 17 is_computed = False 17 18 sql_template = '%(function)s(%(field)s)' 19 conditional_template = "CASE WHEN %(condition)s THEN %(field_name)s ELSE null END" 18 20 19 def __init__(self, col, source=None, is_summary=False, **extra):21 def __init__(self, col, source=None, is_summary=False, condition=None, **extra): 20 22 """Instantiate an SQL aggregate 21 23 22 24 * col is a column reference describing the subject field … … class Aggregate(object): 26 28 the column reference. If the aggregate is not an ordinal or 27 29 computed type, this reference is used to determine the coerced 28 30 output type of the aggregate. 31 * condition is used in conditional aggregation. 29 32 * extra is a dictionary of additional data to provide for the 30 aggregate definition 33 aggregate definition. 31 34 32 35 Also utilizes the class variables: 33 36 * sql_function, the name of the SQL function that implements the … … class Aggregate(object): 35 38 * sql_template, a template string that is used to render the 36 39 aggregate into SQL. 37 40 * is_ordinal, a boolean indicating if the output of this aggregate 38 is an integer (e.g., a count) 41 is an integer (e.g., a count). 39 42 * is_computed, a boolean indicating if this output of this aggregate 40 43 is a computed float (e.g., an average), regardless of the input 41 44 type. … … class Aggregate(object): 45 48 self.source = source 46 49 self.is_summary = is_summary 47 50 self.extra = extra 51 self.condition = condition 48 52 49 53 # Follow the chain of aggregate sources back until you find an 50 54 # actual field, or an aggregate that forces a particular output … … class Aggregate(object): 65 69 def relabel_aliases(self, change_map): 66 70 if isinstance(self.col, (list, tuple)): 67 71 self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) 72 else: 73 self.col.relabel_aliases(change_map) 74 if self.condition: 75 self.condition.relabel_aliases(change_map) 68 76 69 77 def as_sql(self, qn, connection): 70 78 "Return the aggregate, rendered as SQL." 71 79 80 condition_params = [] 81 col_params = [] 72 82 if hasattr(self.col, 'as_sql'): 73 field_name = self.col.as_sql(qn, connection) 83 if isinstance(self.col, SQLEvaluator): 84 field_name, col_params = self.col.as_sql(qn, connection) 85 else: 86 field_name = self.col.as_sql(qn, connection) 74 87 elif isinstance(self.col, (list, tuple)): 75 88 field_name = '.'.join([qn(c) for c in self.col]) 76 89 else: 77 90 field_name = self.col 78 91 79 params = { 80 'function': self.sql_function, 81 'field': field_name 82 } 92 if self.condition: 93 condition, condition_params = self.condition.as_sql(qn, connection) 94 conditional_field = self.conditional_template % { 95 'condition': condition, 96 'field_name': field_name 97 } 98 params = { 99 'function': self.sql_function, 100 'field': conditional_field, 101 } 102 else: 103 params = { 104 'function': self.sql_function, 105 'field': field_name 106 } 83 107 params.update(self.extra) 84 108 85 return self.sql_template % params 109 condition_params.extend(col_params) 110 return (self.sql_template % params, condition_params) 86 111 87 112 88 113 class Avg(Aggregate): -
django/db/models/sql/compiler.py
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 8cfb12a..ed97a11 100644
a b class SQLCompiler(object): 71 71 # as the pre_sql_setup will modify query state in a way that forbids 72 72 # another run of it. 73 73 self.refcounts_before = self.query.alias_refcount.copy() 74 out_cols = self.get_columns(with_col_aliases)74 out_cols, c_params = self.get_columns(with_col_aliases) 75 75 ordering, ordering_group_by = self.get_ordering() 76 76 77 77 distinct_fields = self.get_distinct() … … class SQLCompiler(object): 87 87 params = [] 88 88 for val in six.itervalues(self.query.extra_select): 89 89 params.extend(val[1]) 90 # Extra-select comes before aggregation in the select list 91 params.extend(c_params) 90 92 91 93 result = ['SELECT'] 92 94 … … class SQLCompiler(object): 172 174 qn = self.quote_name_unless_alias 173 175 qn2 = self.connection.ops.quote_name 174 176 result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] 177 query_params = [] 175 178 aliases = set(self.query.extra_select.keys()) 176 179 if with_aliases: 177 180 col_aliases = aliases.copy() … … class SQLCompiler(object): 214 217 aliases.update(new_aliases) 215 218 216 219 max_name_length = self.connection.ops.max_name_length() 217 result.extend([ 218 '%s%s' % ( 219 aggregate.as_sql(qn, self.connection), 220 alias is not None 221 and ' AS %s' % qn(truncate_name(alias, max_name_length)) 222 or '' 220 for alias, aggregate in self.query.aggregate_select.items(): 221 sql, params = aggregate.as_sql(qn, self.connection) 222 result.append( 223 '%s%s' % ( 224 sql, 225 alias is not None 226 and ' AS %s' % qn(truncate_name(alias, max_name_length)) 227 or '' 228 ) 223 229 ) 224 for alias, aggregate in self.query.aggregate_select.items() 225 ]) 230 query_params.extend(params) 226 231 227 232 for (table, col), _ in self.query.related_select_cols: 228 233 r = '%s.%s' % (qn(table), qn(col)) … … class SQLCompiler(object): 237 242 col_aliases.add(col) 238 243 239 244 self._select_aliases = aliases 240 return result 245 return result, query_params 241 246 242 247 def get_default_columns(self, with_aliases=False, col_aliases=None, 243 248 start_alias=None, opts=None, as_pairs=False, from_parent=None): … … class SQLAggregateCompiler(SQLCompiler): 1032 1037 """ 1033 1038 if qn is None: 1034 1039 qn = self.quote_name_unless_alias 1040 buf = [] 1041 a_params = [] 1042 for aggregate in self.query.aggregate_select.values(): 1043 sql, query_params = aggregate.as_sql(qn, self.connection) 1044 buf.append(sql) 1045 a_params.extend(query_params) 1046 aggregate_sql = ', '.join(buf) 1035 1047 1036 1048 sql = ('SELECT %s FROM (%s) subquery' % ( 1037 ', '.join([ 1038 aggregate.as_sql(qn, self.connection) 1039 for aggregate in self.query.aggregate_select.values() 1040 ]), 1049 aggregate_sql, 1041 1050 self.query.subquery) 1042 1051 ) 1043 params = self.query.sub_params1052 params = tuple(a_params) + (self.query.sub_params) 1044 1053 return (sql, params) 1045 1054 1046 1055 class SQLDateCompiler(SQLCompiler): -
django/db/models/sql/expressions.py
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index c809e25..4e53ac4 100644
a b from django.db.models.fields import FieldDoesNotExist 4 4 from django.db.models.sql.constants import REUSE_ALL 5 5 6 6 class SQLEvaluator(object): 7 def __init__(self, expression, query, allow_joins=True, reuse=REUSE_ALL ):7 def __init__(self, expression, query, allow_joins=True, reuse=REUSE_ALL, promote_joins=False): 8 8 self.expression = expression 9 9 self.opts = query.get_meta() 10 10 self.cols = [] 11 11 12 12 self.contains_aggregate = False 13 13 self.reuse = reuse 14 self.expression.prepare(self, query, allow_joins )14 self.expression.prepare(self, query, allow_joins, promote_joins) 15 15 16 16 def prepare(self): 17 17 return self … … class SQLEvaluator(object): 34 34 # Vistor methods for initial expression preparation # 35 35 ##################################################### 36 36 37 def prepare_node(self, node, query, allow_joins ):37 def prepare_node(self, node, query, allow_joins, promote_joins): 38 38 for child in node.children: 39 39 if hasattr(child, 'prepare'): 40 child.prepare(self, query, allow_joins )40 child.prepare(self, query, allow_joins, promote_joins) 41 41 42 def prepare_leaf(self, node, query, allow_joins ):42 def prepare_leaf(self, node, query, allow_joins, promote_joins): 43 43 if not allow_joins and LOOKUP_SEP in node.name: 44 44 raise FieldError("Joined field references are not permitted in this query") 45 45 … … class SQLEvaluator(object): 54 54 field_list, query.get_meta(), 55 55 query.get_initial_alias(), self.reuse) 56 56 col, _, join_list = query.trim_joins(source, join_list, last, False) 57 if promote_joins: 58 query.promote_joins(join_list, unconditional=True) 57 59 if self.reuse is not None and self.reuse != REUSE_ALL: 58 60 self.reuse.update(join_list) 59 61 self.cols.append((node, (join_list[-1], col))) … … class SQLEvaluator(object): 72 74 for child in node.children: 73 75 if hasattr(child, 'evaluate'): 74 76 sql, params = child.evaluate(self, qn, connection) 77 if isinstance(sql, tuple): 78 expression_params.extend(sql[1]) 79 sql = sql[0] 75 80 else: 76 81 sql, params = '%s', (child,) 77 82 -
django/db/models/sql/query.py
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index e24dc22..c0fa92c 100644
a b class Query(object): 119 119 self.filter_is_sticky = False 120 120 self.included_inherited_models = {} 121 121 122 # SQL-related attributes 122 # SQL-related attributes 123 123 # Select and related select clauses as SelectInfo instances. 124 124 # The select is used for cases where we want to set up the select 125 125 # clause to contain other than default fields (values(), annotate(), … … class Query(object): 987 987 Adds a single aggregate expression to the Query 988 988 """ 989 989 opts = model._meta 990 field_list = aggregate.lookup.split(LOOKUP_SEP) 991 if len(field_list) == 1 and aggregate.lookup in self.aggregates: 992 # Aggregate is over an annotation 993 field_name = field_list[0] 994 col = field_name 995 source = self.aggregates[field_name] 996 if not is_summary: 997 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 998 aggregate.name, field_name, field_name)) 999 elif ((len(field_list) > 1) or 1000 (field_list[0] not in [i.name for i in opts.fields]) or 1001 self.group_by is None or 1002 not is_summary): 1003 # If: 1004 # - the field descriptor has more than one part (foo__bar), or 1005 # - the field descriptor is referencing an m2m/m2o field, or 1006 # - this is a reference to a model field (possibly inherited), or 1007 # - this is an annotation over a model field 1008 # then we need to explore the joins that are required. 1009 1010 field, source, opts, join_list, last, _ = self.setup_joins( 1011 field_list, opts, self.get_initial_alias(), REUSE_ALL) 1012 1013 # Process the join chain to see if it can be trimmed 1014 col, _, join_list = self.trim_joins(source, join_list, last, False) 1015 1016 # If the aggregate references a model or field that requires a join, 1017 # those joins must be LEFT OUTER - empty join rows must be returned 1018 # in order for zeros to be returned for those aggregates. 1019 self.promote_joins(join_list, True) 1020 1021 col = (join_list[-1], col) 990 only = aggregate.only 991 if hasattr(aggregate.lookup, 'evaluate'): 992 # If lookup is a query expression, evaluate it 993 col = SQLEvaluator(aggregate.lookup, self, promote_joins=True) 994 # TODO: find out the real source of this field. If any field has 995 # is_computed, then source can be set to is_computed. 996 source = None 1022 997 else: 1023 # The simplest cases. No joins required - 1024 # just reference the provided column alias. 1025 field_name = field_list[0] 1026 source = opts.get_field(field_name) 1027 col = field_name 1028 998 field_list = aggregate.lookup.split(LOOKUP_SEP) 999 join_list = [] 1000 if len(field_list) == 1 and aggregate.lookup in self.aggregates: 1001 # Aggregate is over an annotation 1002 field_name = field_list[0] 1003 col = field_name 1004 source = self.aggregates[field_name] 1005 if not is_summary: 1006 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 1007 aggregate.name, field_name, field_name)) 1008 if only: 1009 raise FieldError("Cannot use aggregated fields in conditional aggregates") 1010 elif ((len(field_list) > 1) or 1011 (field_list[0] not in [i.name for i in opts.fields]) or 1012 self.group_by is None or 1013 not is_summary): 1014 # If: 1015 # - the field descriptor has more than one part (foo__bar), or 1016 # - the field descriptor is referencing an m2m/m2o field, or 1017 # - this is a reference to a model field (possibly inherited), or 1018 # - this is an annotation over a model field 1019 # then we need to explore the joins that are required. 1020 1021 field, source, opts, join_list, last, _ = self.setup_joins( 1022 field_list, opts, self.get_initial_alias(), REUSE_ALL) 1023 1024 # Process the join chain to see if it can be trimmed 1025 col, _, join_list = self.trim_joins(source, join_list, last, False) 1026 1027 # If the aggregate references a model or field that requires a join, 1028 # those joins must be LEFT OUTER - empty join rows must be returned 1029 # in order for zeros to be returned for those aggregates. 1030 self.promote_joins(join_list, unconditional=True) 1031 1032 col = (join_list[-1], col) 1033 else: 1034 # The simplest cases. No joins required - 1035 # just reference the provided column alias. 1036 field_name = field_list[0] 1037 source = opts.get_field(field_name) 1038 col = field_name 1039 1040 if only: 1041 original_where = self.where 1042 original_having = self.having 1043 aggregate.condition = self.where_class() 1044 self.where = aggregate.condition 1045 self.having = self.where_class() 1046 original_alias_map = self.alias_map.keys()[:] 1047 self.add_q(only, used_aliases=set(original_alias_map)) 1048 if original_alias_map != self.alias_map.keys(): 1049 raise FieldError("Aggregate's only condition can not require additional joins, Original joins: %s, joins after: %s" % (original_alias_map, self.alias_map.keys())) 1050 if self.having.children: 1051 raise FieldError("Aggregate's only condition can not reference annotated fields") 1052 self.having = original_having 1053 self.where = original_where 1029 1054 # Add the aggregate to the query 1030 1055 aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) 1031 1056 -
django/db/models/sql/where.py
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 47f4ffa..6b001bc 100644
a b class WhereNode(tree.Node): 156 156 it. 157 157 """ 158 158 lvalue, lookup_type, value_annotation, params_or_value = child 159 additional_params = [] 159 160 if isinstance(lvalue, Constraint): 160 161 try: 161 162 lvalue, params = lvalue.process(lookup_type, params_or_value, connection) … … class WhereNode(tree.Node): 173 174 else: 174 175 # A smart object with an as_sql() method. 175 176 field_sql = lvalue.as_sql(qn, connection) 177 if isinstance(field_sql, tuple): 178 # It also returned params 179 additional_params.extend(field_sql[1]) 180 field_sql = field_sql[0] 176 181 177 182 if value_annotation is datetime.datetime: 178 183 cast_sql = connection.ops.datetime_cast_sql() … … class WhereNode(tree.Node): 181 186 182 187 if hasattr(params, 'as_sql'): 183 188 extra, params = params.as_sql(qn, connection) 189 if isinstance(extra, tuple): 190 params = params + tuple(extra[1]) 191 extra = extra[0] 184 192 cast_sql = '' 185 193 else: 186 194 extra = '' … … class WhereNode(tree.Node): 190 198 lookup_type = 'isnull' 191 199 value_annotation = True 192 200 201 additional_params.extend(params) 202 params = additional_params 193 203 if lookup_type in connection.operators: 194 204 format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) 195 205 return (format % (field_sql, -
tests/modeltests/aggregation/tests.py
diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py index c23b32f..0c445c4 100644
a b import datetime 4 4 from decimal import Decimal 5 5 6 6 from django.db.models import Avg, Sum, Count, Max, Min 7 from django.db.models import Q, F 8 from django.core.exceptions import FieldError 7 9 from django.test import TestCase, Approximate 8 10 9 11 from .models import Author, Publisher, Book, Store … … class BaseAggregateTestCase(TestCase): 18 20 def test_single_aggregate(self): 19 21 vals = Author.objects.aggregate(Avg("age")) 20 22 self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)}) 23 vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29))) 24 self.assertEqual(vals, {"age__sum": 254}) 25 vals = Author.objects.extra(select={'testparams':'age < %s'}, select_params=[0])\ 26 .aggregate(Sum("age", only=Q(age__gt=29))) 27 self.assertEqual(vals, {"age__sum": 254}) 28 vals = Author.objects.aggregate(Sum("age", only=Q(name__icontains='jaco')|Q(name__icontains='adrian'))) 29 self.assertEqual(vals, {"age__sum": 69}) 21 30 22 31 def test_multiple_aggregates(self): 23 32 vals = Author.objects.aggregate(Sum("age"), Avg("age")) 24 33 self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)}) 34 vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)), Avg("age")) 35 self.assertEqual(vals, {"age__sum": 254, "age__avg": Approximate(37.4, places=1)}) 25 36 26 37 def test_filter_aggregate(self): 27 38 vals = Author.objects.filter(age__gt=29).aggregate(Sum("age")) 28 39 self.assertEqual(len(vals), 1) 29 40 self.assertEqual(vals["age__sum"], 254) 41 vals = Author.objects.filter(age__gt=29).aggregate(Sum("age", only=Q(age__lt=29))) 42 # If there are no matching aggregates, then None, not 0 is the answer. 43 self.assertEqual(vals["age__sum"], None) 30 44 31 45 def test_related_aggregate(self): 32 46 vals = Author.objects.aggregate(Avg("friends__age")) 33 47 self.assertEqual(len(vals), 1) 34 48 self.assertAlmostEqual(vals["friends__age__avg"], 34.07, places=2) 35 49 50 vals = Author.objects.aggregate(Avg("friends__age", only=Q(age__lt=29))) 51 self.assertEqual(len(vals), 1) 52 self.assertAlmostEqual(vals["friends__age__avg"], 33.67, places=2) 53 vals2 = Author.objects.filter(age__lt=29).aggregate(Avg("friends__age")) 54 self.assertEqual(vals, vals2) 55 56 vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=35))) 57 self.assertEqual(len(vals), 1) 58 self.assertAlmostEqual(vals["friends__age__avg"], 28.75, places=2) 59 60 # The average age of author's friends, whose age is lower than the authors age. 61 vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=F('age')))) 62 self.assertEqual(len(vals), 1) 63 self.assertAlmostEqual(vals["friends__age__avg"], 30.43, places=2) 64 36 65 vals = Book.objects.filter(rating__lt=4.5).aggregate(Avg("authors__age")) 37 66 self.assertEqual(len(vals), 1) 38 67 self.assertAlmostEqual(vals["authors__age__avg"], 38.2857, places=2) … … class BaseAggregateTestCase(TestCase): 54 83 self.assertEqual(len(vals), 1) 55 84 self.assertEqual(vals["books__authors__age__max"], 57) 56 85 86 vals = Store.objects.aggregate(Max("books__authors__age", only=Q(books__authors__age__lt=56))) 87 self.assertEqual(len(vals), 1) 88 self.assertEqual(vals["books__authors__age__max"], 46) 89 57 90 vals = Author.objects.aggregate(Min("book__publisher__num_awards")) 58 91 self.assertEqual(len(vals), 1) 59 92 self.assertEqual(vals["book__publisher__num_awards__min"], 1) … … class BaseAggregateTestCase(TestCase): 84 117 ) 85 118 self.assertEqual(b.mean_age, 34.5) 86 119 120 # Test extra-select 121 books = Book.objects.annotate(mean_age=Avg("authors__age")) 122 books = books.annotate(mean_age2=Avg('authors__age', only=Q(authors__age__gte=0))) 123 books = books.extra(select={'testparams': 'publisher_id = %s'}, select_params=[1]) 124 b = books.get(pk=1) 125 self.assertEqual(b.mean_age, 34.5) 126 self.assertEqual(b.mean_age2, 34.5) 127 self.assertEqual(b.testparams, True) 128 129 # Test relabel_aliases 130 excluded_authors = Author.objects.annotate(book_rating=Min(F('book__rating') + 5, only=Q(pk__gte=1))) 131 excluded_authors = excluded_authors.filter(book_rating__lt=0) 132 books = books.exclude(authors__in=excluded_authors) 133 b = books.get(pk=1) 134 self.assertEqual(b.mean_age, 34.5) 135 136 # Test joins in F-based annotation 137 books = Book.objects.annotate(oldest=Max(F('authors__age'))) 138 books = books.values_list('rating', 'oldest').order_by('rating', 'oldest') 139 self.assertEqual( 140 list(books), 141 [(3.0, 45), (4.0, 29), (4.0, 37), (4.0, 57), (4.5, 35), (5.0, 57)] 142 ) 143 144 publishers = Publisher.objects.annotate(avg_rating=Avg(F('book__rating') - 0)) 145 publishers = publishers.values_list('id', 'avg_rating').order_by('id') 146 self.assertEqual(list(publishers), [(1, 4.25), (2, 3.0), (3, 4.0), (4, 5.0), (5, None)]) 147 87 148 def test_annotate_m2m(self): 88 149 books = Book.objects.filter(rating__lt=4.5).annotate(Avg("authors__age")).order_by("name") 89 150 self.assertQuerysetEqual( … … class BaseAggregateTestCase(TestCase): 109 170 lambda b: (b.name, b.num_authors) 110 171 ) 111 172 173 def raises_exception(): 174 list(Book.objects.annotate(num_authors=Count("authors")).annotate(num_authors2=Count("authors", only=Q(num_authors__gt=1))).order_by("name")) 175 176 self.assertRaises(FieldError, raises_exception) 177 112 178 def test_backwards_m2m_annotate(self): 113 179 authors = Author.objects.filter(name__contains="a").annotate(Avg("book__rating")).order_by("name") 114 180 self.assertQuerysetEqual( … … class BaseAggregateTestCase(TestCase): 194 260 } 195 261 ] 196 262 ) 263 books = Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age', only=Q(authors__age__lt=35))).values('pk', 'isbn', 'mean_age') 264 self.assertEqual( 265 list(books), [ 266 { 267 "pk": 1, 268 "isbn": "159059725", 269 "mean_age": 34.0, 270 } 271 ] 272 ) 197 273 198 274 books = Book.objects.filter(pk=1).annotate(mean_age=Avg("authors__age")).values("name") 199 275 self.assertEqual( … … class BaseAggregateTestCase(TestCase): 271 347 272 348 vals = Book.objects.aggregate(Count("rating", distinct=True)) 273 349 self.assertEqual(vals, {"rating__count": 4}) 350 vals = Book.objects.aggregate( 351 low_count=Count("rating", only=Q(rating__lt=4)), 352 high_count=Count("rating", only=Q(rating__gte=4)) 353 ) 354 self.assertEqual(vals, {"low_count": 1, 'high_count': 5}) 355 vals = Book.objects.aggregate( 356 low_count=Count("rating", distinct=True, only=Q(rating__lt=4)), 357 high_count=Count("rating", distinct=True, only=Q(rating__gte=4)) 358 ) 359 self.assertEqual(vals, {"low_count": 1, 'high_count': 3}) 274 360 275 361 def test_fkey_aggregate(self): 276 362 explicit = list(Author.objects.annotate(Count('book__id'))) … … class BaseAggregateTestCase(TestCase): 390 476 ], 391 477 lambda p: p.name, 392 478 ) 479 publishers = Publisher.objects.annotate(num_books=Count("book__id", only=Q(book__id__gt=5))).filter(num_books__gt=1, book__price__lt=Decimal("40.0")).order_by("pk") 480 self.assertQuerysetEqual( 481 publishers, [ 482 "Expensive Publisher", 483 ], 484 lambda p: p.name, 485 ) 393 486 394 487 publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk") 395 488 self.assertQuerysetEqual(