Ticket #3566: aggregate.2.diff

File aggregate.2.diff, 8.6 KB (added by nicolas, 7 years ago)
  • django/db/models/sql/query.py

     
    5656        self.start_meta = None
    5757
    5858        # SQL-related attributes
     59        self.aggregates = []
    5960        self.select = []
    6061        self.tables = []    # Aliases in the order they are created.
    6162        self.where = where()
     
    141142        obj.standard_ordering = self.standard_ordering
    142143        obj.start_meta = self.start_meta
    143144        obj.select = self.select[:]
     145        obj.aggregates = self.aggregates[:]
    144146        obj.tables = self.tables[:]
    145147        obj.where = deepcopy(self.where)
    146148        obj.where_class = self.where_class
     
    174176                    row = self.resolve_columns(row, fields)
    175177                yield row
    176178
     179    def get_aggregation(self):
     180        for field in self.select:
     181            self.group_by.append(field)
     182        self.select.extend(self.aggregates)
     183        self.aggregates = []
     184        #print self.as_sql()
     185        #print 'after', self.select
     186
     187        get_name = lambda x : isinstance(x, tuple) and x[1] or x.aliased_name
     188
     189        print 'final query', self.as_sql()
     190
     191        if self.group_by:
     192            data = self.execute_sql(MULTI)
     193            result = []
     194            for rs in data.next():
     195                result.append(dict(zip([get_name(i) for i in self.select], rs)))
     196        else:
     197            data = self.execute_sql(SINGLE)
     198            result = dict(zip([get_name(i) for i in self.select], data))
     199
     200        self.select = []
     201        return result
     202
    177203    def get_count(self):
    178204        """
    179205        Performs a COUNT() query using the current filter constraints.
     
    811837            self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
    812838                    used, next, restricted)
    813839
     840    def annotate(self, aggregate_expr, aliased_name, model):
     841        field_list = aggregate_expr.split(LOOKUP_SEP)
     842        opts = model._meta
     843
     844        aggregate_func = field_list.pop()
     845       
     846        if len(field_list) > 1:
     847            field, target, opts, join_list, last = self.setup_joins(
     848                field_list, opts, self.get_initial_alias(), False)
     849            final = len(join_list)
     850            penultimate = last.pop()
     851            if penultimate == final:
     852                penultimate = last.pop()
     853            if len(join_list) > 1:
     854                extra = join_list[penultimate:]
     855                final = penultimate
     856                col = self.alias_map[extra[0]][LHS_JOIN_COL]
     857            else:
     858                col = target.column
     859               
     860            field_name = field_list.pop()
     861            alias = join_list[-1]
     862            alias = extra[final]
     863        else:
     864            field_name = field_list[0]
     865            alias = opts.db_table
     866         
     867
     868    def add_aggregate(self, aggregate_expr, aliased_name, model):
     869        """
     870        Adds a single aggregate expression to the Query
     871        """
     872       
     873        field_list = aggregate_expr.split(LOOKUP_SEP)
     874        opts = model._meta
     875
     876        aggregate_func = field_list.pop()
     877       
     878        if len(field_list) > 1:
     879            field, target, opts, join_list, last = self.setup_joins(
     880                field_list, opts, self.get_initial_alias(), False)
     881            final = len(join_list)
     882            penultimate = last.pop()
     883            if penultimate == final:
     884                penultimate = last.pop()
     885            if len(join_list) > 1:
     886                extra = join_list[penultimate:]
     887                final = penultimate
     888                col = self.alias_map[extra[0]][LHS_JOIN_COL]
     889            else:
     890                col = target.column
     891               
     892            field_name = field_list.pop()
     893            alias = join_list[-1]
     894            alias = extra[final]
     895        else:
     896            field_name = field_list[0]
     897            alias = opts.db_table
     898
     899        class AggregateNode:
     900            def __init__(self, field_name, aggregate_func, aliased_name, alias):
     901                self.field_name = field_name
     902                self.aggregate_func = aggregate_func
     903                self.aliased_name = aliased_name
     904                self.alias = alias
     905               
     906            def as_sql(self, quote_func=None):
     907                if not quote_func:
     908                    quote_func = lambda x: x
     909                return '%s(%s.%s)' % (self.aggregate_func.upper(),
     910                                      quote_func(self.alias),
     911                                      quote_func(self.field_name))
     912
     913        self.aggregates.append(AggregateNode(field_name, aggregate_func, aliased_name, alias))
     914       
    814915    def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
    815916            single_filter=False):
    816917        """
     
    829930        if not parts:
    830931            raise FieldError("Cannot parse keyword query %r" % arg)
    831932
     933        # if arg in (x.aliased_name for x in self.aggregates):
     934        #     self.having.append(arg)
     935        #     return
     936
    832937        # Work out the lookup type and remove it from 'parts', if necessary.
    833938        if len(parts) == 1 or parts[-1] not in self.query_terms:
    834939            lookup_type = 'exact'
  • django/db/models/query.py

     
    158158                setattr(obj, k, row[i])
    159159            yield obj
    160160
     161    def aggregate(self, *args, **kwargs):
     162        """
     163        Returns the aggregation over the current model as a
     164        dictionary.
     165
     166        When applied to a ValuesQuerySet the results are GROUP BY-ed
     167        by the fields specified in the values queryset.
     168
     169        The kwargs are parsed as expression='alias'.
     170
     171        If args is present the expression is passed as a kwarg with
     172        itself as an alias.
     173        """
     174        #Bug (or is it?): when doing both an aggregation on a related
     175        #field and one on a 'local' field the local one goes
     176        #wrong. something similar to: SELECT SUM(a.f1) FROM a INNER JOIN b;
     177        #the value gets aggregated more than one time.
     178
     179        if args:
     180            newargs = {}
     181            for arg in args:
     182                newargs[arg] = arg
     183            kwargs.update(newargs)
     184           
     185        for (aggregate_expr, alias) in kwargs.items():
     186            self.query.add_aggregate(aggregate_expr, alias, self.model)
     187        return self.query.get_aggregation()
     188
    161189    def count(self):
    162190        """
    163191        Performs a SELECT COUNT() and returns the number of records as an
     
    326354        """
    327355        return self._clone(klass=EmptyQuerySet)
    328356
     357    def annotate(self, *args, **kwargs):
     358        # Fix: Values is not working propperly
     359        # Suffers from the same bug as aggrgate
     360        # To-Do: HAVING
     361
     362        if args:
     363            newargs = {}
     364            for arg in args:
     365                newargs[arg] = arg
     366            kwargs.update(newargs)
     367
     368        opts = self.model._meta
     369        fields = []
     370       
     371        if isinstance(self, ValuesQuerySet):
     372            obj = self._clone()
     373        else:
     374            fields.extend([f.name for f in opts.fields])
     375            obj = self._clone(klass=ValuesQuerySet, setup=True, _fields=fields)
     376           
     377
     378        for (aggregate_expr, alias) in kwargs.items():
     379            obj.query.add_aggregate(aggregate_expr, alias, self.model)
     380
     381        return obj
     382
    329383    ##################################################################
    330384    # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
    331385    ##################################################################
     
    335389        Returns a new QuerySet that is a copy of the current one. This allows a
    336390        QuerySet to proxy for a model manager in some cases.
    337391        """
    338         return self._clone()
     392        return self._clone()       
    339393
    340394    def filter(self, *args, **kwargs):
    341395        """
     
    488542        # names of the model fields to select.
    489543
    490544    def __iter__(self):
     545        if self.query.aggregates:
     546            return self.aggregate_iterator()
    491547        return self.iterator()
    492548
     549    def aggregate_iterator(self):
     550        #Not lazy.. review
     551        for i in self.query.get_aggregation():
     552            yield i
     553
    493554    def iterator(self):
    494555        self.query.trim_extra_select(self.extra_names)
    495556        names = self.query.extra_select.keys() + self.field_names
    496557        for row in self.query.results_iter():
    497558            yield dict(zip(names, row))
    498 
     559 
    499560    def _setup_query(self):
    500561        """
    501562        Constructs the field_names list that the values query will be
Back to Top