Support jwks rotation (#1900)

* support jwks rotation

* force refresh if not in cache

* cleanup

* dedup

* Update spiffworkflow-backend/src/spiffworkflow_backend/services/authentication_service.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* more types

* lint

---------

Co-authored-by: burnettk <burnettk@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
Kevin Burnett 2024-07-10 20:00:18 +00:00 committed by GitHub
parent 8b26848ec9
commit 6d16438816
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -28,6 +28,7 @@ else:
from typing import NotRequired
from typing import TypedDict
import jwt
import requests
from flask import current_app
@ -48,6 +49,21 @@ from spiffworkflow_backend.services.authorization_service import AuthorizationSe
from spiffworkflow_backend.services.user_service import UserService
class JWKSKeyConfig(TypedDict):
kid: str
kty: str
use: str
n: str
e: str
x5c: NotRequired[list[str]]
alg: str
x5t: str
class JWKSConfigs(TypedDict):
keys: NotRequired[list[JWKSKeyConfig]]
class AuthenticationProviderTypes(enum.Enum):
open_id = "open_id"
internal = "internal"
@ -71,7 +87,7 @@ class AuthenticationOptionNotFoundError(Exception):
class AuthenticationService:
ENDPOINT_CACHE: dict[str, dict[str, str]] = {} # We only need to find the openid endpoints once, then we can cache them.
JSON_WEB_KEYSET_CACHE: dict[str, dict[str, str]] = {}
JSON_WEB_KEYSET_CACHE: dict[str, JWKSConfigs] = {}
@classmethod
def authentication_options_for_api(cls) -> list[AuthenticationOptionForApi]:
@ -139,24 +155,30 @@ class AuthenticationService:
return config
@classmethod
def get_jwks_config_from_uri(cls, jwks_uri: str) -> dict:
if jwks_uri not in AuthenticationService.JSON_WEB_KEYSET_CACHE:
def get_jwks_config_from_uri(cls, jwks_uri: str, force_refresh: bool = False) -> JWKSConfigs:
if jwks_uri not in cls.JSON_WEB_KEYSET_CACHE or force_refresh:
try:
jwt_ks_response = safe_requests.get(jwks_uri, timeout=HTTP_REQUEST_TIMEOUT_SECONDS)
AuthenticationService.JSON_WEB_KEYSET_CACHE[jwks_uri] = jwt_ks_response.json()
cls.JSON_WEB_KEYSET_CACHE[jwks_uri] = jwt_ks_response.json()
except requests.exceptions.ConnectionError as ce:
raise OpenIdConnectionError(f"Cannot connect to given jwks url: {jwks_uri}") from ce
return AuthenticationService.JSON_WEB_KEYSET_CACHE[jwks_uri]
@classmethod
def jwks_public_key_for_key_id(cls, authentication_identifier: str, key_id: str) -> dict:
def jwks_public_key_for_key_id(cls, authentication_identifier: str, key_id: str) -> JWKSKeyConfig:
jwks_uri = cls.open_id_endpoint_for_name("jwks_uri", authentication_identifier)
jwks_configs = cls.get_jwks_config_from_uri(jwks_uri)
json_key_configs: dict = next(jk for jk in jwks_configs["keys"] if jk["kid"] == key_id)
json_key_configs: JWKSKeyConfig | None = cls.get_key_config(jwks_configs, key_id)
if not json_key_configs:
# Refetch the JWKS keys from the source if key_id is not found in cache
jwks_configs = cls.get_jwks_config_from_uri(jwks_uri, force_refresh=True)
json_key_configs = cls.get_key_config(jwks_configs, key_id)
if not json_key_configs:
raise KeyError(f"Key ID {key_id} not found in JWKS even after refreshing from {jwks_uri}.")
return json_key_configs
@classmethod
def public_key_from_rsa_public_numbers(cls, json_key_configs: dict) -> Any:
def public_key_from_rsa_public_numbers(cls, json_key_configs: JWKSKeyConfig) -> Any:
modulus = base64.urlsafe_b64decode(json_key_configs["n"] + "===")
exponent = base64.urlsafe_b64decode(json_key_configs["e"] + "===")
public_key_numbers = rsa.RSAPublicNumbers(
@ -165,7 +187,7 @@ class AuthenticationService:
return public_key_numbers.public_key(backend=default_backend())
@classmethod
def public_key_from_x5c(cls, key_id: str, json_key_configs: dict) -> Any:
def public_key_from_x5c(cls, key_id: str, json_key_configs: JWKSKeyConfig) -> Any:
x5c = json_key_configs["x5c"][0]
decoded_certificate = base64.b64decode(x5c)
@ -385,6 +407,13 @@ class AuthenticationService:
return refresh_token_object.token
return None
@classmethod
def get_key_config(cls, jwks_configs: JWKSConfigs, key_id: str) -> JWKSKeyConfig | None:
for jk in jwks_configs["keys"]:
if jk["kid"] == key_id:
return jk
return None
@classmethod
def get_auth_token_from_refresh_token(cls, refresh_token: str, authentication_identifier: str) -> dict:
"""Converts a refresh token to an Auth Token by calling the openid's auth endpoint."""