Django

Code

Ticket #3566: queryset_modular_aggregates.diff

File queryset_modular_aggregates.diff, 5.4 kB (added by jbronn, 1 year ago)

Makes aggregation functionality more amenable to subclassing.

  • a/django/db/models/sql/query.py

    old new  
    1616from django.db.models import signals 
    1717from django.db.models.fields import FieldDoesNotExist 
    1818from django.db.models.query_utils import select_related_descend 
    19 from django.db.models.sql import aggregates 
     19from django.db.models.sql import aggregates as aggregates_module 
    2020from django.db.models.sql.where import WhereNode, EverythingNode, AND, OR 
    2121from django.core.exceptions import FieldError 
    2222from datastructures import EmptyResultSet, Empty, MultiJoin 
     
    4040 
    4141    alias_prefix = 'T' 
    4242    query_terms = QUERY_TERMS 
    43  
     43    aggregates = aggregates_module 
     44     
    4445    def __init__(self, model, connection, where=WhereNode): 
    4546        self.model = model 
    4647        self.connection = connection 
     
    198199            obj._setup_query() 
    199200        return obj 
    200201 
     202    def normalize(self, aggregate, value): 
     203        """ 
     204        Returns a normalized Python object from the given aggregate object 
     205        and raw database value. 
     206        """ 
     207        return self.connection.ops.db_aggregate_to_value(aggregate, value) 
     208 
    201209    def results_iter(self): 
    202210        """ 
    203211        Returns an iterator over the results from executing this query. 
    204212        """ 
    205213        resolve_columns = hasattr(self, 'resolve_columns') 
    206214        fields = None 
    207         normalize = self.connection.ops.db_aggregate_to_value 
    208215        for rows in self.execute_sql(MULTI): 
    209216            for row in rows: 
    210217                if resolve_columns: 
     
    221228                if self.aggregate_select: 
    222229                    aggregate_start = len(self.extra_select.keys()) + len(self.select) 
    223230                    row = row[:aggregate_start] + tuple( 
    224                         normalize(aggregate, value) 
     231                        self.normalize(aggregate, value) 
    225232                        for (alias, aggregate), value 
    226233                        in zip(self.aggregate_select.items(), row[aggregate_start:]) 
    227234                    ) 
     
    263270        query.select_related = False 
    264271        query.related_select_cols = [] 
    265272        query.related_select_fields = [] 
    266  
    267         normalize = self.connection.ops.db_aggregate_to_value 
     273         
    268274        return dict( 
    269             (alias, normalize(aggregate, val)) 
     275            (alias, self.normalize(aggregate, val)) 
    270276            for (alias, aggregate), val 
    271277            in zip(query.aggregate_select.items(), query.execute_sql(SINGLE)) 
    272278        ) 
     
    11781184            aggregate_expr.lookup in self.aggregate_select.keys()): 
    11791185            # Aggregate is over an annotation 
    11801186            field_name = field_list[0] 
    1181             aggregate = aggregate_expr.add_to_query(self, aggregates, 
     1187            aggregate = aggregate_expr.add_to_query(self, self.aggregates, 
    11821188                col=field_name, 
    11831189                source=self.aggregate_select[field_name], 
    11841190                is_summary=is_summary) 
     
    11971203            for column_alias in join_list: 
    11981204                self.promote_alias(column_alias, unconditional=True) 
    11991205 
    1200             aggregate = aggregate_expr.add_to_query(self, aggregates, 
     1206            aggregate = aggregate_expr.add_to_query(self, self.aggregates, 
    12011207                col=(join_list[-1], col), 
    12021208                source=target, 
    12031209                is_summary=is_summary) 
     
    12111217                col = (opts.db_table, field.column) 
    12121218            else: 
    12131219                col = field_name 
    1214             aggregate = aggregate_expr.add_to_query(self, aggregates, 
     1220            aggregate = aggregate_expr.add_to_query(self, self.aggregates, 
    12151221                col=col, 
    12161222                source=field, 
    12171223                is_summary=is_summary) 
     
    17901796        """ 
    17911797        if not self.distinct: 
    17921798            if not self.select: 
    1793                 count = aggregates.Count('*', is_summary=True) 
     1799                count = self.aggregates.Count('*', is_summary=True) 
    17941800            else: 
    17951801                assert len(self.select) == 1, \ 
    17961802                        "Cannot add count col with multiple cols in 'select': %r" % self.select 
    1797                 count = aggregates.Count(self.select[0]) 
     1803                count = self.aggregates.Count(self.select[0]) 
    17981804        else: 
    17991805            opts = self.model._meta 
    18001806            if not self.select: 
    1801                 count = aggregates.Count((self.join((None, opts.db_table, None, None)), opts.pk.column), 
     1807                count = self.aggregates.Count((self.join((None, opts.db_table, None, None)), opts.pk.column), 
    18021808                                         is_summary=True, distinct=True) 
    18031809            else: 
    18041810                # Because of SQL portability issues, multi-column, distinct 
     
    18061812                assert len(self.select) == 1, \ 
    18071813                        "Cannot add count col with multiple cols in 'select'." 
    18081814 
    1809                 count = aggregates.Count(self.select[0], distinct=True) 
     1815                count = self.aggregates.Count(self.select[0], distinct=True) 
    18101816            # Distinct handling is done in Count(), so don't do it at this 
    18111817            # level. 
    18121818            self.distinct = False