Django

Code

Changeset 7179

Show
Ignore:
Timestamp:
02/29/08 09:53:25 (9 months ago)
Author:
mtredinnick
Message:

queryset-refactor: Made update() work with inherited models.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • django/branches/queryset-refactor/django/db/models/sql/query.py

    r7170 r7179  
    504504        pieces = name.split(LOOKUP_SEP) 
    505505        if not alias: 
    506             alias = self.join((None, opts.db_table, None, None)
     506            alias = self.get_initial_alias(
    507507        field, target, opts, joins = self.setup_joins(pieces, opts, alias, 
    508508                False) 
     
    581581            return True 
    582582        return False 
     583 
     584    def change_alias(self, old_alias, new_alias): 
     585        """ 
     586        Changes old_alias to new_alias, relabelling any references to it in 
     587        select columns and the where clause. 
     588        """ 
     589        assert new_alias not in self.alias_map 
     590 
     591        # 1. Update references in "select" and "where". 
     592        change_map = {old_alias: new_alias} 
     593        self.where.relabel_aliases(change_map) 
     594        for pos, col in enumerate(self.select): 
     595            if isinstance(col, (list, tuple)): 
     596                if col[0] == old_alias: 
     597                    self.select[pos] = (new_alias, col[1]) 
     598            else: 
     599                col.relabel_aliases(change_map) 
     600 
     601        # 2. Rename the alias in the internal table/alias datastructures. 
     602        alias_data = self.alias_map[old_alias] 
     603        alias_data[ALIAS_JOIN][RHS_ALIAS] = new_alias 
     604        table_aliases = self.table_map[alias_data[ALIAS_TABLE]] 
     605        for pos, alias in enumerate(table_aliases): 
     606            if alias == old_alias: 
     607                table_aliases[pos] = new_alias 
     608                break 
     609        self.alias_map[new_alias] = alias_data 
     610        del self.alias_map[old_alias] 
     611        for pos, alias in enumerate(self.tables): 
     612            if alias == old_alias: 
     613                self.tables[pos] = new_alias 
     614                break 
     615 
     616        # 3. Update any joins that refer to the old alias. 
     617        for data in self.alias_map.values(): 
     618            if data[ALIAS_JOIN][LHS_ALIAS] == old_alias: 
     619                data[ALIAS_JOIN][LHS_ALIAS] = new_alias 
     620 
     621    def get_initial_alias(self): 
     622        """ 
     623        Returns the first alias for this query, after increasing its reference 
     624        count. 
     625        """ 
     626        if self.tables: 
     627            alias = self.tables[0] 
     628            self.ref_alias(alias) 
     629        else: 
     630            alias = self.join((None, self.model._meta.db_table, None, None)) 
     631        return alias 
     632 
     633    def count_active_tables(self): 
     634        """ 
     635        Returns the number of tables in this query with a non-zero reference 
     636        count. 
     637        """ 
     638        return len([1 for o in self.alias_map.values() if o[ALIAS_REFCOUNT]]) 
    583639 
    584640    def join(self, connection, always_create=False, exclusions=(), 
     
    729785 
    730786        opts = self.get_meta() 
    731         alias = self.join((None, opts.db_table, None, None)
     787        alias = self.get_initial_alias(
    732788        allow_many = trim or not negate 
    733789 
     
    10221078        the root model (the one given in self.model). 
    10231079        """ 
    1024         table = self.model._meta.db_table 
    1025         self.select.extend([(table, col) for col in columns]) 
     1080        for alias in self.tables: 
     1081            if self.alias_map[alias][ALIAS_REFCOUNT]: 
     1082                break 
     1083        else: 
     1084            alias = self.get_initial_alias() 
     1085        self.select.extend([(alias, col) for col in columns]) 
    10261086 
    10271087    def add_ordering(self, *ordering): 
     
    11121172        """ 
    11131173        opts = self.model._meta 
    1114         alias = self.join((None, opts.db_table, None, None)
     1174        alias = self.get_initial_alias(
    11151175        field, col, opts, joins = self.setup_joins(start.split(LOOKUP_SEP), 
    11161176                opts, alias, False) 
     
    11421202        try: 
    11431203            sql, params = self.as_sql() 
     1204            if not sql: 
     1205                raise EmptyResultSet 
    11441206        except EmptyResultSet: 
    11451207            if result_type == MULTI: 
     
    11741236 
    11751237def results_iter(cursor): 
     1238    """ 
     1239    An iterator over the result set that returns a chunk of rows at a time. 
     1240    """ 
    11761241    while 1: 
    11771242        rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) 
  • django/branches/queryset-refactor/django/db/models/sql/subqueries.py

    r7166 r7179  
    22Query subclasses which provide extra functionality beyond simple data retrieval. 
    33""" 
     4from copy import deepcopy 
     5 
    46from django.contrib.contenttypes import generic 
    57from django.core.exceptions import FieldError 
     
    9597    def _setup_query(self): 
    9698        """ 
    97         Run on initialisation and after cloning. 
     99        Runs on initialisation and after cloning. Any attributes that would 
     100        normally be set in __init__ should go in here, instead, so that they 
     101        are also set up after a clone() call. 
    98102        """ 
    99103        self.values = [] 
     104        self.related_updates = {} 
     105        self.related_ids = None 
     106 
     107    def clone(self, klass=None, **kwargs): 
     108        return super(UpdateQuery, self).clone(klass, 
     109                related_updates=self.related_updates.copy, **kwargs) 
     110 
     111    def execute_sql(self, result_type=None): 
     112        super(UpdateQuery, self).execute_sql(result_type) 
     113        for query in self.get_related_updates(): 
     114            query.execute_sql(result_type) 
    100115 
    101116    def as_sql(self): 
     
    104119        parameters. 
    105120        """ 
    106         self.select_related = False 
    107121        self.pre_sql_setup() 
    108  
    109         if len(self.tables) != 1: 
    110             # We can only update one table at a time, so we need to check that 
    111             # only one alias has a nonzero refcount. 
    112             table = None 
    113             for alias_list in self.table_map.values(): 
    114                 for alias in alias_list: 
    115                     if self.alias_map[alias][ALIAS_REFCOUNT]: 
    116                         if table: 
    117                             raise FieldError('Updates can only access a single database table at a time.') 
    118                         table = alias 
    119         else: 
    120             table = self.tables[0] 
    121  
     122        if not self.values: 
     123            return '', () 
     124        table = self.tables[0] 
    122125        qn = self.quote_name_unless_alias 
    123126        result = ['UPDATE %s' % qn(table)] 
     
    136139        return ' '.join(result), tuple(update_params + params) 
    137140 
     141    def pre_sql_setup(self): 
     142        """ 
     143        If the update depends on results from other tables, we need to do some 
     144        munging of the "where" conditions to match the format required for 
     145        (portable) SQL updates. That is done here. 
     146 
     147        Further, if we are going to be running multiple updates, we pull out 
     148        the id values to update at this point so that they don't change as a 
     149        result of the progressive updates. 
     150        """ 
     151        self.select_related = False 
     152        self.clear_ordering(True) 
     153        super(UpdateQuery, self).pre_sql_setup() 
     154        count = self.count_active_tables() 
     155        if not self.related_updates and count == 1: 
     156            return 
     157 
     158        # We need to use a sub-select in the where clause to filter on things 
     159        # from other tables. 
     160        query = self.clone(klass=Query) 
     161        main_alias = query.tables[0] 
     162        if count != 1: 
     163            query.unref_alias(main_alias) 
     164        if query.alias_map[main_alias][ALIAS_REFCOUNT]: 
     165            alias = '%s0' % self.alias_prefix 
     166            query.change_alias(main_alias, alias) 
     167            col = query.model._meta.pk.column 
     168        else: 
     169            for model in query.model._meta.get_parent_list(): 
     170                for alias in query.table_map.get(model._meta.db_table, []): 
     171                    if query.alias_map[alias][ALIAS_REFCOUNT]: 
     172                        col = model._meta.pk.column 
     173                        break 
     174        query.add_local_columns([col]) 
     175 
     176        # Now we adjust the current query: reset the where clause and get rid 
     177        # of all the tables we don't need (since they're in the sub-select). 
     178        self.where = self.where_class() 
     179        if self.related_updates: 
     180            idents = [] 
     181            for rows in query.execute_sql(MULTI): 
     182                idents.extend([r[0] for r in rows]) 
     183            self.add_filter(('pk__in', idents)) 
     184            self.related_ids = idents 
     185        else: 
     186            self.add_filter(('pk__in', query)) 
     187        for alias in self.tables[1:]: 
     188            self.alias_map[alias][ALIAS_REFCOUNT] = 0 
     189 
    138190    def clear_related(self, related_field, pk_list): 
    139191        """ 
     
    157209            field, model, direct, m2m = self.model._meta.get_field_by_name(name) 
    158210            if not direct or m2m: 
    159                 # Can only update non-relation fields and foreign keys. 
    160211                raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field) 
     212            # FIXME: Some sort of db_prep_* is probably more appropriate here. 
    161213            if field.rel and isinstance(val, Model): 
    162214                val = val.pk 
    163             self.values.append((field.column, val)) 
     215            if model: 
     216                self.add_related_update(model, field.column, val) 
     217            else: 
     218                self.values.append((field.column, val)) 
     219 
     220    def add_related_update(self, model, column, value): 
     221        """ 
     222        Adds (name, value) to an update query for an ancestor model. 
     223 
     224        Updates are coalesced so that we only run one update query per ancestor. 
     225        """ 
     226        try: 
     227            self.related_updates[model].append((column, value)) 
     228        except KeyError: 
     229            self.related_updates[model] = [(column, value)] 
     230 
     231    def get_related_updates(self): 
     232        """ 
     233        Returns a list of query objects: one for each update required to an 
     234        ancestor model. Each query will have the same filtering conditions as 
     235        the current query but will only update a single table. 
     236        """ 
     237        if not self.related_updates: 
     238            return [] 
     239        result = [] 
     240        for model, values in self.related_updates.items(): 
     241            query = UpdateQuery(model, self.connection) 
     242            query.values = values 
     243            if self.related_ids: 
     244                query.add_filter(('pk__in', self.related_ids)) 
     245            result.append(query) 
     246        return result 
    164247 
    165248class InsertQuery(Query): 
  • django/branches/queryset-refactor/tests/modeltests/model_inheritance/models.py

    r7163 r7179  
    217217<Restaurant: Ristorante Miron the restaurant> 
    218218 
     219# The update() command can update fields in parent and child classes at once 
     220# (although it executed multiple SQL queries to do so). 
     221>>> Restaurant.objects.filter(serves_hot_dogs=True, name__contains='D').update(name='Demon Puppies', serves_hot_dogs=False) 
     222>>> r1 = Restaurant.objects.get(pk=r.pk) 
     223>>> r1.serves_hot_dogs == False 
     224True 
     225>>> r1.name 
     226u'Demon Puppies' 
    219227 
    220228"""}