Ticket #11305: 11305-2013-09-07-master.patch

File 11305-2013-09-07-master.patch, 23.9 KB (added by Garry Polley, 11 years ago)

patch_update_for_Django_1.6-1.7

  • django/db/models/aggregates.py

    diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
    index 1db3890..9806d34 100644
    a b class Aggregate(object):  
    2121    """
    2222    Default Aggregate definition.
    2323    """
    24     def __init__(self, lookup, **extra):
     24    def __init__(self, lookup, only=None, **extra):
    2525        """Instantiate a new aggregate.
    2626
    2727         * lookup is the field on which the aggregate operates.
     28         * only is a Q-object used in conditional aggregation.
    2829         * extra is a dictionary of additional data to provide for the
    2930           aggregate definition
    3031
    class Aggregate(object):  
    3334        """
    3435        self.lookup = lookup
    3536        self.extra = extra
     37        self.only = only
     38        self.condition = None
    3639
    3740    def _default_alias(self):
     41        if hasattr(self.lookup, 'evaluate'):
     42            raise ValueError('When aggregating over an expression, you need to give an alias.')
    3843        return '%s__%s' % (self.lookup, self.name.lower())
    3944    default_alias = property(_default_alias)
    4045
    class Aggregate(object):  
    5762           summary value rather than an annotation.
    5863        """
    5964        klass = getattr(query.aggregates_module, self.name)
    60         aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)
     65        aggregate = klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra)
    6166        query.aggregates[alias] = aggregate
    6267
    6368
  • django/db/models/expressions.py

    diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
    index b4eea5f..bdb8ec5 100644
    a b class ExpressionNode(tree.Node):  
    5757    # VISITOR METHODS #
    5858    ###################
    5959
    60     def prepare(self, evaluator, query, allow_joins):
    61         return evaluator.prepare_node(self, query, allow_joins)
     60    def prepare(self, evaluator, query, allow_joins, promote_joins=False):
     61        return evaluator.prepare_node(self, query, allow_joins, promote_joins)
    6262
    6363    def evaluate(self, evaluator, qn, connection):
    6464        return evaluator.evaluate_node(self, qn, connection)
    class F(ExpressionNode):  
    143143        obj.name = self.name
    144144        return obj
    145145
    146     def prepare(self, evaluator, query, allow_joins):
    147         return evaluator.prepare_leaf(self, query, allow_joins)
     146    def prepare(self, evaluator, query, allow_joins, promote_joins=False):
     147        return evaluator.prepare_leaf(self, query, allow_joins, promote_joins)
    148148
    149149    def evaluate(self, evaluator, qn, connection):
    150150        return evaluator.evaluate_leaf(self, qn, connection)
  • django/db/models/sql/aggregates.py

    diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
    index 3cda4d2..8c1b0e0 100644
    a b Classes to represent the default SQL aggregate functions  
    44import copy
    55
    66from django.db.models.fields import IntegerField, FloatField
     7from django.db.models.sql.expressions import SQLEvaluator
    78
    89# Fake fields used to identify aggregate types in data-conversion operations.
    910ordinal_aggregate_field = IntegerField()
    class Aggregate(object):  
    1718    is_ordinal = False
    1819    is_computed = False
    1920    sql_template = '%(function)s(%(field)s)'
     21    conditional_template = "CASE WHEN %(condition)s THEN %(field_name)s ELSE null END"
    2022
    21     def __init__(self, col, source=None, is_summary=False, **extra):
     23    def __init__(self, col, source=None, is_summary=False, condition=None, **extra):
    2224        """Instantiate an SQL aggregate
    2325
    2426         * col is a column reference describing the subject field
    class Aggregate(object):  
    2830           the column reference. If the aggregate is not an ordinal or
    2931           computed type, this reference is used to determine the coerced
    3032           output type of the aggregate.
     33         * condition is used in conditional aggregation.
    3134         * extra is a dictionary of additional data to provide for the
    3235           aggregate definition
    3336
    class Aggregate(object):  
    4750        self.source = source
    4851        self.is_summary = is_summary
    4952        self.extra = extra
     53        self.condition = condition
    5054
    5155        # Follow the chain of aggregate sources back until you find an
    5256        # actual field, or an aggregate that forces a particular output
    class Aggregate(object):  
    6872        clone = copy.copy(self)
    6973        if isinstance(self.col, (list, tuple)):
    7074            clone.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
     75        else:
     76            clone.col.relabel_aliases(change_map)
     77        if clone.condition:
     78            clone.condition.relabel_aliases(change_map)
    7179        return clone
    7280
    7381    def as_sql(self, qn, connection):
    7482        "Return the aggregate, rendered as SQL with parameters."
    75         params = []
     83        condition_params = []
     84        col_params = []
    7685
    7786        if hasattr(self.col, 'as_sql'):
    78             field_name, params = self.col.as_sql(qn, connection)
     87            if isinstance(self.col, SQLEvaluator):
     88                field_name, col_params = self.col.as_sql(qn, connection)
     89            else:
     90                field_name = self.col.as_sql(qn, connection)
    7991        elif isinstance(self.col, (list, tuple)):
    8092            field_name = '.'.join(qn(c) for c in self.col)
    8193        else:
    8294            field_name = self.col
    8395
    84         substitutions = {
    85             'function': self.sql_function,
    86             'field': field_name
    87         }
     96        if self.condition:
     97            condition, condition_params = self.condition.as_sql(qn, connection)
     98            conditional_field = self.conditional_template % {
     99                'condition': condition,
     100                'field_name': field_name
     101            }
     102            substitutions = {
     103                'function': self.sql_function,
     104                'field': conditional_field,
     105            }
     106        else:
     107            substitutions = {
     108                'function': self.sql_function,
     109                'field': field_name
     110            }
     111
    88112        substitutions.update(self.extra)
    89113
    90         return self.sql_template % substitutions, params
     114        return (self.sql_template % substitutions, condition_params)
    91115
    92116
    93117class Avg(Aggregate):
  • django/db/models/sql/expressions.py

    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    index f9a8929..37608fa 100644
    a b from django.db.models.fields import FieldDoesNotExist  
    66
    77
    88class SQLEvaluator(object):
    9     def __init__(self, expression, query, allow_joins=True, reuse=None):
     9    def __init__(self, expression, query, allow_joins=True, reuse=None, promote_joins=False):
    1010        self.expression = expression
    1111        self.opts = query.get_meta()
    1212        self.reuse = reuse
    1313        self.cols = []
    14         self.expression.prepare(self, query, allow_joins)
     14        self.expression.prepare(self, query, allow_joins, promote_joins)
    1515
    1616    def relabeled_clone(self, change_map):
    1717        clone = copy.copy(self)
    class SQLEvaluator(object):  
    4343    # Vistor methods for initial expression preparation #
    4444    #####################################################
    4545
    46     def prepare_node(self, node, query, allow_joins):
     46    def prepare_node(self, node, query, allow_joins, promote_joins):
    4747        for child in node.children:
    4848            if hasattr(child, 'prepare'):
    49                 child.prepare(self, query, allow_joins)
     49                child.prepare(self, query, allow_joins, promote_joins)
    5050
    51     def prepare_leaf(self, node, query, allow_joins):
     51    def prepare_leaf(self, node, query, allow_joins, promote_joins):
    5252        if not allow_joins and LOOKUP_SEP in node.name:
    5353            raise FieldError("Joined field references are not permitted in this query")
    5454
    class SQLEvaluator(object):  
    6161                    field_list, query.get_meta(),
    6262                    query.get_initial_alias(), self.reuse)
    6363                targets, _, join_list = query.trim_joins(sources, join_list, path)
     64                if promote_joins:
     65                    query.promote_joins(join_list, unconditional=True)
    6466                if self.reuse is not None:
    6567                    self.reuse.update(join_list)
    6668                for t in targets:
    class SQLEvaluator(object):  
    8082        for child in node.children:
    8183            if hasattr(child, 'evaluate'):
    8284                sql, params = child.evaluate(self, qn, connection)
     85                if isinstance(sql, tuple):
     86                    expression_params.extend(sql[1])
     87                    sql = sql[0]
    8388            else:
    8489                sql, params = '%s', (child,)
    8590
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index 93a0b52..dc0072f 100644
    a b class Query(object):  
    966966        Adds a single aggregate expression to the Query
    967967        """
    968968        opts = model._meta
    969         field_list = aggregate.lookup.split(LOOKUP_SEP)
    970         if len(field_list) == 1 and aggregate.lookup in self.aggregates:
    971             # Aggregate is over an annotation
    972             field_name = field_list[0]
    973             col = field_name
    974             source = self.aggregates[field_name]
    975             if not is_summary:
    976                 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
    977                     aggregate.name, field_name, field_name))
    978         elif ((len(field_list) > 1) or
    979             (field_list[0] not in [i.name for i in opts.fields]) or
    980             self.group_by is None or
    981             not is_summary):
    982             # If:
    983             #   - the field descriptor has more than one part (foo__bar), or
    984             #   - the field descriptor is referencing an m2m/m2o field, or
    985             #   - this is a reference to a model field (possibly inherited), or
    986             #   - this is an annotation over a model field
    987             # then we need to explore the joins that are required.
    988 
    989             field, sources, opts, join_list, path = self.setup_joins(
    990                 field_list, opts, self.get_initial_alias())
    991 
    992             # Process the join chain to see if it can be trimmed
    993             targets, _, join_list = self.trim_joins(sources, join_list, path)
    994 
    995             # If the aggregate references a model or field that requires a join,
    996             # those joins must be LEFT OUTER - empty join rows must be returned
    997             # in order for zeros to be returned for those aggregates.
    998             self.promote_joins(join_list)
    999 
    1000             col = targets[0].column
    1001             source = sources[0]
    1002             col = (join_list[-1], col)
     969        only = aggregate.only
     970        if hasattr(aggregate.lookup, 'evaluate'):
     971            # If lookup is a query expression, evaluate it
     972            col = SQLEvaluator(aggregate.lookup, self, promote_joins=True)
     973            source = opts.get_field(col)
    1003974        else:
    1004             # The simplest cases. No joins required -
    1005             # just reference the provided column alias.
    1006             field_name = field_list[0]
    1007             source = opts.get_field(field_name)
    1008             col = field_name
     975            field_list = aggregate.lookup.split(LOOKUP_SEP)
     976            if len(field_list) == 1 and aggregate.lookup in self.aggregates:
     977                # Aggregate is over an annotation
     978                field_name = field_list[0]
     979                col = field_name
     980                source = self.aggregates[field_name]
     981                if not is_summary:
     982                    raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
     983                        aggregate.name, field_name, field_name))
     984            elif ((len(field_list) > 1) or
     985                (field_list[0] not in [i.name for i in opts.fields]) or
     986                self.group_by is None or
     987                not is_summary):
     988                # If:
     989                #   - the field descriptor has more than one part (foo__bar), or
     990                #   - the field descriptor is referencing an m2m/m2o field, or
     991                #   - this is a reference to a model field (possibly inherited), or
     992                #   - this is an annotation over a model field
     993                # then we need to explore the joins that are required.
     994
     995                field, sources, opts, join_list, path = self.setup_joins(
     996                    field_list, opts, self.get_initial_alias())
     997
     998                # Process the join chain to see if it can be trimmed
     999                targets, _, join_list = self.trim_joins(sources, join_list, path)
     1000
     1001                # If the aggregate references a model or field that requires a join,
     1002                # those joins must be LEFT OUTER - empty join rows must be returned
     1003                # in order for zeros to be returned for those aggregates.
     1004                self.promote_joins(join_list)
     1005
     1006                col = targets[0].column
     1007                source = sources[0]
     1008                col = (join_list[-1], col)
     1009            else:
     1010                # The simplest cases. No joins required -
     1011                # just reference the provided column alias.
     1012                field_name = field_list[0]
     1013                source = opts.get_field(field_name)
     1014                col = field_name
    10091015        # We want to have the alias in SELECT clause even if mask is set.
    10101016        self.append_aggregate_mask([alias])
    10111017
     1018        if only:
     1019            original_where = self.where
     1020            original_having = self.having
     1021            aggregate.condition = self.where_class()
     1022            self.where = aggregate.condition
     1023            self.having = self.where_class()
     1024            original_alias_map = self.alias_map.keys()[:]
     1025            self.add_q(only, used_aliases=set(original_alias_map))
     1026            if original_alias_map != self.alias_map.keys():
     1027                raise FieldError("Aggregate's only condition can not require additional joins, "
     1028                                 "Original joins: %s, joins after: %s"
     1029                                 % (original_alias_map, self.alias_map.keys()))
     1030            if self.having.children:
     1031                raise FieldError("Aggregate's only condition can not reference annotated fields")
     1032            self.having = original_having
     1033            self.where = original_where
     1034
    10121035        # Add the aggregate to the query
    10131036        aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
    10141037
  • django/db/models/sql/where.py

    diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
    index 7b71580..d06d5b0 100644
    a b class WhereNode(tree.Node):  
    175175        it.
    176176        """
    177177        lvalue, lookup_type, value_annotation, params_or_value = child
     178        additional_params = []
    178179        field_internal_type = lvalue.field.get_internal_type() if lvalue.field else None
    179180
    180181        if isinstance(lvalue, Constraint):
    class WhereNode(tree.Node):  
    194195        else:
    195196            # A smart object with an as_sql() method.
    196197            field_sql, field_params = lvalue.as_sql(qn, connection)
     198            if isinstance(field_sql, tuple):
     199                # It also returned params
     200                additional_params.extend(field_sql[1])
     201                field_sql = field_sql[0]
    197202
    198203        is_datetime_field = value_annotation is datetime.datetime
    199204        cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
    class WhereNode(tree.Node):  
    201206        if hasattr(params, 'as_sql'):
    202207            extra, params = params.as_sql(qn, connection)
    203208            cast_sql = ''
     209            if isinstance(extra, tuple):
     210                params = params + tuple(extra[1])
     211                extra = extra[0]
    204212        else:
    205213            extra = ''
    206214
    class WhereNode(tree.Node):  
    211219            lookup_type = 'isnull'
    212220            value_annotation = True
    213221
     222        additional_params.extend(params)
     223        params = additional_params
     224
    214225        if lookup_type in connection.operators:
    215226            format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
    216227            return (format % (field_sql,
  • tests/aggregation/tests.py

    diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py
    index ce7f4e9..a55d22f 100644
    a b import re  
    66
    77from django.db import connection
    88from django.db.models import Avg, Sum, Count, Max, Min
     9from django.db.models import Q, F
     10from django.core.exceptions import FieldError
    911from django.test import TestCase, Approximate
    1012from django.test.utils import CaptureQueriesContext
    1113
    class BaseAggregateTestCase(TestCase):  
    2123    def test_single_aggregate(self):
    2224        vals = Author.objects.aggregate(Avg("age"))
    2325        self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)})
     26        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)))
     27        self.assertEqual(vals, {"age__sum": 254})
     28        vals = Author.objects.extra(select={'testparams': 'age < %s'}, select_params=[0])\
     29               .aggregate(Sum("age", only=Q(age__gt=29)))
     30        self.assertEqual(vals, {"age__sum": 254})
     31        vals = Author.objects.aggregate(Sum("age", only=Q(name__icontains='jaco') | Q(name__icontains='adrian')))
     32        self.assertEqual(vals, {"age__sum": 69})
    2433
    2534    def test_multiple_aggregates(self):
    2635        vals = Author.objects.aggregate(Sum("age"), Avg("age"))
    2736        self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)})
     37        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)), Avg("age"))
     38        self.assertEqual(vals, {"age__sum": 254, "age__avg": Approximate(37.4, places=1)})
    2839
    2940    def test_filter_aggregate(self):
    3041        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age"))
    3142        self.assertEqual(len(vals), 1)
    3243        self.assertEqual(vals["age__sum"], 254)
     44        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age", only=Q(age__lt=29)))
     45        # If there are no matching aggregates, then None, not 0 is the answer.
     46        self.assertEqual(vals["age__sum"], None)
    3347
    3448    def test_related_aggregate(self):
    3549        vals = Author.objects.aggregate(Avg("friends__age"))
    class BaseAggregateTestCase(TestCase):  
    5266        self.assertEqual(len(vals), 1)
    5367        self.assertEqual(vals["book__price__sum"], Decimal("270.27"))
    5468
     69        vals = Author.objects.aggregate(Avg("friends__age", only=Q(age__lt=29)))
     70        self.assertEqual(len(vals), 1)
     71        self.assertAlmostEqual(vals["friends__age__avg"], 33.67, places=2)
     72        vals2 = Author.objects.filter(age__lt=29).aggregate(Avg("friends__age"))
     73        self.assertEqual(vals, vals2)
     74
     75        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=35)))
     76        self.assertEqual(len(vals), 1)
     77        self.assertAlmostEqual(vals["friends__age__avg"], 28.75, places=2)
     78
     79        # The average age of author's friends, whose age is lower than the authors age.
     80        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=F('age'))))
     81        self.assertEqual(len(vals), 1)
     82        self.assertAlmostEqual(vals["friends__age__avg"], 30.43, places=2)
     83
    5584    def test_aggregate_multi_join(self):
    5685        vals = Store.objects.aggregate(Max("books__authors__age"))
    5786        self.assertEqual(len(vals), 1)
    5887        self.assertEqual(vals["books__authors__age__max"], 57)
    5988
     89        vals = Store.objects.aggregate(Max("books__authors__age", only=Q(books__authors__age__lt=56)))
     90        self.assertEqual(len(vals), 1)
     91        self.assertEqual(vals["books__authors__age__max"], 46)
     92
    6093        vals = Author.objects.aggregate(Min("book__publisher__num_awards"))
    6194        self.assertEqual(len(vals), 1)
    6295        self.assertEqual(vals["book__publisher__num_awards__min"], 1)
    class BaseAggregateTestCase(TestCase):  
    87120        )
    88121        self.assertEqual(b.mean_age, 34.5)
    89122
     123        # Test extra-select
     124        books = Book.objects.annotate(mean_age=Avg("authors__age"))
     125        books = books.annotate(mean_age2=Avg('authors__age', only=Q(authors__age__gte=0)))
     126        books = books.extra(select={'testparams': 'publisher_id = %s'}, select_params=[1])
     127        b = books.get(pk=1)
     128        self.assertEqual(b.mean_age, 34.5)
     129        self.assertEqual(b.mean_age2, 34.5)
     130        self.assertEqual(b.testparams, True)
     131
     132        # Test relabel_aliases
     133        excluded_authors = Author.objects.annotate(book_rating=Min(F('book__rating') + 5, only=Q(pk__gte=1)))
     134        excluded_authors = excluded_authors.filter(book_rating__lt=0)
     135        books = books.exclude(authors__in=excluded_authors)
     136        b = books.get(pk=1)
     137        self.assertEqual(b.mean_age, 34.5)
     138
     139        # Test joins in F-based annotation
     140        books = Book.objects.annotate(oldest=Max(F('authors__age')))
     141        books = books.values_list('rating', 'oldest').order_by('rating', 'oldest')
     142        self.assertEqual(
     143            list(books),
     144            [(3.0, 45), (4.0, 29), (4.0, 37), (4.0, 57), (4.5, 35), (5.0, 57)]
     145        )
     146
     147        publishers = Publisher.objects.annotate(avg_rating=Avg(F('book__rating') - 0))
     148        publishers = publishers.values_list('id', 'avg_rating').order_by('id')
     149        self.assertEqual(list(publishers), [(1, 4.25), (2, 3.0), (3, 4.0), (4, 5.0), (5, None)])
     150
    90151    def test_annotate_m2m(self):
    91152        books = Book.objects.filter(rating__lt=4.5).annotate(Avg("authors__age")).order_by("name")
    92153        self.assertQuerysetEqual(
    class BaseAggregateTestCase(TestCase):  
    99160            lambda b: (b.name, b.authors__age__avg),
    100161        )
    101162
     163        def raises_exception():
     164            list(Book.objects.annotate(num_authors=Count("authors")).annotate(num_authors2=Count("authors", only=Q(num_authors__gt=1))).order_by("name"))
     165
     166        self.assertRaises(FieldError, raises_exception)
     167
    102168        books = Book.objects.annotate(num_authors=Count("authors")).order_by("name")
    103169        self.assertQuerysetEqual(
    104170            books, [
    class BaseAggregateTestCase(TestCase):  
    169235        )
    170236
    171237    def test_annotate_values(self):
     238        books = Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age', only=Q(authors__age__lt=35))).values('pk', 'isbn', 'mean_age')
     239        self.assertEqual(
     240            list(books), [
     241                {
     242                    "pk": 1,
     243                    "isbn": "159059725",
     244                    "mean_age": 34.0,
     245                }
     246            ]
     247        )
    172248        books = list(Book.objects.filter(pk=1).annotate(mean_age=Avg("authors__age")).values())
    173249        self.assertEqual(
    174250            books, [
    class BaseAggregateTestCase(TestCase):  
    275351        vals = Book.objects.aggregate(Count("rating", distinct=True))
    276352        self.assertEqual(vals, {"rating__count": 4})
    277353
     354        vals = Book.objects.aggregate(
     355                low_count=Count("rating", only=Q(rating__lt=4)),
     356                high_count=Count("rating", only=Q(rating__gte=4))
     357        )
     358        self.assertEqual(vals, {"low_count": 1, 'high_count': 5})
     359        vals = Book.objects.aggregate(
     360            low_count=Count("rating", distinct=True, only=Q(rating__lt=4)),
     361            high_count=Count("rating", distinct=True, only=Q(rating__gte=4))
     362        )
     363        self.assertEqual(vals, {"low_count": 1, 'high_count': 3})
     364
    278365    def test_fkey_aggregate(self):
    279366        explicit = list(Author.objects.annotate(Count('book__id')))
    280367        implicit = list(Author.objects.annotate(Count('book')))
    class BaseAggregateTestCase(TestCase):  
    394481            lambda p: p.name,
    395482        )
    396483
     484        publishers = Publisher.objects.annotate(num_books=Count("book__id", only=Q(book__id__gt=5))).filter(num_books__gt=1, book__price__lt=Decimal("40.0")).order_by("pk")
     485        self.assertQuerysetEqual(
     486            publishers, [
     487                "Expensive Publisher",
     488            ],
     489            lambda p: p.name,
     490        )
     491
    397492        publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk")
    398493        self.assertQuerysetEqual(
    399494            publishers, [
Back to Top