porchlight/src/porchlight/oidc/endpoints.py
Johan Lundberg b284cf596b
fix(oidc): return 400 instead of 500 on bad token requests
The token endpoint wrapped parse_request in try/except but
  called process_request and do_response unguarded, so a parseable-but-invalid request (e.g. a refresh_token grant missing client_id, or an
  unknown token) made idpyoidc raise and surfaced as a 500. Wrap both so failures return a clean 400 invalid_request and log the traceback
  server-side. Adds a regression test.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-10 14:37:01 +02:00

396 lines
14 KiB
Python

"""FastAPI routes wrapping idpyoidc endpoint processing."""
from __future__ import annotations
import html
import json
import logging
from types import SimpleNamespace
from urllib.parse import urlencode
from fastapi import APIRouter, Request, Response
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from idpyoidc.time_util import utc_time_sans_frac
from porchlight.oidc.claims import PorchlightUserInfo, user_to_claims
logger = logging.getLogger(__name__)
router = APIRouter(tags=["oidc"])
def _error_page(message: object, status_code: int = 400, title: str = "Error") -> HTMLResponse:
"""Render an error page, escaping the (possibly request-derived) message."""
return HTMLResponse(
f"<h1>{html.escape(title)}</h1><p>{html.escape(str(message))}</p>",
status_code=status_code,
)
SCOPE_DESCRIPTIONS: dict[str, str] = {
"openid": "Sign you in (required)",
"profile": "Your name and profile information",
"email": "Your email address",
"phone": "Your phone number",
}
@router.get("/.well-known/openid-configuration")
async def provider_configuration(request: Request) -> JSONResponse:
"""OIDC Discovery endpoint."""
oidc_server = request.app.state.oidc_server
endpoint = oidc_server.get_endpoint("provider_config")
parsed = endpoint.parse_request({})
result = endpoint.process_request(parsed)
response_info = endpoint.do_response(response_args=result.get("response_args"), request=parsed)
response_data = response_info["response"]
if isinstance(response_data, str):
response_data = json.loads(response_data)
elif hasattr(response_data, "to_dict"):
response_data = response_data.to_dict()
return JSONResponse(content=response_data)
@router.get("/jwks")
async def jwks(request: Request) -> JSONResponse:
"""Public signing keys (JWK Set)."""
oidc_server = request.app.state.oidc_server
keys = oidc_server.keyjar.export_jwks()
return JSONResponse(content=keys)
@router.get("/authorization")
async def authorization(request: Request) -> Response:
"""OIDC Authorization endpoint."""
oidc_server = request.app.state.oidc_server
endpoint = oidc_server.get_endpoint("authorization")
query_params = dict(request.query_params)
try:
parsed = endpoint.parse_request(query_params)
except Exception as exc:
return _error_page(exc, title="Invalid Request")
if "error" in parsed:
error_desc = parsed.get("error_description", parsed["error"])
return _error_page(error_desc)
# Check if user is authenticated
userid = request.session.get("userid")
username = request.session.get("username")
if userid and username:
return await _check_consent_or_complete(request, oidc_server, endpoint, parsed, userid, username, query_params)
# Not authenticated — store and redirect to login
request.session["oidc_auth_request"] = query_params
return RedirectResponse("/login", status_code=303)
@router.get("/authorization/complete")
async def authorization_complete(request: Request) -> Response:
"""Resume OIDC authorization after login."""
oidc_server = request.app.state.oidc_server
endpoint = oidc_server.get_endpoint("authorization")
auth_request_params = request.session.pop("oidc_auth_request", None)
if auth_request_params is None:
return HTMLResponse("<h1>Error</h1><p>No pending authorization request</p>", status_code=400)
userid = request.session.get("userid")
username = request.session.get("username")
if not userid or not username:
return RedirectResponse("/login", status_code=303)
try:
parsed = endpoint.parse_request(auth_request_params)
except Exception as exc:
return _error_page(exc)
if "error" in parsed:
error_desc = parsed.get("error_description", parsed["error"])
return _error_page(error_desc)
return await _check_consent_or_complete(
request, oidc_server, endpoint, parsed, userid, username, auth_request_params
)
async def _check_consent_or_complete( # noqa: PLR0913
request: Request,
oidc_server: object,
endpoint: object,
parsed: object,
userid: str,
username: str,
auth_params: dict,
) -> Response:
"""Check if consent is needed; if so redirect to /consent, otherwise complete."""
settings = request.app.state.settings
client_id = auth_params.get("client_id", "")
# Manage-app bypasses consent
if client_id == settings.manage_client_id:
return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username)
# Check stored consent
consent_repo = request.app.state.consent_repo
requested_scopes = auth_params.get("scope", "openid").split()
stored_consent = await consent_repo.get_consent(userid, client_id)
if stored_consent and set(requested_scopes) <= set(stored_consent.scopes):
# All requested scopes already approved
return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username)
# Consent needed — store auth state and redirect
request.session["consent_auth_request"] = auth_params
return RedirectResponse("/consent", status_code=303)
async def _complete_authorization( # noqa: PLR0913
request: Request,
oidc_server: object,
endpoint: object,
parsed: object,
userid: str,
username: str,
) -> Response:
"""Complete OIDC authorization after user authentication."""
# Populate userinfo cache
user_repo = request.app.state.user_repo
user = await user_repo.get_by_userid(userid)
if user is None:
return HTMLResponse("<h1>Error</h1><p>User not found</p>", status_code=400)
userinfo: PorchlightUserInfo = oidc_server.context.userinfo # type: ignore[union-attr]
claims = user_to_claims(user)
userinfo.set_user_claims(username, claims)
# Create idpyoidc session — authn_method needs a kwargs dict
authn_method = SimpleNamespace(kwargs={})
session_id = endpoint.create_session( # type: ignore[union-attr]
request=parsed,
user_id=username,
acr="urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport",
time_stamp=utc_time_sans_frac(),
authn_method=authn_method,
)
# Complete authorization (mints code, builds redirect)
result = endpoint.authz_part2(request=parsed, session_id=session_id) # type: ignore[union-attr]
if "error" in result.get("response_args", {}):
response_args = result["response_args"]
error_desc = response_args.get("error_description", response_args["error"])
return _error_page(error_desc)
# Build redirect URL
response_args = result.get("response_args", {})
return_uri = result.get("return_uri", "")
if hasattr(response_args, "to_dict"):
params = response_args.to_dict()
elif isinstance(response_args, dict):
params = response_args
else:
params = dict(response_args)
redirect_url = f"{return_uri}?{urlencode(params)}"
return RedirectResponse(redirect_url, status_code=303)
@router.post("/token")
async def token_endpoint(request: Request) -> JSONResponse: # noqa: PLR0911
"""OIDC Token endpoint."""
oidc_server = request.app.state.oidc_server
endpoint = oidc_server.get_endpoint("token")
body = await request.body()
body_str = body.decode("utf-8")
http_info = {
"headers": dict(request.headers),
"url": str(request.url),
}
try:
parsed = endpoint.parse_request(body_str, http_info=http_info)
except Exception as exc:
return JSONResponse({"error": "invalid_request", "error_description": str(exc)}, status_code=400)
if isinstance(parsed, dict) and "error" in parsed:
return JSONResponse(parsed, status_code=400)
elif hasattr(parsed, "to_dict") and "error" in parsed:
return JSONResponse(parsed.to_dict(), status_code=400)
# process_request / do_response can raise on malformed-but-parseable
# requests (e.g. a refresh_token grant missing client_id). Treat those as a
# bad request rather than letting them surface as a 500.
try:
result = endpoint.process_request(parsed)
except Exception:
logger.exception("Token endpoint failed to process request")
return JSONResponse(
{"error": "invalid_request", "error_description": "The request could not be processed"},
status_code=400,
)
if hasattr(result, "to_dict") and "error" in result:
return JSONResponse(result.to_dict(), status_code=400)
elif isinstance(result, dict) and "error" in result:
return JSONResponse(result, status_code=400)
try:
resp_info = endpoint.do_response(response_args=result.get("response_args"), request=parsed)
except Exception:
logger.exception("Token endpoint failed to build response")
return JSONResponse(
{"error": "invalid_request", "error_description": "The request could not be processed"},
status_code=400,
)
response_data = resp_info["response"]
if isinstance(response_data, str):
response_data = json.loads(response_data)
elif hasattr(response_data, "to_dict"):
response_data = response_data.to_dict()
return JSONResponse(response_data)
@router.api_route("/userinfo", methods=["GET", "POST"])
async def userinfo_endpoint(request: Request) -> JSONResponse:
"""OIDC UserInfo endpoint."""
oidc_server = request.app.state.oidc_server
endpoint = oidc_server.get_endpoint("userinfo")
http_info = {
"headers": dict(request.headers),
"url": str(request.url),
}
if request.method == "POST":
body = await request.body()
request_data = body.decode("utf-8")
else:
request_data = {}
try:
parsed = endpoint.parse_request(request_data, http_info=http_info)
except Exception as exc:
return JSONResponse(
{"error": "invalid_token", "error_description": str(exc)},
status_code=401,
)
if isinstance(parsed, dict) and "error" in parsed:
error_data = parsed
elif hasattr(parsed, "to_dict") and "error" in parsed:
error_data = parsed.to_dict()
else:
error_data = None
if error_data is not None:
return JSONResponse(error_data, status_code=401)
result = endpoint.process_request(parsed)
if hasattr(result, "to_dict") and "error" in result:
return JSONResponse(result.to_dict(), status_code=401)
elif isinstance(result, dict) and "error" in result:
return JSONResponse(result, status_code=401)
resp_info = endpoint.do_response(
response_args=result.get("response_args"),
request=parsed,
client_id=result.get("client_id", ""),
)
response_data = resp_info["response"]
if isinstance(response_data, str):
response_data = json.loads(response_data)
elif hasattr(response_data, "to_dict"):
response_data = response_data.to_dict()
return JSONResponse(response_data)
@router.get("/consent")
async def consent_page(request: Request) -> Response:
"""Show the consent form."""
auth_params = request.session.get("consent_auth_request")
if auth_params is None:
return HTMLResponse("<h1>Error</h1><p>No pending consent request</p>", status_code=400)
userid = request.session.get("userid")
if not userid:
return RedirectResponse("/login", status_code=303)
client_id = auth_params.get("client_id", "")
requested_scopes = auth_params.get("scope", "openid").split()
scope_info = [
{"name": s, "description": SCOPE_DESCRIPTIONS.get(s, s), "required": s == "openid"} for s in requested_scopes
]
templates = request.app.state.templates
return templates.TemplateResponse(
request,
"consent.html",
{"client_id": client_id, "scopes": scope_info},
)
@router.post("/consent")
async def consent_submit(request: Request) -> Response:
"""Handle consent form submission."""
auth_params = request.session.pop("consent_auth_request", None)
if auth_params is None:
return HTMLResponse("<h1>Error</h1><p>No pending consent request</p>", status_code=400)
userid = request.session.get("userid")
username = request.session.get("username")
if not userid or not username:
return RedirectResponse("/login", status_code=303)
form = await request.form()
action = form.get("action")
client_id = auth_params.get("client_id", "")
redirect_uri = auth_params.get("redirect_uri", "")
state = auth_params.get("state", "")
if action != "allow":
if action == "deny":
params = urlencode({"error": "access_denied", "state": state})
return RedirectResponse(f"{redirect_uri}?{params}", status_code=303)
return HTMLResponse("<h1>Error</h1><p>Invalid action</p>", status_code=400)
# Allow — collect approved scopes, rejecting anything outside the
# originally requested set (a forged form must not escalate scope).
requested_scopes = set(auth_params.get("scope", "openid").split())
approved_scopes: list[str] = [str(s) for s in form.getlist("scope") if str(s) in requested_scopes]
if "openid" not in approved_scopes:
approved_scopes = ["openid", *list(approved_scopes)]
# Save consent
consent_repo = request.app.state.consent_repo
await consent_repo.set_consent(userid, client_id, list(approved_scopes))
# Filter auth request scopes to only approved
auth_params["scope"] = " ".join(approved_scopes)
# Re-parse and complete
oidc_server = request.app.state.oidc_server
endpoint = oidc_server.get_endpoint("authorization")
try:
parsed = endpoint.parse_request(auth_params)
if "error" in parsed:
raise ValueError(parsed.get("error_description", parsed["error"]))
except Exception as exc:
return _error_page(exc)
return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username)