diff --git a/src/porchlight/app.py b/src/porchlight/app.py index f0bc34f..dd0f95f 100644 --- a/src/porchlight/app.py +++ b/src/porchlight/app.py @@ -19,6 +19,7 @@ from porchlight.oidc.endpoints import router as oidc_router from porchlight.oidc.provider import create_oidc_server from porchlight.store.sqlite.db import open_db from porchlight.store.sqlite.repositories import ( + SQLiteConsentRepository, SQLiteCredentialRepository, SQLiteMagicLinkRepository, SQLiteUserRepository, @@ -36,6 +37,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: app.state.user_repo = SQLiteUserRepository(db) app.state.credential_repo = SQLiteCredentialRepository(db) app.state.magic_link_repo = SQLiteMagicLinkRepository(db) + app.state.consent_repo = SQLiteConsentRepository(db) # Auth services app.state.password_service = PasswordService() diff --git a/src/porchlight/models.py b/src/porchlight/models.py index a232ddc..3357b0b 100644 --- a/src/porchlight/models.py +++ b/src/porchlight/models.py @@ -56,3 +56,11 @@ class MagicLink(BaseModel): used: bool = False created_by: str | None = None note: str | None = None + + +class Consent(BaseModel): + userid: str + client_id: str + scopes: list[str] + created_at: datetime = Field(default_factory=_utcnow) + updated_at: datetime = Field(default_factory=_utcnow) diff --git a/src/porchlight/oidc/endpoints.py b/src/porchlight/oidc/endpoints.py index 380242b..ee221ee 100644 --- a/src/porchlight/oidc/endpoints.py +++ b/src/porchlight/oidc/endpoints.py @@ -14,6 +14,13 @@ from porchlight.oidc.claims import PorchlightUserInfo, user_to_claims router = APIRouter(tags=["oidc"]) +SCOPE_DESCRIPTIONS: dict[str, str] = { + "openid": "Sign you in (required)", + "profile": "Your name and profile information", + "email": "Your email address", + "phone": "Your phone number", +} + @router.get("/.well-known/openid-configuration") async def provider_configuration(request: Request) -> JSONResponse: @@ -63,7 +70,7 @@ async def authorization(request: Request) -> Response: username = request.session.get("username") if userid and username: - return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username) + return await _check_consent_or_complete(request, oidc_server, endpoint, parsed, userid, username, query_params) # Not authenticated — store and redirect to login request.session["oidc_auth_request"] = query_params @@ -94,7 +101,40 @@ async def authorization_complete(request: Request) -> Response: error_desc = parsed.get("error_description", parsed["error"]) return HTMLResponse(f"

Error

{error_desc}

", status_code=400) - return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username) + return await _check_consent_or_complete( + request, oidc_server, endpoint, parsed, userid, username, auth_request_params + ) + + +async def _check_consent_or_complete( + request: Request, + oidc_server: object, + endpoint: object, + parsed: object, + userid: str, + username: str, + auth_params: dict, +) -> Response: + """Check if consent is needed; if so redirect to /consent, otherwise complete.""" + settings = request.app.state.settings + client_id = auth_params.get("client_id", "") + + # Manage-app bypasses consent + if client_id == settings.manage_client_id: + return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username) + + # Check stored consent + consent_repo = request.app.state.consent_repo + requested_scopes = auth_params.get("scope", "openid").split() + stored_consent = await consent_repo.get_consent(userid, client_id) + + if stored_consent and set(requested_scopes) <= set(stored_consent.scopes): + # All requested scopes already approved + return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username) + + # Consent needed — store auth state and redirect + request.session["consent_auth_request"] = auth_params + return RedirectResponse("/consent", status_code=303) async def _complete_authorization( @@ -246,3 +286,82 @@ async def userinfo_endpoint(request: Request) -> JSONResponse: response_data = response_data.to_dict() return JSONResponse(response_data) + + +@router.get("/consent") +async def consent_page(request: Request) -> Response: + """Show the consent form.""" + auth_params = request.session.get("consent_auth_request") + if auth_params is None: + return HTMLResponse("

Error

No pending consent request

", status_code=400) + + userid = request.session.get("userid") + if not userid: + return RedirectResponse("/login", status_code=303) + + client_id = auth_params.get("client_id", "") + requested_scopes = auth_params.get("scope", "openid").split() + + scope_info = [ + {"name": s, "description": SCOPE_DESCRIPTIONS.get(s, s), "required": s == "openid"} for s in requested_scopes + ] + + templates = request.app.state.templates + return templates.TemplateResponse( + request, + "consent.html", + {"client_id": client_id, "scopes": scope_info}, + ) + + +@router.post("/consent") +async def consent_submit(request: Request) -> Response: + """Handle consent form submission.""" + auth_params = request.session.pop("consent_auth_request", None) + if auth_params is None: + return HTMLResponse("

Error

No pending consent request

", status_code=400) + + userid = request.session.get("userid") + username = request.session.get("username") + if not userid or not username: + return RedirectResponse("/login", status_code=303) + + form = await request.form() + action = form.get("action") + client_id = auth_params.get("client_id", "") + redirect_uri = auth_params.get("redirect_uri", "") + state = auth_params.get("state", "") + + if action == "deny": + params = urlencode({"error": "access_denied", "state": state}) + return RedirectResponse(f"{redirect_uri}?{params}", status_code=303) + + if action != "allow": + return HTMLResponse("

Error

Invalid action

", status_code=400) + + # Allow — collect approved scopes + approved_scopes: list[str] = [str(s) for s in form.getlist("scope")] + if "openid" not in approved_scopes: + approved_scopes = ["openid", *list(approved_scopes)] + + # Save consent + consent_repo = request.app.state.consent_repo + await consent_repo.set_consent(userid, client_id, list(approved_scopes)) + + # Filter auth request scopes to only approved + auth_params["scope"] = " ".join(approved_scopes) + + # Re-parse and complete + oidc_server = request.app.state.oidc_server + endpoint = oidc_server.get_endpoint("authorization") + + try: + parsed = endpoint.parse_request(auth_params) + except Exception as exc: + return HTMLResponse(f"

Error

{exc}

", status_code=400) + + if "error" in parsed: + error_desc = parsed.get("error_description", parsed["error"]) + return HTMLResponse(f"

Error

{error_desc}

", status_code=400) + + return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username) diff --git a/src/porchlight/store/protocols.py b/src/porchlight/store/protocols.py index 2e21a11..b272318 100644 --- a/src/porchlight/store/protocols.py +++ b/src/porchlight/store/protocols.py @@ -1,6 +1,7 @@ from typing import Protocol, runtime_checkable from porchlight.models import ( + Consent, MagicLink, PasswordCredential, User, @@ -51,3 +52,14 @@ class MagicLinkRepository(Protocol): async def mark_used(self, token: str) -> bool: ... async def delete_expired(self) -> int: ... + + +@runtime_checkable +class ConsentRepository(Protocol): + async def get_consent(self, userid: str, client_id: str) -> Consent | None: ... + + async def set_consent(self, userid: str, client_id: str, scopes: list[str]) -> None: ... + + async def delete_consent(self, userid: str, client_id: str) -> bool: ... + + async def list_consents(self, userid: str) -> list[Consent]: ... diff --git a/src/porchlight/store/sqlite/migrations/002_user_consents.sql b/src/porchlight/store/sqlite/migrations/002_user_consents.sql new file mode 100644 index 0000000..e3859c1 --- /dev/null +++ b/src/porchlight/store/sqlite/migrations/002_user_consents.sql @@ -0,0 +1,8 @@ +CREATE TABLE user_consents ( + userid TEXT NOT NULL REFERENCES users(userid) ON DELETE CASCADE, + client_id TEXT NOT NULL, + scopes TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + PRIMARY KEY (userid, client_id) +); diff --git a/src/porchlight/store/sqlite/repositories.py b/src/porchlight/store/sqlite/repositories.py index 2854873..22075e0 100644 --- a/src/porchlight/store/sqlite/repositories.py +++ b/src/porchlight/store/sqlite/repositories.py @@ -1,8 +1,9 @@ +import json from datetime import UTC, datetime import aiosqlite -from porchlight.models import MagicLink, PasswordCredential, User, WebAuthnCredential +from porchlight.models import Consent, MagicLink, PasswordCredential, User, WebAuthnCredential from porchlight.store.exceptions import DuplicateError @@ -289,3 +290,56 @@ class SQLiteMagicLinkRepository: cursor = await self._db.execute("DELETE FROM magic_links WHERE expires_at < ? AND used = 0", (now,)) await self._db.commit() return cursor.rowcount + + +class SQLiteConsentRepository: + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + def _row_to_consent(self, row: aiosqlite.Row) -> Consent: + return Consent( + userid=row["userid"], + client_id=row["client_id"], + scopes=json.loads(row["scopes"]), + created_at=datetime.fromisoformat(row["created_at"]), + updated_at=datetime.fromisoformat(row["updated_at"]), + ) + + async def get_consent(self, userid: str, client_id: str) -> Consent | None: + async with self._db.execute( + "SELECT * FROM user_consents WHERE userid = ? AND client_id = ?", + (userid, client_id), + ) as cursor: + row = await cursor.fetchone() + if row is None: + return None + return self._row_to_consent(row) + + async def set_consent(self, userid: str, client_id: str, scopes: list[str]) -> None: + now = datetime.now(UTC).isoformat() + await self._db.execute( + """ + INSERT INTO user_consents (userid, client_id, scopes, created_at, updated_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (userid, client_id) + DO UPDATE SET scopes = excluded.scopes, updated_at = excluded.updated_at + """, + (userid, client_id, json.dumps(scopes), now, now), + ) + await self._db.commit() + + async def delete_consent(self, userid: str, client_id: str) -> bool: + cursor = await self._db.execute( + "DELETE FROM user_consents WHERE userid = ? AND client_id = ?", + (userid, client_id), + ) + await self._db.commit() + return cursor.rowcount > 0 + + async def list_consents(self, userid: str) -> list[Consent]: + async with self._db.execute( + "SELECT * FROM user_consents WHERE userid = ? ORDER BY client_id", + (userid,), + ) as cursor: + rows = await cursor.fetchall() + return [self._row_to_consent(row) for row in rows] diff --git a/src/porchlight/templates/consent.html b/src/porchlight/templates/consent.html new file mode 100644 index 0000000..395ba5d --- /dev/null +++ b/src/porchlight/templates/consent.html @@ -0,0 +1,35 @@ +{% extends "base.html" %} + +{% block title %}Authorize — Porchlight{% endblock %} + +{% block content %} + +{% endblock %} diff --git a/tests/test_oidc/test_consent_flow.py b/tests/test_oidc/test_consent_flow.py new file mode 100644 index 0000000..c1a0993 --- /dev/null +++ b/tests/test_oidc/test_consent_flow.py @@ -0,0 +1,268 @@ +import secrets +from datetime import UTC, datetime +from urllib.parse import parse_qs, urlparse + +from argon2 import PasswordHasher +from httpx import AsyncClient + +from porchlight.authn.password import PasswordService +from porchlight.models import PasswordCredential, User + + +async def test_authorization_shows_consent_for_new_client(client: AsyncClient) -> None: + """First-time authorization for an RP should redirect to /consent.""" + app = client._transport.app # type: ignore[union-attr] + _register_test_rp(app) + await _create_test_user(app) + + # Login + await client.post( + "/login/password", + data={"username": "consentuser", "password": "testpass"}, + headers={"HX-Request": "true"}, + ) + + # Authorization request + res = await client.get( + "/authorization", + params={ + "response_type": "code", + "client_id": "consent-rp", + "redirect_uri": "http://localhost:9000/callback", + "scope": "openid profile", + "state": "teststate", + }, + follow_redirects=False, + ) + assert res.status_code == 303 + assert "/consent" in res.headers["location"] + + +async def test_consent_page_renders(client: AsyncClient) -> None: + """GET /consent should render the consent form.""" + app = client._transport.app # type: ignore[union-attr] + _register_test_rp(app) + await _create_test_user(app) + await _login_and_start_auth(client) + + res = await client.get("/consent") + assert res.status_code == 200 + assert "consent-rp" in res.text + assert "profile" in res.text.lower() + + +async def test_consent_allow_redirects_with_code(client: AsyncClient) -> None: + """Approving consent should complete the authorization flow.""" + app = client._transport.app # type: ignore[union-attr] + _register_test_rp(app) + await _create_test_user(app) + await _login_and_start_auth(client) + + res = await client.post( + "/consent", + data={"action": "allow", "scope": ["openid", "profile"]}, + follow_redirects=False, + ) + assert res.status_code == 303 + location = res.headers["location"] + parsed = urlparse(location) + params = parse_qs(parsed.query) + assert "code" in params + + +async def test_consent_deny_redirects_with_error(client: AsyncClient) -> None: + """Denying consent should redirect with access_denied error.""" + app = client._transport.app # type: ignore[union-attr] + _register_test_rp(app) + await _create_test_user(app) + await _login_and_start_auth(client) + + res = await client.post( + "/consent", + data={"action": "deny"}, + follow_redirects=False, + ) + assert res.status_code == 303 + location = res.headers["location"] + parsed = urlparse(location) + params = parse_qs(parsed.query) + assert params["error"] == ["access_denied"] + + +async def test_saved_consent_skips_consent_screen(client: AsyncClient) -> None: + """Second authorization with same scopes should skip consent.""" + app = client._transport.app # type: ignore[union-attr] + _register_test_rp(app) + await _create_test_user(app) + + # First flow: login, authorize, consent + await _login_and_start_auth(client) + await client.post( + "/consent", + data={"action": "allow", "scope": ["openid", "profile"]}, + follow_redirects=False, + ) + + # Second flow: same scopes, should skip consent + res = await client.get( + "/authorization", + params={ + "response_type": "code", + "client_id": "consent-rp", + "redirect_uri": "http://localhost:9000/callback", + "scope": "openid profile", + "state": "teststate2", + }, + follow_redirects=False, + ) + assert res.status_code == 303 + location = res.headers["location"] + # Should redirect directly to callback, not /consent + assert "callback" in location + assert "code" in location + + +async def test_new_scopes_reshows_consent(client: AsyncClient) -> None: + """If RP requests new scopes, consent screen should reappear.""" + app = client._transport.app # type: ignore[union-attr] + _register_test_rp(app) + await _create_test_user(app) + + # First flow: consent to openid only + await _login_and_start_auth(client, scope="openid") + await client.post( + "/consent", + data={"action": "allow", "scope": ["openid"]}, + follow_redirects=False, + ) + + # Second flow: request openid + profile (new scope) + res = await client.get( + "/authorization", + params={ + "response_type": "code", + "client_id": "consent-rp", + "redirect_uri": "http://localhost:9000/callback", + "scope": "openid profile", + "state": "teststate2", + }, + follow_redirects=False, + ) + assert res.status_code == 303 + assert "/consent" in res.headers["location"] + + +async def test_manage_app_skips_consent(client: AsyncClient) -> None: + """The manage-app client should bypass consent entirely.""" + app = client._transport.app # type: ignore[union-attr] + settings = app.state.settings + await _create_test_user(app) + + await client.post( + "/login/password", + data={"username": "consentuser", "password": "testpass"}, + headers={"HX-Request": "true"}, + ) + + manage_cdb = app.state.oidc_server.context.cdb[settings.manage_client_id] + redirect_uri = manage_cdb["redirect_uris"][0][0] + + res = await client.get( + "/authorization", + params={ + "response_type": "code", + "client_id": settings.manage_client_id, + "redirect_uri": redirect_uri, + "scope": "openid profile email", + "state": "teststate", + }, + follow_redirects=False, + ) + assert res.status_code == 303 + location = res.headers["location"] + # Should redirect directly to callback, not /consent + assert "code" in location + assert "/consent" not in location + + +async def test_partial_consent_filters_scopes(client: AsyncClient) -> None: + """User can approve only some scopes (partial consent).""" + app = client._transport.app # type: ignore[union-attr] + _register_test_rp(app) + await _create_test_user(app) + + # Request openid + profile + email, approve only openid + profile + await _login_and_start_auth(client, scope="openid profile email") + res = await client.post( + "/consent", + data={"action": "allow", "scope": ["openid", "profile"]}, + follow_redirects=False, + ) + assert res.status_code == 303 + location = res.headers["location"] + assert "code" in location + + # Verify consent was saved with only the approved scopes + consent_repo = app.state.consent_repo + consent = await consent_repo.get_consent("lusab-consent", "consent-rp") + assert consent is not None + assert set(consent.scopes) == {"openid", "profile"} + + +# -- Test helpers -- + + +def _register_test_rp(app) -> None: + oidc_server = app.state.oidc_server + if "consent-rp" in oidc_server.context.cdb: + return + client_id = "consent-rp" + client_secret = "consent-secret-0123456789abcdef" + oidc_server.context.cdb[client_id] = { + "client_id": client_id, + "client_secret": client_secret, + "redirect_uris": [("http://localhost:9000/callback", {})], + "response_types_supported": ["code"], + "token_endpoint_auth_method": "client_secret_basic", + "scope": ["openid", "profile", "email"], + "allowed_scopes": ["openid", "profile", "email"], + "client_salt": secrets.token_hex(8), + } + oidc_server.keyjar.add_symmetric(client_id, client_secret) + + +async def _create_test_user(app) -> None: + user_repo = app.state.user_repo + existing = await user_repo.get_by_username("consentuser") + if existing: + return + user = User( + userid="lusab-consent", + username="consentuser", + email="consent@example.com", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + await user_repo.create(user) + svc = PasswordService(hasher=PasswordHasher(time_cost=1, memory_cost=8192)) + cred_repo = app.state.credential_repo + await cred_repo.create_password(PasswordCredential(user_id=user.userid, password_hash=svc.hash("testpass"))) + + +async def _login_and_start_auth(client: AsyncClient, scope: str = "openid profile") -> None: + await client.post( + "/login/password", + data={"username": "consentuser", "password": "testpass"}, + headers={"HX-Request": "true"}, + ) + await client.get( + "/authorization", + params={ + "response_type": "code", + "client_id": "consent-rp", + "redirect_uri": "http://localhost:9000/callback", + "scope": scope, + "state": "teststate", + }, + follow_redirects=False, + ) diff --git a/tests/test_oidc/test_e2e_flow.py b/tests/test_oidc/test_e2e_flow.py index 0761bc5..9626a94 100644 --- a/tests/test_oidc/test_e2e_flow.py +++ b/tests/test_oidc/test_e2e_flow.py @@ -87,13 +87,22 @@ async def test_full_authorization_code_flow(client: AsyncClient) -> None: f"Expected HX-Redirect to /authorization/complete, got '{hx_redirect}'" ) - # -- Step 3: Complete authorization → redirect to callback with code + state -- + # -- Step 3: Complete authorization → redirect to consent -- complete_res = await client.get("/authorization/complete", follow_redirects=False) assert complete_res.status_code in (302, 303), ( - f"Expected redirect to callback, got {complete_res.status_code}: {complete_res.text}" + f"Expected redirect to /consent, got {complete_res.status_code}: {complete_res.text}" ) + assert "/consent" in complete_res.headers["location"] - location = complete_res.headers["location"] + # -- Step 3b: Approve consent → redirect to callback with code + state -- + consent_res = await client.post( + "/consent", + data={"action": "allow", "scope": ["openid", "profile", "email"]}, + follow_redirects=False, + ) + assert consent_res.status_code in (302, 303) + + location = consent_res.headers["location"] parsed = urlparse(location) assert parsed.netloc == "localhost:9000" assert parsed.path == "/callback" diff --git a/tests/test_oidc/test_token.py b/tests/test_oidc/test_token.py index 3d5a38a..7d086ad 100644 --- a/tests/test_oidc/test_token.py +++ b/tests/test_oidc/test_token.py @@ -59,6 +59,8 @@ async def _get_authorization_code(client: AsyncClient) -> str: """Run full auth flow and extract the authorization code.""" _register_test_client(client) + app = client._transport.app # type: ignore[union-attr] + # Start authorization (unauthenticated — stores in session) await client.get( "/authorization", @@ -74,9 +76,13 @@ async def _get_authorization_code(client: AsyncClient) -> str: ) # Create user and log in - await _create_user_and_login(client) + userid = await _create_user_and_login(client) - # Complete authorization (now authenticated, session has oidc_auth_request) + # Pre-seed consent so the consent screen is skipped + consent_repo = app.state.consent_repo + await consent_repo.set_consent(userid, "test-rp", ["openid", "profile", "email"]) + + # Complete authorization (now authenticated, consent exists → redirects to callback) complete_res = await client.get("/authorization/complete", follow_redirects=False) assert complete_res.status_code in (302, 303), ( f"Expected redirect, got {complete_res.status_code}: {complete_res.text}" diff --git a/tests/test_oidc/test_userinfo.py b/tests/test_oidc/test_userinfo.py index 828f1c0..63f217b 100644 --- a/tests/test_oidc/test_userinfo.py +++ b/tests/test_oidc/test_userinfo.py @@ -61,6 +61,8 @@ async def _get_access_token(client: AsyncClient) -> str: """Run full auth + token flow and return the access_token.""" client_secret = _register_test_client(client) + app = client._transport.app # type: ignore[union-attr] + # Start authorization (unauthenticated — stores in session) await client.get( "/authorization", @@ -76,9 +78,13 @@ async def _get_access_token(client: AsyncClient) -> str: ) # Create user and log in - await _create_user_and_login(client) + userid = await _create_user_and_login(client) - # Complete authorization (now authenticated, session has oidc_auth_request) + # Pre-seed consent so the consent screen is skipped + consent_repo = app.state.consent_repo + await consent_repo.set_consent(userid, "test-rp", ["openid", "profile", "email"]) + + # Complete authorization (now authenticated, consent exists → redirects to callback) complete_res = await client.get("/authorization/complete", follow_redirects=False) assert complete_res.status_code in (302, 303), ( f"Expected redirect, got {complete_res.status_code}: {complete_res.text}" diff --git a/tests/test_store/test_migrations.py b/tests/test_store/test_migrations.py index 37c8c72..386aa5a 100644 --- a/tests/test_store/test_migrations.py +++ b/tests/test_store/test_migrations.py @@ -13,7 +13,7 @@ async def test_run_migrations_applies_initial() -> None: async with aiosqlite.connect(":memory:") as db: await db.execute("PRAGMA foreign_keys=ON") count = await run_migrations(db, MIGRATIONS_DIR) - assert count == 1 + assert count == 2 async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") as cursor: row = await cursor.fetchone() assert row is not None @@ -24,7 +24,7 @@ async def test_run_migrations_skips_already_applied() -> None: await db.execute("PRAGMA foreign_keys=ON") first_count = await run_migrations(db, MIGRATIONS_DIR) second_count = await run_migrations(db, MIGRATIONS_DIR) - assert first_count == 1 + assert first_count == 2 assert second_count == 0 @@ -39,4 +39,5 @@ async def test_run_migrations_creates_all_tables() -> None: assert "webauthn_credentials" in tables assert "password_credentials" in tables assert "magic_links" in tables + assert "user_consents" in tables assert "_migrations" in tables diff --git a/tests/test_store/test_protocols.py b/tests/test_store/test_protocols.py index 9d99dd0..fb8c943 100644 --- a/tests/test_store/test_protocols.py +++ b/tests/test_store/test_protocols.py @@ -1,6 +1,7 @@ from typing import runtime_checkable from porchlight.store.protocols import ( + ConsentRepository, CredentialRepository, MagicLinkRepository, UserRepository, @@ -11,3 +12,4 @@ def test_protocols_are_runtime_checkable() -> None: assert runtime_checkable(UserRepository) # type: ignore[arg-type] assert runtime_checkable(CredentialRepository) # type: ignore[arg-type] assert runtime_checkable(MagicLinkRepository) # type: ignore[arg-type] + assert runtime_checkable(ConsentRepository) # type: ignore[arg-type] diff --git a/tests/test_store/test_sqlite_consent_repo.py b/tests/test_store/test_sqlite_consent_repo.py new file mode 100644 index 0000000..d009510 --- /dev/null +++ b/tests/test_store/test_sqlite_consent_repo.py @@ -0,0 +1,112 @@ +from datetime import UTC, datetime + +from porchlight.models import User +from porchlight.store.protocols import ConsentRepository +from porchlight.store.sqlite.repositories import SQLiteConsentRepository, SQLiteUserRepository + + +async def _create_user(db) -> User: + """Helper to create a test user.""" + user_repo = SQLiteUserRepository(db) + user = User( + userid="test-user-id", + username="testuser", + email="test@example.com", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + return await user_repo.create(user) + + +async def test_implements_protocol(db) -> None: + repo = SQLiteConsentRepository(db) + assert isinstance(repo, ConsentRepository) + + +async def test_set_and_get_consent(db) -> None: + user = await _create_user(db) + repo = SQLiteConsentRepository(db) + await repo.set_consent(user.userid, "test-rp", ["openid", "profile"]) + + consent = await repo.get_consent(user.userid, "test-rp") + assert consent is not None + assert consent.userid == user.userid + assert consent.client_id == "test-rp" + assert consent.scopes == ["openid", "profile"] + assert isinstance(consent.created_at, datetime) + assert isinstance(consent.updated_at, datetime) + + +async def test_get_consent_not_found(db) -> None: + user = await _create_user(db) + repo = SQLiteConsentRepository(db) + consent = await repo.get_consent(user.userid, "nonexistent") + assert consent is None + + +async def test_set_consent_upserts(db) -> None: + user = await _create_user(db) + repo = SQLiteConsentRepository(db) + await repo.set_consent(user.userid, "test-rp", ["openid"]) + + original = await repo.get_consent(user.userid, "test-rp") + assert original is not None + + await repo.set_consent(user.userid, "test-rp", ["openid", "profile", "email"]) + + consent = await repo.get_consent(user.userid, "test-rp") + assert consent is not None + assert consent.scopes == ["openid", "profile", "email"] + assert consent.created_at == original.created_at + assert consent.updated_at >= original.updated_at + + +async def test_delete_consent(db) -> None: + user = await _create_user(db) + repo = SQLiteConsentRepository(db) + await repo.set_consent(user.userid, "test-rp", ["openid"]) + + result = await repo.delete_consent(user.userid, "test-rp") + assert result is True + + consent = await repo.get_consent(user.userid, "test-rp") + assert consent is None + + +async def test_delete_consent_not_found(db) -> None: + user = await _create_user(db) + repo = SQLiteConsentRepository(db) + result = await repo.delete_consent(user.userid, "nonexistent") + assert result is False + + +async def test_list_consents(db) -> None: + user = await _create_user(db) + repo = SQLiteConsentRepository(db) + await repo.set_consent(user.userid, "rp-a", ["openid"]) + await repo.set_consent(user.userid, "rp-b", ["openid", "profile"]) + + consents = await repo.list_consents(user.userid) + assert len(consents) == 2 + client_ids = {c.client_id for c in consents} + assert client_ids == {"rp-a", "rp-b"} + + +async def test_list_consents_empty(db) -> None: + user = await _create_user(db) + repo = SQLiteConsentRepository(db) + consents = await repo.list_consents(user.userid) + assert consents == [] + + +async def test_consent_deleted_on_user_cascade(db) -> None: + """Consent records are deleted when the user is deleted (CASCADE).""" + user = await _create_user(db) + user_repo = SQLiteUserRepository(db) + repo = SQLiteConsentRepository(db) + + await repo.set_consent(user.userid, "test-rp", ["openid"]) + await user_repo.delete(user.userid) + + consent = await repo.get_consent(user.userid, "test-rp") + assert consent is None