feat: add SQLiteUserRepository with tests
This commit is contained in:
parent
6c4ba79eed
commit
d941209f1e
3 changed files with 356 additions and 0 deletions
142
src/fastapi_oidc_op/store/sqlite/repositories.py
Normal file
142
src/fastapi_oidc_op/store/sqlite/repositories.py
Normal file
|
|
@ -0,0 +1,142 @@
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from fastapi_oidc_op.models import User
|
||||||
|
from fastapi_oidc_op.store.exceptions import DuplicateError
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteUserRepository:
|
||||||
|
def __init__(self, db: aiosqlite.Connection) -> None:
|
||||||
|
self._db = db
|
||||||
|
|
||||||
|
async def _get_groups(self, userid: str) -> list[str]:
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT group_name FROM user_groups WHERE userid = ? ORDER BY group_name", (userid,)
|
||||||
|
) as cursor:
|
||||||
|
return [row[0] async for row in cursor]
|
||||||
|
|
||||||
|
async def _set_groups(self, userid: str, groups: list[str]) -> None:
|
||||||
|
await self._db.execute("DELETE FROM user_groups WHERE userid = ?", (userid,))
|
||||||
|
for group in groups:
|
||||||
|
await self._db.execute("INSERT INTO user_groups (userid, group_name) VALUES (?, ?)", (userid, group))
|
||||||
|
|
||||||
|
def _row_to_user(self, row: aiosqlite.Row, groups: list[str]) -> User:
|
||||||
|
return User(
|
||||||
|
userid=row["userid"],
|
||||||
|
username=row["username"],
|
||||||
|
preferred_username=row["preferred_username"],
|
||||||
|
given_name=row["given_name"],
|
||||||
|
family_name=row["family_name"],
|
||||||
|
nickname=row["nickname"],
|
||||||
|
email=row["email"],
|
||||||
|
email_verified=bool(row["email_verified"]),
|
||||||
|
phone_number=row["phone_number"],
|
||||||
|
phone_number_verified=bool(row["phone_number_verified"]),
|
||||||
|
picture=row["picture"],
|
||||||
|
locale=row["locale"],
|
||||||
|
active=bool(row["active"]),
|
||||||
|
created_at=datetime.fromisoformat(row["created_at"]),
|
||||||
|
updated_at=datetime.fromisoformat(row["updated_at"]),
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create(self, user: User) -> User:
|
||||||
|
try:
|
||||||
|
await self._db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO users (
|
||||||
|
userid, username, preferred_username, given_name, family_name,
|
||||||
|
nickname, email, email_verified, phone_number, phone_number_verified,
|
||||||
|
picture, locale, active, created_at, updated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
user.userid,
|
||||||
|
user.username,
|
||||||
|
user.preferred_username,
|
||||||
|
user.given_name,
|
||||||
|
user.family_name,
|
||||||
|
user.nickname,
|
||||||
|
user.email,
|
||||||
|
int(user.email_verified),
|
||||||
|
user.phone_number,
|
||||||
|
int(user.phone_number_verified),
|
||||||
|
user.picture,
|
||||||
|
user.locale,
|
||||||
|
int(user.active),
|
||||||
|
user.created_at.isoformat(),
|
||||||
|
user.updated_at.isoformat(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self._set_groups(user.userid, user.groups)
|
||||||
|
await self._db.commit()
|
||||||
|
except aiosqlite.IntegrityError as e:
|
||||||
|
raise DuplicateError(str(e)) from e
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_by_userid(self, userid: str) -> User | None:
|
||||||
|
async with self._db.execute("SELECT * FROM users WHERE userid = ?", (userid,)) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
groups = await self._get_groups(userid)
|
||||||
|
return self._row_to_user(row, groups)
|
||||||
|
|
||||||
|
async def get_by_username(self, username: str) -> User | None:
|
||||||
|
async with self._db.execute("SELECT * FROM users WHERE username = ?", (username,)) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
groups = await self._get_groups(row["userid"])
|
||||||
|
return self._row_to_user(row, groups)
|
||||||
|
|
||||||
|
async def update(self, user: User) -> User:
|
||||||
|
updated = user.model_copy(update={"updated_at": datetime.now(UTC)})
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"""
|
||||||
|
UPDATE users SET
|
||||||
|
username = ?, preferred_username = ?, given_name = ?, family_name = ?,
|
||||||
|
nickname = ?, email = ?, email_verified = ?, phone_number = ?,
|
||||||
|
phone_number_verified = ?, picture = ?, locale = ?, active = ?,
|
||||||
|
updated_at = ?
|
||||||
|
WHERE userid = ?
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
updated.username,
|
||||||
|
updated.preferred_username,
|
||||||
|
updated.given_name,
|
||||||
|
updated.family_name,
|
||||||
|
updated.nickname,
|
||||||
|
updated.email,
|
||||||
|
int(updated.email_verified),
|
||||||
|
updated.phone_number,
|
||||||
|
int(updated.phone_number_verified),
|
||||||
|
updated.picture,
|
||||||
|
updated.locale,
|
||||||
|
int(updated.active),
|
||||||
|
updated.updated_at.isoformat(),
|
||||||
|
updated.userid,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if cursor.rowcount == 0:
|
||||||
|
raise ValueError(f"User not found: {user.userid}")
|
||||||
|
await self._set_groups(updated.userid, updated.groups)
|
||||||
|
await self._db.commit()
|
||||||
|
return updated
|
||||||
|
|
||||||
|
async def list_users(self, offset: int = 0, limit: int = 100) -> list[User]:
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT * FROM users ORDER BY username LIMIT ? OFFSET ?", (limit, offset)
|
||||||
|
) as cursor:
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
users = []
|
||||||
|
for row in rows:
|
||||||
|
groups = await self._get_groups(row["userid"])
|
||||||
|
users.append(self._row_to_user(row, groups))
|
||||||
|
return users
|
||||||
|
|
||||||
|
async def delete(self, userid: str) -> bool:
|
||||||
|
cursor = await self._db.execute("DELETE FROM users WHERE userid = ?", (userid,))
|
||||||
|
await self._db.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
20
tests/test_store/conftest.py
Normal file
20
tests/test_store/conftest.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi_oidc_op.store.sqlite.migrations import run_migrations
|
||||||
|
|
||||||
|
MIGRATIONS_DIR = (
|
||||||
|
Path(__file__).resolve().parent.parent.parent / "src" / "fastapi_oidc_op" / "store" / "sqlite" / "migrations"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db():
|
||||||
|
conn = await aiosqlite.connect(":memory:")
|
||||||
|
conn.row_factory = aiosqlite.Row
|
||||||
|
await conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
await run_migrations(conn, MIGRATIONS_DIR)
|
||||||
|
yield conn
|
||||||
|
await conn.close()
|
||||||
194
tests/test_store/test_sqlite_user_repo.py
Normal file
194
tests/test_store/test_sqlite_user_repo.py
Normal file
|
|
@ -0,0 +1,194 @@
|
||||||
|
import aiosqlite
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi_oidc_op.models import User
|
||||||
|
from fastapi_oidc_op.store.exceptions import DuplicateError
|
||||||
|
from fastapi_oidc_op.store.protocols import UserRepository
|
||||||
|
from fastapi_oidc_op.store.sqlite.repositories import SQLiteUserRepository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_repo(db: aiosqlite.Connection) -> SQLiteUserRepository:
|
||||||
|
return SQLiteUserRepository(db)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_user(**overrides) -> User:
|
||||||
|
defaults = {"userid": "lusab-bansen", "username": "alice"}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return User(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_implements_protocol(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
assert isinstance(user_repo, UserRepository)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_and_get_by_userid(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
user = _make_user()
|
||||||
|
created = await user_repo.create(user)
|
||||||
|
assert created.userid == "lusab-bansen"
|
||||||
|
assert created.username == "alice"
|
||||||
|
|
||||||
|
fetched = await user_repo.get_by_userid("lusab-bansen")
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.userid == "lusab-bansen"
|
||||||
|
assert fetched.username == "alice"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_by_username(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
user = _make_user()
|
||||||
|
await user_repo.create(user)
|
||||||
|
|
||||||
|
fetched = await user_repo.get_by_username("alice")
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.username == "alice"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_by_userid_not_found(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
result = await user_repo.get_by_userid("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_by_username_not_found(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
result = await user_repo.get_by_username("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_with_groups(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
user = _make_user(groups=["admin", "users"])
|
||||||
|
await user_repo.create(user)
|
||||||
|
|
||||||
|
fetched = await user_repo.get_by_userid("lusab-bansen")
|
||||||
|
assert fetched is not None
|
||||||
|
assert sorted(fetched.groups) == ["admin", "users"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
user = _make_user()
|
||||||
|
await user_repo.create(user)
|
||||||
|
|
||||||
|
user.email = "alice@example.com"
|
||||||
|
user.given_name = "Alice"
|
||||||
|
updated = await user_repo.update(user)
|
||||||
|
assert updated.email == "alice@example.com"
|
||||||
|
assert updated.given_name == "Alice"
|
||||||
|
|
||||||
|
fetched = await user_repo.get_by_userid("lusab-bansen")
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.email == "alice@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_not_found(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
user = _make_user()
|
||||||
|
with pytest.raises(ValueError, match="User not found"):
|
||||||
|
await user_repo.update(user)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_does_not_mutate_input(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
user = _make_user()
|
||||||
|
await user_repo.create(user)
|
||||||
|
original_updated_at = user.updated_at
|
||||||
|
|
||||||
|
await user_repo.update(user)
|
||||||
|
assert user.updated_at == original_updated_at
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_groups(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
user = _make_user(groups=["users"])
|
||||||
|
await user_repo.create(user)
|
||||||
|
|
||||||
|
user.groups = ["admin", "editors"]
|
||||||
|
await user_repo.update(user)
|
||||||
|
|
||||||
|
fetched = await user_repo.get_by_userid("lusab-bansen")
|
||||||
|
assert fetched is not None
|
||||||
|
assert sorted(fetched.groups) == ["admin", "editors"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_users(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
await user_repo.create(_make_user(userid="id-1", username="alice"))
|
||||||
|
await user_repo.create(_make_user(userid="id-2", username="bob"))
|
||||||
|
await user_repo.create(_make_user(userid="id-3", username="charlie"))
|
||||||
|
|
||||||
|
users = await user_repo.list_users()
|
||||||
|
assert len(users) == 3
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_users_pagination(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
for i in range(5):
|
||||||
|
await user_repo.create(_make_user(userid=f"id-{i}", username=f"user-{i}"))
|
||||||
|
|
||||||
|
page1 = await user_repo.list_users(offset=0, limit=2)
|
||||||
|
page2 = await user_repo.list_users(offset=2, limit=2)
|
||||||
|
page3 = await user_repo.list_users(offset=4, limit=2)
|
||||||
|
assert len(page1) == 2
|
||||||
|
assert len(page2) == 2
|
||||||
|
assert len(page3) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
user = _make_user()
|
||||||
|
await user_repo.create(user)
|
||||||
|
|
||||||
|
deleted = await user_repo.delete("lusab-bansen")
|
||||||
|
assert deleted is True
|
||||||
|
|
||||||
|
fetched = await user_repo.get_by_userid("lusab-bansen")
|
||||||
|
assert fetched is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_not_found(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
deleted = await user_repo.delete("nonexistent")
|
||||||
|
assert deleted is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_cascades_groups(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
user = _make_user(groups=["admin"])
|
||||||
|
await user_repo.create(user)
|
||||||
|
|
||||||
|
await user_repo.delete("lusab-bansen")
|
||||||
|
|
||||||
|
async with user_repo._db.execute("SELECT COUNT(*) FROM user_groups WHERE userid = ?", ("lusab-bansen",)) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row[0] == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_duplicate_username(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
await user_repo.create(_make_user())
|
||||||
|
|
||||||
|
with pytest.raises(DuplicateError):
|
||||||
|
await user_repo.create(_make_user(userid="different-id", username="alice"))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_roundtrip_preserves_all_fields(user_repo: SQLiteUserRepository) -> None:
|
||||||
|
user = _make_user(
|
||||||
|
preferred_username="ally",
|
||||||
|
given_name="Alice",
|
||||||
|
family_name="Smith",
|
||||||
|
nickname="Al",
|
||||||
|
email="alice@example.com",
|
||||||
|
email_verified=True,
|
||||||
|
phone_number="+1234567890",
|
||||||
|
phone_number_verified=True,
|
||||||
|
picture="https://example.com/alice.jpg",
|
||||||
|
locale="en-US",
|
||||||
|
active=False,
|
||||||
|
groups=["admin", "users"],
|
||||||
|
)
|
||||||
|
await user_repo.create(user)
|
||||||
|
|
||||||
|
fetched = await user_repo.get_by_userid("lusab-bansen")
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.preferred_username == "ally"
|
||||||
|
assert fetched.given_name == "Alice"
|
||||||
|
assert fetched.family_name == "Smith"
|
||||||
|
assert fetched.nickname == "Al"
|
||||||
|
assert fetched.email == "alice@example.com"
|
||||||
|
assert fetched.email_verified is True
|
||||||
|
assert fetched.phone_number == "+1234567890"
|
||||||
|
assert fetched.phone_number_verified is True
|
||||||
|
assert fetched.picture == "https://example.com/alice.jpg"
|
||||||
|
assert fetched.locale == "en-US"
|
||||||
|
assert fetched.active is False
|
||||||
|
assert sorted(fetched.groups) == ["admin", "users"]
|
||||||
|
assert fetched.created_at == user.created_at
|
||||||
|
assert fetched.updated_at == user.updated_at
|
||||||
Loading…
Add table
Add a link
Reference in a new issue