Ticket #16436: 16436-annotate-select_related-only-r16522.diff

File 16436-annotate-select_related-only-r16522.diff, 3.1 KB (added by mrmachine, 4 years ago)

Failing test case and fix.

  • tests/regressiontests/defer_annotate_select_related/tests.py

     
     1from django.db.models import Count
     2from django.test import TestCase
     3from models import *
     4
     5class DeferAnnotateSelectRelatedTest(TestCase):
     6    def test(self):
     7        location = Location.objects.create()
     8        request = Request.objects.create(location=location)
     9        self.assertIsInstance(list(Request.objects
     10            .annotate(Count('items')).select_related('profile', 'location')
     11            .only('profile', 'location')), list)
     12        self.assertIsInstance(list(Request.objects
     13            .annotate(Count('items')).select_related('profile', 'location')
     14            .only('profile__profile1', 'location__location1')), list)
     15        self.assertIsInstance(list(Request.objects
     16            .annotate(Count('items')).select_related('profile', 'location')
     17            .defer('request1', 'request2', 'request3', 'request4')), list)
  • tests/regressiontests/defer_annotate_select_related/models.py

     
     1from django.db import models
     2
     3class Profile(models.Model):
     4    profile1 = models.TextField(default='profile1')
     5
     6class Location(models.Model):
     7    location1 = models.TextField(default='location1')
     8
     9class Item(models.Model):
     10    pass
     11
     12class Request(models.Model):
     13    profile = models.ForeignKey(Profile, null=True, blank=True)
     14    location = models.ForeignKey(Location)
     15    items = models.ManyToManyField(Item)
     16
     17    request1 = models.TextField(default='request1')
     18    request2 = models.TextField(default='request2')
     19    request3 = models.TextField(default='request3')
     20    request4 = models.TextField(default='request4')
  • django/db/models/sql/compiler.py

     
    717717                    row = self.resolve_columns(row, fields)
    718718
    719719                if has_aggregate_select:
    720                     aggregate_start = len(self.query.extra_select.keys()) + len(self.query.select)
     720                    loaded_fields = self.query.get_loaded_field_names().get(self.query.model, set()) or self.query.select
     721                    aggregate_start = len(self.query.extra_select.keys()) + len(loaded_fields)
    721722                    aggregate_end = aggregate_start + len(self.query.aggregate_select)
    722723                    row = tuple(row[:aggregate_start]) + tuple([
    723724                        self.query.resolve_aggregate(value, aggregate, self.connection)
Back to Top