Ticket #14030: 14030.patch

File 14030.patch, 19.8 KB (added by Nate Bragg, 12 years ago)
  • django/db/models/aggregates.py

    From c7a74c08def758c62997ba037eccfb8f73ba3efc Mon Sep 17 00:00:00 2001
    From: Nate Bragg <jonathan.bragg@alum.rpi.edu>
    Date: Thu, 19 Jan 2012 21:01:32 -0500
    Subject: [PATCH] An attempt at rebasing out the changes required for
     supporting F expressions in aggregation from the more
     complex patch supporting conditional aggregation for
     #11305.
    
    Additional changes needed to make F expressions usable without
    being passed in inside an aggregation function.
    
    Also added some doc, and some tests.
    ---
     django/db/models/aggregates.py        |    2 +
     django/db/models/sql/aggregates.py    |   20 ++++++-
     django/db/models/sql/compiler.py      |   39 +++++++++-----
     django/db/models/sql/expressions.py   |    3 +
     django/db/models/sql/query.py         |   93 +++++++++++++++++++-------------
     django/db/models/sql/where.py         |   10 ++++
     django/test/testcases.py              |   12 ++++
     docs/ref/models/querysets.txt         |   23 ++++++++
     tests/modeltests/aggregation/tests.py |   31 +++++++++++
     9 files changed, 177 insertions(+), 56 deletions(-)
    
    diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
    index a2349cf..61848fe 100644
    a b class Aggregate(object):  
    2020        self.extra = extra
    2121
    2222    def _default_alias(self):
     23        if hasattr(self.lookup, 'evaluate'):
     24             raise ValueError('When aggregating over an expression, you need to give an alias.')
    2325        return '%s__%s' % (self.lookup, self.name.lower())
    2426    default_alias = property(_default_alias)
    2527
  • django/db/models/sql/aggregates.py

    diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
    index 207bc0c..7e131b9 100644
    a b  
    11"""
    22Classes to represent the default SQL aggregate functions
    33"""
     4from django.db.models.sql.expressions import SQLEvaluator
    45
    56class AggregateField(object):
    67    """An internal field mockup used to identify aggregates in the
    class Aggregate(object):  
    2223    is_ordinal = False
    2324    is_computed = False
    2425    sql_template = '%(function)s(%(field)s)'
     26    sql_function = ''
    2527
    2628    def __init__(self, col, source=None, is_summary=False, **extra):
    2729        """Instantiate an SQL aggregate
    class Aggregate(object):  
    6668                tmp = computed_aggregate_field
    6769            else:
    6870                tmp = tmp.source
    69 
     71       
     72        # We don't know the real source of this aggregate, and the
     73        # aggregate doesn't define ordinal or computed either. So
     74        # we default to computed for these cases.
     75        if tmp is None:
     76            tmp = computed_aggregate_field
    7077        self.field = tmp
    7178
    7279    def relabel_aliases(self, change_map):
    7380        if isinstance(self.col, (list, tuple)):
    7481            self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
     82        else:
     83            self.col.relabel_aliases(change_map)
    7584
    7685    def as_sql(self, qn, connection):
    7786        "Return the aggregate, rendered as SQL."
    7887
     88        col_params = []
    7989        if hasattr(self.col, 'as_sql'):
    80             field_name = self.col.as_sql(qn, connection)
     90            if isinstance(self.col, SQLEvaluator):
     91                field_name, col_params = self.col.as_sql(qn, connection)
     92            else:
     93                field_name = self.col.as_sql(qn, connection)
     94           
    8195        elif isinstance(self.col, (list, tuple)):
    8296            field_name = '.'.join([qn(c) for c in self.col])
    8397        else:
    class Aggregate(object):  
    89103        }
    90104        params.update(self.extra)
    91105
    92         return self.sql_template % params
     106        return (self.sql_template % params, col_params)
    93107
    94108
    95109class Avg(Aggregate):
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 72948f9..bf3cb25 100644
    a b class SQLCompiler(object):  
    6868        # as the pre_sql_setup will modify query state in a way that forbids
    6969        # another run of it.
    7070        self.refcounts_before = self.query.alias_refcount.copy()
    71         out_cols = self.get_columns(with_col_aliases)
     71        out_cols, c_params = self.get_columns(with_col_aliases)
    7272        ordering, ordering_group_by = self.get_ordering()
    7373
    7474        distinct_fields = self.get_distinct()
    class SQLCompiler(object):  
    8484        params = []
    8585        for val in self.query.extra_select.itervalues():
    8686            params.extend(val[1])
     87        # Extra-select comes before aggregation in the select list
     88        params.extend(c_params)
    8789
    8890        result = ['SELECT']
    8991
    class SQLCompiler(object):  
    178180        qn = self.quote_name_unless_alias
    179181        qn2 = self.connection.ops.quote_name
    180182        result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()]
     183        query_params = []
    181184        aliases = set(self.query.extra_select.keys())
    182185        if with_aliases:
    183186            col_aliases = aliases.copy()
    class SQLCompiler(object):  
    220223            aliases.update(new_aliases)
    221224
    222225        max_name_length = self.connection.ops.max_name_length()
    223         result.extend([
    224             '%s%s' % (
    225                 aggregate.as_sql(qn, self.connection),
    226                 alias is not None
    227                     and ' AS %s' % qn(truncate_name(alias, max_name_length))
    228                     or ''
     226        for alias, aggregate in self.query.aggregate_select.items():
     227            sql, params = aggregate.as_sql(qn, self.connection)
     228            result.append(
     229                '%s%s' % (
     230                    sql,
     231                    alias is not None
     232                       and ' AS %s' % qn(truncate_name(alias, max_name_length))
     233                       or ''
     234                )
    229235            )
    230             for alias, aggregate in self.query.aggregate_select.items()
    231         ])
     236            query_params.extend(params)
    232237
    233238        for table, col in self.query.related_select_cols:
    234239            r = '%s.%s' % (qn(table), qn(col))
    class SQLCompiler(object):  
    243248                col_aliases.add(col)
    244249
    245250        self._select_aliases = aliases
    246         return result
     251        return result, query_params
    247252
    248253    def get_default_columns(self, with_aliases=False, col_aliases=None,
    249254            start_alias=None, opts=None, as_pairs=False, local_only=False):
    class SQLAggregateCompiler(SQLCompiler):  
    10461051        """
    10471052        if qn is None:
    10481053            qn = self.quote_name_unless_alias
     1054        buf = []
     1055        a_params = []
     1056        for aggregate in self.query.aggregate_select.values():
     1057            sql, query_params = aggregate.as_sql(qn, self.connection)
     1058            buf.append(sql)
     1059            a_params.extend(query_params)
     1060        aggregate_sql = ', '.join(buf)
    10491061
    10501062        sql = ('SELECT %s FROM (%s) subquery' % (
    1051             ', '.join([
    1052                 aggregate.as_sql(qn, self.connection)
    1053                 for aggregate in self.query.aggregate_select.values()
    1054             ]),
     1063            aggregate_sql, 
    10551064            self.query.subquery)
    10561065        )
    1057         params = self.query.sub_params
     1066        params = tuple(a_params) + (self.query.sub_params)
    10581067        return (sql, params)
    10591068
    10601069class SQLDateCompiler(SQLCompiler):
  • django/db/models/sql/expressions.py

    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    index 1bbf742..f9c23a9 100644
    a b class SQLEvaluator(object):  
    6565        for child in node.children:
    6666            if hasattr(child, 'evaluate'):
    6767                sql, params = child.evaluate(self, qn, connection)
     68                if isinstance(sql, tuple):
     69                    expression_params.extend(sql[1])
     70                    sql = sql[0]
    6871            else:
    6972                sql, params = '%s', (child,)
    7073
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index ed2bc06..2c0e973 100644
    a b from django.utils.encoding import force_unicode  
    1414from django.utils.tree import Node
    1515from django.db import connections, DEFAULT_DB_ALIAS
    1616from django.db.models import signals
     17from django.db.models.aggregates import Aggregate
    1718from django.db.models.expressions import ExpressionNode
    1819from django.db.models.fields import FieldDoesNotExist
    1920from django.db.models.query_utils import InvalidQuery
    class Query(object):  
    987988        Adds a single aggregate expression to the Query
    988989        """
    989990        opts = model._meta
    990         field_list = aggregate.lookup.split(LOOKUP_SEP)
    991         if len(field_list) == 1 and aggregate.lookup in self.aggregates:
    992             # Aggregate is over an annotation
    993             field_name = field_list[0]
    994             col = field_name
    995             source = self.aggregates[field_name]
    996             if not is_summary:
    997                 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
    998                     aggregate.name, field_name, field_name))
    999         elif ((len(field_list) > 1) or
    1000             (field_list[0] not in [i.name for i in opts.fields]) or
    1001             self.group_by is None or
    1002             not is_summary):
    1003             # If:
    1004             #   - the field descriptor has more than one part (foo__bar), or
    1005             #   - the field descriptor is referencing an m2m/m2o field, or
    1006             #   - this is a reference to a model field (possibly inherited), or
    1007             #   - this is an annotation over a model field
    1008             # then we need to explore the joins that are required.
    1009 
    1010             field, source, opts, join_list, last, _ = self.setup_joins(
    1011                 field_list, opts, self.get_initial_alias(), False)
    1012 
    1013             # Process the join chain to see if it can be trimmed
    1014             col, _, join_list = self.trim_joins(source, join_list, last, False)
    1015 
    1016             # If the aggregate references a model or field that requires a join,
    1017             # those joins must be LEFT OUTER - empty join rows must be returned
    1018             # in order for zeros to be returned for those aggregates.
    1019             for column_alias in join_list:
    1020                 self.promote_alias(column_alias, unconditional=True)
    1021 
    1022             col = (join_list[-1], col)
     991        if hasattr(aggregate, 'evaluate'):
     992            # If aggregate is a query expression, make it an aggregate
     993            # This is a 'cheat' to make an empty aggregate - i.e.,
     994            # one that has no attached function.  This is because
     995            # no computation needs to be done outside that which the
     996            # F expression represents
     997            aggregate = Aggregate(aggregate)
     998            aggregate.name = 'Aggregate'
     999        if hasattr(aggregate.lookup, 'evaluate'):
     1000            # If lookup is a query expression, evaluate it
     1001            col = SQLEvaluator(aggregate.lookup, self)
     1002            # TODO: find out the real source of this field. If any field has
     1003            # is_computed, then source can be set to is_computed.
     1004            source = None
    10231005        else:
    1024             # The simplest cases. No joins required -
    1025             # just reference the provided column alias.
    1026             field_name = field_list[0]
    1027             source = opts.get_field(field_name)
    1028             col = field_name
     1006            field_list = aggregate.lookup.split(LOOKUP_SEP)
     1007            join_list = []
     1008            if len(field_list) == 1 and aggregate.lookup in self.aggregates:
     1009                # Aggregate is over an annotation
     1010                field_name = field_list[0]
     1011                col = field_name
     1012                source = self.aggregates[field_name]
     1013                if not is_summary:
     1014                    raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
     1015                        aggregate.name, field_name, field_name))
     1016            elif ((len(field_list) > 1) or
     1017                (field_list[0] not in [i.name for i in opts.fields]) or
     1018                self.group_by is None or
     1019                not is_summary):
     1020                # If:
     1021                #   - the field descriptor has more than one part (foo__bar), or
     1022                #   - the field descriptor is referencing an m2m/m2o field, or
     1023                #   - this is a reference to a model field (possibly inherited), or
     1024                #   - this is an annotation over a model field
     1025                # then we need to explore the joins that are required.
     1026
     1027                field, source, opts, join_list, last, _ = self.setup_joins(
     1028                    field_list, opts, self.get_initial_alias(), False)
     1029
     1030                # Process the join chain to see if it can be trimmed
     1031                col, _, join_list = self.trim_joins(source, join_list, last, False)
     1032
     1033                # If the aggregate references a model or field that requires a join,
     1034                # those joins must be LEFT OUTER - empty join rows must be returned
     1035                # in order for zeros to be returned for those aggregates.
     1036                for column_alias in join_list:
     1037                    self.promote_alias(column_alias, unconditional=True)
     1038
     1039                col = (join_list[-1], col)
     1040            else:
     1041                # The simplest cases. No joins required -
     1042                # just reference the provided column alias.
     1043                field_name = field_list[0]
     1044                source = opts.get_field(field_name)
     1045                col = field_name
    10291046
    10301047        # Add the aggregate to the query
    10311048        aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
  • django/db/models/sql/where.py

    diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
    index 1455ba6..8b530bd 100644
    a b class WhereNode(tree.Node):  
    139139        it.
    140140        """
    141141        lvalue, lookup_type, value_annot, params_or_value = child
     142        additional_params = []
    142143        if hasattr(lvalue, 'process'):
    143144            try:
    144145                lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
    class WhereNode(tree.Node):  
    153154        else:
    154155            # A smart object with an as_sql() method.
    155156            field_sql = lvalue.as_sql(qn, connection)
     157            if isinstance(field_sql, tuple):
     158                # It returned also params
     159                additional_params.extend(field_sql[1])
     160                field_sql = field_sql[0]
    156161
    157162        if value_annot is datetime.datetime:
    158163            cast_sql = connection.ops.datetime_cast_sql()
    class WhereNode(tree.Node):  
    161166
    162167        if hasattr(params, 'as_sql'):
    163168            extra, params = params.as_sql(qn, connection)
     169            if isinstance(extra, tuple):
     170                params = params + tuple(extra[1])
     171                extra = extra[0]
    164172            cast_sql = ''
    165173        else:
    166174            extra = ''
    class WhereNode(tree.Node):  
    170178            lookup_type = 'isnull'
    171179            value_annot = True
    172180
     181        additional_params.extend(params)
     182        params = additional_params
    173183        if lookup_type in connection.operators:
    174184            format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
    175185            return (format % (field_sql,
  • django/test/testcases.py

    diff --git a/django/test/testcases.py b/django/test/testcases.py
    index 53ea02a..ba4f496 100644
    a b class TransactionTestCase(SimpleTestCase):  
    646646            return self.assertEqual(set(map(transform, qs)), set(values))
    647647        return self.assertEqual(map(transform, qs), values)
    648648
     649    def assertQuerysetAlmostEqual(self, qs, values, transform=repr, ordered=True, places=7):
     650        # This could have been done with iterating zip(map(transform, qs), values),
     651        # checking each with assertAlmostEqual, which rounds the difference of each
     652        # pair, but this way you get much nicer error messages, and you can have an
     653        # unordered comparison, at the cost of a half a digit of accuracy.
     654        round_to = lambda v: round(v,places)
     655        tqs = map(round_to, map(transform, qs) )
     656        tvs = map(round_to, values)
     657        if not ordered:
     658            return self.assertEqual(set(tqs), set(tvs))
     659        return self.assertEqual(tqs, tvs)
     660
    649661    def assertNumQueries(self, num, func=None, *args, **kwargs):
    650662        using = kwargs.pop("using", DEFAULT_DB_ALIAS)
    651663        conn = connections[using]
  • docs/ref/models/querysets.txt

    diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
    index 7633555..d175f44 100644
    a b control the name of the annotation::  
    245245    >>> q[0].number_of_entries
    246246    42
    247247
     248In addition to aggregation functions, `:ref:`F() objects <query-expressions>`
     249can be used to perform a specific mathematical operation::
     250
     251    # The 1.0 is to force float conversion
     252    >>> q = Entry.objects.annotate(cpb_ratio=F('n_comments')*1.0/F('n_pingbacks'))
     253    # The ratio of comments to pingbacks for the first blog entry
     254    >>> q[0].cpb_ratio
     255    0.0625
     256
    248257For an in-depth discussion of aggregation, see :doc:`the topic guide on
    249258Aggregation </topics/db/aggregation>`.
    250259
    control the name of the aggregation value that is returned::  
    14821491    >>> q = Blog.objects.aggregate(number_of_entries=Count('entry'))
    14831492    {'number_of_entries': 16}
    14841493
     1494Inside aggregation functions, `:ref:`F() objects <query-expressions>`
     1495can be used to perform a specific mathematical operation::
     1496
     1497    # The 1.0 is to force float conversion
     1498    >>> q = Entry.objects.aggregate(avg_cpb_ratio=Avg(F('n_comments')*1.0/F('n_pingbacks')))
     1499    {'avg_cpb_ratio': 0.125}
     1500
    14851501For an in-depth discussion of aggregation, see :doc:`the topic guide on
    14861502Aggregation </topics/db/aggregation>`.
    14871503
    Django provides the following aggregation functions in the  
    21162132aggregate functions, see
    21172133:doc:`the topic guide on aggregation </topics/db/aggregation>`.
    21182134
     2135Note that in addition to taking a named field, aggregation
     2136functions can take `:ref:`F() objects <query-expressions>`.
     2137
     2138.. admonition:: Default aliases
     2139
     2140    When using ``F()`` objects, note that there is no default alias.
     2141
    21192142Avg
    21202143~~~
    21212144
  • tests/modeltests/aggregation/tests.py

    diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py
    index a35dbb3..a5d3a4e 100644
    a b import datetime  
    44from decimal import Decimal
    55
    66from django.db.models import Avg, Sum, Count, Max, Min
     7from django.db.models import F
    78from django.test import TestCase, Approximate
    89
    910from .models import Author, Publisher, Book, Store
    class BaseAggregateTestCase(TestCase):  
    6364        self.assertEqual(len(vals), 1)
    6465        self.assertAlmostEqual(vals["amazon_mean"], 4.08, places=2)
    6566
     67    def test_aggregate_f_expression(self):
     68        vals = Book.objects.all().aggregate(price_per_page=Avg(F('price')*1.0/F('pages')))
     69        self.assertEqual(len(vals), 1)
     70        self.assertAlmostEqual(vals["price_per_page"], 0.0745110754864109, places=2)
     71
     72    def test_annotate_f_expression(self):
     73        self.assertQuerysetAlmostEqual(
     74            Book.objects.all().annotate(price_per_page=F('price')*1.0/F('pages')), [
     75                0.0671140939597315,
     76                0.0437310606060606,
     77                0.0989666666666667,
     78                0.0848285714285714,
     79                0.0731448763250883,
     80                0.0792811839323467,
     81            ],
     82            lambda b: b.price_per_page,
     83            places=4
     84        )
     85
     86        self.assertQuerysetAlmostEqual(
     87            Publisher.objects.all().annotate(price_per_page=Avg(F('book__price')*1.0/F('book__pages'))), [
     88                0.0830403803131991,
     89                0.0437310606060606,
     90                0.0789867238768299,
     91                0.0792811839323467,
     92            ],
     93            lambda p: p.price_per_page,
     94            places=4
     95        )
     96
    6697    def test_annotate_basic(self):
    6798        self.assertQuerysetEqual(
    6899            Book.objects.annotate().order_by('pk'), [
Back to Top