fix(csrf): replay request body consumed during token validation #1

Merged
lundberg merged 1 commit from fix/csrf-body-replay into main 2026-06-10 11:08:57 +00:00
2 changed files with 48 additions and 3 deletions

View file

@ -6,7 +6,7 @@ import secrets
from starlette.requests import Request
from starlette.responses import HTMLResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send
logger = logging.getLogger(__name__)
@ -73,11 +73,18 @@ class CSRFMiddleware:
# Token validation
expected_token: str | None = request.session.get(SESSION_KEY)
# Check header first, then fall back to form field
# Check header first, then fall back to a form field. Reading the form
# consumes the ASGI receive stream, so when we go down that path we
# buffer the body and replay it to the downstream app — otherwise the
# endpoint's own request.form()/body() would block forever on an
# already-consumed stream (manifesting as a hung request).
downstream_receive: Receive = receive
submitted_token: str | None = request.headers.get("x-csrf-token")
if submitted_token is None:
body = await request.body()
form = await request.form()
submitted_token = form.get("csrf_token") # type: ignore[assignment]
downstream_receive = _replay_receive(body)
if (
expected_token is None
@ -92,4 +99,22 @@ class CSRFMiddleware:
await response(scope, receive, send)
return
await self.app(scope, receive, send)
await self.app(scope, downstream_receive, send)
def _replay_receive(body: bytes) -> Receive:
"""Build a receive callable that replays an already-read request body.
The first call returns the buffered body; subsequent calls report a
disconnect so a downstream reader cannot block waiting for more data.
"""
sent = False
async def receive() -> Message:
nonlocal sent
if not sent:
sent = True
return {"type": "http.request", "body": body, "more_body": False}
return {"type": "http.disconnect"}
return receive

View file

@ -25,6 +25,12 @@ def _make_app(check_origin: str | None = None) -> Starlette:
async def action_post(request: Request) -> PlainTextResponse:
return PlainTextResponse("ok")
async def echo_post(request: Request) -> PlainTextResponse:
# Reads the body itself — this only works if the CSRF middleware
# replayed the body it consumed to find the form token.
form = await request.form()
return PlainTextResponse(str(form.get("payload")))
async def exempt_post(request: Request) -> PlainTextResponse:
return PlainTextResponse("ok")
@ -32,6 +38,7 @@ def _make_app(check_origin: str | None = None) -> Starlette:
Route("/get-token", get_token, methods=["GET"]),
Route("/action", action_get, methods=["GET"]),
Route("/action", action_post, methods=["POST"]),
Route("/echo", echo_post, methods=["POST"]),
Route("/exempt", exempt_post, methods=["POST"]),
]
@ -77,6 +84,19 @@ class TestCSRFValidation:
response = await client.post("/action", data={"csrf_token": token})
assert response.status_code == 200
async def test_form_token_path_replays_body_to_endpoint(self) -> None:
"""Regression: when the token comes from the form body (no header), the
middleware must replay the consumed body so the endpoint can still read
it. Previously this hung the request until client disconnect."""
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post(
"/echo", data={"csrf_token": token, "payload": "hello"}
)
assert response.status_code == 200
assert response.text == "hello"
async def test_post_with_valid_header_token(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: