feat: add SQLiteMagicLinkRepository with tests
This commit is contained in:
parent
bbe0dac8cb
commit
9f4914a922
2 changed files with 154 additions and 1 deletions
|
|
@ -2,7 +2,7 @@ from datetime import UTC, datetime
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
from fastapi_oidc_op.models import PasswordCredential, User, WebAuthnCredential
|
from fastapi_oidc_op.models import MagicLink, PasswordCredential, User, WebAuthnCredential
|
||||||
from fastapi_oidc_op.store.exceptions import DuplicateError
|
from fastapi_oidc_op.store.exceptions import DuplicateError
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -238,3 +238,54 @@ class SQLiteCredentialRepository:
|
||||||
cursor = await self._db.execute("DELETE FROM password_credentials WHERE user_id = ?", (user_id,))
|
cursor = await self._db.execute("DELETE FROM password_credentials WHERE user_id = ?", (user_id,))
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
return cursor.rowcount > 0
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteMagicLinkRepository:
|
||||||
|
def __init__(self, db: aiosqlite.Connection) -> None:
|
||||||
|
self._db = db
|
||||||
|
|
||||||
|
def _row_to_magic_link(self, row: aiosqlite.Row) -> MagicLink:
|
||||||
|
return MagicLink(
|
||||||
|
token=row["token"],
|
||||||
|
username=row["username"],
|
||||||
|
expires_at=datetime.fromisoformat(row["expires_at"]),
|
||||||
|
used=bool(row["used"]),
|
||||||
|
created_by=row["created_by"],
|
||||||
|
note=row["note"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create(self, link: MagicLink) -> MagicLink:
|
||||||
|
try:
|
||||||
|
await self._db.execute(
|
||||||
|
"INSERT INTO magic_links (token, username, expires_at, used, created_by, note) VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
link.token,
|
||||||
|
link.username,
|
||||||
|
link.expires_at.isoformat(),
|
||||||
|
int(link.used),
|
||||||
|
link.created_by,
|
||||||
|
link.note,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
except aiosqlite.IntegrityError as e:
|
||||||
|
raise DuplicateError(str(e)) from e
|
||||||
|
return link
|
||||||
|
|
||||||
|
async def get_by_token(self, token: str) -> MagicLink | None:
|
||||||
|
async with self._db.execute("SELECT * FROM magic_links WHERE token = ?", (token,)) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return self._row_to_magic_link(row)
|
||||||
|
|
||||||
|
async def mark_used(self, token: str) -> bool:
|
||||||
|
cursor = await self._db.execute("UPDATE magic_links SET used = 1 WHERE token = ?", (token,))
|
||||||
|
await self._db.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
async def delete_expired(self) -> int:
|
||||||
|
now = datetime.now(UTC).isoformat()
|
||||||
|
cursor = await self._db.execute("DELETE FROM magic_links WHERE expires_at < ? AND used = 0", (now,))
|
||||||
|
await self._db.commit()
|
||||||
|
return cursor.rowcount
|
||||||
|
|
|
||||||
102
tests/test_store/test_sqlite_magic_link_repo.py
Normal file
102
tests/test_store/test_sqlite_magic_link_repo.py
Normal file
|
|
@ -0,0 +1,102 @@
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi_oidc_op.models import MagicLink
|
||||||
|
from fastapi_oidc_op.store.exceptions import DuplicateError
|
||||||
|
from fastapi_oidc_op.store.protocols import MagicLinkRepository
|
||||||
|
from fastapi_oidc_op.store.sqlite.repositories import SQLiteMagicLinkRepository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def magic_link_repo(db: aiosqlite.Connection) -> SQLiteMagicLinkRepository:
|
||||||
|
return SQLiteMagicLinkRepository(db)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_link(**overrides) -> MagicLink:
|
||||||
|
defaults = {
|
||||||
|
"token": "abc123",
|
||||||
|
"username": "alice",
|
||||||
|
"expires_at": datetime.now(UTC) + timedelta(hours=24),
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return MagicLink(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_implements_protocol(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||||
|
assert isinstance(magic_link_repo, MagicLinkRepository)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_and_get_by_token(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||||
|
link = _make_link()
|
||||||
|
created = await magic_link_repo.create(link)
|
||||||
|
assert created.token == "abc123"
|
||||||
|
|
||||||
|
fetched = await magic_link_repo.get_by_token("abc123")
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.token == "abc123"
|
||||||
|
assert fetched.username == "alice"
|
||||||
|
assert fetched.used is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_by_token_not_found(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||||
|
result = await magic_link_repo.get_by_token("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_mark_used(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||||
|
link = _make_link()
|
||||||
|
await magic_link_repo.create(link)
|
||||||
|
|
||||||
|
marked = await magic_link_repo.mark_used("abc123")
|
||||||
|
assert marked is True
|
||||||
|
|
||||||
|
fetched = await magic_link_repo.get_by_token("abc123")
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.used is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_mark_used_not_found(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||||
|
marked = await magic_link_repo.mark_used("nonexistent")
|
||||||
|
assert marked is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_expired(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||||
|
expired = _make_link(token="expired", expires_at=datetime.now(UTC) - timedelta(hours=1))
|
||||||
|
await magic_link_repo.create(expired)
|
||||||
|
|
||||||
|
valid = _make_link(token="valid", expires_at=datetime.now(UTC) + timedelta(hours=24))
|
||||||
|
await magic_link_repo.create(valid)
|
||||||
|
|
||||||
|
count = await magic_link_repo.delete_expired()
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
assert await magic_link_repo.get_by_token("expired") is None
|
||||||
|
assert await magic_link_repo.get_by_token("valid") is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_expired_skips_used(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||||
|
link = _make_link(token="used-expired", expires_at=datetime.now(UTC) - timedelta(hours=1))
|
||||||
|
await magic_link_repo.create(link)
|
||||||
|
await magic_link_repo.mark_used("used-expired")
|
||||||
|
|
||||||
|
count = await magic_link_repo.delete_expired()
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_with_optional_fields(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||||
|
link = _make_link(created_by="admin", note="Welcome aboard")
|
||||||
|
await magic_link_repo.create(link)
|
||||||
|
|
||||||
|
fetched = await magic_link_repo.get_by_token("abc123")
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.created_by == "admin"
|
||||||
|
assert fetched.note == "Welcome aboard"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_duplicate_token(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||||
|
await magic_link_repo.create(_make_link())
|
||||||
|
|
||||||
|
with pytest.raises(DuplicateError):
|
||||||
|
await magic_link_repo.create(_make_link())
|
||||||
Loading…
Add table
Add a link
Reference in a new issue