fixing some typing issues.

This commit is contained in:
Dan 2022-12-01 14:40:59 -05:00
parent 186727e371
commit 64e30358aa
3 changed files with 22 additions and 20 deletions

View File

@ -5,6 +5,7 @@ This is just here to make local development, testing, and demonstration easier.
""" """
import base64 import base64
import time import time
from typing import Any
from urllib.parse import urlencode from urllib.parse import urlencode
import jwt import jwt
@ -15,6 +16,7 @@ from flask import redirect
from flask import render_template from flask import render_template
from flask import request from flask import request
from flask import url_for from flask import url_for
from werkzeug.wrappers import Response
openid_blueprint = Blueprint( openid_blueprint = Blueprint(
"openid", __name__, template_folder="templates", static_folder="static" "openid", __name__, template_folder="templates", static_folder="static"
@ -24,7 +26,7 @@ MY_SECRET_CODE = ":this_is_not_secure_do_not_use_in_production"
@openid_blueprint.route("/.well-known/openid-configuration", methods=["GET"]) @openid_blueprint.route("/.well-known/openid-configuration", methods=["GET"])
def well_known(): def well_known() -> dict:
"""OpenID Discovery endpoint -- as these urls can be very different from system to system, """OpenID Discovery endpoint -- as these urls can be very different from system to system,
this is just a small subset.""" this is just a small subset."""
host_url = request.host_url.strip("/") host_url = request.host_url.strip("/")
@ -37,7 +39,7 @@ def well_known():
@openid_blueprint.route("/auth", methods=["GET"]) @openid_blueprint.route("/auth", methods=["GET"])
def auth(): def auth() -> str:
"""Accepts a series of parameters""" """Accepts a series of parameters"""
return render_template( return render_template(
"login.html", "login.html",
@ -51,7 +53,7 @@ def auth():
@openid_blueprint.route("/form_submit", methods=["POST"]) @openid_blueprint.route("/form_submit", methods=["POST"])
def form_submit(): def form_submit() -> Response | str:
users = get_users() users = get_users()
if ( if (
request.values["Uname"] in users request.values["Uname"] in users
@ -79,7 +81,7 @@ def form_submit():
@openid_blueprint.route("/token", methods=["POST"]) @openid_blueprint.route("/token", methods=["POST"])
def token(): def token() -> dict:
"""Url that will return a valid token, given the super secret sauce""" """Url that will return a valid token, given the super secret sauce"""
request.values.get("grant_type") request.values.get("grant_type")
code = request.values.get("code") code = request.values.get("code")
@ -90,7 +92,7 @@ def token():
user_details = get_users()[user_name] user_details = get_users()[user_name]
"""Get authentication from headers.""" """Get authentication from headers."""
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization", "Basic ")
authorization = authorization[6:] # Remove "Basic" authorization = authorization[6:] # Remove "Basic"
authorization = base64.b64decode(authorization).decode("utf-8") authorization = base64.b64decode(authorization).decode("utf-8")
client_id, client_secret = authorization.split(":") client_id, client_secret = authorization.split(":")
@ -120,21 +122,21 @@ def token():
@openid_blueprint.route("/end_session", methods=["GET"]) @openid_blueprint.route("/end_session", methods=["GET"])
def end_session(): def end_session() -> Response:
redirect_url = request.args.get("post_logout_redirect_uri") redirect_url = request.args.get("post_logout_redirect_uri", "http://localhost")
request.args.get("id_token_hint") request.args.get("id_token_hint")
return redirect(redirect_url) return redirect(redirect_url)
@openid_blueprint.route("/refresh", methods=["POST"]) @openid_blueprint.route("/refresh", methods=["POST"])
def refresh(): def refresh() -> str:
pass return ""
permission_cache = None permission_cache = None
def get_users(): def get_users() -> Any:
global permission_cache global permission_cache
if not permission_cache: if not permission_cache:
with open(current_app.config["PERMISSIONS_FILE_FULLPATH"]) as file: with open(current_app.config["PERMISSIONS_FILE_FULLPATH"]) as file:

View File

@ -198,7 +198,7 @@ def login(redirect_url: str = "/") -> Response:
return redirect(login_redirect_url) return redirect(login_redirect_url)
def parse_id_token(token: str) -> dict: def parse_id_token(token: str) -> Any:
parts = token.split(".") parts = token.split(".")
if len(parts) != 3: if len(parts) != 3:
raise Exception("Incorrect id token format") raise Exception("Incorrect id token format")

View File

@ -26,24 +26,24 @@ class AuthenticationProviderTypes(enum.Enum):
class AuthenticationService: class AuthenticationService:
"""AuthenticationService.""" """AuthenticationService."""
ENDPOINT_CACHE = ( ENDPOINT_CACHE: dict = (
{} {}
) # We only need to find the openid endpoints once, then we can cache them. ) # We only need to find the openid endpoints once, then we can cache them.
@staticmethod @staticmethod
def client_id(): def client_id() -> str:
return current_app.config["OPEN_ID_CLIENT_ID"] return current_app.config.get("OPEN_ID_CLIENT_ID", "")
@staticmethod @staticmethod
def server_url(): def server_url() -> str:
return current_app.config["OPEN_ID_SERVER_URL"] return current_app.config.get("OPEN_ID_SERVER_URL","")
@staticmethod @staticmethod
def secret_key(): def secret_key() -> str:
return current_app.config["OPEN_ID_CLIENT_SECRET_KEY"] return current_app.config.get("OPEN_ID_CLIENT_SECRET_KEY", "")
@classmethod @classmethod
def open_id_endpoint_for_name(cls, name: str) -> None: def open_id_endpoint_for_name(cls, name: str) -> str:
"""All openid systems provide a mapping of static names to the full path of that endpoint.""" """All openid systems provide a mapping of static names to the full path of that endpoint."""
if name not in AuthenticationService.ENDPOINT_CACHE: if name not in AuthenticationService.ENDPOINT_CACHE:
request_url = f"{cls.server_url()}/.well-known/openid-configuration" request_url = f"{cls.server_url()}/.well-known/openid-configuration"
@ -51,7 +51,7 @@ class AuthenticationService:
AuthenticationService.ENDPOINT_CACHE = response.json() AuthenticationService.ENDPOINT_CACHE = response.json()
if name not in AuthenticationService.ENDPOINT_CACHE: if name not in AuthenticationService.ENDPOINT_CACHE:
raise Exception(f"Unknown OpenID Endpoint: {name}") raise Exception(f"Unknown OpenID Endpoint: {name}")
return AuthenticationService.ENDPOINT_CACHE[name] return AuthenticationService.ENDPOINT_CACHE.get(name, "")
@staticmethod @staticmethod
def get_backend_url() -> str: def get_backend_url() -> str: