Ticket #10847: values-extra.3.diff

File values-extra.3.diff, 9.9 KB (added by Alex Gaynor, 15 years ago)
  • django/db/models/query.py

    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index 9323e9b..8b0bce6 100644
    a b class ValuesQuerySet(QuerySet):  
    698698
    699699    def iterator(self):
    700700        # Purge any extra columns that haven't been explicitly asked for
    701         if self.extra_names is not None:
    702             self.query.trim_extra_select(self.extra_names)
    703 
    704701        extra_names = self.query.extra_select.keys()
    705702        field_names = self.field_names
    706703        aggregate_names = self.query.aggregate_select.keys()
    class ValuesQuerySet(QuerySet):  
    724721        if self._fields:
    725722            self.extra_names = []
    726723            self.aggregate_names = []
    727             if not self.query.extra_select and not self.query.aggregate_select:
     724            if not self.query._extra_select and not self.query.aggregate_select:
    728725                self.field_names = list(self._fields)
    729726            else:
    730727                self.query.default_cols = False
    731728                self.field_names = []
    732729                for f in self._fields:
    733                     if self.query.extra_select.has_key(f):
     730                    # we poke at the inner attribute here since we might be adding
     731                    # back an extra select item that we hadn't had selected previously.
     732                    if self.query._extra_select.has_key(f):
    734733                        self.extra_names.append(f)
    735734                    elif self.query.aggregate_select.has_key(f):
    736735                        self.aggregate_names.append(f)
    class ValuesQuerySet(QuerySet):  
    743742            self.aggregate_names = None
    744743
    745744        self.query.select = []
     745        if self.extra_names is not None:
     746            self.query.set_extra_mask(self.extra_names)
    746747        self.query.add_fields(self.field_names, False)
    747748        if self.aggregate_names is not None:
    748749            self.query.set_aggregate_mask(self.aggregate_names)
    class ValuesQuerySet(QuerySet):  
    799800
    800801class ValuesListQuerySet(ValuesQuerySet):
    801802    def iterator(self):
    802         if self.extra_names is not None:
    803             self.query.trim_extra_select(self.extra_names)
    804 
    805803        if self.flat and len(self._fields) == 1:
    806804            for row in self.query.results_iter():
    807805                yield row[0]
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index f4bf8b2..8960e8e 100644
    a b class BaseQuery(object):  
    8888
    8989        # These are for extensions. The contents are more or less appended
    9090        # verbatim to the appropriate clause.
    91         self.extra_select = SortedDict()  # Maps col_alias -> (col_sql, params).
     91        self._extra_select = SortedDict()  # Maps col_alias -> (col_sql, params).
     92        self.extra_select_mask = None # says which items from extra_select should
     93                                      # actually be selected
    9294        self.extra_tables = ()
    9395        self.extra_where = ()
    9496        self.extra_params = ()
    class BaseQuery(object):  
    220222        else:
    221223            obj._aggregate_select_cache = self._aggregate_select_cache.copy()
    222224        obj.max_depth = self.max_depth
    223         obj.extra_select = self.extra_select.copy()
     225        obj._extra_select = self._extra_select.copy()
     226        if self.extra_select_mask is None:
     227            obj.extra_select_mask = None
     228        else:
     229            obj.extra_select_mask = self.extra_select_mask[:]
    224230        obj.extra_tables = self.extra_tables
    225231        obj.extra_where = self.extra_where
    226232        obj.extra_params = self.extra_params
    class BaseQuery(object):  
    325331            query = self
    326332            self.select = []
    327333            self.default_cols = False
    328             self.extra_select = {}
     334            self._extra_select = {}
    329335            self.remove_inherited_models()
    330336
    331337        query.clear_ordering(True)
    class BaseQuery(object):  
    540546            # It would be nice to be able to handle this, but the queries don't
    541547            # really make sense (or return consistent value sets). Not worth
    542548            # the extra complexity when you can write a real query instead.
    543             if self.extra_select and rhs.extra_select:
     549            if self._extra_select and rhs._extra_select:
    544550                raise ValueError("When merging querysets using 'or', you "
    545551                        "cannot have extra(select=...) on both sides.")
    546552            if self.extra_where and rhs.extra_where:
    547553                raise ValueError("When merging querysets using 'or', you "
    548554                        "cannot have extra(where=...) on both sides.")
    549         self.extra_select.update(rhs.extra_select)
     555        self._extra_select.update(rhs._extra_select)
     556        extra_select_mask = []
     557        if self.extra_select_mask is not None:
     558            extra_select_mask.extend(self.extra_select_mask)
     559        if rhs.extra_select_mask is not None:
     560            extra_select_mask.extend(rhs.extra_select_mask)
     561        if not extra_select_mask:
     562            extra_select_mask = None
     563        self.extra_select_mask = extra_select_mask
    550564        self.extra_tables += rhs.extra_tables
    551565        self.extra_where += rhs.extra_where
    552566        self.extra_params += rhs.extra_params
    class BaseQuery(object):  
    20112025        except MultiJoin:
    20122026            raise FieldError("Invalid field name: '%s'" % name)
    20132027        except FieldError:
    2014             names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys()
     2028            names = opts.get_all_field_names() + self._extra_select.keys() + self.aggregate_select.keys()
    20152029            names.sort()
    20162030            raise FieldError("Cannot resolve keyword %r into field. "
    20172031                    "Choices are: %s" % (name, ", ".join(names)))
    class BaseQuery(object):  
    21392153                    pos = entry.find("%s", pos + 2)
    21402154                select_pairs[name] = (entry, entry_params)
    21412155            # This is order preserving, since self.extra_select is a SortedDict.
    2142             self.extra_select.update(select_pairs)
     2156            self._extra_select.update(select_pairs)
    21432157        if where:
    21442158            self.extra_where += tuple(where)
    21452159        if params:
    class BaseQuery(object):  
    22132227        """
    22142228        target[model] = set([f.name for f in fields])
    22152229
    2216     def trim_extra_select(self, names):
    2217         """
    2218         Removes any aliases in the extra_select dictionary that aren't in
    2219         'names'.
    2220 
    2221         This is needed if we are selecting certain values that don't incldue
    2222         all of the extra_select names.
    2223         """
    2224         for key in set(self.extra_select).difference(set(names)):
    2225             del self.extra_select[key]
    2226 
    22272230    def set_aggregate_mask(self, names):
    22282231        "Set the mask of aggregates that will actually be returned by the SELECT"
    22292232        self.aggregate_select_mask = names
    22302233        self._aggregate_select_cache = None
    22312234
     2235    def set_extra_mask(self, names):
     2236        """
     2237        Set the mask of extra select items that will be returned by SELECT,
     2238        we don't actually remove them from the Query since they might be used
     2239        later
     2240        """
     2241        self.extra_select_mask = names
     2242
    22322243    def _aggregate_select(self):
    22332244        """The SortedDict of aggregate columns that are not masked, and should
    22342245        be used in the SELECT clause.
    class BaseQuery(object):  
    22472258            return self.aggregates
    22482259    aggregate_select = property(_aggregate_select)
    22492260
     2261    def _extra_select(self):
     2262        if self.extra_select_mask is None:
     2263            return self._extra_select
     2264        return SortedDict([
     2265            (k, v) for k, v in self._extra_select.items()
     2266            if k in self.extra_select_mask
     2267        ])
     2268    extra_select = property(_extra_select)
     2269
    22502270    def set_start(self, start):
    22512271        """
    22522272        Sets the table from which to start joining. The start position is
  • django/db/models/sql/subqueries.py

    diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
    index 4c62457..9b78131 100644
    a b class UpdateQuery(Query):  
    178178        # from other tables.
    179179        query = self.clone(klass=Query)
    180180        query.bump_prefix()
    181         query.extra_select = {}
     181        query._extra_select = {}
    182182        query.select = []
    183183        query.add_fields([query.model._meta.pk.name])
    184184        must_pre_select = count > 1 and not self.connection.features.update_can_self_select
    class DateQuery(Query):  
    409409        self.select = [select]
    410410        self.select_fields = [None]
    411411        self.select_related = False # See #7097.
    412         self.extra_select = {}
     412        self._extra_select = {}
    413413        self.distinct = True
    414414        self.order_by = order == 'ASC' and [1] or [-1]
    415415
  • tests/regressiontests/extra_regress/models.py

    diff --git a/tests/regressiontests/extra_regress/models.py b/tests/regressiontests/extra_regress/models.py
    index fd34982..0a31e5b 100644
    a b class TestObject(models.Model):  
    3535    second = models.CharField(max_length=20)
    3636    third = models.CharField(max_length=20)
    3737
     38    def __unicode__(self):
     39        return "%s-%s-%s" % (self.first, self.second, self.third)
     40
    3841__test__ = {"API_TESTS": """
    3942# Regression tests for #7314 and #7372
    4043
    True  
    189192>>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz', 'first', 'bar', 'id')
    190193[(u'third', u'first', u'second', 1)]
    191194
    192 """}
     195>>> list(TestObject.objects.filter(pk__in=TestObject.objects.extra(select={'extra': 1}).values('pk'))) == list(TestObject.objects.all())
     196True
    193197
     198>>> TestObject.objects.values('pk').extra(select={'extra': 1})
     199[{'pk': 1}]
    194200
     201>>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1}))
     202[<TestObject: first-second-third>]
     203
     204>>> TestObject.objects.values('pk').extra(select={'extra': 5}).values_list('pk', 'extra')
     205[(1, 5)]
     206
     207"""}
Back to Top