Django

Code

Ticket #3275: query.py.diff

File query.py.diff, 5.9 kB (added by David Cramer <dcramer@gmail.com>, 2 years ago)

diffs for django/db/models/query.py

  • query.py

    old new  
    8080        self._filters = Q() 
    8181        self._order_by = None        # Ordering, e.g. ('date', '-name'). If None, use model's ordering. 
    8282        self._select_related = False # Whether to fill cache for related objects. 
     83        self._recurse_depth = 0      # Used to track how deep we are following for select_related() 
     84        self._recurse_fields = []    # Fields to recurse through for select_related() 
    8385        self._distinct = False       # Whether the query should use SELECT DISTINCT. 
    8486        self._select = {}            # Dictionary of attname -> SQL. 
    8587        self._where = []             # List of extra WHERE clauses to use. 
     
    178180                raise StopIteration 
    179181            for row in rows: 
    180182                if fill_cache: 
    181                     obj, index_end = get_cached_row(self.model, row, 0
     183                    obj, index_end = get_cached_row(self.model, row, 0, self._recurse_fields, self._recurse_depth
    182184                else: 
    183185                    obj = self.model(*row[:index_end]) 
    184186                for i, k in enumerate(extra_select): 
     
    194196        counter._select_related = False 
    195197        select, sql, params = counter._get_sql_clause() 
    196198        cursor = connection.cursor() 
     199        id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table), 
     200                backend.quote_name(self.model._meta.pk.column)) 
    197201        if self._distinct: 
    198             id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table), 
    199                     backend.quote_name(self.model._meta.pk.column)) 
    200202            cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params) 
    201203        else: 
    202             cursor.execute("SELECT COUNT(*)" + sql, params) 
     204            cursor.execute("SELECT COUNT(%s)" % id_col + sql, params) 
    203205        return cursor.fetchone()[0] 
    204206 
    205207    def get(self, *args, **kwargs): 
     
    359361        else: 
    360362            return self._filter_or_exclude(None, **filter_obj) 
    361363 
    362     def select_related(self, true_or_false=True): 
     364    # fields should be a list of field names in the root table, if specified, it modifies depth to 1 
     365    # depth is the maximum number of children to recurse through, defaults to infinite 
     366    def select_related(self, true_or_false=True, depth=0, fields=[]): 
    363367        "Returns a new QuerySet instance with '_select_related' modified." 
    364         return self._clone(_select_related=true_or_false) 
     368        if fields != []: 
     369            depth = 1 
     370        return self._clone(_select_related=true_or_false, _recurse_depth=depth, _recurse_fields=fields) 
    365371 
    366372    def order_by(self, *field_names): 
    367373        "Returns a new QuerySet instance with the ordering changed." 
     
    395401        c._filters = self._filters 
    396402        c._order_by = self._order_by 
    397403        c._select_related = self._select_related 
     404        c._recurse_fields = self._recurse_fields 
     405        c._recurse_depth = self._recurse_depth 
    398406        c._distinct = self._distinct 
    399407        c._select = self._select.copy() 
    400408        c._where = self._where[:] 
     
    448456 
    449457        # Add additional tables and WHERE clauses based on select_related. 
    450458        if self._select_related: 
    451             fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table]
     459            fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table], self._recurse_depth, self._recurse_fields
    452460 
    453461        # Add any additional SELECTs. 
    454462        if self._select: 
     
    660668        return backend.get_fulltext_search_sql(table_prefix + field_name) 
    661669    raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type) 
    662670 
    663 def get_cached_row(klass, row, index_start): 
     671def get_cached_row(klass, row, index_start, fields=[], max_depth=0, cur_depth=0): 
    664672    "Helper function that recursively returns an object with cache filled" 
     673    if max_depth and cur_depth > max_depth: 
     674        return None 
    665675    index_end = index_start + len(klass._meta.fields) 
    666676    obj = klass(*row[index_start:index_end]) 
    667677    for f in klass._meta.fields: 
    668         if f.rel and not f.null: 
    669             rel_obj, index_end = get_cached_row(f.rel.to, row, index_end) 
    670             setattr(obj, f.get_cache_name(), rel_obj) 
     678        if f.rel and not f.null and (fields == [] or (cur_depth == 0 and f.name in fields)): 
     679            cached_row = get_cached_row(f.rel.to, row, index_end, fields, max_depth, cur_depth+1) 
     680            if cached_row: 
     681                    rel_obj, index_end = cached_row 
     682                    setattr(obj, f.get_cache_name(), rel_obj) 
    671683    return obj, index_end 
    672684 
    673 def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen): 
     685def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, fields=[], cur_depth=0): 
    674686    """ 
    675687    Helper function that recursively populates the select, tables and where (in 
    676688    place) for select_related queries. 
    677689    """ 
    678690    qn = backend.quote_name 
     691    if max_depth and cur_depth > max_depth: 
     692        return 
    679693    for f in opts.fields: 
    680         if f.rel and not f.null
     694        if f.rel and not f.null and (fields == [] or (cur_depth == 0 and f.name in fields))
    681695            db_table = f.rel.to._meta.db_table 
    682696            if db_table not in cache_tables_seen: 
    683697                tables.append(qn(db_table)) 
     
    689703            where.append('%s.%s = %s.%s' % \ 
    690704                (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column))) 
    691705            select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields]) 
    692             fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen
     706            fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, fields, cur_depth+1
    693707 
    694708def parse_lookup(kwarg_items, opts): 
    695709    # Helper function that handles converting API kwargs