refactor: extract open_db() context manager from lifespan
This commit is contained in:
parent
3462e38131
commit
51d03bc780
3 changed files with 102 additions and 42 deletions
|
|
@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import aiosqlite
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
@ -18,7 +17,7 @@ from porchlight.invite.service import MagicLinkService
|
||||||
from porchlight.manage.routes import router as manage_router
|
from porchlight.manage.routes import router as manage_router
|
||||||
from porchlight.oidc.endpoints import router as oidc_router
|
from porchlight.oidc.endpoints import router as oidc_router
|
||||||
from porchlight.oidc.provider import create_oidc_server
|
from porchlight.oidc.provider import create_oidc_server
|
||||||
from porchlight.store.sqlite.migrations import run_migrations
|
from porchlight.store.sqlite.db import open_db
|
||||||
from porchlight.store.sqlite.repositories import (
|
from porchlight.store.sqlite.repositories import (
|
||||||
SQLiteCredentialRepository,
|
SQLiteCredentialRepository,
|
||||||
SQLiteMagicLinkRepository,
|
SQLiteMagicLinkRepository,
|
||||||
|
|
@ -33,52 +32,45 @@ MIGRATIONS_DIR = PACKAGE_DIR / "store" / "sqlite" / "migrations"
|
||||||
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||||
settings: Settings = app.state.settings
|
settings: Settings = app.state.settings
|
||||||
if settings.storage_backend == StorageBackend.SQLITE:
|
if settings.storage_backend == StorageBackend.SQLITE:
|
||||||
if settings.sqlite_path != ":memory:":
|
async with open_db(settings.sqlite_path, MIGRATIONS_DIR) as db:
|
||||||
Path(settings.sqlite_path).parent.mkdir(parents=True, exist_ok=True)
|
app.state.user_repo = SQLiteUserRepository(db)
|
||||||
db = await aiosqlite.connect(settings.sqlite_path)
|
app.state.credential_repo = SQLiteCredentialRepository(db)
|
||||||
db.row_factory = aiosqlite.Row
|
app.state.magic_link_repo = SQLiteMagicLinkRepository(db)
|
||||||
await db.execute("PRAGMA journal_mode=WAL")
|
|
||||||
await db.execute("PRAGMA foreign_keys=ON")
|
|
||||||
await run_migrations(db, MIGRATIONS_DIR)
|
|
||||||
app.state.user_repo = SQLiteUserRepository(db)
|
|
||||||
app.state.credential_repo = SQLiteCredentialRepository(db)
|
|
||||||
app.state.magic_link_repo = SQLiteMagicLinkRepository(db)
|
|
||||||
|
|
||||||
# Auth services
|
# Auth services
|
||||||
app.state.password_service = PasswordService()
|
app.state.password_service = PasswordService()
|
||||||
|
|
||||||
rp_id = urlparse(settings.issuer).hostname or "localhost"
|
rp_id = urlparse(settings.issuer).hostname or "localhost"
|
||||||
app.state.webauthn_service = WebAuthnService(
|
app.state.webauthn_service = WebAuthnService(
|
||||||
rp_id=rp_id,
|
rp_id=rp_id,
|
||||||
rp_name=app.title,
|
rp_name=app.title,
|
||||||
origin=settings.issuer,
|
origin=settings.issuer,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.state.magic_link_service = MagicLinkService(
|
app.state.magic_link_service = MagicLinkService(
|
||||||
repo=app.state.magic_link_repo,
|
repo=app.state.magic_link_repo,
|
||||||
ttl=settings.invite_ttl,
|
ttl=settings.invite_ttl,
|
||||||
)
|
)
|
||||||
|
|
||||||
# OIDC Server
|
# OIDC Server
|
||||||
oidc_server = create_oidc_server(settings)
|
oidc_server = create_oidc_server(settings)
|
||||||
app.state.oidc_server = oidc_server
|
app.state.oidc_server = oidc_server
|
||||||
|
|
||||||
# Register management client
|
# Register management client
|
||||||
manage_secret = settings.session_secret or secrets.token_hex(32)
|
manage_secret = settings.session_secret or secrets.token_hex(32)
|
||||||
oidc_server.context.cdb[settings.manage_client_id] = {
|
oidc_server.context.cdb[settings.manage_client_id] = {
|
||||||
"client_id": settings.manage_client_id,
|
"client_id": settings.manage_client_id,
|
||||||
"client_secret": manage_secret,
|
"client_secret": manage_secret,
|
||||||
"redirect_uris": [(f"{settings.issuer}/manage/callback", {})],
|
"redirect_uris": [(f"{settings.issuer}/manage/callback", {})],
|
||||||
"response_types_supported": ["code"],
|
"response_types_supported": ["code"],
|
||||||
"token_endpoint_auth_method": "client_secret_basic",
|
"token_endpoint_auth_method": "client_secret_basic",
|
||||||
"scope": ["openid", "profile", "email"],
|
"scope": ["openid", "profile", "email"],
|
||||||
"allowed_scopes": ["openid", "profile", "email"],
|
"allowed_scopes": ["openid", "profile", "email"],
|
||||||
"client_salt": secrets.token_hex(8),
|
"client_salt": secrets.token_hex(8),
|
||||||
}
|
}
|
||||||
oidc_server.keyjar.add_symmetric(settings.manage_client_id, manage_secret)
|
oidc_server.keyjar.add_symmetric(settings.manage_client_id, manage_secret)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
await db.close()
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("MongoDB backend not yet implemented")
|
raise NotImplementedError("MongoDB backend not yet implemented")
|
||||||
|
|
||||||
|
|
|
||||||
23
src/porchlight/store/sqlite/db.py
Normal file
23
src/porchlight/store/sqlite/db.py
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from porchlight.store.sqlite.migrations import run_migrations
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def open_db(db_path: str, migrations_dir: Path) -> AsyncIterator[aiosqlite.Connection]:
|
||||||
|
"""Open a SQLite connection with WAL mode, foreign keys, and migrations applied."""
|
||||||
|
if db_path != ":memory:":
|
||||||
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
db = await aiosqlite.connect(db_path)
|
||||||
|
try:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
await db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
await db.execute("PRAGMA foreign_keys=ON")
|
||||||
|
await run_migrations(db, migrations_dir)
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
await db.close()
|
||||||
45
tests/test_store/test_db.py
Normal file
45
tests/test_store/test_db.py
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from porchlight.store.sqlite.db import open_db
|
||||||
|
|
||||||
|
MIGRATIONS_DIR = (
|
||||||
|
Path(__file__).resolve().parent.parent.parent / "src" / "porchlight" / "store" / "sqlite" / "migrations"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def migrations_dir() -> Path:
|
||||||
|
return MIGRATIONS_DIR
|
||||||
|
|
||||||
|
|
||||||
|
async def test_open_db_returns_connection(tmp_path: Path, migrations_dir: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
async with open_db(db_path, migrations_dir) as db:
|
||||||
|
assert isinstance(db, aiosqlite.Connection)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_open_db_runs_migrations(tmp_path: Path, migrations_dir: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
async with (
|
||||||
|
open_db(db_path, migrations_dir) as db,
|
||||||
|
db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") as cursor,
|
||||||
|
):
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_open_db_sets_wal_mode(tmp_path: Path, migrations_dir: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
async with open_db(db_path, migrations_dir) as db, db.execute("PRAGMA journal_mode") as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row[0] == "wal"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_open_db_creates_parent_dirs(tmp_path: Path, migrations_dir: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "sub" / "dir" / "test.db")
|
||||||
|
async with open_db(db_path, migrations_dir) as db:
|
||||||
|
assert isinstance(db, aiosqlite.Connection)
|
||||||
|
assert (tmp_path / "sub" / "dir" / "test.db").exists()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue