Ticket #11305: conditional_aggregates.1.6.cleanup.patch

File conditional_aggregates.1.6.cleanup.patch, 28.5 KB (added by Anssi Kääriäinen, 11 years ago)
  • django/db/models/aggregates.py

    diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
    index a2349cf..d816aa7 100644
    a b class Aggregate(object):  
    66    """
    77    Default Aggregate definition.
    88    """
    9     def __init__(self, lookup, **extra):
     9    def __init__(self, lookup, only=None, **extra):
    1010        """Instantiate a new aggregate.
    1111
    1212         * lookup is the field on which the aggregate operates.
     13         * only is a Q-object used in conditional aggregation.
    1314         * extra is a dictionary of additional data to provide for the
    1415           aggregate definition
    1516
    class Aggregate(object):  
    1819        """
    1920        self.lookup = lookup
    2021        self.extra = extra
     22        self.only = only
     23        self.condition = None
    2124
    2225    def _default_alias(self):
     26        if hasattr(self.lookup, 'evaluate'):
     27             raise ValueError('When aggregating over an expression, you need to give an alias.')
    2328        return '%s__%s' % (self.lookup, self.name.lower())
    2429    default_alias = property(_default_alias)
    2530
    class Aggregate(object):  
    4247           summary value rather than an annotation.
    4348        """
    4449        klass = getattr(query.aggregates_module, self.name)
    45         aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)
     50        aggregate = klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra)
    4651        query.aggregates[alias] = aggregate
    4752
    4853class Avg(Aggregate):
  • django/db/models/expressions.py

    diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
    index 3566d77..2b782ea 100644
    a b class ExpressionNode(tree.Node):  
    4141    # VISITOR METHODS #
    4242    ###################
    4343
    44     def prepare(self, evaluator, query, allow_joins):
    45         return evaluator.prepare_node(self, query, allow_joins)
     44    def prepare(self, evaluator, query, allow_joins, promote_joins=False):
     45        return evaluator.prepare_node(self, query, allow_joins, promote_joins)
    4646
    4747    def evaluate(self, evaluator, qn, connection):
    4848        return evaluator.evaluate_node(self, qn, connection)
    class F(ExpressionNode):  
    129129        obj.name = self.name
    130130        return obj
    131131
    132     def prepare(self, evaluator, query, allow_joins):
    133         return evaluator.prepare_leaf(self, query, allow_joins)
     132    def prepare(self, evaluator, query, allow_joins, promote_joins=False):
     133        return evaluator.prepare_leaf(self, query, allow_joins, promote_joins)
    134134
    135135    def evaluate(self, evaluator, qn, connection):
    136136        return evaluator.evaluate_leaf(self, qn, connection)
  • django/db/models/sql/aggregates.py

    diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
    index b41314a..5fe2215 100644
    a b Classes to represent the default SQL aggregate functions  
    33"""
    44
    55from django.db.models.fields import IntegerField, FloatField
     6from django.db.models.sql.expressions import SQLEvaluator
    67
    78# Fake fields used to identify aggregate types in data-conversion operations.
    89ordinal_aggregate_field = IntegerField()
    class Aggregate(object):  
    1516    is_ordinal = False
    1617    is_computed = False
    1718    sql_template = '%(function)s(%(field)s)'
     19    conditional_template = "CASE WHEN %(condition)s THEN %(field_name)s ELSE null END"
    1820
    19     def __init__(self, col, source=None, is_summary=False, **extra):
     21    def __init__(self, col, source=None, is_summary=False, condition=None, **extra):
    2022        """Instantiate an SQL aggregate
    2123
    2224         * col is a column reference describing the subject field
    class Aggregate(object):  
    2628           the column reference. If the aggregate is not an ordinal or
    2729           computed type, this reference is used to determine the coerced
    2830           output type of the aggregate.
     31         * condition is used in conditional aggregation.
    2932         * extra is a dictionary of additional data to provide for the
    30            aggregate definition
     33           aggregate definition.
    3134
    3235        Also utilizes the class variables:
    3336         * sql_function, the name of the SQL function that implements the
    class Aggregate(object):  
    3538         * sql_template, a template string that is used to render the
    3639           aggregate into SQL.
    3740         * is_ordinal, a boolean indicating if the output of this aggregate
    38            is an integer (e.g., a count)
     41           is an integer (e.g., a count).
    3942         * is_computed, a boolean indicating if this output of this aggregate
    4043           is a computed float (e.g., an average), regardless of the input
    4144           type.
    class Aggregate(object):  
    4548        self.source = source
    4649        self.is_summary = is_summary
    4750        self.extra = extra
     51        self.condition = condition
    4852
    4953        # Follow the chain of aggregate sources back until you find an
    5054        # actual field, or an aggregate that forces a particular output
    class Aggregate(object):  
    6569    def relabel_aliases(self, change_map):
    6670        if isinstance(self.col, (list, tuple)):
    6771            self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
     72        else:
     73            self.col.relabel_aliases(change_map)
     74        if self.condition:
     75            self.condition.relabel_aliases(change_map)
    6876
    6977    def as_sql(self, qn, connection):
    7078        "Return the aggregate, rendered as SQL."
    7179
     80        condition_params = []
     81        col_params = []
    7282        if hasattr(self.col, 'as_sql'):
    73             field_name = self.col.as_sql(qn, connection)
     83            if isinstance(self.col, SQLEvaluator):
     84                field_name, col_params = self.col.as_sql(qn, connection)
     85            else:
     86                field_name = self.col.as_sql(qn, connection)
    7487        elif isinstance(self.col, (list, tuple)):
    7588            field_name = '.'.join([qn(c) for c in self.col])
    7689        else:
    7790            field_name = self.col
    7891
    79         params = {
    80             'function': self.sql_function,
    81             'field': field_name
    82         }
     92        if self.condition:
     93            condition, condition_params = self.condition.as_sql(qn, connection)
     94            conditional_field = self.conditional_template % {
     95                'condition': condition,
     96                'field_name': field_name
     97            }
     98            params = {
     99                'function': self.sql_function,
     100                'field': conditional_field,
     101            }
     102        else:
     103            params = {
     104                'function': self.sql_function,
     105                'field': field_name
     106            }
    83107        params.update(self.extra)
    84108
    85         return self.sql_template % params
     109        condition_params.extend(col_params)
     110        return (self.sql_template % params, condition_params)
    86111
    87112
    88113class Avg(Aggregate):
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 4d846fb..606b364 100644
    a b class SQLCompiler(object):  
    7171        # as the pre_sql_setup will modify query state in a way that forbids
    7272        # another run of it.
    7373        self.refcounts_before = self.query.alias_refcount.copy()
    74         out_cols = self.get_columns(with_col_aliases)
     74        out_cols, c_params = self.get_columns(with_col_aliases)
    7575        ordering, ordering_group_by = self.get_ordering()
    7676
    7777        distinct_fields = self.get_distinct()
    class SQLCompiler(object):  
    8787        params = []
    8888        for val in six.itervalues(self.query.extra_select):
    8989            params.extend(val[1])
     90        # Extra-select comes before aggregation in the select list
     91        params.extend(c_params)
    9092
    9193        result = ['SELECT']
    9294
    class SQLCompiler(object):  
    172174        qn = self.quote_name_unless_alias
    173175        qn2 = self.connection.ops.quote_name
    174176        result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
     177        query_params = []
    175178        aliases = set(self.query.extra_select.keys())
    176179        if with_aliases:
    177180            col_aliases = aliases.copy()
    class SQLCompiler(object):  
    214217            aliases.update(new_aliases)
    215218
    216219        max_name_length = self.connection.ops.max_name_length()
    217         result.extend([
    218             '%s%s' % (
    219                 aggregate.as_sql(qn, self.connection),
    220                 alias is not None
    221                     and ' AS %s' % qn(truncate_name(alias, max_name_length))
    222                     or ''
     220        for alias, aggregate in self.query.aggregate_select.items():
     221            sql, params = aggregate.as_sql(qn, self.connection)
     222            result.append(
     223                '%s%s' % (
     224                    sql,
     225                    alias is not None
     226                       and ' AS %s' % qn(truncate_name(alias, max_name_length))
     227                       or ''
     228                )
    223229            )
    224             for alias, aggregate in self.query.aggregate_select.items()
    225         ])
     230            query_params.extend(params)
    226231
    227232        for (table, col), _ in self.query.related_select_cols:
    228233            r = '%s.%s' % (qn(table), qn(col))
    class SQLCompiler(object):  
    237242                col_aliases.add(col)
    238243
    239244        self._select_aliases = aliases
    240         return result
     245        return result, query_params
    241246
    242247    def get_default_columns(self, with_aliases=False, col_aliases=None,
    243248            start_alias=None, opts=None, as_pairs=False, from_parent=None):
    class SQLAggregateCompiler(SQLCompiler):  
    10401045        """
    10411046        if qn is None:
    10421047            qn = self.quote_name_unless_alias
     1048        buf = []
     1049        a_params = []
     1050        for aggregate in self.query.aggregate_select.values():
     1051            sql, query_params = aggregate.as_sql(qn, self.connection)
     1052            buf.append(sql)
     1053            a_params.extend(query_params)
     1054        aggregate_sql = ', '.join(buf)
    10431055
    10441056        sql = ('SELECT %s FROM (%s) subquery' % (
    1045             ', '.join([
    1046                 aggregate.as_sql(qn, self.connection)
    1047                 for aggregate in self.query.aggregate_select.values()
    1048             ]),
     1057            aggregate_sql,
    10491058            self.query.subquery)
    10501059        )
    1051         params = self.query.sub_params
     1060        params = tuple(a_params) + (self.query.sub_params)
    10521061        return (sql, params)
    10531062
    10541063class SQLDateCompiler(SQLCompiler):
  • django/db/models/sql/expressions.py

    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    index af7e45e..9441e0e 100644
    a b from django.db.models.fields import FieldDoesNotExist  
    44from django.db.models.sql.constants import REUSE_ALL
    55
    66class SQLEvaluator(object):
    7     def __init__(self, expression, query, allow_joins=True, reuse=REUSE_ALL):
     7    def __init__(self, expression, query, allow_joins=True, reuse=REUSE_ALL, promote_joins=False):
    88        self.expression = expression
    99        self.opts = query.get_meta()
    1010        self.cols = []
    1111
    1212        self.contains_aggregate = False
    1313        self.reuse = reuse
    14         self.expression.prepare(self, query, allow_joins)
     14        self.expression.prepare(self, query, allow_joins, promote_joins)
    1515
    1616    def prepare(self):
    1717        return self
    class SQLEvaluator(object):  
    3434    # Vistor methods for initial expression preparation #
    3535    #####################################################
    3636
    37     def prepare_node(self, node, query, allow_joins):
     37    def prepare_node(self, node, query, allow_joins, promote_joins):
    3838        for child in node.children:
    3939            if hasattr(child, 'prepare'):
    40                 child.prepare(self, query, allow_joins)
     40                child.prepare(self, query, allow_joins, promote_joins)
    4141
    42     def prepare_leaf(self, node, query, allow_joins):
     42    def prepare_leaf(self, node, query, allow_joins, promote_joins):
    4343        if not allow_joins and LOOKUP_SEP in node.name:
    4444            raise FieldError("Joined field references are not permitted in this query")
    4545
    class SQLEvaluator(object):  
    5454                    field_list, query.get_meta(),
    5555                    query.get_initial_alias(), self.reuse)
    5656                col, _, join_list = query.trim_joins(source, join_list, path)
     57                self.source = source
     58                if promote_joins:
     59                    query.promote_joins(join_list, unconditional=True)
    5760                if self.reuse is not None and self.reuse != REUSE_ALL:
    5861                    self.reuse.update(join_list)
    5962                self.cols.append((node, (join_list[-1], col)))
    class SQLEvaluator(object):  
    7275        for child in node.children:
    7376            if hasattr(child, 'evaluate'):
    7477                sql, params = child.evaluate(self, qn, connection)
     78                if isinstance(sql, tuple):
     79                    expression_params.extend(sql[1])
     80                    sql = sql[0]
    7581            else:
    7682                sql, params = '%s', (child,)
    7783
    class SQLEvaluator(object):  
    108114            return sql, params
    109115
    110116        return connection.ops.date_interval_sql(sql, node.connector, timedelta), params
     117
     118    def get_field_type(self):
     119        """
     120        Returns the field type of the result.
     121
     122        TODO: The field type resolving is very simple and will likely give
     123        incorrect field type in many situations.
     124        """
     125        return self.source
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index ff56211..205e31b 100644
    a b class Query(object):  
    10191019        Adds a single aggregate expression to the Query
    10201020        """
    10211021        opts = model._meta
    1022         field_list = aggregate.lookup.split(LOOKUP_SEP)
    1023         if len(field_list) == 1 and aggregate.lookup in self.aggregates:
    1024             # Aggregate is over an annotation
    1025             field_name = field_list[0]
    1026             col = field_name
    1027             source = self.aggregates[field_name]
    1028             if not is_summary:
    1029                 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
    1030                     aggregate.name, field_name, field_name))
    1031         elif ((len(field_list) > 1) or
    1032             (field_list[0] not in [i.name for i in opts.fields]) or
    1033             self.group_by is None or
    1034             not is_summary):
    1035             # If:
    1036             #   - the field descriptor has more than one part (foo__bar), or
    1037             #   - the field descriptor is referencing an m2m/m2o field, or
    1038             #   - this is a reference to a model field (possibly inherited), or
    1039             #   - this is an annotation over a model field
    1040             # then we need to explore the joins that are required.
    1041 
    1042             field, source, opts, join_list, path = self.setup_joins(
    1043                 field_list, opts, self.get_initial_alias(), REUSE_ALL)
    1044 
    1045             # Process the join chain to see if it can be trimmed
    1046             col, _, join_list = self.trim_joins(source, join_list, path)
    1047 
    1048             # If the aggregate references a model or field that requires a join,
    1049             # those joins must be LEFT OUTER - empty join rows must be returned
    1050             # in order for zeros to be returned for those aggregates.
    1051             self.promote_joins(join_list, True)
    1052 
    1053             col = (join_list[-1], col)
     1022        only = aggregate.only
     1023        if hasattr(aggregate.lookup, 'evaluate'):
     1024            # If lookup is a query expression, evaluate it
     1025            col = SQLEvaluator(aggregate.lookup, self, promote_joins=True)
     1026            source = col.get_field_type()
    10541027        else:
    1055             # The simplest cases. No joins required -
    1056             # just reference the provided column alias.
    1057             field_name = field_list[0]
    1058             source = opts.get_field(field_name)
    1059             col = field_name
     1028            field_list = aggregate.lookup.split(LOOKUP_SEP)
     1029            join_list = []
     1030            if len(field_list) == 1 and aggregate.lookup in self.aggregates:
     1031                # Aggregate is over an annotation
     1032                field_name = field_list[0]
     1033                col = field_name
     1034                source = self.aggregates[field_name]
     1035                if not is_summary:
     1036                    raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
     1037                        aggregate.name, field_name, field_name))
     1038                if only:
     1039                    raise FieldError("Cannot use aggregated fields in conditional aggregates")
     1040            elif ((len(field_list) > 1) or
     1041                    (field_list[0] not in [i.name for i in opts.fields]) or
     1042                    self.group_by is None or
     1043                    not is_summary):
     1044                # If:
     1045                #   - the field descriptor has more than one part (foo__bar), or
     1046                #   - the field descriptor is referencing an m2m/m2o field, or
     1047                #   - this is a reference to a model field (possibly inherited), or
     1048                #   - this is an annotation over a model field
     1049                # then we need to explore the joins that are required.
     1050
     1051                field, source, opts, join_list, path = self.setup_joins(
     1052                    field_list, opts, self.get_initial_alias(), REUSE_ALL)
     1053
     1054                # Process the join chain to see if it can be trimmed
     1055                col, _, join_list = self.trim_joins(source, join_list, path)
     1056
     1057                # If the aggregate references a model or field that requires a join,
     1058                # those joins must be LEFT OUTER - empty join rows must be returned
     1059                # in order for zeros to be returned for those aggregates.
     1060                self.promote_joins(join_list, unconditional=True)
     1061
     1062                col = (join_list[-1], col)
     1063            else:
     1064                # The simplest cases. No joins required -
     1065                # just reference the provided column alias.
     1066                field_name = field_list[0]
     1067                source = opts.get_field(field_name)
     1068                col = field_name
     1069
     1070        if only:
     1071            original_where = self.where
     1072            original_having = self.having
     1073            aggregate.condition = self.where_class()
     1074            self.where = aggregate.condition
     1075            self.having = self.where_class()
     1076            original_alias_map = self.alias_map.keys()[:]
     1077            self.add_q(only, used_aliases=set(original_alias_map))
     1078            if original_alias_map != self.alias_map.keys():
     1079                raise FieldError("Aggregate's only condition can not require additional joins, "
     1080                                 "Original joins: %s, joins after: %s"
     1081                                 % (original_alias_map, self.alias_map.keys()))
     1082            if self.having.children:
     1083                raise FieldError("Aggregate's only condition can not reference annotated fields")
     1084            self.having = original_having
     1085            self.where = original_where
    10601086
    10611087        # Add the aggregate to the query
    10621088        aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
  • django/db/models/sql/where.py

    diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
    index 47f4ffa..6b001bc 100644
    a b class WhereNode(tree.Node):  
    156156        it.
    157157        """
    158158        lvalue, lookup_type, value_annotation, params_or_value = child
     159        additional_params = []
    159160        if isinstance(lvalue, Constraint):
    160161            try:
    161162                lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
    class WhereNode(tree.Node):  
    173174        else:
    174175            # A smart object with an as_sql() method.
    175176            field_sql = lvalue.as_sql(qn, connection)
     177            if isinstance(field_sql, tuple):
     178                # It also returned params
     179                additional_params.extend(field_sql[1])
     180                field_sql = field_sql[0]
    176181
    177182        if value_annotation is datetime.datetime:
    178183            cast_sql = connection.ops.datetime_cast_sql()
    class WhereNode(tree.Node):  
    181186
    182187        if hasattr(params, 'as_sql'):
    183188            extra, params = params.as_sql(qn, connection)
     189            if isinstance(extra, tuple):
     190                params = params + tuple(extra[1])
     191                extra = extra[0]
    184192            cast_sql = ''
    185193        else:
    186194            extra = ''
    class WhereNode(tree.Node):  
    190198            lookup_type = 'isnull'
    191199            value_annotation = True
    192200
     201        additional_params.extend(params)
     202        params = additional_params
    193203        if lookup_type in connection.operators:
    194204            format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
    195205            return (format % (field_sql,
  • tests/modeltests/aggregation/tests.py

    diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py
    index c23b32f..0c445c4 100644
    a b import datetime  
    44from decimal import Decimal
    55
    66from django.db.models import Avg, Sum, Count, Max, Min
     7from django.db.models import Q, F
     8from django.core.exceptions import FieldError
    79from django.test import TestCase, Approximate
    810
    911from .models import Author, Publisher, Book, Store
    class BaseAggregateTestCase(TestCase):  
    1820    def test_single_aggregate(self):
    1921        vals = Author.objects.aggregate(Avg("age"))
    2022        self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)})
     23        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)))
     24        self.assertEqual(vals, {"age__sum": 254})
     25        vals = Author.objects.extra(select={'testparams':'age < %s'}, select_params=[0])\
     26               .aggregate(Sum("age", only=Q(age__gt=29)))
     27        self.assertEqual(vals, {"age__sum": 254})
     28        vals = Author.objects.aggregate(Sum("age", only=Q(name__icontains='jaco')|Q(name__icontains='adrian')))
     29        self.assertEqual(vals, {"age__sum": 69})
    2130
    2231    def test_multiple_aggregates(self):
    2332        vals = Author.objects.aggregate(Sum("age"), Avg("age"))
    2433        self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)})
     34        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)), Avg("age"))
     35        self.assertEqual(vals, {"age__sum": 254, "age__avg": Approximate(37.4, places=1)})
    2536
    2637    def test_filter_aggregate(self):
    2738        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age"))
    2839        self.assertEqual(len(vals), 1)
    2940        self.assertEqual(vals["age__sum"], 254)
     41        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age", only=Q(age__lt=29)))
     42        # If there are no matching aggregates, then None, not 0 is the answer.
     43        self.assertEqual(vals["age__sum"], None)
    3044
    3145    def test_related_aggregate(self):
    3246        vals = Author.objects.aggregate(Avg("friends__age"))
    3347        self.assertEqual(len(vals), 1)
    3448        self.assertAlmostEqual(vals["friends__age__avg"], 34.07, places=2)
    3549
     50        vals = Author.objects.aggregate(Avg("friends__age", only=Q(age__lt=29)))
     51        self.assertEqual(len(vals), 1)
     52        self.assertAlmostEqual(vals["friends__age__avg"], 33.67, places=2)
     53        vals2 = Author.objects.filter(age__lt=29).aggregate(Avg("friends__age"))
     54        self.assertEqual(vals, vals2)
     55
     56        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=35)))
     57        self.assertEqual(len(vals), 1)
     58        self.assertAlmostEqual(vals["friends__age__avg"], 28.75, places=2)
     59
     60        # The average age of author's friends, whose age is lower than the authors age.
     61        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=F('age'))))
     62        self.assertEqual(len(vals), 1)
     63        self.assertAlmostEqual(vals["friends__age__avg"], 30.43, places=2)
     64
    3665        vals = Book.objects.filter(rating__lt=4.5).aggregate(Avg("authors__age"))
    3766        self.assertEqual(len(vals), 1)
    3867        self.assertAlmostEqual(vals["authors__age__avg"], 38.2857, places=2)
    class BaseAggregateTestCase(TestCase):  
    5483        self.assertEqual(len(vals), 1)
    5584        self.assertEqual(vals["books__authors__age__max"], 57)
    5685
     86        vals = Store.objects.aggregate(Max("books__authors__age", only=Q(books__authors__age__lt=56)))
     87        self.assertEqual(len(vals), 1)
     88        self.assertEqual(vals["books__authors__age__max"], 46)
     89
    5790        vals = Author.objects.aggregate(Min("book__publisher__num_awards"))
    5891        self.assertEqual(len(vals), 1)
    5992        self.assertEqual(vals["book__publisher__num_awards__min"], 1)
    class BaseAggregateTestCase(TestCase):  
    84117        )
    85118        self.assertEqual(b.mean_age, 34.5)
    86119
     120        # Test extra-select
     121        books = Book.objects.annotate(mean_age=Avg("authors__age"))
     122        books = books.annotate(mean_age2=Avg('authors__age', only=Q(authors__age__gte=0)))
     123        books = books.extra(select={'testparams': 'publisher_id = %s'}, select_params=[1])
     124        b = books.get(pk=1)
     125        self.assertEqual(b.mean_age, 34.5)
     126        self.assertEqual(b.mean_age2, 34.5)
     127        self.assertEqual(b.testparams, True)
     128
     129        # Test relabel_aliases
     130        excluded_authors = Author.objects.annotate(book_rating=Min(F('book__rating') + 5, only=Q(pk__gte=1)))
     131        excluded_authors = excluded_authors.filter(book_rating__lt=0)
     132        books = books.exclude(authors__in=excluded_authors)
     133        b = books.get(pk=1)
     134        self.assertEqual(b.mean_age, 34.5)
     135
     136        # Test joins in F-based annotation
     137        books = Book.objects.annotate(oldest=Max(F('authors__age')))
     138        books = books.values_list('rating', 'oldest').order_by('rating', 'oldest')
     139        self.assertEqual(
     140            list(books),
     141            [(3.0, 45), (4.0, 29), (4.0, 37), (4.0, 57), (4.5, 35), (5.0, 57)]
     142        )
     143
     144        publishers = Publisher.objects.annotate(avg_rating=Avg(F('book__rating') - 0))
     145        publishers = publishers.values_list('id', 'avg_rating').order_by('id')
     146        self.assertEqual(list(publishers), [(1, 4.25), (2, 3.0), (3, 4.0), (4, 5.0), (5, None)])
     147
    87148    def test_annotate_m2m(self):
    88149        books = Book.objects.filter(rating__lt=4.5).annotate(Avg("authors__age")).order_by("name")
    89150        self.assertQuerysetEqual(
    class BaseAggregateTestCase(TestCase):  
    109170            lambda b: (b.name, b.num_authors)
    110171        )
    111172
     173        def raises_exception():
     174            list(Book.objects.annotate(num_authors=Count("authors")).annotate(num_authors2=Count("authors", only=Q(num_authors__gt=1))).order_by("name"))
     175
     176        self.assertRaises(FieldError, raises_exception)
     177
    112178    def test_backwards_m2m_annotate(self):
    113179        authors = Author.objects.filter(name__contains="a").annotate(Avg("book__rating")).order_by("name")
    114180        self.assertQuerysetEqual(
    class BaseAggregateTestCase(TestCase):  
    194260                }
    195261            ]
    196262        )
     263        books = Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age', only=Q(authors__age__lt=35))).values('pk', 'isbn', 'mean_age')
     264        self.assertEqual(
     265            list(books), [
     266                {
     267                    "pk": 1,
     268                    "isbn": "159059725",
     269                    "mean_age": 34.0,
     270                }
     271            ]
     272        )
    197273
    198274        books = Book.objects.filter(pk=1).annotate(mean_age=Avg("authors__age")).values("name")
    199275        self.assertEqual(
    class BaseAggregateTestCase(TestCase):  
    271347
    272348        vals = Book.objects.aggregate(Count("rating", distinct=True))
    273349        self.assertEqual(vals, {"rating__count": 4})
     350        vals = Book.objects.aggregate(
     351            low_count=Count("rating", only=Q(rating__lt=4)),
     352            high_count=Count("rating", only=Q(rating__gte=4))
     353        )
     354        self.assertEqual(vals, {"low_count": 1, 'high_count': 5})
     355        vals = Book.objects.aggregate(
     356            low_count=Count("rating", distinct=True, only=Q(rating__lt=4)),
     357            high_count=Count("rating", distinct=True, only=Q(rating__gte=4))
     358        )
     359        self.assertEqual(vals, {"low_count": 1, 'high_count': 3})
    274360
    275361    def test_fkey_aggregate(self):
    276362        explicit = list(Author.objects.annotate(Count('book__id')))
    class BaseAggregateTestCase(TestCase):  
    390476            ],
    391477            lambda p: p.name,
    392478        )
     479        publishers = Publisher.objects.annotate(num_books=Count("book__id", only=Q(book__id__gt=5))).filter(num_books__gt=1, book__price__lt=Decimal("40.0")).order_by("pk")
     480        self.assertQuerysetEqual(
     481            publishers, [
     482                "Expensive Publisher",
     483            ],
     484            lambda p: p.name,
     485        )
    393486
    394487        publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk")
    395488        self.assertQuerysetEqual(
Back to Top