feat: add CSRF middleware with synchronizer token pattern

This commit is contained in:
Johan Lundberg 2026-02-19 13:26:33 +01:00
parent b1291c801e
commit f93290d43e
No known key found for this signature in database
GPG key ID: A6C152738D03C7D1
2 changed files with 241 additions and 0 deletions

94
src/porchlight/csrf.py Normal file
View file

@ -0,0 +1,94 @@
"""CSRF middleware implementing the OWASP Synchronizer Token Pattern."""
import hmac
import logging
import secrets
from starlette.requests import Request
from starlette.responses import HTMLResponse
from starlette.types import ASGIApp, Receive, Scope, Send
logger = logging.getLogger(__name__)
SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
def generate_csrf_token(request: Request) -> str:
"""Get or create a CSRF token for the current session.
Stores the token at ``request.session["csrf_token"]``. Returns the
existing token when one is already present (idempotent per session).
"""
token: str | None = request.session.get("csrf_token")
if token is None:
token = secrets.token_urlsafe(32)
request.session["csrf_token"] = token
return token
class CSRFMiddleware:
"""ASGI middleware that enforces CSRF token validation on non-safe requests."""
def __init__(
self,
app: ASGIApp,
exempt_paths: set[str] | None = None,
check_origin: str | None = None,
) -> None:
self.app = app
self.exempt_paths: set[str] = exempt_paths or set()
self.check_origin = check_origin
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive)
method = request.method
# Safe methods pass through
if method in SAFE_METHODS:
await self.app(scope, receive, send)
return
# Exempt paths pass through
if request.url.path in self.exempt_paths:
await self.app(scope, receive, send)
return
# Origin check (defense-in-depth)
if self.check_origin is not None:
origin = request.headers.get("origin")
if origin is not None and origin != "null" and origin != self.check_origin:
logger.warning("CSRF origin mismatch: expected %s, got %s", self.check_origin, origin)
response = HTMLResponse(
"<h1>403 Forbidden</h1><p>Origin mismatch</p>",
status_code=403,
)
await response(scope, receive, send)
return
# Token validation
expected_token: str | None = request.session.get("csrf_token")
# Check header first, then fall back to form field
submitted_token: str | None = request.headers.get("x-csrf-token")
if submitted_token is None:
form = await request.form()
submitted_token = form.get("csrf_token") # type: ignore[assignment]
if (
expected_token is None
or submitted_token is None
or not hmac.compare_digest(expected_token, submitted_token)
):
logger.warning("CSRF validation failed for %s %s", method, request.url.path)
response = HTMLResponse(
"<h1>403 Forbidden</h1><p>CSRF validation failed</p>",
status_code=403,
)
await response(scope, receive, send)
return
await self.app(scope, receive, send)