Django

Code

Ticket #3566: aggregate.2.diff

File aggregate.2.diff, 8.6 kB (added by nicolas, 3 months ago)
  • django/db/models/sql/query.py

    old new  
    5656        self.start_meta = None 
    5757 
    5858        # SQL-related attributes 
     59        self.aggregates = [] 
    5960        self.select = [] 
    6061        self.tables = []    # Aliases in the order they are created. 
    6162        self.where = where() 
     
    141142        obj.standard_ordering = self.standard_ordering 
    142143        obj.start_meta = self.start_meta 
    143144        obj.select = self.select[:] 
     145        obj.aggregates = self.aggregates[:] 
    144146        obj.tables = self.tables[:] 
    145147        obj.where = deepcopy(self.where) 
    146148        obj.where_class = self.where_class 
     
    174176                    row = self.resolve_columns(row, fields) 
    175177                yield row 
    176178 
     179    def get_aggregation(self): 
     180        for field in self.select: 
     181            self.group_by.append(field) 
     182        self.select.extend(self.aggregates) 
     183        self.aggregates = [] 
     184        #print self.as_sql() 
     185        #print 'after', self.select 
     186 
     187        get_name = lambda x : isinstance(x, tuple) and x[1] or x.aliased_name 
     188 
     189        print 'final query', self.as_sql()  
     190 
     191        if self.group_by: 
     192            data = self.execute_sql(MULTI) 
     193            result = [] 
     194            for rs in data.next(): 
     195                result.append(dict(zip([get_name(i) for i in self.select], rs))) 
     196        else: 
     197            data = self.execute_sql(SINGLE) 
     198            result = dict(zip([get_name(i) for i in self.select], data)) 
     199 
     200        self.select = [] 
     201        return result 
     202 
    177203    def get_count(self): 
    178204        """ 
    179205        Performs a COUNT() query using the current filter constraints. 
     
    811837            self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, 
    812838                    used, next, restricted) 
    813839 
     840    def annotate(self, aggregate_expr, aliased_name, model): 
     841        field_list = aggregate_expr.split(LOOKUP_SEP) 
     842        opts = model._meta 
     843 
     844        aggregate_func = field_list.pop() 
     845         
     846        if len(field_list) > 1: 
     847            field, target, opts, join_list, last = self.setup_joins( 
     848                field_list, opts, self.get_initial_alias(), False) 
     849            final = len(join_list) 
     850            penultimate = last.pop() 
     851            if penultimate == final: 
     852                penultimate = last.pop() 
     853            if len(join_list) > 1: 
     854                extra = join_list[penultimate:] 
     855                final = penultimate 
     856                col = self.alias_map[extra[0]][LHS_JOIN_COL] 
     857            else: 
     858                col = target.column 
     859                 
     860            field_name = field_list.pop() 
     861            alias = join_list[-1] 
     862            alias = extra[final] 
     863        else: 
     864            field_name = field_list[0] 
     865            alias = opts.db_table 
     866           
     867 
     868    def add_aggregate(self, aggregate_expr, aliased_name, model): 
     869        """ 
     870        Adds a single aggregate expression to the Query 
     871        """ 
     872         
     873        field_list = aggregate_expr.split(LOOKUP_SEP) 
     874        opts = model._meta 
     875 
     876        aggregate_func = field_list.pop() 
     877         
     878        if len(field_list) > 1: 
     879            field, target, opts, join_list, last = self.setup_joins( 
     880                field_list, opts, self.get_initial_alias(), False) 
     881            final = len(join_list) 
     882            penultimate = last.pop() 
     883            if penultimate == final: 
     884                penultimate = last.pop() 
     885            if len(join_list) > 1: 
     886                extra = join_list[penultimate:] 
     887                final = penultimate 
     888                col = self.alias_map[extra[0]][LHS_JOIN_COL] 
     889            else: 
     890                col = target.column 
     891                 
     892            field_name = field_list.pop() 
     893            alias = join_list[-1] 
     894            alias = extra[final] 
     895        else: 
     896            field_name = field_list[0] 
     897            alias = opts.db_table 
     898 
     899        class AggregateNode: 
     900            def __init__(self, field_name, aggregate_func, aliased_name, alias): 
     901                self.field_name = field_name 
     902                self.aggregate_func = aggregate_func 
     903                self.aliased_name = aliased_name 
     904                self.alias = alias 
     905                 
     906            def as_sql(self, quote_func=None): 
     907                if not quote_func: 
     908                    quote_func = lambda x: x 
     909                return '%s(%s.%s)' % (self.aggregate_func.upper(), 
     910                                      quote_func(self.alias), 
     911                                      quote_func(self.field_name)) 
     912 
     913        self.aggregates.append(AggregateNode(field_name, aggregate_func, aliased_name, alias)) 
     914         
    814915    def add_filter(self, filter_expr, connector=AND, negate=False, trim=False, 
    815916            single_filter=False): 
    816917        """ 
     
    829930        if not parts: 
    830931            raise FieldError("Cannot parse keyword query %r" % arg) 
    831932 
     933        # if arg in (x.aliased_name for x in self.aggregates): 
     934        #     self.having.append(arg) 
     935        #     return 
     936 
    832937        # Work out the lookup type and remove it from 'parts', if necessary. 
    833938        if len(parts) == 1 or parts[-1] not in self.query_terms: 
    834939            lookup_type = 'exact' 
  • django/db/models/query.py

    old new  
    158158                setattr(obj, k, row[i]) 
    159159            yield obj 
    160160 
     161    def aggregate(self, *args, **kwargs): 
     162        """ 
     163        Returns the aggregation over the current model as a 
     164        dictionary. 
     165 
     166        When applied to a ValuesQuerySet the results are GROUP BY-ed 
     167        by the fields specified in the values queryset. 
     168 
     169        The kwargs are parsed as expression='alias'. 
     170 
     171        If args is present the expression is passed as a kwarg with 
     172        itself as an alias. 
     173        """ 
     174        #Bug (or is it?): when doing both an aggregation on a related 
     175        #field and one on a 'local' field the local one goes 
     176        #wrong. something similar to: SELECT SUM(a.f1) FROM a INNER JOIN b; 
     177        #the value gets aggregated more than one time. 
     178 
     179        if args: 
     180            newargs = {} 
     181            for arg in args: 
     182                newargs[arg] = arg 
     183            kwargs.update(newargs) 
     184             
     185        for (aggregate_expr, alias) in kwargs.items(): 
     186            self.query.add_aggregate(aggregate_expr, alias, self.model) 
     187        return self.query.get_aggregation() 
     188 
    161189    def count(self): 
    162190        """ 
    163191        Performs a SELECT COUNT() and returns the number of records as an 
     
    326354        """ 
    327355        return self._clone(klass=EmptyQuerySet) 
    328356 
     357    def annotate(self, *args, **kwargs): 
     358        # Fix: Values is not working propperly 
     359        # Suffers from the same bug as aggrgate 
     360        # To-Do: HAVING 
     361 
     362        if args: 
     363            newargs = {} 
     364            for arg in args: 
     365                newargs[arg] = arg 
     366            kwargs.update(newargs) 
     367 
     368        opts = self.model._meta 
     369        fields = [] 
     370         
     371        if isinstance(self, ValuesQuerySet): 
     372            obj = self._clone() 
     373        else: 
     374            fields.extend([f.name for f in opts.fields]) 
     375            obj = self._clone(klass=ValuesQuerySet, setup=True, _fields=fields) 
     376             
     377 
     378        for (aggregate_expr, alias) in kwargs.items(): 
     379            obj.query.add_aggregate(aggregate_expr, alias, self.model) 
     380 
     381        return obj 
     382 
    329383    ################################################################## 
    330384    # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET # 
    331385    ################################################################## 
     
    335389        Returns a new QuerySet that is a copy of the current one. This allows a 
    336390        QuerySet to proxy for a model manager in some cases. 
    337391        """ 
    338         return self._clone() 
     392        return self._clone()         
    339393 
    340394    def filter(self, *args, **kwargs): 
    341395        """ 
     
    488542        # names of the model fields to select. 
    489543 
    490544    def __iter__(self): 
     545        if self.query.aggregates: 
     546            return self.aggregate_iterator() 
    491547        return self.iterator() 
    492548 
     549    def aggregate_iterator(self): 
     550        #Not lazy.. review 
     551        for i in self.query.get_aggregation(): 
     552            yield i 
     553 
    493554    def iterator(self): 
    494555        self.query.trim_extra_select(self.extra_names) 
    495556        names = self.query.extra_select.keys() + self.field_names 
    496557        for row in self.query.results_iter(): 
    497558            yield dict(zip(names, row)) 
    498  
     559  
    499560    def _setup_query(self): 
    500561        """ 
    501562        Constructs the field_names list that the values query will be