feat: add SQLiteUserRepository with tests
This commit is contained in:
parent
6c4ba79eed
commit
d941209f1e
3 changed files with 356 additions and 0 deletions
142
src/fastapi_oidc_op/store/sqlite/repositories.py
Normal file
142
src/fastapi_oidc_op/store/sqlite/repositories.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from fastapi_oidc_op.models import User
|
||||
from fastapi_oidc_op.store.exceptions import DuplicateError
|
||||
|
||||
|
||||
class SQLiteUserRepository:
|
||||
def __init__(self, db: aiosqlite.Connection) -> None:
|
||||
self._db = db
|
||||
|
||||
async def _get_groups(self, userid: str) -> list[str]:
|
||||
async with self._db.execute(
|
||||
"SELECT group_name FROM user_groups WHERE userid = ? ORDER BY group_name", (userid,)
|
||||
) as cursor:
|
||||
return [row[0] async for row in cursor]
|
||||
|
||||
async def _set_groups(self, userid: str, groups: list[str]) -> None:
|
||||
await self._db.execute("DELETE FROM user_groups WHERE userid = ?", (userid,))
|
||||
for group in groups:
|
||||
await self._db.execute("INSERT INTO user_groups (userid, group_name) VALUES (?, ?)", (userid, group))
|
||||
|
||||
def _row_to_user(self, row: aiosqlite.Row, groups: list[str]) -> User:
|
||||
return User(
|
||||
userid=row["userid"],
|
||||
username=row["username"],
|
||||
preferred_username=row["preferred_username"],
|
||||
given_name=row["given_name"],
|
||||
family_name=row["family_name"],
|
||||
nickname=row["nickname"],
|
||||
email=row["email"],
|
||||
email_verified=bool(row["email_verified"]),
|
||||
phone_number=row["phone_number"],
|
||||
phone_number_verified=bool(row["phone_number_verified"]),
|
||||
picture=row["picture"],
|
||||
locale=row["locale"],
|
||||
active=bool(row["active"]),
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
updated_at=datetime.fromisoformat(row["updated_at"]),
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
async def create(self, user: User) -> User:
|
||||
try:
|
||||
await self._db.execute(
|
||||
"""
|
||||
INSERT INTO users (
|
||||
userid, username, preferred_username, given_name, family_name,
|
||||
nickname, email, email_verified, phone_number, phone_number_verified,
|
||||
picture, locale, active, created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
user.userid,
|
||||
user.username,
|
||||
user.preferred_username,
|
||||
user.given_name,
|
||||
user.family_name,
|
||||
user.nickname,
|
||||
user.email,
|
||||
int(user.email_verified),
|
||||
user.phone_number,
|
||||
int(user.phone_number_verified),
|
||||
user.picture,
|
||||
user.locale,
|
||||
int(user.active),
|
||||
user.created_at.isoformat(),
|
||||
user.updated_at.isoformat(),
|
||||
),
|
||||
)
|
||||
await self._set_groups(user.userid, user.groups)
|
||||
await self._db.commit()
|
||||
except aiosqlite.IntegrityError as e:
|
||||
raise DuplicateError(str(e)) from e
|
||||
return user
|
||||
|
||||
async def get_by_userid(self, userid: str) -> User | None:
|
||||
async with self._db.execute("SELECT * FROM users WHERE userid = ?", (userid,)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
groups = await self._get_groups(userid)
|
||||
return self._row_to_user(row, groups)
|
||||
|
||||
async def get_by_username(self, username: str) -> User | None:
|
||||
async with self._db.execute("SELECT * FROM users WHERE username = ?", (username,)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
groups = await self._get_groups(row["userid"])
|
||||
return self._row_to_user(row, groups)
|
||||
|
||||
async def update(self, user: User) -> User:
|
||||
updated = user.model_copy(update={"updated_at": datetime.now(UTC)})
|
||||
cursor = await self._db.execute(
|
||||
"""
|
||||
UPDATE users SET
|
||||
username = ?, preferred_username = ?, given_name = ?, family_name = ?,
|
||||
nickname = ?, email = ?, email_verified = ?, phone_number = ?,
|
||||
phone_number_verified = ?, picture = ?, locale = ?, active = ?,
|
||||
updated_at = ?
|
||||
WHERE userid = ?
|
||||
""",
|
||||
(
|
||||
updated.username,
|
||||
updated.preferred_username,
|
||||
updated.given_name,
|
||||
updated.family_name,
|
||||
updated.nickname,
|
||||
updated.email,
|
||||
int(updated.email_verified),
|
||||
updated.phone_number,
|
||||
int(updated.phone_number_verified),
|
||||
updated.picture,
|
||||
updated.locale,
|
||||
int(updated.active),
|
||||
updated.updated_at.isoformat(),
|
||||
updated.userid,
|
||||
),
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise ValueError(f"User not found: {user.userid}")
|
||||
await self._set_groups(updated.userid, updated.groups)
|
||||
await self._db.commit()
|
||||
return updated
|
||||
|
||||
async def list_users(self, offset: int = 0, limit: int = 100) -> list[User]:
|
||||
async with self._db.execute(
|
||||
"SELECT * FROM users ORDER BY username LIMIT ? OFFSET ?", (limit, offset)
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
users = []
|
||||
for row in rows:
|
||||
groups = await self._get_groups(row["userid"])
|
||||
users.append(self._row_to_user(row, groups))
|
||||
return users
|
||||
|
||||
async def delete(self, userid: str) -> bool:
|
||||
cursor = await self._db.execute("DELETE FROM users WHERE userid = ?", (userid,))
|
||||
await self._db.commit()
|
||||
return cursor.rowcount > 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue