fix(csrf): replay request body consumed during token validation
The CSRF middleware read the request body via request.form() to extract the csrf_token form field, then handed the already-consumed ASGI receive to the downstream app. The endpoint's own request.form()/body() then blocked indefinitely waiting for body bytes that would never arrive, hanging the request until the client disconnected (ClientDisconnect). Only native form POSTs hit this path: htmx requests carry the token in the X-CSRF-Token header and skip the body read. The OIDC consent form is a plain form with the token in the body, so authorization consent hung. Buffer the body when falling back to the form token and replay it to the downstream app via a wrapped receive. Header-token requests are unchanged. Adds a regression test where the endpoint reads the body after body-token CSRF validation. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
27763d19ea
commit
3c5451b9c2
2 changed files with 48 additions and 3 deletions
|
|
@ -6,7 +6,7 @@ import secrets
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import HTMLResponse
|
from starlette.responses import HTMLResponse
|
||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -73,11 +73,18 @@ class CSRFMiddleware:
|
||||||
# Token validation
|
# Token validation
|
||||||
expected_token: str | None = request.session.get(SESSION_KEY)
|
expected_token: str | None = request.session.get(SESSION_KEY)
|
||||||
|
|
||||||
# Check header first, then fall back to form field
|
# Check header first, then fall back to a form field. Reading the form
|
||||||
|
# consumes the ASGI receive stream, so when we go down that path we
|
||||||
|
# buffer the body and replay it to the downstream app — otherwise the
|
||||||
|
# endpoint's own request.form()/body() would block forever on an
|
||||||
|
# already-consumed stream (manifesting as a hung request).
|
||||||
|
downstream_receive: Receive = receive
|
||||||
submitted_token: str | None = request.headers.get("x-csrf-token")
|
submitted_token: str | None = request.headers.get("x-csrf-token")
|
||||||
if submitted_token is None:
|
if submitted_token is None:
|
||||||
|
body = await request.body()
|
||||||
form = await request.form()
|
form = await request.form()
|
||||||
submitted_token = form.get("csrf_token") # type: ignore[assignment]
|
submitted_token = form.get("csrf_token") # type: ignore[assignment]
|
||||||
|
downstream_receive = _replay_receive(body)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
expected_token is None
|
expected_token is None
|
||||||
|
|
@ -92,4 +99,22 @@ class CSRFMiddleware:
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.app(scope, receive, send)
|
await self.app(scope, downstream_receive, send)
|
||||||
|
|
||||||
|
|
||||||
|
def _replay_receive(body: bytes) -> Receive:
|
||||||
|
"""Build a receive callable that replays an already-read request body.
|
||||||
|
|
||||||
|
The first call returns the buffered body; subsequent calls report a
|
||||||
|
disconnect so a downstream reader cannot block waiting for more data.
|
||||||
|
"""
|
||||||
|
sent = False
|
||||||
|
|
||||||
|
async def receive() -> Message:
|
||||||
|
nonlocal sent
|
||||||
|
if not sent:
|
||||||
|
sent = True
|
||||||
|
return {"type": "http.request", "body": body, "more_body": False}
|
||||||
|
return {"type": "http.disconnect"}
|
||||||
|
|
||||||
|
return receive
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,12 @@ def _make_app(check_origin: str | None = None) -> Starlette:
|
||||||
async def action_post(request: Request) -> PlainTextResponse:
|
async def action_post(request: Request) -> PlainTextResponse:
|
||||||
return PlainTextResponse("ok")
|
return PlainTextResponse("ok")
|
||||||
|
|
||||||
|
async def echo_post(request: Request) -> PlainTextResponse:
|
||||||
|
# Reads the body itself — this only works if the CSRF middleware
|
||||||
|
# replayed the body it consumed to find the form token.
|
||||||
|
form = await request.form()
|
||||||
|
return PlainTextResponse(str(form.get("payload")))
|
||||||
|
|
||||||
async def exempt_post(request: Request) -> PlainTextResponse:
|
async def exempt_post(request: Request) -> PlainTextResponse:
|
||||||
return PlainTextResponse("ok")
|
return PlainTextResponse("ok")
|
||||||
|
|
||||||
|
|
@ -32,6 +38,7 @@ def _make_app(check_origin: str | None = None) -> Starlette:
|
||||||
Route("/get-token", get_token, methods=["GET"]),
|
Route("/get-token", get_token, methods=["GET"]),
|
||||||
Route("/action", action_get, methods=["GET"]),
|
Route("/action", action_get, methods=["GET"]),
|
||||||
Route("/action", action_post, methods=["POST"]),
|
Route("/action", action_post, methods=["POST"]),
|
||||||
|
Route("/echo", echo_post, methods=["POST"]),
|
||||||
Route("/exempt", exempt_post, methods=["POST"]),
|
Route("/exempt", exempt_post, methods=["POST"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -77,6 +84,19 @@ class TestCSRFValidation:
|
||||||
response = await client.post("/action", data={"csrf_token": token})
|
response = await client.post("/action", data={"csrf_token": token})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
async def test_form_token_path_replays_body_to_endpoint(self) -> None:
|
||||||
|
"""Regression: when the token comes from the form body (no header), the
|
||||||
|
middleware must replay the consumed body so the endpoint can still read
|
||||||
|
it. Previously this hung the request until client disconnect."""
|
||||||
|
app = _make_app()
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
token, _ = await _get_token_and_cookies(client)
|
||||||
|
response = await client.post(
|
||||||
|
"/echo", data={"csrf_token": token, "payload": "hello"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.text == "hello"
|
||||||
|
|
||||||
async def test_post_with_valid_header_token(self) -> None:
|
async def test_post_with_valid_header_token(self) -> None:
|
||||||
app = _make_app()
|
app = _make_app()
|
||||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue