Opened 8 years ago

Closed 3 years ago

Last modified 3 years ago

#26602 closed New feature (needsinfo)

Provide a way to manage grouping with RawSQL

Reported by: Jamie Cockburn Owned by: Manav Agarwal
Component: Database layer (models, ORM) Version: dev
Severity: Normal Keywords: QuerySet.extra
Cc: josh.smeaton@… Triage Stage: Accepted
Has patch: no Needs documentation: no
Needs tests: no Patch needs improvement: no
Easy pickings: no UI/UX: no

Description (last modified by Tobin Brown)

I wanted to annotate my objects with a running total:

class A(models.Model):
    amount = models.IntegerField()
    created = models.DateTimeField(auto_now_add=True)

qs = A.objects.annotate(total=RawSQL("SUM(amount) OVER (ORDER BY created)", []))

That works fine, and I get a running total for each object, but I cannot call count() on that queryset:

>>> qs.count()
Traceback...
ProgrammingError: window functions not allowed in GROUP BY clause

Using extra(), I can get the same annotation behaviour as well being able to call count():

>>> qs = A.objects.extra(select={'total': 'SUM(amount) OVER (ORDER BY created)'})
>>> qs.count()
8

Change History (18)

comment:1 by Tobin Brown, 8 years ago

Description: modified (diff)

comment:2 by Josh Smeaton, 8 years ago

Cc: josh.smeaton@… added
Triage Stage: UnreviewedAccepted
Version: 1.9master

This happens because RawSQL defines:

def get_group_by_cols(self):
        return [self]

It'd be helpful if there was a way to disable grouping in an easy way. For the moment, you should be able to do something like:

window = RawSQL("SUM(amount) OVER (ORDER BY created)", [])
window.get_group_by_cols = lambda: []
qs = A.objects.annotate(total=window)
qs.count()

# Or..

RawAggregate(RawSQL):
    contains_aggregate = True  # may not be necessary in newer django versions
    def get_group_by_cols(self):
        return []

qs = A.objects.annotate(total=RawAggregate("SUM(amount) OVER (ORDER BY created)", []))

Can you let me know if this workaround helps you?

In the meantime, I think it's worth figuring out how we want to treat RawSQL when it should not be doing group bys, so I'm accepting on that basis.

Two options I see. First, we can define a whole new class as I've done above - like RawAggregate. Second, we could use a kwarg to RawSQL to tweak whether or not it contains aggregate code.

RawSQL('SELECT SUM() ..', [], grouping=False)

comment:3 by Jamie Cockburn, 8 years ago

Yes, you're suggestions did work as solutions, thank you.

Inspired by your RawAggregate subclass, I decided I'd have a go at writing a proper window function Expression. I was mostly guessing at how the API is supposed to work, so any hints/improvements would be welcome.

class Window(Expression):
    template = '%(expression)s OVER (%(window)s)'

    def __init__(self, expression, partition_by=None, order_by=None, output_field=None):
        self.order_by = order_by
        if isinstance(order_by, six.string_types):
            if order_by.startswith('-'):
                self.order_by = OrderBy(F(self.order_by[1:]), descending=True)
            else:
                self.order_by = OrderBy(F(self.order_by))

        self.partition_by = partition_by
        if self.partition_by:
            self.partition_by = self._parse_expressions(partition_by)[0]

        super(Window, self).__init__(output_field=output_field)
        self.source_expression = self._parse_expressions(expression)[0]
        if not self.source_expression.contains_aggregate:
            raise FieldError("Window function expressions must be aggregate functions")

    def _resolve_output_field(self):
        if self._output_field is None:
            self._output_field = self.source_expression.output_field

    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
        c = self.copy()
        c.source_expression = c.source_expression.resolve_expression(query, allow_joins, reuse, summarize, for_save)
        if c.partition_by:
            c.partition_by = c.partition_by.resolve_expression(query, allow_joins, reuse, summarize, for_save)
        if c.order_by:
            c.order_by = c.order_by.resolve_expression(query, allow_joins, reuse, summarize, for_save)
        c.is_summary = summarize
        c.for_save = for_save
        return c

    def as_sql(self, compiler, connection, function=None, template=None):
        connection.ops.check_expression_support(self)
        expr_sql, params = compiler.compile(self.source_expression)

        window_sql = []
        if self.partition_by:
            window_sql.append('PARTITION BY ')
            order_sql, order_params = compiler.compile(self.partition_by)
            window_sql.append(order_sql)
            params.extend(order_params)
        if self.order_by:
            window_sql.append(' ORDER BY ')
            order_sql, order_params = compiler.compile(self.order_by)
            window_sql.append(order_sql)
            params.extend(order_params)
        template = template or self.template
        return template % {'expression': expr_sql, 'window': "".join(window_sql)}, params

    def copy(self):
        copy = super(Window, self).copy()
        copy.source_expression = self.source_expression.copy()
        copy.partition_by = self.partition_by.copy() if self.partition_by else None
        copy.order_by = self.order_by.copy() if self.order_by else None
        return copy

    def get_group_by_cols(self):
        return []

Which means you can write things like this:

A.objects.annotate(total=Window(Sum('amount'), order_by='created'))

comment:4 by Anssi Kääriäinen, 8 years ago

I guess we could add Window annotations. The problem is that you can't actually filter on window functions, you need a subquery for that. That is:

select SUM(amount) OVER (ORDER BY created) as sum_amount
  from ...
 where sum_amount > 0

is illegal SQL. We'd need to write the query as:

select * from (select SUM(amount) OVER (ORDER BY created) as sum_amount from ...) subq
 where sum_amount > 0

I don't think we have machinery to do that easily right now.

I'm OK adding window aggregates to Django if they fail loudly and clearly when one tries to filter over them. They could be useful even if you can't actually filter on the result.

Maybe a new ticket is better for this than hijacking this ticket?

comment:5 by Jamie Cockburn, 8 years ago

I agree a new ticket is appropriate, as this has moved beyond my original problem.

Do you still intend to add the RawAggregate or RawSQL(..., grouping=False) solutions for this ticket?

Should I create the new ticket or is that your job?

comment:6 by Josh Smeaton, 8 years ago

Window functions were always something I intended to look into, so it's really cool to see that working!

This ticket will focus on providing a way to manage grouping on RawSQL. If you'd like to create a ticket to track actual Window expressions, that'd be awesome. Anyone can create new feature tickets so feel free to.

comment:7 by Jamie Cockburn, 8 years ago

#26608 created

comment:8 by Tim Graham, 8 years ago

Summary: RawSQL window functions break count()Provide a way to manage grouping with RawSQL
Type: UncategorizedNew feature

comment:9 by Manav Agarwal, 3 years ago

Owner: changed from nobody to Manav Agarwal
Status: newassigned

in reply to:  2 comment:10 by Manav Agarwal, 3 years ago

Replying to Josh Smeaton:

This happens because RawSQL defines:

def get_group_by_cols(self):
        return [self]

It'd be helpful if there was a way to disable grouping in an easy way. For the moment, you should be able to do something like:

window = RawSQL("SUM(amount) OVER (ORDER BY created)", [])
window.get_group_by_cols = lambda: []
qs = A.objects.annotate(total=window)
qs.count()

# Or..

RawAggregate(RawSQL):
    contains_aggregate = True  # may not be necessary in newer django versions
    def get_group_by_cols(self):
        return []

qs = A.objects.annotate(total=RawAggregate("SUM(amount) OVER (ORDER BY created)", []))

Can you let me know if this workaround helps you?

In the meantime, I think it's worth figuring out how we want to treat RawSQL when it should not be doing group bys, so I'm accepting on that basis.

Two options I see. First, we can define a whole new class as I've done above - like RawAggregate. Second, we could use a kwarg to RawSQL to tweak whether or not it contains aggregate code.

RawSQL('SELECT SUM() ..', [], grouping=False)

I think we may use Josh's suggestion of adding a kwarg to RawSQL to tweak whether or not it contains aggregate code. Please update if this is fine so that I may submit a patch.
Regards

comment:11 by Manav Agarwal, 3 years ago

I think this bug is solved. I tried to reproduce the same but It didn't reflect any error and returned the correct output.

comment:12 by Claude Paroz, 3 years ago

Manav, if you think this is resolved, the best you can do is to find the commit where this has been fixed. Look for bisect examples in the docs if needed.

comment:13 by Manav Agarwal, 3 years ago

I wrote the following test in the BasicExpressionsTests class in expressions/test_regression.py

    def test_regression(self):
        companies = Company.objects.annotate(salary=RawSQL("SUM(num_chairs) OVER (ORDER BY num_employees)",[]))
        self.assertEqual(companies.count(),3)

And found that the error was not returned after https://github.com/django/django/commits/3f32154f40a855afa063095e3d091ce6be21f2c5 commit but was there in the code before this commit.

comment:14 by Claude Paroz, 3 years ago

Awesome! I'll let a fellow decide if it's worth adding a new regression test or if this issue can be closed right now.

comment:15 by Mariusz Felisiak, 3 years ago

Resolution: wontfix
Status: assignedclosed

Manav, thanks for checking.

With Simon's changes (3f32154f40a855afa063095e3d091ce6be21f2c5 & co.) we avoid unnecessary GROUP BY clauses. This ticket is about adding a way to manage grouping with RawSQL() not about this specific crash, however I think we can close it as "wontfix", unless someone can provide a valid use case.

Version 0, edited 3 years ago by Mariusz Felisiak (next)

comment:16 by Mariusz Felisiak, 3 years ago

I prepared PR with extra tests.

comment:17 by Mariusz Felisiak, 3 years ago

Resolution: wontfixneedsinfo

comment:18 by GitHub <noreply@…>, 3 years ago

In b989d213:

Refs #26602 -- Added tests for aggregating over a RawSQL() annotation.

Fixed in 3f32154f40a855afa063095e3d091ce6be21f2c5.

Thanks Manav Agarwal for initial test.

Note: See TracTickets for help on using tickets.
Back to Top