porchlight/examples/rp-reference/oidc_client.py
Johan Lundberg 8e8c33a407
reference RP
2026-06-29 09:23:22 +02:00

266 lines
9.7 KiB
Python

"""The hand-rolled OIDC logic.
This module deliberately does NOT use an OIDC client library. Each protocol
step is a small function so you can read exactly what an RP sends and checks.
PyJWT is used only for the RS256 *signature* primitive and JWK parsing — every
OIDC-level claim check (iss / aud / exp / nonce) is written out explicitly
below, because that is the part worth seeing.
Flow overview:
discovery -> build_authorization_url -> (browser redirect to OP)
-> exchange_code -> verify_id_token -> fetch_userinfo
-> refresh_tokens
References:
- OpenID Connect Core 1.0, section 3.1 (Authorization Code Flow)
- RFC 7636 (PKCE)
"""
from __future__ import annotations
import base64
import hashlib
import json
import secrets
import time
from typing import Any
import httpx
import jwt
from jwt.algorithms import RSAAlgorithm
# --------------------------------------------------------------------------
# 0. PKCE helpers (RFC 7636)
# --------------------------------------------------------------------------
def generate_pkce_pair() -> tuple[str, str]:
"""Return (code_verifier, code_challenge) for the S256 method.
The verifier is a high-entropy random string we keep secret in our session.
The challenge is its SHA-256, base64url-encoded, and is what we send to the
OP in the authorization request. At the token endpoint we send the verifier;
the OP re-hashes it and checks it matches the challenge it stored.
"""
code_verifier = _b64url(secrets.token_bytes(32))
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
code_challenge = _b64url(digest)
return code_verifier, code_challenge
def _b64url(raw: bytes) -> str:
"""base64url without padding, per the JOSE/PKCE conventions."""
return base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
# --------------------------------------------------------------------------
# 1. Discovery
# --------------------------------------------------------------------------
async def fetch_discovery(issuer: str) -> dict[str, Any]:
"""Fetch the OP's provider metadata.
The well-known URL is the issuer with a fixed suffix appended. The returned
document tells us the real authorization / token / userinfo / jwks URLs, so
we never hard-code endpoint paths.
"""
url = issuer.rstrip("/") + "/.well-known/openid-configuration"
async with httpx.AsyncClient() as http:
resp = await http.get(url)
resp.raise_for_status()
return resp.json()
# --------------------------------------------------------------------------
# 2. Authorization request
# --------------------------------------------------------------------------
def build_authorization_url(
discovery: dict[str, Any],
*,
client_id: str,
redirect_uri: str,
scope: str,
) -> tuple[str, dict[str, str]]:
"""Build the URL we redirect the browser to, plus the per-request secrets.
Returns (authorization_url, session_state) where session_state holds the
values we must remember to validate the callback: the CSRF `state`, the
replay-protection `nonce`, and the PKCE `code_verifier`.
"""
state = secrets.token_urlsafe(24) # CSRF protection for the redirect
nonce = secrets.token_urlsafe(24) # binds the ID token to this request
code_verifier, code_challenge = generate_pkce_pair()
params = {
"response_type": "code",
"client_id": client_id,
"redirect_uri": redirect_uri,
"scope": scope,
"state": state,
"nonce": nonce,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}
# OIDC Core 11.1: requesting a refresh token via `offline_access` REQUIRES
# prompt=consent, so the user explicitly approves long-lived access. The OP
# (idpyoidc) rejects the request otherwise ("consent in prompt").
if "offline_access" in scope.split():
params["prompt"] = "consent"
# httpx URL handles correct percent-encoding of each value:
url = str(httpx.URL(discovery["authorization_endpoint"], params=params))
session_state = {"state": state, "nonce": nonce, "code_verifier": code_verifier}
return url, session_state
# --------------------------------------------------------------------------
# 3. Token exchange (authorization code -> tokens)
# --------------------------------------------------------------------------
async def exchange_code(
discovery: dict[str, Any],
*,
code: str,
code_verifier: str,
redirect_uri: str,
client_id: str,
client_secret: str,
) -> dict[str, Any]:
"""Swap the authorization `code` for tokens at the token endpoint.
We authenticate with HTTP Basic (client_secret_basic): the client_id and
client_secret go in the Authorization header, not the body. We also send the
PKCE code_verifier so the OP can prove the same client that started the flow
is finishing it.
"""
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
# We authenticate with client_secret_basic (below), but we also include
# client_id in the body. It is redundant for the code grant but required
# by the OP for the refresh grant, so we send it consistently.
"client_id": client_id,
}
async with httpx.AsyncClient() as http:
resp = await http.post(
discovery["token_endpoint"],
data=data,
auth=(client_id, client_secret), # client_secret_basic
headers={"Accept": "application/json"},
)
resp.raise_for_status()
return resp.json()
# --------------------------------------------------------------------------
# 4. ID token verification
# --------------------------------------------------------------------------
async def verify_id_token(
id_token: str,
*,
discovery: dict[str, Any],
issuer: str,
client_id: str,
expected_nonce: str,
leeway: int,
) -> dict[str, Any]:
"""Verify the ID token signature, then check its claims by hand.
PyJWT verifies the RS256 signature and the exp/iat timing for us (those need
crypto / the current clock). Everything that is OIDC-specific — iss, aud,
nonce — we check explicitly so the rules are visible.
"""
# (a) Read the unverified header to learn which key (kid) signed it.
header = jwt.get_unverified_header(id_token)
kid = header.get("kid")
# (b) Fetch the OP's public keys and pick the matching one.
async with httpx.AsyncClient() as http:
resp = await http.get(discovery["jwks_uri"])
resp.raise_for_status()
jwks = resp.json()
jwk = next((k for k in jwks["keys"] if k.get("kid") == kid), None)
if jwk is None:
raise ValueError(f"no JWK in OP key set matches token kid={kid!r}")
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))
# (c) Verify the signature + exp/iat. We turn off PyJWT's own iss/aud checks
# so we can do them ourselves below.
claims: dict[str, Any] = jwt.decode(
id_token,
key=public_key,
algorithms=["RS256"],
leeway=leeway,
options={"verify_aud": False, "verify_iss": False, "require": ["exp", "iat"]},
)
# (d) Explicit OIDC claim checks — the heart of RP validation.
if claims.get("iss") != issuer:
raise ValueError(f"iss mismatch: {claims.get('iss')!r} != {issuer!r}")
aud = claims.get("aud")
audiences = aud if isinstance(aud, list) else [aud]
if client_id not in audiences:
raise ValueError(f"aud {audiences!r} does not contain client_id {client_id!r}")
if claims.get("nonce") != expected_nonce:
raise ValueError("nonce mismatch: ID token does not belong to this login")
return claims
# --------------------------------------------------------------------------
# 5. UserInfo
# --------------------------------------------------------------------------
async def fetch_userinfo(
discovery: dict[str, Any], *, access_token: str
) -> dict[str, Any]:
"""Call the UserInfo endpoint with the access token as a Bearer token."""
async with httpx.AsyncClient() as http:
resp = await http.get(
discovery["userinfo_endpoint"],
headers={"Authorization": f"Bearer {access_token}"},
)
resp.raise_for_status()
return resp.json()
# --------------------------------------------------------------------------
# 6. Refresh
# --------------------------------------------------------------------------
async def refresh_tokens(
discovery: dict[str, Any],
*,
refresh_token: str,
client_id: str,
client_secret: str,
) -> dict[str, Any]:
"""Use the refresh token to get a fresh access token.
Note (porchlight specific): the OP rotates the refresh token (you get a new
one back) and does NOT re-mint an ID token on refresh — re-authentication is
what issues ID tokens. So the response here typically has access_token and a
new refresh_token, but no id_token.
"""
# client_id is required in the body for the refresh grant: the OP reads it
# from the request message (it does not fall back to the Basic-auth client
# for this grant), so omitting it causes a server error.
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
async with httpx.AsyncClient() as http:
resp = await http.post(
discovery["token_endpoint"],
data=data,
auth=(client_id, client_secret),
headers={"Accept": "application/json"},
)
resp.raise_for_status()
return resp.json()
def now() -> int:
"""Current unix time — handy for showing token age in the UI."""
return int(time.time())