diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..a0fdc61 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,61 @@ +# Set the maximum line length to 120. +line-length = 120 +target-version = "py313" + +[lint] +select = [ + "A", + "ANN", + "ASYNC", + "B", + "C4", + "DTZ", + "E", + "ERA", + "F", + "FAST", + "FLY", + "FURB", + "I", + "ISC", + "PERF", + "PGH", + "PIE", + "PL", + "PT", + "RUF", + "SIM", + "UP", + "W", +] + +ignore = [ + "SIM102", # collapsible-if + "SIM103", # return-bool-condition-directly (keeping explicit if/else for clarity) + "SIM108", # if-else-block-instead-of-if-exp + + # Since we use ruff as a formatter, the following rules should be ignored + # See: https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules + "W191", + "E111", + "E114", + "E117", + "E501", + "D206", + "D300", + "Q000", + "Q001", + "Q002", + "Q003", + "COM812", + "COM819", + "ISC002", +] + +allowed-confusables = ["’", "х"] + +[lint.flake8-annotations] +allow-star-arg-any = true + +[lint.per-file-ignores] +"**/test_*.py" = ["PLR2004"] # magic-value-comparison - allow magic numbers in tests diff --git a/src/porchlight/admin/routes.py b/src/porchlight/admin/routes.py index b6692cd..4bcc0a8 100644 --- a/src/porchlight/admin/routes.py +++ b/src/porchlight/admin/routes.py @@ -1,4 +1,5 @@ from base64 import urlsafe_b64decode +from typing import Annotated from fastapi import APIRouter, Form, Request, Response from fastapi.responses import HTMLResponse, RedirectResponse @@ -6,7 +7,7 @@ from pydantic import ValidationError from porchlight.dependencies import get_session_user from porchlight.models import User -from porchlight.validation import ProfileUpdate +from porchlight.validation import GroupListInput, ProfileUpdate, UsernameInput, format_validation_errors router = APIRouter(prefix="/admin", tags=["admin"]) @@ -99,7 +100,7 @@ async def user_detail(request: Request, userid: str) -> Response: @router.post("/invite", response_class=HTMLResponse) async def create_invite( request: Request, - username: str = Form(), + username: Annotated[str, Form()], ) -> Response: session_user = get_session_user(request) if session_user is None: @@ -109,17 +110,19 @@ async def create_invite( if admin is None: return HTMLResponse("Forbidden", status_code=403) - username = username.strip() - if not username: - return HTMLResponse('
Username is required
') + try: + validated = UsernameInput(username=username) + except ValidationError as exc: + return HTMLResponse(format_validation_errors(exc)) magic_link_service = request.app.state.magic_link_service settings = request.app.state.settings - link = await magic_link_service.create(username=username, created_by=admin.username, note="admin invite") + link = await magic_link_service.create(username=validated.username, created_by=admin.username, note="admin invite") url = f"{settings.issuer}/register/{link.token}" return HTMLResponse( - f'
Invite created for {username}:
{url}
' + f'
Invite created for {validated.username}:
' + f'
{url}
' ) @@ -128,13 +131,6 @@ async def create_invite( async def update_user_profile( request: Request, userid: str, - given_name: str = Form(""), - family_name: str = Form(""), - preferred_username: str = Form(""), - email: str = Form(""), - phone_number: str = Form(""), - picture: str = Form(""), - locale: str = Form(""), ) -> Response: session_user = get_session_user(request) if session_user is None: @@ -144,32 +140,19 @@ async def update_user_profile( return HTMLResponse("Forbidden", status_code=403) # Validate + form = await request.form() try: profile = ProfileUpdate( - given_name=given_name, - family_name=family_name, - preferred_username=preferred_username, - email=email, - phone_number=phone_number, - picture=picture, - locale=locale, + given_name=str(form.get("given_name", "")), + family_name=str(form.get("family_name", "")), + preferred_username=str(form.get("preferred_username", "")), + email=str(form.get("email", "")), + phone_number=str(form.get("phone_number", "")), + picture=str(form.get("picture", "")), + locale=str(form.get("locale", "")), ) except ValidationError as exc: - error = exc.errors()[0] - field = error["loc"][-1] if error["loc"] else "input" - msg = error["msg"] - labels = { - "given_name": "Given name", - "family_name": "Family name", - "preferred_username": "Display name", - "email": "Email", - "phone_number": "Phone number", - "picture": "Picture URL", - "locale": "Locale", - } - label = labels.get(str(field), str(field)) - display_msg = msg.removeprefix("Value error, ") if error["type"] == "value_error" else f"{label}: {msg}" - return HTMLResponse(f'
{display_msg}
') + return HTMLResponse(format_validation_errors(exc)) user_repo = request.app.state.user_repo user = await user_repo.get_by_userid(userid) @@ -196,7 +179,7 @@ async def update_user_profile( async def update_user_groups( request: Request, userid: str, - groups: str = Form(""), + groups: Annotated[str, Form()] = "", ) -> Response: session_user = get_session_user(request) if session_user is None: @@ -205,13 +188,17 @@ async def update_user_groups( if admin is None: return HTMLResponse("Forbidden", status_code=403) + try: + validated = GroupListInput(groups=groups) + except ValidationError as exc: + return HTMLResponse(format_validation_errors(exc)) + user_repo = request.app.state.user_repo user = await user_repo.get_by_userid(userid) if user is None: return HTMLResponse("User not found", status_code=404) - group_list = [g.strip() for g in groups.split(",") if g.strip()] - updated = user.model_copy(update={"groups": group_list}) + updated = user.model_copy(update={"groups": validated.group_list}) await user_repo.update(updated) return HTMLResponse('
Groups updated
') @@ -353,7 +340,7 @@ async def delete_user(request: Request, userid: str) -> Response: return HTMLResponse("Forbidden", status_code=403) # Prevent self-deletion - admin_userid, _ = get_session_user(request) + admin_userid, _ = session_user if userid == admin_userid: return HTMLResponse('
Cannot delete your own account
') diff --git a/src/porchlight/app.py b/src/porchlight/app.py index acf3a70..91bb490 100644 --- a/src/porchlight/app.py +++ b/src/porchlight/app.py @@ -8,8 +8,11 @@ from urllib.parse import urlparse from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates +from slowapi.errors import RateLimitExceeded from starlette.middleware.sessions import SessionMiddleware from starlette.requests import Request +from starlette.responses import HTMLResponse as StarletteHTMLResponse +from starlette.responses import Response from porchlight.admin.routes import router as admin_router from porchlight.authn.password import PasswordService @@ -21,6 +24,7 @@ from porchlight.invite.service import MagicLinkService from porchlight.manage.routes import router as manage_router from porchlight.oidc.endpoints import router as oidc_router from porchlight.oidc.provider import create_oidc_server +from porchlight.rate_limit import limiter from porchlight.store.sqlite.db import open_db from porchlight.store.sqlite.repositories import ( SQLiteConsentRepository, @@ -123,6 +127,16 @@ def create_app(settings: Settings | None = None) -> FastAPI: https_only=settings.session_https_only, ) + # Rate limiting + app.state.limiter = limiter + + @app.exception_handler(RateLimitExceeded) + async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> StarletteHTMLResponse: + return StarletteHTMLResponse( + '
Too many attempts. Please try again later.
', + status_code=429, + ) + # Templates templates = Jinja2Templates(directory=str(PACKAGE_DIR / "templates")) @@ -147,7 +161,7 @@ def create_app(settings: Settings | None = None) -> FastAPI: return {"status": "ok"} @app.get("/") - async def landing(request: Request): # type: ignore[no-untyped-def] + async def landing(request: Request) -> Response: return templates.TemplateResponse(request, "index.html") return app diff --git a/src/porchlight/authn/routes.py b/src/porchlight/authn/routes.py index 2ef2cea..c710f3f 100644 --- a/src/porchlight/authn/routes.py +++ b/src/porchlight/authn/routes.py @@ -1,10 +1,12 @@ from base64 import urlsafe_b64decode +from typing import Annotated from fastapi import APIRouter, Form, Request, Response from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fido2.webauthn import AttestedCredentialData, AuthenticationResponse from porchlight.models import User +from porchlight.rate_limit import limiter from porchlight.userid import generate_unique_userid router = APIRouter(tags=["authn"]) @@ -28,10 +30,11 @@ async def login_page(request: Request) -> HTMLResponse: @router.post("/login/password", response_class=HTMLResponse) +@limiter.limit("5/minute") async def login_password( request: Request, - username: str = Form(), - password: str = Form(), + username: Annotated[str, Form()], + password: Annotated[str, Form()], ) -> Response: user_repo = request.app.state.user_repo cred_repo = request.app.state.credential_repo @@ -50,6 +53,9 @@ async def login_password( if not password_service.verify(credential.password_hash, password): return HTMLResponse(error_html) + if not user.active: + return HTMLResponse(error_html) + request.session["userid"] = user.userid request.session["username"] = user.username @@ -77,6 +83,8 @@ async def register_magic_link(request: Request, token: str) -> Response: existing_user = await user_repo.get_by_username(link.username) if existing_user is not None: + if not existing_user.active: + return HTMLResponse("

This account has been deactivated.

", status_code=400) user = existing_user else: userid = await generate_unique_userid(user_repo) @@ -102,6 +110,7 @@ async def login_webauthn_begin(request: Request) -> Response: @router.post("/login/webauthn/complete") +@limiter.limit("10/minute") async def login_webauthn_complete(request: Request) -> Response: webauthn_service = request.app.state.webauthn_service user_repo = request.app.state.user_repo @@ -144,8 +153,8 @@ async def login_webauthn_complete(request: Request) -> Response: await cred_repo.update_webauthn(stored) user = await user_repo.get_by_userid(userid) - if user is None: - return JSONResponse({"error": "User not found"}, status_code=400) + if user is None or not user.active: + return JSONResponse({"error": "Authentication failed"}, status_code=400) request.session["userid"] = user.userid request.session["username"] = user.username diff --git a/src/porchlight/csrf.py b/src/porchlight/csrf.py index 630c723..f9b8a68 100644 --- a/src/porchlight/csrf.py +++ b/src/porchlight/csrf.py @@ -61,7 +61,7 @@ class CSRFMiddleware: # Origin check (defense-in-depth) if self.check_origin is not None: origin = request.headers.get("origin") - if origin is not None and origin != "null" and origin != self.check_origin: + if origin is not None and origin not in ("null", self.check_origin): logger.warning("CSRF origin mismatch: expected %s, got %s", self.check_origin, origin) response = HTMLResponse( "

403 Forbidden

Origin mismatch

", diff --git a/src/porchlight/manage/routes.py b/src/porchlight/manage/routes.py index a3c4887..69afdff 100644 --- a/src/porchlight/manage/routes.py +++ b/src/porchlight/manage/routes.py @@ -1,4 +1,5 @@ from base64 import urlsafe_b64decode +from typing import Annotated from fastapi import APIRouter, Form, Request, Response from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse @@ -7,7 +8,7 @@ from pydantic import ValidationError from porchlight.dependencies import get_session_user from porchlight.models import PasswordCredential, WebAuthnCredential -from porchlight.validation import ProfileUpdate +from porchlight.validation import PasswordChange, PasswordSet, ProfileUpdate, format_validation_errors router = APIRouter(prefix="/manage", tags=["manage"]) @@ -53,8 +54,9 @@ async def credentials_page(request: Request) -> Response: @router.post("/credentials/password", response_class=HTMLResponse) async def set_password( request: Request, - password: str = Form(), - confirm: str = Form(), + password: Annotated[str, Form()], + confirm: Annotated[str, Form()], + current_password: Annotated[str, Form()] = "", ) -> Response: session_user = get_session_user(request) if session_user is None: @@ -64,15 +66,30 @@ async def set_password( cred_repo = request.app.state.credential_repo password_service = request.app.state.password_service - if password != confirm: - return HTMLResponse('
Passwords do not match
') - - if len(password) < 8: - return HTMLResponse('
Password must be at least 8 characters
') - - password_hash = password_service.hash(password) - existing = await cred_repo.get_password_by_user(userid) + has_password = existing is not None + + # Validate input + try: + if has_password: + validated = PasswordChange( + current_password=current_password, + password=password, + confirm=confirm, + ) + else: + validated = PasswordSet(password=password, confirm=confirm) + except ValidationError as exc: + return HTMLResponse(format_validation_errors(exc)) + + # Verify current password if changing + if has_password and isinstance(validated, PasswordChange): + if not password_service.verify(existing.password_hash, validated.current_password): + return HTMLResponse('
Current password is incorrect
') + + # Store new password + password_hash = password_service.hash(validated.password) + if existing is not None: await cred_repo.delete_password(userid) @@ -200,13 +217,6 @@ async def profile_page(request: Request) -> Response: @router.post("/profile", response_class=HTMLResponse) async def update_profile( request: Request, - given_name: str = Form(""), - family_name: str = Form(""), - preferred_username: str = Form(""), - email: str = Form(""), - phone_number: str = Form(""), - picture: str = Form(""), - locale: str = Form(""), ) -> Response: session_user = get_session_user(request) if session_user is None: @@ -214,34 +224,19 @@ async def update_profile( userid, _username = session_user + form = await request.form() try: profile = ProfileUpdate( - given_name=given_name, - family_name=family_name, - preferred_username=preferred_username, - email=email, - phone_number=phone_number, - picture=picture, - locale=locale, + given_name=str(form.get("given_name", "")), + family_name=str(form.get("family_name", "")), + preferred_username=str(form.get("preferred_username", "")), + email=str(form.get("email", "")), + phone_number=str(form.get("phone_number", "")), + picture=str(form.get("picture", "")), + locale=str(form.get("locale", "")), ) except ValidationError as exc: - error = exc.errors()[0] - field = error["loc"][-1] if error["loc"] else "input" - msg = error["msg"] - # Produce user-friendly field labels - labels = { - "given_name": "Given name", - "family_name": "Family name", - "preferred_username": "Display name", - "email": "Email", - "phone_number": "Phone number", - "picture": "Picture URL", - "locale": "Locale", - } - label = labels.get(str(field), str(field)) - # Use custom message for value errors (e.g. picture URL), generic pydantic message otherwise - display_msg = msg.removeprefix("Value error, ") if error["type"] == "value_error" else f"{label}: {msg}" - return HTMLResponse(f'
{display_msg}
') + return HTMLResponse(format_validation_errors(exc)) user_repo = request.app.state.user_repo user = await user_repo.get_by_userid(userid) diff --git a/src/porchlight/oidc/claims.py b/src/porchlight/oidc/claims.py index 45987c0..edf2373 100644 --- a/src/porchlight/oidc/claims.py +++ b/src/porchlight/oidc/claims.py @@ -1,5 +1,7 @@ """OIDC claims mapping and UserInfo source.""" +from typing import Any + from idpyoidc.server.user_info import UserInfo from porchlight.models import User @@ -28,9 +30,7 @@ def user_to_claims(user: User) -> dict: "locale": user.locale, } - for claim_name, value in optional_fields.items(): - if value is not None: - claims[claim_name] = value + claims.update({claim_name: value for claim_name, value in optional_fields.items() if value is not None}) # updated_at as Unix timestamp (OIDC spec requires number) if user.updated_at: @@ -46,7 +46,7 @@ class PorchlightUserInfo(UserInfo): idpyoidc calls __call__() synchronously to look up claims. """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(db={}, **kwargs) def set_user_claims(self, user_id: str, claims: dict) -> None: diff --git a/src/porchlight/oidc/endpoints.py b/src/porchlight/oidc/endpoints.py index ee221ee..67d4968 100644 --- a/src/porchlight/oidc/endpoints.py +++ b/src/porchlight/oidc/endpoints.py @@ -106,7 +106,7 @@ async def authorization_complete(request: Request) -> Response: ) -async def _check_consent_or_complete( +async def _check_consent_or_complete( # noqa: PLR0913 request: Request, oidc_server: object, endpoint: object, @@ -137,7 +137,7 @@ async def _check_consent_or_complete( return RedirectResponse("/consent", status_code=303) -async def _complete_authorization( +async def _complete_authorization( # noqa: PLR0913 request: Request, oidc_server: object, endpoint: object, @@ -332,11 +332,10 @@ async def consent_submit(request: Request) -> Response: redirect_uri = auth_params.get("redirect_uri", "") state = auth_params.get("state", "") - if action == "deny": - params = urlencode({"error": "access_denied", "state": state}) - return RedirectResponse(f"{redirect_uri}?{params}", status_code=303) - if action != "allow": + if action == "deny": + params = urlencode({"error": "access_denied", "state": state}) + return RedirectResponse(f"{redirect_uri}?{params}", status_code=303) return HTMLResponse("

Error

Invalid action

", status_code=400) # Allow — collect approved scopes @@ -357,11 +356,9 @@ async def consent_submit(request: Request) -> Response: try: parsed = endpoint.parse_request(auth_params) + if "error" in parsed: + raise ValueError(parsed.get("error_description", parsed["error"])) except Exception as exc: return HTMLResponse(f"

Error

{exc}

", status_code=400) - if "error" in parsed: - error_desc = parsed.get("error_description", parsed["error"]) - return HTMLResponse(f"

Error

{error_desc}

", status_code=400) - return await _complete_authorization(request, oidc_server, endpoint, parsed, userid, username) diff --git a/src/porchlight/rate_limit.py b/src/porchlight/rate_limit.py new file mode 100644 index 0000000..38404a8 --- /dev/null +++ b/src/porchlight/rate_limit.py @@ -0,0 +1,4 @@ +from slowapi import Limiter +from slowapi.util import get_remote_address + +limiter = Limiter(key_func=get_remote_address) diff --git a/src/porchlight/store/sqlite/migrations.py b/src/porchlight/store/sqlite/migrations.py index afc5628..2bfbb32 100644 --- a/src/porchlight/store/sqlite/migrations.py +++ b/src/porchlight/store/sqlite/migrations.py @@ -1,11 +1,13 @@ from pathlib import Path import aiosqlite +import anyio async def run_migrations(db: aiosqlite.Connection, migrations_dir: Path) -> int: """Apply unapplied SQL migration files in order. Returns count of newly applied migrations.""" - if not migrations_dir.is_dir(): + async_dir = anyio.Path(migrations_dir) + if not await async_dir.is_dir(): raise FileNotFoundError(f"Migrations directory not found: {migrations_dir}") await db.execute( @@ -22,19 +24,22 @@ async def run_migrations(db: aiosqlite.Connection, migrations_dir: Path) -> int: async with db.execute("SELECT filename FROM _migrations") as cursor: applied = {row[0] async for row in cursor} - migration_files = sorted(migrations_dir.glob("*.sql")) + migration_files = sorted( + [f async for f in async_dir.iterdir() if f.suffix == ".sql"], + key=lambda f: f.name, + ) count = 0 for migration_file in migration_files: if migration_file.name in applied: continue - sql = migration_file.read_text(encoding="utf-8") + sql = await migration_file.read_text(encoding="utf-8") await db.execute("BEGIN") try: for statement in sql.split(";"): - statement = statement.strip() - if statement: - await db.execute(statement) + cleaned = statement.strip() + if cleaned: + await db.execute(cleaned) await db.execute( "INSERT INTO _migrations (filename) VALUES (?)", (migration_file.name,), diff --git a/src/porchlight/templates/admin/user_detail.html b/src/porchlight/templates/admin/user_detail.html index f65717e..99138db 100644 --- a/src/porchlight/templates/admin/user_detail.html +++ b/src/porchlight/templates/admin/user_detail.html @@ -11,6 +11,7 @@

Profile

+
@@ -48,6 +49,7 @@

Groups

+
{% for group in target_user.groups %} {{ group }} diff --git a/src/porchlight/templates/admin/users.html b/src/porchlight/templates/admin/users.html index 26d660b..675dd11 100644 --- a/src/porchlight/templates/admin/users.html +++ b/src/porchlight/templates/admin/users.html @@ -8,8 +8,11 @@

Create invite

+ diff --git a/src/porchlight/templates/manage/credentials.html b/src/porchlight/templates/manage/credentials.html index 05071c7..6d4c269 100644 --- a/src/porchlight/templates/manage/credentials.html +++ b/src/porchlight/templates/manage/credentials.html @@ -40,13 +40,19 @@ {% endif %}
+ {% if has_password %} +
+ + +
+ {% endif %}
- +
- +
diff --git a/src/porchlight/validation.py b/src/porchlight/validation.py index 512ad02..5d3b48c 100644 --- a/src/porchlight/validation.py +++ b/src/porchlight/validation.py @@ -1,8 +1,10 @@ +import re from typing import Annotated from urllib.parse import urlparse -from pydantic import BaseModel, EmailStr, Field, field_validator +from pydantic import BaseModel, EmailStr, Field, ValidationError, field_validator, model_validator from pydantic_extra_types.phone_numbers import PhoneNumberValidator +from zxcvbn import zxcvbn E164Phone = Annotated[str, PhoneNumberValidator(number_format="E164")] @@ -40,3 +42,134 @@ class ProfileUpdate(BaseModel): if parsed.scheme not in ("http", "https") or not parsed.netloc: raise ValueError("Picture URL must be a valid HTTP or HTTPS URL") return v + + @field_validator("locale", mode="before") + @classmethod + def validate_locale(cls, v: str) -> str: + if isinstance(v, str): + v = v.strip() + if v == "": + return "" + if not re.match(r"^[a-z]{2,3}(-[A-Z][a-z]{3})?(-[A-Z]{2})?$", v): + raise ValueError("Locale must be a valid BCP 47 language tag (e.g. en, sv-SE, zh-Hans-CN)") + return v + + +class UsernameInput(BaseModel): + username: str = Field(max_length=255) + + @field_validator("username", mode="before") + @classmethod + def validate_username(cls, v: str) -> str: + if isinstance(v, str): + v = v.strip() + if not v: + raise ValueError("Username is required") + if not re.match(r"^[a-zA-Z0-9_.@-]+$", v): + raise ValueError("Username may only contain letters, digits, dots, hyphens, underscores, and @") + return v + + +class GroupListInput(BaseModel): + groups: str = "" + + @property + def group_list(self) -> list[str]: + """Parse comma-separated groups into a deduplicated list.""" + seen: set[str] = set() + result: list[str] = [] + for g in (g.strip() for g in self.groups.split(",") if g.strip()): + if g not in seen: + seen.add(g) + result.append(g) + return result + + @field_validator("groups", mode="before") + @classmethod + def validate_groups(cls, v: str) -> str: + if isinstance(v, str): + names = [g.strip() for g in v.split(",") if g.strip()] + for name in names: + if not re.match(r"^[a-z0-9_-]{1,64}$", name): + raise ValueError( + f"Invalid group name '{name}'. " + "Groups must be 1-64 lowercase letters, digits, hyphens, or underscores." + ) + return v + + +MIN_PASSWORD_STRENGTH = 2 + + +class PasswordSet(BaseModel): + password: str = Field(min_length=8, max_length=256) + confirm: str + + @model_validator(mode="after") + def validate_password(self) -> "PasswordSet": + if self.password != self.confirm: + raise ValueError("Passwords do not match") + result = zxcvbn(self.password) + if result["score"] < MIN_PASSWORD_STRENGTH: + feedback = result.get("feedback", {}) + warning = feedback.get("warning", "") + suggestions = feedback.get("suggestions", []) + msg = "Password is too easily guessed." + if warning: + msg += f" {warning}." + if suggestions: + msg += " " + " ".join(suggestions) + raise ValueError(msg) + return self + + +class PasswordChange(PasswordSet): + current_password: str + + @field_validator("current_password", mode="before") + @classmethod + def validate_current_password(cls, v: str) -> str: + if isinstance(v, str) and v.strip() == "": + raise ValueError("Current password is required") + return v + + +FIELD_LABELS: dict[str, str] = { + "given_name": "Given name", + "family_name": "Family name", + "preferred_username": "Display name", + "email": "Email", + "phone_number": "Phone number", + "picture": "Picture URL", + "locale": "Locale", + "username": "Username", + "groups": "Groups", + "password": "Password", + "confirm": "Confirm password", + "current_password": "Current password", +} + + +def format_validation_errors(exc: ValidationError) -> str: + """Format Pydantic ValidationError into user-friendly HTML.""" + messages: list[str] = [] + for error in exc.errors(): + field = str(error["loc"][-1]) if error["loc"] else "input" + label = FIELD_LABELS.get(field, field) + msg = error["msg"] + if error["type"] == "value_error": + raw = msg.removeprefix("Value error, ") + # If the message already starts with the label, don't duplicate it + if raw.startswith(label): + display_msg = raw + else: + display_msg = f"{label}: {raw}" + else: + display_msg = f"{label}: {msg}" + messages.append(display_msg) + + if len(messages) == 1: + return f'
{messages[0]}
' + + items = "".join(f"
  • {m}
  • " for m in messages) + return f'
      {items}
    ' diff --git a/tests/conftest.py b/tests/conftest.py index 3506c70..cc9eee3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from httpx import ASGITransport, AsyncClient from porchlight.app import create_app from porchlight.config import Settings +from porchlight.rate_limit import limiter @pytest.fixture @@ -21,6 +22,12 @@ async def client(settings: Settings) -> AsyncIterator[AsyncClient]: yield ac +@pytest.fixture(autouse=True) +def _reset_rate_limiter() -> None: + """Reset the rate limiter storage before each test.""" + limiter.reset() + + async def get_csrf_token(client: AsyncClient) -> str: """Get a CSRF token by visiting the login page. diff --git a/tests/e2e/setup_db.py b/tests/e2e/setup_db.py index bb43e8e..5dac7aa 100644 --- a/tests/e2e/setup_db.py +++ b/tests/e2e/setup_db.py @@ -24,6 +24,107 @@ from porchlight.store.sqlite.repositories import ( ) +async def _create_user_with_password( + user_repo: SQLiteUserRepository, + cred_repo: SQLiteCredentialRepository, + password_service: PasswordService, + user: User, + password: str, +) -> None: + """Helper to create a user and set their password credential.""" + await user_repo.create(user) + password_hash = password_service.hash(password) + await cred_repo.create_password(PasswordCredential(user_id=user.userid, password_hash=password_hash)) + + +async def _seed_test_users( + user_repo: SQLiteUserRepository, + cred_repo: SQLiteCredentialRepository, + password_service: PasswordService, + result: dict[str, str], +) -> None: + """Create all test users with passwords.""" + # Login test user + await _create_user_with_password( + user_repo, + cred_repo, + password_service, + User(userid="test-user-01", username="testuser", groups=["users"]), + "testpassword123", + ) + result["login_username"] = "testuser" + result["login_password"] = "testpassword123" + + # Credentials management test user + await _create_user_with_password( + user_repo, + cred_repo, + password_service, + User(userid="test-user-02", username="creduser", groups=["users"]), + "credpassword123", + ) + result["cred_username"] = "creduser" + result["cred_password"] = "credpassword123" + + # WebAuthn registration test user + await _create_user_with_password( + user_repo, + cred_repo, + password_service, + User(userid="test-user-03", username="webauthnuser", groups=["users"]), + "webauthnpass123", + ) + result["webauthn_username"] = "webauthnuser" + result["webauthn_password"] = "webauthnpass123" + result["webauthn_userid"] = "test-user-03" + + # Profile management test user + await _create_user_with_password( + user_repo, + cred_repo, + password_service, + User( + userid="test-user-04", + username="profileuser", + given_name="Alice", + family_name="Smith", + preferred_username="asmith", + email="alice@example.com", + phone_number="+12025551234", + picture="https://example.com/alice.jpg", + locale="en", + groups=["users"], + ), + "profilepass123", + ) + result["profile_username"] = "profileuser" + result["profile_password"] = "profilepass123" + + # Admin user for admin page tests + await _create_user_with_password( + user_repo, + cred_repo, + password_service, + User( + userid="test-user-05", + username="adminuser", + given_name="Admin", + family_name="User", + email="admin@example.com", + groups=["admin", "users"], + ), + "adminpass123", + ) + result["admin_username"] = "adminuser" + result["admin_password"] = "adminpass123" + result["admin_userid"] = "test-user-05" + + # Disposable user for admin delete test + await user_repo.create(User(userid="test-user-06", username="disposableuser", groups=["users"])) + result["disposable_userid"] = "test-user-06" + result["disposable_username"] = "disposableuser" + + async def seed() -> None: db_path = os.environ.get("OIDC_OP_SQLITE_PATH") if not db_path: @@ -39,89 +140,21 @@ async def seed() -> None: password_service = PasswordService() magic_link_service = MagicLinkService(repo=magic_link_repo) - result = {} + result: dict[str, str] = {} - # 1. Create a magic link for registration test + # Create magic link for registration test link = await magic_link_service.create(username="newuser") result["register_token"] = link.token result["register_username"] = "newuser" - # 2. Create a user with a password for login test - user = User(userid="test-user-01", username="testuser", groups=["users"]) - await user_repo.create(user) - password_hash = password_service.hash("testpassword123") - await cred_repo.create_password(PasswordCredential(user_id=user.userid, password_hash=password_hash)) - result["login_username"] = "testuser" - result["login_password"] = "testpassword123" + # Create all test users + await _seed_test_users(user_repo, cred_repo, password_service, result) - # 3. Create a separate user for credentials management test - cred_user = User(userid="test-user-02", username="creduser", groups=["users"]) - await user_repo.create(cred_user) - cred_password_hash = password_service.hash("credpassword123") - await cred_repo.create_password(PasswordCredential(user_id=cred_user.userid, password_hash=cred_password_hash)) - result["cred_username"] = "creduser" - result["cred_password"] = "credpassword123" - - # 5. Create a user with password for WebAuthn registration tests - # (login with password first, then register a passkey) - webauthn_user = User(userid="test-user-03", username="webauthnuser", groups=["users"]) - await user_repo.create(webauthn_user) - webauthn_password_hash = password_service.hash("webauthnpass123") - await cred_repo.create_password( - PasswordCredential(user_id=webauthn_user.userid, password_hash=webauthn_password_hash) - ) - result["webauthn_username"] = "webauthnuser" - result["webauthn_password"] = "webauthnpass123" - result["webauthn_userid"] = "test-user-03" - - # 4. Create an expired/used magic link for negative test + # Create an expired/used magic link for negative test expired_link = await magic_link_service.create(username="expired") await magic_link_service.mark_used(expired_link.token) result["used_token"] = expired_link.token - # 5. Create a user with profile data for profile management tests - profile_user = User( - userid="test-user-04", - username="profileuser", - given_name="Alice", - family_name="Smith", - preferred_username="asmith", - email="alice@example.com", - phone_number="+12025551234", - picture="https://example.com/alice.jpg", - locale="en", - groups=["users"], - ) - await user_repo.create(profile_user) - profile_password_hash = password_service.hash("profilepass123") - await cred_repo.create_password( - PasswordCredential(user_id=profile_user.userid, password_hash=profile_password_hash) - ) - result["profile_username"] = "profileuser" - result["profile_password"] = "profilepass123" - - # 6. Admin user for admin page tests - admin_user = User( - userid="test-user-05", - username="adminuser", - given_name="Admin", - family_name="User", - email="admin@example.com", - groups=["admin", "users"], - ) - await user_repo.create(admin_user) - admin_password_hash = password_service.hash("adminpass123") - await cred_repo.create_password(PasswordCredential(user_id=admin_user.userid, password_hash=admin_password_hash)) - result["admin_username"] = "adminuser" - result["admin_password"] = "adminpass123" - result["admin_userid"] = "test-user-05" - - # 7. Disposable user for admin delete test (not used by any other tests) - disposable_user = User(userid="test-user-06", username="disposableuser", groups=["users"]) - await user_repo.create(disposable_user) - result["disposable_userid"] = "test-user-06" - result["disposable_username"] = "disposableuser" - await db.commit() await db.close() print(json.dumps(result)) diff --git a/tests/test_admin/test_admin_routes.py b/tests/test_admin/test_admin_routes.py index b579a44..603af60 100644 --- a/tests/test_admin/test_admin_routes.py +++ b/tests/test_admin/test_admin_routes.py @@ -1,3 +1,4 @@ +from base64 import urlsafe_b64encode from datetime import UTC, datetime import pytest @@ -46,7 +47,7 @@ async def _login( ) -async def _create_target_user( +async def _create_target_user( # noqa: PLR0913 client: AsyncClient, *, userid: str = "target-user-01", @@ -365,8 +366,6 @@ async def test_delete_webauthn_credential(client: AsyncClient) -> None: ) # URL uses base64url without padding - from base64 import urlsafe_b64encode - credential_id_b64 = urlsafe_b64encode(credential_id).decode().rstrip("=") token = await get_csrf_token(client) diff --git a/tests/test_admin_groups_validation.py b/tests/test_admin_groups_validation.py new file mode 100644 index 0000000..ff662e9 --- /dev/null +++ b/tests/test_admin_groups_validation.py @@ -0,0 +1,85 @@ +from datetime import UTC, datetime + +import pytest +from httpx import AsyncClient + +from porchlight.authn.password import PasswordHasher, PasswordService +from porchlight.models import PasswordCredential, User +from tests.conftest import get_csrf_token + + +async def _setup_admin_and_target(client: AsyncClient) -> tuple[str, str]: + """Create admin + target user, login as admin, return (token, target_userid).""" + app = client._transport.app + user_repo = app.state.user_repo + cred_repo = app.state.credential_repo + + admin = User( + userid="admin-g01", + username="admin_g", + groups=["admin", "users"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + await user_repo.create(admin) + + target = User( + userid="target-g01", + username="target_g", + groups=["users"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + await user_repo.create(target) + + svc = PasswordService(hasher=PasswordHasher(time_cost=1, memory_cost=8192)) + await cred_repo.create_password(PasswordCredential(user_id=admin.userid, password_hash=svc.hash("AdminPass123!"))) + + token = await get_csrf_token(client) + await client.post( + "/login/password", + data={"username": "admin_g", "password": "AdminPass123!"}, + headers={"HX-Request": "true", "X-CSRF-Token": token}, + ) + return token, target.userid + + +@pytest.mark.asyncio +async def test_valid_groups(client: AsyncClient) -> None: + token, userid = await _setup_admin_and_target(client) + response = await client.post( + f"/admin/users/{userid}/groups", + data={"groups": "users, staff"}, + headers={"X-CSRF-Token": token}, + ) + assert "Groups updated" in response.text + + app = client._transport.app + user = await app.state.user_repo.get_by_userid(userid) + assert sorted(user.groups) == ["staff", "users"] + + +@pytest.mark.asyncio +async def test_invalid_group_name_rejected(client: AsyncClient) -> None: + token, userid = await _setup_admin_and_target(client) + response = await client.post( + f"/admin/users/{userid}/groups", + data={"groups": "users, Bad Group!"}, + headers={"X-CSRF-Token": token}, + ) + assert "alert" in response.text + + +@pytest.mark.asyncio +async def test_empty_groups_clears(client: AsyncClient) -> None: + token, userid = await _setup_admin_and_target(client) + response = await client.post( + f"/admin/users/{userid}/groups", + data={"groups": ""}, + headers={"X-CSRF-Token": token}, + ) + assert "Groups updated" in response.text + + app = client._transport.app + user = await app.state.user_repo.get_by_userid(userid) + assert user.groups == [] diff --git a/tests/test_admin_invite_validation.py b/tests/test_admin_invite_validation.py new file mode 100644 index 0000000..6bf1b68 --- /dev/null +++ b/tests/test_admin_invite_validation.py @@ -0,0 +1,71 @@ +from datetime import UTC, datetime + +import pytest +from httpx import AsyncClient + +from porchlight.authn.password import PasswordHasher, PasswordService +from porchlight.models import PasswordCredential, User +from tests.conftest import get_csrf_token + + +async def _login_admin(client: AsyncClient) -> str: + """Create and login as admin user, return CSRF token.""" + app = client._transport.app + user_repo = app.state.user_repo + cred_repo = app.state.credential_repo + + user = User( + userid="admin-01", + username="admin", + groups=["admin", "users"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + await user_repo.create(user) + + svc = PasswordService(hasher=PasswordHasher(time_cost=1, memory_cost=8192)) + await cred_repo.create_password(PasswordCredential(user_id=user.userid, password_hash=svc.hash("AdminPass123!"))) + + token = await get_csrf_token(client) + await client.post( + "/login/password", + data={"username": "admin", "password": "AdminPass123!"}, + headers={"HX-Request": "true", "X-CSRF-Token": token}, + ) + return token + + +@pytest.mark.asyncio +async def test_invite_valid_username(client: AsyncClient) -> None: + token = await _login_admin(client) + response = await client.post( + "/admin/invite", + data={"username": "newuser@example.com"}, + headers={"X-CSRF-Token": token}, + ) + assert response.status_code == 200 + assert "Invite created" in response.text + + +@pytest.mark.asyncio +async def test_invite_empty_username_rejected(client: AsyncClient) -> None: + token = await _login_admin(client) + response = await client.post( + "/admin/invite", + data={"username": ""}, + headers={"X-CSRF-Token": token}, + ) + # Empty username is rejected — either by FastAPI (422) or validation (alert) + assert response.status_code == 422 or "alert" in response.text + + +@pytest.mark.asyncio +async def test_invite_invalid_username_rejected(client: AsyncClient) -> None: + token = await _login_admin(client) + response = await client.post( + "/admin/invite", + data={"username": "bad user