fix(csrf): replay request body consumed during token validation
The CSRF middleware read the request body via request.form() to extract the csrf_token form field, then handed the already-consumed ASGI receive to the downstream app. The endpoint's own request.form()/body() then blocked indefinitely waiting for body bytes that would never arrive, hanging the request until the client disconnected (ClientDisconnect). Only native form POSTs hit this path: htmx requests carry the token in the X-CSRF-Token header and skip the body read. The OIDC consent form is a plain form with the token in the body, so authorization consent hung. Buffer the body when falling back to the form token and replay it to the downstream app via a wrapped receive. Header-token requests are unchanged. Adds a regression test where the endpoint reads the body after body-token CSRF validation. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
27763d19ea
commit
3c5451b9c2
2 changed files with 48 additions and 3 deletions
|
|
@ -6,7 +6,7 @@ import secrets
|
|||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -73,11 +73,18 @@ class CSRFMiddleware:
|
|||
# Token validation
|
||||
expected_token: str | None = request.session.get(SESSION_KEY)
|
||||
|
||||
# Check header first, then fall back to form field
|
||||
# Check header first, then fall back to a form field. Reading the form
|
||||
# consumes the ASGI receive stream, so when we go down that path we
|
||||
# buffer the body and replay it to the downstream app — otherwise the
|
||||
# endpoint's own request.form()/body() would block forever on an
|
||||
# already-consumed stream (manifesting as a hung request).
|
||||
downstream_receive: Receive = receive
|
||||
submitted_token: str | None = request.headers.get("x-csrf-token")
|
||||
if submitted_token is None:
|
||||
body = await request.body()
|
||||
form = await request.form()
|
||||
submitted_token = form.get("csrf_token") # type: ignore[assignment]
|
||||
downstream_receive = _replay_receive(body)
|
||||
|
||||
if (
|
||||
expected_token is None
|
||||
|
|
@ -92,4 +99,22 @@ class CSRFMiddleware:
|
|||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
await self.app(scope, downstream_receive, send)
|
||||
|
||||
|
||||
def _replay_receive(body: bytes) -> Receive:
|
||||
"""Build a receive callable that replays an already-read request body.
|
||||
|
||||
The first call returns the buffered body; subsequent calls report a
|
||||
disconnect so a downstream reader cannot block waiting for more data.
|
||||
"""
|
||||
sent = False
|
||||
|
||||
async def receive() -> Message:
|
||||
nonlocal sent
|
||||
if not sent:
|
||||
sent = True
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
return receive
|
||||
|
|
|
|||
|
|
@ -25,6 +25,12 @@ def _make_app(check_origin: str | None = None) -> Starlette:
|
|||
async def action_post(request: Request) -> PlainTextResponse:
|
||||
return PlainTextResponse("ok")
|
||||
|
||||
async def echo_post(request: Request) -> PlainTextResponse:
|
||||
# Reads the body itself — this only works if the CSRF middleware
|
||||
# replayed the body it consumed to find the form token.
|
||||
form = await request.form()
|
||||
return PlainTextResponse(str(form.get("payload")))
|
||||
|
||||
async def exempt_post(request: Request) -> PlainTextResponse:
|
||||
return PlainTextResponse("ok")
|
||||
|
||||
|
|
@ -32,6 +38,7 @@ def _make_app(check_origin: str | None = None) -> Starlette:
|
|||
Route("/get-token", get_token, methods=["GET"]),
|
||||
Route("/action", action_get, methods=["GET"]),
|
||||
Route("/action", action_post, methods=["POST"]),
|
||||
Route("/echo", echo_post, methods=["POST"]),
|
||||
Route("/exempt", exempt_post, methods=["POST"]),
|
||||
]
|
||||
|
||||
|
|
@ -77,6 +84,19 @@ class TestCSRFValidation:
|
|||
response = await client.post("/action", data={"csrf_token": token})
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_form_token_path_replays_body_to_endpoint(self) -> None:
|
||||
"""Regression: when the token comes from the form body (no header), the
|
||||
middleware must replay the consumed body so the endpoint can still read
|
||||
it. Previously this hung the request until client disconnect."""
|
||||
app = _make_app()
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
token, _ = await _get_token_and_cookies(client)
|
||||
response = await client.post(
|
||||
"/echo", data={"csrf_token": token, "payload": "hello"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.text == "hello"
|
||||
|
||||
async def test_post_with_valid_header_token(self) -> None:
|
||||
app = _make_app()
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue