The CSRF middleware read the request body via request.form() to extract the csrf_token form field, then handed the already-consumed ASGI receive to the downstream app. The endpoint's own request.form()/body() then blocked indefinitely waiting for body bytes that would never arrive, hanging the request until the client disconnected (ClientDisconnect). Only native form POSTs hit this path: htmx requests carry the token in the X-CSRF-Token header and skip the body read. The OIDC consent form is a plain form with the token in the body, so authorization consent hung. Buffer the body when falling back to the form token and replay it to the downstream app via a wrapped receive. Header-token requests are unchanged. Adds a regression test where the endpoint reads the body after body-token CSRF validation. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
182 lines
8.2 KiB
Python
182 lines
8.2 KiB
Python
"""Tests for CSRF middleware using a minimal Starlette app."""
|
|
|
|
from httpx import ASGITransport, AsyncClient
|
|
from starlette.applications import Starlette
|
|
from starlette.middleware import Middleware
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, PlainTextResponse
|
|
from starlette.routing import Route
|
|
|
|
from porchlight.csrf import CSRFMiddleware, generate_csrf_token
|
|
|
|
|
|
def _make_app(check_origin: str | None = None) -> Starlette:
|
|
"""Build a minimal Starlette app with session + CSRF middleware for testing."""
|
|
|
|
async def get_token(request: Request) -> JSONResponse:
|
|
token = generate_csrf_token(request)
|
|
return JSONResponse({"token": token})
|
|
|
|
async def action_get(request: Request) -> JSONResponse:
|
|
token = generate_csrf_token(request)
|
|
return JSONResponse({"token": token})
|
|
|
|
async def action_post(request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("ok")
|
|
|
|
async def echo_post(request: Request) -> PlainTextResponse:
|
|
# Reads the body itself — this only works if the CSRF middleware
|
|
# replayed the body it consumed to find the form token.
|
|
form = await request.form()
|
|
return PlainTextResponse(str(form.get("payload")))
|
|
|
|
async def exempt_post(request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("ok")
|
|
|
|
routes = [
|
|
Route("/get-token", get_token, methods=["GET"]),
|
|
Route("/action", action_get, methods=["GET"]),
|
|
Route("/action", action_post, methods=["POST"]),
|
|
Route("/echo", echo_post, methods=["POST"]),
|
|
Route("/exempt", exempt_post, methods=["POST"]),
|
|
]
|
|
|
|
app = Starlette(
|
|
routes=routes,
|
|
middleware=[
|
|
Middleware(SessionMiddleware, secret_key="test-secret"),
|
|
Middleware(CSRFMiddleware, exempt_paths={"/exempt"}, check_origin=check_origin),
|
|
],
|
|
)
|
|
return app
|
|
|
|
|
|
async def _get_token_and_cookies(client: AsyncClient) -> tuple[str, dict[str, str]]:
|
|
"""GET /get-token and return (token, cookies_dict)."""
|
|
response = await client.get("/get-token")
|
|
assert response.status_code == 200
|
|
token = response.json()["token"]
|
|
cookies = {cookie.name: cookie.value for cookie in client.cookies.jar}
|
|
return token, cookies
|
|
|
|
|
|
class TestCSRFValidation:
|
|
async def test_safe_methods_pass_through(self) -> None:
|
|
app = _make_app()
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
response = await client.get("/action")
|
|
assert response.status_code == 200
|
|
|
|
async def test_post_without_token_returns_403(self) -> None:
|
|
app = _make_app()
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
# First GET to establish a session
|
|
await client.get("/get-token")
|
|
response = await client.post("/action")
|
|
assert response.status_code == 403
|
|
assert "CSRF" in response.text
|
|
|
|
async def test_post_with_valid_form_token(self) -> None:
|
|
app = _make_app()
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
token, _ = await _get_token_and_cookies(client)
|
|
response = await client.post("/action", data={"csrf_token": token})
|
|
assert response.status_code == 200
|
|
|
|
async def test_form_token_path_replays_body_to_endpoint(self) -> None:
|
|
"""Regression: when the token comes from the form body (no header), the
|
|
middleware must replay the consumed body so the endpoint can still read
|
|
it. Previously this hung the request until client disconnect."""
|
|
app = _make_app()
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
token, _ = await _get_token_and_cookies(client)
|
|
response = await client.post(
|
|
"/echo", data={"csrf_token": token, "payload": "hello"}
|
|
)
|
|
assert response.status_code == 200
|
|
assert response.text == "hello"
|
|
|
|
async def test_post_with_valid_header_token(self) -> None:
|
|
app = _make_app()
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
token, _ = await _get_token_and_cookies(client)
|
|
response = await client.post("/action", headers={"X-CSRF-Token": token})
|
|
assert response.status_code == 200
|
|
|
|
async def test_post_with_wrong_token_returns_403(self) -> None:
|
|
app = _make_app()
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
await _get_token_and_cookies(client)
|
|
response = await client.post("/action", headers={"X-CSRF-Token": "wrong-token"})
|
|
assert response.status_code == 403
|
|
|
|
async def test_exempt_path_skips_validation(self) -> None:
|
|
app = _make_app()
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
response = await client.post("/exempt")
|
|
assert response.status_code == 200
|
|
|
|
|
|
class TestOriginCheck:
|
|
async def test_matching_origin_passes(self) -> None:
|
|
app = _make_app(check_origin="http://testserver")
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
token, _ = await _get_token_and_cookies(client)
|
|
response = await client.post(
|
|
"/action",
|
|
headers={"Origin": "http://testserver", "X-CSRF-Token": token},
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
async def test_mismatched_origin_returns_403(self) -> None:
|
|
app = _make_app(check_origin="http://testserver")
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
token, _ = await _get_token_and_cookies(client)
|
|
response = await client.post(
|
|
"/action",
|
|
headers={"Origin": "http://evil.example.com", "X-CSRF-Token": token},
|
|
)
|
|
assert response.status_code == 403
|
|
assert "Origin" in response.text
|
|
|
|
async def test_no_origin_falls_back_to_token_check(self) -> None:
|
|
app = _make_app(check_origin="http://testserver")
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
token, _ = await _get_token_and_cookies(client)
|
|
response = await client.post("/action", headers={"X-CSRF-Token": token})
|
|
assert response.status_code == 200
|
|
|
|
|
|
class TestGenerateCSRFToken:
|
|
async def test_generates_token_and_stores_in_session(self) -> None:
|
|
app = _make_app()
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
response = await client.get("/get-token")
|
|
assert response.status_code == 200
|
|
token = response.json()["token"]
|
|
assert len(token) > 20
|
|
|
|
async def test_returns_same_token_within_session(self) -> None:
|
|
app = _make_app()
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
|
response1 = await client.get("/get-token")
|
|
token1 = response1.json()["token"]
|
|
response2 = await client.get("/get-token")
|
|
token2 = response2.json()["token"]
|
|
assert token1 == token2
|
|
|
|
|
|
class TestAppIntegration:
|
|
"""Test CSRF middleware is wired into the real app."""
|
|
|
|
async def test_post_without_csrf_token_returns_403(self, client: AsyncClient) -> None:
|
|
"""Any POST to a session-protected endpoint without CSRF token gets 403."""
|
|
resp = await client.post("/login/password", data={"username": "x", "password": "y"})
|
|
assert resp.status_code == 403
|
|
|
|
async def test_exempt_token_endpoint(self, client: AsyncClient) -> None:
|
|
"""The /token endpoint is exempt from CSRF (uses client auth)."""
|
|
resp = await client.post("/token", data={"grant_type": "authorization_code", "code": "fake"})
|
|
# Should NOT be 403 — it should fail for auth reasons, not CSRF
|
|
assert resp.status_code != 403
|