248 lines
8.6 KiB
Python

"""Authentication_service."""
import base64
import enum
import json
import time
from typing import Optional
import jwt
import requests
from flask import current_app
from flask import redirect
from werkzeug.wrappers import Response
from spiffworkflow_backend.models.db import db
from spiffworkflow_backend.models.refresh_token import RefreshTokenModel
class MissingAccessTokenError(Exception):
"""MissingAccessTokenError."""
class NotAuthorizedError(Exception):
pass
class RefreshTokenStorageError(Exception):
pass
class UserNotLoggedInError(Exception):
pass
# These could be either 'id' OR 'access' tokens and we can't always know which
class TokenExpiredError(Exception):
"""TokenExpiredError."""
class TokenInvalidError(Exception):
"""TokenInvalidError."""
class TokenNotProvidedError(Exception):
pass
class AuthenticationProviderTypes(enum.Enum):
"""AuthenticationServiceProviders."""
open_id = "open_id"
internal = "internal"
class AuthenticationService:
"""AuthenticationService."""
ENDPOINT_CACHE: dict = {} # We only need to find the openid endpoints once, then we can cache them.
@staticmethod
def client_id() -> str:
"""Returns the client id from the config."""
return current_app.config.get("SPIFFWORKFLOW_BACKEND_OPEN_ID_CLIENT_ID", "")
@staticmethod
def server_url() -> str:
"""Returns the server url from the config."""
return current_app.config.get("SPIFFWORKFLOW_BACKEND_OPEN_ID_SERVER_URL", "")
@staticmethod
def secret_key() -> str:
"""Returns the secret key from the config."""
return current_app.config.get("SPIFFWORKFLOW_BACKEND_OPEN_ID_CLIENT_SECRET_KEY", "")
@classmethod
def open_id_endpoint_for_name(cls, name: str) -> str:
"""All openid systems provide a mapping of static names to the full path of that endpoint."""
openid_config_url = f"{cls.server_url()}/.well-known/openid-configuration"
if name not in AuthenticationService.ENDPOINT_CACHE:
response = requests.get(openid_config_url)
AuthenticationService.ENDPOINT_CACHE = response.json()
if name not in AuthenticationService.ENDPOINT_CACHE:
raise Exception(f"Unknown OpenID Endpoint: {name}. Tried to get from {openid_config_url}")
return AuthenticationService.ENDPOINT_CACHE.get(name, "")
@staticmethod
def get_backend_url() -> str:
"""Get_backend_url."""
return str(current_app.config["SPIFFWORKFLOW_BACKEND_URL"])
def logout(self, id_token: str, redirect_url: Optional[str] = None) -> Response:
"""Logout."""
if redirect_url is None:
redirect_url = f"{self.get_backend_url()}/v1.0/logout_return"
request_url = (
self.open_id_endpoint_for_name("end_session_endpoint")
+ f"?post_logout_redirect_uri={redirect_url}&"
+ f"id_token_hint={id_token}"
)
return redirect(request_url)
@staticmethod
def generate_state(redirect_url: str) -> bytes:
"""Generate_state."""
state = base64.b64encode(bytes(str({"redirect_url": redirect_url}), "UTF-8"))
return state
def get_login_redirect_url(self, state: str, redirect_url: str = "/v1.0/login_return") -> str:
"""Get_login_redirect_url."""
return_redirect_url = f"{self.get_backend_url()}{redirect_url}"
login_redirect_url = (
self.open_id_endpoint_for_name("authorization_endpoint")
+ f"?state={state}&"
+ "response_type=code&"
+ f"client_id={self.client_id()}&"
+ "scope=openid profile email&"
+ f"redirect_uri={return_redirect_url}"
)
return login_redirect_url
def get_auth_token_object(self, code: str, redirect_url: str = "/v1.0/login_return") -> dict:
"""Get_auth_token_object."""
backend_basic_auth_string = f"{self.client_id()}:{self.secret_key()}"
backend_basic_auth_bytes = bytes(backend_basic_auth_string, encoding="ascii")
backend_basic_auth = base64.b64encode(backend_basic_auth_bytes)
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Authorization": f"Basic {backend_basic_auth.decode('utf-8')}",
}
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": f"{self.get_backend_url()}{redirect_url}",
}
request_url = self.open_id_endpoint_for_name("token_endpoint")
response = requests.post(request_url, data=data, headers=headers)
auth_token_object: dict = json.loads(response.text)
return auth_token_object
@classmethod
def validate_id_or_access_token(cls, token: str) -> bool:
"""Https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation."""
valid = True
now = round(time.time())
try:
decoded_token = jwt.decode(token, options={"verify_signature": False})
except Exception as e:
raise TokenInvalidError("Cannot decode token") from e
# give a 5 second leeway to iat in case keycloak server time doesn't match backend server
iat_clock_skew_leeway = 5
iss = decoded_token["iss"]
aud = decoded_token["aud"]
azp = decoded_token["azp"] if "azp" in decoded_token else None
iat = decoded_token["iat"]
valid_audience_values = (cls.client_id(), "account")
audience_array_in_token = aud
if isinstance(aud, str):
audience_array_in_token = [aud]
overlapping_aud_values = [x for x in audience_array_in_token if x in valid_audience_values]
if iss != cls.server_url():
valid = False
# aud could be an array or a string
elif len(overlapping_aud_values) < 1:
valid = False
elif azp and azp not in (
cls.client_id(),
"account",
):
valid = False
# make sure issued at time is not in the future
elif now + iat_clock_skew_leeway < iat:
valid = False
if valid and now > decoded_token["exp"]:
raise TokenExpiredError("Your token is expired. Please Login")
elif not valid:
current_app.logger.error(
"TOKEN INVALID: details: "
f"ISS: {iss} "
f"AUD: {aud} "
f"AZP: {azp} "
f"IAT: {iat} "
f"SERVER_URL: {cls.server_url()} "
f"CLIENT_ID: {cls.client_id()} "
f"NOW: {now}"
)
return valid
@staticmethod
def store_refresh_token(user_id: int, refresh_token: str) -> None:
"""Store_refresh_token."""
refresh_token_model = RefreshTokenModel.query.filter(RefreshTokenModel.user_id == user_id).first()
if refresh_token_model:
refresh_token_model.token = refresh_token
else:
refresh_token_model = RefreshTokenModel(user_id=user_id, token=refresh_token)
db.session.add(refresh_token_model)
try:
db.session.commit()
except Exception as e:
db.session.rollback()
raise RefreshTokenStorageError(
f"We could not store the refresh token. Original error is {e}",
) from e
@staticmethod
def get_refresh_token(user_id: int) -> Optional[str]:
"""Get_refresh_token."""
refresh_token_object: RefreshTokenModel = RefreshTokenModel.query.filter(
RefreshTokenModel.user_id == user_id
).first()
if refresh_token_object:
return refresh_token_object.token
return None
@classmethod
def get_auth_token_from_refresh_token(cls, refresh_token: str) -> dict:
"""Converts a refresh token to an Auth Token by calling the openid's auth endpoint."""
backend_basic_auth_string = f"{cls.client_id()}:{cls.secret_key()}"
backend_basic_auth_bytes = bytes(backend_basic_auth_string, encoding="ascii")
backend_basic_auth = base64.b64encode(backend_basic_auth_bytes)
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Authorization": f"Basic {backend_basic_auth.decode('utf-8')}",
}
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": cls.client_id(),
"client_secret": cls.secret_key(),
}
request_url = cls.open_id_endpoint_for_name("token_endpoint")
response = requests.post(request_url, data=data, headers=headers)
auth_token_object: dict = json.loads(response.text)
return auth_token_object