Ticket #3566: queryset_modular_aggregates.diff

File queryset_modular_aggregates.diff, 5.4 KB (added by jbronn, 6 years ago)

Makes aggregation functionality more amenable to subclassing.

  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index 8b40edd..f736e27 100644
    a b from django.db import connection 
    1616from django.db.models import signals
    1717from django.db.models.fields import FieldDoesNotExist
    1818from django.db.models.query_utils import select_related_descend
    19 from django.db.models.sql import aggregates
     19from django.db.models.sql import aggregates as aggregates_module
    2020from django.db.models.sql.where import WhereNode, EverythingNode, AND, OR
    2121from django.core.exceptions import FieldError
    2222from datastructures import EmptyResultSet, Empty, MultiJoin
    class BaseQuery(object): 
    4040
    4141    alias_prefix = 'T'
    4242    query_terms = QUERY_TERMS
    43 
     43    aggregates = aggregates_module
     44   
    4445    def __init__(self, model, connection, where=WhereNode):
    4546        self.model = model
    4647        self.connection = connection
    class BaseQuery(object): 
    198199            obj._setup_query()
    199200        return obj
    200201
     202    def normalize(self, aggregate, value):
     203        """
     204        Returns a normalized Python object from the given aggregate object
     205        and raw database value.
     206        """
     207        return self.connection.ops.db_aggregate_to_value(aggregate, value)
     208
    201209    def results_iter(self):
    202210        """
    203211        Returns an iterator over the results from executing this query.
    204212        """
    205213        resolve_columns = hasattr(self, 'resolve_columns')
    206214        fields = None
    207         normalize = self.connection.ops.db_aggregate_to_value
    208215        for rows in self.execute_sql(MULTI):
    209216            for row in rows:
    210217                if resolve_columns:
    class BaseQuery(object): 
    221228                if self.aggregate_select:
    222229                    aggregate_start = len(self.extra_select.keys()) + len(self.select)
    223230                    row = row[:aggregate_start] + tuple(
    224                         normalize(aggregate, value)
     231                        self.normalize(aggregate, value)
    225232                        for (alias, aggregate), value
    226233                        in zip(self.aggregate_select.items(), row[aggregate_start:])
    227234                    )
    class BaseQuery(object): 
    263270        query.select_related = False
    264271        query.related_select_cols = []
    265272        query.related_select_fields = []
    266 
    267         normalize = self.connection.ops.db_aggregate_to_value
     273       
    268274        return dict(
    269             (alias, normalize(aggregate, val))
     275            (alias, self.normalize(aggregate, val))
    270276            for (alias, aggregate), val
    271277            in zip(query.aggregate_select.items(), query.execute_sql(SINGLE))
    272278        )
    class BaseQuery(object): 
    11781184            aggregate_expr.lookup in self.aggregate_select.keys()):
    11791185            # Aggregate is over an annotation
    11801186            field_name = field_list[0]
    1181             aggregate = aggregate_expr.add_to_query(self, aggregates,
     1187            aggregate = aggregate_expr.add_to_query(self, self.aggregates,
    11821188                col=field_name,
    11831189                source=self.aggregate_select[field_name],
    11841190                is_summary=is_summary)
    class BaseQuery(object): 
    11971203            for column_alias in join_list:
    11981204                self.promote_alias(column_alias, unconditional=True)
    11991205
    1200             aggregate = aggregate_expr.add_to_query(self, aggregates,
     1206            aggregate = aggregate_expr.add_to_query(self, self.aggregates,
    12011207                col=(join_list[-1], col),
    12021208                source=target,
    12031209                is_summary=is_summary)
    class BaseQuery(object): 
    12111217                col = (opts.db_table, field.column)
    12121218            else:
    12131219                col = field_name
    1214             aggregate = aggregate_expr.add_to_query(self, aggregates,
     1220            aggregate = aggregate_expr.add_to_query(self, self.aggregates,
    12151221                col=col,
    12161222                source=field,
    12171223                is_summary=is_summary)
    class BaseQuery(object): 
    17901796        """
    17911797        if not self.distinct:
    17921798            if not self.select:
    1793                 count = aggregates.Count('*', is_summary=True)
     1799                count = self.aggregates.Count('*', is_summary=True)
    17941800            else:
    17951801                assert len(self.select) == 1, \
    17961802                        "Cannot add count col with multiple cols in 'select': %r" % self.select
    1797                 count = aggregates.Count(self.select[0])
     1803                count = self.aggregates.Count(self.select[0])
    17981804        else:
    17991805            opts = self.model._meta
    18001806            if not self.select:
    1801                 count = aggregates.Count((self.join((None, opts.db_table, None, None)), opts.pk.column),
     1807                count = self.aggregates.Count((self.join((None, opts.db_table, None, None)), opts.pk.column),
    18021808                                         is_summary=True, distinct=True)
    18031809            else:
    18041810                # Because of SQL portability issues, multi-column, distinct
    class BaseQuery(object): 
    18061812                assert len(self.select) == 1, \
    18071813                        "Cannot add count col with multiple cols in 'select'."
    18081814
    1809                 count = aggregates.Count(self.select[0], distinct=True)
     1815                count = self.aggregates.Count(self.select[0], distinct=True)
    18101816            # Distinct handling is done in Count(), so don't do it at this
    18111817            # level.
    18121818            self.distinct = False
Back to Top