feat: add CSRF middleware with synchronizer token pattern

This commit is contained in:
Johan Lundberg 2026-02-19 13:26:33 +01:00
parent b1291c801e
commit f93290d43e
No known key found for this signature in database
GPG key ID: A6C152738D03C7D1
2 changed files with 241 additions and 0 deletions

94
src/porchlight/csrf.py Normal file
View file

@ -0,0 +1,94 @@
"""CSRF middleware implementing the OWASP Synchronizer Token Pattern."""
import hmac
import logging
import secrets
from starlette.requests import Request
from starlette.responses import HTMLResponse
from starlette.types import ASGIApp, Receive, Scope, Send
logger = logging.getLogger(__name__)
SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
def generate_csrf_token(request: Request) -> str:
"""Get or create a CSRF token for the current session.
Stores the token at ``request.session["csrf_token"]``. Returns the
existing token when one is already present (idempotent per session).
"""
token: str | None = request.session.get("csrf_token")
if token is None:
token = secrets.token_urlsafe(32)
request.session["csrf_token"] = token
return token
class CSRFMiddleware:
"""ASGI middleware that enforces CSRF token validation on non-safe requests."""
def __init__(
self,
app: ASGIApp,
exempt_paths: set[str] | None = None,
check_origin: str | None = None,
) -> None:
self.app = app
self.exempt_paths: set[str] = exempt_paths or set()
self.check_origin = check_origin
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive)
method = request.method
# Safe methods pass through
if method in SAFE_METHODS:
await self.app(scope, receive, send)
return
# Exempt paths pass through
if request.url.path in self.exempt_paths:
await self.app(scope, receive, send)
return
# Origin check (defense-in-depth)
if self.check_origin is not None:
origin = request.headers.get("origin")
if origin is not None and origin != "null" and origin != self.check_origin:
logger.warning("CSRF origin mismatch: expected %s, got %s", self.check_origin, origin)
response = HTMLResponse(
"<h1>403 Forbidden</h1><p>Origin mismatch</p>",
status_code=403,
)
await response(scope, receive, send)
return
# Token validation
expected_token: str | None = request.session.get("csrf_token")
# Check header first, then fall back to form field
submitted_token: str | None = request.headers.get("x-csrf-token")
if submitted_token is None:
form = await request.form()
submitted_token = form.get("csrf_token") # type: ignore[assignment]
if (
expected_token is None
or submitted_token is None
or not hmac.compare_digest(expected_token, submitted_token)
):
logger.warning("CSRF validation failed for %s %s", method, request.url.path)
response = HTMLResponse(
"<h1>403 Forbidden</h1><p>CSRF validation failed</p>",
status_code=403,
)
await response(scope, receive, send)
return
await self.app(scope, receive, send)

147
tests/test_csrf.py Normal file
View file

@ -0,0 +1,147 @@
"""Tests for CSRF middleware using a minimal Starlette app."""
from httpx import ASGITransport, AsyncClient
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Route
from porchlight.csrf import CSRFMiddleware, generate_csrf_token
def _make_app(check_origin: str | None = None) -> Starlette:
"""Build a minimal Starlette app with session + CSRF middleware for testing."""
async def get_token(request: Request) -> JSONResponse:
token = generate_csrf_token(request)
return JSONResponse({"token": token})
async def action_get(request: Request) -> JSONResponse:
token = generate_csrf_token(request)
return JSONResponse({"token": token})
async def action_post(request: Request) -> PlainTextResponse:
return PlainTextResponse("ok")
async def exempt_post(request: Request) -> PlainTextResponse:
return PlainTextResponse("ok")
routes = [
Route("/get-token", get_token, methods=["GET"]),
Route("/action", action_get, methods=["GET"]),
Route("/action", action_post, methods=["POST"]),
Route("/exempt", exempt_post, methods=["POST"]),
]
app = Starlette(
routes=routes,
middleware=[
Middleware(SessionMiddleware, secret_key="test-secret"),
Middleware(CSRFMiddleware, exempt_paths={"/exempt"}, check_origin=check_origin),
],
)
return app
async def _get_token_and_cookies(client: AsyncClient) -> tuple[str, dict[str, str]]:
"""GET /get-token and return (token, cookies_dict)."""
response = await client.get("/get-token")
assert response.status_code == 200
token = response.json()["token"]
cookies = {cookie.name: cookie.value for cookie in client.cookies.jar}
return token, cookies
class TestCSRFValidation:
async def test_safe_methods_pass_through(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.get("/action")
assert response.status_code == 200
async def test_post_without_token_returns_403(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
# First GET to establish a session
await client.get("/get-token")
response = await client.post("/action")
assert response.status_code == 403
assert "CSRF" in response.text
async def test_post_with_valid_form_token(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post("/action", data={"csrf_token": token})
assert response.status_code == 200
async def test_post_with_valid_header_token(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post("/action", headers={"X-CSRF-Token": token})
assert response.status_code == 200
async def test_post_with_wrong_token_returns_403(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
await _get_token_and_cookies(client)
response = await client.post("/action", headers={"X-CSRF-Token": "wrong-token"})
assert response.status_code == 403
async def test_exempt_path_skips_validation(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/exempt")
assert response.status_code == 200
class TestOriginCheck:
async def test_matching_origin_passes(self) -> None:
app = _make_app(check_origin="http://testserver")
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post(
"/action",
headers={"Origin": "http://testserver", "X-CSRF-Token": token},
)
assert response.status_code == 200
async def test_mismatched_origin_returns_403(self) -> None:
app = _make_app(check_origin="http://testserver")
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post(
"/action",
headers={"Origin": "http://evil.example.com", "X-CSRF-Token": token},
)
assert response.status_code == 403
assert "Origin" in response.text
async def test_no_origin_falls_back_to_token_check(self) -> None:
app = _make_app(check_origin="http://testserver")
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
token, _ = await _get_token_and_cookies(client)
response = await client.post("/action", headers={"X-CSRF-Token": token})
assert response.status_code == 200
class TestGenerateCSRFToken:
async def test_generates_token_and_stores_in_session(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.get("/get-token")
assert response.status_code == 200
token = response.json()["token"]
assert len(token) > 20
async def test_returns_same_token_within_session(self) -> None:
app = _make_app()
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
response1 = await client.get("/get-token")
token1 = response1.json()["token"]
response2 = await client.get("/get-token")
token2 = response2.json()["token"]
assert token1 == token2