Django

Code

Ticket #7210: 0001-Added-expression-support-for-QuerySet.update.patch

File 0001-Added-expression-support-for-QuerySet.update.patch, 8.5 kB (added by sebastian_noack, 7 months ago)
  • /dev/null

    old new  
     1from copy import deepcopy 
     2 
     3from django.db import connection 
     4from django.db.models.fields import FieldDoesNotExist 
     5from django.core.exceptions import FieldError 
     6from django.utils import tree 
     7 
     8class Expression(object): 
     9    """ 
     10    Base class for all sql expressions, expected by QuerySet.update. 
     11    """ 
     12    # Arithmetic connection types 
     13    ADD = '+' 
     14    SUB = '-' 
     15    MUL = '*' 
     16    DIV = '/' 
     17    MOD = '%' 
     18 
     19    # Logical connection types 
     20    AND = 'AND' 
     21    OR = 'OR' 
     22 
     23    def _combine(self, other, conn, node=None): 
     24        if not isinstance(other, Expression): 
     25            raise TypeError(other) 
     26        obj = node or ExpressionNode([self], conn) 
     27        obj.add(other, conn) 
     28        return obj 
     29 
     30    def __add__(self, other): 
     31        return self._combine(other, self.ADD) 
     32 
     33    def __sub__(self, other): 
     34        return self._combine(other, self.SUB) 
     35 
     36    def __mul__(self, other): 
     37        return self._combine(other, self.MUL) 
     38 
     39    def __div__(self, other): 
     40        return self._combine(other, self.DIV) 
     41 
     42    def __mod__(self, other): 
     43        return self._combine(other, self.MOD) 
     44 
     45    def __and__(self, other): 
     46        return self._combine(other, self.AND) 
     47 
     48    def __or__(self, other): 
     49        return self._combine(other, self.OR) 
     50 
     51    def __invert__(self, node=None): 
     52        obj = node or ExpressionNode([self]) 
     53        obj.negate() 
     54        return obj 
     55 
     56    def as_sql(self, field, opts, qn=None): 
     57        raise NotImplementedError 
     58 
     59class ExpressionNode(Expression, tree.Node): 
     60    def __init__(self, children=None, connector=None, negated=False): 
     61        if children and len(children) > 1 and connector in (None, self.default): 
     62            raise TypeError('You have to specify a connector.') 
     63        super(ExpressionNode, self).__init__(children, connector, negated) 
     64 
     65    def _combine(self, *args, **kwargs): 
     66        return super(ExpressionNode, self)._combine(node=deepcopy(self), *args, **kwargs) 
     67 
     68    def __invert__(self): 
     69        return super(ExpressionNode, self).__invert__(node=deepcopy(self)) 
     70 
     71    def as_sql(self, field, opts, qn=None, node=None): 
     72        if node is None: 
     73            node = self 
     74        result = [] 
     75        result_params = [] 
     76        for child in node.children: 
     77            if hasattr(child, 'as_sql'): 
     78                sql, params = child.as_sql(field, opts, qn) 
     79                format = '%s' 
     80            else: 
     81                sql, params = self.as_sql(field, opts, qn, child) 
     82                if child.negated: 
     83                    format = 'NOT %s' 
     84                else: 
     85                    format = '%s' 
     86                if len(child.children) > 1: 
     87                    format %= '(%s)' 
     88            if sql: 
     89                result.append(format % sql) 
     90                result_params.extend(params) 
     91        conn = ' %s ' % node.connector 
     92        return conn.join(result), result_params 
     93 
     94class LiteralExpr(Expression): 
     95    """ 
     96    An expression representing the given value. 
     97    """ 
     98    def __init__(self, value): 
     99        self.value = value 
     100 
     101    def as_sql(self, field, opts, qn=None): 
     102        if self.value is None: 
     103            return 'NULL', () 
     104        if hasattr(field, 'get_placeholder'): 
     105            return field.get_placeholder(self.value), (self.value,) 
     106        return '%s', (self.value,) 
     107 
     108class ColumnExpr(Expression): 
     109    """ 
     110    An expression representing the value of the given column. 
     111    """ 
     112    def __init__(self, column): 
     113        self.column = column 
     114 
     115    def as_sql(self, field, opts, qn=None): 
     116        if not qn: 
     117            qn = connection.ops.quote_name 
     118        try: 
     119            column = opts.get_field(self.column).attname 
     120        except FieldDoesNotExist: 
     121            names = opts.get_all_field_names() 
     122            raise FieldError('Cannot resolve keyword %r into field. ' 
     123                    'Choices are: %s' % (self.column, ', '.join(names))) 
     124        return '%s.%s' % (qn(opts.db_table), qn(column)), () 
     125 
     126class CurrentExpr(Expression): 
     127    """ 
     128    An expression representing the value of the current column. 
     129    """ 
     130    def as_sql(self, field, opts, qn=None): 
     131        if not qn: 
     132            qn = connection.ops.quote_name 
     133        return '%s.%s' % (qn(opts.db_table), qn(field.attname)), () 
  • a/django/db/models/sql/subqueries.py

    old new  
    88from django.db.models.sql.datastructures import RawValue, Date 
    99from django.db.models.sql.query import Query 
    1010from django.db.models.sql.where import AND 
     11from django.db.models.sql.expressions import Expression, LiteralExpr 
    1112 
    1213__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', 
    1314        'CountQuery'] 
     
    126127        result = ['UPDATE %s' % qn(table)] 
    127128        result.append('SET') 
    128129        values, update_params = [], [] 
    129         for name, val, placeholder in self.values: 
    130             if val is not None: 
    131                 values.append('%s = %s' % (qn(name), placeholder)) 
    132                 update_params.append(val) 
    133             else: 
    134                 values.append('%s = NULL' % qn(name)) 
     130        for name, sql, params in self.values: 
     131            values.append('%s = %s' % (qn(name), sql)) 
     132            update_params.extend(params) 
    135133        result.append(', '.join(values)) 
    136134        where, params = self.where.as_sql() 
    137135        if where: 
     
    207205            self.where.add((None, f.column, f, 'in', 
    208206                    pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), 
    209207                    AND) 
    210             self.values = [(related_field.column, None, '%s')] 
     208            self.values = [(related_field.column, 'NULL', ())] 
    211209            self.execute_sql(None) 
    212210 
    213211    def add_update_values(self, values): 
     
    232230        """ 
    233231        from django.db.models.base import Model 
    234232        for field, model, val in values_seq: 
    235             # FIXME: Some sort of db_prep_* is probably more appropriate here. 
    236             if field.rel and isinstance(val, Model): 
    237                 val = val.pk 
     233            if isinstance(val, Expression): 
     234                expr = val 
     235            elif field.rel and isinstance(val, Model):  # FIXME: Some sort of 
     236                expr = LiteralExpr(val.pk)              # db_prep_* is probably 
     237            else:                                       # more appropriate here. 
     238                expr = LiteralExpr(val) 
    238239 
    239             # Getting the placeholder for the field. 
    240             if hasattr(field, 'get_placeholder'): 
    241                 placeholder = field.get_placeholder(val) 
    242             else: 
    243                 placeholder = '%s' 
     240            sql, params = expr.as_sql( 
     241                field, self.get_meta(), self.connection.ops.quote_name) 
    244242 
    245243            if model: 
    246                 self.add_related_update(model, field.column, val, placeholder
     244                self.add_related_update(model, field.column, sql, params
    247245            else: 
    248                 self.values.append((field.column, val, placeholder)) 
     246                self.values.append((field.column, sql, params)) 
    249247 
    250     def add_related_update(self, model, column, value, placeholder): 
     248    def add_related_update(self, model, column, sql, params): 
    251249        """ 
    252250        Adds (name, value) to an update query for an ancestor model. 
    253251 
    254252        Updates are coalesced so that we only run one update query per ancestor. 
    255253        """ 
    256254        try: 
    257             self.related_updates[model].append((column, value, placeholder)) 
     255            self.related_updates[model].append((column, sql, params)) 
    258256        except KeyError: 
    259             self.related_updates[model] = [(column, value, placeholder)] 
     257            self.related_updates[model] = [(column, sql, params)] 
    260258 
    261259    def get_related_updates(self): 
    262260        """