diff --git a/src/fastapi_oidc_op/app.py b/src/fastapi_oidc_op/app.py index beb0896..5d22ace 100644 --- a/src/fastapi_oidc_op/app.py +++ b/src/fastapi_oidc_op/app.py @@ -16,6 +16,7 @@ from fastapi_oidc_op.authn.webauthn import WebAuthnService from fastapi_oidc_op.config import Settings, StorageBackend from fastapi_oidc_op.invite.service import MagicLinkService from fastapi_oidc_op.manage.routes import router as manage_router +from fastapi_oidc_op.oidc.endpoints import router as oidc_router from fastapi_oidc_op.oidc.provider import create_oidc_server from fastapi_oidc_op.store.sqlite.migrations import run_migrations from fastapi_oidc_op.store.sqlite.repositories import ( @@ -109,6 +110,7 @@ def create_app(settings: Settings | None = None) -> FastAPI: # Routers app.include_router(authn_router) app.include_router(manage_router) + app.include_router(oidc_router) @app.get("/health") async def health() -> dict[str, str]: diff --git a/src/fastapi_oidc_op/oidc/endpoints.py b/src/fastapi_oidc_op/oidc/endpoints.py new file mode 100644 index 0000000..a068cbb --- /dev/null +++ b/src/fastapi_oidc_op/oidc/endpoints.py @@ -0,0 +1,77 @@ +"""FastAPI routes wrapping idpyoidc endpoint processing.""" + +from __future__ import annotations + +import json + +from fastapi import APIRouter, Request, Response +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse + +router = APIRouter(tags=["oidc"]) + + +@router.get("/.well-known/openid-configuration") +async def provider_configuration(request: Request) -> JSONResponse: + """OIDC Discovery endpoint.""" + oidc_server = request.app.state.oidc_server + endpoint = oidc_server.get_endpoint("provider_config") + + parsed = endpoint.parse_request({}) + result = endpoint.process_request(parsed) + response_info = endpoint.do_response(response_args=result.get("response_args"), request=parsed) + + response_data = response_info["response"] + if isinstance(response_data, str): + response_data = json.loads(response_data) + elif hasattr(response_data, "to_dict"): + response_data = response_data.to_dict() + + return JSONResponse(content=response_data) + + +@router.get("/jwks") +async def jwks(request: Request) -> JSONResponse: + """Public signing keys (JWK Set).""" + oidc_server = request.app.state.oidc_server + keys = oidc_server.keyjar.export_jwks() + return JSONResponse(content=keys) + + +@router.get("/authorization") +async def authorization(request: Request) -> Response: + """OIDC Authorization endpoint.""" + oidc_server = request.app.state.oidc_server + endpoint = oidc_server.get_endpoint("authorization") + query_params = dict(request.query_params) + + try: + parsed = endpoint.parse_request(query_params) + except Exception as exc: + return HTMLResponse(f"

Invalid Request

{exc}

", status_code=400) + + if "error" in parsed: + error_desc = parsed.get("error_description", parsed["error"]) + return HTMLResponse(f"

Error

{error_desc}

", status_code=400) + + # Check if user is authenticated + userid = request.session.get("userid") + username = request.session.get("username") + + if userid and username: + return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username) + + # Not authenticated — store and redirect to login + request.session["oidc_auth_request"] = query_params + return RedirectResponse("/login", status_code=303) + + +async def _complete_authorization( + request: Request, + oidc_server: object, + endpoint: object, + parsed: object, + userid: str, + username: str, +) -> Response: + """Complete the authorization after authentication. Placeholder — implemented in Task 7.""" + return HTMLResponse("Authorization completion not yet implemented", status_code=501) diff --git a/tests/test_oidc/test_authorization.py b/tests/test_oidc/test_authorization.py new file mode 100644 index 0000000..5994134 --- /dev/null +++ b/tests/test_oidc/test_authorization.py @@ -0,0 +1,78 @@ +import secrets + +from httpx import AsyncClient + + +def _register_test_client( + client: AsyncClient, + client_id: str = "test-rp", + redirect_uri: str = "http://localhost:9000/callback", +) -> str: + """Register a test client in the OIDC server. Returns client_secret.""" + app = client._transport.app # type: ignore[union-attr] + oidc_server = app.state.oidc_server + client_secret = secrets.token_hex(16) + oidc_server.context.cdb[client_id] = { + "client_id": client_id, + "client_secret": client_secret, + "redirect_uris": [(redirect_uri, {})], + "response_types_supported": ["code"], + "token_endpoint_auth_method": "client_secret_basic", + "scope": ["openid", "profile", "email"], + "allowed_scopes": ["openid", "profile", "email"], + "client_salt": secrets.token_hex(8), + } + oidc_server.keyjar.add_symmetric(client_id, client_secret) + return client_secret + + +async def test_authorization_redirects_to_login_when_unauthenticated(client: AsyncClient) -> None: + _register_test_client(client) + res = await client.get( + "/authorization", + params={ + "response_type": "code", + "client_id": "test-rp", + "redirect_uri": "http://localhost:9000/callback", + "scope": "openid", + "state": "test-state", + "nonce": "test-nonce", + }, + follow_redirects=False, + ) + assert res.status_code in (302, 303) + assert "/login" in res.headers["location"] + + +async def test_authorization_stores_auth_request_in_session(client: AsyncClient) -> None: + _register_test_client(client) + res = await client.get( + "/authorization", + params={ + "response_type": "code", + "client_id": "test-rp", + "redirect_uri": "http://localhost:9000/callback", + "scope": "openid", + "state": "test-state", + "nonce": "test-nonce", + }, + follow_redirects=False, + ) + assert res.status_code in (302, 303) + login_res = await client.get("/login") + assert login_res.status_code == 200 + + +async def test_authorization_invalid_client_returns_error(client: AsyncClient) -> None: + res = await client.get( + "/authorization", + params={ + "response_type": "code", + "client_id": "nonexistent", + "redirect_uri": "http://evil.com/callback", + "scope": "openid", + "state": "test-state", + }, + follow_redirects=False, + ) + assert res.status_code >= 400 or "error" in res.text.lower() diff --git a/tests/test_oidc/test_discovery.py b/tests/test_oidc/test_discovery.py new file mode 100644 index 0000000..df83257 --- /dev/null +++ b/tests/test_oidc/test_discovery.py @@ -0,0 +1,34 @@ +from httpx import AsyncClient + + +async def test_discovery_endpoint_returns_metadata(client: AsyncClient) -> None: + res = await client.get("/.well-known/openid-configuration") + assert res.status_code == 200 + data = res.json() + assert data["issuer"] == "http://localhost:8000" + assert "authorization_endpoint" in data + assert "token_endpoint" in data + assert "userinfo_endpoint" in data + assert "jwks_uri" in data + + +async def test_discovery_response_types_supported(client: AsyncClient) -> None: + res = await client.get("/.well-known/openid-configuration") + data = res.json() + assert "code" in data["response_types_supported"] + + +async def test_discovery_scopes_supported(client: AsyncClient) -> None: + res = await client.get("/.well-known/openid-configuration") + data = res.json() + assert "openid" in data["scopes_supported"] + + +async def test_jwks_endpoint_returns_keys(client: AsyncClient) -> None: + res = await client.get("/jwks") + assert res.status_code == 200 + data = res.json() + assert "keys" in data + assert len(data["keys"]) > 0 + for key in data["keys"]: + assert "kty" in key