Code

Ticket #11305: poc_11305.patch

File poc_11305.patch, 28.1 KB (added by akaariai, 3 years ago)

Proof of concept (latest, ignore the .2.)

  • django/db/models/aggregates.py

    diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
    index a2349cf..d816aa7 100644
    a b class Aggregate(object): 
    66    """ 
    77    Default Aggregate definition. 
    88    """ 
    9     def __init__(self, lookup, **extra): 
     9    def __init__(self, lookup, only=None, **extra): 
    1010        """Instantiate a new aggregate. 
    1111 
    1212         * lookup is the field on which the aggregate operates. 
     13         * only is a Q-object used in conditional aggregation. 
    1314         * extra is a dictionary of additional data to provide for the 
    1415           aggregate definition 
    1516 
    class Aggregate(object): 
    1819        """ 
    1920        self.lookup = lookup 
    2021        self.extra = extra 
     22        self.only = only 
     23        self.condition = None 
    2124 
    2225    def _default_alias(self): 
     26        if hasattr(self.lookup, 'evaluate'): 
     27             raise ValueError('When aggregating over an expression, you need to give an alias.') 
    2328        return '%s__%s' % (self.lookup, self.name.lower()) 
    2429    default_alias = property(_default_alias) 
    2530 
    class Aggregate(object): 
    4247           summary value rather than an annotation. 
    4348        """ 
    4449        klass = getattr(query.aggregates_module, self.name) 
    45         aggregate = klass(col, source=source, is_summary=is_summary, **self.extra) 
     50        aggregate = klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra) 
    4651        query.aggregates[alias] = aggregate 
    4752 
    4853class Avg(Aggregate): 
  • django/db/models/expressions.py

    diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
    index a71f4a3..390f475 100644
    a b class ExpressionNode(tree.Node): 
    3939    # VISITOR METHODS # 
    4040    ################### 
    4141 
    42     def prepare(self, evaluator, query, allow_joins): 
    43         return evaluator.prepare_node(self, query, allow_joins) 
     42    def prepare(self, evaluator, query, allow_joins, promote_joins=False): 
     43        return evaluator.prepare_node(self, query, allow_joins, promote_joins) 
    4444 
    4545    def evaluate(self, evaluator, qn, connection): 
    4646        return evaluator.evaluate_node(self, qn, connection) 
    class F(ExpressionNode): 
    107107        obj.name = self.name 
    108108        return obj 
    109109 
    110     def prepare(self, evaluator, query, allow_joins): 
    111         return evaluator.prepare_leaf(self, query, allow_joins) 
     110    def prepare(self, evaluator, query, allow_joins, promote_joins=False): 
     111        return evaluator.prepare_leaf(self, query, allow_joins, promote_joins) 
    112112 
    113113    def evaluate(self, evaluator, qn, connection): 
    114114        return evaluator.evaluate_leaf(self, qn, connection) 
  • django/db/models/sql/aggregates.py

    diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
    index 207bc0c..45f9104 100644
    a b  
    11""" 
    22Classes to represent the default SQL aggregate functions 
    33""" 
     4from django.db.models.sql.expressions import SQLEvaluator 
    45 
    56class AggregateField(object): 
    67    """An internal field mockup used to identify aggregates in the 
    class Aggregate(object): 
    2223    is_ordinal = False 
    2324    is_computed = False 
    2425    sql_template = '%(function)s(%(field)s)' 
     26    conditional_template = "CASE WHEN %(condition)s THEN %(field_name)s ELSE null END" 
    2527 
    26     def __init__(self, col, source=None, is_summary=False, **extra): 
     28    def __init__(self, col, source=None, is_summary=False, condition=None, **extra): 
    2729        """Instantiate an SQL aggregate 
    2830 
    2931         * col is a column reference describing the subject field 
    class Aggregate(object): 
    3335           the column reference. If the aggregate is not an ordinal or 
    3436           computed type, this reference is used to determine the coerced 
    3537           output type of the aggregate. 
     38         * condition is used in conditional aggregation 
    3639         * extra is a dictionary of additional data to provide for the 
    3740           aggregate definition 
    3841 
    class Aggregate(object): 
    5255        self.source = source 
    5356        self.is_summary = is_summary 
    5457        self.extra = extra 
     58        self.condition = condition 
    5559 
    5660        # Follow the chain of aggregate sources back until you find an 
    5761        # actual field, or an aggregate that forces a particular output 
    class Aggregate(object): 
    6670                tmp = computed_aggregate_field 
    6771            else: 
    6872                tmp = tmp.source 
    69  
     73         
     74        # We don't know the real source of this aggregate, and the 
     75        # aggregate doesn't define ordinal or computed either. So 
     76        # we default to computed for these cases.  
     77        if tmp is None: 
     78            tmp = computed_aggregate_field 
    7079        self.field = tmp 
    7180 
     81 
    7282    def relabel_aliases(self, change_map): 
    7383        if isinstance(self.col, (list, tuple)): 
    7484            self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) 
     85        else: 
     86            self.col.relabel_aliases(change_map) 
     87        if self.condition: 
     88            self.condition.relabel_aliases(change_map) 
    7589 
    7690    def as_sql(self, qn, connection): 
    7791        "Return the aggregate, rendered as SQL." 
    7892 
     93        condition_params = [] 
     94        col_params = [] 
    7995        if hasattr(self.col, 'as_sql'): 
    80             field_name = self.col.as_sql(qn, connection) 
     96            if isinstance(self.col, SQLEvaluator): 
     97                field_name, col_params = self.col.as_sql(qn, connection) 
     98            else: 
     99                field_name = self.col.as_sql(qn, connection) 
     100             
    81101        elif isinstance(self.col, (list, tuple)): 
    82102            field_name = '.'.join([qn(c) for c in self.col]) 
    83103        else: 
    84104            field_name = self.col 
    85  
    86         params = { 
    87             'function': self.sql_function, 
    88             'field': field_name 
    89         } 
     105        if self.condition: 
     106            condition, condition_params = self.condition.as_sql(qn, connection) 
     107            conditional_field = self.conditional_template % { 
     108                'condition': condition,  
     109                'field_name': field_name 
     110            }  
     111            params = { 
     112                'function': self.sql_function, 
     113                'field': conditional_field, 
     114            }  
     115        else: 
     116            params = { 
     117                'function': self.sql_function, 
     118                'field': field_name 
     119            } 
    90120        params.update(self.extra) 
    91  
    92         return self.sql_template % params 
     121        condition_params.extend(col_params) 
     122        return (self.sql_template % params, condition_params) 
    93123 
    94124 
    95125class Avg(Aggregate): 
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 841ec12..775d9ca 100644
    a b class SQLCompiler(object): 
    5858            return '', () 
    5959 
    6060        self.pre_sql_setup() 
    61         out_cols = self.get_columns(with_col_aliases) 
     61        out_cols, c_params = self.get_columns(with_col_aliases) 
    6262        ordering, ordering_group_by = self.get_ordering() 
    6363 
    6464        # This must come after 'select' and 'ordering' -- see docstring of 
    class SQLCompiler(object): 
    7272        params = [] 
    7373        for val in self.query.extra_select.itervalues(): 
    7474            params.extend(val[1]) 
     75        # Extra-select comes before aggregation in the select list 
     76        params.extend(c_params) 
    7577 
    7678        result = ['SELECT'] 
    7779        if self.query.distinct: 
    class SQLCompiler(object): 
    126128            if nowait and not self.connection.features.has_select_for_update_nowait: 
    127129                raise DatabaseError('NOWAIT is not supported on this database backend.') 
    128130            result.append(self.connection.ops.for_update_sql(nowait=nowait)) 
    129  
    130131        return ' '.join(result), tuple(params) 
    131132 
    132133    def as_nested_sql(self): 
    class SQLCompiler(object): 
    158159        qn = self.quote_name_unless_alias 
    159160        qn2 = self.connection.ops.quote_name 
    160161        result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()] 
     162        query_params = [] 
    161163        aliases = set(self.query.extra_select.keys()) 
    162164        if with_aliases: 
    163165            col_aliases = aliases.copy() 
    class SQLCompiler(object): 
    200202            aliases.update(new_aliases) 
    201203 
    202204        max_name_length = self.connection.ops.max_name_length() 
    203         result.extend([ 
    204             '%s%s' % ( 
    205                 aggregate.as_sql(qn, self.connection), 
    206                 alias is not None 
    207                     and ' AS %s' % qn(truncate_name(alias, max_name_length)) 
    208                     or '' 
     205        for alias, aggregate in self.query.aggregate_select.items(): 
     206            sql, params = aggregate.as_sql(qn, self.connection) 
     207            result.append( 
     208                '%s%s' % ( 
     209                    sql, 
     210                    alias is not None 
     211                       and ' AS %s' % qn(truncate_name(alias, max_name_length)) 
     212                       or '' 
     213                ) 
    209214            ) 
    210             for alias, aggregate in self.query.aggregate_select.items() 
    211         ]) 
     215            query_params.extend(params) 
    212216 
    213217        for table, col in self.query.related_select_cols: 
    214218            r = '%s.%s' % (qn(table), qn(col)) 
    class SQLCompiler(object): 
    223227                col_aliases.add(col) 
    224228 
    225229        self._select_aliases = aliases 
    226         return result 
     230        return result, query_params 
    227231 
    228232    def get_default_columns(self, with_aliases=False, col_aliases=None, 
    229233            start_alias=None, opts=None, as_pairs=False, local_only=False): 
    class SQLAggregateCompiler(SQLCompiler): 
    948952        """ 
    949953        if qn is None: 
    950954            qn = self.quote_name_unless_alias 
     955        buf = [] 
     956        a_params = [] 
     957        for aggregate in self.query.aggregate_select.values(): 
     958            sql, query_params = aggregate.as_sql(qn, self.connection) 
     959            buf.append(sql) 
     960            a_params.extend(query_params) 
     961        aggregate_sql = ', '.join(buf) 
    951962        sql = ('SELECT %s FROM (%s) subquery' % ( 
    952             ', '.join([ 
    953                 aggregate.as_sql(qn, self.connection) 
    954                 for aggregate in self.query.aggregate_select.values() 
    955             ]), 
     963            aggregate_sql,   
    956964            self.query.subquery) 
    957965        ) 
    958         params = self.query.sub_params 
     966        params = tuple(a_params) + (self.query.sub_params) 
    959967        return (sql, params) 
    960968 
    961969class SQLDateCompiler(SQLCompiler): 
  • django/db/models/sql/expressions.py

    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    index 1bbf742..3df33f0 100644
    a b from django.db.models.fields import FieldDoesNotExist 
    33from django.db.models.sql.constants import LOOKUP_SEP 
    44 
    55class SQLEvaluator(object): 
    6     def __init__(self, expression, query, allow_joins=True): 
     6    def __init__(self, expression, query, allow_joins=True, promote_joins=False): 
    77        self.expression = expression 
    88        self.opts = query.get_meta() 
    99        self.cols = {} 
    1010 
    1111        self.contains_aggregate = False 
    12         self.expression.prepare(self, query, allow_joins) 
     12        self.expression.prepare(self, query, allow_joins, promote_joins) 
    1313 
    1414    def prepare(self): 
    1515        return self 
    class SQLEvaluator(object): 
    2828    # Vistor methods for initial expression preparation # 
    2929    ##################################################### 
    3030 
    31     def prepare_node(self, node, query, allow_joins): 
     31    def prepare_node(self, node, query, allow_joins, promote_joins): 
    3232        for child in node.children: 
    3333            if hasattr(child, 'prepare'): 
    34                 child.prepare(self, query, allow_joins) 
     34                child.prepare(self, query, allow_joins, promote_joins) 
    3535 
    36     def prepare_leaf(self, node, query, allow_joins): 
     36    def prepare_leaf(self, node, query, allow_joins, promote_joins): 
    3737        if not allow_joins and LOOKUP_SEP in node.name: 
    3838            raise FieldError("Joined field references are not permitted in this query") 
    3939 
    class SQLEvaluator(object): 
    4848                    field_list, query.get_meta(), 
    4949                    query.get_initial_alias(), False) 
    5050                col, _, join_list = query.trim_joins(source, join_list, last, False) 
     51                if promote_joins: 
     52                    for column_alias in join_list: 
     53                        query.promote_alias(column_alias, unconditional=True) 
    5154 
    5255                self.cols[node] = (join_list[-1], col) 
    5356            except FieldDoesNotExist: 
    class SQLEvaluator(object): 
    6568        for child in node.children: 
    6669            if hasattr(child, 'evaluate'): 
    6770                sql, params = child.evaluate(self, qn, connection) 
     71                if isinstance(sql, tuple): 
     72                    expression_params.extend(sql[1]) 
     73                    sql = sql[0] 
    6874            else: 
    6975                sql, params = '%s', (child,) 
    7076 
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index 110e317..b80f1c9 100644
    a b class Query(object): 
    955955        Adds a single aggregate expression to the Query 
    956956        """ 
    957957        opts = model._meta 
    958         field_list = aggregate.lookup.split(LOOKUP_SEP) 
    959         if len(field_list) == 1 and aggregate.lookup in self.aggregates: 
    960             # Aggregate is over an annotation 
    961             field_name = field_list[0] 
    962             col = field_name 
    963             source = self.aggregates[field_name] 
    964             if not is_summary: 
    965                 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 
    966                     aggregate.name, field_name, field_name)) 
    967         elif ((len(field_list) > 1) or 
    968             (field_list[0] not in [i.name for i in opts.fields]) or 
    969             self.group_by is None or 
    970             not is_summary): 
    971             # If: 
    972             #   - the field descriptor has more than one part (foo__bar), or 
    973             #   - the field descriptor is referencing an m2m/m2o field, or 
    974             #   - this is a reference to a model field (possibly inherited), or 
    975             #   - this is an annotation over a model field 
    976             # then we need to explore the joins that are required. 
    977  
    978             field, source, opts, join_list, last, _ = self.setup_joins( 
    979                 field_list, opts, self.get_initial_alias(), False) 
    980  
    981             # Process the join chain to see if it can be trimmed 
    982             col, _, join_list = self.trim_joins(source, join_list, last, False) 
    983  
    984             # If the aggregate references a model or field that requires a join, 
    985             # those joins must be LEFT OUTER - empty join rows must be returned 
    986             # in order for zeros to be returned for those aggregates. 
    987             for column_alias in join_list: 
    988                 self.promote_alias(column_alias, unconditional=True) 
    989  
    990             col = (join_list[-1], col) 
     958        only = aggregate.only 
     959        if hasattr(aggregate.lookup, 'evaluate'): 
     960            # If lookup is a query expression, evaluate it 
     961            col = SQLEvaluator(aggregate.lookup, self, promote_joins=True) 
     962            # TODO: find out the real source of this field. If any field has 
     963            # is_computed, then source can be set to is_computed. 
     964            source = None 
    991965        else: 
    992             # The simplest cases. No joins required - 
    993             # just reference the provided column alias. 
    994             field_name = field_list[0] 
    995             source = opts.get_field(field_name) 
    996             col = field_name 
    997  
     966            field_list = aggregate.lookup.split(LOOKUP_SEP) 
     967            join_list = [] 
     968            if len(field_list) == 1 and aggregate.lookup in self.aggregates: 
     969                # Aggregate is over an annotation 
     970                field_name = field_list[0] 
     971                col = field_name 
     972                source = self.aggregates[field_name] 
     973                if not is_summary: 
     974                    raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 
     975                        aggregate.name, field_name, field_name)) 
     976                if only: 
     977                    raise FieldError("Cannot use aggregated fields in conditional aggregates") 
     978            elif ((len(field_list) > 1) or 
     979                (field_list[0] not in [i.name for i in opts.fields]) or 
     980                self.group_by is None or 
     981                not is_summary): 
     982                # If: 
     983                #   - the field descriptor has more than one part (foo__bar), or 
     984                #   - the field descriptor is referencing an m2m/m2o field, or 
     985                #   - this is a reference to a model field (possibly inherited), or 
     986                #   - this is an annotation over a model field 
     987                # then we need to explore the joins that are required. 
     988 
     989                field, source, opts, join_list, last, _ = self.setup_joins( 
     990                    field_list, opts, self.get_initial_alias(), False) 
     991 
     992                # Process the join chain to see if it can be trimmed 
     993                col, _, join_list = self.trim_joins(source, join_list, last, False) 
     994 
     995                # If the aggregate references a model or field that requires a join, 
     996                # those joins must be LEFT OUTER - empty join rows must be returned 
     997                # in order for zeros to be returned for those aggregates. 
     998                for column_alias in join_list: 
     999                    self.promote_alias(column_alias, unconditional=True) 
     1000 
     1001                col = (join_list[-1], col) 
     1002            else: 
     1003                # The simplest cases. No joins required - 
     1004                # just reference the provided column alias. 
     1005                field_name = field_list[0] 
     1006                source = opts.get_field(field_name) 
     1007                col = field_name 
     1008 
     1009        if only: 
     1010            original_where = self.where 
     1011            original_having = self.having 
     1012            aggregate.condition = self.where_class() 
     1013            self.where = aggregate.condition 
     1014            self.having = self.where_class() 
     1015            original_alias_map = self.alias_map.keys()[:] 
     1016            self.add_q(only, used_aliases=set(original_alias_map))  
     1017            if original_alias_map != self.alias_map.keys(): 
     1018                raise FieldError("Aggregate's only condition can not require additional joins, Original joins: %s, joins after: %s" % (original_alias_map, self.alias_map.keys())) 
     1019            if self.having.children: 
     1020                raise FieldError("Aggregate's only condition can not reference annotated fields") 
     1021            self.having = original_having 
     1022            self.where = original_where 
    9981023        # Add the aggregate to the query 
    9991024        aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) 
    10001025 
  • django/db/models/sql/where.py

    diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
    index 2427a52..7ac0287 100644
    a b class WhereNode(tree.Node): 
    137137        it. 
    138138        """ 
    139139        lvalue, lookup_type, value_annot, params_or_value = child 
     140        additional_params = [] 
    140141        if hasattr(lvalue, 'process'): 
    141142            try: 
    142143                lvalue, params = lvalue.process(lookup_type, params_or_value, connection) 
    class WhereNode(tree.Node): 
    151152        else: 
    152153            # A smart object with an as_sql() method. 
    153154            field_sql = lvalue.as_sql(qn, connection) 
     155            if isinstance(field_sql, tuple): 
     156                # It returned also params 
     157                additional_params.extend(field_sql[1]) 
     158                field_sql = field_sql[0] 
    154159 
    155160        if value_annot is datetime.datetime: 
    156161            cast_sql = connection.ops.datetime_cast_sql() 
    class WhereNode(tree.Node): 
    159164 
    160165        if hasattr(params, 'as_sql'): 
    161166            extra, params = params.as_sql(qn, connection) 
     167            if isinstance(extra, tuple): 
     168                params = params + tuple(extra[1]) 
     169                extra = extra[0] 
    162170            cast_sql = '' 
    163171        else: 
    164172            extra = '' 
    class WhereNode(tree.Node): 
    168176            lookup_type = 'isnull' 
    169177            value_annot = True 
    170178 
     179        additional_params.extend(params) 
     180        params = additional_params 
    171181        if lookup_type in connection.operators: 
    172182            format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) 
    173183            return (format % (field_sql, 
  • tests/modeltests/aggregation/tests.py

    diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py
    index 6f68800..c5ec1e5 100644
    a b import datetime 
    22from decimal import Decimal 
    33 
    44from django.db.models import Avg, Sum, Count, Max, Min 
     5from django.db.models import Q, F 
     6from django.core.exceptions import FieldError 
    57from django.test import TestCase, Approximate 
    68 
    79from models import Author, Publisher, Book, Store 
    class BaseAggregateTestCase(TestCase): 
    1618    def test_single_aggregate(self): 
    1719        vals = Author.objects.aggregate(Avg("age")) 
    1820        self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)}) 
     21        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29))) 
     22        self.assertEqual(vals, {"age__sum": 254}) 
     23        vals = Author.objects.extra(select={'testparams':'age < %s'}, select_params=[0])\ 
     24               .aggregate(Sum("age", only=Q(age__gt=29))) 
     25        self.assertEqual(vals, {"age__sum": 254}) 
     26        vals = Author.objects.aggregate(Sum("age", only=Q(name__icontains='jaco')|Q(name__icontains='adrian'))) 
     27        self.assertEqual(vals, {"age__sum": 69})  
    1928 
    2029    def test_multiple_aggregates(self): 
    2130        vals = Author.objects.aggregate(Sum("age"), Avg("age")) 
    2231        self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)}) 
     32        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)), Avg("age")) 
     33        self.assertEqual(vals, {"age__sum": 254, "age__avg": Approximate(37.4, places=1)}) 
    2334 
    2435    def test_filter_aggregate(self): 
    2536        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age")) 
    2637        self.assertEqual(len(vals), 1) 
    2738        self.assertEqual(vals["age__sum"], 254) 
     39        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age", only=Q(age__lt=29))) 
     40        # If there are no matching aggregates, then None, not 0 is the answer. 
     41        self.assertEqual(vals["age__sum"], None) 
    2842 
    2943    def test_related_aggregate(self): 
    3044        vals = Author.objects.aggregate(Avg("friends__age")) 
    3145        self.assertEqual(len(vals), 1) 
    3246        self.assertAlmostEqual(vals["friends__age__avg"], 34.07, places=2) 
    3347 
     48        vals = Author.objects.aggregate(Avg("friends__age", only=Q(age__lt=29))) 
     49        self.assertEqual(len(vals), 1) 
     50        self.assertAlmostEqual(vals["friends__age__avg"], 33.67, places=2) 
     51        vals2 = Author.objects.filter(age__lt=29).aggregate(Avg("friends__age")) 
     52        self.assertEqual(vals, vals2) 
     53 
     54        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=35))) 
     55        self.assertEqual(len(vals), 1) 
     56        self.assertAlmostEqual(vals["friends__age__avg"], 28.75, places=2) 
     57 
     58        # The average age of author's friends, whose age is lower than the authors age. 
     59        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=F('age')))) 
     60        self.assertEqual(len(vals), 1) 
     61        self.assertAlmostEqual(vals["friends__age__avg"], 30.43, places=2) 
     62 
    3463        vals = Book.objects.filter(rating__lt=4.5).aggregate(Avg("authors__age")) 
    3564        self.assertEqual(len(vals), 1) 
    3665        self.assertAlmostEqual(vals["authors__age__avg"], 38.2857, places=2) 
    class BaseAggregateTestCase(TestCase): 
    5180        vals = Store.objects.aggregate(Max("books__authors__age")) 
    5281        self.assertEqual(len(vals), 1) 
    5382        self.assertEqual(vals["books__authors__age__max"], 57) 
     83         
     84        vals = Store.objects.aggregate(Max("books__authors__age", only=Q(books__authors__age__lt=56))) 
     85        self.assertEqual(len(vals), 1) 
     86        self.assertEqual(vals["books__authors__age__max"], 46) 
    5487 
    5588        vals = Author.objects.aggregate(Min("book__publisher__num_awards")) 
    5689        self.assertEqual(len(vals), 1) 
    class BaseAggregateTestCase(TestCase): 
    82115        ) 
    83116        self.assertEqual(b.mean_age, 34.5) 
    84117 
     118        # Test extra-select 
     119        books = Book.objects.annotate(mean_age=Avg("authors__age")) 
     120        books = books.annotate(mean_age2=Avg('authors__age', only=Q(authors__age__gte=0))) 
     121        books = books.extra(select={'testparams': 'publisher_id = %s'}, select_params=[1]) 
     122        b = books.get(pk=1) 
     123        self.assertEqual(b.mean_age, 34.5) 
     124        self.assertEqual(b.mean_age2, 34.5) 
     125        self.assertEqual(b.testparams, True) 
     126 
     127        # Test relabel_aliases 
     128        excluded_authors = Author.objects.annotate(book_rating=Min(F('book__rating') + 5, only=Q(pk__gte=1))) 
     129        excluded_authors = excluded_authors.filter(book_rating__lt=0) 
     130        books = books.exclude(authors__in=excluded_authors) 
     131        b = books.get(pk=1) 
     132        self.assertEqual(b.mean_age, 34.5) 
     133 
     134        # Test joins in F-based annotation 
     135        books = Book.objects.annotate(oldest=Max(F('authors__age'))) 
     136        books = books.values_list('rating', 'oldest').order_by('rating', 'oldest') 
     137        self.assertEqual( 
     138            list(books), 
     139            [(3.0, 45), (4.0, 29), (4.0, 37), (4.0, 57), (4.5, 35), (5.0, 57)] 
     140        ) 
     141 
     142        publishers = Publisher.objects.annotate(avg_rating=Avg(F('book__rating') - 0)) 
     143        publishers = publishers.values_list('id', 'avg_rating').order_by('id') 
     144        self.assertEqual(list(publishers), [(1, 4.25), (2, 3.0), (3, 4.0), (4, 5.0), (5, None)]) 
     145 
    85146    def test_annotate_m2m(self): 
    86147        books = Book.objects.filter(rating__lt=4.5).annotate(Avg("authors__age")).order_by("name") 
    87148        self.assertQuerysetEqual( 
    class BaseAggregateTestCase(TestCase): 
    106167            ], 
    107168            lambda b: (b.name, b.num_authors) 
    108169        ) 
     170         
     171        def raises_exception(): 
     172            list(Book.objects.annotate(num_authors=Count("authors")).annotate(num_authors2=Count("authors", only=Q(num_authors__gt=1))).order_by("name")) 
     173 
     174        self.assertRaises(FieldError, raises_exception) 
    109175 
    110176    def test_backwards_m2m_annotate(self): 
    111177        authors = Author.objects.filter(name__contains="a").annotate(Avg("book__rating")).order_by("name") 
    class BaseAggregateTestCase(TestCase): 
    192258                } 
    193259            ] 
    194260        ) 
     261        books = Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age', only=Q(authors__age__lt=35))).values('pk', 'isbn', 'mean_age') 
     262        self.assertEqual( 
     263            list(books), [ 
     264                { 
     265                    "pk": 1, 
     266                    "isbn": "159059725", 
     267                    "mean_age": 34.0, 
     268                } 
     269            ] 
     270        ) 
    195271 
    196272        books = Book.objects.filter(pk=1).annotate(mean_age=Avg("authors__age")).values("name") 
    197273        self.assertEqual( 
    class BaseAggregateTestCase(TestCase): 
    269345 
    270346        vals = Book.objects.aggregate(Count("rating", distinct=True)) 
    271347        self.assertEqual(vals, {"rating__count": 4}) 
     348        vals = Book.objects.aggregate( 
     349            low_count=Count("rating", only=Q(rating__lt=4)),  
     350            high_count=Count("rating", only=Q(rating__gte=4)) 
     351        ) 
     352        self.assertEqual(vals, {"low_count": 1, 'high_count': 5}) 
     353        vals = Book.objects.aggregate( 
     354            low_count=Count("rating", distinct=True, only=Q(rating__lt=4)),  
     355            high_count=Count("rating", distinct=True, only=Q(rating__gte=4)) 
     356        ) 
     357        self.assertEqual(vals, {"low_count": 1, 'high_count': 3}) 
    272358 
    273359    def test_fkey_aggregate(self): 
    274360        explicit = list(Author.objects.annotate(Count('book__id'))) 
    class BaseAggregateTestCase(TestCase): 
    388474            ], 
    389475            lambda p: p.name, 
    390476        ) 
     477        publishers = Publisher.objects.annotate(num_books=Count("book__id", only=Q(book__id__gt=5))).filter(num_books__gt=1, book__price__lt=Decimal("40.0")).order_by("pk") 
     478        self.assertQuerysetEqual( 
     479            publishers, [ 
     480                "Expensive Publisher", 
     481            ], 
     482            lambda p: p.name, 
     483        ) 
    391484 
    392485        publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk") 
    393486        self.assertQuerysetEqual(