Code

Ticket #11305: conditional_aggregates.1.6.patch

File conditional_aggregates.1.6.patch, 29.3 KB (added by airstrike, 19 months ago)

Updated to work with 1.5+, except for one test which isn't passing.

  • django/db/models/aggregates.py

    From 55f415fa2fe31b5191f9a67d17364f9590be401d Mon Sep 17 00:00:00 2001
    From: Andre Terra <loftarasa@gmail.com>
    Date: Tue, 11 Dec 2012 18:32:25 -0200
    Subject: [PATCH] Added support for conditional aggregates (ticket #11305)
    
    ---
     django/db/models/aggregates.py        |   9 ++-
     django/db/models/expressions.py       |   8 +--
     django/db/models/sql/aggregates.py    |  43 +++++++++++---
     django/db/models/sql/compiler.py      |  39 ++++++++-----
     django/db/models/sql/expressions.py   |  15 +++--
     django/db/models/sql/query.py         | 103 +++++++++++++++++++++-------------
     django/db/models/sql/where.py         |  10 ++++
     tests/modeltests/aggregation/tests.py |  93 ++++++++++++++++++++++++++++++
     8 files changed, 246 insertions(+), 74 deletions(-)
    
    diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
    index a2349cf..d816aa7 100644
    a b class Aggregate(object): 
    66    """ 
    77    Default Aggregate definition. 
    88    """ 
    9     def __init__(self, lookup, **extra): 
     9    def __init__(self, lookup, only=None, **extra): 
    1010        """Instantiate a new aggregate. 
    1111 
    1212         * lookup is the field on which the aggregate operates. 
     13         * only is a Q-object used in conditional aggregation. 
    1314         * extra is a dictionary of additional data to provide for the 
    1415           aggregate definition 
    1516 
    class Aggregate(object): 
    1819        """ 
    1920        self.lookup = lookup 
    2021        self.extra = extra 
     22        self.only = only 
     23        self.condition = None 
    2124 
    2225    def _default_alias(self): 
     26        if hasattr(self.lookup, 'evaluate'): 
     27             raise ValueError('When aggregating over an expression, you need to give an alias.') 
    2328        return '%s__%s' % (self.lookup, self.name.lower()) 
    2429    default_alias = property(_default_alias) 
    2530 
    class Aggregate(object): 
    4247           summary value rather than an annotation. 
    4348        """ 
    4449        klass = getattr(query.aggregates_module, self.name) 
    45         aggregate = klass(col, source=source, is_summary=is_summary, **self.extra) 
     50        aggregate = klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra) 
    4651        query.aggregates[alias] = aggregate 
    4752 
    4853class Avg(Aggregate): 
  • django/db/models/expressions.py

    diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
    index 3566d77..2b782ea 100644
    a b class ExpressionNode(tree.Node): 
    4141    # VISITOR METHODS # 
    4242    ################### 
    4343 
    44     def prepare(self, evaluator, query, allow_joins): 
    45         return evaluator.prepare_node(self, query, allow_joins) 
     44    def prepare(self, evaluator, query, allow_joins, promote_joins=False): 
     45        return evaluator.prepare_node(self, query, allow_joins, promote_joins) 
    4646 
    4747    def evaluate(self, evaluator, qn, connection): 
    4848        return evaluator.evaluate_node(self, qn, connection) 
    class F(ExpressionNode): 
    129129        obj.name = self.name 
    130130        return obj 
    131131 
    132     def prepare(self, evaluator, query, allow_joins): 
    133         return evaluator.prepare_leaf(self, query, allow_joins) 
     132    def prepare(self, evaluator, query, allow_joins, promote_joins=False): 
     133        return evaluator.prepare_leaf(self, query, allow_joins, promote_joins) 
    134134 
    135135    def evaluate(self, evaluator, qn, connection): 
    136136        return evaluator.evaluate_leaf(self, qn, connection) 
  • django/db/models/sql/aggregates.py

    diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
    index b41314a..5fe2215 100644
    a b Classes to represent the default SQL aggregate functions 
    33""" 
    44 
    55from django.db.models.fields import IntegerField, FloatField 
     6from django.db.models.sql.expressions import SQLEvaluator 
    67 
    78# Fake fields used to identify aggregate types in data-conversion operations. 
    89ordinal_aggregate_field = IntegerField() 
    class Aggregate(object): 
    1516    is_ordinal = False 
    1617    is_computed = False 
    1718    sql_template = '%(function)s(%(field)s)' 
     19    conditional_template = "CASE WHEN %(condition)s THEN %(field_name)s ELSE null END" 
    1820 
    19     def __init__(self, col, source=None, is_summary=False, **extra): 
     21    def __init__(self, col, source=None, is_summary=False, condition=None, **extra): 
    2022        """Instantiate an SQL aggregate 
    2123 
    2224         * col is a column reference describing the subject field 
    class Aggregate(object): 
    2628           the column reference. If the aggregate is not an ordinal or 
    2729           computed type, this reference is used to determine the coerced 
    2830           output type of the aggregate. 
     31         * condition is used in conditional aggregation. 
    2932         * extra is a dictionary of additional data to provide for the 
    30            aggregate definition 
     33           aggregate definition. 
    3134 
    3235        Also utilizes the class variables: 
    3336         * sql_function, the name of the SQL function that implements the 
    class Aggregate(object): 
    3538         * sql_template, a template string that is used to render the 
    3639           aggregate into SQL. 
    3740         * is_ordinal, a boolean indicating if the output of this aggregate 
    38            is an integer (e.g., a count) 
     41           is an integer (e.g., a count). 
    3942         * is_computed, a boolean indicating if this output of this aggregate 
    4043           is a computed float (e.g., an average), regardless of the input 
    4144           type. 
    class Aggregate(object): 
    4548        self.source = source 
    4649        self.is_summary = is_summary 
    4750        self.extra = extra 
     51        self.condition = condition 
    4852 
    4953        # Follow the chain of aggregate sources back until you find an 
    5054        # actual field, or an aggregate that forces a particular output 
    class Aggregate(object): 
    6569    def relabel_aliases(self, change_map): 
    6670        if isinstance(self.col, (list, tuple)): 
    6771            self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) 
     72        else: 
     73            self.col.relabel_aliases(change_map) 
     74        if self.condition: 
     75            self.condition.relabel_aliases(change_map) 
    6876 
    6977    def as_sql(self, qn, connection): 
    7078        "Return the aggregate, rendered as SQL." 
    7179 
     80        condition_params = [] 
     81        col_params = [] 
    7282        if hasattr(self.col, 'as_sql'): 
    73             field_name = self.col.as_sql(qn, connection) 
     83            if isinstance(self.col, SQLEvaluator): 
     84                field_name, col_params = self.col.as_sql(qn, connection) 
     85            else: 
     86                field_name = self.col.as_sql(qn, connection) 
    7487        elif isinstance(self.col, (list, tuple)): 
    7588            field_name = '.'.join([qn(c) for c in self.col]) 
    7689        else: 
    7790            field_name = self.col 
    7891 
    79         params = { 
    80             'function': self.sql_function, 
    81             'field': field_name 
    82         } 
     92        if self.condition: 
     93            condition, condition_params = self.condition.as_sql(qn, connection) 
     94            conditional_field = self.conditional_template % { 
     95                'condition': condition, 
     96                'field_name': field_name 
     97            } 
     98            params = { 
     99                'function': self.sql_function, 
     100                'field': conditional_field, 
     101            } 
     102        else: 
     103            params = { 
     104                'function': self.sql_function, 
     105                'field': field_name 
     106            } 
    83107        params.update(self.extra) 
    84108 
    85         return self.sql_template % params 
     109        condition_params.extend(col_params) 
     110        return (self.sql_template % params, condition_params) 
    86111 
    87112 
    88113class Avg(Aggregate): 
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 8cfb12a..ed97a11 100644
    a b class SQLCompiler(object): 
    7171        # as the pre_sql_setup will modify query state in a way that forbids 
    7272        # another run of it. 
    7373        self.refcounts_before = self.query.alias_refcount.copy() 
    74         out_cols = self.get_columns(with_col_aliases) 
     74        out_cols, c_params = self.get_columns(with_col_aliases) 
    7575        ordering, ordering_group_by = self.get_ordering() 
    7676 
    7777        distinct_fields = self.get_distinct() 
    class SQLCompiler(object): 
    8787        params = [] 
    8888        for val in six.itervalues(self.query.extra_select): 
    8989            params.extend(val[1]) 
     90        # Extra-select comes before aggregation in the select list 
     91        params.extend(c_params) 
    9092 
    9193        result = ['SELECT'] 
    9294 
    class SQLCompiler(object): 
    172174        qn = self.quote_name_unless_alias 
    173175        qn2 = self.connection.ops.quote_name 
    174176        result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] 
     177        query_params = [] 
    175178        aliases = set(self.query.extra_select.keys()) 
    176179        if with_aliases: 
    177180            col_aliases = aliases.copy() 
    class SQLCompiler(object): 
    214217            aliases.update(new_aliases) 
    215218 
    216219        max_name_length = self.connection.ops.max_name_length() 
    217         result.extend([ 
    218             '%s%s' % ( 
    219                 aggregate.as_sql(qn, self.connection), 
    220                 alias is not None 
    221                     and ' AS %s' % qn(truncate_name(alias, max_name_length)) 
    222                     or '' 
     220        for alias, aggregate in self.query.aggregate_select.items(): 
     221            sql, params = aggregate.as_sql(qn, self.connection) 
     222            result.append( 
     223                '%s%s' % ( 
     224                    sql, 
     225                    alias is not None 
     226                       and ' AS %s' % qn(truncate_name(alias, max_name_length)) 
     227                       or '' 
     228                ) 
    223229            ) 
    224             for alias, aggregate in self.query.aggregate_select.items() 
    225         ]) 
     230            query_params.extend(params) 
    226231 
    227232        for (table, col), _ in self.query.related_select_cols: 
    228233            r = '%s.%s' % (qn(table), qn(col)) 
    class SQLCompiler(object): 
    237242                col_aliases.add(col) 
    238243 
    239244        self._select_aliases = aliases 
    240         return result 
     245        return result, query_params 
    241246 
    242247    def get_default_columns(self, with_aliases=False, col_aliases=None, 
    243248            start_alias=None, opts=None, as_pairs=False, from_parent=None): 
    class SQLAggregateCompiler(SQLCompiler): 
    10321037        """ 
    10331038        if qn is None: 
    10341039            qn = self.quote_name_unless_alias 
     1040        buf = [] 
     1041        a_params = [] 
     1042        for aggregate in self.query.aggregate_select.values(): 
     1043            sql, query_params = aggregate.as_sql(qn, self.connection) 
     1044            buf.append(sql) 
     1045            a_params.extend(query_params) 
     1046        aggregate_sql = ', '.join(buf) 
    10351047 
    10361048        sql = ('SELECT %s FROM (%s) subquery' % ( 
    1037             ', '.join([ 
    1038                 aggregate.as_sql(qn, self.connection) 
    1039                 for aggregate in self.query.aggregate_select.values() 
    1040             ]), 
     1049            aggregate_sql, 
    10411050            self.query.subquery) 
    10421051        ) 
    1043         params = self.query.sub_params 
     1052        params = tuple(a_params) + (self.query.sub_params) 
    10441053        return (sql, params) 
    10451054 
    10461055class SQLDateCompiler(SQLCompiler): 
  • django/db/models/sql/expressions.py

    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    index c809e25..4e53ac4 100644
    a b from django.db.models.fields import FieldDoesNotExist 
    44from django.db.models.sql.constants import REUSE_ALL 
    55 
    66class SQLEvaluator(object): 
    7     def __init__(self, expression, query, allow_joins=True, reuse=REUSE_ALL): 
     7    def __init__(self, expression, query, allow_joins=True, reuse=REUSE_ALL, promote_joins=False): 
    88        self.expression = expression 
    99        self.opts = query.get_meta() 
    1010        self.cols = [] 
    1111 
    1212        self.contains_aggregate = False 
    1313        self.reuse = reuse 
    14         self.expression.prepare(self, query, allow_joins) 
     14        self.expression.prepare(self, query, allow_joins, promote_joins) 
    1515 
    1616    def prepare(self): 
    1717        return self 
    class SQLEvaluator(object): 
    3434    # Vistor methods for initial expression preparation # 
    3535    ##################################################### 
    3636 
    37     def prepare_node(self, node, query, allow_joins): 
     37    def prepare_node(self, node, query, allow_joins, promote_joins): 
    3838        for child in node.children: 
    3939            if hasattr(child, 'prepare'): 
    40                 child.prepare(self, query, allow_joins) 
     40                child.prepare(self, query, allow_joins, promote_joins) 
    4141 
    42     def prepare_leaf(self, node, query, allow_joins): 
     42    def prepare_leaf(self, node, query, allow_joins, promote_joins): 
    4343        if not allow_joins and LOOKUP_SEP in node.name: 
    4444            raise FieldError("Joined field references are not permitted in this query") 
    4545 
    class SQLEvaluator(object): 
    5454                    field_list, query.get_meta(), 
    5555                    query.get_initial_alias(), self.reuse) 
    5656                col, _, join_list = query.trim_joins(source, join_list, last, False) 
     57                if promote_joins: 
     58                    query.promote_joins(join_list, unconditional=True) 
    5759                if self.reuse is not None and self.reuse != REUSE_ALL: 
    5860                    self.reuse.update(join_list) 
    5961                self.cols.append((node, (join_list[-1], col))) 
    class SQLEvaluator(object): 
    7274        for child in node.children: 
    7375            if hasattr(child, 'evaluate'): 
    7476                sql, params = child.evaluate(self, qn, connection) 
     77                if isinstance(sql, tuple): 
     78                    expression_params.extend(sql[1]) 
     79                    sql = sql[0] 
    7580            else: 
    7681                sql, params = '%s', (child,) 
    7782 
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index e24dc22..c0fa92c 100644
    a b class Query(object): 
    119119        self.filter_is_sticky = False 
    120120        self.included_inherited_models = {} 
    121121 
    122         # SQL-related attributes   
     122        # SQL-related attributes 
    123123        # Select and related select clauses as SelectInfo instances. 
    124124        # The select is used for cases where we want to set up the select 
    125125        # clause to contain other than default fields (values(), annotate(), 
    class Query(object): 
    987987        Adds a single aggregate expression to the Query 
    988988        """ 
    989989        opts = model._meta 
    990         field_list = aggregate.lookup.split(LOOKUP_SEP) 
    991         if len(field_list) == 1 and aggregate.lookup in self.aggregates: 
    992             # Aggregate is over an annotation 
    993             field_name = field_list[0] 
    994             col = field_name 
    995             source = self.aggregates[field_name] 
    996             if not is_summary: 
    997                 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 
    998                     aggregate.name, field_name, field_name)) 
    999         elif ((len(field_list) > 1) or 
    1000             (field_list[0] not in [i.name for i in opts.fields]) or 
    1001             self.group_by is None or 
    1002             not is_summary): 
    1003             # If: 
    1004             #   - the field descriptor has more than one part (foo__bar), or 
    1005             #   - the field descriptor is referencing an m2m/m2o field, or 
    1006             #   - this is a reference to a model field (possibly inherited), or 
    1007             #   - this is an annotation over a model field 
    1008             # then we need to explore the joins that are required. 
    1009  
    1010             field, source, opts, join_list, last, _ = self.setup_joins( 
    1011                 field_list, opts, self.get_initial_alias(), REUSE_ALL) 
    1012  
    1013             # Process the join chain to see if it can be trimmed 
    1014             col, _, join_list = self.trim_joins(source, join_list, last, False) 
    1015  
    1016             # If the aggregate references a model or field that requires a join, 
    1017             # those joins must be LEFT OUTER - empty join rows must be returned 
    1018             # in order for zeros to be returned for those aggregates. 
    1019             self.promote_joins(join_list, True) 
    1020  
    1021             col = (join_list[-1], col) 
     990        only = aggregate.only 
     991        if hasattr(aggregate.lookup, 'evaluate'): 
     992            # If lookup is a query expression, evaluate it 
     993            col = SQLEvaluator(aggregate.lookup, self, promote_joins=True) 
     994            # TODO: find out the real source of this field. If any field has 
     995            # is_computed, then source can be set to is_computed. 
     996            source = None 
    1022997        else: 
    1023             # The simplest cases. No joins required - 
    1024             # just reference the provided column alias. 
    1025             field_name = field_list[0] 
    1026             source = opts.get_field(field_name) 
    1027             col = field_name 
    1028  
     998            field_list = aggregate.lookup.split(LOOKUP_SEP) 
     999            join_list = [] 
     1000            if len(field_list) == 1 and aggregate.lookup in self.aggregates: 
     1001                # Aggregate is over an annotation 
     1002                field_name = field_list[0] 
     1003                col = field_name 
     1004                source = self.aggregates[field_name] 
     1005                if not is_summary: 
     1006                    raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 
     1007                        aggregate.name, field_name, field_name)) 
     1008                if only: 
     1009                    raise FieldError("Cannot use aggregated fields in conditional aggregates") 
     1010            elif ((len(field_list) > 1) or 
     1011                (field_list[0] not in [i.name for i in opts.fields]) or 
     1012                self.group_by is None or 
     1013                not is_summary): 
     1014                # If: 
     1015                #   - the field descriptor has more than one part (foo__bar), or 
     1016                #   - the field descriptor is referencing an m2m/m2o field, or 
     1017                #   - this is a reference to a model field (possibly inherited), or 
     1018                #   - this is an annotation over a model field 
     1019                # then we need to explore the joins that are required. 
     1020 
     1021                field, source, opts, join_list, last, _ = self.setup_joins( 
     1022                    field_list, opts, self.get_initial_alias(), REUSE_ALL) 
     1023 
     1024                # Process the join chain to see if it can be trimmed 
     1025                col, _, join_list = self.trim_joins(source, join_list, last, False) 
     1026 
     1027                # If the aggregate references a model or field that requires a join, 
     1028                # those joins must be LEFT OUTER - empty join rows must be returned 
     1029                # in order for zeros to be returned for those aggregates. 
     1030                self.promote_joins(join_list, unconditional=True) 
     1031 
     1032                col = (join_list[-1], col) 
     1033            else: 
     1034                # The simplest cases. No joins required - 
     1035                # just reference the provided column alias. 
     1036                field_name = field_list[0] 
     1037                source = opts.get_field(field_name) 
     1038                col = field_name 
     1039 
     1040        if only: 
     1041            original_where = self.where 
     1042            original_having = self.having 
     1043            aggregate.condition = self.where_class() 
     1044            self.where = aggregate.condition 
     1045            self.having = self.where_class() 
     1046            original_alias_map = self.alias_map.keys()[:] 
     1047            self.add_q(only, used_aliases=set(original_alias_map)) 
     1048            if original_alias_map != self.alias_map.keys(): 
     1049                raise FieldError("Aggregate's only condition can not require additional joins, Original joins: %s, joins after: %s" % (original_alias_map, self.alias_map.keys())) 
     1050            if self.having.children: 
     1051                raise FieldError("Aggregate's only condition can not reference annotated fields") 
     1052            self.having = original_having 
     1053            self.where = original_where 
    10291054        # Add the aggregate to the query 
    10301055        aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) 
    10311056 
  • django/db/models/sql/where.py

    diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
    index 47f4ffa..6b001bc 100644
    a b class WhereNode(tree.Node): 
    156156        it. 
    157157        """ 
    158158        lvalue, lookup_type, value_annotation, params_or_value = child 
     159        additional_params = [] 
    159160        if isinstance(lvalue, Constraint): 
    160161            try: 
    161162                lvalue, params = lvalue.process(lookup_type, params_or_value, connection) 
    class WhereNode(tree.Node): 
    173174        else: 
    174175            # A smart object with an as_sql() method. 
    175176            field_sql = lvalue.as_sql(qn, connection) 
     177            if isinstance(field_sql, tuple): 
     178                # It also returned params 
     179                additional_params.extend(field_sql[1]) 
     180                field_sql = field_sql[0] 
    176181 
    177182        if value_annotation is datetime.datetime: 
    178183            cast_sql = connection.ops.datetime_cast_sql() 
    class WhereNode(tree.Node): 
    181186 
    182187        if hasattr(params, 'as_sql'): 
    183188            extra, params = params.as_sql(qn, connection) 
     189            if isinstance(extra, tuple): 
     190                params = params + tuple(extra[1]) 
     191                extra = extra[0] 
    184192            cast_sql = '' 
    185193        else: 
    186194            extra = '' 
    class WhereNode(tree.Node): 
    190198            lookup_type = 'isnull' 
    191199            value_annotation = True 
    192200 
     201        additional_params.extend(params) 
     202        params = additional_params 
    193203        if lookup_type in connection.operators: 
    194204            format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) 
    195205            return (format % (field_sql, 
  • tests/modeltests/aggregation/tests.py

    diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py
    index c23b32f..0c445c4 100644
    a b import datetime 
    44from decimal import Decimal 
    55 
    66from django.db.models import Avg, Sum, Count, Max, Min 
     7from django.db.models import Q, F 
     8from django.core.exceptions import FieldError 
    79from django.test import TestCase, Approximate 
    810 
    911from .models import Author, Publisher, Book, Store 
    class BaseAggregateTestCase(TestCase): 
    1820    def test_single_aggregate(self): 
    1921        vals = Author.objects.aggregate(Avg("age")) 
    2022        self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)}) 
     23        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29))) 
     24        self.assertEqual(vals, {"age__sum": 254}) 
     25        vals = Author.objects.extra(select={'testparams':'age < %s'}, select_params=[0])\ 
     26               .aggregate(Sum("age", only=Q(age__gt=29))) 
     27        self.assertEqual(vals, {"age__sum": 254}) 
     28        vals = Author.objects.aggregate(Sum("age", only=Q(name__icontains='jaco')|Q(name__icontains='adrian'))) 
     29        self.assertEqual(vals, {"age__sum": 69}) 
    2130 
    2231    def test_multiple_aggregates(self): 
    2332        vals = Author.objects.aggregate(Sum("age"), Avg("age")) 
    2433        self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)}) 
     34        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)), Avg("age")) 
     35        self.assertEqual(vals, {"age__sum": 254, "age__avg": Approximate(37.4, places=1)}) 
    2536 
    2637    def test_filter_aggregate(self): 
    2738        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age")) 
    2839        self.assertEqual(len(vals), 1) 
    2940        self.assertEqual(vals["age__sum"], 254) 
     41        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age", only=Q(age__lt=29))) 
     42        # If there are no matching aggregates, then None, not 0 is the answer. 
     43        self.assertEqual(vals["age__sum"], None) 
    3044 
    3145    def test_related_aggregate(self): 
    3246        vals = Author.objects.aggregate(Avg("friends__age")) 
    3347        self.assertEqual(len(vals), 1) 
    3448        self.assertAlmostEqual(vals["friends__age__avg"], 34.07, places=2) 
    3549 
     50        vals = Author.objects.aggregate(Avg("friends__age", only=Q(age__lt=29))) 
     51        self.assertEqual(len(vals), 1) 
     52        self.assertAlmostEqual(vals["friends__age__avg"], 33.67, places=2) 
     53        vals2 = Author.objects.filter(age__lt=29).aggregate(Avg("friends__age")) 
     54        self.assertEqual(vals, vals2) 
     55 
     56        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=35))) 
     57        self.assertEqual(len(vals), 1) 
     58        self.assertAlmostEqual(vals["friends__age__avg"], 28.75, places=2) 
     59 
     60        # The average age of author's friends, whose age is lower than the authors age. 
     61        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=F('age')))) 
     62        self.assertEqual(len(vals), 1) 
     63        self.assertAlmostEqual(vals["friends__age__avg"], 30.43, places=2) 
     64 
    3665        vals = Book.objects.filter(rating__lt=4.5).aggregate(Avg("authors__age")) 
    3766        self.assertEqual(len(vals), 1) 
    3867        self.assertAlmostEqual(vals["authors__age__avg"], 38.2857, places=2) 
    class BaseAggregateTestCase(TestCase): 
    5483        self.assertEqual(len(vals), 1) 
    5584        self.assertEqual(vals["books__authors__age__max"], 57) 
    5685 
     86        vals = Store.objects.aggregate(Max("books__authors__age", only=Q(books__authors__age__lt=56))) 
     87        self.assertEqual(len(vals), 1) 
     88        self.assertEqual(vals["books__authors__age__max"], 46) 
     89 
    5790        vals = Author.objects.aggregate(Min("book__publisher__num_awards")) 
    5891        self.assertEqual(len(vals), 1) 
    5992        self.assertEqual(vals["book__publisher__num_awards__min"], 1) 
    class BaseAggregateTestCase(TestCase): 
    84117        ) 
    85118        self.assertEqual(b.mean_age, 34.5) 
    86119 
     120        # Test extra-select 
     121        books = Book.objects.annotate(mean_age=Avg("authors__age")) 
     122        books = books.annotate(mean_age2=Avg('authors__age', only=Q(authors__age__gte=0))) 
     123        books = books.extra(select={'testparams': 'publisher_id = %s'}, select_params=[1]) 
     124        b = books.get(pk=1) 
     125        self.assertEqual(b.mean_age, 34.5) 
     126        self.assertEqual(b.mean_age2, 34.5) 
     127        self.assertEqual(b.testparams, True) 
     128 
     129        # Test relabel_aliases 
     130        excluded_authors = Author.objects.annotate(book_rating=Min(F('book__rating') + 5, only=Q(pk__gte=1))) 
     131        excluded_authors = excluded_authors.filter(book_rating__lt=0) 
     132        books = books.exclude(authors__in=excluded_authors) 
     133        b = books.get(pk=1) 
     134        self.assertEqual(b.mean_age, 34.5) 
     135 
     136        # Test joins in F-based annotation 
     137        books = Book.objects.annotate(oldest=Max(F('authors__age'))) 
     138        books = books.values_list('rating', 'oldest').order_by('rating', 'oldest') 
     139        self.assertEqual( 
     140            list(books), 
     141            [(3.0, 45), (4.0, 29), (4.0, 37), (4.0, 57), (4.5, 35), (5.0, 57)] 
     142        ) 
     143 
     144        publishers = Publisher.objects.annotate(avg_rating=Avg(F('book__rating') - 0)) 
     145        publishers = publishers.values_list('id', 'avg_rating').order_by('id') 
     146        self.assertEqual(list(publishers), [(1, 4.25), (2, 3.0), (3, 4.0), (4, 5.0), (5, None)]) 
     147 
    87148    def test_annotate_m2m(self): 
    88149        books = Book.objects.filter(rating__lt=4.5).annotate(Avg("authors__age")).order_by("name") 
    89150        self.assertQuerysetEqual( 
    class BaseAggregateTestCase(TestCase): 
    109170            lambda b: (b.name, b.num_authors) 
    110171        ) 
    111172 
     173        def raises_exception(): 
     174            list(Book.objects.annotate(num_authors=Count("authors")).annotate(num_authors2=Count("authors", only=Q(num_authors__gt=1))).order_by("name")) 
     175 
     176        self.assertRaises(FieldError, raises_exception) 
     177 
    112178    def test_backwards_m2m_annotate(self): 
    113179        authors = Author.objects.filter(name__contains="a").annotate(Avg("book__rating")).order_by("name") 
    114180        self.assertQuerysetEqual( 
    class BaseAggregateTestCase(TestCase): 
    194260                } 
    195261            ] 
    196262        ) 
     263        books = Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age', only=Q(authors__age__lt=35))).values('pk', 'isbn', 'mean_age') 
     264        self.assertEqual( 
     265            list(books), [ 
     266                { 
     267                    "pk": 1, 
     268                    "isbn": "159059725", 
     269                    "mean_age": 34.0, 
     270                } 
     271            ] 
     272        ) 
    197273 
    198274        books = Book.objects.filter(pk=1).annotate(mean_age=Avg("authors__age")).values("name") 
    199275        self.assertEqual( 
    class BaseAggregateTestCase(TestCase): 
    271347 
    272348        vals = Book.objects.aggregate(Count("rating", distinct=True)) 
    273349        self.assertEqual(vals, {"rating__count": 4}) 
     350        vals = Book.objects.aggregate( 
     351            low_count=Count("rating", only=Q(rating__lt=4)), 
     352            high_count=Count("rating", only=Q(rating__gte=4)) 
     353        ) 
     354        self.assertEqual(vals, {"low_count": 1, 'high_count': 5}) 
     355        vals = Book.objects.aggregate( 
     356            low_count=Count("rating", distinct=True, only=Q(rating__lt=4)), 
     357            high_count=Count("rating", distinct=True, only=Q(rating__gte=4)) 
     358        ) 
     359        self.assertEqual(vals, {"low_count": 1, 'high_count': 3}) 
    274360 
    275361    def test_fkey_aggregate(self): 
    276362        explicit = list(Author.objects.annotate(Count('book__id'))) 
    class BaseAggregateTestCase(TestCase): 
    390476            ], 
    391477            lambda p: p.name, 
    392478        ) 
     479        publishers = Publisher.objects.annotate(num_books=Count("book__id", only=Q(book__id__gt=5))).filter(num_books__gt=1, book__price__lt=Decimal("40.0")).order_by("pk") 
     480        self.assertQuerysetEqual( 
     481            publishers, [ 
     482                "Expensive Publisher", 
     483            ], 
     484            lambda p: p.name, 
     485        ) 
    393486 
    394487        publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk") 
    395488        self.assertQuerysetEqual(