Code

Ticket #7210: 0001-Added-expression-support-for-QuerySet.update.2.patch

File 0001-Added-expression-support-for-QuerySet.update.2.patch, 8.2 KB (added by sebastian_noack, 6 years ago)
  • new file django/db/models/sql/expressions.py

    From 2555dc3f525243548e8476bd1fe474171a350531 Mon Sep 17 00:00:00 2001
    From: Sebastian Noack <sebastian.noack@gmail.com>
    Date: Thu, 8 May 2008 14:30:19 +0200
    Subject: [PATCH] Added expression support for QuerySet.update.
    
    ---
     django/db/models/sql/expressions.py |  127 +++++++++++++++++++++++++++++++++++
     django/db/models/sql/subqueries.py  |   38 +++++------
     2 files changed, 145 insertions(+), 20 deletions(-)
     create mode 100644 django/db/models/sql/expressions.py
    
    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    new file mode 100644
    index 0000000..1532479
    - +  
     1from copy import deepcopy 
     2 
     3from django.db import connection 
     4from django.db.models.fields import FieldDoesNotExist 
     5from django.core.exceptions import FieldError 
     6from django.utils import tree 
     7 
     8class Expression(object): 
     9    """ 
     10    Base class for all sql expressions, expected by QuerySet.update. 
     11    """ 
     12    # Arithmetic connection types 
     13    ADD = '+' 
     14    SUB = '-' 
     15    MUL = '*' 
     16    DIV = '/' 
     17    MOD = '%' 
     18 
     19    # Logical connection types 
     20    AND = 'AND' 
     21    OR = 'OR' 
     22 
     23    def _combine(self, other, conn, node=None): 
     24        obj = node or ExpressionNode([self], conn) 
     25        if isinstance(other, Expression): 
     26            obj.add(other, conn) 
     27        else: 
     28            obj.add(L(other), conn) 
     29        return obj 
     30 
     31    def __add__(self, other): 
     32        return self._combine(other, self.ADD) 
     33 
     34    def __sub__(self, other): 
     35        return self._combine(other, self.SUB) 
     36 
     37    def __mul__(self, other): 
     38        return self._combine(other, self.MUL) 
     39 
     40    def __div__(self, other): 
     41        return self._combine(other, self.DIV) 
     42 
     43    def __mod__(self, other): 
     44        return self._combine(other, self.MOD) 
     45 
     46    def __and__(self, other): 
     47        return self._combine(other, self.AND) 
     48 
     49    def __or__(self, other): 
     50        return self._combine(other, self.OR) 
     51 
     52    def __invert__(self, node=None): 
     53        obj = node or ExpressionNode([self]) 
     54        obj.negate() 
     55        return obj 
     56 
     57    def as_sql(self, field, opts, qn=None): 
     58        raise NotImplementedError 
     59 
     60class ExpressionNode(Expression, tree.Node): 
     61    default = None 
     62 
     63    def __init__(self, children=None, connector=None, negated=False): 
     64        if children is not None and len(children) > 1 and connector is None: 
     65            raise TypeError('You have to specify a connector.') 
     66        super(ExpressionNode, self).__init__(children, connector, negated) 
     67 
     68    def _combine(self, *args, **kwargs): 
     69        return super(ExpressionNode, self)._combine(node=deepcopy(self), *args, **kwargs) 
     70 
     71    def __invert__(self): 
     72        return super(ExpressionNode, self).__invert__(node=deepcopy(self)) 
     73 
     74    def as_sql(self, field, opts, qn=None, node=None): 
     75        if node is None: 
     76            node = self 
     77        result = [] 
     78        result_params = [] 
     79        for child in node.children: 
     80            if hasattr(child, 'as_sql'): 
     81                sql, params = child.as_sql(field, opts, qn) 
     82                format = '%s' 
     83            else: 
     84                sql, params = self.as_sql(field, opts, qn, child) 
     85                if child.negated: 
     86                    format = 'NOT %s' 
     87                else: 
     88                    format = '%s' 
     89                if len(child.children) > 1: 
     90                    format %= '(%s)' 
     91            if sql: 
     92                result.append(format % sql) 
     93                result_params.extend(params) 
     94        conn = ' %s ' % node.connector 
     95        return conn.join(result), result_params 
     96 
     97class L(Expression): 
     98    """ 
     99    An expression representing the given value. 
     100    """ 
     101    def __init__(self, value): 
     102        self.value = value 
     103 
     104    def as_sql(self, field, opts, qn=None): 
     105        if self.value is None: 
     106            return 'NULL', () 
     107        if hasattr(field, 'get_placeholder'): 
     108            return field.get_placeholder(self.value), (self.value,) 
     109        return '%s', (self.value,) 
     110 
     111class F(Expression): 
     112    """ 
     113    An expression representing the value of the given field. 
     114    """ 
     115    def __init__(self, name): 
     116        self.name = name 
     117 
     118    def as_sql(self, field, opts, qn=None): 
     119        if not qn: 
     120            qn = connection.ops.quote_name 
     121        try: 
     122            column = opts.get_field(self.name).attname 
     123        except FieldDoesNotExist: 
     124            names = opts.get_all_field_names() 
     125            raise FieldError('Cannot resolve keyword %r into field. ' 
     126                    'Choices are: %s' % (self.name, ', '.join(names))) 
     127        return '%s.%s' % (qn(opts.db_table), qn(column)), () 
  • django/db/models/sql/subqueries.py

    diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
    index 7385cd0..d0d5393 100644
    a b from django.db.models.sql.constants import * 
    88from django.db.models.sql.datastructures import RawValue, Date 
    99from django.db.models.sql.query import Query 
    1010from django.db.models.sql.where import AND 
     11from django.db.models.sql.expressions import Expression, L 
    1112 
    1213__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', 
    1314        'CountQuery'] 
    class UpdateQuery(Query): 
    126127        result = ['UPDATE %s' % qn(table)] 
    127128        result.append('SET') 
    128129        values, update_params = [], [] 
    129         for name, val, placeholder in self.values: 
    130             if val is not None: 
    131                 values.append('%s = %s' % (qn(name), placeholder)) 
    132                 update_params.append(val) 
    133             else: 
    134                 values.append('%s = NULL' % qn(name)) 
     130        for name, sql, params in self.values: 
     131            values.append('%s = %s' % (qn(name), sql)) 
     132            update_params.extend(params) 
    135133        result.append(', '.join(values)) 
    136134        where, params = self.where.as_sql() 
    137135        if where: 
    class UpdateQuery(Query): 
    207205            self.where.add((None, f.column, f, 'in', 
    208206                    pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), 
    209207                    AND) 
    210             self.values = [(related_field.column, None, '%s')] 
     208            self.values = [(related_field.column, 'NULL', ())] 
    211209            self.execute_sql(None) 
    212210 
    213211    def add_update_values(self, values): 
    class UpdateQuery(Query): 
    232230        """ 
    233231        from django.db.models.base import Model 
    234232        for field, model, val in values_seq: 
    235             # FIXME: Some sort of db_prep_* is probably more appropriate here. 
    236             if field.rel and isinstance(val, Model): 
    237                 val = val.pk 
     233            if isinstance(val, Expression): 
     234                expr = val 
     235            elif field.rel and isinstance(val, Model):  # FIXME: Some sort of 
     236                expr = L(val.pk)                        # db_prep_* is probably 
     237            else:                                       # more appropriate here. 
     238                expr = L(val) 
    238239 
    239             # Getting the placeholder for the field. 
    240             if hasattr(field, 'get_placeholder'): 
    241                 placeholder = field.get_placeholder(val) 
    242             else: 
    243                 placeholder = '%s' 
     240            sql, params = expr.as_sql( 
     241                field, self.get_meta(), self.connection.ops.quote_name) 
    244242 
    245243            if model: 
    246                 self.add_related_update(model, field.column, val, placeholder) 
     244                self.add_related_update(model, field.column, sql, params) 
    247245            else: 
    248                 self.values.append((field.column, val, placeholder)) 
     246                self.values.append((field.column, sql, params)) 
    249247 
    250     def add_related_update(self, model, column, value, placeholder): 
     248    def add_related_update(self, model, column, sql, params): 
    251249        """ 
    252250        Adds (name, value) to an update query for an ancestor model. 
    253251 
    254252        Updates are coalesced so that we only run one update query per ancestor. 
    255253        """ 
    256254        try: 
    257             self.related_updates[model].append((column, value, placeholder)) 
     255            self.related_updates[model].append((column, sql, params)) 
    258256        except KeyError: 
    259             self.related_updates[model] = [(column, value, placeholder)] 
     257            self.related_updates[model] = [(column, sql, params)] 
    260258 
    261259    def get_related_updates(self): 
    262260        """