Ticket #4260: update.diff

File update.diff, 6.9 KB (added by Jonathan Buchanan, 16 years ago)

Implements an update method using queries

  • django/db/models/query.py

     
    2727    'regex', 'iregex',
    2828)
    2929
     30UPDATE_TERMS = (
     31    'add', 'exact', 'sql', 'subtract',
     32)
     33
    3034# Size of each "chunk" for get_iterator calls.
    3135# Larger values are slightly faster at the expense of more storage space.
    3236GET_ITERATOR_CHUNK_SIZE = 100
     
    319323        qs._params.extend(id_list)
    320324        return dict([(obj._get_pk_val(), obj) for obj in qs.iterator()])
    321325
     326    def update(self, **kwargs):
     327        """
     328        Performs a SQL UPDATE.
     329        """
     330        if not kwargs:
     331            raise ValueError('No updates were specified.')
     332
     333        opts = self.model._meta
     334        db_table = connection.ops.quote_name(opts.db_table)
     335        set_, set_params = parse_update(kwargs.items(), opts)
     336        joins, where, where_params = self._filters.get_sql(opts)
     337
     338        sql = ['UPDATE %s' % db_table]
     339
     340        if joins:
     341            # Special case for databases which don't allow joins in
     342            # UPDATE statements - select primary keys and add a new
     343            # where IN clause.
     344            if not connection.features.allows_join_in_update:
     345                try:
     346                    select, pk_sql, pk_params = self._get_sql_clause()
     347                except EmptyResultSet:
     348                    return
     349
     350                cursor = connection.cursor()
     351                pk_col = '%s.%s' % (db_table,
     352                                    connection.ops.quote_name(opts.pk.column))
     353                cursor.execute('SELECT %s' % pk_col + pk_sql, pk_params)
     354                pks = [row[0] for row in cursor.fetchall()]
     355
     356                # Remove join details from where and where_params
     357                joins, where, where_params = get_non_join_sql(self._filters, opts)
     358
     359                try:
     360                    where.append(get_where_clause('in', db_table + '.',
     361                                                  opts.pk.column, pks,
     362                                                  opts.pk.db_type()))
     363                except EmptyResultSet:
     364                    return
     365                where_params.extend(opts.pk.get_db_prep_lookup('in', pks))
     366            else:
     367                sql.append(' '.join(['%s %s AS %s ON %s' % (join_type, table, alias, condition)
     368                                     for (alias, (table, join_type, condition)) in joins.items()]))
     369
     370        sql.append('SET %s' % ','.join(set_))
     371
     372        if where:
     373            sql.append('WHERE %s' % ' AND '.join(where))
     374
     375        cursor = connection.cursor()
     376        cursor.execute(' '.join(sql), set_params + where_params)
     377        transaction.commit_unless_managed()
     378
    322379    def delete(self):
    323380        """
    324381        Deletes the records in the current QuerySet.
     
    929986        params.extend(params2)
    930987    return joins, where, params
    931988
     989def parse_update(kwarg_items, opts):
     990    set_, params = [], []
     991
     992    for kwarg, value in kwarg_items:
     993        path = kwarg.split(LOOKUP_SEPARATOR)
     994        update_type = path.pop()
     995        if len(path) == 0 or update_type not in UPDATE_TERMS:
     996            path.append(update_type)
     997            update_type = 'exact'
     998
     999        if len(path) < 1:
     1000            raise TypeError('Cannot parse keyword update %r') % kwarg
     1001
     1002        if value is None:
     1003            # Interpret '__exact=None' as the sql '= NULL'; otherwise, reject
     1004            # all uses of None as a query value.
     1005            if update_type != 'exact':
     1006                raise ValueError('Cannot use None as a query value.')
     1007        elif callable(value):
     1008            value = value()
     1009
     1010        field = opts.get_field(path[0], many_to_many=False)
     1011        db_field = connection.ops.quote_name(field.column)
     1012        if update_type == 'sql':
     1013            try:
     1014                sql, sql_params = value
     1015                set_.append('%s=%s' % (db_field, sql))
     1016                params.extend(sql_params)
     1017            except ValueError:
     1018                set_.append('%s=%s' % (db_field, value))
     1019        else:
     1020            if update_type == 'exact':
     1021                set_.append('%s=%%s' % db_field)
     1022            elif update_type == 'add':
     1023                set_.append('%s=%s+%%s' % (db_field, db_field))
     1024            elif update_type == 'subtract':
     1025                set_.append('%s=%s-%%s' % (db_field, db_field))
     1026            params.append(field.get_db_prep_save(value))
     1027    return set_, params
     1028
     1029def get_non_join_sql(q, opts):
     1030    # Here be dragyns
     1031    if isinstance(q, QNot):
     1032        try:
     1033            joins, where, params = get_non_join_sql(q.q, opts)
     1034            where2 = ['(NOT (%s))' % ' AND '.join(where)]
     1035        except EmptyResultSet:
     1036            return SortedDict(), [], []
     1037        return joins, where2, params
     1038    elif isinstance(q, Q):
     1039        return parse_lookup(exclude_joins(q.kwargs).items(), opts)
     1040    else:
     1041        joins, where, params = SortedDict(), [], []
     1042        for val in q.args:
     1043            try:
     1044                joins, where2, params2 = get_non_join_sql(val, opts)
     1045                where.extend(where2)
     1046                params.extend(params2)
     1047            except EmptyResultSet:
     1048                if not isinstance(q, QOr):
     1049                    raise EmptyResultSet
     1050        if where:
     1051            return joins, ['(%s)' % q.operator.join(where)], params
     1052        return joins, [], params
     1053
     1054def exclude_joins(kwargs):
     1055    """
     1056    Returns a dict of filters consisting of those in the given dict
     1057    which do not require joins. These are of the following form::
     1058
     1059    field_name
     1060    field_name__lookup
     1061    """
     1062    filters = {}
     1063    for kwarg, value in kwargs.items():
     1064        path = kwarg.split(LOOKUP_SEPARATOR)
     1065        if len(path) == 1:
     1066            filters[kwarg] = value
     1067        else:
     1068            lookup_type = path.pop()
     1069            if len(path) == 1 and (lookup_type == 'pk' or lookup_type in QUERY_TERMS):
     1070                filters[kwarg] = value
     1071    return filters
     1072
    9321073class FieldFound(Exception):
    9331074    "Exception used to short circuit field-finding operations."
    9341075    pass
  • django/db/backends/sqlite3/base.py

     
    3939Database.register_adapter(decimal.Decimal, util.rev_typecast_decimal)
    4040
    4141class DatabaseFeatures(BaseDatabaseFeatures):
     42    allows_join_in_update = False
    4243    supports_constraints = False
    4344
    4445class DatabaseOperations(BaseDatabaseOperations):
  • django/db/backends/__init__.py

     
    4141
    4242class BaseDatabaseFeatures(object):
    4343    allows_group_by_ordinal = True
     44    allows_join_in_update = True
    4445    allows_unique_and_pk = True
    4546    autoindexes_primary_keys = True
    4647    inline_fk_references = True
Back to Top