Code

Ticket #1133: query_args.r1799.v2.patch

File query_args.r1799.v2.patch, 8.8 KB (added by freakboy@…, 8 years ago)

Updated patch after merge to r1799

  • django/db/models/manager.py

     
    5050           self.creation_counter < klass._default_manager.creation_counter: 
    5151                klass._default_manager = self 
    5252 
    53     def _get_sql_clause(self, **kwargs): 
     53    def _get_sql_clause(self, *args, **kwargs): 
    5454        def quote_only_if_word(word): 
    5555            if ' ' in word: 
    5656                return word 
     
    6262        # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z. 
    6363        select = ["%s.%s" % (backend.quote_name(opts.db_table), backend.quote_name(f.column)) for f in opts.fields] 
    6464        tables = (kwargs.get('tables') and [quote_only_if_word(t) for t in kwargs['tables']] or []) 
     65        joins = {} 
    6566        where = kwargs.get('where') and kwargs['where'][:] or [] 
    6667        params = kwargs.get('params') and kwargs['params'][:] or [] 
    6768 
     69        # Convert all the args into SQL. 
     70        table_count = 0 
     71        for arg in args: 
     72            # check that the provided argument is a Query (i.e., it has a get_sql method) 
     73            if not hasattr(arg, 'get_sql'): 
     74                raise TypeError, "got unknown query argument '%s'" % str(arg) 
     75 
     76            tables2, joins2, where2, params2 = arg.get_sql(opts) 
     77            tables.extend(tables2) 
     78            joins.update(joins2) 
     79            where.extend(where2) 
     80            params.extend(params2) 
     81 
    6882        # Convert the kwargs into SQL. 
    69         tables2, joins, where2, params2 = parse_lookup(kwargs.items(), opts) 
     83        tables2, joins2, where2, params2 = parse_lookup(kwargs.items(), opts) 
    7084        tables.extend(tables2) 
     85        joins.update(joins2) 
    7186        where.extend(where2) 
    7287        params.extend(params2) 
    7388 
     
    132147 
    133148        return select, " ".join(sql), params 
    134149 
    135     def get_iterator(self, **kwargs): 
     150    def get_iterator(self, *args, **kwargs): 
    136151        # kwargs['select'] is a dictionary, and dictionaries' key order is 
    137152        # undefined, so we convert it to a list of tuples internally. 
    138153        kwargs['select'] = kwargs.get('select', {}).items() 
    139154 
    140155        cursor = connection.cursor() 
    141         select, sql, params = self._get_sql_clause(**kwargs) 
     156        select, sql, params = self._get_sql_clause(*args, **kwargs) 
    142157        cursor.execute("SELECT " + (kwargs.get('distinct') and "DISTINCT " or "") + ",".join(select) + sql, params) 
    143158        fill_cache = kwargs.get('select_related') 
    144159        index_end = len(self.klass._meta.fields) 
     
    155170                    setattr(obj, k[0], row[index_end+i]) 
    156171                yield obj 
    157172 
    158     def get_list(self, **kwargs): 
    159         return list(self.get_iterator(**kwargs)) 
     173    def get_list(self, *args, **kwargs): 
     174        return list(self.get_iterator(*args, **kwargs)) 
    160175 
    161     def get_count(self, **kwargs): 
     176    def get_count(self, *args, **kwargs): 
    162177        kwargs['order_by'] = [] 
    163178        kwargs['offset'] = None 
    164179        kwargs['limit'] = None 
    165180        kwargs['select_related'] = False 
    166         _, sql, params = self._get_sql_clause(**kwargs) 
     181        _, sql, params = self._get_sql_clause(*args, **kwargs) 
    167182        cursor = connection.cursor() 
    168183        cursor.execute("SELECT COUNT(*)" + sql, params) 
    169184        return cursor.fetchone()[0] 
    170185 
    171     def get_object(self, **kwargs): 
    172         obj_list = self.get_list(**kwargs) 
     186    def get_object(self, *args, **kwargs): 
     187        obj_list = self.get_list(*args, **kwargs) 
    173188        if len(obj_list) < 1: 
    174189            raise self.klass.DoesNotExist, "%s does not exist for %s" % (self.klass._meta.object_name, kwargs) 
    175190        assert len(obj_list) == 1, "get_object() returned more than one %s -- it returned %s! Lookup parameters were %s" % (self.klass._meta.object_name, len(obj_list), kwargs) 
    176191        return obj_list[0] 
    177192 
    178193    def get_in_bulk(self, *args, **kwargs): 
    179         id_list = args and args[0] or kwargs['id_list'] 
    180         assert id_list != [], "get_in_bulk() cannot be passed an empty list." 
     194        # Separate any list arguments: these will be added together to provide the id list 
     195        id_args = filter(lambda arg: isinstance(arg, list), args) 
     196        # Separate any non-list arguments: these are assumed to be query arguments 
     197        sql_args = filter(lambda arg: not isinstance(arg, list), args) 
     198 
     199        id_list = id_args and id_args[0] or kwargs.get('id_list', []) 
     200        assert id_list != [], "get_in_bulk() cannot be passed an empty ID list." 
    181201        kwargs['where'] = ["%s.%s IN (%s)" % (backend.quote_name(self.klass._meta.db_table), backend.quote_name(self.klass._meta.pk.column), ",".join(['%s'] * len(id_list)))] 
    182202        kwargs['params'] = id_list 
    183         obj_list = self.get_list(**kwargs) 
     203        obj_list = self.get_list(*sql_args, **kwargs) 
    184204        return dict([(getattr(o, self.klass._meta.pk.attname), o) for o in obj_list]) 
    185205 
    186     def get_values_iterator(self, **kwargs): 
     206    def get_values_iterator(self, *args, **kwargs): 
    187207        # select_related and select aren't supported in get_values(). 
    188208        kwargs['select_related'] = False 
    189209        kwargs['select'] = {} 
     
    195215            fields = [f.column for f in self.klass._meta.fields] 
    196216 
    197217        cursor = connection.cursor() 
    198         _, sql, params = self._get_sql_clause(**kwargs) 
     218        _, sql, params = self._get_sql_clause(*args, **kwargs) 
    199219        select = ['%s.%s' % (backend.quote_name(self.klass._meta.db_table), backend.quote_name(f)) for f in fields] 
    200220        cursor.execute("SELECT " + (kwargs.get('distinct') and "DISTINCT " or "") + ",".join(select) + sql, params) 
    201221        while 1: 
     
    205225            for row in rows: 
    206226                yield dict(zip(fields, row)) 
    207227 
    208     def get_values(self, **kwargs): 
    209         return list(self.get_values_iterator(**kwargs)) 
     228    def get_values(self, *args, **kwargs): 
     229        return list(self.get_values_iterator(*args, **kwargs)) 
    210230 
    211     def __get_latest(self, **kwargs): 
     231    def __get_latest(self, *args, **kwargs): 
    212232        kwargs['order_by'] = ('-' + self.klass._meta.get_latest_by,) 
    213233        kwargs['limit'] = 1 
    214         return self.get_object(**kwargs) 
     234        return self.get_object(*args, **kwargs) 
    215235 
    216236    def __get_date_list(self, field, *args, **kwargs): 
     237        # Separate any string arguments: the first will be used as the kind 
     238        kind_args = filter(lambda arg: isinstance(arg, str), args) 
     239        # Separate any non-list arguments: these are assumed to be query arguments 
     240        sql_args = filter(lambda arg: not isinstance(arg, str), args) 
     241 
    217242        from django.db.backends.util import typecast_timestamp 
    218         kind = args and args[0] or kwargs['kind'] 
     243        kind = kind_args and kind_args[0] or kwargs.get('kind', "") 
    219244        assert kind in ("month", "year", "day"), "'kind' must be one of 'year', 'month' or 'day'." 
    220245        order = 'ASC' 
    221246        if kwargs.has_key('order'): 
     
    226251        if field.null: 
    227252            kwargs.setdefault('where', []).append('%s.%s IS NOT NULL' % \ 
    228253                (backend.quote_name(self.klass._meta.db_table), backend.quote_name(field.column))) 
    229         select, sql, params = self._get_sql_clause(**kwargs) 
     254        select, sql, params = self._get_sql_clause(*sql_args, **kwargs) 
    230255        sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \ 
    231256            (backend.get_date_trunc_sql(kind, '%s.%s' % (backend.quote_name(self.klass._meta.db_table), 
    232257            backend.quote_name(field.column))), sql, order) 
  • django/db/models/query.py

     
    197197        if kwarg_value is None: 
    198198            continue 
    199199        if kwarg == 'complex': 
     200            if not hasattr(kwarg_value, 'get_sql'): 
     201                raise TypeError, "got unknown query argument '%s'" % str(arg)    
    200202            tables2, joins2, where2, params2 = kwarg_value.get_sql(opts) 
    201203            tables.extend(tables2) 
    202204            joins.update(joins2) 
  • tests/modeltests/or_lookups/models.py

     
    5454>>> Article.objects.get_list(complex=(Q(pk=1) | Q(pk=2) | Q(pk=3))) 
    5555[Hello, Goodbye, Hello and goodbye] 
    5656 
     57>>> Article.objects.get_list(Q(headline__startswith='Hello')) 
     58[Hello, Hello and goodbye] 
     59 
     60>>> Article.objects.get_list(Q(headline__startswith='Hello'), Q(headline__contains='bye')) 
     61[Hello and goodbye] 
     62 
     63>>> Article.objects.get_list(Q(headline__startswith='Hello') & Q(headline__contains='bye')) 
     64[Hello and goodbye] 
     65 
     66>>> Article.objects.get_list(Q(headline__contains='bye'), headline__startswith='Hello') 
     67[Hello and goodbye] 
     68 
     69>>> Article.objects.get_list(Q(headline__contains='Hello') | Q(headline__contains='bye')) 
     70[Hello, Goodbye, Hello and goodbye] 
     71 
    5772"""