Ticket #7210: 0001-Added-expression-support-for-QuerySet.update.patch

File 0001-Added-expression-support-for-QuerySet.update.patch, 8.5 KB (added by Sebastian Noack, 16 years ago)
  • new file django/db/models/sql/expressions.py

    From e1b81cccb9881c21626aec01fa5b050c972a1b0c Mon Sep 17 00:00:00 2001
    From: Sebastian Noack <sebastian.noack@gmail.com>
    Date: Thu, 8 May 2008 14:30:19 +0200
    Subject: [PATCH] Added expression support for QuerySet.update.
    
    ---
     django/db/models/sql/expressions.py |  133 +++++++++++++++++++++++++++++++++++
     django/db/models/sql/subqueries.py  |   38 +++++-----
     2 files changed, 151 insertions(+), 20 deletions(-)
     create mode 100644 django/db/models/sql/expressions.py
    
    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    new file mode 100644
    index 0000000..d52cc46
    - +  
     1from copy import deepcopy
     2
     3from django.db import connection
     4from django.db.models.fields import FieldDoesNotExist
     5from django.core.exceptions import FieldError
     6from django.utils import tree
     7
     8class Expression(object):
     9    """
     10    Base class for all sql expressions, expected by QuerySet.update.
     11    """
     12    # Arithmetic connection types
     13    ADD = '+'
     14    SUB = '-'
     15    MUL = '*'
     16    DIV = '/'
     17    MOD = '%'
     18
     19    # Logical connection types
     20    AND = 'AND'
     21    OR = 'OR'
     22
     23    def _combine(self, other, conn, node=None):
     24        if not isinstance(other, Expression):
     25            raise TypeError(other)
     26        obj = node or ExpressionNode([self], conn)
     27        obj.add(other, conn)
     28        return obj
     29
     30    def __add__(self, other):
     31        return self._combine(other, self.ADD)
     32
     33    def __sub__(self, other):
     34        return self._combine(other, self.SUB)
     35
     36    def __mul__(self, other):
     37        return self._combine(other, self.MUL)
     38
     39    def __div__(self, other):
     40        return self._combine(other, self.DIV)
     41
     42    def __mod__(self, other):
     43        return self._combine(other, self.MOD)
     44
     45    def __and__(self, other):
     46        return self._combine(other, self.AND)
     47
     48    def __or__(self, other):
     49        return self._combine(other, self.OR)
     50
     51    def __invert__(self, node=None):
     52        obj = node or ExpressionNode([self])
     53        obj.negate()
     54        return obj
     55
     56    def as_sql(self, field, opts, qn=None):
     57        raise NotImplementedError
     58
     59class ExpressionNode(Expression, tree.Node):
     60    def __init__(self, children=None, connector=None, negated=False):
     61        if children and len(children) > 1 and connector in (None, self.default):
     62            raise TypeError('You have to specify a connector.')
     63        super(ExpressionNode, self).__init__(children, connector, negated)
     64
     65    def _combine(self, *args, **kwargs):
     66        return super(ExpressionNode, self)._combine(node=deepcopy(self), *args, **kwargs)
     67
     68    def __invert__(self):
     69        return super(ExpressionNode, self).__invert__(node=deepcopy(self))
     70
     71    def as_sql(self, field, opts, qn=None, node=None):
     72        if node is None:
     73            node = self
     74        result = []
     75        result_params = []
     76        for child in node.children:
     77            if hasattr(child, 'as_sql'):
     78                sql, params = child.as_sql(field, opts, qn)
     79                format = '%s'
     80            else:
     81                sql, params = self.as_sql(field, opts, qn, child)
     82                if child.negated:
     83                    format = 'NOT %s'
     84                else:
     85                    format = '%s'
     86                if len(child.children) > 1:
     87                    format %= '(%s)'
     88            if sql:
     89                result.append(format % sql)
     90                result_params.extend(params)
     91        conn = ' %s ' % node.connector
     92        return conn.join(result), result_params
     93
     94class LiteralExpr(Expression):
     95    """
     96    An expression representing the given value.
     97    """
     98    def __init__(self, value):
     99        self.value = value
     100
     101    def as_sql(self, field, opts, qn=None):
     102        if self.value is None:
     103            return 'NULL', ()
     104        if hasattr(field, 'get_placeholder'):
     105            return field.get_placeholder(self.value), (self.value,)
     106        return '%s', (self.value,)
     107
     108class ColumnExpr(Expression):
     109    """
     110    An expression representing the value of the given column.
     111    """
     112    def __init__(self, column):
     113        self.column = column
     114
     115    def as_sql(self, field, opts, qn=None):
     116        if not qn:
     117            qn = connection.ops.quote_name
     118        try:
     119            column = opts.get_field(self.column).attname
     120        except FieldDoesNotExist:
     121            names = opts.get_all_field_names()
     122            raise FieldError('Cannot resolve keyword %r into field. '
     123                    'Choices are: %s' % (self.column, ', '.join(names)))
     124        return '%s.%s' % (qn(opts.db_table), qn(column)), ()
     125
     126class CurrentExpr(Expression):
     127    """
     128    An expression representing the value of the current column.
     129    """
     130    def as_sql(self, field, opts, qn=None):
     131        if not qn:
     132            qn = connection.ops.quote_name
     133        return '%s.%s' % (qn(opts.db_table), qn(field.attname)), ()
  • django/db/models/sql/subqueries.py

    diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
    index 7385cd0..0bcca96 100644
    a b from django.db.models.sql.constants import *  
    88from django.db.models.sql.datastructures import RawValue, Date
    99from django.db.models.sql.query import Query
    1010from django.db.models.sql.where import AND
     11from django.db.models.sql.expressions import Expression, LiteralExpr
    1112
    1213__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery',
    1314        'CountQuery']
    class UpdateQuery(Query):  
    126127        result = ['UPDATE %s' % qn(table)]
    127128        result.append('SET')
    128129        values, update_params = [], []
    129         for name, val, placeholder in self.values:
    130             if val is not None:
    131                 values.append('%s = %s' % (qn(name), placeholder))
    132                 update_params.append(val)
    133             else:
    134                 values.append('%s = NULL' % qn(name))
     130        for name, sql, params in self.values:
     131            values.append('%s = %s' % (qn(name), sql))
     132            update_params.extend(params)
    135133        result.append(', '.join(values))
    136134        where, params = self.where.as_sql()
    137135        if where:
    class UpdateQuery(Query):  
    207205            self.where.add((None, f.column, f, 'in',
    208206                    pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
    209207                    AND)
    210             self.values = [(related_field.column, None, '%s')]
     208            self.values = [(related_field.column, 'NULL', ())]
    211209            self.execute_sql(None)
    212210
    213211    def add_update_values(self, values):
    class UpdateQuery(Query):  
    232230        """
    233231        from django.db.models.base import Model
    234232        for field, model, val in values_seq:
    235             # FIXME: Some sort of db_prep_* is probably more appropriate here.
    236             if field.rel and isinstance(val, Model):
    237                 val = val.pk
     233            if isinstance(val, Expression):
     234                expr = val
     235            elif field.rel and isinstance(val, Model):  # FIXME: Some sort of
     236                expr = LiteralExpr(val.pk)              # db_prep_* is probably
     237            else:                                       # more appropriate here.
     238                expr = LiteralExpr(val)
    238239
    239             # Getting the placeholder for the field.
    240             if hasattr(field, 'get_placeholder'):
    241                 placeholder = field.get_placeholder(val)
    242             else:
    243                 placeholder = '%s'
     240            sql, params = expr.as_sql(
     241                field, self.get_meta(), self.connection.ops.quote_name)
    244242
    245243            if model:
    246                 self.add_related_update(model, field.column, val, placeholder)
     244                self.add_related_update(model, field.column, sql, params)
    247245            else:
    248                 self.values.append((field.column, val, placeholder))
     246                self.values.append((field.column, sql, params))
    249247
    250     def add_related_update(self, model, column, value, placeholder):
     248    def add_related_update(self, model, column, sql, params):
    251249        """
    252250        Adds (name, value) to an update query for an ancestor model.
    253251
    254252        Updates are coalesced so that we only run one update query per ancestor.
    255253        """
    256254        try:
    257             self.related_updates[model].append((column, value, placeholder))
     255            self.related_updates[model].append((column, sql, params))
    258256        except KeyError:
    259             self.related_updates[model] = [(column, value, placeholder)]
     257            self.related_updates[model] = [(column, sql, params)]
    260258
    261259    def get_related_updates(self):
    262260        """
Back to Top