Ticket #11305: 11305-2013-09-07-master.patch
File 11305-2013-09-07-master.patch, 23.9 KB (added by , 11 years ago) |
---|
-
django/db/models/aggregates.py
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 1db3890..9806d34 100644
a b class Aggregate(object): 21 21 """ 22 22 Default Aggregate definition. 23 23 """ 24 def __init__(self, lookup, **extra):24 def __init__(self, lookup, only=None, **extra): 25 25 """Instantiate a new aggregate. 26 26 27 27 * lookup is the field on which the aggregate operates. 28 * only is a Q-object used in conditional aggregation. 28 29 * extra is a dictionary of additional data to provide for the 29 30 aggregate definition 30 31 … … class Aggregate(object): 33 34 """ 34 35 self.lookup = lookup 35 36 self.extra = extra 37 self.only = only 38 self.condition = None 36 39 37 40 def _default_alias(self): 41 if hasattr(self.lookup, 'evaluate'): 42 raise ValueError('When aggregating over an expression, you need to give an alias.') 38 43 return '%s__%s' % (self.lookup, self.name.lower()) 39 44 default_alias = property(_default_alias) 40 45 … … class Aggregate(object): 57 62 summary value rather than an annotation. 58 63 """ 59 64 klass = getattr(query.aggregates_module, self.name) 60 aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)65 aggregate = klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra) 61 66 query.aggregates[alias] = aggregate 62 67 63 68 -
django/db/models/expressions.py
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index b4eea5f..bdb8ec5 100644
a b class ExpressionNode(tree.Node): 57 57 # VISITOR METHODS # 58 58 ################### 59 59 60 def prepare(self, evaluator, query, allow_joins ):61 return evaluator.prepare_node(self, query, allow_joins )60 def prepare(self, evaluator, query, allow_joins, promote_joins=False): 61 return evaluator.prepare_node(self, query, allow_joins, promote_joins) 62 62 63 63 def evaluate(self, evaluator, qn, connection): 64 64 return evaluator.evaluate_node(self, qn, connection) … … class F(ExpressionNode): 143 143 obj.name = self.name 144 144 return obj 145 145 146 def prepare(self, evaluator, query, allow_joins ):147 return evaluator.prepare_leaf(self, query, allow_joins )146 def prepare(self, evaluator, query, allow_joins, promote_joins=False): 147 return evaluator.prepare_leaf(self, query, allow_joins, promote_joins) 148 148 149 149 def evaluate(self, evaluator, qn, connection): 150 150 return evaluator.evaluate_leaf(self, qn, connection) -
django/db/models/sql/aggregates.py
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 3cda4d2..8c1b0e0 100644
a b Classes to represent the default SQL aggregate functions 4 4 import copy 5 5 6 6 from django.db.models.fields import IntegerField, FloatField 7 from django.db.models.sql.expressions import SQLEvaluator 7 8 8 9 # Fake fields used to identify aggregate types in data-conversion operations. 9 10 ordinal_aggregate_field = IntegerField() … … class Aggregate(object): 17 18 is_ordinal = False 18 19 is_computed = False 19 20 sql_template = '%(function)s(%(field)s)' 21 conditional_template = "CASE WHEN %(condition)s THEN %(field_name)s ELSE null END" 20 22 21 def __init__(self, col, source=None, is_summary=False, **extra):23 def __init__(self, col, source=None, is_summary=False, condition=None, **extra): 22 24 """Instantiate an SQL aggregate 23 25 24 26 * col is a column reference describing the subject field … … class Aggregate(object): 28 30 the column reference. If the aggregate is not an ordinal or 29 31 computed type, this reference is used to determine the coerced 30 32 output type of the aggregate. 33 * condition is used in conditional aggregation. 31 34 * extra is a dictionary of additional data to provide for the 32 35 aggregate definition 33 36 … … class Aggregate(object): 47 50 self.source = source 48 51 self.is_summary = is_summary 49 52 self.extra = extra 53 self.condition = condition 50 54 51 55 # Follow the chain of aggregate sources back until you find an 52 56 # actual field, or an aggregate that forces a particular output … … class Aggregate(object): 68 72 clone = copy.copy(self) 69 73 if isinstance(self.col, (list, tuple)): 70 74 clone.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) 75 else: 76 clone.col.relabel_aliases(change_map) 77 if clone.condition: 78 clone.condition.relabel_aliases(change_map) 71 79 return clone 72 80 73 81 def as_sql(self, qn, connection): 74 82 "Return the aggregate, rendered as SQL with parameters." 75 params = [] 83 condition_params = [] 84 col_params = [] 76 85 77 86 if hasattr(self.col, 'as_sql'): 78 field_name, params = self.col.as_sql(qn, connection) 87 if isinstance(self.col, SQLEvaluator): 88 field_name, col_params = self.col.as_sql(qn, connection) 89 else: 90 field_name = self.col.as_sql(qn, connection) 79 91 elif isinstance(self.col, (list, tuple)): 80 92 field_name = '.'.join(qn(c) for c in self.col) 81 93 else: 82 94 field_name = self.col 83 95 84 substitutions = { 85 'function': self.sql_function, 86 'field': field_name 87 } 96 if self.condition: 97 condition, condition_params = self.condition.as_sql(qn, connection) 98 conditional_field = self.conditional_template % { 99 'condition': condition, 100 'field_name': field_name 101 } 102 substitutions = { 103 'function': self.sql_function, 104 'field': conditional_field, 105 } 106 else: 107 substitutions = { 108 'function': self.sql_function, 109 'field': field_name 110 } 111 88 112 substitutions.update(self.extra) 89 113 90 return self.sql_template % substitutions, params114 return (self.sql_template % substitutions, condition_params) 91 115 92 116 93 117 class Avg(Aggregate): -
django/db/models/sql/expressions.py
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index f9a8929..37608fa 100644
a b from django.db.models.fields import FieldDoesNotExist 6 6 7 7 8 8 class SQLEvaluator(object): 9 def __init__(self, expression, query, allow_joins=True, reuse=None ):9 def __init__(self, expression, query, allow_joins=True, reuse=None, promote_joins=False): 10 10 self.expression = expression 11 11 self.opts = query.get_meta() 12 12 self.reuse = reuse 13 13 self.cols = [] 14 self.expression.prepare(self, query, allow_joins )14 self.expression.prepare(self, query, allow_joins, promote_joins) 15 15 16 16 def relabeled_clone(self, change_map): 17 17 clone = copy.copy(self) … … class SQLEvaluator(object): 43 43 # Vistor methods for initial expression preparation # 44 44 ##################################################### 45 45 46 def prepare_node(self, node, query, allow_joins ):46 def prepare_node(self, node, query, allow_joins, promote_joins): 47 47 for child in node.children: 48 48 if hasattr(child, 'prepare'): 49 child.prepare(self, query, allow_joins )49 child.prepare(self, query, allow_joins, promote_joins) 50 50 51 def prepare_leaf(self, node, query, allow_joins ):51 def prepare_leaf(self, node, query, allow_joins, promote_joins): 52 52 if not allow_joins and LOOKUP_SEP in node.name: 53 53 raise FieldError("Joined field references are not permitted in this query") 54 54 … … class SQLEvaluator(object): 61 61 field_list, query.get_meta(), 62 62 query.get_initial_alias(), self.reuse) 63 63 targets, _, join_list = query.trim_joins(sources, join_list, path) 64 if promote_joins: 65 query.promote_joins(join_list, unconditional=True) 64 66 if self.reuse is not None: 65 67 self.reuse.update(join_list) 66 68 for t in targets: … … class SQLEvaluator(object): 80 82 for child in node.children: 81 83 if hasattr(child, 'evaluate'): 82 84 sql, params = child.evaluate(self, qn, connection) 85 if isinstance(sql, tuple): 86 expression_params.extend(sql[1]) 87 sql = sql[0] 83 88 else: 84 89 sql, params = '%s', (child,) 85 90 -
django/db/models/sql/query.py
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 93a0b52..dc0072f 100644
a b class Query(object): 966 966 Adds a single aggregate expression to the Query 967 967 """ 968 968 opts = model._meta 969 field_list = aggregate.lookup.split(LOOKUP_SEP) 970 if len(field_list) == 1 and aggregate.lookup in self.aggregates: 971 # Aggregate is over an annotation 972 field_name = field_list[0] 973 col = field_name 974 source = self.aggregates[field_name] 975 if not is_summary: 976 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 977 aggregate.name, field_name, field_name)) 978 elif ((len(field_list) > 1) or 979 (field_list[0] not in [i.name for i in opts.fields]) or 980 self.group_by is None or 981 not is_summary): 982 # If: 983 # - the field descriptor has more than one part (foo__bar), or 984 # - the field descriptor is referencing an m2m/m2o field, or 985 # - this is a reference to a model field (possibly inherited), or 986 # - this is an annotation over a model field 987 # then we need to explore the joins that are required. 988 989 field, sources, opts, join_list, path = self.setup_joins( 990 field_list, opts, self.get_initial_alias()) 991 992 # Process the join chain to see if it can be trimmed 993 targets, _, join_list = self.trim_joins(sources, join_list, path) 994 995 # If the aggregate references a model or field that requires a join, 996 # those joins must be LEFT OUTER - empty join rows must be returned 997 # in order for zeros to be returned for those aggregates. 998 self.promote_joins(join_list) 999 1000 col = targets[0].column 1001 source = sources[0] 1002 col = (join_list[-1], col) 969 only = aggregate.only 970 if hasattr(aggregate.lookup, 'evaluate'): 971 # If lookup is a query expression, evaluate it 972 col = SQLEvaluator(aggregate.lookup, self, promote_joins=True) 973 source = opts.get_field(col) 1003 974 else: 1004 # The simplest cases. No joins required - 1005 # just reference the provided column alias. 1006 field_name = field_list[0] 1007 source = opts.get_field(field_name) 1008 col = field_name 975 field_list = aggregate.lookup.split(LOOKUP_SEP) 976 if len(field_list) == 1 and aggregate.lookup in self.aggregates: 977 # Aggregate is over an annotation 978 field_name = field_list[0] 979 col = field_name 980 source = self.aggregates[field_name] 981 if not is_summary: 982 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 983 aggregate.name, field_name, field_name)) 984 elif ((len(field_list) > 1) or 985 (field_list[0] not in [i.name for i in opts.fields]) or 986 self.group_by is None or 987 not is_summary): 988 # If: 989 # - the field descriptor has more than one part (foo__bar), or 990 # - the field descriptor is referencing an m2m/m2o field, or 991 # - this is a reference to a model field (possibly inherited), or 992 # - this is an annotation over a model field 993 # then we need to explore the joins that are required. 994 995 field, sources, opts, join_list, path = self.setup_joins( 996 field_list, opts, self.get_initial_alias()) 997 998 # Process the join chain to see if it can be trimmed 999 targets, _, join_list = self.trim_joins(sources, join_list, path) 1000 1001 # If the aggregate references a model or field that requires a join, 1002 # those joins must be LEFT OUTER - empty join rows must be returned 1003 # in order for zeros to be returned for those aggregates. 1004 self.promote_joins(join_list) 1005 1006 col = targets[0].column 1007 source = sources[0] 1008 col = (join_list[-1], col) 1009 else: 1010 # The simplest cases. No joins required - 1011 # just reference the provided column alias. 1012 field_name = field_list[0] 1013 source = opts.get_field(field_name) 1014 col = field_name 1009 1015 # We want to have the alias in SELECT clause even if mask is set. 1010 1016 self.append_aggregate_mask([alias]) 1011 1017 1018 if only: 1019 original_where = self.where 1020 original_having = self.having 1021 aggregate.condition = self.where_class() 1022 self.where = aggregate.condition 1023 self.having = self.where_class() 1024 original_alias_map = self.alias_map.keys()[:] 1025 self.add_q(only, used_aliases=set(original_alias_map)) 1026 if original_alias_map != self.alias_map.keys(): 1027 raise FieldError("Aggregate's only condition can not require additional joins, " 1028 "Original joins: %s, joins after: %s" 1029 % (original_alias_map, self.alias_map.keys())) 1030 if self.having.children: 1031 raise FieldError("Aggregate's only condition can not reference annotated fields") 1032 self.having = original_having 1033 self.where = original_where 1034 1012 1035 # Add the aggregate to the query 1013 1036 aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) 1014 1037 -
django/db/models/sql/where.py
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 7b71580..d06d5b0 100644
a b class WhereNode(tree.Node): 175 175 it. 176 176 """ 177 177 lvalue, lookup_type, value_annotation, params_or_value = child 178 additional_params = [] 178 179 field_internal_type = lvalue.field.get_internal_type() if lvalue.field else None 179 180 180 181 if isinstance(lvalue, Constraint): … … class WhereNode(tree.Node): 194 195 else: 195 196 # A smart object with an as_sql() method. 196 197 field_sql, field_params = lvalue.as_sql(qn, connection) 198 if isinstance(field_sql, tuple): 199 # It also returned params 200 additional_params.extend(field_sql[1]) 201 field_sql = field_sql[0] 197 202 198 203 is_datetime_field = value_annotation is datetime.datetime 199 204 cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' … … class WhereNode(tree.Node): 201 206 if hasattr(params, 'as_sql'): 202 207 extra, params = params.as_sql(qn, connection) 203 208 cast_sql = '' 209 if isinstance(extra, tuple): 210 params = params + tuple(extra[1]) 211 extra = extra[0] 204 212 else: 205 213 extra = '' 206 214 … … class WhereNode(tree.Node): 211 219 lookup_type = 'isnull' 212 220 value_annotation = True 213 221 222 additional_params.extend(params) 223 params = additional_params 224 214 225 if lookup_type in connection.operators: 215 226 format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) 216 227 return (format % (field_sql, -
tests/aggregation/tests.py
diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index ce7f4e9..a55d22f 100644
a b import re 6 6 7 7 from django.db import connection 8 8 from django.db.models import Avg, Sum, Count, Max, Min 9 from django.db.models import Q, F 10 from django.core.exceptions import FieldError 9 11 from django.test import TestCase, Approximate 10 12 from django.test.utils import CaptureQueriesContext 11 13 … … class BaseAggregateTestCase(TestCase): 21 23 def test_single_aggregate(self): 22 24 vals = Author.objects.aggregate(Avg("age")) 23 25 self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)}) 26 vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29))) 27 self.assertEqual(vals, {"age__sum": 254}) 28 vals = Author.objects.extra(select={'testparams': 'age < %s'}, select_params=[0])\ 29 .aggregate(Sum("age", only=Q(age__gt=29))) 30 self.assertEqual(vals, {"age__sum": 254}) 31 vals = Author.objects.aggregate(Sum("age", only=Q(name__icontains='jaco') | Q(name__icontains='adrian'))) 32 self.assertEqual(vals, {"age__sum": 69}) 24 33 25 34 def test_multiple_aggregates(self): 26 35 vals = Author.objects.aggregate(Sum("age"), Avg("age")) 27 36 self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)}) 37 vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)), Avg("age")) 38 self.assertEqual(vals, {"age__sum": 254, "age__avg": Approximate(37.4, places=1)}) 28 39 29 40 def test_filter_aggregate(self): 30 41 vals = Author.objects.filter(age__gt=29).aggregate(Sum("age")) 31 42 self.assertEqual(len(vals), 1) 32 43 self.assertEqual(vals["age__sum"], 254) 44 vals = Author.objects.filter(age__gt=29).aggregate(Sum("age", only=Q(age__lt=29))) 45 # If there are no matching aggregates, then None, not 0 is the answer. 46 self.assertEqual(vals["age__sum"], None) 33 47 34 48 def test_related_aggregate(self): 35 49 vals = Author.objects.aggregate(Avg("friends__age")) … … class BaseAggregateTestCase(TestCase): 52 66 self.assertEqual(len(vals), 1) 53 67 self.assertEqual(vals["book__price__sum"], Decimal("270.27")) 54 68 69 vals = Author.objects.aggregate(Avg("friends__age", only=Q(age__lt=29))) 70 self.assertEqual(len(vals), 1) 71 self.assertAlmostEqual(vals["friends__age__avg"], 33.67, places=2) 72 vals2 = Author.objects.filter(age__lt=29).aggregate(Avg("friends__age")) 73 self.assertEqual(vals, vals2) 74 75 vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=35))) 76 self.assertEqual(len(vals), 1) 77 self.assertAlmostEqual(vals["friends__age__avg"], 28.75, places=2) 78 79 # The average age of author's friends, whose age is lower than the authors age. 80 vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=F('age')))) 81 self.assertEqual(len(vals), 1) 82 self.assertAlmostEqual(vals["friends__age__avg"], 30.43, places=2) 83 55 84 def test_aggregate_multi_join(self): 56 85 vals = Store.objects.aggregate(Max("books__authors__age")) 57 86 self.assertEqual(len(vals), 1) 58 87 self.assertEqual(vals["books__authors__age__max"], 57) 59 88 89 vals = Store.objects.aggregate(Max("books__authors__age", only=Q(books__authors__age__lt=56))) 90 self.assertEqual(len(vals), 1) 91 self.assertEqual(vals["books__authors__age__max"], 46) 92 60 93 vals = Author.objects.aggregate(Min("book__publisher__num_awards")) 61 94 self.assertEqual(len(vals), 1) 62 95 self.assertEqual(vals["book__publisher__num_awards__min"], 1) … … class BaseAggregateTestCase(TestCase): 87 120 ) 88 121 self.assertEqual(b.mean_age, 34.5) 89 122 123 # Test extra-select 124 books = Book.objects.annotate(mean_age=Avg("authors__age")) 125 books = books.annotate(mean_age2=Avg('authors__age', only=Q(authors__age__gte=0))) 126 books = books.extra(select={'testparams': 'publisher_id = %s'}, select_params=[1]) 127 b = books.get(pk=1) 128 self.assertEqual(b.mean_age, 34.5) 129 self.assertEqual(b.mean_age2, 34.5) 130 self.assertEqual(b.testparams, True) 131 132 # Test relabel_aliases 133 excluded_authors = Author.objects.annotate(book_rating=Min(F('book__rating') + 5, only=Q(pk__gte=1))) 134 excluded_authors = excluded_authors.filter(book_rating__lt=0) 135 books = books.exclude(authors__in=excluded_authors) 136 b = books.get(pk=1) 137 self.assertEqual(b.mean_age, 34.5) 138 139 # Test joins in F-based annotation 140 books = Book.objects.annotate(oldest=Max(F('authors__age'))) 141 books = books.values_list('rating', 'oldest').order_by('rating', 'oldest') 142 self.assertEqual( 143 list(books), 144 [(3.0, 45), (4.0, 29), (4.0, 37), (4.0, 57), (4.5, 35), (5.0, 57)] 145 ) 146 147 publishers = Publisher.objects.annotate(avg_rating=Avg(F('book__rating') - 0)) 148 publishers = publishers.values_list('id', 'avg_rating').order_by('id') 149 self.assertEqual(list(publishers), [(1, 4.25), (2, 3.0), (3, 4.0), (4, 5.0), (5, None)]) 150 90 151 def test_annotate_m2m(self): 91 152 books = Book.objects.filter(rating__lt=4.5).annotate(Avg("authors__age")).order_by("name") 92 153 self.assertQuerysetEqual( … … class BaseAggregateTestCase(TestCase): 99 160 lambda b: (b.name, b.authors__age__avg), 100 161 ) 101 162 163 def raises_exception(): 164 list(Book.objects.annotate(num_authors=Count("authors")).annotate(num_authors2=Count("authors", only=Q(num_authors__gt=1))).order_by("name")) 165 166 self.assertRaises(FieldError, raises_exception) 167 102 168 books = Book.objects.annotate(num_authors=Count("authors")).order_by("name") 103 169 self.assertQuerysetEqual( 104 170 books, [ … … class BaseAggregateTestCase(TestCase): 169 235 ) 170 236 171 237 def test_annotate_values(self): 238 books = Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age', only=Q(authors__age__lt=35))).values('pk', 'isbn', 'mean_age') 239 self.assertEqual( 240 list(books), [ 241 { 242 "pk": 1, 243 "isbn": "159059725", 244 "mean_age": 34.0, 245 } 246 ] 247 ) 172 248 books = list(Book.objects.filter(pk=1).annotate(mean_age=Avg("authors__age")).values()) 173 249 self.assertEqual( 174 250 books, [ … … class BaseAggregateTestCase(TestCase): 275 351 vals = Book.objects.aggregate(Count("rating", distinct=True)) 276 352 self.assertEqual(vals, {"rating__count": 4}) 277 353 354 vals = Book.objects.aggregate( 355 low_count=Count("rating", only=Q(rating__lt=4)), 356 high_count=Count("rating", only=Q(rating__gte=4)) 357 ) 358 self.assertEqual(vals, {"low_count": 1, 'high_count': 5}) 359 vals = Book.objects.aggregate( 360 low_count=Count("rating", distinct=True, only=Q(rating__lt=4)), 361 high_count=Count("rating", distinct=True, only=Q(rating__gte=4)) 362 ) 363 self.assertEqual(vals, {"low_count": 1, 'high_count': 3}) 364 278 365 def test_fkey_aggregate(self): 279 366 explicit = list(Author.objects.annotate(Count('book__id'))) 280 367 implicit = list(Author.objects.annotate(Count('book'))) … … class BaseAggregateTestCase(TestCase): 394 481 lambda p: p.name, 395 482 ) 396 483 484 publishers = Publisher.objects.annotate(num_books=Count("book__id", only=Q(book__id__gt=5))).filter(num_books__gt=1, book__price__lt=Decimal("40.0")).order_by("pk") 485 self.assertQuerysetEqual( 486 publishers, [ 487 "Expensive Publisher", 488 ], 489 lambda p: p.name, 490 ) 491 397 492 publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk") 398 493 self.assertQuerysetEqual( 399 494 publishers, [