Ticket #11305: poc_11305.2.patch

File poc_11305.2.patch, 17.9 KB (added by Anssi Kääriäinen, 13 years ago)
  • django/db/models/aggregates.py

    diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
    index a2349cf..3d8b63a5 100644
    a b class Aggregate(object):  
    66    """
    77    Default Aggregate definition.
    88    """
    9     def __init__(self, lookup, **extra):
     9    def __init__(self, lookup, only=None, **extra):
    1010        """Instantiate a new aggregate.
    1111
    1212         * lookup is the field on which the aggregate operates.
     13         * only is a Q-object used in conditional aggregation.
    1314         * extra is a dictionary of additional data to provide for the
    1415           aggregate definition
    1516
    class Aggregate(object):  
    1819        """
    1920        self.lookup = lookup
    2021        self.extra = extra
     22        self.only = only
     23        self.condition = None
    2124
    2225    def _default_alias(self):
    2326        return '%s__%s' % (self.lookup, self.name.lower())
    class Aggregate(object):  
    4245           summary value rather than an annotation.
    4346        """
    4447        klass = getattr(query.aggregates_module, self.name)
    45         aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)
     48        aggregate = klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra)
    4649        query.aggregates[alias] = aggregate
    4750
    4851class Avg(Aggregate):
  • django/db/models/sql/aggregates.py

    diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
    index 207bc0c..0de19a3 100644
    a b class Aggregate(object):  
    2222    is_ordinal = False
    2323    is_computed = False
    2424    sql_template = '%(function)s(%(field)s)'
     25    conditional_template = "CASE WHEN %(condition)s THEN %(field_name)s ELSE null END"
    2526
    26     def __init__(self, col, source=None, is_summary=False, **extra):
     27    def __init__(self, col, source=None, is_summary=False, condition=None, **extra):
    2728        """Instantiate an SQL aggregate
    2829
    2930         * col is a column reference describing the subject field
    class Aggregate(object):  
    5253        self.source = source
    5354        self.is_summary = is_summary
    5455        self.extra = extra
     56        self.condition = condition
    5557
    5658        # Follow the chain of aggregate sources back until you find an
    5759        # actual field, or an aggregate that forces a particular output
    class Aggregate(object):  
    7678    def as_sql(self, qn, connection):
    7779        "Return the aggregate, rendered as SQL."
    7880
     81        query_params = []
    7982        if hasattr(self.col, 'as_sql'):
    8083            field_name = self.col.as_sql(qn, connection)
    8184        elif isinstance(self.col, (list, tuple)):
    8285            field_name = '.'.join([qn(c) for c in self.col])
    8386        else:
    8487            field_name = self.col
    85 
    86         params = {
    87             'function': self.sql_function,
    88             'field': field_name
    89         }
     88        if self.condition:
     89            condition = self.condition.as_sql(qn, connection)
     90            query_params = condition[1]
     91            conditional_field = self.conditional_template % {
     92                'condition': condition[0],
     93                'field_name': field_name
     94            }
     95            params = {
     96                'function': self.sql_function,
     97                'field': conditional_field,
     98            }
     99        else:
     100            params = {
     101                'function': self.sql_function,
     102                'field': field_name
     103            }
    90104        params.update(self.extra)
    91105
    92         return self.sql_template % params
     106        return (self.sql_template % params, query_params)
    93107
    94108
    95109class Avg(Aggregate):
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 841ec12..a0d2f35 100644
    a b class SQLCompiler(object):  
    5858            return '', ()
    5959
    6060        self.pre_sql_setup()
    61         out_cols = self.get_columns(with_col_aliases)
     61        out_cols, c_params = self.get_columns(with_col_aliases)
    6262        ordering, ordering_group_by = self.get_ordering()
     63        params = []
     64        params.extend(c_params)
    6365
    6466        # This must come after 'select' and 'ordering' -- see docstring of
    6567        # get_from_clause() for details.
    class SQLCompiler(object):  
    6971
    7072        where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
    7173        having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
    72         params = []
    7374        for val in self.query.extra_select.itervalues():
    7475            params.extend(val[1])
    7576
    class SQLCompiler(object):  
    126127            if nowait and not self.connection.features.has_select_for_update_nowait:
    127128                raise DatabaseError('NOWAIT is not supported on this database backend.')
    128129            result.append(self.connection.ops.for_update_sql(nowait=nowait))
    129 
    130130        return ' '.join(result), tuple(params)
    131131
    132132    def as_nested_sql(self):
    class SQLCompiler(object):  
    158158        qn = self.quote_name_unless_alias
    159159        qn2 = self.connection.ops.quote_name
    160160        result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()]
     161        query_params = []
    161162        aliases = set(self.query.extra_select.keys())
    162163        if with_aliases:
    163164            col_aliases = aliases.copy()
    class SQLCompiler(object):  
    200201            aliases.update(new_aliases)
    201202
    202203        max_name_length = self.connection.ops.max_name_length()
    203         result.extend([
    204             '%s%s' % (
    205                 aggregate.as_sql(qn, self.connection),
    206                 alias is not None
    207                     and ' AS %s' % qn(truncate_name(alias, max_name_length))
    208                     or ''
     204        for alias, aggregate in self.query.aggregate_select.items():
     205            sql, params = aggregate.as_sql(qn, self.connection)
     206            result.append(
     207                '%s%s' % (
     208                    sql,
     209                    alias is not None
     210                       and ' AS %s' % qn(truncate_name(alias, max_name_length))
     211                       or ''
     212                )
    209213            )
    210             for alias, aggregate in self.query.aggregate_select.items()
    211         ])
     214            query_params.extend(params)
    212215
    213216        for table, col in self.query.related_select_cols:
    214217            r = '%s.%s' % (qn(table), qn(col))
    class SQLCompiler(object):  
    223226                col_aliases.add(col)
    224227
    225228        self._select_aliases = aliases
    226         return result
     229        return result, query_params
    227230
    228231    def get_default_columns(self, with_aliases=False, col_aliases=None,
    229232            start_alias=None, opts=None, as_pairs=False, local_only=False):
    class SQLAggregateCompiler(SQLCompiler):  
    948951        """
    949952        if qn is None:
    950953            qn = self.quote_name_unless_alias
     954        buf = []
     955        a_params = []
     956        for aggregate in self.query.aggregate_select.values():
     957            sql, query_params = aggregate.as_sql(qn, self.connection)
     958            buf.append(sql)
     959            a_params.extend(query_params)
     960        aggregate_sql = ', '.join(buf)
    951961        sql = ('SELECT %s FROM (%s) subquery' % (
    952             ', '.join([
    953                 aggregate.as_sql(qn, self.connection)
    954                 for aggregate in self.query.aggregate_select.values()
    955             ]),
     962            aggregate_sql, 
    956963            self.query.subquery)
    957964        )
    958         params = self.query.sub_params
     965        params = tuple(a_params) + (self.query.sub_params)
    959966        return (sql, params)
    960967
    961968class SQLDateCompiler(SQLCompiler):
  • django/db/models/sql/expressions.py

    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    index 1bbf742..f9c23a9 100644
    a b class SQLEvaluator(object):  
    6565        for child in node.children:
    6666            if hasattr(child, 'evaluate'):
    6767                sql, params = child.evaluate(self, qn, connection)
     68                if isinstance(sql, tuple):
     69                    expression_params.extend(sql[1])
     70                    sql = sql[0]
    6871            else:
    6972                sql, params = '%s', (child,)
    7073
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index 110e317..7735261 100644
    a b class Query(object):  
    956956        """
    957957        opts = model._meta
    958958        field_list = aggregate.lookup.split(LOOKUP_SEP)
     959        only = aggregate.only
     960        join_list = []
    959961        if len(field_list) == 1 and aggregate.lookup in self.aggregates:
    960962            # Aggregate is over an annotation
    961963            field_name = field_list[0]
    class Query(object):  
    964966            if not is_summary:
    965967                raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
    966968                    aggregate.name, field_name, field_name))
     969            if only:
     970                raise FieldError("Cannot use aggregated fields in conditional aggregates")
    967971        elif ((len(field_list) > 1) or
    968972            (field_list[0] not in [i.name for i in opts.fields]) or
    969973            self.group_by is None or
    class Query(object):  
    995999            source = opts.get_field(field_name)
    9961000            col = field_name
    9971001
     1002        if only:
     1003            original_where = self.where
     1004            original_having = self.having
     1005            aggregate.condition = self.where_class()
     1006            self.where = aggregate.condition
     1007            self.having = self.where_class()
     1008            original_alias_map = self.alias_map.keys()[:]
     1009            self.add_q(only, used_aliases=set(original_alias_map))
     1010            if original_alias_map != self.alias_map.keys():
     1011                raise FieldError("Aggregate's only condition can not require additional joins, Original joins: %s, joins after: %s" % (original_alias_map, self.alias_map.keys()))
     1012            if self.having.children:
     1013                raise FieldError("Aggregate's only condition can not reference annotated fields")
     1014            self.having = original_having
     1015            self.where = original_where
    9981016        # Add the aggregate to the query
    9991017        aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
    10001018
  • django/db/models/sql/where.py

    diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
    index 2427a52..7ac0287 100644
    a b class WhereNode(tree.Node):  
    137137        it.
    138138        """
    139139        lvalue, lookup_type, value_annot, params_or_value = child
     140        additional_params = []
    140141        if hasattr(lvalue, 'process'):
    141142            try:
    142143                lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
    class WhereNode(tree.Node):  
    151152        else:
    152153            # A smart object with an as_sql() method.
    153154            field_sql = lvalue.as_sql(qn, connection)
     155            if isinstance(field_sql, tuple):
     156                # It returned also params
     157                additional_params.extend(field_sql[1])
     158                field_sql = field_sql[0]
    154159
    155160        if value_annot is datetime.datetime:
    156161            cast_sql = connection.ops.datetime_cast_sql()
    class WhereNode(tree.Node):  
    159164
    160165        if hasattr(params, 'as_sql'):
    161166            extra, params = params.as_sql(qn, connection)
     167            if isinstance(extra, tuple):
     168                params = params + tuple(extra[1])
     169                extra = extra[0]
    162170            cast_sql = ''
    163171        else:
    164172            extra = ''
    class WhereNode(tree.Node):  
    168176            lookup_type = 'isnull'
    169177            value_annot = True
    170178
     179        additional_params.extend(params)
     180        params = additional_params
    171181        if lookup_type in connection.operators:
    172182            format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
    173183            return (format % (field_sql,
  • tests/modeltests/aggregation/tests.py

    diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py
    index 6f68800..3a04fa9 100644
    a b import datetime  
    22from decimal import Decimal
    33
    44from django.db.models import Avg, Sum, Count, Max, Min
     5from django.db.models import Q, F
     6from django.core.exceptions import FieldError
    57from django.test import TestCase, Approximate
    68
    79from models import Author, Publisher, Book, Store
    class BaseAggregateTestCase(TestCase):  
    1618    def test_single_aggregate(self):
    1719        vals = Author.objects.aggregate(Avg("age"))
    1820        self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)})
     21        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)))
     22        self.assertEqual(vals, {"age__sum": 254})
     23        vals = Author.objects.aggregate(Sum("age", only=Q(name__icontains='jaco')|Q(name__icontains='adrian')))
     24        self.assertEqual(vals, {"age__sum": 69})
    1925
    2026    def test_multiple_aggregates(self):
    2127        vals = Author.objects.aggregate(Sum("age"), Avg("age"))
    2228        self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)})
     29        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)), Avg("age"))
     30        self.assertEqual(vals, {"age__sum": 254, "age__avg": Approximate(37.4, places=1)})
    2331
    2432    def test_filter_aggregate(self):
    2533        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age"))
    2634        self.assertEqual(len(vals), 1)
    2735        self.assertEqual(vals["age__sum"], 254)
     36        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age", only=Q(age__lt=29)))
     37        # If there are no matching aggregates, then None, not 0 is the answer.
     38        self.assertEqual(vals["age__sum"], None)
    2839
    2940    def test_related_aggregate(self):
    3041        vals = Author.objects.aggregate(Avg("friends__age"))
    3142        self.assertEqual(len(vals), 1)
    3243        self.assertAlmostEqual(vals["friends__age__avg"], 34.07, places=2)
    3344
     45        vals = Author.objects.aggregate(Avg("friends__age", only=Q(age__lt=29)))
     46        self.assertEqual(len(vals), 1)
     47        self.assertAlmostEqual(vals["friends__age__avg"], 33.67, places=2)
     48        vals2 = Author.objects.filter(age__lt=29).aggregate(Avg("friends__age"))
     49        self.assertEqual(vals, vals2)
     50
     51        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=35)))
     52        self.assertEqual(len(vals), 1)
     53        self.assertAlmostEqual(vals["friends__age__avg"], 28.75, places=2)
     54
     55        # The average age of author's friends, whose age is lower than the authors age.
     56        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=F('age'))))
     57        self.assertEqual(len(vals), 1)
     58        self.assertAlmostEqual(vals["friends__age__avg"], 30.43, places=2)
     59
    3460        vals = Book.objects.filter(rating__lt=4.5).aggregate(Avg("authors__age"))
    3561        self.assertEqual(len(vals), 1)
    3662        self.assertAlmostEqual(vals["authors__age__avg"], 38.2857, places=2)
    class BaseAggregateTestCase(TestCase):  
    5177        vals = Store.objects.aggregate(Max("books__authors__age"))
    5278        self.assertEqual(len(vals), 1)
    5379        self.assertEqual(vals["books__authors__age__max"], 57)
     80       
     81        vals = Store.objects.aggregate(Max("books__authors__age", only=Q(books__authors__age__lt=56)))
     82        self.assertEqual(len(vals), 1)
     83        self.assertEqual(vals["books__authors__age__max"], 46)
    5484
    5585        vals = Author.objects.aggregate(Min("book__publisher__num_awards"))
    5686        self.assertEqual(len(vals), 1)
    class BaseAggregateTestCase(TestCase):  
    106136            ],
    107137            lambda b: (b.name, b.num_authors)
    108138        )
     139       
     140        def raises_exception():
     141            list(Book.objects.annotate(num_authors=Count("authors")).annotate(num_authors2=Count("authors", only=Q(num_authors__gt=1))).order_by("name"))
     142
     143        self.assertRaises(FieldError, raises_exception)
    109144
    110145    def test_backwards_m2m_annotate(self):
    111146        authors = Author.objects.filter(name__contains="a").annotate(Avg("book__rating")).order_by("name")
    class BaseAggregateTestCase(TestCase):  
    192227                }
    193228            ]
    194229        )
     230        books = Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age', only=Q(authors__age__lt=35))).values('pk', 'isbn', 'mean_age')
     231        self.assertEqual(
     232            list(books), [
     233                {
     234                    "pk": 1,
     235                    "isbn": "159059725",
     236                    "mean_age": 34.0,
     237                }
     238            ]
     239        )
    195240
    196241        books = Book.objects.filter(pk=1).annotate(mean_age=Avg("authors__age")).values("name")
    197242        self.assertEqual(
    class BaseAggregateTestCase(TestCase):  
    269314
    270315        vals = Book.objects.aggregate(Count("rating", distinct=True))
    271316        self.assertEqual(vals, {"rating__count": 4})
     317        vals = Book.objects.aggregate(
     318            low_count=Count("rating", only=Q(rating__lt=4)),
     319            high_count=Count("rating", only=Q(rating__gte=4))
     320        )
     321        self.assertEqual(vals, {"low_count": 1, 'high_count': 5})
     322        vals = Book.objects.aggregate(
     323            low_count=Count("rating", distinct=True, only=Q(rating__lt=4)),
     324            high_count=Count("rating", distinct=True, only=Q(rating__gte=4))
     325        )
     326        self.assertEqual(vals, {"low_count": 1, 'high_count': 3})
    272327
    273328    def test_fkey_aggregate(self):
    274329        explicit = list(Author.objects.annotate(Count('book__id')))
    class BaseAggregateTestCase(TestCase):  
    388443            ],
    389444            lambda p: p.name,
    390445        )
     446        publishers = Publisher.objects.annotate(num_books=Count("book__id", only=Q(book__id__gt=5))).filter(num_books__gt=1, book__price__lt=Decimal("40.0")).order_by("pk")
     447        self.assertQuerysetEqual(
     448            publishers, [
     449                "Expensive Publisher",
     450            ],
     451            lambda p: p.name,
     452        )
    391453
    392454        publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk")
    393455        self.assertQuerysetEqual(
Back to Top