diff --git a/src/porchlight/oidc/endpoints.py b/src/porchlight/oidc/endpoints.py index 12af079..2390802 100644 --- a/src/porchlight/oidc/endpoints.py +++ b/src/porchlight/oidc/endpoints.py @@ -4,6 +4,7 @@ from __future__ import annotations import html import json +import logging from types import SimpleNamespace from urllib.parse import urlencode @@ -13,6 +14,8 @@ from idpyoidc.time_util import utc_time_sans_frac from porchlight.oidc.claims import PorchlightUserInfo, user_to_claims +logger = logging.getLogger(__name__) + router = APIRouter(tags=["oidc"]) @@ -200,7 +203,7 @@ async def _complete_authorization( # noqa: PLR0913 @router.post("/token") -async def token_endpoint(request: Request) -> JSONResponse: +async def token_endpoint(request: Request) -> JSONResponse: # noqa: PLR0911 """OIDC Token endpoint.""" oidc_server = request.app.state.oidc_server endpoint = oidc_server.get_endpoint("token") @@ -223,14 +226,31 @@ async def token_endpoint(request: Request) -> JSONResponse: elif hasattr(parsed, "to_dict") and "error" in parsed: return JSONResponse(parsed.to_dict(), status_code=400) - result = endpoint.process_request(parsed) + # process_request / do_response can raise on malformed-but-parseable + # requests (e.g. a refresh_token grant missing client_id). Treat those as a + # bad request rather than letting them surface as a 500. + try: + result = endpoint.process_request(parsed) + except Exception: + logger.exception("Token endpoint failed to process request") + return JSONResponse( + {"error": "invalid_request", "error_description": "The request could not be processed"}, + status_code=400, + ) if hasattr(result, "to_dict") and "error" in result: return JSONResponse(result.to_dict(), status_code=400) elif isinstance(result, dict) and "error" in result: return JSONResponse(result, status_code=400) - resp_info = endpoint.do_response(response_args=result.get("response_args"), request=parsed) + try: + resp_info = endpoint.do_response(response_args=result.get("response_args"), request=parsed) + except Exception: + logger.exception("Token endpoint failed to build response") + return JSONResponse( + {"error": "invalid_request", "error_description": "The request could not be processed"}, + status_code=400, + ) response_data = resp_info["response"] if isinstance(response_data, str): diff --git a/tests/test_oidc/test_token.py b/tests/test_oidc/test_token.py index e15a4b6..485e2e3 100644 --- a/tests/test_oidc/test_token.py +++ b/tests/test_oidc/test_token.py @@ -126,6 +126,25 @@ async def test_token_endpoint_exchanges_code(client: AsyncClient) -> None: assert data["token_type"].lower() == "bearer" +async def test_refresh_grant_failure_returns_400_not_500(client: AsyncClient) -> None: + """A refresh_token request that idpyoidc cannot process (here: unknown token, + no client_id in body) must surface as a 400, not a server 500.""" + _register_test_client(client) + client_secret = "test-secret-0123456789abcdef" + + auth_header = b64encode(f"test-rp:{client_secret}".encode()).decode() + token_res = await client.post( + "/token", + data={"grant_type": "refresh_token", "refresh_token": "bogus-token"}, + headers={ + "Authorization": f"Basic {auth_header}", + "Content-Type": "application/x-www-form-urlencoded", + }, + ) + assert token_res.status_code == 400, f"Expected 400, got {token_res.status_code}: {token_res.text}" + assert "error" in token_res.json() + + async def test_token_endpoint_invalid_code_returns_error(client: AsyncClient) -> None: _register_test_client(client) client_secret = "test-secret-0123456789abcdef"