diff --git a/src/porchlight/csrf.py b/src/porchlight/csrf.py new file mode 100644 index 0000000..aa82096 --- /dev/null +++ b/src/porchlight/csrf.py @@ -0,0 +1,94 @@ +"""CSRF middleware implementing the OWASP Synchronizer Token Pattern.""" + +import hmac +import logging +import secrets + +from starlette.requests import Request +from starlette.responses import HTMLResponse +from starlette.types import ASGIApp, Receive, Scope, Send + +logger = logging.getLogger(__name__) + +SAFE_METHODS = {"GET", "HEAD", "OPTIONS"} + + +def generate_csrf_token(request: Request) -> str: + """Get or create a CSRF token for the current session. + + Stores the token at ``request.session["csrf_token"]``. Returns the + existing token when one is already present (idempotent per session). + """ + token: str | None = request.session.get("csrf_token") + if token is None: + token = secrets.token_urlsafe(32) + request.session["csrf_token"] = token + return token + + +class CSRFMiddleware: + """ASGI middleware that enforces CSRF token validation on non-safe requests.""" + + def __init__( + self, + app: ASGIApp, + exempt_paths: set[str] | None = None, + check_origin: str | None = None, + ) -> None: + self.app = app + self.exempt_paths: set[str] = exempt_paths or set() + self.check_origin = check_origin + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = Request(scope, receive) + method = request.method + + # Safe methods pass through + if method in SAFE_METHODS: + await self.app(scope, receive, send) + return + + # Exempt paths pass through + if request.url.path in self.exempt_paths: + await self.app(scope, receive, send) + return + + # Origin check (defense-in-depth) + if self.check_origin is not None: + origin = request.headers.get("origin") + if origin is not None and origin != "null" and origin != self.check_origin: + logger.warning("CSRF origin mismatch: expected %s, got %s", self.check_origin, origin) + response = HTMLResponse( + "

403 Forbidden

Origin mismatch

", + status_code=403, + ) + await response(scope, receive, send) + return + + # Token validation + expected_token: str | None = request.session.get("csrf_token") + + # Check header first, then fall back to form field + submitted_token: str | None = request.headers.get("x-csrf-token") + if submitted_token is None: + form = await request.form() + submitted_token = form.get("csrf_token") # type: ignore[assignment] + + if ( + expected_token is None + or submitted_token is None + or not hmac.compare_digest(expected_token, submitted_token) + ): + logger.warning("CSRF validation failed for %s %s", method, request.url.path) + response = HTMLResponse( + "

403 Forbidden

CSRF validation failed

", + status_code=403, + ) + await response(scope, receive, send) + return + + await self.app(scope, receive, send) diff --git a/tests/test_csrf.py b/tests/test_csrf.py new file mode 100644 index 0000000..c4f560d --- /dev/null +++ b/tests/test_csrf.py @@ -0,0 +1,147 @@ +"""Tests for CSRF middleware using a minimal Starlette app.""" + +from httpx import ASGITransport, AsyncClient +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.sessions import SessionMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, PlainTextResponse +from starlette.routing import Route + +from porchlight.csrf import CSRFMiddleware, generate_csrf_token + + +def _make_app(check_origin: str | None = None) -> Starlette: + """Build a minimal Starlette app with session + CSRF middleware for testing.""" + + async def get_token(request: Request) -> JSONResponse: + token = generate_csrf_token(request) + return JSONResponse({"token": token}) + + async def action_get(request: Request) -> JSONResponse: + token = generate_csrf_token(request) + return JSONResponse({"token": token}) + + async def action_post(request: Request) -> PlainTextResponse: + return PlainTextResponse("ok") + + async def exempt_post(request: Request) -> PlainTextResponse: + return PlainTextResponse("ok") + + routes = [ + Route("/get-token", get_token, methods=["GET"]), + Route("/action", action_get, methods=["GET"]), + Route("/action", action_post, methods=["POST"]), + Route("/exempt", exempt_post, methods=["POST"]), + ] + + app = Starlette( + routes=routes, + middleware=[ + Middleware(SessionMiddleware, secret_key="test-secret"), + Middleware(CSRFMiddleware, exempt_paths={"/exempt"}, check_origin=check_origin), + ], + ) + return app + + +async def _get_token_and_cookies(client: AsyncClient) -> tuple[str, dict[str, str]]: + """GET /get-token and return (token, cookies_dict).""" + response = await client.get("/get-token") + assert response.status_code == 200 + token = response.json()["token"] + cookies = {cookie.name: cookie.value for cookie in client.cookies.jar} + return token, cookies + + +class TestCSRFValidation: + async def test_safe_methods_pass_through(self) -> None: + app = _make_app() + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.get("/action") + assert response.status_code == 200 + + async def test_post_without_token_returns_403(self) -> None: + app = _make_app() + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + # First GET to establish a session + await client.get("/get-token") + response = await client.post("/action") + assert response.status_code == 403 + assert "CSRF" in response.text + + async def test_post_with_valid_form_token(self) -> None: + app = _make_app() + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + token, _ = await _get_token_and_cookies(client) + response = await client.post("/action", data={"csrf_token": token}) + assert response.status_code == 200 + + async def test_post_with_valid_header_token(self) -> None: + app = _make_app() + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + token, _ = await _get_token_and_cookies(client) + response = await client.post("/action", headers={"X-CSRF-Token": token}) + assert response.status_code == 200 + + async def test_post_with_wrong_token_returns_403(self) -> None: + app = _make_app() + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + await _get_token_and_cookies(client) + response = await client.post("/action", headers={"X-CSRF-Token": "wrong-token"}) + assert response.status_code == 403 + + async def test_exempt_path_skips_validation(self) -> None: + app = _make_app() + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post("/exempt") + assert response.status_code == 200 + + +class TestOriginCheck: + async def test_matching_origin_passes(self) -> None: + app = _make_app(check_origin="http://testserver") + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + token, _ = await _get_token_and_cookies(client) + response = await client.post( + "/action", + headers={"Origin": "http://testserver", "X-CSRF-Token": token}, + ) + assert response.status_code == 200 + + async def test_mismatched_origin_returns_403(self) -> None: + app = _make_app(check_origin="http://testserver") + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + token, _ = await _get_token_and_cookies(client) + response = await client.post( + "/action", + headers={"Origin": "http://evil.example.com", "X-CSRF-Token": token}, + ) + assert response.status_code == 403 + assert "Origin" in response.text + + async def test_no_origin_falls_back_to_token_check(self) -> None: + app = _make_app(check_origin="http://testserver") + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + token, _ = await _get_token_and_cookies(client) + response = await client.post("/action", headers={"X-CSRF-Token": token}) + assert response.status_code == 200 + + +class TestGenerateCSRFToken: + async def test_generates_token_and_stores_in_session(self) -> None: + app = _make_app() + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.get("/get-token") + assert response.status_code == 200 + token = response.json()["token"] + assert len(token) > 20 + + async def test_returns_same_token_within_session(self) -> None: + app = _make_app() + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + response1 = await client.get("/get-token") + token1 = response1.json()["token"] + response2 = await client.get("/get-token") + token2 = response2.json()["token"] + assert token1 == token2