diff --git a/src/fastapi_oidc_op/authn/routes.py b/src/fastapi_oidc_op/authn/routes.py index 2bca6b0..dcb4fbf 100644 --- a/src/fastapi_oidc_op/authn/routes.py +++ b/src/fastapi_oidc_op/authn/routes.py @@ -13,6 +13,17 @@ from fastapi_oidc_op.userid import generate_unique_userid router = APIRouter(tags=["authn"]) +def _login_redirect_target(request: Request) -> str: + """Determine where to redirect after successful login. + + If there's a pending OIDC authorization request, redirect to complete it. + Otherwise, redirect to credential management. + """ + if "oidc_auth_request" in request.session: + return "/authorization/complete" + return "/manage/credentials" + + @router.get("/login", response_class=HTMLResponse) async def login_page(request: Request) -> HTMLResponse: templates = request.app.state.templates @@ -46,7 +57,7 @@ async def login_password( request.session["username"] = user.username response = Response() - response.headers["HX-Redirect"] = "/manage/credentials" + response.headers["HX-Redirect"] = _login_redirect_target(request) return response @@ -150,5 +161,5 @@ async def login_webauthn_complete(request: Request) -> Response: request.session["username"] = user.username response = Response() - response.headers["HX-Redirect"] = "/manage/credentials" + response.headers["HX-Redirect"] = _login_redirect_target(request) return response diff --git a/tests/test_oidc/test_login_oidc_redirect.py b/tests/test_oidc/test_login_oidc_redirect.py new file mode 100644 index 0000000..dd51d71 --- /dev/null +++ b/tests/test_oidc/test_login_oidc_redirect.py @@ -0,0 +1,77 @@ +import secrets +from datetime import UTC, datetime + +from argon2 import PasswordHasher +from httpx import AsyncClient + +from fastapi_oidc_op.authn.password import PasswordService +from fastapi_oidc_op.models import PasswordCredential, User + + +def _register_test_client(client: AsyncClient) -> None: + app = client._transport.app # type: ignore[union-attr] + oidc_server = app.state.oidc_server + oidc_server.context.cdb["test-rp"] = { + "client_id": "test-rp", + "client_secret": "test-secret", + "redirect_uris": [("http://localhost:9000/callback", {})], + "response_types_supported": ["code"], + "token_endpoint_auth_method": "client_secret_basic", + "scope": ["openid", "profile", "email"], + "allowed_scopes": ["openid", "profile", "email"], + "client_salt": secrets.token_hex(8), + } + oidc_server.keyjar.add_symmetric("test-rp", "test-secret") + + +async def _create_user(client: AsyncClient) -> None: + app = client._transport.app # type: ignore[union-attr] + user_repo = app.state.user_repo + cred_repo = app.state.credential_repo + user = User(userid="lusab-bansen", username="alice", created_at=datetime.now(UTC), updated_at=datetime.now(UTC)) + await user_repo.create(user) + svc = PasswordService(hasher=PasswordHasher(time_cost=1, memory_cost=8192)) + await cred_repo.create_password(PasswordCredential(user_id=user.userid, password_hash=svc.hash("testpass"))) + + +async def test_login_with_pending_oidc_redirects_to_authorization_complete(client: AsyncClient) -> None: + _register_test_client(client) + await _create_user(client) + + # Step 1: Start OIDC authorization (stores request in session, redirects to /login) + auth_res = await client.get( + "/authorization", + params={ + "response_type": "code", + "client_id": "test-rp", + "redirect_uri": "http://localhost:9000/callback", + "scope": "openid", + "state": "test-state", + "nonce": "test-nonce", + }, + follow_redirects=False, + ) + assert auth_res.status_code in (302, 303) + + # Step 2: Login via password + login_res = await client.post( + "/login/password", + data={"username": "alice", "password": "testpass"}, + headers={"HX-Request": "true"}, + ) + assert login_res.status_code == 200 + redirect_target = login_res.headers.get("HX-Redirect", "") + assert "/authorization/complete" in redirect_target + + +async def test_login_without_pending_oidc_redirects_to_manage(client: AsyncClient) -> None: + await _create_user(client) + + login_res = await client.post( + "/login/password", + data={"username": "alice", "password": "testpass"}, + headers={"HX-Request": "true"}, + ) + assert login_res.status_code == 200 + redirect_target = login_res.headers.get("HX-Redirect", "") + assert redirect_target == "/manage/credentials"