How to filter a Django QuerySet by a window function

Django ORM API is well-designed and it works well in 90% of cases. Besides, it can be extended with custom QuerySet and Manager classes. But as soon as you need to make a query, which Django isn't suited for, ORM becomes not an aid, but an obstacle. Fortunately, Django allows users to perform raw SQL queries for such cases.

One kind of the queries, which Django can't execute, is a filter by a window function. In raw SQL you can do that by making SELECT from a subquery, but Django doesn't support this.

3.5. Window Functions
3.5. Window Functions A window function performs a calculation across a set of table rows that are somehow related to the …

In this post, I will explain why would you need to do such a filter and how this problem can be overcome by wrapping a given QuerySet in a raw SQL query.

All the code in this article was tested in Django 3.1.

TL;DR

Given a QuerySet, annotated with a window function:

from django.db import models
from django.db.models.functions import window

from catalog.models import Product

products = Product.objects.all().annotate(row_number=models.Window(
    expression=window.RowNumber(),
    partition_by=[models.F('category_id')],
    order_by=[models.F('price').asc()],
)).order_by('price')
Pay attention to the row_number=models.Window(...) part

We can extract raw SQL out of this QuerySet like this:

sql, params = products.query.sql_with_params()

To wrap in a query of our own:

count_per_category = 3
count = 10
products_filtered = Product.objects.raw("""
        SELECT * FROM ({}) products_with_row_numbers
        WHERE row_number <= %s
        LIMIT %s
    """.format(sql),
    [*params, count_per_category, count],
)

selected_products = list(products_filtered)
We put QuerySet's sql inside our SQL query and use QuerySet's params to execute it

Now let me elaborate.

The problem: why would you need to filter by a window function

Let's say we're working on an online shop. Its catalog is organized into categories:

class Category(models.Model):
    name = models.CharField(max_length=255)

Which are populated with products:

class Product(models.Model):
    category = models.ForeignKey(Category, on_delete=models.CASCADE)
    name = models.CharField(max_length=255)
    price = models.DecimalField(max_digits=12, decimal_places=2)

We're given a task to display a selection of the 10 most inexpensive products in the catalog. Easy one, right? It's a simple query:

products = Product.objects.all().order_by('price')[:10]
ID Category Name Price
101 Accessories A red phone case $10
102 Accessories A blue phone case $15
103 Accessories A black phone case $16
104 Accessories A gray phone case $16
... ... ... ...

But here's a thing: the selection must be representative of the diverse assortment of the shop. It would not be helpful if the selection contained just 10 pieces of accessories, we would like to display products from more pricey categories too. So a new condition is put upon the selection: it must contain no more than 3 products from a single category.

The selection is expected to look somewhat like this:

ID Category Name Price
101 Accessories A red phone case $10
102 Accessories A blue phone case $15
103 Accessories A black phone case $16
201 Smartphones Simple Android phone $200
202 Smartphones Simple Android phone (with 4G) $220
... ... ... ...

If you're using a suitable DBMS, you can solve this by using a window function. It's a powerful SQL feature, but an explanation of their usefulness is out of the scope of this article, so I will only demonstrate how the correct query looks like:

from django.db import models
from django.db.models.functions import window

products = Product.objects.all().annotate(row_number=models.Window(
    expression=window.RowNumber(),
    partition_by=[models.F('category_id')],
    order_by=[models.F('price').asc()],
)).order_by('price')

For PostgreSQL this QuerySet will produce the following SQL:

SELECT
    "catalog_product"."id",
    "catalog_product"."category_id",
    "catalog_product"."name",
    "catalog_product"."price",
    ROW_NUMBER() OVER (
        PARTITION BY "catalog_product"."category_id"
        ORDER BY "catalog_product"."price" ASC
    ) AS "row_number"
FROM
    "catalog_product"
ORDER BY
    "catalog_product"."price" ASC

In this query, we add a row_number attribute to each product. It's the ordinal number of the product in the list of the cheapest product of its category. So the first case in the query will have row_number set to 1, the next case will have a row_number of 2, and so on. And the first smartphone will have a row_number of 1, the next one - of 2...

So, the problem is solved? We just need to add a filter by row_number to exclude 4th, 5th, etc. cheapest products from the result. Yet if we attempt to make such a filter:

products = products.filter(row_number__lte=3)

We will get an exception:

django.db.utils.NotSupportedError: Window is disallowed in the filter clause.

Neither Django nor PostgreSQL support filtering rows by the values of a window function! Fortunately, in the case of PostgreSQL we can overcome this by wrapping our query and putting the WHERE clause in the outer query:

SELECT *
FROM
    (
        SELECT
            "catalog_product"."id",
            "catalog_product"."category_id",
            "catalog_product"."name",
            "catalog_product"."price",
            ROW_NUMBER() OVER (
                PARTITION BY "catalog_product"."category_id"
                ORDER BY "catalog_product"."price" ASC
            ) AS "row_number"
        FROM
            "catalog_product"
        ORDER BY
            "catalog_product"."price" ASC
    ) subquery
WHERE subquery.row_number < 3
We calculate the window function in the inner query, and filter by it in the outer query
ID Category Name Price Row Number
101 Accessories A red phone case $10 1
102 Accessories A blue phone case $15 2
103 Accessories A black phone case $16 3
201 Smartphones Simple Android phone $200 1
202 Smartphones Simple Android phone (with 4G) $220 2
... ... ... ...

How would we write this query using Django ORM? Unfortunately, I don't think this is possible with the ORM API. I will describe a way to do this using raw SQL. It's not very portable, but for the projects which only use a single DBMS, it should be a good enough solution.

Writing the filter with minimum SQL

If you have previously debugged Django's queries, you might know about the QuerySet.query attribute, which you might have used to print the SQL of a query:

print(products.query)

QuerySet.query is an object of the class django.db.models.sql.Query and it's important to know, that str(products.query) does not always return a valid SQL.

As stated in the docstring of Query.__str__:

def __str__(self):
    """
    Return the query as a string of SQL with the parameter values
    substituted in (use sql_with_params() to see the unsubstituted string).

    Parameter values won't necessarily be quoted correctly, since that is
    done by the database interface at execution time.
    """
"Parameter values won't necessarily be quoted correctly"

We should use Query.sql_with_params instead. It's a method that returns the raw SQL and its parameters separately.

sql, params = products.query.sql_with_params()

We can then modify the SQL as we'd like and substitute its parameters only when we need to execute it.

products_filtered = Product.objects.raw("""
        SELECT * FROM ({}) products_with_row_numbers
        WHERE row_number <= 3
        LIMIT 10
    """.format(sql),
    params,
)
ID Category Name Price Row Number
101 Accessories A red phone case $10 1
102 Accessories A blue phone case $15 2
103 Accessories A black phone case $16 3
201 Smartphones Simple Android phone $200 1
202 Smartphones Simple Android phone (with 4G) $220 2
... ... ... ...

Let's unhardcode the 3 and the 10 and make them into parameters. We have to make sure to put them in the correct order together with the parameters of the query that we're wrapping.

count_per_category = 3
count = 10
products_filtered = Product.objects.raw("""
        SELECT * FROM ({}) products_with_row_numbers
        WHERE row_number <= %s
        LIMIT %s
    """.format(sql),
    [*params, count_per_category, count],
)
Notice how we substitute the SQL using the {} placeholder and the parameters – using the %s placeholder. 

That's it. products_filtered is a QuerySet which will contain the selection of the products that we need.

Here's the full code:

from django.db import models
from django.db.models.functions import window

from catalog.models import Product

products = Product.objects.all().annotate(row_number=models.Window(
    expression=window.RowNumber(),
    partition_by=[models.F('category_id')],
    order_by=[models.F('price').asc()],
)).order_by('price')

sql, params = products.query.sql_with_params()

count_per_category = 3
count = 10
products_filtered = Product.objects.raw("""
        SELECT * FROM ({}) products_with_row_numbers
        WHERE row_number <= %s
        LIMIT %s
    """.format(sql),
    [*params, count_per_category, count],
)

selected_products = list(products_filtered)

Bonus: how to keep the same ordering in the window function as in the QuerySet

You might have notices that in the QuerySet we have to specify the ordering twice: in the window function and in the QuerySet itself:

products = Product.objects.all().annotate(row_number=models.Window(
    expression=window.RowNumber(),
    partition_by=[models.F('category_id')],
    order_by=[models.F('price').asc()],  # <- here
)).order_by('price')                     # <- and here

But what if we don't have access to the original ordering of the QuerySet? For example, if the QuerySet has come from some other part of the project.

def annotate_row_number(products):
    return products.annotate(row_number=models.Window(
    	expression=window.RowNumber(),
    	partition_by=[models.F('category_id')],
    	order_by=[ "what to put here?" ],
    )

It's possible to programmatically generate ordering for the windows function from a QuerySet. Query's order is written into the QuerySet.query.order_by attribute which is a tuple of strings as they were passed to the order_by method:

products = Product.objects.all().order_by('price')
print(products.query.order_by)  # ('price',)

products = Product.objects.all().order_by('-price', 'name')
print(products.query.order_by)  # ('-price', 'name')

products = Product.objects.all()
# empty tuple - ordering is not specified
print(products.query.order_by)  # ()

We have to transform this tuple into a list of objects as required by models.Window. It's important to take note that descending order is denoted differently in these cases: -price in the products.query.order_by should be models.F('price').desc() in models.Window – without the dash.

Here is my code to prepare the ordering for the window function:

order_by = []
for field in products.query.order_by:
    if field.startswith('-'):
        # Descending order
        desc = True
        field = field[1:]  # remove the dash
    else:
        # Ascending order
        desc = False

    order_field = models.F(field)
    if desc:
        order_field = order_field.desc()
    else:
        order_field = order_field.asc()

    order_by.append(order_field)

The resulting list order_by can then be passed to models.Window:

products = products.annotate(row_number=models.Window(
    expression=window.RowNumber(),
    partition_by=[models.F('category_id')],
    order_by=order_by,
))

Conclusion

Thank you for reading this post. In case you have experienced a similar problem before and you can share it, or if you can correct me or build a better solution upon mine, I would like to see it in the comments below.