Code

Ticket #11305: poc_11305.2.patch

File poc_11305.2.patch, 17.9 KB (added by akaariai, 3 years ago)
  • django/db/models/aggregates.py

    diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
    index a2349cf..3d8b63a5 100644
    a b class Aggregate(object): 
    66    """ 
    77    Default Aggregate definition. 
    88    """ 
    9     def __init__(self, lookup, **extra): 
     9    def __init__(self, lookup, only=None, **extra): 
    1010        """Instantiate a new aggregate. 
    1111 
    1212         * lookup is the field on which the aggregate operates. 
     13         * only is a Q-object used in conditional aggregation. 
    1314         * extra is a dictionary of additional data to provide for the 
    1415           aggregate definition 
    1516 
    class Aggregate(object): 
    1819        """ 
    1920        self.lookup = lookup 
    2021        self.extra = extra 
     22        self.only = only 
     23        self.condition = None 
    2124 
    2225    def _default_alias(self): 
    2326        return '%s__%s' % (self.lookup, self.name.lower()) 
    class Aggregate(object): 
    4245           summary value rather than an annotation. 
    4346        """ 
    4447        klass = getattr(query.aggregates_module, self.name) 
    45         aggregate = klass(col, source=source, is_summary=is_summary, **self.extra) 
     48        aggregate = klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra) 
    4649        query.aggregates[alias] = aggregate 
    4750 
    4851class Avg(Aggregate): 
  • django/db/models/sql/aggregates.py

    diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
    index 207bc0c..0de19a3 100644
    a b class Aggregate(object): 
    2222    is_ordinal = False 
    2323    is_computed = False 
    2424    sql_template = '%(function)s(%(field)s)' 
     25    conditional_template = "CASE WHEN %(condition)s THEN %(field_name)s ELSE null END" 
    2526 
    26     def __init__(self, col, source=None, is_summary=False, **extra): 
     27    def __init__(self, col, source=None, is_summary=False, condition=None, **extra): 
    2728        """Instantiate an SQL aggregate 
    2829 
    2930         * col is a column reference describing the subject field 
    class Aggregate(object): 
    5253        self.source = source 
    5354        self.is_summary = is_summary 
    5455        self.extra = extra 
     56        self.condition = condition 
    5557 
    5658        # Follow the chain of aggregate sources back until you find an 
    5759        # actual field, or an aggregate that forces a particular output 
    class Aggregate(object): 
    7678    def as_sql(self, qn, connection): 
    7779        "Return the aggregate, rendered as SQL." 
    7880 
     81        query_params = [] 
    7982        if hasattr(self.col, 'as_sql'): 
    8083            field_name = self.col.as_sql(qn, connection) 
    8184        elif isinstance(self.col, (list, tuple)): 
    8285            field_name = '.'.join([qn(c) for c in self.col]) 
    8386        else: 
    8487            field_name = self.col 
    85  
    86         params = { 
    87             'function': self.sql_function, 
    88             'field': field_name 
    89         } 
     88        if self.condition: 
     89            condition = self.condition.as_sql(qn, connection) 
     90            query_params = condition[1] 
     91            conditional_field = self.conditional_template % { 
     92                'condition': condition[0],  
     93                'field_name': field_name 
     94            }  
     95            params = { 
     96                'function': self.sql_function, 
     97                'field': conditional_field, 
     98            }  
     99        else: 
     100            params = { 
     101                'function': self.sql_function, 
     102                'field': field_name 
     103            } 
    90104        params.update(self.extra) 
    91105 
    92         return self.sql_template % params 
     106        return (self.sql_template % params, query_params) 
    93107 
    94108 
    95109class Avg(Aggregate): 
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 841ec12..a0d2f35 100644
    a b class SQLCompiler(object): 
    5858            return '', () 
    5959 
    6060        self.pre_sql_setup() 
    61         out_cols = self.get_columns(with_col_aliases) 
     61        out_cols, c_params = self.get_columns(with_col_aliases) 
    6262        ordering, ordering_group_by = self.get_ordering() 
     63        params = [] 
     64        params.extend(c_params) 
    6365 
    6466        # This must come after 'select' and 'ordering' -- see docstring of 
    6567        # get_from_clause() for details. 
    class SQLCompiler(object): 
    6971 
    7072        where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection) 
    7173        having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection) 
    72         params = [] 
    7374        for val in self.query.extra_select.itervalues(): 
    7475            params.extend(val[1]) 
    7576 
    class SQLCompiler(object): 
    126127            if nowait and not self.connection.features.has_select_for_update_nowait: 
    127128                raise DatabaseError('NOWAIT is not supported on this database backend.') 
    128129            result.append(self.connection.ops.for_update_sql(nowait=nowait)) 
    129  
    130130        return ' '.join(result), tuple(params) 
    131131 
    132132    def as_nested_sql(self): 
    class SQLCompiler(object): 
    158158        qn = self.quote_name_unless_alias 
    159159        qn2 = self.connection.ops.quote_name 
    160160        result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()] 
     161        query_params = [] 
    161162        aliases = set(self.query.extra_select.keys()) 
    162163        if with_aliases: 
    163164            col_aliases = aliases.copy() 
    class SQLCompiler(object): 
    200201            aliases.update(new_aliases) 
    201202 
    202203        max_name_length = self.connection.ops.max_name_length() 
    203         result.extend([ 
    204             '%s%s' % ( 
    205                 aggregate.as_sql(qn, self.connection), 
    206                 alias is not None 
    207                     and ' AS %s' % qn(truncate_name(alias, max_name_length)) 
    208                     or '' 
     204        for alias, aggregate in self.query.aggregate_select.items(): 
     205            sql, params = aggregate.as_sql(qn, self.connection) 
     206            result.append( 
     207                '%s%s' % ( 
     208                    sql, 
     209                    alias is not None 
     210                       and ' AS %s' % qn(truncate_name(alias, max_name_length)) 
     211                       or '' 
     212                ) 
    209213            ) 
    210             for alias, aggregate in self.query.aggregate_select.items() 
    211         ]) 
     214            query_params.extend(params) 
    212215 
    213216        for table, col in self.query.related_select_cols: 
    214217            r = '%s.%s' % (qn(table), qn(col)) 
    class SQLCompiler(object): 
    223226                col_aliases.add(col) 
    224227 
    225228        self._select_aliases = aliases 
    226         return result 
     229        return result, query_params 
    227230 
    228231    def get_default_columns(self, with_aliases=False, col_aliases=None, 
    229232            start_alias=None, opts=None, as_pairs=False, local_only=False): 
    class SQLAggregateCompiler(SQLCompiler): 
    948951        """ 
    949952        if qn is None: 
    950953            qn = self.quote_name_unless_alias 
     954        buf = [] 
     955        a_params = [] 
     956        for aggregate in self.query.aggregate_select.values(): 
     957            sql, query_params = aggregate.as_sql(qn, self.connection) 
     958            buf.append(sql) 
     959            a_params.extend(query_params) 
     960        aggregate_sql = ', '.join(buf) 
    951961        sql = ('SELECT %s FROM (%s) subquery' % ( 
    952             ', '.join([ 
    953                 aggregate.as_sql(qn, self.connection) 
    954                 for aggregate in self.query.aggregate_select.values() 
    955             ]), 
     962            aggregate_sql,   
    956963            self.query.subquery) 
    957964        ) 
    958         params = self.query.sub_params 
     965        params = tuple(a_params) + (self.query.sub_params) 
    959966        return (sql, params) 
    960967 
    961968class SQLDateCompiler(SQLCompiler): 
  • django/db/models/sql/expressions.py

    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    index 1bbf742..f9c23a9 100644
    a b class SQLEvaluator(object): 
    6565        for child in node.children: 
    6666            if hasattr(child, 'evaluate'): 
    6767                sql, params = child.evaluate(self, qn, connection) 
     68                if isinstance(sql, tuple): 
     69                    expression_params.extend(sql[1]) 
     70                    sql = sql[0] 
    6871            else: 
    6972                sql, params = '%s', (child,) 
    7073 
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index 110e317..7735261 100644
    a b class Query(object): 
    956956        """ 
    957957        opts = model._meta 
    958958        field_list = aggregate.lookup.split(LOOKUP_SEP) 
     959        only = aggregate.only 
     960        join_list = [] 
    959961        if len(field_list) == 1 and aggregate.lookup in self.aggregates: 
    960962            # Aggregate is over an annotation 
    961963            field_name = field_list[0] 
    class Query(object): 
    964966            if not is_summary: 
    965967                raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 
    966968                    aggregate.name, field_name, field_name)) 
     969            if only: 
     970                raise FieldError("Cannot use aggregated fields in conditional aggregates") 
    967971        elif ((len(field_list) > 1) or 
    968972            (field_list[0] not in [i.name for i in opts.fields]) or 
    969973            self.group_by is None or 
    class Query(object): 
    995999            source = opts.get_field(field_name) 
    9961000            col = field_name 
    9971001 
     1002        if only: 
     1003            original_where = self.where 
     1004            original_having = self.having 
     1005            aggregate.condition = self.where_class() 
     1006            self.where = aggregate.condition 
     1007            self.having = self.where_class() 
     1008            original_alias_map = self.alias_map.keys()[:] 
     1009            self.add_q(only, used_aliases=set(original_alias_map))  
     1010            if original_alias_map != self.alias_map.keys(): 
     1011                raise FieldError("Aggregate's only condition can not require additional joins, Original joins: %s, joins after: %s" % (original_alias_map, self.alias_map.keys())) 
     1012            if self.having.children: 
     1013                raise FieldError("Aggregate's only condition can not reference annotated fields") 
     1014            self.having = original_having 
     1015            self.where = original_where 
    9981016        # Add the aggregate to the query 
    9991017        aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) 
    10001018 
  • django/db/models/sql/where.py

    diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
    index 2427a52..7ac0287 100644
    a b class WhereNode(tree.Node): 
    137137        it. 
    138138        """ 
    139139        lvalue, lookup_type, value_annot, params_or_value = child 
     140        additional_params = [] 
    140141        if hasattr(lvalue, 'process'): 
    141142            try: 
    142143                lvalue, params = lvalue.process(lookup_type, params_or_value, connection) 
    class WhereNode(tree.Node): 
    151152        else: 
    152153            # A smart object with an as_sql() method. 
    153154            field_sql = lvalue.as_sql(qn, connection) 
     155            if isinstance(field_sql, tuple): 
     156                # It returned also params 
     157                additional_params.extend(field_sql[1]) 
     158                field_sql = field_sql[0] 
    154159 
    155160        if value_annot is datetime.datetime: 
    156161            cast_sql = connection.ops.datetime_cast_sql() 
    class WhereNode(tree.Node): 
    159164 
    160165        if hasattr(params, 'as_sql'): 
    161166            extra, params = params.as_sql(qn, connection) 
     167            if isinstance(extra, tuple): 
     168                params = params + tuple(extra[1]) 
     169                extra = extra[0] 
    162170            cast_sql = '' 
    163171        else: 
    164172            extra = '' 
    class WhereNode(tree.Node): 
    168176            lookup_type = 'isnull' 
    169177            value_annot = True 
    170178 
     179        additional_params.extend(params) 
     180        params = additional_params 
    171181        if lookup_type in connection.operators: 
    172182            format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) 
    173183            return (format % (field_sql, 
  • tests/modeltests/aggregation/tests.py

    diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py
    index 6f68800..3a04fa9 100644
    a b import datetime 
    22from decimal import Decimal 
    33 
    44from django.db.models import Avg, Sum, Count, Max, Min 
     5from django.db.models import Q, F 
     6from django.core.exceptions import FieldError 
    57from django.test import TestCase, Approximate 
    68 
    79from models import Author, Publisher, Book, Store 
    class BaseAggregateTestCase(TestCase): 
    1618    def test_single_aggregate(self): 
    1719        vals = Author.objects.aggregate(Avg("age")) 
    1820        self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)}) 
     21        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29))) 
     22        self.assertEqual(vals, {"age__sum": 254}) 
     23        vals = Author.objects.aggregate(Sum("age", only=Q(name__icontains='jaco')|Q(name__icontains='adrian'))) 
     24        self.assertEqual(vals, {"age__sum": 69}) 
    1925 
    2026    def test_multiple_aggregates(self): 
    2127        vals = Author.objects.aggregate(Sum("age"), Avg("age")) 
    2228        self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)}) 
     29        vals = Author.objects.aggregate(Sum("age", only=Q(age__gt=29)), Avg("age")) 
     30        self.assertEqual(vals, {"age__sum": 254, "age__avg": Approximate(37.4, places=1)}) 
    2331 
    2432    def test_filter_aggregate(self): 
    2533        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age")) 
    2634        self.assertEqual(len(vals), 1) 
    2735        self.assertEqual(vals["age__sum"], 254) 
     36        vals = Author.objects.filter(age__gt=29).aggregate(Sum("age", only=Q(age__lt=29))) 
     37        # If there are no matching aggregates, then None, not 0 is the answer. 
     38        self.assertEqual(vals["age__sum"], None) 
    2839 
    2940    def test_related_aggregate(self): 
    3041        vals = Author.objects.aggregate(Avg("friends__age")) 
    3142        self.assertEqual(len(vals), 1) 
    3243        self.assertAlmostEqual(vals["friends__age__avg"], 34.07, places=2) 
    3344 
     45        vals = Author.objects.aggregate(Avg("friends__age", only=Q(age__lt=29))) 
     46        self.assertEqual(len(vals), 1) 
     47        self.assertAlmostEqual(vals["friends__age__avg"], 33.67, places=2) 
     48        vals2 = Author.objects.filter(age__lt=29).aggregate(Avg("friends__age")) 
     49        self.assertEqual(vals, vals2) 
     50 
     51        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=35))) 
     52        self.assertEqual(len(vals), 1) 
     53        self.assertAlmostEqual(vals["friends__age__avg"], 28.75, places=2) 
     54 
     55        # The average age of author's friends, whose age is lower than the authors age. 
     56        vals = Author.objects.aggregate(Avg("friends__age", only=Q(friends__age__lt=F('age')))) 
     57        self.assertEqual(len(vals), 1) 
     58        self.assertAlmostEqual(vals["friends__age__avg"], 30.43, places=2) 
     59 
    3460        vals = Book.objects.filter(rating__lt=4.5).aggregate(Avg("authors__age")) 
    3561        self.assertEqual(len(vals), 1) 
    3662        self.assertAlmostEqual(vals["authors__age__avg"], 38.2857, places=2) 
    class BaseAggregateTestCase(TestCase): 
    5177        vals = Store.objects.aggregate(Max("books__authors__age")) 
    5278        self.assertEqual(len(vals), 1) 
    5379        self.assertEqual(vals["books__authors__age__max"], 57) 
     80         
     81        vals = Store.objects.aggregate(Max("books__authors__age", only=Q(books__authors__age__lt=56))) 
     82        self.assertEqual(len(vals), 1) 
     83        self.assertEqual(vals["books__authors__age__max"], 46) 
    5484 
    5585        vals = Author.objects.aggregate(Min("book__publisher__num_awards")) 
    5686        self.assertEqual(len(vals), 1) 
    class BaseAggregateTestCase(TestCase): 
    106136            ], 
    107137            lambda b: (b.name, b.num_authors) 
    108138        ) 
     139         
     140        def raises_exception(): 
     141            list(Book.objects.annotate(num_authors=Count("authors")).annotate(num_authors2=Count("authors", only=Q(num_authors__gt=1))).order_by("name")) 
     142 
     143        self.assertRaises(FieldError, raises_exception) 
    109144 
    110145    def test_backwards_m2m_annotate(self): 
    111146        authors = Author.objects.filter(name__contains="a").annotate(Avg("book__rating")).order_by("name") 
    class BaseAggregateTestCase(TestCase): 
    192227                } 
    193228            ] 
    194229        ) 
     230        books = Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age', only=Q(authors__age__lt=35))).values('pk', 'isbn', 'mean_age') 
     231        self.assertEqual( 
     232            list(books), [ 
     233                { 
     234                    "pk": 1, 
     235                    "isbn": "159059725", 
     236                    "mean_age": 34.0, 
     237                } 
     238            ] 
     239        ) 
    195240 
    196241        books = Book.objects.filter(pk=1).annotate(mean_age=Avg("authors__age")).values("name") 
    197242        self.assertEqual( 
    class BaseAggregateTestCase(TestCase): 
    269314 
    270315        vals = Book.objects.aggregate(Count("rating", distinct=True)) 
    271316        self.assertEqual(vals, {"rating__count": 4}) 
     317        vals = Book.objects.aggregate( 
     318            low_count=Count("rating", only=Q(rating__lt=4)),  
     319            high_count=Count("rating", only=Q(rating__gte=4)) 
     320        ) 
     321        self.assertEqual(vals, {"low_count": 1, 'high_count': 5}) 
     322        vals = Book.objects.aggregate( 
     323            low_count=Count("rating", distinct=True, only=Q(rating__lt=4)),  
     324            high_count=Count("rating", distinct=True, only=Q(rating__gte=4)) 
     325        ) 
     326        self.assertEqual(vals, {"low_count": 1, 'high_count': 3}) 
    272327 
    273328    def test_fkey_aggregate(self): 
    274329        explicit = list(Author.objects.annotate(Count('book__id'))) 
    class BaseAggregateTestCase(TestCase): 
    388443            ], 
    389444            lambda p: p.name, 
    390445        ) 
     446        publishers = Publisher.objects.annotate(num_books=Count("book__id", only=Q(book__id__gt=5))).filter(num_books__gt=1, book__price__lt=Decimal("40.0")).order_by("pk") 
     447        self.assertQuerysetEqual( 
     448            publishers, [ 
     449                "Expensive Publisher", 
     450            ], 
     451            lambda p: p.name, 
     452        ) 
    391453 
    392454        publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk") 
    393455        self.assertQuerysetEqual(