Django

Code

Ticket #5020: query.py.3.diff

File query.py.3.diff, 5.4 kB (added by David Cramer <dcramer@gmail.com>, 1 year ago)

get_cached_row had a slipup

  • query.py

    old new  
    408408        else: 
    409409            return self._filter_or_exclude(None, **filter_obj) 
    410410 
    411     def select_related(self, true_or_false=True, depth=0): 
     411    def select_related(self, *fields, **kwargs): 
    412412        "Returns a new QuerySet instance with '_select_related' modified." 
    413         return self._clone(_select_related=true_or_false, _max_related_depth=depth) 
     413        true_or_false = kwargs.pop('true_or_false', True) 
     414        depth = kwargs.pop('depth', 0) 
     415        if not fields: fields = None 
     416        return self._clone(_select_related=true_or_false, _max_related_depth=depth, _recurse_fields=fields) 
    414417 
    415418    def order_by(self, *field_names): 
    416419        "Returns a new QuerySet instance with the ordering changed." 
     
    819822            raise NotImplementedError 
    820823    raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type) 
    821824 
    822 def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0): 
     825def get_cached_row(klass, row, index_start, fields=None, max_depth=0, cur_depth=0): 
    823826    """Helper function that recursively returns an object with cache filled""" 
    824827 
    825828    # If we've got a max_depth set and we've exceeded that depth, bail now. 
    826829    if max_depth and cur_depth > max_depth: 
    827830        return None 
    828831 
     832    fields_to_join = get_select_related_fields(klass._meta, fields) 
     833 
    829834    index_end = index_start + len(klass._meta.fields) 
    830835    obj = klass(*row[index_start:index_end]) 
    831     for f in klass._meta.fields: 
    832         if f.rel and not f.null: 
    833             cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1) 
    834             if cached_row: 
    835                 rel_obj, index_end = cached_row 
    836                 setattr(obj, f.get_cache_name(), rel_obj) 
     836    for f in fields_to_join.iterkeys(): 
     837       cached_row = get_cached_row(f.rel.to, row, index_end, fields_to_join[f], max_depth, cur_depth+1) 
     838       if cached_row: 
     839               rel_obj, index_end = cached_row 
     840               setattr(obj, f.get_cache_name(), rel_obj) 
    837841    return obj, index_end 
    838842 
    839 def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0): 
     843def get_select_related_fields(opts, fields=None): 
     844    """ 
     845    Helper function that returns a dictionary of fields for select_related() 
     846    """ 
     847    if fields is None: 
     848        fields_to_join = dict([(f, None) for f in opts.fields if f.rel and not f.null]) 
     849    else: 
     850        fields_for_lookup = dict([(f.name, f) for f in opts.fields if f.rel]) 
     851        fields_to_join = dict() 
     852        for f in fields: 
     853            path = f.split(LOOKUP_SEPARATOR) 
     854            try: 
     855                fn = fields_for_lookup[path[0]] 
     856                if fn not in fields_to_join: 
     857                    fields_to_join[fn] = [] 
     858                if len(path) > 1: 
     859                    fields_to_join[fn].append(LOOKUP_SEPARATOR.join(path[1:])) 
     860            except KeyError: 
     861                raise FieldDoesNotExist, '%s has no field named %s' % (opts.object_name, path[0]) 
     862    return fields_to_join 
     863 
     864def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, fields=None, cur_depth=0): 
    840865    """ 
    841866    Helper function that recursively populates the select, tables and where (in 
    842867    place) for select_related queries. 
     868 
     869    Implicit select_related calls on NULL fields will force an INNER JOIN currently. 
    843870    """ 
    844871 
    845872    # If we've got a max_depth set and we've exceeded that depth, bail now. 
    846     if max_depth and cur_depth > max_depth: 
     873    if max_depth and cur_depth >= max_depth: 
    847874        return None 
    848875 
     876    fields_to_join = get_select_related_fields(opts, fields) 
     877 
    849878    qn = backend.quote_name 
    850     for f in opts.fields: 
    851         if f.rel and not f.null: 
    852             db_table = f.rel.to._meta.db_table 
    853             if db_table not in cache_tables_seen: 
    854                 tables.append(qn(db_table)) 
    855             else: # The table was already seen, so give it a table alias. 
    856                 new_prefix = '%s%s' % (db_table, len(cache_tables_seen)) 
    857                 tables.append('%s %s' % (qn(db_table), qn(new_prefix))) 
    858                 db_table = new_prefix 
    859             cache_tables_seen.append(db_table) 
    860             where.append('%s.%s = %s.%s' % \ 
    861                 (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column))) 
    862             select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields]) 
    863             fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1) 
     879    for f in fields_to_join.iterkeys(): 
     880        db_table = f.rel.to._meta.db_table 
     881        if db_table not in cache_tables_seen: 
     882            tables.append(qn(db_table)) 
     883        else: # The table was already seen, so give it a table alias. 
     884            new_prefix = '%s%s' % (db_table, len(cache_tables_seen)) 
     885            tables.append('%s %s' % (qn(db_table), qn(new_prefix))) 
     886            db_table = new_prefix 
     887        cache_tables_seen.append(db_table) 
     888        where.append('%s.%s = %s.%s' % \ 
     889            (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column))) 
     890        select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields]) 
     891        fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, fields_to_join[f], cur_depth+1) 
    864892 
    865893def parse_lookup(kwarg_items, opts): 
    866894    # Helper function that handles converting API kwargs