diff --git a/src/porchlight/authn/routes.py b/src/porchlight/authn/routes.py index c710f3f..529e1e3 100644 --- a/src/porchlight/authn/routes.py +++ b/src/porchlight/authn/routes.py @@ -77,7 +77,8 @@ async def register_magic_link(request: Request, token: str) -> Response: magic_link_service = request.app.state.magic_link_service user_repo = request.app.state.user_repo - link = await magic_link_service.validate(token) + # Atomically validate and consume the token (single-use, no replay race). + link = await magic_link_service.consume(token) if link is None: return HTMLResponse("
Invalid or expired registration link.
", status_code=400) @@ -91,8 +92,6 @@ async def register_magic_link(request: Request, token: str) -> Response: user = User(userid=userid, username=link.username, groups=["users"]) await user_repo.create(user) - await magic_link_service.mark_used(token) - request.session["userid"] = user.userid request.session["username"] = user.username diff --git a/src/porchlight/invite/service.py b/src/porchlight/invite/service.py index b8b5ef8..4766442 100644 --- a/src/porchlight/invite/service.py +++ b/src/porchlight/invite/service.py @@ -66,6 +66,19 @@ class MagicLinkService: """Mark a magic link as used. Returns True if found and marked.""" return await self._repo.mark_used(self._hash_token(token)) + async def consume(self, token: str) -> MagicLink | None: + """Atomically validate and consume a token in one step. + + Prefer this over validate()+mark_used(): it closes the replay race + where two concurrent requests could both pass validation before either + marks the token used. Returns the link (raw token re-attached) on + success, else None. + """ + link = await self._repo.consume(self._hash_token(token)) + if link is None: + return None + return link.model_copy(update={"token": token}) + async def cleanup_expired(self) -> int: """Delete expired unused links. Returns count deleted.""" return await self._repo.delete_expired() diff --git a/src/porchlight/store/protocols.py b/src/porchlight/store/protocols.py index ba1af34..034d6d9 100644 --- a/src/porchlight/store/protocols.py +++ b/src/porchlight/store/protocols.py @@ -55,6 +55,8 @@ class MagicLinkRepository(Protocol): async def mark_used(self, token: str) -> bool: ... + async def consume(self, token: str) -> MagicLink | None: ... + async def delete_expired(self) -> int: ... diff --git a/src/porchlight/store/sqlite/repositories.py b/src/porchlight/store/sqlite/repositories.py index 37d02bb..6acc941 100644 --- a/src/porchlight/store/sqlite/repositories.py +++ b/src/porchlight/store/sqlite/repositories.py @@ -311,6 +311,25 @@ class SQLiteMagicLinkRepository: await self._db.commit() return cursor.rowcount > 0 + async def consume(self, token: str) -> MagicLink | None: + """Atomically validate-and-mark a token as used. + + The conditional UPDATE is the single point of decision, so two + concurrent requests cannot both consume the same token. Returns the + link if this call won the race (unused and unexpired), else None. + """ + now = datetime.now(UTC).isoformat() + cursor = await self._db.execute( + "UPDATE magic_links SET used = 1 WHERE token = ? AND used = 0 AND expires_at > ?", + (token, now), + ) + await self._db.commit() + if cursor.rowcount == 0: + return None + async with self._db.execute("SELECT * FROM magic_links WHERE token = ?", (token,)) as c: + row = await c.fetchone() + return self._row_to_magic_link(row) if row is not None else None + async def delete_expired(self) -> int: now = datetime.now(UTC).isoformat() cursor = await self._db.execute("DELETE FROM magic_links WHERE expires_at < ? AND used = 0", (now,)) diff --git a/tests/test_invite/test_service.py b/tests/test_invite/test_service.py index e7a20bd..166e7cd 100644 --- a/tests/test_invite/test_service.py +++ b/tests/test_invite/test_service.py @@ -102,6 +102,22 @@ async def test_validate_expired_token(service: MagicLinkService, repo: SQLiteMag assert result is None +async def test_consume_validates_and_is_single_use(service: MagicLinkService) -> None: + link = await service.create(username="alice") + consumed = await service.consume(link.token) + assert consumed is not None + assert consumed.username == "alice" + assert consumed.token == link.token # raw token re-attached + # A second consume of the same token fails. + assert await service.consume(link.token) is None + + +async def test_consume_expired_returns_none(service: MagicLinkService, repo: SQLiteMagicLinkRepository) -> None: + expired_service = MagicLinkService(repo=repo, ttl=-1) + link = await expired_service.create(username="alice") + assert await service.consume(link.token) is None + + async def test_mark_used_returns_true(service: MagicLinkService) -> None: link = await service.create(username="alice") result = await service.mark_used(link.token) diff --git a/tests/test_store/test_sqlite_magic_link_repo.py b/tests/test_store/test_sqlite_magic_link_repo.py index c2c9571..bb4f8b6 100644 --- a/tests/test_store/test_sqlite_magic_link_repo.py +++ b/tests/test_store/test_sqlite_magic_link_repo.py @@ -63,6 +63,34 @@ async def test_mark_used_not_found(magic_link_repo: SQLiteMagicLinkRepository) - assert marked is False +async def test_consume_marks_used_and_returns_link(magic_link_repo: SQLiteMagicLinkRepository) -> None: + await magic_link_repo.create(_make_link()) + + consumed = await magic_link_repo.consume("abc123") + assert consumed is not None + assert consumed.used is True + assert consumed.username == "alice" + + +async def test_consume_is_single_use(magic_link_repo: SQLiteMagicLinkRepository) -> None: + await magic_link_repo.create(_make_link()) + + first = await magic_link_repo.consume("abc123") + assert first is not None + # A second consume of the same token must fail (atomic single-use). + second = await magic_link_repo.consume("abc123") + assert second is None + + +async def test_consume_expired_returns_none(magic_link_repo: SQLiteMagicLinkRepository) -> None: + await magic_link_repo.create(_make_link(token="exp", expires_at=datetime.now(UTC) - timedelta(hours=1))) + assert await magic_link_repo.consume("exp") is None + + +async def test_consume_nonexistent_returns_none(magic_link_repo: SQLiteMagicLinkRepository) -> None: + assert await magic_link_repo.consume("nope") is None + + async def test_delete_expired(magic_link_repo: SQLiteMagicLinkRepository) -> None: expired = _make_link(token="expired", expires_at=datetime.now(UTC) - timedelta(hours=1)) await magic_link_repo.create(expired)