diff --git a/src/porchlight/manage/routes.py b/src/porchlight/manage/routes.py index 69afdff..92a7912 100644 --- a/src/porchlight/manage/routes.py +++ b/src/porchlight/manage/routes.py @@ -13,13 +13,6 @@ from porchlight.validation import PasswordChange, PasswordSet, ProfileUpdate, fo router = APIRouter(prefix="/manage", tags=["manage"]) -async def _count_credentials(cred_repo: object, userid: str) -> int: - """Count total credentials (password + webauthn) for a user.""" - webauthn = await cred_repo.get_webauthn_by_user(userid) # type: ignore[union-attr] - password = await cred_repo.get_password_by_user(userid) # type: ignore[union-attr] - return len(webauthn) + (1 if password else 0) - - @router.get("/credentials", response_class=HTMLResponse) async def credentials_page(request: Request) -> Response: session_user = get_session_user(request) @@ -107,11 +100,9 @@ async def delete_password(request: Request) -> Response: userid, _username = session_user cred_repo = request.app.state.credential_repo - count = await _count_credentials(cred_repo, userid) - if count <= 1: + # Atomic: refuses if this is the user's last credential (not raceable). + if not await cred_repo.delete_password_if_not_last(userid): return HTMLResponse('
Cannot remove your last credential
') - - await cred_repo.delete_password(userid) return HTMLResponse('
Password removed
') @@ -182,11 +173,9 @@ async def delete_webauthn(request: Request, credential_id_b64: str) -> Response: padded = credential_id_b64 + "=" * (-len(credential_id_b64) % 4) credential_id = urlsafe_b64decode(padded) - count = await _count_credentials(cred_repo, userid) - if count <= 1: + # Atomic: refuses if this is the user's last credential (not raceable). + if not await cred_repo.delete_webauthn_if_not_last(userid, credential_id): return HTMLResponse('
Cannot remove your last credential
') - - await cred_repo.delete_webauthn(userid, credential_id) return HTMLResponse('
Security key removed
') diff --git a/src/porchlight/store/protocols.py b/src/porchlight/store/protocols.py index 034d6d9..16d7c1a 100644 --- a/src/porchlight/store/protocols.py +++ b/src/porchlight/store/protocols.py @@ -46,6 +46,10 @@ class CredentialRepository(Protocol): async def delete_password(self, user_id: str) -> bool: ... + async def delete_password_if_not_last(self, user_id: str) -> bool: ... + + async def delete_webauthn_if_not_last(self, user_id: str, credential_id: bytes) -> bool: ... + @runtime_checkable class MagicLinkRepository(Protocol): diff --git a/src/porchlight/store/sqlite/repositories.py b/src/porchlight/store/sqlite/repositories.py index 6acc941..41799d2 100644 --- a/src/porchlight/store/sqlite/repositories.py +++ b/src/porchlight/store/sqlite/repositories.py @@ -266,6 +266,41 @@ class SQLiteCredentialRepository: await self._db.commit() return cursor.rowcount > 0 + async def delete_password_if_not_last(self, user_id: str) -> bool: + """Delete the password credential only if it is not the user's last + credential. The count and delete happen in one atomic statement, so it + is not raceable. Returns True if a row was deleted.""" + cursor = await self._db.execute( + """ + DELETE FROM password_credentials + WHERE user_id = ? + AND ( + (SELECT COUNT(*) FROM password_credentials WHERE user_id = ?) + + (SELECT COUNT(*) FROM webauthn_credentials WHERE user_id = ?) + ) > 1 + """, + (user_id, user_id, user_id), + ) + await self._db.commit() + return cursor.rowcount > 0 + + async def delete_webauthn_if_not_last(self, user_id: str, credential_id: bytes) -> bool: + """Delete a WebAuthn credential only if it is not the user's last + credential, atomically. Returns True if a row was deleted.""" + cursor = await self._db.execute( + """ + DELETE FROM webauthn_credentials + WHERE user_id = ? AND credential_id = ? + AND ( + (SELECT COUNT(*) FROM password_credentials WHERE user_id = ?) + + (SELECT COUNT(*) FROM webauthn_credentials WHERE user_id = ?) + ) > 1 + """, + (user_id, credential_id, user_id, user_id), + ) + await self._db.commit() + return cursor.rowcount > 0 + class SQLiteMagicLinkRepository: def __init__(self, db: aiosqlite.Connection) -> None: diff --git a/tests/test_store/test_sqlite_credential_repo.py b/tests/test_store/test_sqlite_credential_repo.py index 2c5d33d..1c3396a 100644 --- a/tests/test_store/test_sqlite_credential_repo.py +++ b/tests/test_store/test_sqlite_credential_repo.py @@ -182,3 +182,48 @@ async def test_create_duplicate_webauthn(credential_repo: SQLiteCredentialReposi with pytest.raises(DuplicateError): await credential_repo.create_webauthn(cred) + + +async def test_delete_password_if_not_last_refuses_last( + credential_repo: SQLiteCredentialRepository, alice: User +) -> None: + await credential_repo.create_password(PasswordCredential(user_id=alice.userid, password_hash="h")) + + # Only credential -> must refuse and leave it in place. + assert await credential_repo.delete_password_if_not_last(alice.userid) is False + assert await credential_repo.get_password_by_user(alice.userid) is not None + + +async def test_delete_password_if_not_last_allows_when_others_exist( + credential_repo: SQLiteCredentialRepository, alice: User +) -> None: + await credential_repo.create_password(PasswordCredential(user_id=alice.userid, password_hash="h")) + await credential_repo.create_webauthn( + WebAuthnCredential(user_id=alice.userid, credential_id=b"\x01", public_key=b"\x02") + ) + + assert await credential_repo.delete_password_if_not_last(alice.userid) is True + assert await credential_repo.get_password_by_user(alice.userid) is None + + +async def test_delete_webauthn_if_not_last_refuses_last( + credential_repo: SQLiteCredentialRepository, alice: User +) -> None: + await credential_repo.create_webauthn( + WebAuthnCredential(user_id=alice.userid, credential_id=b"\x01", public_key=b"\x02") + ) + + assert await credential_repo.delete_webauthn_if_not_last(alice.userid, b"\x01") is False + assert len(await credential_repo.get_webauthn_by_user(alice.userid)) == 1 + + +async def test_delete_webauthn_if_not_last_allows_when_others_exist( + credential_repo: SQLiteCredentialRepository, alice: User +) -> None: + await credential_repo.create_webauthn( + WebAuthnCredential(user_id=alice.userid, credential_id=b"\x01", public_key=b"\x02") + ) + await credential_repo.create_password(PasswordCredential(user_id=alice.userid, password_hash="h")) + + assert await credential_repo.delete_webauthn_if_not_last(alice.userid, b"\x01") is True + assert len(await credential_repo.get_webauthn_by_user(alice.userid)) == 0