Ticket #17000: add_q_refactor.diff

File add_q_refactor.diff, 25.8 KB (added by Anssi Kääriäinen, 13 years ago)
  • django/db/models/query_utils.py

    diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py
    index a56ab5c..2f3b7a7 100644
    a b class Q(tree.Node):  
    4747        if not isinstance(other, Q):
    4848            raise TypeError(other)
    4949        obj = type(self)()
    50         obj.add(self, conn)
    51         obj.add(other, conn)
     50        obj.connector = conn
     51        if len(self) == 1 and not self.negated:
     52            obj.add(self.children[0], conn)
     53        else:
     54            obj.add(self, conn)
     55        if len(other) == 1 and not other.negated:
     56            obj.add(other.children[0], conn)
     57        else:
     58            obj.add(other, conn)
    5259        return obj
    5360
    5461    def __or__(self, other):
    class Q(tree.Node):  
    5865        return self._combine(other, self.AND)
    5966
    6067    def __invert__(self):
    61         obj = type(self)()
    62         obj.add(self, self.AND)
     68        obj = self.clone()
    6369        obj.negate()
    6470        return obj
    6571
  • django/db/models/sql/aggregates.py

    diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
    index 207bc0c..ee43808 100644
    a b class Aggregate(object):  
    6969
    7070        self.field = tmp
    7171
     72    def clone(self):
     73        clone = copy.copy(self)
     74        clone.col = self.col[:]
     75        return clone
     76
    7277    def relabel_aliases(self, change_map):
    7378        if isinstance(self.col, (list, tuple)):
    7479            self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 6bf7de2..4eb69cb 100644
    a b from django.db import transaction  
    55from django.db.backends.util import truncate_name
    66from django.db.models.query_utils import select_related_descend
    77from django.db.models.sql.constants import *
    8 from django.db.models.sql.datastructures import EmptyResultSet
     8from django.db.models.sql.datastructures import EmptyResultSet, FullResultSet
    99from django.db.models.sql.expressions import SQLEvaluator
    1010from django.db.models.sql.query import get_proxied_model, get_order_dir, Query
    1111from django.db.utils import DatabaseError
    class SQLCompiler(object):  
    6868        from_, f_params = self.get_from_clause()
    6969
    7070        qn = self.quote_name_unless_alias
    71 
    7271        where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
    7372        having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
    7473        params = []
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index 61fd2be..73b597d 100644
    a b from django.db.models.sql import aggregates as base_aggregates_module  
    2020from django.db.models.sql.constants import *
    2121from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
    2222from django.db.models.sql.expressions import SQLEvaluator
    23 from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
    24     ExtraWhere, AND, OR)
     23from django.db.models.sql.where import (WhereNode, Constraint, ExtraWhere,
     24    AND, OR)
    2525from django.core.exceptions import FieldError
    2626
    2727__all__ = ['Query', 'RawQuery']
    class Query(object):  
    254254        obj.dupe_avoidance = self.dupe_avoidance.copy()
    255255        obj.select = self.select[:]
    256256        obj.tables = self.tables[:]
    257         obj.where = copy.deepcopy(self.where, memo=memo)
     257        obj.where = self.where.clone()
    258258        obj.where_class = self.where_class
    259259        if self.group_by is None:
    260260            obj.group_by = None
    261261        else:
    262262            obj.group_by = self.group_by[:]
    263         obj.having = copy.deepcopy(self.having, memo=memo)
     263        obj.having = self.having.clone()
    264264        obj.order_by = self.order_by[:]
    265265        obj.low_mark, obj.high_mark = self.low_mark, self.high_mark
    266266        obj.distinct = self.distinct
    class Query(object):  
    268268        obj.select_for_update_nowait = self.select_for_update_nowait
    269269        obj.select_related = self.select_related
    270270        obj.related_select_cols = []
    271         obj.aggregates = copy.deepcopy(self.aggregates, memo=memo)
     271        if self.aggregates:
     272            obj.aggregates = copy.deepcopy(self.aggregates, memo=memo)
     273        else:
     274            obj.aggregates = SortedDict()
    272275        if self.aggregate_select_mask is None:
    273276            obj.aggregate_select_mask = None
    274277        else:
    class Query(object):  
    291294            obj._extra_select_cache = self._extra_select_cache.copy()
    292295        obj.extra_tables = self.extra_tables
    293296        obj.extra_order_by = self.extra_order_by
    294         obj.deferred_loading = copy.deepcopy(self.deferred_loading, memo=memo)
     297        obj.deferred_loading = self.deferred_loading[0].copy(), self.deferred_loading[1]
    295298        if self.filter_is_sticky and self.used_aliases:
    296299            obj.used_aliases = self.used_aliases.copy()
    297300        else:
    class Query(object):  
    499502                if self.alias_refcount.get(alias) or rhs.alias_refcount.get(alias):
    500503                    self.promote_alias(alias, True)
    501504
    502         # Now relabel a copy of the rhs where-clause and add it to the current
    503         # one.
    504         if rhs.where:
    505             w = copy.deepcopy(rhs.where)
    506             w.relabel_aliases(change_map)
    507             if not self.where:
    508                 # Since 'self' matches everything, add an explicit "include
    509                 # everything" where-constraint so that connections between the
    510                 # where clauses won't exclude valid results.
    511                 self.where.add(EverythingNode(), AND)
    512         elif self.where:
    513             # rhs has an empty where clause.
    514             w = self.where_class()
    515             w.add(EverythingNode(), AND)
     505        if connector == OR and (not self.where or not rhs.where):
     506            # One of the two sides matches everything and the connector is OR.
     507            # This means the new where condition must match everything.
     508            self.where = self.where_class()
    516509        else:
    517             w = self.where_class()
    518         self.where.add(w, connector)
     510            rhs_where = rhs.where.clone()
     511            rhs_where.relabel_aliases(change_map)
     512            self.where = self.where_class([self.where, rhs_where], connector)
     513            # the root node's connector must always be AND
     514            if self.where.connector == OR:
     515                self.where = self.where_class([self.where])
     516            self.where.prune_tree(recurse=True)
    519517
    520518        # Selection columns and extra extensions are those provided by 'rhs'.
    521519        self.select = []
    class Query(object):  
    10711069
    10721070        for alias, aggregate in self.aggregates.items():
    10731071            if alias in (parts[0], LOOKUP_SEP.join(parts)):
    1074                 entry = self.where_class()
    1075                 entry.add((aggregate, lookup_type, value), AND)
    1076                 if negate:
    1077                     entry.negate()
    1078                 self.having.add(entry, connector)
     1072                self.having.add((aggregate, lookup_type, value), connector)
    10791073                return
    10801074
    10811075        opts = self.get_meta()
    class Query(object):  
    11831177                self.add_filter(filter, negate=negate, can_reuse=can_reuse,
    11841178                        process_extras=False)
    11851179
    1186     def add_q(self, q_object, used_aliases=None, force_having=False):
     1180    def add_q(self, q_object, force_having=False):
    11871181        """
    11881182        Adds a Q-object to the current filter.
    11891183
    11901184        Can also be used to add anything that has an 'add_to_query()' method.
     1185
     1186        In case add_to_query path is not executed, this method's main purpose
     1187        is to walk the q_object's internal nodes and manage the state of the
     1188        self.where / self.having trees. Leaf nodes will be handled by
     1189        add_filter.
     1190
     1191        The self.where / self.having trees are managed by pushing new nodes
     1192        to self.where / self.having. This way self.where / self.having is
     1193        always at the right node when add_filter adds items to them.
     1194
     1195        We need to start a new subtree when:
     1196           - The connector of the q_object is different than the connector of
     1197             the where / having tree.
     1198           - The q_object is negated.
     1199
     1200        After call of this function with q_object=~Q(pk=1)&~Q(Q(pk=3)|Q(pk=2))
     1201        we should have the following tree:
     1202                      AND
     1203                     /   \
     1204                    NOT  NOT
     1205                     |     \
     1206                    pk=1   OR
     1207                          /  \
     1208                        pk=3 pk=2
     1209
     1210        This method will call recursively itself for those childrens of the
     1211        q_object which are Q-objs, and call add_filter for the leaf nodes.
    11911212        """
    1192         if used_aliases is None:
    1193             used_aliases = self.used_aliases
     1213
     1214        # Complex custom objects are responsible for adding themselves.
    11941215        if hasattr(q_object, 'add_to_query'):
    1195             # Complex custom objects are responsible for adding themselves.
    1196             q_object.add_to_query(self, used_aliases)
    1197         else:
    1198             if self.where and q_object.connector != AND and len(q_object) > 1:
    1199                 self.where.start_subtree(AND)
    1200                 subtree = True
     1216            q_object.add_to_query(self, self.used_aliases)
     1217            return
     1218
     1219        # We need to check upfront if this whole tree should be placed in
     1220        # the query's having clause or not. The reason is we can't have
     1221        # one part of ORed clause in having and the other in where. Once set,
     1222        # force_having can't be changed later on.
     1223        if not force_having and q_object.connector == OR:
     1224            force_having = self.need_force_having(q_object)
     1225
     1226        # Start subtrees for both having and where if needed. At the end we
     1227        # check if anything got added into the subtrees. If not, prune em.
     1228        where_subtree = False
     1229        having_subtree = False
     1230        connector = q_object.connector
     1231        if self.having.connector <> connector or q_object.negated:
     1232            self.having = self.having.subtree(q_object.connector)
     1233            having_subtree = True
     1234        if self.where.connector <> connector or q_object.negated:
     1235            self.where = self.where.subtree(q_object.connector)
     1236            where_subtree = True
     1237        if q_object.negated:
     1238            self.where.negate()
     1239            self.having.negate()
     1240
     1241        # Aliases that were newly added or not used at all need to
     1242        # be promoted to outer joins if they are nullable relations.
     1243        # (they shouldn't turn the whole conditional into the empty
     1244        # set just because they don't match anything). Take the
     1245        # before snapshot of the aliases.
     1246        if connector == OR:
     1247            refcounts_before = self.alias_refcount.copy()
     1248
     1249        for child in q_object.children:
     1250            if isinstance(child, Node):
     1251                self.add_q(child, force_having=force_having)
    12011252            else:
    1202                 subtree = False
    1203             connector = AND
    1204             if q_object.connector == OR and not force_having:
    1205                 force_having = self.need_force_having(q_object)
    1206             for child in q_object.children:
    1207                 if connector == OR:
    1208                     refcounts_before = self.alias_refcount.copy()
    1209                 if force_having:
    1210                     self.having.start_subtree(connector)
    1211                 else:
    1212                     self.where.start_subtree(connector)
    1213                 if isinstance(child, Node):
    1214                     self.add_q(child, used_aliases, force_having=force_having)
    1215                 else:
    1216                     self.add_filter(child, connector, q_object.negated,
    1217                             can_reuse=used_aliases, force_having=force_having)
    1218                 if force_having:
    1219                     self.having.end_subtree()
    1220                 else:
    1221                     self.where.end_subtree()
    1222 
    1223                 if connector == OR:
    1224                     # Aliases that were newly added or not used at all need to
    1225                     # be promoted to outer joins if they are nullable relations.
    1226                     # (they shouldn't turn the whole conditional into the empty
    1227                     # set just because they don't match anything).
    1228                     self.promote_unused_aliases(refcounts_before, used_aliases)
    1229                 connector = q_object.connector
    1230             if q_object.negated:
    1231                 self.where.negate()
    1232             if subtree:
    1233                 self.where.end_subtree()
    1234         if self.filter_is_sticky:
    1235             self.used_aliases = used_aliases
     1253                self.add_filter(child, connector, q_object.negated,
     1254                        can_reuse=self.used_aliases, force_having=force_having)
     1255
     1256        if connector == OR:
     1257            self.promote_unused_aliases(refcounts_before, self.used_aliases)
     1258        if having_subtree:
     1259            self.having = self.having.parent
     1260            self.having.prune_tree()
     1261        if where_subtree:
     1262            self.where = self.where.parent
     1263            self.where.prune_tree()
    12361264
    12371265    def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
    12381266            allow_explicit_fk=False, can_reuse=None, negate=False,
  • django/db/models/sql/where.py

    diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
    index 3e9dbf0..b2b5b3c 100644
    a b class WhereNode(tree.Node):  
    6464
    6565        if hasattr(obj, "prepare"):
    6666            value = obj.prepare(lookup_type, value)
    67             super(WhereNode, self).add((obj, lookup_type, annotation, value),
    68                 connector)
    69             return
    7067
    7168        super(WhereNode, self).add((obj, lookup_type, annotation, value),
    7269                connector)
    7370
    7471    def as_sql(self, qn, connection):
    7572        """
    76         Returns the SQL version of the where clause and the value to be
    77         substituted in. Returns None, None if this node is empty.
     73        Returns the SQL version of the where clause and the values to be
     74        substituted in.
     75
     76        If the tree evaluates to always true, then the function will return
     77        ("", []). If the function evaluates to always false, EmptyResultSet
     78        will be risen.
     79        """
     80        try:
     81            return self._as_sql(qn, connection)
     82        except FullResultSet:
     83            return "", []
    7884
    79         If 'node' is provided, that is the root of the SQL generation
    80         (generally not needed except by the internal implementation for
    81         recursion).
     85    def _as_sql(self, qn, connection):
     86        """
     87        Internal helper.
    8288        """
    8389        if not self.children:
    84             return None, []
     90            return '', []
    8591        result = []
    8692        result_params = []
    87         empty = True
     93        # Track the amount of EmptyResultSets and FullResultSets this node
     94        # has. These + self.connector and self.negated are used to check
     95        # if this node matches nothing or matches everything.
     96        empty_vals = 0
     97        full_vals = 0
    8898        for child in self.children:
    8999            try:
    90100                if hasattr(child, 'as_sql'):
    class WhereNode(tree.Node):  
    92102                else:
    93103                    # A leaf node in the tree.
    94104                    sql, params = self.make_atom(child, qn, connection)
    95 
    96             except EmptyResultSet:
    97                 if self.connector == AND and not self.negated:
    98                     # We can bail out early in this particular case (only).
    99                     raise
    100                 elif self.negated:
    101                     empty = False
    102                 continue
    103             except FullResultSet:
    104                 if self.connector == OR:
    105                     if self.negated:
    106                         empty = True
    107                         break
    108                     # We match everything. No need for any constraints.
    109                     return '', []
    110                 if self.negated:
    111                     empty = True
    112                 continue
    113 
    114             empty = False
    115             if sql:
    116105                result.append(sql)
    117106                result_params.extend(params)
    118         if empty:
     107            except EmptyResultSet:
     108                empty_vals += 1
     109            except FullResultSet:
     110                full_vals += 1
     111
     112        if self.negated:
     113            full_vals, empty_vals = empty_vals, full_vals
     114
     115        if full_vals > 0 and self.connector == OR:
     116            raise FullResultSet
     117        if full_vals == len(self) and self.connector == AND:
     118            raise FullResultSet
     119        if empty_vals == len(self) and self.connector == OR:
     120            raise EmptyResultSet
     121        if empty_vals > 0 and self.connector == AND:
    119122            raise EmptyResultSet
    120123
    121124        conn = ' %s ' % self.connector
    class WhereNode(tree.Node):  
    249252                if hasattr(child[3], 'relabel_aliases'):
    250253                    child[3].relabel_aliases(change_map)
    251254
    252 class EverythingNode(object):
    253     """
    254     A node that matches everything.
    255     """
    256 
    257     def as_sql(self, qn=None, connection=None):
    258         raise FullResultSet
    259 
    260     def relabel_aliases(self, change_map, node=None):
    261         return
    262 
    263 class NothingNode(object):
    264     """
    265     A node that matches nothing.
    266     """
    267     def as_sql(self, qn=None, connection=None):
    268         raise EmptyResultSet
    269 
    270     def relabel_aliases(self, change_map, node=None):
    271         return
    272 
    273255class ExtraWhere(object):
    274256    def __init__(self, sqls, params):
    275257        self.sqls = sqls
    class ExtraWhere(object):  
    278260    def as_sql(self, qn=None, connection=None):
    279261        return " AND ".join(self.sqls), tuple(self.params or ())
    280262
     263    def clone(self):
     264        return self
     265
    281266class Constraint(object):
    282267    """
    283268    An object that can be passed to WhereNode.add() and knows how to
    class Constraint(object):  
    342327    def relabel_aliases(self, change_map):
    343328        if self.alias in change_map:
    344329            self.alias = change_map[self.alias]
     330
     331    def clone(self):
     332        return Constraint(self.alias, self.col, self.field)
  • django/utils/tree.py

    diff --git a/django/utils/tree.py b/django/utils/tree.py
    index 36b5977..4be04d5 100644
    a b class Node(object):  
    2626        """
    2727        self.children = children and children[:] or []
    2828        self.connector = connector or self.default
    29         self.subtree_parents = []
     29        self.parent = None
    3030        self.negated = negated
    3131
    3232    # We need this because of django.db.models.query_utils.Q. Q. __init__() is
    class Node(object):  
    4545        return obj
    4646    _new_instance = classmethod(_new_instance)
    4747
     48    def empty(cls):
     49        return cls._new_instance([])
     50    empty = classmethod(empty)
     51
     52
     53    def clone(self, memo=None):
     54        """
     55        Clones the whole tree, not just the subtree. We have loops in
     56        the tree due to keeping both parent and child links. Because
     57        of this, we must keep a memo of objects already copied.
     58        """
     59        if memo is None:
     60            memo = {}
     61        if self in memo:
     62            return memo[self]
     63        obj = self.empty()
     64        memo[self] = obj
     65        for child in self.children:
     66             if isinstance(child, Node):
     67                 child = child.clone(memo=memo)
     68             obj._add(child)
     69        if self.parent is not None:
     70            new_parent = self.parent.clone(memo=memo)
     71            obj.parent = new_parent
     72        obj.connector = self.connector
     73        obj.negated = self.negated
     74        return obj
     75
     76    def __repr__(self):
     77        return self.as_subtree
     78
    4879    def __str__(self):
    4980        if self.negated:
    5081            return '(NOT (%s: %s))' % (self.connector, ', '.join([str(c) for c
    class Node(object):  
    5283        return '(%s: %s)' % (self.connector, ', '.join([str(c) for c in
    5384                self.children]))
    5485
    55     def __deepcopy__(self, memodict):
    56         """
    57         Utility method used by copy.deepcopy().
    58         """
    59         obj = Node(connector=self.connector, negated=self.negated)
    60         obj.__class__ = self.__class__
    61         obj.children = copy.deepcopy(self.children, memodict)
    62         obj.subtree_parents = copy.deepcopy(self.subtree_parents, memodict)
    63         return obj
     86    def _as_subtree(self, indent=0):
     87        buf = []
     88        if self.negated:
     89            buf.append(" " * indent + "NOT")
     90        buf.append((" " * indent) + self.connector + ":")
     91        indent += 2
     92        for child in self.children:
     93            if isinstance(child, Node):
     94                buf.append(child._as_subtree(indent=indent))
     95            else:
     96                buf.append((" " * indent) + str(child))
     97        return "\n".join(buf)
     98    as_subtree = property(_as_subtree)
     99
     100    def _as_tree(self):
     101        root = self
     102        while root.parent:
     103            root = root.parent
     104        return root._as_subtree(indent=0)
     105    as_tree = property(_as_tree)
    64106
    65107    def __len__(self):
    66108        """
    class Node(object):  
    80122        """
    81123        return other in self.children
    82124
     125    def _add(self, *nodes):
     126        """
     127        A helper method to keep the parent/child links in valid state.
     128        """
     129        for node in nodes:
     130            self.children.append(node)
     131            if isinstance(node, Node):
     132                node.parent = self
     133
    83134    def add(self, node, conn_type):
    84135        """
    85         Adds a new node to the tree. If the conn_type is the same as the root's
    86         current connector type, the node is added to the first level.
     136        Adds a new node to the tree. If the conn_type is the same as the
     137        root's current connector type, the node is added to the first level.
    87138        Otherwise, the whole tree is pushed down one level and a new root
    88         connector is created, connecting the existing tree and the new node.
     139        connector is created, connecting the existing tree and the added node.
    89140        """
    90141        if node in self.children and conn_type == self.connector:
    91142            return
    92         if len(self.children) < 2:
    93             self.connector = conn_type
    94143        if self.connector == conn_type:
    95             if isinstance(node, Node) and (node.connector == conn_type or
    96                     len(node) == 1):
    97                 self.children.extend(node.children)
    98             else:
    99                 self.children.append(node)
     144            self._add(node)
    100145        else:
    101             obj = self._new_instance(self.children, self.connector,
    102                     self.negated)
    103             self.connector = conn_type
    104             self.children = [obj, node]
     146            obj = self._new_instance([node], conn_type)
     147            obj2 = self.clone()
     148            self._add(obj, obj2)
    105149
    106150    def negate(self):
    107151        """
    108         Negate the sense of the root connector. This reorganises the children
    109         so that the current node has a single child: a negated node containing
    110         all the previous children. This slightly odd construction makes adding
    111         new children behave more intuitively.
    112 
    113         Interpreting the meaning of this negate is up to client code. This
    114         method is useful for implementing "not" arrangements.
    115         """
    116         self.children = [self._new_instance(self.children, self.connector,
    117                 not self.negated)]
    118         self.connector = self.default
    119 
    120     def start_subtree(self, conn_type):
    121         """
    122         Sets up internal state so that new nodes are added to a subtree of the
    123         current node. The conn_type specifies how the sub-tree is joined to the
    124         existing children.
     152        Negate the sense of this node.
    125153        """
    126         if len(self.children) == 1:
    127             self.connector = conn_type
    128         elif self.connector != conn_type:
    129             self.children = [self._new_instance(self.children, self.connector,
    130                     self.negated)]
    131             self.connector = conn_type
    132             self.negated = False
     154        self.negated = not self.negated
    133155
    134         self.subtree_parents.append(self.__class__(self.children,
    135                 self.connector, self.negated))
    136         self.connector = self.default
    137         self.negated = False
    138         self.children = []
     156    def subtree(self, conn_type):
     157        obj = self.empty()
     158        obj.connector = conn_type
     159        obj.parent = self
     160        self.children.append(obj)
     161        return obj
    139162
    140     def end_subtree(self):
     163    def prune_tree(self, recurse=False):
    141164        """
    142         Closes off the most recently unmatched start_subtree() call.
    143 
    144         This puts the current state into a node of the parent tree and returns
    145         the current instances state to be the parent.
     165        Removes empty children nodes, and non-necessary intermediatry
     166        nodes from this node. If recurse is true, will recurse down
     167        the tree.
    146168        """
    147         obj = self.subtree_parents.pop()
    148         node = self.__class__(self.children, self.connector)
    149         self.connector = obj.connector
    150         self.negated = obj.negated
    151         self.children = obj.children
    152         self.children.append(node)
    153 
     169        old_childs = self.children[:]
     170        self.children = []
     171        for child in old_childs:
     172            if not child:
     173                continue
     174            if isinstance(child, Node):
     175                if recurse:
     176                    child.prune_tree(recurse=True)
     177                if not child.negated and len(child) == 1:
     178                    child = child.children[0]
     179            self._add(child)
  • tests/regressiontests/queries/tests.py

    diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py
    index d8fd5bc..4f505a3 100644
    a b class Queries1Tests(BaseQuerysetTest):  
    820820        q = Note.objects.filter(Q(extrainfo__author=self.a1)|Q(extrainfo=xx)).query
    821821        self.assertEqual(
    822822            len([x[2] for x in q.alias_map.values() if x[2] == q.LOUTER and q.alias_refcount[x[1]]]),
    823             1
     823            2
    824824        )
    825825
    826826
Back to Top