Django

Code

Ticket #5020: query.py.4.diff

File query.py.4.diff, 7.3 kB (added by nostgard <nostgard@yahoo.com>, 1 year ago)

clones and uses _recurse_fields

  • query.py

    old new  
    8989        self._filters = Q() 
    9090        self._order_by = None        # Ordering, e.g. ('date', '-name'). If None, use model's ordering. 
    9191        self._select_related = False # Whether to fill cache for related objects. 
     92        self._recurse_fields = None 
    9293        self._max_related_depth = 0  # Maximum "depth" for select_related 
    9394        self._distinct = False       # Whether the query should use SELECT DISTINCT. 
    9495        self._select = {}            # Dictionary of attname -> SQL. 
     
    200201                    row = self.resolve_columns(row, fields) 
    201202                if fill_cache: 
    202203                    obj, index_end = get_cached_row(klass=self.model, row=row, 
    203                                                     index_start=0, max_depth=self._max_related_depth
     204                                                    index_start=0, max_depth=self._max_related_depth, fields=self._recurse_fields
    204205                else: 
    205206                    obj = self.model(*row[:index_end]) 
    206207                for i, k in enumerate(extra_select): 
     
    408409        else: 
    409410            return self._filter_or_exclude(None, **filter_obj) 
    410411 
    411     def select_related(self, true_or_false=True, depth=0): 
     412    def select_related(self, *fields, **kwargs): 
    412413        "Returns a new QuerySet instance with '_select_related' modified." 
    413         return self._clone(_select_related=true_or_false, _max_related_depth=depth) 
     414        true_or_false = kwargs.pop('true_or_false', True) 
     415        depth = kwargs.pop('depth', 0) 
     416        if not fields: fields = None 
     417        return self._clone(_select_related=true_or_false, _max_related_depth=depth, _recurse_fields=fields) 
    414418 
    415419    def order_by(self, *field_names): 
    416420        "Returns a new QuerySet instance with the ordering changed." 
     
    444448        c._filters = self._filters 
    445449        c._order_by = self._order_by 
    446450        c._select_related = self._select_related 
     451        c._recurse_fields = self._recurse_fields 
    447452        c._max_related_depth = self._max_related_depth 
    448453        c._distinct = self._distinct 
    449454        c._select = self._select.copy() 
     
    501506            fill_table_cache(opts, select, tables, where, 
    502507                             old_prefix=opts.db_table, 
    503508                             cache_tables_seen=[opts.db_table], 
    504                              max_depth=self._max_related_depth
     509                             max_depth=self._max_related_depth, fields=self._recurse_fields
    505510 
    506511        # Add any additional SELECTs. 
    507512        if self._select: 
     
    819824            raise NotImplementedError 
    820825    raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type) 
    821826 
    822 def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0): 
     827def get_cached_row(klass, row, index_start, fields=None, max_depth=0, cur_depth=0): 
    823828    """Helper function that recursively returns an object with cache filled""" 
    824829 
    825830    # If we've got a max_depth set and we've exceeded that depth, bail now. 
    826831    if max_depth and cur_depth > max_depth: 
    827832        return None 
    828833 
     834    fields_to_join = get_select_related_fields(klass._meta, fields) 
     835 
    829836    index_end = index_start + len(klass._meta.fields) 
    830837    obj = klass(*row[index_start:index_end]) 
    831     for f in klass._meta.fields: 
    832         if f.rel and not f.null: 
    833             cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1) 
    834             if cached_row: 
    835                 rel_obj, index_end = cached_row 
    836                 setattr(obj, f.get_cache_name(), rel_obj) 
     838    for f in fields_to_join.iterkeys(): 
     839       cached_row = get_cached_row(f.rel.to, row, index_end, fields_to_join[f], max_depth, cur_depth+1) 
     840       if cached_row: 
     841               rel_obj, index_end = cached_row 
     842               setattr(obj, f.get_cache_name(), rel_obj) 
    837843    return obj, index_end 
    838844 
    839 def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0): 
     845def get_select_related_fields(opts, fields=None): 
    840846    """ 
     847    Helper function that returns a dictionary of fields for select_related() 
     848    """ 
     849    if fields is None: 
     850        fields_to_join = dict([(f, None) for f in opts.fields if f.rel and not f.null]) 
     851    else: 
     852        fields_for_lookup = dict([(f.name, f) for f in opts.fields if f.rel]) 
     853        fields_to_join = dict() 
     854        for f in fields: 
     855            path = f.split(LOOKUP_SEPARATOR) 
     856            try: 
     857                fn = fields_for_lookup[path[0]] 
     858                if fn not in fields_to_join: 
     859                    fields_to_join[fn] = [] 
     860                if len(path) > 1: 
     861                    fields_to_join[fn].append(LOOKUP_SEPARATOR.join(path[1:])) 
     862            except KeyError: 
     863                raise FieldDoesNotExist, '%s has no field named %s' % (opts.object_name, path[0]) 
     864    return fields_to_join 
     865 
     866def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, fields=None, cur_depth=0): 
     867    """ 
    841868    Helper function that recursively populates the select, tables and where (in 
    842869    place) for select_related queries. 
     870 
     871    Implicit select_related calls on NULL fields will force an INNER JOIN currently. 
    843872    """ 
    844873 
    845874    # If we've got a max_depth set and we've exceeded that depth, bail now. 
    846     if max_depth and cur_depth > max_depth: 
     875    if max_depth and cur_depth >= max_depth: 
    847876        return None 
    848877 
     878    fields_to_join = get_select_related_fields(opts, fields) 
     879 
    849880    qn = backend.quote_name 
    850     for f in opts.fields: 
    851         if f.rel and not f.null: 
    852             db_table = f.rel.to._meta.db_table 
    853             if db_table not in cache_tables_seen: 
    854                 tables.append(qn(db_table)) 
    855             else: # The table was already seen, so give it a table alias. 
    856                 new_prefix = '%s%s' % (db_table, len(cache_tables_seen)) 
    857                 tables.append('%s %s' % (qn(db_table), qn(new_prefix))) 
    858                 db_table = new_prefix 
    859             cache_tables_seen.append(db_table) 
    860             where.append('%s.%s = %s.%s' % \ 
    861                 (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column))) 
    862             select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields]) 
    863             fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1) 
     881    for f in fields_to_join.iterkeys(): 
     882        db_table = f.rel.to._meta.db_table 
     883        if db_table not in cache_tables_seen: 
     884            tables.append(qn(db_table)) 
     885        else: # The table was already seen, so give it a table alias. 
     886            new_prefix = '%s%s' % (db_table, len(cache_tables_seen)) 
     887            tables.append('%s %s' % (qn(db_table), qn(new_prefix))) 
     888            db_table = new_prefix 
     889        cache_tables_seen.append(db_table) 
     890        where.append('%s.%s = %s.%s' % \ 
     891            (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column))) 
     892        select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields]) 
     893        fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, fields_to_join[f], cur_depth+1) 
    864894 
    865895def parse_lookup(kwarg_items, opts): 
    866896    # Helper function that handles converting API kwargs