"""Tests for CSRF middleware using a minimal Starlette app.""" from httpx import ASGITransport, AsyncClient from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.sessions import SessionMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Route from porchlight.csrf import CSRFMiddleware, generate_csrf_token def _make_app(check_origin: str | None = None) -> Starlette: """Build a minimal Starlette app with session + CSRF middleware for testing.""" async def get_token(request: Request) -> JSONResponse: token = generate_csrf_token(request) return JSONResponse({"token": token}) async def action_get(request: Request) -> JSONResponse: token = generate_csrf_token(request) return JSONResponse({"token": token}) async def action_post(request: Request) -> PlainTextResponse: return PlainTextResponse("ok") async def exempt_post(request: Request) -> PlainTextResponse: return PlainTextResponse("ok") routes = [ Route("/get-token", get_token, methods=["GET"]), Route("/action", action_get, methods=["GET"]), Route("/action", action_post, methods=["POST"]), Route("/exempt", exempt_post, methods=["POST"]), ] app = Starlette( routes=routes, middleware=[ Middleware(SessionMiddleware, secret_key="test-secret"), Middleware(CSRFMiddleware, exempt_paths={"/exempt"}, check_origin=check_origin), ], ) return app async def _get_token_and_cookies(client: AsyncClient) -> tuple[str, dict[str, str]]: """GET /get-token and return (token, cookies_dict).""" response = await client.get("/get-token") assert response.status_code == 200 token = response.json()["token"] cookies = {cookie.name: cookie.value for cookie in client.cookies.jar} return token, cookies class TestCSRFValidation: async def test_safe_methods_pass_through(self) -> None: app = _make_app() async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: response = await client.get("/action") assert response.status_code == 200 async def test_post_without_token_returns_403(self) -> None: app = _make_app() async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: # First GET to establish a session await client.get("/get-token") response = await client.post("/action") assert response.status_code == 403 assert "CSRF" in response.text async def test_post_with_valid_form_token(self) -> None: app = _make_app() async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: token, _ = await _get_token_and_cookies(client) response = await client.post("/action", data={"csrf_token": token}) assert response.status_code == 200 async def test_post_with_valid_header_token(self) -> None: app = _make_app() async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: token, _ = await _get_token_and_cookies(client) response = await client.post("/action", headers={"X-CSRF-Token": token}) assert response.status_code == 200 async def test_post_with_wrong_token_returns_403(self) -> None: app = _make_app() async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: await _get_token_and_cookies(client) response = await client.post("/action", headers={"X-CSRF-Token": "wrong-token"}) assert response.status_code == 403 async def test_exempt_path_skips_validation(self) -> None: app = _make_app() async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: response = await client.post("/exempt") assert response.status_code == 200 class TestOriginCheck: async def test_matching_origin_passes(self) -> None: app = _make_app(check_origin="http://testserver") async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: token, _ = await _get_token_and_cookies(client) response = await client.post( "/action", headers={"Origin": "http://testserver", "X-CSRF-Token": token}, ) assert response.status_code == 200 async def test_mismatched_origin_returns_403(self) -> None: app = _make_app(check_origin="http://testserver") async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: token, _ = await _get_token_and_cookies(client) response = await client.post( "/action", headers={"Origin": "http://evil.example.com", "X-CSRF-Token": token}, ) assert response.status_code == 403 assert "Origin" in response.text async def test_no_origin_falls_back_to_token_check(self) -> None: app = _make_app(check_origin="http://testserver") async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: token, _ = await _get_token_and_cookies(client) response = await client.post("/action", headers={"X-CSRF-Token": token}) assert response.status_code == 200 class TestGenerateCSRFToken: async def test_generates_token_and_stores_in_session(self) -> None: app = _make_app() async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: response = await client.get("/get-token") assert response.status_code == 200 token = response.json()["token"] assert len(token) > 20 async def test_returns_same_token_within_session(self) -> None: app = _make_app() async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: response1 = await client.get("/get-token") token1 = response1.json()["token"] response2 = await client.get("/get-token") token2 = response2.json()["token"] assert token1 == token2