porchlight/tests/test_csrf.py
Johan Lundberg 3c5451b9c2
fix(csrf): replay request body consumed during token validation
The CSRF middleware read the request body via request.form() to extract
the csrf_token form field, then handed the already-consumed ASGI receive
to the downstream app. The endpoint's own request.form()/body() then
blocked indefinitely waiting for body bytes that would never arrive,
hanging the request until the client disconnected (ClientDisconnect).

Only native form POSTs hit this path: htmx requests carry the token in
the X-CSRF-Token header and skip the body read. The OIDC consent form is
a plain form with the token in the body, so authorization consent hung.

Buffer the body when falling back to the form token and replay it to the
downstream app via a wrapped receive. Header-token requests are
unchanged. Adds a regression test where the endpoint reads the body after
body-token CSRF validation.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-10 13:07:02 +02:00

182 lines
8.2 KiB
Python

"""Tests for CSRF middleware using a minimal Starlette app."""
from httpx import ASGITransport, AsyncClient
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Route
from porchlight.csrf import CSRFMiddleware, generate_csrf_token
def _make_app(check_origin: str | None = None) -> Starlette:
"""Build a minimal Starlette app with session + CSRF middleware for testing."""
async def get_token(request: Request) -> JSONResponse:
token = generate_csrf_token(request)
return JSONResponse({"token": token})
async def action_get(request: Request) -> JSONResponse:
token = generate_csrf_token(request)
return JSONResponse({"token": token})
async def action_post(request: Request) -> PlainTextResponse:
return PlainTextResponse("ok")
async def echo_post(request: Request) -> PlainTextResponse:
# Reads the body itself — this only works if the CSRF middleware
# replayed the body it consumed to find the form token.
form = await request.form()
return PlainTextResponse(str(form.get("payload")))
async def exempt_post(request: Request) -> PlainTextResponse:
return PlainTextResponse("ok")
routes = [
Route("/get-token", get_token, methods=["GET"]),
Route("/action", action_get, methods=["GET"]),
Route("/action", action_post, methods=["POST"]),
Route("/echo", echo_post, methods=["POST"]),
Route("/exempt", exempt_post, methods=["POST"]),
]
app = Starlette(
routes=routes,
middleware=[
Middleware(SessionMiddleware, secret_key="test-secret"),
Middleware(CSRFMiddleware, exempt_paths={"/exempt"}, check_origin=check_origin),
],
)
return app
async def _get_token_and_cookies(client: AsyncClient) -> tuple[str, dict[str, str]]:
"""GET /get-token and return (token, cookies_dict)."""
response = await client.get("/get-token")
assert response.status_code == 200
token = response.json()["token"]
cookies = {cookie.name: cookie.value for cookie in client.cookies.jar}
return token, cookies
class TestCSRFValidation:
async def test_safe_methods_pass_through(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.get("/action")
assert response.status_code == 200
async def test_post_without_token_returns_403(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
# First GET to establish a session
await client.get("/get-token")
response = await client.post("/action")
assert response.status_code == 403
assert "CSRF" in response.text
async def test_post_with_valid_form_token(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post("/action", data={"csrf_token": token})
assert response.status_code == 200
async def test_form_token_path_replays_body_to_endpoint(self) -> None:
"""Regression: when the token comes from the form body (no header), the
middleware must replay the consumed body so the endpoint can still read
it. Previously this hung the request until client disconnect."""
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post(
"/echo", data={"csrf_token": token, "payload": "hello"}
)
assert response.status_code == 200
assert response.text == "hello"
async def test_post_with_valid_header_token(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post("/action", headers={"X-CSRF-Token": token})
assert response.status_code == 200
async def test_post_with_wrong_token_returns_403(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
await _get_token_and_cookies(client)
response = await client.post("/action", headers={"X-CSRF-Token": "wrong-token"})
assert response.status_code == 403
async def test_exempt_path_skips_validation(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/exempt")
assert response.status_code == 200
class TestOriginCheck:
async def test_matching_origin_passes(self) -> None:
app = _make_app(check_origin="http://testserver")
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post(
"/action",
headers={"Origin": "http://testserver", "X-CSRF-Token": token},
)
assert response.status_code == 200
async def test_mismatched_origin_returns_403(self) -> None:
app = _make_app(check_origin="http://testserver")
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post(
"/action",
headers={"Origin": "http://evil.example.com", "X-CSRF-Token": token},
)
assert response.status_code == 403
assert "Origin" in response.text
async def test_no_origin_falls_back_to_token_check(self) -> None:
app = _make_app(check_origin="http://testserver")
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post("/action", headers={"X-CSRF-Token": token})
assert response.status_code == 200
class TestGenerateCSRFToken:
async def test_generates_token_and_stores_in_session(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.get("/get-token")
assert response.status_code == 200
token = response.json()["token"]
assert len(token) > 20
async def test_returns_same_token_within_session(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
response1 = await client.get("/get-token")
token1 = response1.json()["token"]
response2 = await client.get("/get-token")
token2 = response2.json()["token"]
assert token1 == token2
class TestAppIntegration:
"""Test CSRF middleware is wired into the real app."""
async def test_post_without_csrf_token_returns_403(self, client: AsyncClient) -> None:
"""Any POST to a session-protected endpoint without CSRF token gets 403."""
resp = await client.post("/login/password", data={"username": "x", "password": "y"})
assert resp.status_code == 403
async def test_exempt_token_endpoint(self, client: AsyncClient) -> None:
"""The /token endpoint is exempt from CSRF (uses client auth)."""
resp = await client.post("/token", data={"grant_type": "authorization_code", "code": "fake"})
# Should NOT be 403 — it should fail for auth reasons, not CSRF
assert resp.status_code != 403