Django

Code

Ticket #4260: update.diff

File update.diff, 6.9 kB (added by insin, 6 months ago)

Implements an update method using queries

  • django/db/models/query.py

    old new  
    2727    'regex', 'iregex', 
    2828) 
    2929 
     30UPDATE_TERMS = ( 
     31    'add', 'exact', 'sql', 'subtract', 
     32) 
     33 
    3034# Size of each "chunk" for get_iterator calls. 
    3135# Larger values are slightly faster at the expense of more storage space. 
    3236GET_ITERATOR_CHUNK_SIZE = 100 
     
    319323        qs._params.extend(id_list) 
    320324        return dict([(obj._get_pk_val(), obj) for obj in qs.iterator()]) 
    321325 
     326    def update(self, **kwargs): 
     327        """ 
     328        Performs a SQL UPDATE. 
     329        """ 
     330        if not kwargs: 
     331            raise ValueError('No updates were specified.') 
     332 
     333        opts = self.model._meta 
     334        db_table = connection.ops.quote_name(opts.db_table) 
     335        set_, set_params = parse_update(kwargs.items(), opts) 
     336        joins, where, where_params = self._filters.get_sql(opts) 
     337 
     338        sql = ['UPDATE %s' % db_table] 
     339 
     340        if joins: 
     341            # Special case for databases which don't allow joins in 
     342            # UPDATE statements - select primary keys and add a new 
     343            # where IN clause. 
     344            if not connection.features.allows_join_in_update: 
     345                try: 
     346                    select, pk_sql, pk_params = self._get_sql_clause() 
     347                except EmptyResultSet: 
     348                    return 
     349 
     350                cursor = connection.cursor() 
     351                pk_col = '%s.%s' % (db_table, 
     352                                    connection.ops.quote_name(opts.pk.column)) 
     353                cursor.execute('SELECT %s' % pk_col + pk_sql, pk_params) 
     354                pks = [row[0] for row in cursor.fetchall()] 
     355 
     356                # Remove join details from where and where_params 
     357                joins, where, where_params = get_non_join_sql(self._filters, opts) 
     358 
     359                try: 
     360                    where.append(get_where_clause('in', db_table + '.', 
     361                                                  opts.pk.column, pks, 
     362                                                  opts.pk.db_type())) 
     363                except EmptyResultSet: 
     364                    return 
     365                where_params.extend(opts.pk.get_db_prep_lookup('in', pks)) 
     366            else: 
     367                sql.append(' '.join(['%s %s AS %s ON %s' % (join_type, table, alias, condition) 
     368                                     for (alias, (table, join_type, condition)) in joins.items()])) 
     369 
     370        sql.append('SET %s' % ','.join(set_)) 
     371 
     372        if where: 
     373            sql.append('WHERE %s' % ' AND '.join(where)) 
     374 
     375        cursor = connection.cursor() 
     376        cursor.execute(' '.join(sql), set_params + where_params) 
     377        transaction.commit_unless_managed() 
     378 
    322379    def delete(self): 
    323380        """ 
    324381        Deletes the records in the current QuerySet. 
     
    929986        params.extend(params2) 
    930987    return joins, where, params 
    931988 
     989def parse_update(kwarg_items, opts): 
     990    set_, params = [], [] 
     991 
     992    for kwarg, value in kwarg_items: 
     993        path = kwarg.split(LOOKUP_SEPARATOR) 
     994        update_type = path.pop() 
     995        if len(path) == 0 or update_type not in UPDATE_TERMS: 
     996            path.append(update_type) 
     997            update_type = 'exact' 
     998 
     999        if len(path) < 1: 
     1000            raise TypeError('Cannot parse keyword update %r') % kwarg 
     1001 
     1002        if value is None: 
     1003            # Interpret '__exact=None' as the sql '= NULL'; otherwise, reject 
     1004            # all uses of None as a query value. 
     1005            if update_type != 'exact': 
     1006                raise ValueError('Cannot use None as a query value.') 
     1007        elif callable(value): 
     1008            value = value() 
     1009 
     1010        field = opts.get_field(path[0], many_to_many=False) 
     1011        db_field = connection.ops.quote_name(field.column) 
     1012        if update_type == 'sql': 
     1013            try: 
     1014                sql, sql_params = value 
     1015                set_.append('%s=%s' % (db_field, sql)) 
     1016                params.extend(sql_params) 
     1017            except ValueError: 
     1018                set_.append('%s=%s' % (db_field, value)) 
     1019        else: 
     1020            if update_type == 'exact': 
     1021                set_.append('%s=%%s' % db_field) 
     1022            elif update_type == 'add': 
     1023                set_.append('%s=%s+%%s' % (db_field, db_field)) 
     1024            elif update_type == 'subtract': 
     1025                set_.append('%s=%s-%%s' % (db_field, db_field)) 
     1026            params.append(field.get_db_prep_save(value)) 
     1027    return set_, params 
     1028 
     1029def get_non_join_sql(q, opts): 
     1030    # Here be dragyns 
     1031    if isinstance(q, QNot): 
     1032        try: 
     1033            joins, where, params = get_non_join_sql(q.q, opts) 
     1034            where2 = ['(NOT (%s))' % ' AND '.join(where)] 
     1035        except EmptyResultSet: 
     1036            return SortedDict(), [], [] 
     1037        return joins, where2, params 
     1038    elif isinstance(q, Q): 
     1039        return parse_lookup(exclude_joins(q.kwargs).items(), opts) 
     1040    else: 
     1041        joins, where, params = SortedDict(), [], [] 
     1042        for val in q.args: 
     1043            try: 
     1044                joins, where2, params2 = get_non_join_sql(val, opts) 
     1045                where.extend(where2) 
     1046                params.extend(params2) 
     1047            except EmptyResultSet: 
     1048                if not isinstance(q, QOr): 
     1049                    raise EmptyResultSet 
     1050        if where: 
     1051            return joins, ['(%s)' % q.operator.join(where)], params 
     1052        return joins, [], params 
     1053 
     1054def exclude_joins(kwargs): 
     1055    """ 
     1056    Returns a dict of filters consisting of those in the given dict 
     1057    which do not require joins. These are of the following form:: 
     1058 
     1059    field_name 
     1060    field_name__lookup 
     1061    """ 
     1062    filters = {} 
     1063    for kwarg, value in kwargs.items(): 
     1064        path = kwarg.split(LOOKUP_SEPARATOR) 
     1065        if len(path) == 1: 
     1066            filters[kwarg] = value 
     1067        else: 
     1068            lookup_type = path.pop() 
     1069            if len(path) == 1 and (lookup_type == 'pk' or lookup_type in QUERY_TERMS): 
     1070                filters[kwarg] = value 
     1071    return filters 
     1072 
    9321073class FieldFound(Exception): 
    9331074    "Exception used to short circuit field-finding operations." 
    9341075    pass 
  • django/db/backends/sqlite3/base.py

    old new  
    3939Database.register_adapter(decimal.Decimal, util.rev_typecast_decimal) 
    4040 
    4141class DatabaseFeatures(BaseDatabaseFeatures): 
     42    allows_join_in_update = False 
    4243    supports_constraints = False 
    4344 
    4445class DatabaseOperations(BaseDatabaseOperations): 
  • django/db/backends/__init__.py

    old new  
    4141 
    4242class BaseDatabaseFeatures(object): 
    4343    allows_group_by_ordinal = True 
     44    allows_join_in_update = True 
    4445    allows_unique_and_pk = True 
    4546    autoindexes_primary_keys = True 
    4647    inline_fk_references = True