fix(oidc): return 400 instead of 500 on bad token requests
The token endpoint wrapped parse_request in try/except but called process_request and do_response unguarded, so a parseable-but-invalid request (e.g. a refresh_token grant missing client_id, or an unknown token) made idpyoidc raise and surfaced as a 500. Wrap both so failures return a clean 400 invalid_request and log the traceback server-side. Adds a regression test. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
3c5451b9c2
commit
b284cf596b
2 changed files with 42 additions and 3 deletions
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import html
|
import html
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
|
@ -13,6 +14,8 @@ from idpyoidc.time_util import utc_time_sans_frac
|
||||||
|
|
||||||
from porchlight.oidc.claims import PorchlightUserInfo, user_to_claims
|
from porchlight.oidc.claims import PorchlightUserInfo, user_to_claims
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(tags=["oidc"])
|
router = APIRouter(tags=["oidc"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -200,7 +203,7 @@ async def _complete_authorization( # noqa: PLR0913
|
||||||
|
|
||||||
|
|
||||||
@router.post("/token")
|
@router.post("/token")
|
||||||
async def token_endpoint(request: Request) -> JSONResponse:
|
async def token_endpoint(request: Request) -> JSONResponse: # noqa: PLR0911
|
||||||
"""OIDC Token endpoint."""
|
"""OIDC Token endpoint."""
|
||||||
oidc_server = request.app.state.oidc_server
|
oidc_server = request.app.state.oidc_server
|
||||||
endpoint = oidc_server.get_endpoint("token")
|
endpoint = oidc_server.get_endpoint("token")
|
||||||
|
|
@ -223,14 +226,31 @@ async def token_endpoint(request: Request) -> JSONResponse:
|
||||||
elif hasattr(parsed, "to_dict") and "error" in parsed:
|
elif hasattr(parsed, "to_dict") and "error" in parsed:
|
||||||
return JSONResponse(parsed.to_dict(), status_code=400)
|
return JSONResponse(parsed.to_dict(), status_code=400)
|
||||||
|
|
||||||
|
# process_request / do_response can raise on malformed-but-parseable
|
||||||
|
# requests (e.g. a refresh_token grant missing client_id). Treat those as a
|
||||||
|
# bad request rather than letting them surface as a 500.
|
||||||
|
try:
|
||||||
result = endpoint.process_request(parsed)
|
result = endpoint.process_request(parsed)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Token endpoint failed to process request")
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "invalid_request", "error_description": "The request could not be processed"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
if hasattr(result, "to_dict") and "error" in result:
|
if hasattr(result, "to_dict") and "error" in result:
|
||||||
return JSONResponse(result.to_dict(), status_code=400)
|
return JSONResponse(result.to_dict(), status_code=400)
|
||||||
elif isinstance(result, dict) and "error" in result:
|
elif isinstance(result, dict) and "error" in result:
|
||||||
return JSONResponse(result, status_code=400)
|
return JSONResponse(result, status_code=400)
|
||||||
|
|
||||||
|
try:
|
||||||
resp_info = endpoint.do_response(response_args=result.get("response_args"), request=parsed)
|
resp_info = endpoint.do_response(response_args=result.get("response_args"), request=parsed)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Token endpoint failed to build response")
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "invalid_request", "error_description": "The request could not be processed"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
response_data = resp_info["response"]
|
response_data = resp_info["response"]
|
||||||
if isinstance(response_data, str):
|
if isinstance(response_data, str):
|
||||||
|
|
|
||||||
|
|
@ -126,6 +126,25 @@ async def test_token_endpoint_exchanges_code(client: AsyncClient) -> None:
|
||||||
assert data["token_type"].lower() == "bearer"
|
assert data["token_type"].lower() == "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_grant_failure_returns_400_not_500(client: AsyncClient) -> None:
|
||||||
|
"""A refresh_token request that idpyoidc cannot process (here: unknown token,
|
||||||
|
no client_id in body) must surface as a 400, not a server 500."""
|
||||||
|
_register_test_client(client)
|
||||||
|
client_secret = "test-secret-0123456789abcdef"
|
||||||
|
|
||||||
|
auth_header = b64encode(f"test-rp:{client_secret}".encode()).decode()
|
||||||
|
token_res = await client.post(
|
||||||
|
"/token",
|
||||||
|
data={"grant_type": "refresh_token", "refresh_token": "bogus-token"},
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Basic {auth_header}",
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert token_res.status_code == 400, f"Expected 400, got {token_res.status_code}: {token_res.text}"
|
||||||
|
assert "error" in token_res.json()
|
||||||
|
|
||||||
|
|
||||||
async def test_token_endpoint_invalid_code_returns_error(client: AsyncClient) -> None:
|
async def test_token_endpoint_invalid_code_returns_error(client: AsyncClient) -> None:
|
||||||
_register_test_client(client)
|
_register_test_client(client)
|
||||||
client_secret = "test-secret-0123456789abcdef"
|
client_secret = "test-secret-0123456789abcdef"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue