diff --git a/src/porchlight/manage/routes.py b/src/porchlight/manage/routes.py
index 69afdff..92a7912 100644
--- a/src/porchlight/manage/routes.py
+++ b/src/porchlight/manage/routes.py
@@ -13,13 +13,6 @@ from porchlight.validation import PasswordChange, PasswordSet, ProfileUpdate, fo
router = APIRouter(prefix="/manage", tags=["manage"])
-async def _count_credentials(cred_repo: object, userid: str) -> int:
- """Count total credentials (password + webauthn) for a user."""
- webauthn = await cred_repo.get_webauthn_by_user(userid) # type: ignore[union-attr]
- password = await cred_repo.get_password_by_user(userid) # type: ignore[union-attr]
- return len(webauthn) + (1 if password else 0)
-
-
@router.get("/credentials", response_class=HTMLResponse)
async def credentials_page(request: Request) -> Response:
session_user = get_session_user(request)
@@ -107,11 +100,9 @@ async def delete_password(request: Request) -> Response:
userid, _username = session_user
cred_repo = request.app.state.credential_repo
- count = await _count_credentials(cred_repo, userid)
- if count <= 1:
+ # Atomic: refuses if this is the user's last credential (not raceable).
+ if not await cred_repo.delete_password_if_not_last(userid):
return HTMLResponse('
Cannot remove your last credential
')
-
- await cred_repo.delete_password(userid)
return HTMLResponse('Password removed
')
@@ -182,11 +173,9 @@ async def delete_webauthn(request: Request, credential_id_b64: str) -> Response:
padded = credential_id_b64 + "=" * (-len(credential_id_b64) % 4)
credential_id = urlsafe_b64decode(padded)
- count = await _count_credentials(cred_repo, userid)
- if count <= 1:
+ # Atomic: refuses if this is the user's last credential (not raceable).
+ if not await cred_repo.delete_webauthn_if_not_last(userid, credential_id):
return HTMLResponse('Cannot remove your last credential
')
-
- await cred_repo.delete_webauthn(userid, credential_id)
return HTMLResponse('Security key removed
')
diff --git a/src/porchlight/store/protocols.py b/src/porchlight/store/protocols.py
index 034d6d9..16d7c1a 100644
--- a/src/porchlight/store/protocols.py
+++ b/src/porchlight/store/protocols.py
@@ -46,6 +46,10 @@ class CredentialRepository(Protocol):
async def delete_password(self, user_id: str) -> bool: ...
+ async def delete_password_if_not_last(self, user_id: str) -> bool: ...
+
+ async def delete_webauthn_if_not_last(self, user_id: str, credential_id: bytes) -> bool: ...
+
@runtime_checkable
class MagicLinkRepository(Protocol):
diff --git a/src/porchlight/store/sqlite/repositories.py b/src/porchlight/store/sqlite/repositories.py
index 6acc941..41799d2 100644
--- a/src/porchlight/store/sqlite/repositories.py
+++ b/src/porchlight/store/sqlite/repositories.py
@@ -266,6 +266,41 @@ class SQLiteCredentialRepository:
await self._db.commit()
return cursor.rowcount > 0
+ async def delete_password_if_not_last(self, user_id: str) -> bool:
+ """Delete the password credential only if it is not the user's last
+ credential. The count and delete happen in one atomic statement, so it
+ is not raceable. Returns True if a row was deleted."""
+ cursor = await self._db.execute(
+ """
+ DELETE FROM password_credentials
+ WHERE user_id = ?
+ AND (
+ (SELECT COUNT(*) FROM password_credentials WHERE user_id = ?)
+ + (SELECT COUNT(*) FROM webauthn_credentials WHERE user_id = ?)
+ ) > 1
+ """,
+ (user_id, user_id, user_id),
+ )
+ await self._db.commit()
+ return cursor.rowcount > 0
+
+ async def delete_webauthn_if_not_last(self, user_id: str, credential_id: bytes) -> bool:
+ """Delete a WebAuthn credential only if it is not the user's last
+ credential, atomically. Returns True if a row was deleted."""
+ cursor = await self._db.execute(
+ """
+ DELETE FROM webauthn_credentials
+ WHERE user_id = ? AND credential_id = ?
+ AND (
+ (SELECT COUNT(*) FROM password_credentials WHERE user_id = ?)
+ + (SELECT COUNT(*) FROM webauthn_credentials WHERE user_id = ?)
+ ) > 1
+ """,
+ (user_id, credential_id, user_id, user_id),
+ )
+ await self._db.commit()
+ return cursor.rowcount > 0
+
class SQLiteMagicLinkRepository:
def __init__(self, db: aiosqlite.Connection) -> None:
diff --git a/tests/test_store/test_sqlite_credential_repo.py b/tests/test_store/test_sqlite_credential_repo.py
index 2c5d33d..1c3396a 100644
--- a/tests/test_store/test_sqlite_credential_repo.py
+++ b/tests/test_store/test_sqlite_credential_repo.py
@@ -182,3 +182,48 @@ async def test_create_duplicate_webauthn(credential_repo: SQLiteCredentialReposi
with pytest.raises(DuplicateError):
await credential_repo.create_webauthn(cred)
+
+
+async def test_delete_password_if_not_last_refuses_last(
+ credential_repo: SQLiteCredentialRepository, alice: User
+) -> None:
+ await credential_repo.create_password(PasswordCredential(user_id=alice.userid, password_hash="h"))
+
+ # Only credential -> must refuse and leave it in place.
+ assert await credential_repo.delete_password_if_not_last(alice.userid) is False
+ assert await credential_repo.get_password_by_user(alice.userid) is not None
+
+
+async def test_delete_password_if_not_last_allows_when_others_exist(
+ credential_repo: SQLiteCredentialRepository, alice: User
+) -> None:
+ await credential_repo.create_password(PasswordCredential(user_id=alice.userid, password_hash="h"))
+ await credential_repo.create_webauthn(
+ WebAuthnCredential(user_id=alice.userid, credential_id=b"\x01", public_key=b"\x02")
+ )
+
+ assert await credential_repo.delete_password_if_not_last(alice.userid) is True
+ assert await credential_repo.get_password_by_user(alice.userid) is None
+
+
+async def test_delete_webauthn_if_not_last_refuses_last(
+ credential_repo: SQLiteCredentialRepository, alice: User
+) -> None:
+ await credential_repo.create_webauthn(
+ WebAuthnCredential(user_id=alice.userid, credential_id=b"\x01", public_key=b"\x02")
+ )
+
+ assert await credential_repo.delete_webauthn_if_not_last(alice.userid, b"\x01") is False
+ assert len(await credential_repo.get_webauthn_by_user(alice.userid)) == 1
+
+
+async def test_delete_webauthn_if_not_last_allows_when_others_exist(
+ credential_repo: SQLiteCredentialRepository, alice: User
+) -> None:
+ await credential_repo.create_webauthn(
+ WebAuthnCredential(user_id=alice.userid, credential_id=b"\x01", public_key=b"\x02")
+ )
+ await credential_repo.create_password(PasswordCredential(user_id=alice.userid, password_hash="h"))
+
+ assert await credential_repo.delete_webauthn_if_not_last(alice.userid, b"\x01") is True
+ assert len(await credential_repo.get_webauthn_by_user(alice.userid)) == 0