If you prefer, the code for this article is also available in a gist.
GraphQL is great because it lets consumers define the schema of their requests, however that makes it somehow harder to optimize than typical REST requests. If you implement GraphQL navively, you will quickly get the infamous N+1 issue. Let me demonstrate the issue with an example project using Quart, Quart-DB, PostgreSQL and Strawberry.
Let us start by defining a simple schema for a music collection and populating the database with a migration:
from quart_db import Connection
async def create_schema(cnx: Connection) -> None:
await cnx.execute(
"""
CREATE TABLE bands (
id bigint PRIMARY KEY,
name text NOT NULL
);
CREATE TABLE albums (
id bigint PRIMARY KEY,
name text NOT NULL,
band_id bigint REFERENCES bands(id)
);
CREATE TABLE songs (
id bigint GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
name text NOT NULL,
album_id bigint REFERENCES albums(id)
);
"""
)
async def populate(cnx: Connection) -> None:
await cnx.execute(
"""
INSERT INTO bands (id, name) VALUES
(1, 'Dark Tranquillity'),
(2, 'Pineapple Thief'),
(3, 'Wintersun');
INSERT INTO albums (id, name, band_id) VALUES
(1, 'Haven', 1),
(2, 'Fiction', 1),
(3, 'Atoma', 1),
(4, 'Time I', 3),
(5, 'Your Wilderness', 2),
(6, 'Versions of the Truth', 2);
INSERT INTO songs (name, album_id) VALUES
('The Wonders at Your Feet', 1),
('Not Built to Last', 1),
('Indifferent Suns', 1),
('At Loss for Words', 1),
('Terminus', 2),
('Inside the Particle Storm', 2),
('Focus Shift', 2),
('Forward Momentum', 3),
('Caves and Embers', 3),
('When Mountains Fall', 4),
('Sons of Winter and Stars', 4),
('Land of Snow and Sorrow', 4),
('Time', 4),
('The Final Thing on My Mind', 5),
('Tear You Up', 5);
"""
)
async def migrate(cnx: Connection) -> None:
await create_schema(cnx)
await populate(cnx)
We expose bands, albums and songs; now let us build a GraphQL interface to query them the naive way:
import logging
from typing import Any
import strawberry
from quart_cors import cors
from quart_db import QuartDB
from strawberry.quart.views import GraphQLView as QuartGraphQLView
from quart import Quart, Request, Response
app = Quart("sfdl")
app.logger.setLevel(logging.INFO)
app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl"
app.config["QUART_DB_AUTO_REQUEST_CONNECTION"] = False
db = QuartDB(app)
cors(app, allow_origin="*", allow_methods=["GET", "POST"])
@strawberry.type
class Song:
id: int
name: str
album_id: int
@strawberry.type
class Album:
id: int
name: str
band_id: int
@strawberry.field
async def songs(self) -> list[Song]:
query = """
SELECT id, name, album_id
FROM songs
WHERE album_id = :album_id
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"album_id": self.id})
songs = [Song(**row) for row in result]
app.logger.info(f"Got {len(songs)} songs.")
return songs
@strawberry.type
class Band:
id: int
name: str
@strawberry.field
async def albums(self) -> list[Album]:
query = """
SELECT id, name, band_id
FROM albums
WHERE band_id = :band_id
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"band_id": self.id})
albums = [Album(**row) for row in result]
app.logger.info(f"Got {len(albums)} albums.")
return albums
@strawberry.type
class Query:
@strawberry.field
async def bands(self) -> list[Band]:
query = """
SELECT id, name
FROM bands
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query)
bands = [Band(**row) for row in result]
app.logger.info(f"Got {len(bands)} bands.")
return bands
class GraphQLView(QuartGraphQLView):
async def get_context(self, request: Request, response: Response) -> dict[str, Any]:
return {"request": request, "response": response}
view = GraphQLView.as_view(
"graphql_view",
schema=strawberry.Schema(query=Query),
graphql_ide="graphiql",
)
app.add_url_rule("/", view_func=view)
We can try it with GraphiQL, to check that it works:
It does, but if we look at the logs we see this:
INFO in __init__: Got 3 bands.
INFO in __init__: Got 2 albums.
INFO in __init__: Got 1 albums.
INFO in __init__: Got 3 albums.
INFO in __init__: Got 2 songs.
INFO in __init__: Got 4 songs.
INFO in __init__: Got 0 songs.
INFO in __init__: Got 4 songs.
INFO in __init__: Got 2 songs.
INFO in __init__: Got 3 songs.
[59520] [INFO] 127.0.0.1:52944 POST / 1.1 200 806 5311
It does too many requests: first it gets all the artists, then for each artist it gets the albums, then for each album it gets the songs… That is the N+1 problem.
If you look it up online, you will quickly find out that a popular solution is dataloaders. However, a lot of examples you will find do not correspond exactly to that problem, because they only demonstrate how to use dataloaders on primary keys, whereas here we want to load on a foreign key. This is the case of the official Strawberry documentation, which you should still read too if you use that library.
The important thing to understand with dataloaders is that they take a list of keys and return a list of answers of the same size, correponding to those keys. So you cannot, for instance, have a dataloader that takes a list of album IDs and returns a list of all the songs in those albums. However, the trick is that you can have a dataloader that returns a list of lists of songs corresponding to each album!
Now let us rewrite our example with that pattern:
import logging
from collections import defaultdict
from functools import cached_property
from typing import Any
import strawberry
from quart_cors import cors
from quart_db import QuartDB
from strawberry.dataloader import DataLoader
from strawberry.quart.views import GraphQLView as QuartGraphQLView
from strawberry.types import Info
from quart import Quart, Request, Response
app = Quart("sfdl")
app.logger.setLevel(logging.INFO)
app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl"
app.config["QUART_DB_AUTO_REQUEST_CONNECTION"] = False
db = QuartDB(app)
cors(app, allow_origin="*", allow_methods=["GET", "POST"])
@strawberry.type
class Song:
id: int
name: str
album_id: int
@strawberry.type
class Album:
id: int
name: str
band_id: int
@strawberry.field
async def songs(self, info: Info) -> list[Song]:
dl = info.context["dataloaders"].songs_for_albums
return await dl.load(self.id)
@strawberry.type
class Band:
id: int
name: str
@strawberry.field
async def albums(self, info: Info) -> list[Album]:
dl = info.context["dataloaders"].albums_for_bands
return await dl.load(self.id)
@strawberry.type
class Query:
@strawberry.field
async def bands(self) -> list[Band]:
query = """
SELECT id, name
FROM bands
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query)
bands = [Band(**row) for row in result]
app.logger.info(f"Got {len(bands)} bands.")
return bands
class DataLoaders:
@staticmethod
async def load_songs_for_albums(keys: list[int]) -> list[list[Song]]:
query = """
SELECT id, name, album_id
FROM songs
WHERE album_id = ANY(:keys)
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"keys": keys})
songs = [Song(**row) for row in result]
app.logger.info(f"Got {len(songs)} songs.")
by_key: defaultdict[int, list[Song]] = defaultdict(list)
for song in songs:
by_key[song.album_id].append(song)
return [by_key[k] for k in keys]
@staticmethod
async def load_albums_for_bands(keys: list[int]) -> list[list[Album]]:
query = """
SELECT id, name, band_id
FROM albums
WHERE band_id = ANY(:keys)
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"keys": keys})
albums = [Album(**row) for row in result]
app.logger.info(f"Got {len(albums)} albums.")
by_key: defaultdict[int, list[Album]] = defaultdict(list)
for album in albums:
by_key[album.band_id].append(album)
return [by_key[k] for k in keys]
@cached_property
def songs_for_albums(self) -> DataLoader[int, list[Song]]:
return DataLoader(self.load_songs_for_albums)
@cached_property
def albums_for_bands(self) -> DataLoader[int, list[Album]]:
return DataLoader(self.load_albums_for_bands)
class GraphQLView(QuartGraphQLView):
async def get_context(self, request: Request, response: Response) -> dict[str, Any]:
return {"request": request, "response": response, "dataloaders": DataLoaders()}
view = GraphQLView.as_view(
"graphql_view",
schema=strawberry.Schema(query=Query),
graphql_ide="graphiql",
)
app.add_url_rule("/", view_func=view)
The request for the whole collection still works, but now if we look at the logs we can see this:
INFO in __init__: Got 3 bands.
INFO in __init__: Got 6 albums.
INFO in __init__: Got 15 songs.
[59520] [INFO] 127.0.0.1:60244 POST / 1.1 200 806 4509
No more N+1, our job is done here.