Use the "well-known" configuration dictionary from openid to get the url endpoints, rather than trying to configure or guess the correct endpoint urls.
This commit is contained in:
parent
d63c410988
commit
8993748934
|
@ -19,7 +19,7 @@ openid_blueprint = Blueprint(
|
|||
|
||||
MY_SECRET_CODE = ":this_should_be_some_crazy_code_different_all_the_time"
|
||||
|
||||
@openid_blueprint.route("/well-known/openid-configuration", methods=["GET"])
|
||||
@openid_blueprint.route("/.well-known/openid-configuration", methods=["GET"])
|
||||
def well_known():
|
||||
"""OpenID Discovery endpoint -- as these urls can be very different from system to system,
|
||||
this is just a small subset."""
|
||||
|
@ -52,9 +52,10 @@ def form_submit():
|
|||
state = request.values.get('state')
|
||||
data = {
|
||||
"state": base64.b64encode(bytes(state, 'UTF-8')),
|
||||
"code": request.values['Uname'] + MY_SECRET_CODE
|
||||
"code": request.values['Uname'] + MY_SECRET_CODE,
|
||||
"session_state": ""
|
||||
}
|
||||
url = request.values.get('redirect_uri') + urlencode(data)
|
||||
url = request.values.get('redirect_uri') + "?" + urlencode(data)
|
||||
return redirect(url, code=200)
|
||||
else:
|
||||
return render_template('login.html',
|
||||
|
|
|
@ -3,6 +3,7 @@ import base64
|
|||
import enum
|
||||
import json
|
||||
import time
|
||||
import typing
|
||||
from typing import Optional
|
||||
|
||||
import jwt
|
||||
|
@ -15,6 +16,7 @@ from werkzeug.wrappers import Response
|
|||
|
||||
from spiffworkflow_backend.models.refresh_token import RefreshTokenModel
|
||||
|
||||
|
||||
class AuthenticationProviderTypes(enum.Enum):
|
||||
"""AuthenticationServiceProviders."""
|
||||
|
||||
|
@ -24,20 +26,31 @@ class AuthenticationProviderTypes(enum.Enum):
|
|||
|
||||
class AuthenticationService:
|
||||
"""AuthenticationService."""
|
||||
ENDPOINT_CACHE = {} # We only need to find the openid endpoints once, then we can cache them.
|
||||
|
||||
@staticmethod
|
||||
def get_open_id_args() -> tuple:
|
||||
"""Get_open_id_args."""
|
||||
open_id_server_url = current_app.config["OPEN_ID_SERVER_URL"]
|
||||
open_id_client_id = current_app.config["OPEN_ID_CLIENT_ID"]
|
||||
open_id_client_secret_key = current_app.config[
|
||||
"OPEN_ID_CLIENT_SECRET_KEY"
|
||||
] # noqa: S105
|
||||
return (
|
||||
open_id_server_url,
|
||||
open_id_client_id,
|
||||
open_id_client_secret_key,
|
||||
)
|
||||
def client_id():
|
||||
return current_app.config["OPEN_ID_CLIENT_ID"]
|
||||
|
||||
@staticmethod
|
||||
def server_url():
|
||||
return current_app.config["OPEN_ID_SERVER_URL"]
|
||||
|
||||
@staticmethod
|
||||
def secret_key():
|
||||
return current_app.config["OPEN_ID_CLIENT_SECRET_KEY"]
|
||||
|
||||
|
||||
@classmethod
|
||||
def open_id_endpoint_for_name(cls, name: str) -> None:
|
||||
"""All openid systems provide a mapping of static names to the full path of that endpoint."""
|
||||
if name not in AuthenticationService.ENDPOINT_CACHE:
|
||||
request_url = f"{cls.server_url()}/.well-known/openid-configuration"
|
||||
response = requests.get(request_url)
|
||||
AuthenticationService.ENDPOINT_CACHE = response.json()
|
||||
if name not in AuthenticationService.ENDPOINT_CACHE:
|
||||
raise Exception(f"Unknown OpenID Endpoint: {name}")
|
||||
return AuthenticationService.ENDPOINT_CACHE[name]
|
||||
|
||||
@staticmethod
|
||||
def get_backend_url() -> str:
|
||||
|
@ -49,14 +62,9 @@ class AuthenticationService:
|
|||
if redirect_url is None:
|
||||
redirect_url = "/"
|
||||
return_redirect_url = f"{self.get_backend_url()}/v1.0/logout_return"
|
||||
(
|
||||
open_id_server_url,
|
||||
open_id_client_id,
|
||||
open_id_client_secret_key,
|
||||
) = AuthenticationService.get_open_id_args()
|
||||
request_url = (
|
||||
f"{open_id_server_url}/protocol/openid-connect/logout?"
|
||||
+ f"post_logout_redirect_uri={return_redirect_url}&"
|
||||
self.open_id_endpoint_for_name("end_session_endpoint")
|
||||
+ f"?post_logout_redirect_uri={return_redirect_url}&"
|
||||
+ f"id_token_hint={id_token}"
|
||||
)
|
||||
|
||||
|
@ -72,17 +80,12 @@ class AuthenticationService:
|
|||
self, state: str, redirect_url: str = "/v1.0/login_return"
|
||||
) -> str:
|
||||
"""Get_login_redirect_url."""
|
||||
(
|
||||
open_id_server_url,
|
||||
open_id_client_id,
|
||||
open_id_client_secret_key,
|
||||
) = AuthenticationService.get_open_id_args()
|
||||
return_redirect_url = f"{self.get_backend_url()}{redirect_url}"
|
||||
login_redirect_url = (
|
||||
f"{open_id_server_url}/protocol/openid-connect/auth?"
|
||||
+ f"state={state}&"
|
||||
self.open_id_endpoint_for_name("authorization_endpoint")
|
||||
+ f"?state={state}&"
|
||||
+ "response_type=code&"
|
||||
+ f"client_id={open_id_client_id}&"
|
||||
+ f"client_id={self.client_id()}&"
|
||||
+ "scope=openid&"
|
||||
+ f"redirect_uri={return_redirect_url}"
|
||||
)
|
||||
|
@ -92,13 +95,7 @@ class AuthenticationService:
|
|||
self, code: str, redirect_url: str = "/v1.0/login_return"
|
||||
) -> dict:
|
||||
"""Get_auth_token_object."""
|
||||
(
|
||||
open_id_server_url,
|
||||
open_id_client_id,
|
||||
open_id_client_secret_key,
|
||||
) = AuthenticationService.get_open_id_args()
|
||||
|
||||
backend_basic_auth_string = f"{open_id_client_id}:{open_id_client_secret_key}"
|
||||
backend_basic_auth_string = f"{self.client_id()}:{self.secret_key()}"
|
||||
backend_basic_auth_bytes = bytes(backend_basic_auth_string, encoding="ascii")
|
||||
backend_basic_auth = base64.b64encode(backend_basic_auth_bytes)
|
||||
headers = {
|
||||
|
@ -111,7 +108,7 @@ class AuthenticationService:
|
|||
"redirect_uri": f"{self.get_backend_url()}{redirect_url}",
|
||||
}
|
||||
|
||||
request_url = f"{open_id_server_url}/protocol/openid-connect/token"
|
||||
request_url = self.open_id_endpoint_for_name("token_endpoint")
|
||||
|
||||
response = requests.post(request_url, data=data, headers=headers)
|
||||
auth_token_object: dict = json.loads(response.text)
|
||||
|
@ -122,11 +119,6 @@ class AuthenticationService:
|
|||
"""Https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation."""
|
||||
valid = True
|
||||
now = time.time()
|
||||
(
|
||||
open_id_server_url,
|
||||
open_id_client_id,
|
||||
open_id_client_secret_key,
|
||||
) = cls.get_open_id_args()
|
||||
try:
|
||||
decoded_token = jwt.decode(id_token, options={"verify_signature": False})
|
||||
except Exception as e:
|
||||
|
@ -135,15 +127,15 @@ class AuthenticationService:
|
|||
message="Cannot decode id_token",
|
||||
status_code=401,
|
||||
) from e
|
||||
if decoded_token["iss"] != open_id_server_url:
|
||||
if decoded_token["iss"] != cls.server_url():
|
||||
valid = False
|
||||
elif (
|
||||
open_id_client_id not in decoded_token["aud"]
|
||||
cls.client_id() not in decoded_token["aud"]
|
||||
and "account" not in decoded_token["aud"]
|
||||
):
|
||||
valid = False
|
||||
elif "azp" in decoded_token and decoded_token["azp"] not in (
|
||||
open_id_client_id,
|
||||
cls.client_id(),
|
||||
"account",
|
||||
):
|
||||
valid = False
|
||||
|
@ -196,14 +188,8 @@ class AuthenticationService:
|
|||
|
||||
@classmethod
|
||||
def get_auth_token_from_refresh_token(cls, refresh_token: str) -> dict:
|
||||
"""Get a new auth_token from a refresh_token."""
|
||||
(
|
||||
open_id_server_url,
|
||||
open_id_client_id,
|
||||
open_id_client_secret_key,
|
||||
) = cls.get_open_id_args()
|
||||
|
||||
backend_basic_auth_string = f"{open_id_client_id}:{open_id_client_secret_key}"
|
||||
backend_basic_auth_string = f"{cls.client_id()}:{cls.secret_key()}"
|
||||
backend_basic_auth_bytes = bytes(backend_basic_auth_string, encoding="ascii")
|
||||
backend_basic_auth = base64.b64encode(backend_basic_auth_bytes)
|
||||
headers = {
|
||||
|
@ -214,11 +200,11 @@ class AuthenticationService:
|
|||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": open_id_client_id,
|
||||
"client_secret": open_id_client_secret_key,
|
||||
"client_id": cls.client_id(),
|
||||
"client_secret": cls.secret_key(),
|
||||
}
|
||||
|
||||
request_url = f"{open_id_server_url}/protocol/openid-connect/token"
|
||||
request_url = cls.open_id_endpoint_for_name("token_endpoint")
|
||||
|
||||
response = requests.post(request_url, data=data, headers=headers)
|
||||
auth_token_object: dict = json.loads(response.text)
|
||||
|
|
Loading…
Reference in New Issue