From 51d03bc780a668e854c69827a875461af19bc7ca Mon Sep 17 00:00:00 2001 From: Johan Lundberg Date: Mon, 16 Feb 2026 15:41:15 +0100 Subject: [PATCH] refactor: extract open_db() context manager from lifespan --- src/porchlight/app.py | 76 ++++++++++++++----------------- src/porchlight/store/sqlite/db.py | 23 ++++++++++ tests/test_store/test_db.py | 45 ++++++++++++++++++ 3 files changed, 102 insertions(+), 42 deletions(-) create mode 100644 src/porchlight/store/sqlite/db.py create mode 100644 tests/test_store/test_db.py diff --git a/src/porchlight/app.py b/src/porchlight/app.py index bdd66b1..90f6f84 100644 --- a/src/porchlight/app.py +++ b/src/porchlight/app.py @@ -4,7 +4,6 @@ from contextlib import asynccontextmanager from pathlib import Path from urllib.parse import urlparse -import aiosqlite from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates @@ -18,7 +17,7 @@ from porchlight.invite.service import MagicLinkService from porchlight.manage.routes import router as manage_router from porchlight.oidc.endpoints import router as oidc_router from porchlight.oidc.provider import create_oidc_server -from porchlight.store.sqlite.migrations import run_migrations +from porchlight.store.sqlite.db import open_db from porchlight.store.sqlite.repositories import ( SQLiteCredentialRepository, SQLiteMagicLinkRepository, @@ -33,52 +32,45 @@ MIGRATIONS_DIR = PACKAGE_DIR / "store" / "sqlite" / "migrations" async def lifespan(app: FastAPI) -> AsyncIterator[None]: settings: Settings = app.state.settings if settings.storage_backend == StorageBackend.SQLITE: - if settings.sqlite_path != ":memory:": - Path(settings.sqlite_path).parent.mkdir(parents=True, exist_ok=True) - db = await aiosqlite.connect(settings.sqlite_path) - db.row_factory = aiosqlite.Row - await db.execute("PRAGMA journal_mode=WAL") - await db.execute("PRAGMA foreign_keys=ON") - await run_migrations(db, MIGRATIONS_DIR) - app.state.user_repo = SQLiteUserRepository(db) - app.state.credential_repo = SQLiteCredentialRepository(db) - app.state.magic_link_repo = SQLiteMagicLinkRepository(db) + async with open_db(settings.sqlite_path, MIGRATIONS_DIR) as db: + app.state.user_repo = SQLiteUserRepository(db) + app.state.credential_repo = SQLiteCredentialRepository(db) + app.state.magic_link_repo = SQLiteMagicLinkRepository(db) - # Auth services - app.state.password_service = PasswordService() + # Auth services + app.state.password_service = PasswordService() - rp_id = urlparse(settings.issuer).hostname or "localhost" - app.state.webauthn_service = WebAuthnService( - rp_id=rp_id, - rp_name=app.title, - origin=settings.issuer, - ) + rp_id = urlparse(settings.issuer).hostname or "localhost" + app.state.webauthn_service = WebAuthnService( + rp_id=rp_id, + rp_name=app.title, + origin=settings.issuer, + ) - app.state.magic_link_service = MagicLinkService( - repo=app.state.magic_link_repo, - ttl=settings.invite_ttl, - ) + app.state.magic_link_service = MagicLinkService( + repo=app.state.magic_link_repo, + ttl=settings.invite_ttl, + ) - # OIDC Server - oidc_server = create_oidc_server(settings) - app.state.oidc_server = oidc_server + # OIDC Server + oidc_server = create_oidc_server(settings) + app.state.oidc_server = oidc_server - # Register management client - manage_secret = settings.session_secret or secrets.token_hex(32) - oidc_server.context.cdb[settings.manage_client_id] = { - "client_id": settings.manage_client_id, - "client_secret": manage_secret, - "redirect_uris": [(f"{settings.issuer}/manage/callback", {})], - "response_types_supported": ["code"], - "token_endpoint_auth_method": "client_secret_basic", - "scope": ["openid", "profile", "email"], - "allowed_scopes": ["openid", "profile", "email"], - "client_salt": secrets.token_hex(8), - } - oidc_server.keyjar.add_symmetric(settings.manage_client_id, manage_secret) + # Register management client + manage_secret = settings.session_secret or secrets.token_hex(32) + oidc_server.context.cdb[settings.manage_client_id] = { + "client_id": settings.manage_client_id, + "client_secret": manage_secret, + "redirect_uris": [(f"{settings.issuer}/manage/callback", {})], + "response_types_supported": ["code"], + "token_endpoint_auth_method": "client_secret_basic", + "scope": ["openid", "profile", "email"], + "allowed_scopes": ["openid", "profile", "email"], + "client_salt": secrets.token_hex(8), + } + oidc_server.keyjar.add_symmetric(settings.manage_client_id, manage_secret) - yield - await db.close() + yield else: raise NotImplementedError("MongoDB backend not yet implemented") diff --git a/src/porchlight/store/sqlite/db.py b/src/porchlight/store/sqlite/db.py new file mode 100644 index 0000000..8ab97ac --- /dev/null +++ b/src/porchlight/store/sqlite/db.py @@ -0,0 +1,23 @@ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from pathlib import Path + +import aiosqlite + +from porchlight.store.sqlite.migrations import run_migrations + + +@asynccontextmanager +async def open_db(db_path: str, migrations_dir: Path) -> AsyncIterator[aiosqlite.Connection]: + """Open a SQLite connection with WAL mode, foreign keys, and migrations applied.""" + if db_path != ":memory:": + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + db = await aiosqlite.connect(db_path) + try: + db.row_factory = aiosqlite.Row + await db.execute("PRAGMA journal_mode=WAL") + await db.execute("PRAGMA foreign_keys=ON") + await run_migrations(db, migrations_dir) + yield db + finally: + await db.close() diff --git a/tests/test_store/test_db.py b/tests/test_store/test_db.py new file mode 100644 index 0000000..8417349 --- /dev/null +++ b/tests/test_store/test_db.py @@ -0,0 +1,45 @@ +from pathlib import Path + +import aiosqlite +import pytest + +from porchlight.store.sqlite.db import open_db + +MIGRATIONS_DIR = ( + Path(__file__).resolve().parent.parent.parent / "src" / "porchlight" / "store" / "sqlite" / "migrations" +) + + +@pytest.fixture +def migrations_dir() -> Path: + return MIGRATIONS_DIR + + +async def test_open_db_returns_connection(tmp_path: Path, migrations_dir: Path) -> None: + db_path = str(tmp_path / "test.db") + async with open_db(db_path, migrations_dir) as db: + assert isinstance(db, aiosqlite.Connection) + + +async def test_open_db_runs_migrations(tmp_path: Path, migrations_dir: Path) -> None: + db_path = str(tmp_path / "test.db") + async with ( + open_db(db_path, migrations_dir) as db, + db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") as cursor, + ): + row = await cursor.fetchone() + assert row is not None + + +async def test_open_db_sets_wal_mode(tmp_path: Path, migrations_dir: Path) -> None: + db_path = str(tmp_path / "test.db") + async with open_db(db_path, migrations_dir) as db, db.execute("PRAGMA journal_mode") as cursor: + row = await cursor.fetchone() + assert row[0] == "wal" + + +async def test_open_db_creates_parent_dirs(tmp_path: Path, migrations_dir: Path) -> None: + db_path = str(tmp_path / "sub" / "dir" / "test.db") + async with open_db(db_path, migrations_dir) as db: + assert isinstance(db, aiosqlite.Connection) + assert (tmp_path / "sub" / "dir" / "test.db").exists()