Ticket #1133: query_args.patch

File query_args.patch, 9.1 KB (added by freakboy@…, 9 years ago)

Patch to add *args handling to DB queries

  • django/db/models/manager.py

     
    5050           self.creation_counter < klass._default_manager.creation_counter:
    5151                klass._default_manager = self
    5252       
    53     def _get_sql_clause(self, **kwargs):
     53    def _get_sql_clause(self, *args, **kwargs):
    5454        def quote_only_if_word(word):
    5555            if ' ' in word:
    5656                return word
     
    6666        where = kwargs.get('where') and kwargs['where'][:] or []
    6767        params = kwargs.get('params') and kwargs['params'][:] or []
    6868
     69        # Convert all the args into SQL.
     70        table_count = 0
     71        for arg in args:
     72            # check that the provided argument is a Query (i.e., it has a get_sql method)
     73            if not hasattr(arg, 'get_sql'):
     74                raise TypeError, "got unknown query argument '%s'" % str(arg)
     75
     76            tables2, join_where2, where2, params2, table_count = arg.get_sql(opts, table_count)
     77            tables.extend(tables2)
     78            where.extend(join_where2 + where2)
     79            params.extend(params2)
     80
    6981        # Convert the kwargs into SQL.
    70         tables2, join_where2, where2, params2, _ = parse_lookup(kwargs.items(), opts)
     82        tables2, join_where2, where2, params2, _ = parse_lookup(kwargs.items(), opts, table_count)
    7183        tables.extend(tables2)
    7284        where.extend(join_where2 + where2)
    7385        params.extend(params2)
     
    117129
    118130        return select, " FROM " + ",".join(tables) + (where and " WHERE " + " AND ".join(where) or "") + (order_by and " ORDER BY " + order_by or "") + limit_sql, params
    119131
    120     def get_iterator(self, **kwargs):
     132    def get_iterator(self, *args, **kwargs):
    121133        # kwargs['select'] is a dictionary, and dictionaries' key order is
    122134        # undefined, so we convert it to a list of tuples internally.
    123135        kwargs['select'] = kwargs.get('select', {}).items()
    124136
    125137        cursor = connection.cursor()
    126         select, sql, params = self._get_sql_clause(**kwargs)
     138        select, sql, params = self._get_sql_clause(*args, **kwargs)
    127139        cursor.execute("SELECT " + (kwargs.get('distinct') and "DISTINCT " or "") + ",".join(select) + sql, params)
    128140        fill_cache = kwargs.get('select_related')
    129141        index_end = len(self.klass._meta.fields)
     
    140152                    setattr(obj, k[0], row[index_end+i])
    141153                yield obj
    142154
    143     def get_list(self, **kwargs):
    144         return list(self.get_iterator(**kwargs))
     155    def get_list(self, *args, **kwargs):
     156        return list(self.get_iterator(*args, **kwargs))
    145157
    146     def get_count(self, **kwargs):
     158    def get_count(self, *args, **kwargs):
    147159        kwargs['order_by'] = []
    148160        kwargs['offset'] = None
    149161        kwargs['limit'] = None
    150162        kwargs['select_related'] = False
    151         _, sql, params = self._get_sql_clause(**kwargs)
     163        _, sql, params = self._get_sql_clause(*args, **kwargs)
    152164        cursor = connection.cursor()
    153165        cursor.execute("SELECT COUNT(*)" + sql, params)
    154166        return cursor.fetchone()[0]
    155167
    156     def get_object(self, **kwargs):
    157         obj_list = self.get_list(**kwargs)
     168    def get_object(self, *args, **kwargs):
     169        obj_list = self.get_list(*args, **kwargs)
    158170        if len(obj_list) < 1:
    159171            raise self.klass.DoesNotExist, "%s does not exist for %s" % (self.klass._meta.object_name, kwargs)
    160172        assert len(obj_list) == 1, "get_object() returned more than one %s -- it returned %s! Lookup parameters were %s" % (self.klass._meta.object_name, len(obj_list), kwargs)
    161173        return obj_list[0]
    162174
    163175    def get_in_bulk(self, *args, **kwargs):
    164         id_list = args and args[0] or kwargs['id_list']
    165         assert id_list != [], "get_in_bulk() cannot be passed an empty list."
     176        # Separate any list arguments: these will be added together to provide the id list
     177        id_args = filter(lambda arg: isinstance(arg, list), args)
     178        # Separate any non-list arguments: these are assumed to be query arguments
     179        sql_args = filter(lambda arg: not isinstance(arg, list), args)
     180
     181        id_list = id_args and id_args[0] or kwargs.get('id_list', [])
     182        assert id_list != [], "get_in_bulk() cannot be passed an empty ID list."
    166183        kwargs['where'] = ["%s.%s IN (%s)" % (backend.quote_name(self.klass._meta.db_table), backend.quote_name(self.klass._meta.pk.column), ",".join(['%s'] * len(id_list)))]
    167184        kwargs['params'] = id_list
    168         obj_list = self.get_list(**kwargs)
     185        obj_list = self.get_list(*sql_args, **kwargs)
    169186        return dict([(getattr(o, self.klass._meta.pk.attname), o) for o in obj_list])
    170187
    171     def get_values_iterator(self, **kwargs):
     188    def get_values_iterator(self, *args, **kwargs):
    172189        # select_related and select aren't supported in get_values().
    173190        kwargs['select_related'] = False
    174191        kwargs['select'] = {}
     
    180197            fields = [f.column for f in self.klass._meta.fields]
    181198
    182199        cursor = connection.cursor()
    183         _, sql, params = self._get_sql_clause(**kwargs)
     200        _, sql, params = self._get_sql_clause(*args, **kwargs)
    184201        select = ['%s.%s' % (backend.quote_name(self.klass._meta.db_table), backend.quote_name(f)) for f in fields]
    185202        cursor.execute("SELECT " + (kwargs.get('distinct') and "DISTINCT " or "") + ",".join(select) + sql, params)
    186203        while 1:
     
    190207            for row in rows:
    191208                yield dict(zip(fields, row))
    192209
    193     def get_values(self, **kwargs):
    194         return list(self.get_values_iterator(**kwargs))
     210    def get_values(self, *args, **kwargs):
     211        return list(self.get_values_iterator(*args, **kwargs))
    195212
    196     def __get_latest(self, **kwargs):
     213    def __get_latest(self, *args, **kwargs):
    197214        kwargs['order_by'] = ('-' + self.klass._meta.get_latest_by,)
    198215        kwargs['limit'] = 1
    199         return self.get_object(**kwargs)
     216        return self.get_object(*args, **kwargs)
    200217
    201218    def __get_date_list(self, field, *args, **kwargs):
     219        # Separate any string arguments: the first will be used as the kind
     220        kind_args = filter(lambda arg: isinstance(arg, str), args)
     221        # Separate any non-list arguments: these are assumed to be query arguments
     222        sql_args = filter(lambda arg: not isinstance(arg, str), args)
     223
    202224        from django.db.backends.util import typecast_timestamp
    203         kind = args and args[0] or kwargs['kind']
     225        kind = kind_args and kind_args[0] or kwargs.get('kind', "")
    204226        assert kind in ("month", "year", "day"), "'kind' must be one of 'year', 'month' or 'day'."
    205227        order = 'ASC'
    206228        if kwargs.has_key('order'):
     
    211233        if field.null:
    212234            kwargs.setdefault('where', []).append('%s.%s IS NOT NULL' % \
    213235                (backend.quote_name(self.klass._meta.db_table), backend.quote_name(field.column)))
    214         select, sql, params = self._get_sql_clause(**kwargs)
     236        select, sql, params = self._get_sql_clause(*sql_args, **kwargs)
    215237        sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \
    216238            (backend.get_date_trunc_sql(kind, '%s.%s' % (backend.quote_name(self.klass._meta.db_table),
    217239            backend.quote_name(field.column))), sql, order)
  • django/db/models/query.py

     
    190190        if kwarg_value is None:
    191191            continue
    192192        if kwarg == 'complex':
     193            if not hasattr(kwarg_value, 'get_sql'):
     194                raise TypeError, "got unknown query argument '%s'" % str(arg)   
    193195            tables2, join_where2, where2, params2, table_count = kwarg_value.get_sql(opts, table_count)
    194196            tables.extend(tables2)
    195197            join_where.extend(join_where2)
     
    212214            else:
    213215                lookup_list = lookup_list[:-1] + [opts.pk.name, 'exact']
    214216        if len(lookup_list) == 1:
    215             _throw_bad_kwarg_error(kwarg)
     217            throw_bad_kwarg_error(kwarg)
    216218        lookup_type = lookup_list.pop()
    217219        current_opts = opts # We'll be overwriting this, so keep a reference to the original opts.
    218220        current_table_alias = current_opts.db_table
  • tests/modeltests/or_lookups/models.py

     
    5454>>> Article.objects.get_list(complex=(Q(pk=1) | Q(pk=2) | Q(pk=3)))
    5555[Hello, Goodbye, Hello and goodbye]
    5656
     57>>> Article.objects.get_list(Q(headline__startswith='Hello'))
     58[Hello, Hello and goodbye]
     59
     60>>> Article.objects.get_list(Q(headline__startswith='Hello'), Q(headline__contains='bye'))
     61[Hello and goodbye]
     62
     63>>> Article.objects.get_list(Q(headline__startswith='Hello') & Q(headline__contains='bye'))
     64[Hello and goodbye]
     65
     66>>> Article.objects.get_list(Q(headline__contains='bye'), headline__startswith='Hello')
     67[Hello and goodbye]
     68
     69>>> Article.objects.get_list(Q(headline__contains='Hello') | Q(headline__contains='bye'))
     70[Hello, Goodbye, Hello and goodbye]
     71
    5772"""
Back to Top