Use the "well-known" configuration dictionary from openid to get the url endpoints, rather than trying to configure or guess the correct endpoint urls.

This commit is contained in:
Dan 2022-11-30 16:33:44 -05:00
parent d63c410988
commit 8993748934
2 changed files with 43 additions and 56 deletions

View File

@ -19,7 +19,7 @@ openid_blueprint = Blueprint(
MY_SECRET_CODE = ":this_should_be_some_crazy_code_different_all_the_time"
@openid_blueprint.route("/well-known/openid-configuration", methods=["GET"])
@openid_blueprint.route("/.well-known/openid-configuration", methods=["GET"])
def well_known():
"""OpenID Discovery endpoint -- as these urls can be very different from system to system,
this is just a small subset."""
@ -52,9 +52,10 @@ def form_submit():
state = request.values.get('state')
data = {
"state": base64.b64encode(bytes(state, 'UTF-8')),
"code": request.values['Uname'] + MY_SECRET_CODE
"code": request.values['Uname'] + MY_SECRET_CODE,
"session_state": ""
}
url = request.values.get('redirect_uri') + urlencode(data)
url = request.values.get('redirect_uri') + "?" + urlencode(data)
return redirect(url, code=200)
else:
return render_template('login.html',

View File

@ -3,6 +3,7 @@ import base64
import enum
import json
import time
import typing
from typing import Optional
import jwt
@ -15,6 +16,7 @@ from werkzeug.wrappers import Response
from spiffworkflow_backend.models.refresh_token import RefreshTokenModel
class AuthenticationProviderTypes(enum.Enum):
"""AuthenticationServiceProviders."""
@ -24,20 +26,31 @@ class AuthenticationProviderTypes(enum.Enum):
class AuthenticationService:
"""AuthenticationService."""
ENDPOINT_CACHE = {} # We only need to find the openid endpoints once, then we can cache them.
@staticmethod
def get_open_id_args() -> tuple:
"""Get_open_id_args."""
open_id_server_url = current_app.config["OPEN_ID_SERVER_URL"]
open_id_client_id = current_app.config["OPEN_ID_CLIENT_ID"]
open_id_client_secret_key = current_app.config[
"OPEN_ID_CLIENT_SECRET_KEY"
] # noqa: S105
return (
open_id_server_url,
open_id_client_id,
open_id_client_secret_key,
)
def client_id():
return current_app.config["OPEN_ID_CLIENT_ID"]
@staticmethod
def server_url():
return current_app.config["OPEN_ID_SERVER_URL"]
@staticmethod
def secret_key():
return current_app.config["OPEN_ID_CLIENT_SECRET_KEY"]
@classmethod
def open_id_endpoint_for_name(cls, name: str) -> None:
"""All openid systems provide a mapping of static names to the full path of that endpoint."""
if name not in AuthenticationService.ENDPOINT_CACHE:
request_url = f"{cls.server_url()}/.well-known/openid-configuration"
response = requests.get(request_url)
AuthenticationService.ENDPOINT_CACHE = response.json()
if name not in AuthenticationService.ENDPOINT_CACHE:
raise Exception(f"Unknown OpenID Endpoint: {name}")
return AuthenticationService.ENDPOINT_CACHE[name]
@staticmethod
def get_backend_url() -> str:
@ -49,14 +62,9 @@ class AuthenticationService:
if redirect_url is None:
redirect_url = "/"
return_redirect_url = f"{self.get_backend_url()}/v1.0/logout_return"
(
open_id_server_url,
open_id_client_id,
open_id_client_secret_key,
) = AuthenticationService.get_open_id_args()
request_url = (
f"{open_id_server_url}/protocol/openid-connect/logout?"
+ f"post_logout_redirect_uri={return_redirect_url}&"
self.open_id_endpoint_for_name("end_session_endpoint")
+ f"?post_logout_redirect_uri={return_redirect_url}&"
+ f"id_token_hint={id_token}"
)
@ -72,17 +80,12 @@ class AuthenticationService:
self, state: str, redirect_url: str = "/v1.0/login_return"
) -> str:
"""Get_login_redirect_url."""
(
open_id_server_url,
open_id_client_id,
open_id_client_secret_key,
) = AuthenticationService.get_open_id_args()
return_redirect_url = f"{self.get_backend_url()}{redirect_url}"
login_redirect_url = (
f"{open_id_server_url}/protocol/openid-connect/auth?"
+ f"state={state}&"
self.open_id_endpoint_for_name("authorization_endpoint")
+ f"?state={state}&"
+ "response_type=code&"
+ f"client_id={open_id_client_id}&"
+ f"client_id={self.client_id()}&"
+ "scope=openid&"
+ f"redirect_uri={return_redirect_url}"
)
@ -92,13 +95,7 @@ class AuthenticationService:
self, code: str, redirect_url: str = "/v1.0/login_return"
) -> dict:
"""Get_auth_token_object."""
(
open_id_server_url,
open_id_client_id,
open_id_client_secret_key,
) = AuthenticationService.get_open_id_args()
backend_basic_auth_string = f"{open_id_client_id}:{open_id_client_secret_key}"
backend_basic_auth_string = f"{self.client_id()}:{self.secret_key()}"
backend_basic_auth_bytes = bytes(backend_basic_auth_string, encoding="ascii")
backend_basic_auth = base64.b64encode(backend_basic_auth_bytes)
headers = {
@ -111,7 +108,7 @@ class AuthenticationService:
"redirect_uri": f"{self.get_backend_url()}{redirect_url}",
}
request_url = f"{open_id_server_url}/protocol/openid-connect/token"
request_url = self.open_id_endpoint_for_name("token_endpoint")
response = requests.post(request_url, data=data, headers=headers)
auth_token_object: dict = json.loads(response.text)
@ -122,11 +119,6 @@ class AuthenticationService:
"""Https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation."""
valid = True
now = time.time()
(
open_id_server_url,
open_id_client_id,
open_id_client_secret_key,
) = cls.get_open_id_args()
try:
decoded_token = jwt.decode(id_token, options={"verify_signature": False})
except Exception as e:
@ -135,15 +127,15 @@ class AuthenticationService:
message="Cannot decode id_token",
status_code=401,
) from e
if decoded_token["iss"] != open_id_server_url:
if decoded_token["iss"] != cls.server_url():
valid = False
elif (
open_id_client_id not in decoded_token["aud"]
cls.client_id() not in decoded_token["aud"]
and "account" not in decoded_token["aud"]
):
valid = False
elif "azp" in decoded_token and decoded_token["azp"] not in (
open_id_client_id,
cls.client_id(),
"account",
):
valid = False
@ -196,14 +188,8 @@ class AuthenticationService:
@classmethod
def get_auth_token_from_refresh_token(cls, refresh_token: str) -> dict:
"""Get a new auth_token from a refresh_token."""
(
open_id_server_url,
open_id_client_id,
open_id_client_secret_key,
) = cls.get_open_id_args()
backend_basic_auth_string = f"{open_id_client_id}:{open_id_client_secret_key}"
backend_basic_auth_string = f"{cls.client_id()}:{cls.secret_key()}"
backend_basic_auth_bytes = bytes(backend_basic_auth_string, encoding="ascii")
backend_basic_auth = base64.b64encode(backend_basic_auth_bytes)
headers = {
@ -214,11 +200,11 @@ class AuthenticationService:
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": open_id_client_id,
"client_secret": open_id_client_secret_key,
"client_id": cls.client_id(),
"client_secret": cls.secret_key(),
}
request_url = f"{open_id_server_url}/protocol/openid-connect/token"
request_url = cls.open_id_endpoint_for_name("token_endpoint")
response = requests.post(request_url, data=data, headers=headers)
auth_token_object: dict = json.loads(response.text)