fix(csrf): replay request body consumed during token validation

The CSRF middleware read the request body via request.form() to extract
the csrf_token form field, then handed the already-consumed ASGI receive
to the downstream app. The endpoint's own request.form()/body() then
blocked indefinitely waiting for body bytes that would never arrive,
hanging the request until the client disconnected (ClientDisconnect).

Only native form POSTs hit this path: htmx requests carry the token in
the X-CSRF-Token header and skip the body read. The OIDC consent form is
a plain form with the token in the body, so authorization consent hung.

Buffer the body when falling back to the form token and replay it to the
downstream app via a wrapped receive. Header-token requests are
unchanged. Adds a regression test where the endpoint reads the body after
body-token CSRF validation.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Johan Lundberg 2026-06-10 13:05:18 +02:00
parent 27763d19ea
commit 3c5451b9c2
No known key found for this signature in database
GPG key ID: A6C152738D03C7D1
2 changed files with 48 additions and 3 deletions

View file

@ -6,7 +6,7 @@ import secrets
from starlette.requests import Request
from starlette.responses import HTMLResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send
logger = logging.getLogger(__name__)
@ -73,11 +73,18 @@ class CSRFMiddleware:
# Token validation
expected_token: str | None = request.session.get(SESSION_KEY)
# Check header first, then fall back to form field
# Check header first, then fall back to a form field. Reading the form
# consumes the ASGI receive stream, so when we go down that path we
# buffer the body and replay it to the downstream app — otherwise the
# endpoint's own request.form()/body() would block forever on an
# already-consumed stream (manifesting as a hung request).
downstream_receive: Receive = receive
submitted_token: str | None = request.headers.get("x-csrf-token")
if submitted_token is None:
body = await request.body()
form = await request.form()
submitted_token = form.get("csrf_token") # type: ignore[assignment]
downstream_receive = _replay_receive(body)
if (
expected_token is None
@ -92,4 +99,22 @@ class CSRFMiddleware:
await response(scope, receive, send)
return
await self.app(scope, receive, send)
await self.app(scope, downstream_receive, send)
def _replay_receive(body: bytes) -> Receive:
"""Build a receive callable that replays an already-read request body.
The first call returns the buffered body; subsequent calls report a
disconnect so a downstream reader cannot block waiting for more data.
"""
sent = False
async def receive() -> Message:
nonlocal sent
if not sent:
sent = True
return {"type": "http.request", "body": body, "more_body": False}
return {"type": "http.disconnect"}
return receive