Index: query.py
===================================================================
--- query.py	(revision 4300)
+++ query.py	(working copy)
@@ -80,6 +80,8 @@
         self._filters = Q()
         self._order_by = None        # Ordering, e.g. ('date', '-name'). If None, use model's ordering.
         self._select_related = False # Whether to fill cache for related objects.
+        self._recurse_depth = 0      # Used to track how deep we are following for select_related()
+        self._recurse_fields = []    # Fields to recurse through for select_related()
         self._distinct = False       # Whether the query should use SELECT DISTINCT.
         self._select = {}            # Dictionary of attname -> SQL.
         self._where = []             # List of extra WHERE clauses to use.
@@ -178,7 +180,7 @@
                 raise StopIteration
             for row in rows:
                 if fill_cache:
-                    obj, index_end = get_cached_row(self.model, row, 0)
+                    obj, index_end = get_cached_row(self.model, row, 0, self._recurse_fields, self._recurse_depth)
                 else:
                     obj = self.model(*row[:index_end])
                 for i, k in enumerate(extra_select):
@@ -194,12 +196,12 @@
         counter._select_related = False
         select, sql, params = counter._get_sql_clause()
         cursor = connection.cursor()
+        id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table),
+                backend.quote_name(self.model._meta.pk.column))
         if self._distinct:
-            id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table),
-                    backend.quote_name(self.model._meta.pk.column))
             cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params)
         else:
-            cursor.execute("SELECT COUNT(*)" + sql, params)
+            cursor.execute("SELECT COUNT(%s)" % id_col + sql, params)
         return cursor.fetchone()[0]
 
     def get(self, *args, **kwargs):
@@ -359,9 +361,13 @@
         else:
             return self._filter_or_exclude(None, **filter_obj)
 
-    def select_related(self, true_or_false=True):
+    # fields should be a list of field names in the root table, if specified, it modifies depth to 1
+    # depth is the maximum number of children to recurse through, defaults to infinite
+    def select_related(self, true_or_false=True, depth=0, fields=[]):
         "Returns a new QuerySet instance with '_select_related' modified."
-        return self._clone(_select_related=true_or_false)
+        if fields != []:
+            depth = 1
+        return self._clone(_select_related=true_or_false, _recurse_depth=depth, _recurse_fields=fields)
 
     def order_by(self, *field_names):
         "Returns a new QuerySet instance with the ordering changed."
@@ -395,6 +401,8 @@
         c._filters = self._filters
         c._order_by = self._order_by
         c._select_related = self._select_related
+        c._recurse_fields = self._recurse_fields
+        c._recurse_depth = self._recurse_depth
         c._distinct = self._distinct
         c._select = self._select.copy()
         c._where = self._where[:]
@@ -448,7 +456,7 @@
 
         # Add additional tables and WHERE clauses based on select_related.
         if self._select_related:
-            fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
+            fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table], self._recurse_depth, self._recurse_fields)
 
         # Add any additional SELECTs.
         if self._select:
@@ -660,24 +668,30 @@
         return backend.get_fulltext_search_sql(table_prefix + field_name)
     raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type)
 
-def get_cached_row(klass, row, index_start):
+def get_cached_row(klass, row, index_start, fields=[], max_depth=0, cur_depth=0):
     "Helper function that recursively returns an object with cache filled"
+    if max_depth and cur_depth > max_depth:
+        return None
     index_end = index_start + len(klass._meta.fields)
     obj = klass(*row[index_start:index_end])
     for f in klass._meta.fields:
-        if f.rel and not f.null:
-            rel_obj, index_end = get_cached_row(f.rel.to, row, index_end)
-            setattr(obj, f.get_cache_name(), rel_obj)
+        if f.rel and not f.null and (fields == [] or (cur_depth == 0 and f.name in fields)):
+            cached_row = get_cached_row(f.rel.to, row, index_end, fields, max_depth, cur_depth+1)
+            if cached_row:
+                    rel_obj, index_end = cached_row
+                    setattr(obj, f.get_cache_name(), rel_obj)
     return obj, index_end
 
-def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen):
+def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, fields=[], cur_depth=0):
     """
     Helper function that recursively populates the select, tables and where (in
     place) for select_related queries.
     """
     qn = backend.quote_name
+    if max_depth and cur_depth > max_depth:
+        return
     for f in opts.fields:
-        if f.rel and not f.null:
+        if f.rel and not f.null and (fields == [] or (cur_depth == 0 and f.name in fields)):
             db_table = f.rel.to._meta.db_table
             if db_table not in cache_tables_seen:
                 tables.append(qn(db_table))
@@ -689,7 +703,7 @@
             where.append('%s.%s = %s.%s' % \
                 (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column)))
             select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields])
-            fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen)
+            fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, fields, cur_depth+1)
 
 def parse_lookup(kwarg_items, opts):
     # Helper function that handles converting API kwargs
