refactor: extract open_db() context manager from lifespan

This commit is contained in:
Johan Lundberg 2026-02-16 15:41:15 +01:00
parent 3462e38131
commit 51d03bc780
No known key found for this signature in database
GPG key ID: A6C152738D03C7D1
3 changed files with 102 additions and 42 deletions

View file

@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
import aiosqlite
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
@ -18,7 +17,7 @@ from porchlight.invite.service import MagicLinkService
from porchlight.manage.routes import router as manage_router from porchlight.manage.routes import router as manage_router
from porchlight.oidc.endpoints import router as oidc_router from porchlight.oidc.endpoints import router as oidc_router
from porchlight.oidc.provider import create_oidc_server from porchlight.oidc.provider import create_oidc_server
from porchlight.store.sqlite.migrations import run_migrations from porchlight.store.sqlite.db import open_db
from porchlight.store.sqlite.repositories import ( from porchlight.store.sqlite.repositories import (
SQLiteCredentialRepository, SQLiteCredentialRepository,
SQLiteMagicLinkRepository, SQLiteMagicLinkRepository,
@ -33,52 +32,45 @@ MIGRATIONS_DIR = PACKAGE_DIR / "store" / "sqlite" / "migrations"
async def lifespan(app: FastAPI) -> AsyncIterator[None]: async def lifespan(app: FastAPI) -> AsyncIterator[None]:
settings: Settings = app.state.settings settings: Settings = app.state.settings
if settings.storage_backend == StorageBackend.SQLITE: if settings.storage_backend == StorageBackend.SQLITE:
if settings.sqlite_path != ":memory:": async with open_db(settings.sqlite_path, MIGRATIONS_DIR) as db:
Path(settings.sqlite_path).parent.mkdir(parents=True, exist_ok=True) app.state.user_repo = SQLiteUserRepository(db)
db = await aiosqlite.connect(settings.sqlite_path) app.state.credential_repo = SQLiteCredentialRepository(db)
db.row_factory = aiosqlite.Row app.state.magic_link_repo = SQLiteMagicLinkRepository(db)
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA foreign_keys=ON")
await run_migrations(db, MIGRATIONS_DIR)
app.state.user_repo = SQLiteUserRepository(db)
app.state.credential_repo = SQLiteCredentialRepository(db)
app.state.magic_link_repo = SQLiteMagicLinkRepository(db)
# Auth services # Auth services
app.state.password_service = PasswordService() app.state.password_service = PasswordService()
rp_id = urlparse(settings.issuer).hostname or "localhost" rp_id = urlparse(settings.issuer).hostname or "localhost"
app.state.webauthn_service = WebAuthnService( app.state.webauthn_service = WebAuthnService(
rp_id=rp_id, rp_id=rp_id,
rp_name=app.title, rp_name=app.title,
origin=settings.issuer, origin=settings.issuer,
) )
app.state.magic_link_service = MagicLinkService( app.state.magic_link_service = MagicLinkService(
repo=app.state.magic_link_repo, repo=app.state.magic_link_repo,
ttl=settings.invite_ttl, ttl=settings.invite_ttl,
) )
# OIDC Server # OIDC Server
oidc_server = create_oidc_server(settings) oidc_server = create_oidc_server(settings)
app.state.oidc_server = oidc_server app.state.oidc_server = oidc_server
# Register management client # Register management client
manage_secret = settings.session_secret or secrets.token_hex(32) manage_secret = settings.session_secret or secrets.token_hex(32)
oidc_server.context.cdb[settings.manage_client_id] = { oidc_server.context.cdb[settings.manage_client_id] = {
"client_id": settings.manage_client_id, "client_id": settings.manage_client_id,
"client_secret": manage_secret, "client_secret": manage_secret,
"redirect_uris": [(f"{settings.issuer}/manage/callback", {})], "redirect_uris": [(f"{settings.issuer}/manage/callback", {})],
"response_types_supported": ["code"], "response_types_supported": ["code"],
"token_endpoint_auth_method": "client_secret_basic", "token_endpoint_auth_method": "client_secret_basic",
"scope": ["openid", "profile", "email"], "scope": ["openid", "profile", "email"],
"allowed_scopes": ["openid", "profile", "email"], "allowed_scopes": ["openid", "profile", "email"],
"client_salt": secrets.token_hex(8), "client_salt": secrets.token_hex(8),
} }
oidc_server.keyjar.add_symmetric(settings.manage_client_id, manage_secret) oidc_server.keyjar.add_symmetric(settings.manage_client_id, manage_secret)
yield yield
await db.close()
else: else:
raise NotImplementedError("MongoDB backend not yet implemented") raise NotImplementedError("MongoDB backend not yet implemented")

View file

@ -0,0 +1,23 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
import aiosqlite
from porchlight.store.sqlite.migrations import run_migrations
@asynccontextmanager
async def open_db(db_path: str, migrations_dir: Path) -> AsyncIterator[aiosqlite.Connection]:
"""Open a SQLite connection with WAL mode, foreign keys, and migrations applied."""
if db_path != ":memory:":
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
db = await aiosqlite.connect(db_path)
try:
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA foreign_keys=ON")
await run_migrations(db, migrations_dir)
yield db
finally:
await db.close()

View file

@ -0,0 +1,45 @@
from pathlib import Path
import aiosqlite
import pytest
from porchlight.store.sqlite.db import open_db
MIGRATIONS_DIR = (
Path(__file__).resolve().parent.parent.parent / "src" / "porchlight" / "store" / "sqlite" / "migrations"
)
@pytest.fixture
def migrations_dir() -> Path:
return MIGRATIONS_DIR
async def test_open_db_returns_connection(tmp_path: Path, migrations_dir: Path) -> None:
db_path = str(tmp_path / "test.db")
async with open_db(db_path, migrations_dir) as db:
assert isinstance(db, aiosqlite.Connection)
async def test_open_db_runs_migrations(tmp_path: Path, migrations_dir: Path) -> None:
db_path = str(tmp_path / "test.db")
async with (
open_db(db_path, migrations_dir) as db,
db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") as cursor,
):
row = await cursor.fetchone()
assert row is not None
async def test_open_db_sets_wal_mode(tmp_path: Path, migrations_dir: Path) -> None:
db_path = str(tmp_path / "test.db")
async with open_db(db_path, migrations_dir) as db, db.execute("PRAGMA journal_mode") as cursor:
row = await cursor.fetchone()
assert row[0] == "wal"
async def test_open_db_creates_parent_dirs(tmp_path: Path, migrations_dir: Path) -> None:
db_path = str(tmp_path / "sub" / "dir" / "test.db")
async with open_db(db_path, migrations_dir) as db:
assert isinstance(db, aiosqlite.Connection)
assert (tmp_path / "sub" / "dir" / "test.db").exists()