Ticket #5020: query.py.2.diff

File query.py.2.diff, 5.5 KB (added by David Cramer <dcramer@…>, 17 years ago)

fixed two minor bugs

  • query.py

    old new  
    408408        else:
    409409            return self._filter_or_exclude(None, **filter_obj)
    410410
    411     def select_related(self, true_or_false=True, depth=0):
     411    def select_related(self, *fields, **kwargs):
    412412        "Returns a new QuerySet instance with '_select_related' modified."
    413         return self._clone(_select_related=true_or_false, _max_related_depth=depth)
     413        true_or_false = kwargs.pop('true_or_false', True)
     414        depth = kwargs.pop('depth', 0)
     415        if not fields: fields = None
     416        return self._clone(_select_related=true_or_false, _max_related_depth=depth, _recurse_fields=fields)
    414417
    415418    def order_by(self, *field_names):
    416419        "Returns a new QuerySet instance with the ordering changed."
     
    819822            raise NotImplementedError
    820823    raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type)
    821824
    822 def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0):
     825def get_cached_row(klass, row, index_start, fields=None, max_depth=0, cur_depth=0):
    823826    """Helper function that recursively returns an object with cache filled"""
    824827
    825828    # If we've got a max_depth set and we've exceeded that depth, bail now.
    826     if max_depth and cur_depth > max_depth:
     829    if max_depth and cur_depth >= max_depth:
    827830        return None
    828831
     832    fields_to_join = get_select_related_fields(klass._meta, fields)
     833
    829834    index_end = index_start + len(klass._meta.fields)
    830835    obj = klass(*row[index_start:index_end])
    831     for f in klass._meta.fields:
    832         if f.rel and not f.null:
    833             cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1)
    834             if cached_row:
    835                 rel_obj, index_end = cached_row
    836                 setattr(obj, f.get_cache_name(), rel_obj)
     836    for f in fields_to_join.iterkeys():
     837       cached_row = get_cached_row(f.rel.to, row, index_end, fields_to_join[f], max_depth, cur_depth+1)
     838       if cached_row:
     839               rel_obj, index_end = cached_row
     840               setattr(obj, f.get_cache_name(), rel_obj)
    837841    return obj, index_end
    838842
    839 def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0):
     843def get_select_related_fields(opts, fields=None):
     844    """
     845    Helper function that returns a dictionary of fields for select_related()
     846    """
     847    if fields is None:
     848        fields_to_join = dict([(f, None) for f in opts.fields if f.rel and not f.null])
     849    else:
     850        fields_for_lookup = dict([(f.name, f) for f in opts.fields if f.rel])
     851        fields_to_join = dict()
     852        for f in fields:
     853            path = f.split(LOOKUP_SEPARATOR)
     854            try:
     855                fn = fields_for_lookup[path[0]]
     856                if fn not in fields_to_join:
     857                    fields_to_join[fn] = []
     858                if len(path) > 1:
     859                    fields_to_join[fn].append(LOOKUP_SEPARATOR.join(path[1:]))
     860            except KeyError:
     861                raise FieldDoesNotExist, '%s has no field named %s' % (opts.object_name, path[0])
     862    return fields_to_join
     863
     864def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, fields=None, cur_depth=0):
    840865    """
    841866    Helper function that recursively populates the select, tables and where (in
    842867    place) for select_related queries.
     868
     869    Implicit select_related calls on NULL fields will force an INNER JOIN currently.
    843870    """
    844871
    845872    # If we've got a max_depth set and we've exceeded that depth, bail now.
    846     if max_depth and cur_depth > max_depth:
     873    if max_depth and cur_depth >= max_depth:
    847874        return None
    848875
     876    fields_to_join = get_select_related_fields(opts, fields)
     877
    849878    qn = backend.quote_name
    850     for f in opts.fields:
    851         if f.rel and not f.null:
    852             db_table = f.rel.to._meta.db_table
    853             if db_table not in cache_tables_seen:
    854                 tables.append(qn(db_table))
    855             else: # The table was already seen, so give it a table alias.
    856                 new_prefix = '%s%s' % (db_table, len(cache_tables_seen))
    857                 tables.append('%s %s' % (qn(db_table), qn(new_prefix)))
    858                 db_table = new_prefix
    859             cache_tables_seen.append(db_table)
    860             where.append('%s.%s = %s.%s' % \
    861                 (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column)))
    862             select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields])
    863             fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1)
     879    for f in fields_to_join.iterkeys():
     880        db_table = f.rel.to._meta.db_table
     881        if db_table not in cache_tables_seen:
     882            tables.append(qn(db_table))
     883        else: # The table was already seen, so give it a table alias.
     884            new_prefix = '%s%s' % (db_table, len(cache_tables_seen))
     885            tables.append('%s %s' % (qn(db_table), qn(new_prefix)))
     886            db_table = new_prefix
     887        cache_tables_seen.append(db_table)
     888        where.append('%s.%s = %s.%s' % \
     889            (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column)))
     890        select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields])
     891        fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, fields_to_join[f], cur_depth+1)
    864892
    865893def parse_lookup(kwarg_items, opts):
    866894    # Helper function that handles converting API kwargs
Back to Top