feat: add SQLiteCredentialRepository with tests
This commit is contained in:
parent
d941209f1e
commit
bbe0dac8cb
2 changed files with 283 additions and 1 deletions
|
|
@ -2,7 +2,7 @@ from datetime import UTC, datetime
|
|||
|
||||
import aiosqlite
|
||||
|
||||
from fastapi_oidc_op.models import User
|
||||
from fastapi_oidc_op.models import PasswordCredential, User, WebAuthnCredential
|
||||
from fastapi_oidc_op.store.exceptions import DuplicateError
|
||||
|
||||
|
||||
|
|
@ -140,3 +140,101 @@ class SQLiteUserRepository:
|
|||
cursor = await self._db.execute("DELETE FROM users WHERE userid = ?", (userid,))
|
||||
await self._db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
|
||||
class SQLiteCredentialRepository:
|
||||
def __init__(self, db: aiosqlite.Connection) -> None:
|
||||
self._db = db
|
||||
|
||||
def _row_to_webauthn(self, row: aiosqlite.Row) -> WebAuthnCredential:
|
||||
return WebAuthnCredential(
|
||||
user_id=row["user_id"],
|
||||
credential_id=bytes(row["credential_id"]),
|
||||
public_key=bytes(row["public_key"]),
|
||||
sign_count=row["sign_count"],
|
||||
device_name=row["device_name"],
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
)
|
||||
|
||||
def _row_to_password(self, row: aiosqlite.Row) -> PasswordCredential:
|
||||
return PasswordCredential(
|
||||
user_id=row["user_id"],
|
||||
password_hash=row["password_hash"],
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
)
|
||||
|
||||
async def create_webauthn(self, credential: WebAuthnCredential) -> WebAuthnCredential:
|
||||
try:
|
||||
await self._db.execute(
|
||||
"""
|
||||
INSERT INTO webauthn_credentials (user_id, credential_id, public_key, sign_count, device_name, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
credential.user_id,
|
||||
credential.credential_id,
|
||||
credential.public_key,
|
||||
credential.sign_count,
|
||||
credential.device_name,
|
||||
credential.created_at.isoformat(),
|
||||
),
|
||||
)
|
||||
await self._db.commit()
|
||||
except aiosqlite.IntegrityError as e:
|
||||
raise DuplicateError(str(e)) from e
|
||||
return credential
|
||||
|
||||
async def create_password(self, credential: PasswordCredential) -> PasswordCredential:
|
||||
try:
|
||||
await self._db.execute(
|
||||
"INSERT INTO password_credentials (user_id, password_hash, created_at) VALUES (?, ?, ?)",
|
||||
(credential.user_id, credential.password_hash, credential.created_at.isoformat()),
|
||||
)
|
||||
await self._db.commit()
|
||||
except aiosqlite.IntegrityError as e:
|
||||
raise DuplicateError(str(e)) from e
|
||||
return credential
|
||||
|
||||
async def get_webauthn_by_user(self, user_id: str) -> list[WebAuthnCredential]:
|
||||
async with self._db.execute("SELECT * FROM webauthn_credentials WHERE user_id = ?", (user_id,)) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [self._row_to_webauthn(row) for row in rows]
|
||||
|
||||
async def get_webauthn_by_credential_id(self, credential_id: bytes) -> WebAuthnCredential | None:
|
||||
async with self._db.execute(
|
||||
"SELECT * FROM webauthn_credentials WHERE credential_id = ?", (credential_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_webauthn(row)
|
||||
|
||||
async def get_password_by_user(self, user_id: str) -> PasswordCredential | None:
|
||||
async with self._db.execute("SELECT * FROM password_credentials WHERE user_id = ?", (user_id,)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_password(row)
|
||||
|
||||
async def update_webauthn(self, credential: WebAuthnCredential) -> WebAuthnCredential:
|
||||
cursor = await self._db.execute(
|
||||
"UPDATE webauthn_credentials SET sign_count = ?, device_name = ? WHERE user_id = ? AND credential_id = ?",
|
||||
(credential.sign_count, credential.device_name, credential.user_id, credential.credential_id),
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise ValueError(f"WebAuthn credential not found for user: {credential.user_id}")
|
||||
await self._db.commit()
|
||||
return credential
|
||||
|
||||
async def delete_webauthn(self, user_id: str, credential_id: bytes) -> bool:
|
||||
cursor = await self._db.execute(
|
||||
"DELETE FROM webauthn_credentials WHERE user_id = ? AND credential_id = ?",
|
||||
(user_id, credential_id),
|
||||
)
|
||||
await self._db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
async def delete_password(self, user_id: str) -> bool:
|
||||
cursor = await self._db.execute("DELETE FROM password_credentials WHERE user_id = ?", (user_id,))
|
||||
await self._db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
|
|
|||
184
tests/test_store/test_sqlite_credential_repo.py
Normal file
184
tests/test_store/test_sqlite_credential_repo.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from fastapi_oidc_op.models import PasswordCredential, User, WebAuthnCredential
|
||||
from fastapi_oidc_op.store.exceptions import DuplicateError
|
||||
from fastapi_oidc_op.store.protocols import CredentialRepository
|
||||
from fastapi_oidc_op.store.sqlite.repositories import (
|
||||
SQLiteCredentialRepository,
|
||||
SQLiteUserRepository,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_repo(db: aiosqlite.Connection) -> SQLiteUserRepository:
|
||||
return SQLiteUserRepository(db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def credential_repo(db: aiosqlite.Connection) -> SQLiteCredentialRepository:
|
||||
return SQLiteCredentialRepository(db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def alice(user_repo: SQLiteUserRepository) -> User:
|
||||
return await user_repo.create(User(userid="lusab-bansen", username="alice"))
|
||||
|
||||
|
||||
async def test_implements_protocol(credential_repo: SQLiteCredentialRepository) -> None:
|
||||
assert isinstance(credential_repo, CredentialRepository)
|
||||
|
||||
|
||||
async def test_create_and_get_webauthn(credential_repo: SQLiteCredentialRepository, alice: User) -> None:
|
||||
cred = WebAuthnCredential(
|
||||
user_id=alice.userid,
|
||||
credential_id=b"\x01\x02\x03",
|
||||
public_key=b"\x04\x05\x06",
|
||||
device_name="YubiKey",
|
||||
)
|
||||
created = await credential_repo.create_webauthn(cred)
|
||||
assert created.user_id == alice.userid
|
||||
|
||||
creds = await credential_repo.get_webauthn_by_user(alice.userid)
|
||||
assert len(creds) == 1
|
||||
assert creds[0].credential_id == b"\x01\x02\x03"
|
||||
assert creds[0].public_key == b"\x04\x05\x06"
|
||||
assert creds[0].device_name == "YubiKey"
|
||||
|
||||
|
||||
async def test_get_webauthn_by_credential_id(credential_repo: SQLiteCredentialRepository, alice: User) -> None:
|
||||
cred = WebAuthnCredential(
|
||||
user_id=alice.userid,
|
||||
credential_id=b"\x01\x02\x03",
|
||||
public_key=b"\x04\x05\x06",
|
||||
)
|
||||
await credential_repo.create_webauthn(cred)
|
||||
|
||||
fetched = await credential_repo.get_webauthn_by_credential_id(b"\x01\x02\x03")
|
||||
assert fetched is not None
|
||||
assert fetched.user_id == alice.userid
|
||||
|
||||
|
||||
async def test_get_webauthn_by_credential_id_not_found(credential_repo: SQLiteCredentialRepository) -> None:
|
||||
result = await credential_repo.get_webauthn_by_credential_id(b"\xff\xff")
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_multiple_webauthn_per_user(credential_repo: SQLiteCredentialRepository, alice: User) -> None:
|
||||
for i in range(3):
|
||||
cred = WebAuthnCredential(
|
||||
user_id=alice.userid,
|
||||
credential_id=bytes([i]),
|
||||
public_key=b"\x00",
|
||||
device_name=f"Key {i}",
|
||||
)
|
||||
await credential_repo.create_webauthn(cred)
|
||||
|
||||
creds = await credential_repo.get_webauthn_by_user(alice.userid)
|
||||
assert len(creds) == 3
|
||||
|
||||
|
||||
async def test_update_webauthn(credential_repo: SQLiteCredentialRepository, alice: User) -> None:
|
||||
cred = WebAuthnCredential(
|
||||
user_id=alice.userid,
|
||||
credential_id=b"\x01\x02\x03",
|
||||
public_key=b"\x04\x05\x06",
|
||||
sign_count=0,
|
||||
device_name="Old Name",
|
||||
)
|
||||
await credential_repo.create_webauthn(cred)
|
||||
|
||||
cred.sign_count = 42
|
||||
cred.device_name = "New Name"
|
||||
updated = await credential_repo.update_webauthn(cred)
|
||||
assert updated.sign_count == 42
|
||||
assert updated.device_name == "New Name"
|
||||
|
||||
fetched = await credential_repo.get_webauthn_by_credential_id(b"\x01\x02\x03")
|
||||
assert fetched is not None
|
||||
assert fetched.sign_count == 42
|
||||
assert fetched.device_name == "New Name"
|
||||
|
||||
|
||||
async def test_update_webauthn_not_found(credential_repo: SQLiteCredentialRepository, alice: User) -> None:
|
||||
cred = WebAuthnCredential(
|
||||
user_id=alice.userid,
|
||||
credential_id=b"\x01\x02\x03",
|
||||
public_key=b"\x04\x05\x06",
|
||||
)
|
||||
with pytest.raises(ValueError, match="WebAuthn credential not found"):
|
||||
await credential_repo.update_webauthn(cred)
|
||||
|
||||
|
||||
async def test_delete_webauthn(credential_repo: SQLiteCredentialRepository, alice: User) -> None:
|
||||
cred = WebAuthnCredential(
|
||||
user_id=alice.userid,
|
||||
credential_id=b"\x01\x02\x03",
|
||||
public_key=b"\x04\x05\x06",
|
||||
)
|
||||
await credential_repo.create_webauthn(cred)
|
||||
|
||||
deleted = await credential_repo.delete_webauthn(alice.userid, b"\x01\x02\x03")
|
||||
assert deleted is True
|
||||
|
||||
creds = await credential_repo.get_webauthn_by_user(alice.userid)
|
||||
assert len(creds) == 0
|
||||
|
||||
|
||||
async def test_delete_webauthn_not_found(credential_repo: SQLiteCredentialRepository) -> None:
|
||||
deleted = await credential_repo.delete_webauthn("nobody", b"\xff")
|
||||
assert deleted is False
|
||||
|
||||
|
||||
async def test_create_and_get_password(credential_repo: SQLiteCredentialRepository, alice: User) -> None:
|
||||
cred = PasswordCredential(
|
||||
user_id=alice.userid,
|
||||
password_hash="$argon2id$v=19$m=65536,t=3,p=4$hash",
|
||||
)
|
||||
created = await credential_repo.create_password(cred)
|
||||
assert created.user_id == alice.userid
|
||||
|
||||
fetched = await credential_repo.get_password_by_user(alice.userid)
|
||||
assert fetched is not None
|
||||
assert fetched.password_hash == "$argon2id$v=19$m=65536,t=3,p=4$hash"
|
||||
|
||||
|
||||
async def test_get_password_not_found(credential_repo: SQLiteCredentialRepository) -> None:
|
||||
result = await credential_repo.get_password_by_user("nobody")
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_delete_password(credential_repo: SQLiteCredentialRepository, alice: User) -> None:
|
||||
cred = PasswordCredential(
|
||||
user_id=alice.userid,
|
||||
password_hash="$argon2id$v=19$hash",
|
||||
)
|
||||
await credential_repo.create_password(cred)
|
||||
|
||||
deleted = await credential_repo.delete_password(alice.userid)
|
||||
assert deleted is True
|
||||
|
||||
fetched = await credential_repo.get_password_by_user(alice.userid)
|
||||
assert fetched is None
|
||||
|
||||
|
||||
async def test_delete_password_not_found(credential_repo: SQLiteCredentialRepository) -> None:
|
||||
deleted = await credential_repo.delete_password("nobody")
|
||||
assert deleted is False
|
||||
|
||||
|
||||
async def test_create_duplicate_password(credential_repo: SQLiteCredentialRepository, alice: User) -> None:
|
||||
cred = PasswordCredential(user_id=alice.userid, password_hash="hash1")
|
||||
await credential_repo.create_password(cred)
|
||||
|
||||
with pytest.raises(DuplicateError):
|
||||
cred2 = PasswordCredential(user_id=alice.userid, password_hash="hash2")
|
||||
await credential_repo.create_password(cred2)
|
||||
|
||||
|
||||
async def test_create_duplicate_webauthn(credential_repo: SQLiteCredentialRepository, alice: User) -> None:
|
||||
cred = WebAuthnCredential(user_id=alice.userid, credential_id=b"\x01", public_key=b"\x02")
|
||||
await credential_repo.create_webauthn(cred)
|
||||
|
||||
with pytest.raises(DuplicateError):
|
||||
await credential_repo.create_webauthn(cred)
|
||||
Loading…
Add table
Add a link
Reference in a new issue