diff --git a/src/fastapi_oidc_op/store/sqlite/repositories.py b/src/fastapi_oidc_op/store/sqlite/repositories.py index ed02392..08815b3 100644 --- a/src/fastapi_oidc_op/store/sqlite/repositories.py +++ b/src/fastapi_oidc_op/store/sqlite/repositories.py @@ -2,7 +2,7 @@ from datetime import UTC, datetime import aiosqlite -from fastapi_oidc_op.models import PasswordCredential, User, WebAuthnCredential +from fastapi_oidc_op.models import MagicLink, PasswordCredential, User, WebAuthnCredential from fastapi_oidc_op.store.exceptions import DuplicateError @@ -238,3 +238,54 @@ class SQLiteCredentialRepository: cursor = await self._db.execute("DELETE FROM password_credentials WHERE user_id = ?", (user_id,)) await self._db.commit() return cursor.rowcount > 0 + + +class SQLiteMagicLinkRepository: + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + def _row_to_magic_link(self, row: aiosqlite.Row) -> MagicLink: + return MagicLink( + token=row["token"], + username=row["username"], + expires_at=datetime.fromisoformat(row["expires_at"]), + used=bool(row["used"]), + created_by=row["created_by"], + note=row["note"], + ) + + async def create(self, link: MagicLink) -> MagicLink: + try: + await self._db.execute( + "INSERT INTO magic_links (token, username, expires_at, used, created_by, note) VALUES (?, ?, ?, ?, ?, ?)", + ( + link.token, + link.username, + link.expires_at.isoformat(), + int(link.used), + link.created_by, + link.note, + ), + ) + await self._db.commit() + except aiosqlite.IntegrityError as e: + raise DuplicateError(str(e)) from e + return link + + async def get_by_token(self, token: str) -> MagicLink | None: + async with self._db.execute("SELECT * FROM magic_links WHERE token = ?", (token,)) as cursor: + row = await cursor.fetchone() + if row is None: + return None + return self._row_to_magic_link(row) + + async def mark_used(self, token: str) -> bool: + cursor = await self._db.execute("UPDATE magic_links SET used = 1 WHERE token = ?", (token,)) + await self._db.commit() + return cursor.rowcount > 0 + + async def delete_expired(self) -> int: + now = datetime.now(UTC).isoformat() + cursor = await self._db.execute("DELETE FROM magic_links WHERE expires_at < ? AND used = 0", (now,)) + await self._db.commit() + return cursor.rowcount diff --git a/tests/test_store/test_sqlite_magic_link_repo.py b/tests/test_store/test_sqlite_magic_link_repo.py new file mode 100644 index 0000000..c8e6828 --- /dev/null +++ b/tests/test_store/test_sqlite_magic_link_repo.py @@ -0,0 +1,102 @@ +from datetime import UTC, datetime, timedelta + +import aiosqlite +import pytest + +from fastapi_oidc_op.models import MagicLink +from fastapi_oidc_op.store.exceptions import DuplicateError +from fastapi_oidc_op.store.protocols import MagicLinkRepository +from fastapi_oidc_op.store.sqlite.repositories import SQLiteMagicLinkRepository + + +@pytest.fixture +def magic_link_repo(db: aiosqlite.Connection) -> SQLiteMagicLinkRepository: + return SQLiteMagicLinkRepository(db) + + +def _make_link(**overrides) -> MagicLink: + defaults = { + "token": "abc123", + "username": "alice", + "expires_at": datetime.now(UTC) + timedelta(hours=24), + } + defaults.update(overrides) + return MagicLink(**defaults) + + +async def test_implements_protocol(magic_link_repo: SQLiteMagicLinkRepository) -> None: + assert isinstance(magic_link_repo, MagicLinkRepository) + + +async def test_create_and_get_by_token(magic_link_repo: SQLiteMagicLinkRepository) -> None: + link = _make_link() + created = await magic_link_repo.create(link) + assert created.token == "abc123" + + fetched = await magic_link_repo.get_by_token("abc123") + assert fetched is not None + assert fetched.token == "abc123" + assert fetched.username == "alice" + assert fetched.used is False + + +async def test_get_by_token_not_found(magic_link_repo: SQLiteMagicLinkRepository) -> None: + result = await magic_link_repo.get_by_token("nonexistent") + assert result is None + + +async def test_mark_used(magic_link_repo: SQLiteMagicLinkRepository) -> None: + link = _make_link() + await magic_link_repo.create(link) + + marked = await magic_link_repo.mark_used("abc123") + assert marked is True + + fetched = await magic_link_repo.get_by_token("abc123") + assert fetched is not None + assert fetched.used is True + + +async def test_mark_used_not_found(magic_link_repo: SQLiteMagicLinkRepository) -> None: + marked = await magic_link_repo.mark_used("nonexistent") + assert marked is False + + +async def test_delete_expired(magic_link_repo: SQLiteMagicLinkRepository) -> None: + expired = _make_link(token="expired", expires_at=datetime.now(UTC) - timedelta(hours=1)) + await magic_link_repo.create(expired) + + valid = _make_link(token="valid", expires_at=datetime.now(UTC) + timedelta(hours=24)) + await magic_link_repo.create(valid) + + count = await magic_link_repo.delete_expired() + assert count == 1 + + assert await magic_link_repo.get_by_token("expired") is None + assert await magic_link_repo.get_by_token("valid") is not None + + +async def test_delete_expired_skips_used(magic_link_repo: SQLiteMagicLinkRepository) -> None: + link = _make_link(token="used-expired", expires_at=datetime.now(UTC) - timedelta(hours=1)) + await magic_link_repo.create(link) + await magic_link_repo.mark_used("used-expired") + + count = await magic_link_repo.delete_expired() + assert count == 0 + + +async def test_create_with_optional_fields(magic_link_repo: SQLiteMagicLinkRepository) -> None: + link = _make_link(created_by="admin", note="Welcome aboard") + await magic_link_repo.create(link) + + fetched = await magic_link_repo.get_by_token("abc123") + assert fetched is not None + assert fetched.created_by == "admin" + assert fetched.note == "Welcome aboard" + + +async def test_create_duplicate_token(magic_link_repo: SQLiteMagicLinkRepository) -> None: + await magic_link_repo.create(_make_link()) + + with pytest.raises(DuplicateError): + await magic_link_repo.create(_make_link())