diff --git a/src/fastapi_oidc_op/store/sqlite/repositories.py b/src/fastapi_oidc_op/store/sqlite/repositories.py new file mode 100644 index 0000000..d2b96c2 --- /dev/null +++ b/src/fastapi_oidc_op/store/sqlite/repositories.py @@ -0,0 +1,142 @@ +from datetime import UTC, datetime + +import aiosqlite + +from fastapi_oidc_op.models import User +from fastapi_oidc_op.store.exceptions import DuplicateError + + +class SQLiteUserRepository: + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def _get_groups(self, userid: str) -> list[str]: + async with self._db.execute( + "SELECT group_name FROM user_groups WHERE userid = ? ORDER BY group_name", (userid,) + ) as cursor: + return [row[0] async for row in cursor] + + async def _set_groups(self, userid: str, groups: list[str]) -> None: + await self._db.execute("DELETE FROM user_groups WHERE userid = ?", (userid,)) + for group in groups: + await self._db.execute("INSERT INTO user_groups (userid, group_name) VALUES (?, ?)", (userid, group)) + + def _row_to_user(self, row: aiosqlite.Row, groups: list[str]) -> User: + return User( + userid=row["userid"], + username=row["username"], + preferred_username=row["preferred_username"], + given_name=row["given_name"], + family_name=row["family_name"], + nickname=row["nickname"], + email=row["email"], + email_verified=bool(row["email_verified"]), + phone_number=row["phone_number"], + phone_number_verified=bool(row["phone_number_verified"]), + picture=row["picture"], + locale=row["locale"], + active=bool(row["active"]), + created_at=datetime.fromisoformat(row["created_at"]), + updated_at=datetime.fromisoformat(row["updated_at"]), + groups=groups, + ) + + async def create(self, user: User) -> User: + try: + await self._db.execute( + """ + INSERT INTO users ( + userid, username, preferred_username, given_name, family_name, + nickname, email, email_verified, phone_number, phone_number_verified, + picture, locale, active, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + user.userid, + user.username, + user.preferred_username, + user.given_name, + user.family_name, + user.nickname, + user.email, + int(user.email_verified), + user.phone_number, + int(user.phone_number_verified), + user.picture, + user.locale, + int(user.active), + user.created_at.isoformat(), + user.updated_at.isoformat(), + ), + ) + await self._set_groups(user.userid, user.groups) + await self._db.commit() + except aiosqlite.IntegrityError as e: + raise DuplicateError(str(e)) from e + return user + + async def get_by_userid(self, userid: str) -> User | None: + async with self._db.execute("SELECT * FROM users WHERE userid = ?", (userid,)) as cursor: + row = await cursor.fetchone() + if row is None: + return None + groups = await self._get_groups(userid) + return self._row_to_user(row, groups) + + async def get_by_username(self, username: str) -> User | None: + async with self._db.execute("SELECT * FROM users WHERE username = ?", (username,)) as cursor: + row = await cursor.fetchone() + if row is None: + return None + groups = await self._get_groups(row["userid"]) + return self._row_to_user(row, groups) + + async def update(self, user: User) -> User: + updated = user.model_copy(update={"updated_at": datetime.now(UTC)}) + cursor = await self._db.execute( + """ + UPDATE users SET + username = ?, preferred_username = ?, given_name = ?, family_name = ?, + nickname = ?, email = ?, email_verified = ?, phone_number = ?, + phone_number_verified = ?, picture = ?, locale = ?, active = ?, + updated_at = ? + WHERE userid = ? + """, + ( + updated.username, + updated.preferred_username, + updated.given_name, + updated.family_name, + updated.nickname, + updated.email, + int(updated.email_verified), + updated.phone_number, + int(updated.phone_number_verified), + updated.picture, + updated.locale, + int(updated.active), + updated.updated_at.isoformat(), + updated.userid, + ), + ) + if cursor.rowcount == 0: + raise ValueError(f"User not found: {user.userid}") + await self._set_groups(updated.userid, updated.groups) + await self._db.commit() + return updated + + async def list_users(self, offset: int = 0, limit: int = 100) -> list[User]: + async with self._db.execute( + "SELECT * FROM users ORDER BY username LIMIT ? OFFSET ?", (limit, offset) + ) as cursor: + rows = await cursor.fetchall() + users = [] + for row in rows: + groups = await self._get_groups(row["userid"]) + users.append(self._row_to_user(row, groups)) + return users + + async def delete(self, userid: str) -> bool: + cursor = await self._db.execute("DELETE FROM users WHERE userid = ?", (userid,)) + await self._db.commit() + return cursor.rowcount > 0 diff --git a/tests/test_store/conftest.py b/tests/test_store/conftest.py new file mode 100644 index 0000000..1e7b55c --- /dev/null +++ b/tests/test_store/conftest.py @@ -0,0 +1,20 @@ +from pathlib import Path + +import aiosqlite +import pytest + +from fastapi_oidc_op.store.sqlite.migrations import run_migrations + +MIGRATIONS_DIR = ( + Path(__file__).resolve().parent.parent.parent / "src" / "fastapi_oidc_op" / "store" / "sqlite" / "migrations" +) + + +@pytest.fixture +async def db(): + conn = await aiosqlite.connect(":memory:") + conn.row_factory = aiosqlite.Row + await conn.execute("PRAGMA foreign_keys=ON") + await run_migrations(conn, MIGRATIONS_DIR) + yield conn + await conn.close() diff --git a/tests/test_store/test_sqlite_user_repo.py b/tests/test_store/test_sqlite_user_repo.py new file mode 100644 index 0000000..f496d60 --- /dev/null +++ b/tests/test_store/test_sqlite_user_repo.py @@ -0,0 +1,194 @@ +import aiosqlite +import pytest + +from fastapi_oidc_op.models import User +from fastapi_oidc_op.store.exceptions import DuplicateError +from fastapi_oidc_op.store.protocols import UserRepository +from fastapi_oidc_op.store.sqlite.repositories import SQLiteUserRepository + + +@pytest.fixture +def user_repo(db: aiosqlite.Connection) -> SQLiteUserRepository: + return SQLiteUserRepository(db) + + +def _make_user(**overrides) -> User: + defaults = {"userid": "lusab-bansen", "username": "alice"} + defaults.update(overrides) + return User(**defaults) + + +async def test_implements_protocol(user_repo: SQLiteUserRepository) -> None: + assert isinstance(user_repo, UserRepository) + + +async def test_create_and_get_by_userid(user_repo: SQLiteUserRepository) -> None: + user = _make_user() + created = await user_repo.create(user) + assert created.userid == "lusab-bansen" + assert created.username == "alice" + + fetched = await user_repo.get_by_userid("lusab-bansen") + assert fetched is not None + assert fetched.userid == "lusab-bansen" + assert fetched.username == "alice" + + +async def test_get_by_username(user_repo: SQLiteUserRepository) -> None: + user = _make_user() + await user_repo.create(user) + + fetched = await user_repo.get_by_username("alice") + assert fetched is not None + assert fetched.username == "alice" + + +async def test_get_by_userid_not_found(user_repo: SQLiteUserRepository) -> None: + result = await user_repo.get_by_userid("nonexistent") + assert result is None + + +async def test_get_by_username_not_found(user_repo: SQLiteUserRepository) -> None: + result = await user_repo.get_by_username("nonexistent") + assert result is None + + +async def test_create_with_groups(user_repo: SQLiteUserRepository) -> None: + user = _make_user(groups=["admin", "users"]) + await user_repo.create(user) + + fetched = await user_repo.get_by_userid("lusab-bansen") + assert fetched is not None + assert sorted(fetched.groups) == ["admin", "users"] + + +async def test_update(user_repo: SQLiteUserRepository) -> None: + user = _make_user() + await user_repo.create(user) + + user.email = "alice@example.com" + user.given_name = "Alice" + updated = await user_repo.update(user) + assert updated.email == "alice@example.com" + assert updated.given_name == "Alice" + + fetched = await user_repo.get_by_userid("lusab-bansen") + assert fetched is not None + assert fetched.email == "alice@example.com" + + +async def test_update_not_found(user_repo: SQLiteUserRepository) -> None: + user = _make_user() + with pytest.raises(ValueError, match="User not found"): + await user_repo.update(user) + + +async def test_update_does_not_mutate_input(user_repo: SQLiteUserRepository) -> None: + user = _make_user() + await user_repo.create(user) + original_updated_at = user.updated_at + + await user_repo.update(user) + assert user.updated_at == original_updated_at + + +async def test_update_groups(user_repo: SQLiteUserRepository) -> None: + user = _make_user(groups=["users"]) + await user_repo.create(user) + + user.groups = ["admin", "editors"] + await user_repo.update(user) + + fetched = await user_repo.get_by_userid("lusab-bansen") + assert fetched is not None + assert sorted(fetched.groups) == ["admin", "editors"] + + +async def test_list_users(user_repo: SQLiteUserRepository) -> None: + await user_repo.create(_make_user(userid="id-1", username="alice")) + await user_repo.create(_make_user(userid="id-2", username="bob")) + await user_repo.create(_make_user(userid="id-3", username="charlie")) + + users = await user_repo.list_users() + assert len(users) == 3 + + +async def test_list_users_pagination(user_repo: SQLiteUserRepository) -> None: + for i in range(5): + await user_repo.create(_make_user(userid=f"id-{i}", username=f"user-{i}")) + + page1 = await user_repo.list_users(offset=0, limit=2) + page2 = await user_repo.list_users(offset=2, limit=2) + page3 = await user_repo.list_users(offset=4, limit=2) + assert len(page1) == 2 + assert len(page2) == 2 + assert len(page3) == 1 + + +async def test_delete(user_repo: SQLiteUserRepository) -> None: + user = _make_user() + await user_repo.create(user) + + deleted = await user_repo.delete("lusab-bansen") + assert deleted is True + + fetched = await user_repo.get_by_userid("lusab-bansen") + assert fetched is None + + +async def test_delete_not_found(user_repo: SQLiteUserRepository) -> None: + deleted = await user_repo.delete("nonexistent") + assert deleted is False + + +async def test_delete_cascades_groups(user_repo: SQLiteUserRepository) -> None: + user = _make_user(groups=["admin"]) + await user_repo.create(user) + + await user_repo.delete("lusab-bansen") + + async with user_repo._db.execute("SELECT COUNT(*) FROM user_groups WHERE userid = ?", ("lusab-bansen",)) as cursor: + row = await cursor.fetchone() + assert row[0] == 0 + + +async def test_create_duplicate_username(user_repo: SQLiteUserRepository) -> None: + await user_repo.create(_make_user()) + + with pytest.raises(DuplicateError): + await user_repo.create(_make_user(userid="different-id", username="alice")) + + +async def test_roundtrip_preserves_all_fields(user_repo: SQLiteUserRepository) -> None: + user = _make_user( + preferred_username="ally", + given_name="Alice", + family_name="Smith", + nickname="Al", + email="alice@example.com", + email_verified=True, + phone_number="+1234567890", + phone_number_verified=True, + picture="https://example.com/alice.jpg", + locale="en-US", + active=False, + groups=["admin", "users"], + ) + await user_repo.create(user) + + fetched = await user_repo.get_by_userid("lusab-bansen") + assert fetched is not None + assert fetched.preferred_username == "ally" + assert fetched.given_name == "Alice" + assert fetched.family_name == "Smith" + assert fetched.nickname == "Al" + assert fetched.email == "alice@example.com" + assert fetched.email_verified is True + assert fetched.phone_number == "+1234567890" + assert fetched.phone_number_verified is True + assert fetched.picture == "https://example.com/alice.jpg" + assert fetched.locale == "en-US" + assert fetched.active is False + assert sorted(fetched.groups) == ["admin", "users"] + assert fetched.created_at == user.created_at + assert fetched.updated_at == user.updated_at