feat: wire CSRF middleware and harden session cookie

This commit is contained in:
Johan Lundberg 2026-02-19 13:45:58 +01:00
parent b5ea9950a2
commit d1f2b39cb6
No known key found for this signature in database
GPG key ID: A6C152738D03C7D1
4 changed files with 37 additions and 3 deletions

View file

@ -8,11 +8,13 @@ from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import Request
from porchlight.authn.password import PasswordService from porchlight.authn.password import PasswordService
from porchlight.authn.routes import router as authn_router from porchlight.authn.routes import router as authn_router
from porchlight.authn.webauthn import WebAuthnService from porchlight.authn.webauthn import WebAuthnService
from porchlight.config import Settings, StorageBackend from porchlight.config import Settings, StorageBackend
from porchlight.csrf import CSRFMiddleware, generate_csrf_token
from porchlight.invite.service import MagicLinkService from porchlight.invite.service import MagicLinkService
from porchlight.manage.routes import router as manage_router from porchlight.manage.routes import router as manage_router
from porchlight.oidc.endpoints import router as oidc_router from porchlight.oidc.endpoints import router as oidc_router
@ -107,10 +109,26 @@ def create_app(settings: Settings | None = None) -> FastAPI:
# Session middleware # Session middleware
session_secret = settings.session_secret or secrets.token_hex(32) session_secret = settings.session_secret or secrets.token_hex(32)
app.add_middleware(SessionMiddleware, secret_key=session_secret) # type: ignore[arg-type] app.add_middleware(
CSRFMiddleware,
exempt_paths={"/token", "/userinfo"},
check_origin=settings.issuer,
)
app.add_middleware(
SessionMiddleware,
secret_key=session_secret,
same_site="lax",
https_only=settings.session_https_only,
) # type: ignore[arg-type]
# Templates # Templates
app.state.templates = Jinja2Templates(directory=str(PACKAGE_DIR / "templates")) templates = Jinja2Templates(directory=str(PACKAGE_DIR / "templates"))
def csrf_token_processor(request: Request) -> str:
return generate_csrf_token(request)
templates.env.globals["csrf_token_processor"] = csrf_token_processor
app.state.templates = templates
# Static files # Static files
app.mount("/static", StaticFiles(directory=str(PACKAGE_DIR / "static")), name="static") app.mount("/static", StaticFiles(directory=str(PACKAGE_DIR / "static")), name="static")

View file

@ -47,6 +47,7 @@ class Settings(BaseSettings):
# Session # Session
session_secret: str | None = None # If None, a random secret is generated per process session_secret: str | None = None # If None, a random secret is generated per process
session_https_only: bool = True
# Magic links # Magic links
invite_ttl: int = 86400 # seconds invite_ttl: int = 86400 # seconds

View file

@ -9,7 +9,7 @@ from porchlight.config import Settings
@pytest.fixture @pytest.fixture
def settings() -> Settings: def settings() -> Settings:
return Settings(issuer="http://localhost:8000", sqlite_path=":memory:") return Settings(issuer="http://localhost:8000", sqlite_path=":memory:", session_https_only=False)
@pytest.fixture @pytest.fixture

View file

@ -145,3 +145,18 @@ class TestGenerateCSRFToken:
response2 = await client.get("/get-token") response2 = await client.get("/get-token")
token2 = response2.json()["token"] token2 = response2.json()["token"]
assert token1 == token2 assert token1 == token2
class TestAppIntegration:
"""Test CSRF middleware is wired into the real app."""
async def test_post_without_csrf_token_returns_403(self, client: AsyncClient) -> None:
"""Any POST to a session-protected endpoint without CSRF token gets 403."""
resp = await client.post("/login/password", data={"username": "x", "password": "y"})
assert resp.status_code == 403
async def test_exempt_token_endpoint(self, client: AsyncClient) -> None:
"""The /token endpoint is exempt from CSRF (uses client auth)."""
resp = await client.post("/token", data={"grant_type": "authorization_code", "code": "fake"})
# Should NOT be 403 — it should fail for auth reasons, not CSRF
assert resp.status_code != 403