Avoiding N+1 queries in Strawberry GraphQL with DataLoaders

published 2024-05-28 [ home ]

If you prefer, the code for this article is also available in a gist.

GraphQL is great because it lets consumers define the schema of their requests, however that makes it somehow harder to optimize than typical REST requests. If you implement GraphQL navively, you will quickly get the infamous N+1 issue. Let me demonstrate the issue with an example project using Quart, Quart-DB, PostgreSQL and Strawberry.

Let us start by defining a simple schema for a music collection and populating the database with a migration:

from quart_db import Connection


async def create_schema(cnx: Connection) -> None:
    await cnx.execute(
        """
            CREATE TABLE bands (
                id bigint PRIMARY KEY,
                name text NOT NULL
            );
            CREATE TABLE albums (
                id bigint PRIMARY KEY,
                name text NOT NULL,
                band_id bigint REFERENCES bands(id)
            );
            CREATE TABLE songs (
                id bigint GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
                name text NOT NULL,
                album_id bigint REFERENCES albums(id)
            );
        """
    )


async def populate(cnx: Connection) -> None:
    await cnx.execute(
        """
            INSERT INTO bands (id, name) VALUES
                (1, 'Dark Tranquillity'),
                (2, 'Pineapple Thief'),
                (3, 'Wintersun');

            INSERT INTO albums (id, name, band_id) VALUES
                (1, 'Haven', 1),
                (2, 'Fiction', 1),
                (3, 'Atoma', 1),
                (4, 'Time I', 3),
                (5, 'Your Wilderness', 2),
                (6, 'Versions of the Truth', 2);

            INSERT INTO songs (name, album_id) VALUES
                ('The Wonders at Your Feet', 1),
                ('Not Built to Last', 1),
                ('Indifferent Suns', 1),
                ('At Loss for Words', 1),
                ('Terminus', 2),
                ('Inside the Particle Storm', 2),
                ('Focus Shift', 2),
                ('Forward Momentum', 3),
                ('Caves and Embers', 3),
                ('When Mountains Fall', 4),
                ('Sons of Winter and Stars', 4),
                ('Land of Snow and Sorrow', 4),
                ('Time', 4),
                ('The Final Thing on My Mind', 5),
                ('Tear You Up', 5);
        """
    )


async def migrate(cnx: Connection) -> None:
    await create_schema(cnx)
    await populate(cnx)

We expose bands, albums and songs; now let us build a GraphQL interface to query them the naive way:

import logging
from typing import Any

import strawberry
from quart_cors import cors
from quart_db import QuartDB
from strawberry.quart.views import GraphQLView as QuartGraphQLView

from quart import Quart, Request, Response

app = Quart("sfdl")
app.logger.setLevel(logging.INFO)

app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl"
app.config["QUART_DB_AUTO_REQUEST_CONNECTION"] = False
db = QuartDB(app)

cors(app, allow_origin="*", allow_methods=["GET", "POST"])


@strawberry.type
class Song:
    id: int
    name: str
    album_id: int


@strawberry.type
class Album:
    id: int
    name: str
    band_id: int

    @strawberry.field
    async def songs(self) -> list[Song]:
        query = """
            SELECT id, name, album_id
            FROM songs
            WHERE album_id = :album_id
        """
        async with db.connection() as cnx:
            result = await cnx.fetch_all(query, {"album_id": self.id})
        songs = [Song(**row) for row in result]
        app.logger.info(f"Got {len(songs)} songs.")
        return songs


@strawberry.type
class Band:
    id: int
    name: str

    @strawberry.field
    async def albums(self) -> list[Album]:
        query = """
            SELECT id, name, band_id
            FROM albums
            WHERE band_id = :band_id
        """
        async with db.connection() as cnx:
            result = await cnx.fetch_all(query, {"band_id": self.id})
        albums = [Album(**row) for row in result]
        app.logger.info(f"Got {len(albums)} albums.")
        return albums


@strawberry.type
class Query:
    @strawberry.field
    async def bands(self) -> list[Band]:
        query = """
            SELECT id, name
            FROM bands
        """
        async with db.connection() as cnx:
            result = await cnx.fetch_all(query)
        bands = [Band(**row) for row in result]
        app.logger.info(f"Got {len(bands)} bands.")
        return bands


class GraphQLView(QuartGraphQLView):
    async def get_context(self, request: Request, response: Response) -> dict[str, Any]:
        return {"request": request, "response": response}


view = GraphQLView.as_view(
    "graphql_view",
    schema=strawberry.Schema(query=Query),
    graphql_ide="graphiql",
)
app.add_url_rule("/", view_func=view)

We can try it with GraphiQL, to check that it works:

music collection request in GraphiQL

It does, but if we look at the logs we see this:

INFO in __init__: Got 3 bands.
INFO in __init__: Got 2 albums.
INFO in __init__: Got 1 albums.
INFO in __init__: Got 3 albums.
INFO in __init__: Got 2 songs.
INFO in __init__: Got 4 songs.
INFO in __init__: Got 0 songs.
INFO in __init__: Got 4 songs.
INFO in __init__: Got 2 songs.
INFO in __init__: Got 3 songs.
[59520] [INFO] 127.0.0.1:52944 POST / 1.1 200 806 5311

It does too many requests: first it gets all the artists, then for each artist it gets the albums, then for each album it gets the songs… That is the N+1 problem.

If you look it up online, you will quickly find out that a popular solution is dataloaders. However, a lot of examples you will find do not correspond exactly to that problem, because they only demonstrate how to use dataloaders on primary keys, whereas here we want to load on a foreign key. This is the case of the official Strawberry documentation, which you should still read too if you use that library.

The important thing to understand with dataloaders is that they take a list of keys and return a list of answers of the same size, correponding to those keys. So you cannot, for instance, have a dataloader that takes a list of album IDs and returns a list of all the songs in those albums. However, the trick is that you can have a dataloader that returns a list of lists of songs corresponding to each album!

Now let us rewrite our example with that pattern:

import logging
from collections import defaultdict
from functools import cached_property
from typing import Any

import strawberry
from quart_cors import cors
from quart_db import QuartDB
from strawberry.dataloader import DataLoader
from strawberry.quart.views import GraphQLView as QuartGraphQLView
from strawberry.types import Info

from quart import Quart, Request, Response

app = Quart("sfdl")
app.logger.setLevel(logging.INFO)

app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl"
app.config["QUART_DB_AUTO_REQUEST_CONNECTION"] = False
db = QuartDB(app)

cors(app, allow_origin="*", allow_methods=["GET", "POST"])


@strawberry.type
class Song:
    id: int
    name: str
    album_id: int


@strawberry.type
class Album:
    id: int
    name: str
    band_id: int

    @strawberry.field
    async def songs(self, info: Info) -> list[Song]:
        dl = info.context["dataloaders"].songs_for_albums
        return await dl.load(self.id)


@strawberry.type
class Band:
    id: int
    name: str

    @strawberry.field
    async def albums(self, info: Info) -> list[Album]:
        dl = info.context["dataloaders"].albums_for_bands
        return await dl.load(self.id)


@strawberry.type
class Query:
    @strawberry.field
    async def bands(self) -> list[Band]:
        query = """
            SELECT id, name
            FROM bands
        """
        async with db.connection() as cnx:
            result = await cnx.fetch_all(query)
        bands = [Band(**row) for row in result]
        app.logger.info(f"Got {len(bands)} bands.")
        return bands


class DataLoaders:
    @staticmethod
    async def load_songs_for_albums(keys: list[int]) -> list[list[Song]]:
        query = """
            SELECT id, name, album_id
            FROM songs
            WHERE album_id = ANY(:keys)
        """
        async with db.connection() as cnx:
            result = await cnx.fetch_all(query, {"keys": keys})
        songs = [Song(**row) for row in result]
        app.logger.info(f"Got {len(songs)} songs.")
        by_key: defaultdict[int, list[Song]] = defaultdict(list)
        for song in songs:
            by_key[song.album_id].append(song)
        return [by_key[k] for k in keys]

    @staticmethod
    async def load_albums_for_bands(keys: list[int]) -> list[list[Album]]:
        query = """
            SELECT id, name, band_id
            FROM albums
            WHERE band_id = ANY(:keys)
        """
        async with db.connection() as cnx:
            result = await cnx.fetch_all(query, {"keys": keys})
        albums = [Album(**row) for row in result]
        app.logger.info(f"Got {len(albums)} albums.")
        by_key: defaultdict[int, list[Album]] = defaultdict(list)
        for album in albums:
            by_key[album.band_id].append(album)
        return [by_key[k] for k in keys]

    @cached_property
    def songs_for_albums(self) -> DataLoader[int, list[Song]]:
        return DataLoader(self.load_songs_for_albums)

    @cached_property
    def albums_for_bands(self) -> DataLoader[int, list[Album]]:
        return DataLoader(self.load_albums_for_bands)


class GraphQLView(QuartGraphQLView):
    async def get_context(self, request: Request, response: Response) -> dict[str, Any]:
        return {"request": request, "response": response, "dataloaders": DataLoaders()}


view = GraphQLView.as_view(
    "graphql_view",
    schema=strawberry.Schema(query=Query),
    graphql_ide="graphiql",
)
app.add_url_rule("/", view_func=view)

The request for the whole collection still works, but now if we look at the logs we can see this:

INFO in __init__: Got 3 bands.
INFO in __init__: Got 6 albums.
INFO in __init__: Got 15 songs.
[59520] [INFO] 127.0.0.1:60244 POST / 1.1 200 806 4509

No more N+1, our job is done here.