mirror of
https://github.com/sartography/spiff-arena.git
synced 2025-03-03 02:20:50 +00:00
248 lines
8.6 KiB
Python
248 lines
8.6 KiB
Python
"""Authentication_service."""
|
|
import base64
|
|
import enum
|
|
import json
|
|
import time
|
|
from typing import Optional
|
|
|
|
import jwt
|
|
import requests
|
|
from flask import current_app
|
|
from flask import redirect
|
|
from werkzeug.wrappers import Response
|
|
|
|
from spiffworkflow_backend.models.db import db
|
|
from spiffworkflow_backend.models.refresh_token import RefreshTokenModel
|
|
|
|
|
|
class MissingAccessTokenError(Exception):
|
|
"""MissingAccessTokenError."""
|
|
|
|
|
|
class NotAuthorizedError(Exception):
|
|
pass
|
|
|
|
|
|
class RefreshTokenStorageError(Exception):
|
|
pass
|
|
|
|
|
|
class UserNotLoggedInError(Exception):
|
|
pass
|
|
|
|
|
|
# These could be either 'id' OR 'access' tokens and we can't always know which
|
|
|
|
|
|
class TokenExpiredError(Exception):
|
|
"""TokenExpiredError."""
|
|
|
|
|
|
class TokenInvalidError(Exception):
|
|
"""TokenInvalidError."""
|
|
|
|
|
|
class TokenNotProvidedError(Exception):
|
|
pass
|
|
|
|
|
|
class AuthenticationProviderTypes(enum.Enum):
|
|
"""AuthenticationServiceProviders."""
|
|
|
|
open_id = "open_id"
|
|
internal = "internal"
|
|
|
|
|
|
class AuthenticationService:
|
|
"""AuthenticationService."""
|
|
|
|
ENDPOINT_CACHE: dict = {} # We only need to find the openid endpoints once, then we can cache them.
|
|
|
|
@staticmethod
|
|
def client_id() -> str:
|
|
"""Returns the client id from the config."""
|
|
return current_app.config.get("SPIFFWORKFLOW_BACKEND_OPEN_ID_CLIENT_ID", "")
|
|
|
|
@staticmethod
|
|
def server_url() -> str:
|
|
"""Returns the server url from the config."""
|
|
return current_app.config.get("SPIFFWORKFLOW_BACKEND_OPEN_ID_SERVER_URL", "")
|
|
|
|
@staticmethod
|
|
def secret_key() -> str:
|
|
"""Returns the secret key from the config."""
|
|
return current_app.config.get("SPIFFWORKFLOW_BACKEND_OPEN_ID_CLIENT_SECRET_KEY", "")
|
|
|
|
@classmethod
|
|
def open_id_endpoint_for_name(cls, name: str) -> str:
|
|
"""All openid systems provide a mapping of static names to the full path of that endpoint."""
|
|
openid_config_url = f"{cls.server_url()}/.well-known/openid-configuration"
|
|
if name not in AuthenticationService.ENDPOINT_CACHE:
|
|
response = requests.get(openid_config_url)
|
|
AuthenticationService.ENDPOINT_CACHE = response.json()
|
|
if name not in AuthenticationService.ENDPOINT_CACHE:
|
|
raise Exception(f"Unknown OpenID Endpoint: {name}. Tried to get from {openid_config_url}")
|
|
return AuthenticationService.ENDPOINT_CACHE.get(name, "")
|
|
|
|
@staticmethod
|
|
def get_backend_url() -> str:
|
|
"""Get_backend_url."""
|
|
return str(current_app.config["SPIFFWORKFLOW_BACKEND_URL"])
|
|
|
|
def logout(self, id_token: str, redirect_url: Optional[str] = None) -> Response:
|
|
"""Logout."""
|
|
if redirect_url is None:
|
|
redirect_url = f"{self.get_backend_url()}/v1.0/logout_return"
|
|
request_url = (
|
|
self.open_id_endpoint_for_name("end_session_endpoint")
|
|
+ f"?post_logout_redirect_uri={redirect_url}&"
|
|
+ f"id_token_hint={id_token}"
|
|
)
|
|
|
|
return redirect(request_url)
|
|
|
|
@staticmethod
|
|
def generate_state(redirect_url: str) -> bytes:
|
|
"""Generate_state."""
|
|
state = base64.b64encode(bytes(str({"redirect_url": redirect_url}), "UTF-8"))
|
|
return state
|
|
|
|
def get_login_redirect_url(self, state: str, redirect_url: str = "/v1.0/login_return") -> str:
|
|
"""Get_login_redirect_url."""
|
|
return_redirect_url = f"{self.get_backend_url()}{redirect_url}"
|
|
login_redirect_url = (
|
|
self.open_id_endpoint_for_name("authorization_endpoint")
|
|
+ f"?state={state}&"
|
|
+ "response_type=code&"
|
|
+ f"client_id={self.client_id()}&"
|
|
+ "scope=openid profile email&"
|
|
+ f"redirect_uri={return_redirect_url}"
|
|
)
|
|
return login_redirect_url
|
|
|
|
def get_auth_token_object(self, code: str, redirect_url: str = "/v1.0/login_return") -> dict:
|
|
"""Get_auth_token_object."""
|
|
backend_basic_auth_string = f"{self.client_id()}:{self.secret_key()}"
|
|
backend_basic_auth_bytes = bytes(backend_basic_auth_string, encoding="ascii")
|
|
backend_basic_auth = base64.b64encode(backend_basic_auth_bytes)
|
|
headers = {
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
"Authorization": f"Basic {backend_basic_auth.decode('utf-8')}",
|
|
}
|
|
data = {
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"redirect_uri": f"{self.get_backend_url()}{redirect_url}",
|
|
}
|
|
|
|
request_url = self.open_id_endpoint_for_name("token_endpoint")
|
|
|
|
response = requests.post(request_url, data=data, headers=headers)
|
|
auth_token_object: dict = json.loads(response.text)
|
|
return auth_token_object
|
|
|
|
@classmethod
|
|
def validate_id_or_access_token(cls, token: str) -> bool:
|
|
"""Https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation."""
|
|
valid = True
|
|
now = round(time.time())
|
|
try:
|
|
decoded_token = jwt.decode(token, options={"verify_signature": False})
|
|
except Exception as e:
|
|
raise TokenInvalidError("Cannot decode token") from e
|
|
|
|
# give a 5 second leeway to iat in case keycloak server time doesn't match backend server
|
|
iat_clock_skew_leeway = 5
|
|
|
|
iss = decoded_token["iss"]
|
|
aud = decoded_token["aud"]
|
|
azp = decoded_token["azp"] if "azp" in decoded_token else None
|
|
iat = decoded_token["iat"]
|
|
|
|
valid_audience_values = (cls.client_id(), "account")
|
|
audience_array_in_token = aud
|
|
if isinstance(aud, str):
|
|
audience_array_in_token = [aud]
|
|
overlapping_aud_values = [x for x in audience_array_in_token if x in valid_audience_values]
|
|
|
|
if iss != cls.server_url():
|
|
valid = False
|
|
# aud could be an array or a string
|
|
elif len(overlapping_aud_values) < 1:
|
|
valid = False
|
|
elif azp and azp not in (
|
|
cls.client_id(),
|
|
"account",
|
|
):
|
|
valid = False
|
|
# make sure issued at time is not in the future
|
|
elif now + iat_clock_skew_leeway < iat:
|
|
valid = False
|
|
|
|
if valid and now > decoded_token["exp"]:
|
|
raise TokenExpiredError("Your token is expired. Please Login")
|
|
elif not valid:
|
|
current_app.logger.error(
|
|
"TOKEN INVALID: details: "
|
|
f"ISS: {iss} "
|
|
f"AUD: {aud} "
|
|
f"AZP: {azp} "
|
|
f"IAT: {iat} "
|
|
f"SERVER_URL: {cls.server_url()} "
|
|
f"CLIENT_ID: {cls.client_id()} "
|
|
f"NOW: {now}"
|
|
)
|
|
|
|
return valid
|
|
|
|
@staticmethod
|
|
def store_refresh_token(user_id: int, refresh_token: str) -> None:
|
|
"""Store_refresh_token."""
|
|
refresh_token_model = RefreshTokenModel.query.filter(RefreshTokenModel.user_id == user_id).first()
|
|
if refresh_token_model:
|
|
refresh_token_model.token = refresh_token
|
|
else:
|
|
refresh_token_model = RefreshTokenModel(user_id=user_id, token=refresh_token)
|
|
db.session.add(refresh_token_model)
|
|
try:
|
|
db.session.commit()
|
|
except Exception as e:
|
|
db.session.rollback()
|
|
raise RefreshTokenStorageError(
|
|
f"We could not store the refresh token. Original error is {e}",
|
|
) from e
|
|
|
|
@staticmethod
|
|
def get_refresh_token(user_id: int) -> Optional[str]:
|
|
"""Get_refresh_token."""
|
|
refresh_token_object: RefreshTokenModel = RefreshTokenModel.query.filter(
|
|
RefreshTokenModel.user_id == user_id
|
|
).first()
|
|
if refresh_token_object:
|
|
return refresh_token_object.token
|
|
return None
|
|
|
|
@classmethod
|
|
def get_auth_token_from_refresh_token(cls, refresh_token: str) -> dict:
|
|
"""Converts a refresh token to an Auth Token by calling the openid's auth endpoint."""
|
|
backend_basic_auth_string = f"{cls.client_id()}:{cls.secret_key()}"
|
|
backend_basic_auth_bytes = bytes(backend_basic_auth_string, encoding="ascii")
|
|
backend_basic_auth = base64.b64encode(backend_basic_auth_bytes)
|
|
headers = {
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
"Authorization": f"Basic {backend_basic_auth.decode('utf-8')}",
|
|
}
|
|
|
|
data = {
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
"client_id": cls.client_id(),
|
|
"client_secret": cls.secret_key(),
|
|
}
|
|
|
|
request_url = cls.open_id_endpoint_for_name("token_endpoint")
|
|
|
|
response = requests.post(request_url, data=data, headers=headers)
|
|
auth_token_object: dict = json.loads(response.text)
|
|
return auth_token_object
|