feat: wire CSRF middleware and harden session cookie
This commit is contained in:
parent
b5ea9950a2
commit
d1f2b39cb6
4 changed files with 37 additions and 3 deletions
|
|
@ -8,11 +8,13 @@ from fastapi import FastAPI
|
|||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.requests import Request
|
||||
|
||||
from porchlight.authn.password import PasswordService
|
||||
from porchlight.authn.routes import router as authn_router
|
||||
from porchlight.authn.webauthn import WebAuthnService
|
||||
from porchlight.config import Settings, StorageBackend
|
||||
from porchlight.csrf import CSRFMiddleware, generate_csrf_token
|
||||
from porchlight.invite.service import MagicLinkService
|
||||
from porchlight.manage.routes import router as manage_router
|
||||
from porchlight.oidc.endpoints import router as oidc_router
|
||||
|
|
@ -107,10 +109,26 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
|||
|
||||
# Session middleware
|
||||
session_secret = settings.session_secret or secrets.token_hex(32)
|
||||
app.add_middleware(SessionMiddleware, secret_key=session_secret) # type: ignore[arg-type]
|
||||
app.add_middleware(
|
||||
CSRFMiddleware,
|
||||
exempt_paths={"/token", "/userinfo"},
|
||||
check_origin=settings.issuer,
|
||||
)
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=session_secret,
|
||||
same_site="lax",
|
||||
https_only=settings.session_https_only,
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
# Templates
|
||||
app.state.templates = Jinja2Templates(directory=str(PACKAGE_DIR / "templates"))
|
||||
templates = Jinja2Templates(directory=str(PACKAGE_DIR / "templates"))
|
||||
|
||||
def csrf_token_processor(request: Request) -> str:
|
||||
return generate_csrf_token(request)
|
||||
|
||||
templates.env.globals["csrf_token_processor"] = csrf_token_processor
|
||||
app.state.templates = templates
|
||||
|
||||
# Static files
|
||||
app.mount("/static", StaticFiles(directory=str(PACKAGE_DIR / "static")), name="static")
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class Settings(BaseSettings):
|
|||
|
||||
# Session
|
||||
session_secret: str | None = None # If None, a random secret is generated per process
|
||||
session_https_only: bool = True
|
||||
|
||||
# Magic links
|
||||
invite_ttl: int = 86400 # seconds
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from porchlight.config import Settings
|
|||
|
||||
@pytest.fixture
|
||||
def settings() -> Settings:
|
||||
return Settings(issuer="http://localhost:8000", sqlite_path=":memory:")
|
||||
return Settings(issuer="http://localhost:8000", sqlite_path=":memory:", session_https_only=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -145,3 +145,18 @@ class TestGenerateCSRFToken:
|
|||
response2 = await client.get("/get-token")
|
||||
token2 = response2.json()["token"]
|
||||
assert token1 == token2
|
||||
|
||||
|
||||
class TestAppIntegration:
|
||||
"""Test CSRF middleware is wired into the real app."""
|
||||
|
||||
async def test_post_without_csrf_token_returns_403(self, client: AsyncClient) -> None:
|
||||
"""Any POST to a session-protected endpoint without CSRF token gets 403."""
|
||||
resp = await client.post("/login/password", data={"username": "x", "password": "y"})
|
||||
assert resp.status_code == 403
|
||||
|
||||
async def test_exempt_token_endpoint(self, client: AsyncClient) -> None:
|
||||
"""The /token endpoint is exempt from CSRF (uses client auth)."""
|
||||
resp = await client.post("/token", data={"grant_type": "authorization_code", "code": "fake"})
|
||||
# Should NOT be 403 — it should fail for auth reasons, not CSRF
|
||||
assert resp.status_code != 403
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue