diff --git a/src/porchlight/app.py b/src/porchlight/app.py index dd0f95f..bfd350f 100644 --- a/src/porchlight/app.py +++ b/src/porchlight/app.py @@ -8,11 +8,13 @@ from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from starlette.middleware.sessions import SessionMiddleware +from starlette.requests import Request from porchlight.authn.password import PasswordService from porchlight.authn.routes import router as authn_router from porchlight.authn.webauthn import WebAuthnService from porchlight.config import Settings, StorageBackend +from porchlight.csrf import CSRFMiddleware, generate_csrf_token from porchlight.invite.service import MagicLinkService from porchlight.manage.routes import router as manage_router from porchlight.oidc.endpoints import router as oidc_router @@ -107,10 +109,26 @@ def create_app(settings: Settings | None = None) -> FastAPI: # Session middleware session_secret = settings.session_secret or secrets.token_hex(32) - app.add_middleware(SessionMiddleware, secret_key=session_secret) # type: ignore[arg-type] + app.add_middleware( + CSRFMiddleware, + exempt_paths={"/token", "/userinfo"}, + check_origin=settings.issuer, + ) + app.add_middleware( + SessionMiddleware, + secret_key=session_secret, + same_site="lax", + https_only=settings.session_https_only, + ) # type: ignore[arg-type] # Templates - app.state.templates = Jinja2Templates(directory=str(PACKAGE_DIR / "templates")) + templates = Jinja2Templates(directory=str(PACKAGE_DIR / "templates")) + + def csrf_token_processor(request: Request) -> str: + return generate_csrf_token(request) + + templates.env.globals["csrf_token_processor"] = csrf_token_processor + app.state.templates = templates # Static files app.mount("/static", StaticFiles(directory=str(PACKAGE_DIR / "static")), name="static") diff --git a/src/porchlight/config.py b/src/porchlight/config.py index 3317e67..90e036e 100644 --- a/src/porchlight/config.py +++ b/src/porchlight/config.py @@ -47,6 +47,7 @@ class Settings(BaseSettings): # Session session_secret: str | None = None # If None, a random secret is generated per process + session_https_only: bool = True # Magic links invite_ttl: int = 86400 # seconds diff --git a/tests/conftest.py b/tests/conftest.py index 2595cbf..63f47aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ from porchlight.config import Settings @pytest.fixture def settings() -> Settings: - return Settings(issuer="http://localhost:8000", sqlite_path=":memory:") + return Settings(issuer="http://localhost:8000", sqlite_path=":memory:", session_https_only=False) @pytest.fixture diff --git a/tests/test_csrf.py b/tests/test_csrf.py index c4f560d..506b8c3 100644 --- a/tests/test_csrf.py +++ b/tests/test_csrf.py @@ -145,3 +145,18 @@ class TestGenerateCSRFToken: response2 = await client.get("/get-token") token2 = response2.json()["token"] assert token1 == token2 + + +class TestAppIntegration: + """Test CSRF middleware is wired into the real app.""" + + async def test_post_without_csrf_token_returns_403(self, client: AsyncClient) -> None: + """Any POST to a session-protected endpoint without CSRF token gets 403.""" + resp = await client.post("/login/password", data={"username": "x", "password": "y"}) + assert resp.status_code == 403 + + async def test_exempt_token_endpoint(self, client: AsyncClient) -> None: + """The /token endpoint is exempt from CSRF (uses client auth).""" + resp = await client.post("/token", data={"grant_type": "authorization_code", "code": "fake"}) + # Should NOT be 403 — it should fail for auth reasons, not CSRF + assert resp.status_code != 403