feat: add CSRF middleware with synchronizer token pattern
This commit is contained in:
parent
b1291c801e
commit
f93290d43e
2 changed files with 241 additions and 0 deletions
94
src/porchlight/csrf.py
Normal file
94
src/porchlight/csrf.py
Normal file
|
|
@ -0,0 +1,94 @@
|
||||||
|
"""CSRF middleware implementing the OWASP Synchronizer Token Pattern."""
|
||||||
|
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import HTMLResponse
|
||||||
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_csrf_token(request: Request) -> str:
|
||||||
|
"""Get or create a CSRF token for the current session.
|
||||||
|
|
||||||
|
Stores the token at ``request.session["csrf_token"]``. Returns the
|
||||||
|
existing token when one is already present (idempotent per session).
|
||||||
|
"""
|
||||||
|
token: str | None = request.session.get("csrf_token")
|
||||||
|
if token is None:
|
||||||
|
token = secrets.token_urlsafe(32)
|
||||||
|
request.session["csrf_token"] = token
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
class CSRFMiddleware:
|
||||||
|
"""ASGI middleware that enforces CSRF token validation on non-safe requests."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
exempt_paths: set[str] | None = None,
|
||||||
|
check_origin: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.app = app
|
||||||
|
self.exempt_paths: set[str] = exempt_paths or set()
|
||||||
|
self.check_origin = check_origin
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
if scope["type"] != "http":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
request = Request(scope, receive)
|
||||||
|
method = request.method
|
||||||
|
|
||||||
|
# Safe methods pass through
|
||||||
|
if method in SAFE_METHODS:
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Exempt paths pass through
|
||||||
|
if request.url.path in self.exempt_paths:
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Origin check (defense-in-depth)
|
||||||
|
if self.check_origin is not None:
|
||||||
|
origin = request.headers.get("origin")
|
||||||
|
if origin is not None and origin != "null" and origin != self.check_origin:
|
||||||
|
logger.warning("CSRF origin mismatch: expected %s, got %s", self.check_origin, origin)
|
||||||
|
response = HTMLResponse(
|
||||||
|
"<h1>403 Forbidden</h1><p>Origin mismatch</p>",
|
||||||
|
status_code=403,
|
||||||
|
)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Token validation
|
||||||
|
expected_token: str | None = request.session.get("csrf_token")
|
||||||
|
|
||||||
|
# Check header first, then fall back to form field
|
||||||
|
submitted_token: str | None = request.headers.get("x-csrf-token")
|
||||||
|
if submitted_token is None:
|
||||||
|
form = await request.form()
|
||||||
|
submitted_token = form.get("csrf_token") # type: ignore[assignment]
|
||||||
|
|
||||||
|
if (
|
||||||
|
expected_token is None
|
||||||
|
or submitted_token is None
|
||||||
|
or not hmac.compare_digest(expected_token, submitted_token)
|
||||||
|
):
|
||||||
|
logger.warning("CSRF validation failed for %s %s", method, request.url.path)
|
||||||
|
response = HTMLResponse(
|
||||||
|
"<h1>403 Forbidden</h1><p>CSRF validation failed</p>",
|
||||||
|
status_code=403,
|
||||||
|
)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.app(scope, receive, send)
|
||||||
147
tests/test_csrf.py
Normal file
147
tests/test_csrf.py
Normal file
|
|
@ -0,0 +1,147 @@
|
||||||
|
"""Tests for CSRF middleware using a minimal Starlette app."""
|
||||||
|
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.middleware import Middleware
|
||||||
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse, PlainTextResponse
|
||||||
|
from starlette.routing import Route
|
||||||
|
|
||||||
|
from porchlight.csrf import CSRFMiddleware, generate_csrf_token
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app(check_origin: str | None = None) -> Starlette:
|
||||||
|
"""Build a minimal Starlette app with session + CSRF middleware for testing."""
|
||||||
|
|
||||||
|
async def get_token(request: Request) -> JSONResponse:
|
||||||
|
token = generate_csrf_token(request)
|
||||||
|
return JSONResponse({"token": token})
|
||||||
|
|
||||||
|
async def action_get(request: Request) -> JSONResponse:
|
||||||
|
token = generate_csrf_token(request)
|
||||||
|
return JSONResponse({"token": token})
|
||||||
|
|
||||||
|
async def action_post(request: Request) -> PlainTextResponse:
|
||||||
|
return PlainTextResponse("ok")
|
||||||
|
|
||||||
|
async def exempt_post(request: Request) -> PlainTextResponse:
|
||||||
|
return PlainTextResponse("ok")
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
Route("/get-token", get_token, methods=["GET"]),
|
||||||
|
Route("/action", action_get, methods=["GET"]),
|
||||||
|
Route("/action", action_post, methods=["POST"]),
|
||||||
|
Route("/exempt", exempt_post, methods=["POST"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
app = Starlette(
|
||||||
|
routes=routes,
|
||||||
|
middleware=[
|
||||||
|
Middleware(SessionMiddleware, secret_key="test-secret"),
|
||||||
|
Middleware(CSRFMiddleware, exempt_paths={"/exempt"}, check_origin=check_origin),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_token_and_cookies(client: AsyncClient) -> tuple[str, dict[str, str]]:
|
||||||
|
"""GET /get-token and return (token, cookies_dict)."""
|
||||||
|
response = await client.get("/get-token")
|
||||||
|
assert response.status_code == 200
|
||||||
|
token = response.json()["token"]
|
||||||
|
cookies = {cookie.name: cookie.value for cookie in client.cookies.jar}
|
||||||
|
return token, cookies
|
||||||
|
|
||||||
|
|
||||||
|
class TestCSRFValidation:
|
||||||
|
async def test_safe_methods_pass_through(self) -> None:
|
||||||
|
app = _make_app()
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.get("/action")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
async def test_post_without_token_returns_403(self) -> None:
|
||||||
|
app = _make_app()
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
# First GET to establish a session
|
||||||
|
await client.get("/get-token")
|
||||||
|
response = await client.post("/action")
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert "CSRF" in response.text
|
||||||
|
|
||||||
|
async def test_post_with_valid_form_token(self) -> None:
|
||||||
|
app = _make_app()
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
token, _ = await _get_token_and_cookies(client)
|
||||||
|
response = await client.post("/action", data={"csrf_token": token})
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
async def test_post_with_valid_header_token(self) -> None:
|
||||||
|
app = _make_app()
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
token, _ = await _get_token_and_cookies(client)
|
||||||
|
response = await client.post("/action", headers={"X-CSRF-Token": token})
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
async def test_post_with_wrong_token_returns_403(self) -> None:
|
||||||
|
app = _make_app()
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
await _get_token_and_cookies(client)
|
||||||
|
response = await client.post("/action", headers={"X-CSRF-Token": "wrong-token"})
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
async def test_exempt_path_skips_validation(self) -> None:
|
||||||
|
app = _make_app()
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post("/exempt")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestOriginCheck:
|
||||||
|
async def test_matching_origin_passes(self) -> None:
|
||||||
|
app = _make_app(check_origin="http://testserver")
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
token, _ = await _get_token_and_cookies(client)
|
||||||
|
response = await client.post(
|
||||||
|
"/action",
|
||||||
|
headers={"Origin": "http://testserver", "X-CSRF-Token": token},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
async def test_mismatched_origin_returns_403(self) -> None:
|
||||||
|
app = _make_app(check_origin="http://testserver")
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
token, _ = await _get_token_and_cookies(client)
|
||||||
|
response = await client.post(
|
||||||
|
"/action",
|
||||||
|
headers={"Origin": "http://evil.example.com", "X-CSRF-Token": token},
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert "Origin" in response.text
|
||||||
|
|
||||||
|
async def test_no_origin_falls_back_to_token_check(self) -> None:
|
||||||
|
app = _make_app(check_origin="http://testserver")
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
token, _ = await _get_token_and_cookies(client)
|
||||||
|
response = await client.post("/action", headers={"X-CSRF-Token": token})
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateCSRFToken:
|
||||||
|
async def test_generates_token_and_stores_in_session(self) -> None:
|
||||||
|
app = _make_app()
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.get("/get-token")
|
||||||
|
assert response.status_code == 200
|
||||||
|
token = response.json()["token"]
|
||||||
|
assert len(token) > 20
|
||||||
|
|
||||||
|
async def test_returns_same_token_within_session(self) -> None:
|
||||||
|
app = _make_app()
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response1 = await client.get("/get-token")
|
||||||
|
token1 = response1.json()["token"]
|
||||||
|
response2 = await client.get("/get-token")
|
||||||
|
token2 = response2.json()["token"]
|
||||||
|
assert token1 == token2
|
||||||
Loading…
Add table
Add a link
Reference in a new issue