- Published on
ARQ + SQLAlchemy - Async done right
- Authors
- Name
- Daniel Herrmann
When writing and designing APIs, at some point you will have jobs that should be deferred and handled outside of the main request / response lifecycle, so ensure a timely response to the API user. There are many libraries out there, including Celery, Dramatiq, RQ and Walnats, or more recent additions such as Taskiq. In my projects I typically use ARQ which provided cronjob support and is relatively easy to use.
SQLAlchemy is IMHO the de-facto standard to provide ORM functionality in Python, and its widely used and supported within the community. Unlike FastAPI, ARQ does not support dependency injection, nor does it have a great documentation about properly integrating a database layer including proper pool and session handling. The approach below is what I came up with when combining FastAPI as API framework and ARQ as distributed job queue, with the following design goals in mind:
- Reuse existing model definitions across API and worker
- Reuse existing CRUD methods
- Full async support
- Task scoped DB sessions
- Automatic rollback on error and commit on success
The generic idea is to have a connection pool available, which is then used to instantiate one DB session per task that is executed. At the end of the task, the session is automatically returned to the pool, .commit()
is executed on success and rollback()
otherwise.
Project Structure
We're assuming an existing project containing the SQLAlchemy models that we want to re-use. Optionally, it could also contain a set up CRUD utilities to simplify database interaction. The folder structure is as follows:
.
├── README.md
├── alembic.ini
├── pyproject.toml
└── src
├── backend
│ └── models
│ ├── __init__.py
│ └── hero.py
└── worker
├── db.py
├── main.py
└── tasks
└── hero.py
Database Connection
The idea is to use Python contextvars to differentiate different job and/or task executions, and then make use of SQLAlchemys scoped_session
to get a session relating to this particular request. Using a scoped session allows us to access the database session from various places, and as long as the related context variable doesn't change, we'll always get the same session without the need to pass the session object around.
Database session manager
We're encapsulating the database methods into a dedicated class. It handles connecting to the database, tearing down the connection on shutdown and providing a session to us if needed:
# src/worker/db.py
from contextvars import ContextVar
from sqlalchemy.ext.asyncio import async_scoped_session, async_sessionmaker, create_async_engine
from app.core import settings
db_session_context: ContextVar[str | None] = ContextVar("db_session_context", default=None)
class DatabaseConnectionManager:
def __init__(self, host: str, engine_kwargs: dict[str, any] = None):
self._host = host
self._engine_kwargs = engine_kwargs or {}
self.engine = None
self.scoped_session = None
async def connect(self) -> None:
"""
Connect to the database
"""
self.engine = create_async_engine(self._host, **self._engine_kwargs)
# Magic happens here
self.scoped_session = async_scoped_session(
session_factory=async_sessionmaker(bind=self.engine, autoflush=False, autocommit=False),
scopefunc=db_session_context.get,
)
async def disconnect(self) -> None:
"""
Close the database connection
"""
if self.engine is not None:
await self.engine.dispose()
self.engine = None
sessionmanager = DatabaseConnectionManager(settings.POSTGRES_URI)
The contextvar db_session_context
is created here and then used as scopefunc
for async_scoped_session
. You can imagine this to essentially be a big dictionary - for each result of the scopefunc
(which is db_session_context.get
in this code), a separate session is maintained.
Arq Worker Settings
For this to work we need a unique identifier for each job. ARQ automatically generates a job_id
within the job context which we can use for this purpose. We'll use the on_job_start
and after_job_end
hooks provided by ARQ to facilitate database handling, and the regular startup
and shutdown
hooks to setup and teardown the database connection.
- On startup, we initiate the database connection
- On job start, we're using the
job_id
and store it in the created context var - After job end, we
commit()
changes and return the connection - On shutdown, we close the connection properly
# src/worker/worker.py
import asyncio
from typing import Any
from uuid import uuid4
from arq.connections import RedisSettings
from arq.jobs import Job, JobDef
import uvloop
from app.core import FastAPIStructLogger, settings, setup_logging
from .db import db_session_context, sessionmanager
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
async def startup(ctx: dict[Any, Any] | None) -> None:
"""
Worker startup function - initialize DB connection
"""
await sessionmanager.connect()
async def shutdown(ctx: dict[Any, Any] | None) -> None:
"""
Worker shutdown function - disconnect from DB
"""
await sessionmanager.disconnect()
async def on_job_start(ctx: dict[Any, Any] | None, cid: str | None = None) -> None:
"""
Job start - set contextvar to job_id
"""
db_session_context.set(ctx["job_id"])
async def on_job_complete(ctx: dict[Any, Any] | None) -> None:
"""
Job complete
"""
job_def = await Job(ctx["job_id"], ctx["redis"]).info()
if hasattr(job_def, "success") and job_def.success:
await sessionmanager.scoped_session().commit()
else:
await sessionmanager.scoped_session().rollback()
# Close DB session
await sessionmanager.scoped_session.remove()
log.info("Job execution completed")
structlog.contextvars.unbind_contextvars()
class WorkerSettings:
"""
ARQ Worker settings
"""
functions = [] # Add your functions here
redis_settings = RedisSettings(host=REDIS_QUEUE_HOST, port=REDIS_QUEUE_PORT)
on_startup = fn.startup
on_shutdown = fn.shutdown
on_job_start = fn.on_job_start
after_job_end = fn.on_job_complete
handle_signals = False
Use the database
With this setup in place, you can simply use the database within your worker functions as follows:
# src/worker/tasks.py
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud import crud_hero
from .db import sessionmanager
async def print_hero(ctx: dict[any, any] | None, hero_id: uuid.UUID) -> None:
# Get database connection
db: AsyncSession = sessionmanager.scoped_session()
# Do whatever you want with it
hero = await crud_hero.get(db, hero_id)