If you prefer, the code for this article is also available in a gist.
GraphQL is great because it lets consumers define the schema of their requests, however that makes it somehow harder to optimize than typical REST requests. If you implement GraphQL navively, you will quickly get the infamous N+1 issue. Let me demonstrate the issue with an example project using Quart, Quart-DB, PostgreSQL and Strawberry.
Let us start by defining a simple schema for a music collection and populating the database with a migration:
from quart_db import Connection
async def create_schema(cnx: Connection) -> None:
await cnx.execute(
id bigint PRIMARY KEY,
name text NOT NULL
id bigint PRIMARY KEY,
name text NOT NULL,
band_id bigint REFERENCES bands(id)
name text NOT NULL,
album_id bigint REFERENCES albums(id)
async def populate(cnx: Connection) -> None:
await cnx.execute(
INSERT INTO bands (id, name) VALUES
(1, 'Dark Tranquillity'),
(2, 'Pineapple Thief'),
(3, 'Wintersun');
INSERT INTO albums (id, name, band_id) VALUES
(1, 'Haven', 1),
(2, 'Fiction', 1),
(3, 'Atoma', 1),
(4, 'Time I', 3),
(5, 'Your Wilderness', 2),
(6, 'Versions of the Truth', 2);
INSERT INTO songs (name, album_id) VALUES
('The Wonders at Your Feet', 1),
('Not Built to Last', 1),
('Indifferent Suns', 1),
('At Loss for Words', 1),
('Terminus', 2),
('Inside the Particle Storm', 2),
('Focus Shift', 2),
('Forward Momentum', 3),
('Caves and Embers', 3),
('When Mountains Fall', 4),
('Sons of Winter and Stars', 4),
('Land of Snow and Sorrow', 4),
('Time', 4),
('The Final Thing on My Mind', 5),
('Tear You Up', 5);
async def migrate(cnx: Connection) -> None:
await create_schema(cnx)
await populate(cnx)
We expose bands, albums and songs; now let us build a GraphQL interface to query them the naive way:
import logging
from typing import Any
import strawberry
from quart_cors import cors
from quart_db import QuartDB
from strawberry.quart.views import GraphQLView as QuartGraphQLView
from quart import Quart, Request, Response
app = Quart("sfdl")
app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl"
db = QuartDB(app)
cors(app, allow_origin="*", allow_methods=["GET", "POST"])
class Song:
id: int
name: str
album_id: int
class Album:
id: int
name: str
band_id: int
async def songs(self) -> list[Song]:
query = """
SELECT id, name, album_id
FROM songs
WHERE album_id = :album_id
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"album_id": self.id})
songs = [Song(**row) for row in result]
app.logger.info(f"Got {len(songs)} songs.")
return songs
class Band:
id: int
name: str
async def albums(self) -> list[Album]:
query = """
SELECT id, name, band_id
FROM albums
WHERE band_id = :band_id
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"band_id": self.id})
albums = [Album(**row) for row in result]
app.logger.info(f"Got {len(albums)} albums.")
return albums
class Query:
async def bands(self) -> list[Band]:
query = """
SELECT id, name
FROM bands
async with db.connection() as cnx:
result = await cnx.fetch_all(query)
bands = [Band(**row) for row in result]
app.logger.info(f"Got {len(bands)} bands.")
return bands
class GraphQLView(QuartGraphQLView):
async def get_context(self, request: Request, response: Response) -> dict[str, Any]:
return {"request": request, "response": response}
view = GraphQLView.as_view(
app.add_url_rule("/", view_func=view)
We can try it with GraphiQL, to check that it works:
It does, but if we look at the logs we see this:
INFO in __init__: Got 3 bands.
INFO in __init__: Got 2 albums.
INFO in __init__: Got 1 albums.
INFO in __init__: Got 3 albums.
INFO in __init__: Got 2 songs.
INFO in __init__: Got 4 songs.
INFO in __init__: Got 0 songs.
INFO in __init__: Got 4 songs.
INFO in __init__: Got 2 songs.
INFO in __init__: Got 3 songs.
[59520] [INFO] POST / 1.1 200 806 5311
It does too many requests: first it gets all the artists, then for each artist it gets the albums, then for each album it gets the songs… That is the N+1 problem.
If you look it up online, you will quickly find out that a popular solution is dataloaders. However, a lot of examples you will find do not correspond exactly to that problem, because they only demonstrate how to use dataloaders on primary keys, whereas here we want to load on a foreign key. This is the case of the official Strawberry documentation, which you should still read too if you use that library.
The important thing to understand with dataloaders is that they take a list of keys and return a list of answers of the same size, correponding to those keys. So you cannot, for instance, have a dataloader that takes a list of album IDs and returns a list of all the songs in those albums. However, the trick is that you can have a dataloader that returns a list of lists of songs corresponding to each album!
Now let us rewrite our example with that pattern:
import logging
from collections import defaultdict
from functools import cached_property
from typing import Any
import strawberry
from quart_cors import cors
from quart_db import QuartDB
from strawberry.dataloader import DataLoader
from strawberry.quart.views import GraphQLView as QuartGraphQLView
from strawberry.types import Info
from quart import Quart, Request, Response
app = Quart("sfdl")
app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl"
db = QuartDB(app)
cors(app, allow_origin="*", allow_methods=["GET", "POST"])
class Song:
id: int
name: str
album_id: int
class Album:
id: int
name: str
band_id: int
async def songs(self, info: Info) -> list[Song]:
dl = info.context["dataloaders"].songs_for_albums
return await dl.load(self.id)
class Band:
id: int
name: str
async def albums(self, info: Info) -> list[Album]:
dl = info.context["dataloaders"].albums_for_bands
return await dl.load(self.id)
class Query:
async def bands(self) -> list[Band]:
query = """
SELECT id, name
FROM bands
async with db.connection() as cnx:
result = await cnx.fetch_all(query)
bands = [Band(**row) for row in result]
app.logger.info(f"Got {len(bands)} bands.")
return bands
class DataLoaders:
async def load_songs_for_albums(keys: list[int]) -> list[list[Song]]:
query = """
SELECT id, name, album_id
FROM songs
WHERE album_id = ANY(:keys)
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"keys": keys})
songs = [Song(**row) for row in result]
app.logger.info(f"Got {len(songs)} songs.")
by_key: defaultdict[int, list[Song]] = defaultdict(list)
for song in songs:
return [by_key[k] for k in keys]
async def load_albums_for_bands(keys: list[int]) -> list[list[Album]]:
query = """
SELECT id, name, band_id
FROM albums
WHERE band_id = ANY(:keys)
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"keys": keys})
albums = [Album(**row) for row in result]
app.logger.info(f"Got {len(albums)} albums.")
by_key: defaultdict[int, list[Album]] = defaultdict(list)
for album in albums:
return [by_key[k] for k in keys]
def songs_for_albums(self) -> DataLoader[int, list[Song]]:
return DataLoader(self.load_songs_for_albums)
def albums_for_bands(self) -> DataLoader[int, list[Album]]:
return DataLoader(self.load_albums_for_bands)
class GraphQLView(QuartGraphQLView):
async def get_context(self, request: Request, response: Response) -> dict[str, Any]:
return {"request": request, "response": response, "dataloaders": DataLoaders()}
view = GraphQLView.as_view(
app.add_url_rule("/", view_func=view)
The request for the whole collection still works, but now if we look at the logs we can see this:
INFO in __init__: Got 3 bands.
INFO in __init__: Got 6 albums.
INFO in __init__: Got 15 songs.
[59520] [INFO] POST / 1.1 200 806 4509
No more N+1, our job is done here.