feat: add SQLiteMagicLinkRepository with tests
This commit is contained in:
parent
bbe0dac8cb
commit
9f4914a922
2 changed files with 154 additions and 1 deletions
|
|
@ -2,7 +2,7 @@ from datetime import UTC, datetime
|
|||
|
||||
import aiosqlite
|
||||
|
||||
from fastapi_oidc_op.models import PasswordCredential, User, WebAuthnCredential
|
||||
from fastapi_oidc_op.models import MagicLink, PasswordCredential, User, WebAuthnCredential
|
||||
from fastapi_oidc_op.store.exceptions import DuplicateError
|
||||
|
||||
|
||||
|
|
@ -238,3 +238,54 @@ class SQLiteCredentialRepository:
|
|||
cursor = await self._db.execute("DELETE FROM password_credentials WHERE user_id = ?", (user_id,))
|
||||
await self._db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
|
||||
class SQLiteMagicLinkRepository:
|
||||
def __init__(self, db: aiosqlite.Connection) -> None:
|
||||
self._db = db
|
||||
|
||||
def _row_to_magic_link(self, row: aiosqlite.Row) -> MagicLink:
|
||||
return MagicLink(
|
||||
token=row["token"],
|
||||
username=row["username"],
|
||||
expires_at=datetime.fromisoformat(row["expires_at"]),
|
||||
used=bool(row["used"]),
|
||||
created_by=row["created_by"],
|
||||
note=row["note"],
|
||||
)
|
||||
|
||||
async def create(self, link: MagicLink) -> MagicLink:
|
||||
try:
|
||||
await self._db.execute(
|
||||
"INSERT INTO magic_links (token, username, expires_at, used, created_by, note) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
link.token,
|
||||
link.username,
|
||||
link.expires_at.isoformat(),
|
||||
int(link.used),
|
||||
link.created_by,
|
||||
link.note,
|
||||
),
|
||||
)
|
||||
await self._db.commit()
|
||||
except aiosqlite.IntegrityError as e:
|
||||
raise DuplicateError(str(e)) from e
|
||||
return link
|
||||
|
||||
async def get_by_token(self, token: str) -> MagicLink | None:
|
||||
async with self._db.execute("SELECT * FROM magic_links WHERE token = ?", (token,)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_magic_link(row)
|
||||
|
||||
async def mark_used(self, token: str) -> bool:
|
||||
cursor = await self._db.execute("UPDATE magic_links SET used = 1 WHERE token = ?", (token,))
|
||||
await self._db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
async def delete_expired(self) -> int:
|
||||
now = datetime.now(UTC).isoformat()
|
||||
cursor = await self._db.execute("DELETE FROM magic_links WHERE expires_at < ? AND used = 0", (now,))
|
||||
await self._db.commit()
|
||||
return cursor.rowcount
|
||||
|
|
|
|||
102
tests/test_store/test_sqlite_magic_link_repo.py
Normal file
102
tests/test_store/test_sqlite_magic_link_repo.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from fastapi_oidc_op.models import MagicLink
|
||||
from fastapi_oidc_op.store.exceptions import DuplicateError
|
||||
from fastapi_oidc_op.store.protocols import MagicLinkRepository
|
||||
from fastapi_oidc_op.store.sqlite.repositories import SQLiteMagicLinkRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def magic_link_repo(db: aiosqlite.Connection) -> SQLiteMagicLinkRepository:
|
||||
return SQLiteMagicLinkRepository(db)
|
||||
|
||||
|
||||
def _make_link(**overrides) -> MagicLink:
|
||||
defaults = {
|
||||
"token": "abc123",
|
||||
"username": "alice",
|
||||
"expires_at": datetime.now(UTC) + timedelta(hours=24),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return MagicLink(**defaults)
|
||||
|
||||
|
||||
async def test_implements_protocol(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||
assert isinstance(magic_link_repo, MagicLinkRepository)
|
||||
|
||||
|
||||
async def test_create_and_get_by_token(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||
link = _make_link()
|
||||
created = await magic_link_repo.create(link)
|
||||
assert created.token == "abc123"
|
||||
|
||||
fetched = await magic_link_repo.get_by_token("abc123")
|
||||
assert fetched is not None
|
||||
assert fetched.token == "abc123"
|
||||
assert fetched.username == "alice"
|
||||
assert fetched.used is False
|
||||
|
||||
|
||||
async def test_get_by_token_not_found(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||
result = await magic_link_repo.get_by_token("nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_mark_used(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||
link = _make_link()
|
||||
await magic_link_repo.create(link)
|
||||
|
||||
marked = await magic_link_repo.mark_used("abc123")
|
||||
assert marked is True
|
||||
|
||||
fetched = await magic_link_repo.get_by_token("abc123")
|
||||
assert fetched is not None
|
||||
assert fetched.used is True
|
||||
|
||||
|
||||
async def test_mark_used_not_found(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||
marked = await magic_link_repo.mark_used("nonexistent")
|
||||
assert marked is False
|
||||
|
||||
|
||||
async def test_delete_expired(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||
expired = _make_link(token="expired", expires_at=datetime.now(UTC) - timedelta(hours=1))
|
||||
await magic_link_repo.create(expired)
|
||||
|
||||
valid = _make_link(token="valid", expires_at=datetime.now(UTC) + timedelta(hours=24))
|
||||
await magic_link_repo.create(valid)
|
||||
|
||||
count = await magic_link_repo.delete_expired()
|
||||
assert count == 1
|
||||
|
||||
assert await magic_link_repo.get_by_token("expired") is None
|
||||
assert await magic_link_repo.get_by_token("valid") is not None
|
||||
|
||||
|
||||
async def test_delete_expired_skips_used(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||
link = _make_link(token="used-expired", expires_at=datetime.now(UTC) - timedelta(hours=1))
|
||||
await magic_link_repo.create(link)
|
||||
await magic_link_repo.mark_used("used-expired")
|
||||
|
||||
count = await magic_link_repo.delete_expired()
|
||||
assert count == 0
|
||||
|
||||
|
||||
async def test_create_with_optional_fields(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||
link = _make_link(created_by="admin", note="Welcome aboard")
|
||||
await magic_link_repo.create(link)
|
||||
|
||||
fetched = await magic_link_repo.get_by_token("abc123")
|
||||
assert fetched is not None
|
||||
assert fetched.created_by == "admin"
|
||||
assert fetched.note == "Welcome aboard"
|
||||
|
||||
|
||||
async def test_create_duplicate_token(magic_link_repo: SQLiteMagicLinkRepository) -> None:
|
||||
await magic_link_repo.create(_make_link())
|
||||
|
||||
with pytest.raises(DuplicateError):
|
||||
await magic_link_repo.create(_make_link())
|
||||
Loading…
Add table
Add a link
Reference in a new issue