mirror of
https://github.com/sartography/spiff-arena.git
synced 2025-01-27 17:55:04 +00:00
Support jwks rotation (#1900)
* support jwks rotation * force refresh if not in cache * cleanup * dedup * Update spiffworkflow-backend/src/spiffworkflow_backend/services/authentication_service.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * more types * lint --------- Co-authored-by: burnettk <burnettk@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
parent
8b26848ec9
commit
6d16438816
@ -28,6 +28,7 @@ else:
|
||||
from typing import NotRequired
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
from flask import current_app
|
||||
@ -48,6 +49,21 @@ from spiffworkflow_backend.services.authorization_service import AuthorizationSe
|
||||
from spiffworkflow_backend.services.user_service import UserService
|
||||
|
||||
|
||||
class JWKSKeyConfig(TypedDict):
|
||||
kid: str
|
||||
kty: str
|
||||
use: str
|
||||
n: str
|
||||
e: str
|
||||
x5c: NotRequired[list[str]]
|
||||
alg: str
|
||||
x5t: str
|
||||
|
||||
|
||||
class JWKSConfigs(TypedDict):
|
||||
keys: NotRequired[list[JWKSKeyConfig]]
|
||||
|
||||
|
||||
class AuthenticationProviderTypes(enum.Enum):
|
||||
open_id = "open_id"
|
||||
internal = "internal"
|
||||
@ -71,7 +87,7 @@ class AuthenticationOptionNotFoundError(Exception):
|
||||
|
||||
class AuthenticationService:
|
||||
ENDPOINT_CACHE: dict[str, dict[str, str]] = {} # We only need to find the openid endpoints once, then we can cache them.
|
||||
JSON_WEB_KEYSET_CACHE: dict[str, dict[str, str]] = {}
|
||||
JSON_WEB_KEYSET_CACHE: dict[str, JWKSConfigs] = {}
|
||||
|
||||
@classmethod
|
||||
def authentication_options_for_api(cls) -> list[AuthenticationOptionForApi]:
|
||||
@ -139,24 +155,30 @@ class AuthenticationService:
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get_jwks_config_from_uri(cls, jwks_uri: str) -> dict:
|
||||
if jwks_uri not in AuthenticationService.JSON_WEB_KEYSET_CACHE:
|
||||
def get_jwks_config_from_uri(cls, jwks_uri: str, force_refresh: bool = False) -> JWKSConfigs:
|
||||
if jwks_uri not in cls.JSON_WEB_KEYSET_CACHE or force_refresh:
|
||||
try:
|
||||
jwt_ks_response = safe_requests.get(jwks_uri, timeout=HTTP_REQUEST_TIMEOUT_SECONDS)
|
||||
AuthenticationService.JSON_WEB_KEYSET_CACHE[jwks_uri] = jwt_ks_response.json()
|
||||
cls.JSON_WEB_KEYSET_CACHE[jwks_uri] = jwt_ks_response.json()
|
||||
except requests.exceptions.ConnectionError as ce:
|
||||
raise OpenIdConnectionError(f"Cannot connect to given jwks url: {jwks_uri}") from ce
|
||||
return AuthenticationService.JSON_WEB_KEYSET_CACHE[jwks_uri]
|
||||
|
||||
@classmethod
|
||||
def jwks_public_key_for_key_id(cls, authentication_identifier: str, key_id: str) -> dict:
|
||||
def jwks_public_key_for_key_id(cls, authentication_identifier: str, key_id: str) -> JWKSKeyConfig:
|
||||
jwks_uri = cls.open_id_endpoint_for_name("jwks_uri", authentication_identifier)
|
||||
jwks_configs = cls.get_jwks_config_from_uri(jwks_uri)
|
||||
json_key_configs: dict = next(jk for jk in jwks_configs["keys"] if jk["kid"] == key_id)
|
||||
json_key_configs: JWKSKeyConfig | None = cls.get_key_config(jwks_configs, key_id)
|
||||
if not json_key_configs:
|
||||
# Refetch the JWKS keys from the source if key_id is not found in cache
|
||||
jwks_configs = cls.get_jwks_config_from_uri(jwks_uri, force_refresh=True)
|
||||
json_key_configs = cls.get_key_config(jwks_configs, key_id)
|
||||
if not json_key_configs:
|
||||
raise KeyError(f"Key ID {key_id} not found in JWKS even after refreshing from {jwks_uri}.")
|
||||
return json_key_configs
|
||||
|
||||
@classmethod
|
||||
def public_key_from_rsa_public_numbers(cls, json_key_configs: dict) -> Any:
|
||||
def public_key_from_rsa_public_numbers(cls, json_key_configs: JWKSKeyConfig) -> Any:
|
||||
modulus = base64.urlsafe_b64decode(json_key_configs["n"] + "===")
|
||||
exponent = base64.urlsafe_b64decode(json_key_configs["e"] + "===")
|
||||
public_key_numbers = rsa.RSAPublicNumbers(
|
||||
@ -165,7 +187,7 @@ class AuthenticationService:
|
||||
return public_key_numbers.public_key(backend=default_backend())
|
||||
|
||||
@classmethod
|
||||
def public_key_from_x5c(cls, key_id: str, json_key_configs: dict) -> Any:
|
||||
def public_key_from_x5c(cls, key_id: str, json_key_configs: JWKSKeyConfig) -> Any:
|
||||
x5c = json_key_configs["x5c"][0]
|
||||
decoded_certificate = base64.b64decode(x5c)
|
||||
|
||||
@ -385,6 +407,13 @@ class AuthenticationService:
|
||||
return refresh_token_object.token
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_key_config(cls, jwks_configs: JWKSConfigs, key_id: str) -> JWKSKeyConfig | None:
|
||||
for jk in jwks_configs["keys"]:
|
||||
if jk["kid"] == key_id:
|
||||
return jk
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_auth_token_from_refresh_token(cls, refresh_token: str, authentication_identifier: str) -> dict:
|
||||
"""Converts a refresh token to an Auth Token by calling the openid's auth endpoint."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user