feat: add lifespan integration and dependency injection
This commit is contained in:
parent
9f4914a922
commit
a45604ff2f
4 changed files with 88 additions and 3 deletions
|
|
@ -1,6 +1,39 @@
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from fastapi_oidc_op.config import Settings
|
from fastapi_oidc_op.config import Settings, StorageBackend
|
||||||
|
from fastapi_oidc_op.store.sqlite.migrations import run_migrations
|
||||||
|
from fastapi_oidc_op.store.sqlite.repositories import (
|
||||||
|
SQLiteCredentialRepository,
|
||||||
|
SQLiteMagicLinkRepository,
|
||||||
|
SQLiteUserRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
MIGRATIONS_DIR = Path(__file__).parent / "store" / "sqlite" / "migrations"
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||||
|
settings: Settings = app.state.settings
|
||||||
|
if settings.storage_backend == StorageBackend.SQLITE:
|
||||||
|
if settings.sqlite_path != ":memory:":
|
||||||
|
Path(settings.sqlite_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
db = await aiosqlite.connect(settings.sqlite_path)
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
await db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
await db.execute("PRAGMA foreign_keys=ON")
|
||||||
|
await run_migrations(db, MIGRATIONS_DIR)
|
||||||
|
app.state.user_repo = SQLiteUserRepository(db)
|
||||||
|
app.state.credential_repo = SQLiteCredentialRepository(db)
|
||||||
|
app.state.magic_link_repo = SQLiteMagicLinkRepository(db)
|
||||||
|
yield
|
||||||
|
await db.close()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("MongoDB backend not yet implemented")
|
||||||
|
|
||||||
|
|
||||||
def create_app(settings: Settings | None = None) -> FastAPI:
|
def create_app(settings: Settings | None = None) -> FastAPI:
|
||||||
|
|
@ -12,6 +45,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
docs_url="/docs" if settings.debug else None,
|
docs_url="/docs" if settings.debug else None,
|
||||||
redoc_url=None,
|
redoc_url=None,
|
||||||
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.state.settings = settings
|
app.state.settings = settings
|
||||||
|
|
|
||||||
19
src/fastapi_oidc_op/dependencies.py
Normal file
19
src/fastapi_oidc_op/dependencies.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from fastapi_oidc_op.store.protocols import (
|
||||||
|
CredentialRepository,
|
||||||
|
MagicLinkRepository,
|
||||||
|
UserRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_repo(request: Request) -> UserRepository:
|
||||||
|
return request.app.state.user_repo
|
||||||
|
|
||||||
|
|
||||||
|
def get_credential_repo(request: Request) -> CredentialRepository:
|
||||||
|
return request.app.state.credential_repo
|
||||||
|
|
||||||
|
|
||||||
|
def get_magic_link_repo(request: Request) -> MagicLinkRepository:
|
||||||
|
return request.app.state.magic_link_repo
|
||||||
|
|
@ -9,12 +9,12 @@ from fastapi_oidc_op.config import Settings
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def settings() -> Settings:
|
def settings() -> Settings:
|
||||||
return Settings(issuer="http://localhost:8000")
|
return Settings(issuer="http://localhost:8000", sqlite_path=":memory:")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def client(settings: Settings) -> AsyncIterator[AsyncClient]:
|
async def client(settings: Settings) -> AsyncIterator[AsyncClient]:
|
||||||
app = create_app(settings)
|
app = create_app(settings)
|
||||||
transport = ASGITransport(app=app)
|
transport = ASGITransport(app=app)
|
||||||
async with AsyncClient(transport=transport, base_url=settings.issuer) as ac:
|
async with app.router.lifespan_context(app), AsyncClient(transport=transport, base_url=settings.issuer) as ac:
|
||||||
yield ac
|
yield ac
|
||||||
|
|
|
||||||
|
|
@ -13,3 +13,35 @@ async def test_app_has_title(client: AsyncClient) -> None:
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["info"]["title"] == "FastAPI OIDC OP"
|
assert data["info"]["title"] == "FastAPI OIDC OP"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_app_has_repos_on_state(client: AsyncClient) -> None:
|
||||||
|
from fastapi_oidc_op.store.protocols import (
|
||||||
|
CredentialRepository,
|
||||||
|
MagicLinkRepository,
|
||||||
|
UserRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
app = client._transport.app # type: ignore[union-attr]
|
||||||
|
assert isinstance(app.state.user_repo, UserRepository)
|
||||||
|
assert isinstance(app.state.credential_repo, CredentialRepository)
|
||||||
|
assert isinstance(app.state.magic_link_repo, MagicLinkRepository)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_dependency_functions() -> None:
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from fastapi_oidc_op.dependencies import (
|
||||||
|
get_credential_repo,
|
||||||
|
get_magic_link_repo,
|
||||||
|
get_user_repo,
|
||||||
|
)
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.app.state.user_repo = "user_repo_sentinel"
|
||||||
|
request.app.state.credential_repo = "credential_repo_sentinel"
|
||||||
|
request.app.state.magic_link_repo = "magic_link_repo_sentinel"
|
||||||
|
|
||||||
|
assert get_user_repo(request) == "user_repo_sentinel"
|
||||||
|
assert get_credential_repo(request) == "credential_repo_sentinel"
|
||||||
|
assert get_magic_link_repo(request) == "magic_link_repo_sentinel"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue