"""FastAPI routes wrapping idpyoidc endpoint processing.""" from __future__ import annotations import json from types import SimpleNamespace from urllib.parse import urlencode from fastapi import APIRouter, Request, Response from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from idpyoidc.time_util import utc_time_sans_frac from fastapi_oidc_op.oidc.claims import PorchlightUserInfo, user_to_claims router = APIRouter(tags=["oidc"]) @router.get("/.well-known/openid-configuration") async def provider_configuration(request: Request) -> JSONResponse: """OIDC Discovery endpoint.""" oidc_server = request.app.state.oidc_server endpoint = oidc_server.get_endpoint("provider_config") parsed = endpoint.parse_request({}) result = endpoint.process_request(parsed) response_info = endpoint.do_response(response_args=result.get("response_args"), request=parsed) response_data = response_info["response"] if isinstance(response_data, str): response_data = json.loads(response_data) elif hasattr(response_data, "to_dict"): response_data = response_data.to_dict() return JSONResponse(content=response_data) @router.get("/jwks") async def jwks(request: Request) -> JSONResponse: """Public signing keys (JWK Set).""" oidc_server = request.app.state.oidc_server keys = oidc_server.keyjar.export_jwks() return JSONResponse(content=keys) @router.get("/authorization") async def authorization(request: Request) -> Response: """OIDC Authorization endpoint.""" oidc_server = request.app.state.oidc_server endpoint = oidc_server.get_endpoint("authorization") query_params = dict(request.query_params) try: parsed = endpoint.parse_request(query_params) except Exception as exc: return HTMLResponse(f"

Invalid Request

{exc}

", status_code=400) if "error" in parsed: error_desc = parsed.get("error_description", parsed["error"]) return HTMLResponse(f"

Error

{error_desc}

", status_code=400) # Check if user is authenticated userid = request.session.get("userid") username = request.session.get("username") if userid and username: return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username) # Not authenticated — store and redirect to login request.session["oidc_auth_request"] = query_params return RedirectResponse("/login", status_code=303) @router.get("/authorization/complete") async def authorization_complete(request: Request) -> Response: """Resume OIDC authorization after login.""" oidc_server = request.app.state.oidc_server endpoint = oidc_server.get_endpoint("authorization") auth_request_params = request.session.pop("oidc_auth_request", None) if auth_request_params is None: return HTMLResponse("

Error

No pending authorization request

", status_code=400) userid = request.session.get("userid") username = request.session.get("username") if not userid or not username: return RedirectResponse("/login", status_code=303) try: parsed = endpoint.parse_request(auth_request_params) except Exception as exc: return HTMLResponse(f"

Error

{exc}

", status_code=400) if "error" in parsed: error_desc = parsed.get("error_description", parsed["error"]) return HTMLResponse(f"

Error

{error_desc}

", status_code=400) return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username) async def _complete_authorization( request: Request, oidc_server: object, endpoint: object, parsed: object, userid: str, username: str, ) -> Response: """Complete OIDC authorization after user authentication.""" # Populate userinfo cache user_repo = request.app.state.user_repo user = await user_repo.get_by_userid(userid) if user is None: return HTMLResponse("

Error

User not found

", status_code=400) userinfo: PorchlightUserInfo = oidc_server.context.userinfo # type: ignore[union-attr] claims = user_to_claims(user) userinfo.set_user_claims(username, claims) # Create idpyoidc session — authn_method needs a kwargs dict authn_method = SimpleNamespace(kwargs={}) session_id = endpoint.create_session( # type: ignore[union-attr] request=parsed, user_id=username, acr="urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport", time_stamp=utc_time_sans_frac(), authn_method=authn_method, ) # Complete authorization (mints code, builds redirect) result = endpoint.authz_part2(request=parsed, session_id=session_id) # type: ignore[union-attr] if "error" in result.get("response_args", {}): response_args = result["response_args"] error_desc = response_args.get("error_description", response_args["error"]) return HTMLResponse(f"

Error

{error_desc}

", status_code=400) # Build redirect URL response_args = result.get("response_args", {}) return_uri = result.get("return_uri", "") if hasattr(response_args, "to_dict"): params = response_args.to_dict() elif isinstance(response_args, dict): params = response_args else: params = dict(response_args) redirect_url = f"{return_uri}?{urlencode(params)}" return RedirectResponse(redirect_url, status_code=303) @router.post("/token") async def token_endpoint(request: Request) -> JSONResponse: """OIDC Token endpoint.""" oidc_server = request.app.state.oidc_server endpoint = oidc_server.get_endpoint("token") body = await request.body() body_str = body.decode("utf-8") http_info = { "headers": dict(request.headers), "url": str(request.url), } try: parsed = endpoint.parse_request(body_str, http_info=http_info) except Exception as exc: return JSONResponse({"error": "invalid_request", "error_description": str(exc)}, status_code=400) if isinstance(parsed, dict) and "error" in parsed: return JSONResponse(parsed, status_code=400) elif hasattr(parsed, "to_dict") and "error" in parsed: return JSONResponse(parsed.to_dict(), status_code=400) result = endpoint.process_request(parsed) if hasattr(result, "to_dict") and "error" in result: return JSONResponse(result.to_dict(), status_code=400) elif isinstance(result, dict) and "error" in result: return JSONResponse(result, status_code=400) resp_info = endpoint.do_response(response_args=result.get("response_args"), request=parsed) response_data = resp_info["response"] if isinstance(response_data, str): response_data = json.loads(response_data) elif hasattr(response_data, "to_dict"): response_data = response_data.to_dict() return JSONResponse(response_data) @router.api_route("/userinfo", methods=["GET", "POST"]) async def userinfo_endpoint(request: Request) -> JSONResponse: """OIDC UserInfo endpoint.""" oidc_server = request.app.state.oidc_server endpoint = oidc_server.get_endpoint("userinfo") http_info = { "headers": dict(request.headers), "url": str(request.url), } if request.method == "POST": body = await request.body() request_data = body.decode("utf-8") else: request_data = {} try: parsed = endpoint.parse_request(request_data, http_info=http_info) except Exception as exc: return JSONResponse( {"error": "invalid_token", "error_description": str(exc)}, status_code=401, ) if isinstance(parsed, dict) and "error" in parsed: error_data = parsed elif hasattr(parsed, "to_dict") and "error" in parsed: error_data = parsed.to_dict() else: error_data = None if error_data is not None: return JSONResponse(error_data, status_code=401) result = endpoint.process_request(parsed) if hasattr(result, "to_dict") and "error" in result: return JSONResponse(result.to_dict(), status_code=401) elif isinstance(result, dict) and "error" in result: return JSONResponse(result, status_code=401) resp_info = endpoint.do_response( response_args=result.get("response_args"), request=parsed, client_id=result.get("client_id", ""), ) response_data = resp_info["response"] if isinstance(response_data, str): response_data = json.loads(response_data) elif hasattr(response_data, "to_dict"): response_data = response_data.to_dict() return JSONResponse(response_data)