diff --git a/spiffworkflow-backend/src/spiffworkflow_backend/routes/openid_blueprint/openid_blueprint.py b/spiffworkflow-backend/src/spiffworkflow_backend/routes/openid_blueprint/openid_blueprint.py index dd8928c63..b16ba46af 100644 --- a/spiffworkflow-backend/src/spiffworkflow_backend/routes/openid_blueprint/openid_blueprint.py +++ b/spiffworkflow-backend/src/spiffworkflow_backend/routes/openid_blueprint/openid_blueprint.py @@ -19,7 +19,7 @@ openid_blueprint = Blueprint( MY_SECRET_CODE = ":this_should_be_some_crazy_code_different_all_the_time" -@openid_blueprint.route("/well-known/openid-configuration", methods=["GET"]) +@openid_blueprint.route("/.well-known/openid-configuration", methods=["GET"]) def well_known(): """OpenID Discovery endpoint -- as these urls can be very different from system to system, this is just a small subset.""" @@ -52,9 +52,10 @@ def form_submit(): state = request.values.get('state') data = { "state": base64.b64encode(bytes(state, 'UTF-8')), - "code": request.values['Uname'] + MY_SECRET_CODE + "code": request.values['Uname'] + MY_SECRET_CODE, + "session_state": "" } - url = request.values.get('redirect_uri') + urlencode(data) + url = request.values.get('redirect_uri') + "?" + urlencode(data) return redirect(url, code=200) else: return render_template('login.html', diff --git a/spiffworkflow-backend/src/spiffworkflow_backend/services/authentication_service.py b/spiffworkflow-backend/src/spiffworkflow_backend/services/authentication_service.py index f8171d88d..5fdedf767 100644 --- a/spiffworkflow-backend/src/spiffworkflow_backend/services/authentication_service.py +++ b/spiffworkflow-backend/src/spiffworkflow_backend/services/authentication_service.py @@ -3,6 +3,7 @@ import base64 import enum import json import time +import typing from typing import Optional import jwt @@ -15,6 +16,7 @@ from werkzeug.wrappers import Response from spiffworkflow_backend.models.refresh_token import RefreshTokenModel + class AuthenticationProviderTypes(enum.Enum): """AuthenticationServiceProviders.""" @@ -24,20 +26,31 @@ class AuthenticationProviderTypes(enum.Enum): class AuthenticationService: """AuthenticationService.""" + ENDPOINT_CACHE = {} # We only need to find the openid endpoints once, then we can cache them. @staticmethod - def get_open_id_args() -> tuple: - """Get_open_id_args.""" - open_id_server_url = current_app.config["OPEN_ID_SERVER_URL"] - open_id_client_id = current_app.config["OPEN_ID_CLIENT_ID"] - open_id_client_secret_key = current_app.config[ - "OPEN_ID_CLIENT_SECRET_KEY" - ] # noqa: S105 - return ( - open_id_server_url, - open_id_client_id, - open_id_client_secret_key, - ) + def client_id(): + return current_app.config["OPEN_ID_CLIENT_ID"] + + @staticmethod + def server_url(): + return current_app.config["OPEN_ID_SERVER_URL"] + + @staticmethod + def secret_key(): + return current_app.config["OPEN_ID_CLIENT_SECRET_KEY"] + + + @classmethod + def open_id_endpoint_for_name(cls, name: str) -> None: + """All openid systems provide a mapping of static names to the full path of that endpoint.""" + if name not in AuthenticationService.ENDPOINT_CACHE: + request_url = f"{cls.server_url()}/.well-known/openid-configuration" + response = requests.get(request_url) + AuthenticationService.ENDPOINT_CACHE = response.json() + if name not in AuthenticationService.ENDPOINT_CACHE: + raise Exception(f"Unknown OpenID Endpoint: {name}") + return AuthenticationService.ENDPOINT_CACHE[name] @staticmethod def get_backend_url() -> str: @@ -49,14 +62,9 @@ class AuthenticationService: if redirect_url is None: redirect_url = "/" return_redirect_url = f"{self.get_backend_url()}/v1.0/logout_return" - ( - open_id_server_url, - open_id_client_id, - open_id_client_secret_key, - ) = AuthenticationService.get_open_id_args() request_url = ( - f"{open_id_server_url}/protocol/openid-connect/logout?" - + f"post_logout_redirect_uri={return_redirect_url}&" + self.open_id_endpoint_for_name("end_session_endpoint") + + f"?post_logout_redirect_uri={return_redirect_url}&" + f"id_token_hint={id_token}" ) @@ -72,17 +80,12 @@ class AuthenticationService: self, state: str, redirect_url: str = "/v1.0/login_return" ) -> str: """Get_login_redirect_url.""" - ( - open_id_server_url, - open_id_client_id, - open_id_client_secret_key, - ) = AuthenticationService.get_open_id_args() return_redirect_url = f"{self.get_backend_url()}{redirect_url}" login_redirect_url = ( - f"{open_id_server_url}/protocol/openid-connect/auth?" - + f"state={state}&" + self.open_id_endpoint_for_name("authorization_endpoint") + + f"?state={state}&" + "response_type=code&" - + f"client_id={open_id_client_id}&" + + f"client_id={self.client_id()}&" + "scope=openid&" + f"redirect_uri={return_redirect_url}" ) @@ -92,13 +95,7 @@ class AuthenticationService: self, code: str, redirect_url: str = "/v1.0/login_return" ) -> dict: """Get_auth_token_object.""" - ( - open_id_server_url, - open_id_client_id, - open_id_client_secret_key, - ) = AuthenticationService.get_open_id_args() - - backend_basic_auth_string = f"{open_id_client_id}:{open_id_client_secret_key}" + backend_basic_auth_string = f"{self.client_id()}:{self.secret_key()}" backend_basic_auth_bytes = bytes(backend_basic_auth_string, encoding="ascii") backend_basic_auth = base64.b64encode(backend_basic_auth_bytes) headers = { @@ -111,7 +108,7 @@ class AuthenticationService: "redirect_uri": f"{self.get_backend_url()}{redirect_url}", } - request_url = f"{open_id_server_url}/protocol/openid-connect/token" + request_url = self.open_id_endpoint_for_name("token_endpoint") response = requests.post(request_url, data=data, headers=headers) auth_token_object: dict = json.loads(response.text) @@ -122,11 +119,6 @@ class AuthenticationService: """Https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation.""" valid = True now = time.time() - ( - open_id_server_url, - open_id_client_id, - open_id_client_secret_key, - ) = cls.get_open_id_args() try: decoded_token = jwt.decode(id_token, options={"verify_signature": False}) except Exception as e: @@ -135,15 +127,15 @@ class AuthenticationService: message="Cannot decode id_token", status_code=401, ) from e - if decoded_token["iss"] != open_id_server_url: + if decoded_token["iss"] != cls.server_url(): valid = False elif ( - open_id_client_id not in decoded_token["aud"] + cls.client_id() not in decoded_token["aud"] and "account" not in decoded_token["aud"] ): valid = False elif "azp" in decoded_token and decoded_token["azp"] not in ( - open_id_client_id, + cls.client_id(), "account", ): valid = False @@ -196,14 +188,8 @@ class AuthenticationService: @classmethod def get_auth_token_from_refresh_token(cls, refresh_token: str) -> dict: - """Get a new auth_token from a refresh_token.""" - ( - open_id_server_url, - open_id_client_id, - open_id_client_secret_key, - ) = cls.get_open_id_args() - backend_basic_auth_string = f"{open_id_client_id}:{open_id_client_secret_key}" + backend_basic_auth_string = f"{cls.client_id()}:{cls.secret_key()}" backend_basic_auth_bytes = bytes(backend_basic_auth_string, encoding="ascii") backend_basic_auth = base64.b64encode(backend_basic_auth_bytes) headers = { @@ -214,11 +200,11 @@ class AuthenticationService: data = { "grant_type": "refresh_token", "refresh_token": refresh_token, - "client_id": open_id_client_id, - "client_secret": open_id_client_secret_key, + "client_id": cls.client_id(), + "client_secret": cls.secret_key(), } - request_url = f"{open_id_server_url}/protocol/openid-connect/token" + request_url = cls.open_id_endpoint_for_name("token_endpoint") response = requests.post(request_url, data=data, headers=headers) auth_token_object: dict = json.loads(response.text)