import secrets from httpx import AsyncClient def _register_test_client( client: AsyncClient, client_id: str = "test-rp", redirect_uri: str = "http://localhost:9000/callback", ) -> str: """Register a test client in the OIDC server. Returns client_secret.""" app = client._transport.app # type: ignore[union-attr] oidc_server = app.state.oidc_server client_secret = secrets.token_hex(16) oidc_server.context.cdb[client_id] = { "client_id": client_id, "client_secret": client_secret, "redirect_uris": [(redirect_uri, {})], "response_types_supported": ["code"], "token_endpoint_auth_method": "client_secret_basic", "scope": ["openid", "profile", "email"], "allowed_scopes": ["openid", "profile", "email"], "client_salt": secrets.token_hex(8), } oidc_server.keyjar.add_symmetric(client_id, client_secret) return client_secret async def test_authorization_redirects_to_login_when_unauthenticated(client: AsyncClient) -> None: _register_test_client(client) res = await client.get( "/authorization", params={ "response_type": "code", "client_id": "test-rp", "redirect_uri": "http://localhost:9000/callback", "scope": "openid", "state": "test-state", "nonce": "test-nonce", }, follow_redirects=False, ) assert res.status_code in (302, 303) assert "/login" in res.headers["location"] async def test_authorization_stores_auth_request_in_session(client: AsyncClient) -> None: _register_test_client(client) res = await client.get( "/authorization", params={ "response_type": "code", "client_id": "test-rp", "redirect_uri": "http://localhost:9000/callback", "scope": "openid", "state": "test-state", "nonce": "test-nonce", }, follow_redirects=False, ) assert res.status_code in (302, 303) login_res = await client.get("/login") assert login_res.status_code == 200 async def test_authorization_invalid_client_returns_error(client: AsyncClient) -> None: res = await client.get( "/authorization", params={ "response_type": "code", "client_id": "nonexistent", "redirect_uri": "http://evil.com/callback", "scope": "openid", "state": "test-state", }, follow_redirects=False, ) assert res.status_code >= 400 or "error" in res.text.lower()