Published on

ARQ + SQLAlchemy - Async done right

Authors
  • avatar
    Name
    Daniel Herrmann
    Twitter

When writing and designing APIs, at some point you will have jobs that should be deferred and handled outside of the main request / response lifecycle, so ensure a timely response to the API user. There are many libraries out there, including Celery, Dramatiq, RQ and Walnats, or more recent additions such as Taskiq. In my projects I typically use ARQ which provided cronjob support and is relatively easy to use.

SQLAlchemy is IMHO the de-facto standard to provide ORM functionality in Python, and its widely used and supported within the community. Unlike FastAPI, ARQ does not support dependency injection, nor does it have a great documentation about properly integrating a database layer including proper pool and session handling. The approach below is what I came up with when combining FastAPI as API framework and ARQ as distributed job queue, with the following design goals in mind:

  • Reuse existing model definitions across API and worker
  • Reuse existing CRUD methods
  • Full async support
  • Task scoped DB sessions
  • Automatic rollback on error and commit on success

The generic idea is to have a connection pool available, which is then used to instantiate one DB session per task that is executed. At the end of the task, the session is automatically returned to the pool, .commit() is executed on success and rollback() otherwise.

Project Structure

We're assuming an existing project containing the SQLAlchemy models that we want to re-use. Optionally, it could also contain a set up CRUD utilities to simplify database interaction. The folder structure is as follows:

.
├── README.md
├── alembic.ini
├── pyproject.toml
└── src
    ├── backend
    │   └── models
    │       ├── __init__.py
    │       └── hero.py
    └── worker
        ├── db.py
        ├── main.py
        └── tasks
            └── hero.py

Database Connection

The idea is to use Python contextvars to differentiate different job and/or task executions, and then make use of SQLAlchemys scoped_session to get a session relating to this particular request. Using a scoped session allows us to access the database session from various places, and as long as the related context variable doesn't change, we'll always get the same session without the need to pass the session object around.

Database session manager

We're encapsulating the database methods into a dedicated class. It handles connecting to the database, tearing down the connection on shutdown and providing a session to us if needed:

# src/worker/db.py
from contextvars import ContextVar

from sqlalchemy.ext.asyncio import async_scoped_session, async_sessionmaker, create_async_engine

from app.core import settings

db_session_context: ContextVar[str | None] = ContextVar("db_session_context", default=None)


class DatabaseConnectionManager:
    def __init__(self, host: str, engine_kwargs: dict[str, any] = None):
        self._host = host
        self._engine_kwargs = engine_kwargs or {}
        self.engine = None
        self.scoped_session = None

    async def connect(self) -> None:
        """
        Connect to the database
        """
        self.engine = create_async_engine(self._host, **self._engine_kwargs)

        # Magic happens here
        self.scoped_session = async_scoped_session(
            session_factory=async_sessionmaker(bind=self.engine, autoflush=False, autocommit=False),
            scopefunc=db_session_context.get,
        )

    async def disconnect(self) -> None:
        """
        Close the database connection
        """
        if self.engine is not None:
            await self.engine.dispose()
            self.engine = None


sessionmanager = DatabaseConnectionManager(settings.POSTGRES_URI)

The contextvar db_session_context is created here and then used as scopefunc for async_scoped_session. You can imagine this to essentially be a big dictionary - for each result of the scopefunc (which is db_session_context.get in this code), a separate session is maintained.

Arq Worker Settings

For this to work we need a unique identifier for each job. ARQ automatically generates a job_id within the job context which we can use for this purpose. We'll use the on_job_start and after_job_end hooks provided by ARQ to facilitate database handling, and the regular startup and shutdown hooks to setup and teardown the database connection.

  • On startup, we initiate the database connection
  • On job start, we're using the job_id and store it in the created context var
  • After job end, we commit() changes and return the connection
  • On shutdown, we close the connection properly
# src/worker/worker.py
import asyncio
from typing import Any
from uuid import uuid4

from arq.connections import RedisSettings
from arq.jobs import Job, JobDef
import uvloop

from app.core import FastAPIStructLogger, settings, setup_logging

from .db import db_session_context, sessionmanager

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())


async def startup(ctx: dict[Any, Any] | None) -> None:
    """
    Worker startup function - initialize DB connection
    """
    await sessionmanager.connect()


async def shutdown(ctx: dict[Any, Any] | None) -> None:
    """
    Worker shutdown function - disconnect from DB
    """
    await sessionmanager.disconnect()


async def on_job_start(ctx: dict[Any, Any] | None, cid: str | None = None) -> None:
    """
    Job start - set contextvar to job_id
    """
    db_session_context.set(ctx["job_id"])


async def on_job_complete(ctx: dict[Any, Any] | None) -> None:
    """
    Job complete
    """
    job_def = await Job(ctx["job_id"], ctx["redis"]).info()
    if hasattr(job_def, "success") and job_def.success:
        await sessionmanager.scoped_session().commit()
    else:
        await sessionmanager.scoped_session().rollback()

    # Close DB session
    await sessionmanager.scoped_session.remove()
    log.info("Job execution completed")
    structlog.contextvars.unbind_contextvars()


class WorkerSettings:
    """
    ARQ Worker settings
    """

    functions = []  # Add your functions here
    redis_settings = RedisSettings(host=REDIS_QUEUE_HOST, port=REDIS_QUEUE_PORT)
    on_startup = fn.startup
    on_shutdown = fn.shutdown
    on_job_start = fn.on_job_start
    after_job_end = fn.on_job_complete
    handle_signals = False

Use the database

With this setup in place, you can simply use the database within your worker functions as follows:

# src/worker/tasks.py

import uuid

from sqlalchemy.ext.asyncio import AsyncSession

from app.crud import crud_hero

from .db import sessionmanager


async def print_hero(ctx: dict[any, any] | None, hero_id: uuid.UUID) -> None:
  
  # Get database connection
  db: AsyncSession = sessionmanager.scoped_session()

  # Do whatever you want with it
  hero = await crud_hero.get(db, hero_id)