feat: wire CSRF middleware and harden session cookie
This commit is contained in:
parent
b5ea9950a2
commit
d1f2b39cb6
4 changed files with 37 additions and 3 deletions
|
|
@ -8,11 +8,13 @@ from fastapi import FastAPI
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
from porchlight.authn.password import PasswordService
|
from porchlight.authn.password import PasswordService
|
||||||
from porchlight.authn.routes import router as authn_router
|
from porchlight.authn.routes import router as authn_router
|
||||||
from porchlight.authn.webauthn import WebAuthnService
|
from porchlight.authn.webauthn import WebAuthnService
|
||||||
from porchlight.config import Settings, StorageBackend
|
from porchlight.config import Settings, StorageBackend
|
||||||
|
from porchlight.csrf import CSRFMiddleware, generate_csrf_token
|
||||||
from porchlight.invite.service import MagicLinkService
|
from porchlight.invite.service import MagicLinkService
|
||||||
from porchlight.manage.routes import router as manage_router
|
from porchlight.manage.routes import router as manage_router
|
||||||
from porchlight.oidc.endpoints import router as oidc_router
|
from porchlight.oidc.endpoints import router as oidc_router
|
||||||
|
|
@ -107,10 +109,26 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||||
|
|
||||||
# Session middleware
|
# Session middleware
|
||||||
session_secret = settings.session_secret or secrets.token_hex(32)
|
session_secret = settings.session_secret or secrets.token_hex(32)
|
||||||
app.add_middleware(SessionMiddleware, secret_key=session_secret) # type: ignore[arg-type]
|
app.add_middleware(
|
||||||
|
CSRFMiddleware,
|
||||||
|
exempt_paths={"/token", "/userinfo"},
|
||||||
|
check_origin=settings.issuer,
|
||||||
|
)
|
||||||
|
app.add_middleware(
|
||||||
|
SessionMiddleware,
|
||||||
|
secret_key=session_secret,
|
||||||
|
same_site="lax",
|
||||||
|
https_only=settings.session_https_only,
|
||||||
|
) # type: ignore[arg-type]
|
||||||
|
|
||||||
# Templates
|
# Templates
|
||||||
app.state.templates = Jinja2Templates(directory=str(PACKAGE_DIR / "templates"))
|
templates = Jinja2Templates(directory=str(PACKAGE_DIR / "templates"))
|
||||||
|
|
||||||
|
def csrf_token_processor(request: Request) -> str:
|
||||||
|
return generate_csrf_token(request)
|
||||||
|
|
||||||
|
templates.env.globals["csrf_token_processor"] = csrf_token_processor
|
||||||
|
app.state.templates = templates
|
||||||
|
|
||||||
# Static files
|
# Static files
|
||||||
app.mount("/static", StaticFiles(directory=str(PACKAGE_DIR / "static")), name="static")
|
app.mount("/static", StaticFiles(directory=str(PACKAGE_DIR / "static")), name="static")
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ class Settings(BaseSettings):
|
||||||
|
|
||||||
# Session
|
# Session
|
||||||
session_secret: str | None = None # If None, a random secret is generated per process
|
session_secret: str | None = None # If None, a random secret is generated per process
|
||||||
|
session_https_only: bool = True
|
||||||
|
|
||||||
# Magic links
|
# Magic links
|
||||||
invite_ttl: int = 86400 # seconds
|
invite_ttl: int = 86400 # seconds
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from porchlight.config import Settings
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def settings() -> Settings:
|
def settings() -> Settings:
|
||||||
return Settings(issuer="http://localhost:8000", sqlite_path=":memory:")
|
return Settings(issuer="http://localhost:8000", sqlite_path=":memory:", session_https_only=False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -145,3 +145,18 @@ class TestGenerateCSRFToken:
|
||||||
response2 = await client.get("/get-token")
|
response2 = await client.get("/get-token")
|
||||||
token2 = response2.json()["token"]
|
token2 = response2.json()["token"]
|
||||||
assert token1 == token2
|
assert token1 == token2
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppIntegration:
|
||||||
|
"""Test CSRF middleware is wired into the real app."""
|
||||||
|
|
||||||
|
async def test_post_without_csrf_token_returns_403(self, client: AsyncClient) -> None:
|
||||||
|
"""Any POST to a session-protected endpoint without CSRF token gets 403."""
|
||||||
|
resp = await client.post("/login/password", data={"username": "x", "password": "y"})
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
async def test_exempt_token_endpoint(self, client: AsyncClient) -> None:
|
||||||
|
"""The /token endpoint is exempt from CSRF (uses client auth)."""
|
||||||
|
resp = await client.post("/token", data={"grant_type": "authorization_code", "code": "fake"})
|
||||||
|
# Should NOT be 403 — it should fail for auth reasons, not CSRF
|
||||||
|
assert resp.status_code != 403
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue