fixing some typing issues.

This commit is contained in:
Dan 2022-12-01 14:40:59 -05:00
parent 186727e371
commit 64e30358aa
3 changed files with 22 additions and 20 deletions

View File

@ -5,6 +5,7 @@ This is just here to make local development, testing, and demonstration easier.
"""
import base64
import time
from typing import Any
from urllib.parse import urlencode
import jwt
@ -15,6 +16,7 @@ from flask import redirect
from flask import render_template
from flask import request
from flask import url_for
from werkzeug.wrappers import Response
openid_blueprint = Blueprint(
"openid", __name__, template_folder="templates", static_folder="static"
@ -24,7 +26,7 @@ MY_SECRET_CODE = ":this_is_not_secure_do_not_use_in_production"
@openid_blueprint.route("/.well-known/openid-configuration", methods=["GET"])
def well_known():
def well_known() -> dict:
"""OpenID Discovery endpoint -- as these urls can be very different from system to system,
this is just a small subset."""
host_url = request.host_url.strip("/")
@ -37,7 +39,7 @@ def well_known():
@openid_blueprint.route("/auth", methods=["GET"])
def auth():
def auth() -> str:
"""Accepts a series of parameters"""
return render_template(
"login.html",
@ -51,7 +53,7 @@ def auth():
@openid_blueprint.route("/form_submit", methods=["POST"])
def form_submit():
def form_submit() -> Response | str:
users = get_users()
if (
request.values["Uname"] in users
@ -79,7 +81,7 @@ def form_submit():
@openid_blueprint.route("/token", methods=["POST"])
def token():
def token() -> dict:
"""Url that will return a valid token, given the super secret sauce"""
request.values.get("grant_type")
code = request.values.get("code")
@ -90,7 +92,7 @@ def token():
user_details = get_users()[user_name]
"""Get authentication from headers."""
authorization = request.headers.get("Authorization")
authorization = request.headers.get("Authorization", "Basic ")
authorization = authorization[6:] # Remove "Basic"
authorization = base64.b64decode(authorization).decode("utf-8")
client_id, client_secret = authorization.split(":")
@ -120,21 +122,21 @@ def token():
@openid_blueprint.route("/end_session", methods=["GET"])
def end_session():
redirect_url = request.args.get("post_logout_redirect_uri")
def end_session() -> Response:
redirect_url = request.args.get("post_logout_redirect_uri", "http://localhost")
request.args.get("id_token_hint")
return redirect(redirect_url)
@openid_blueprint.route("/refresh", methods=["POST"])
def refresh():
pass
def refresh() -> str:
return ""
permission_cache = None
def get_users():
def get_users() -> Any:
global permission_cache
if not permission_cache:
with open(current_app.config["PERMISSIONS_FILE_FULLPATH"]) as file:

View File

@ -198,7 +198,7 @@ def login(redirect_url: str = "/") -> Response:
return redirect(login_redirect_url)
def parse_id_token(token: str) -> dict:
def parse_id_token(token: str) -> Any:
parts = token.split(".")
if len(parts) != 3:
raise Exception("Incorrect id token format")

View File

@ -26,24 +26,24 @@ class AuthenticationProviderTypes(enum.Enum):
class AuthenticationService:
"""AuthenticationService."""
ENDPOINT_CACHE = (
ENDPOINT_CACHE: dict = (
{}
) # We only need to find the openid endpoints once, then we can cache them.
@staticmethod
def client_id():
return current_app.config["OPEN_ID_CLIENT_ID"]
def client_id() -> str:
return current_app.config.get("OPEN_ID_CLIENT_ID", "")
@staticmethod
def server_url():
return current_app.config["OPEN_ID_SERVER_URL"]
def server_url() -> str:
return current_app.config.get("OPEN_ID_SERVER_URL","")
@staticmethod
def secret_key():
return current_app.config["OPEN_ID_CLIENT_SECRET_KEY"]
def secret_key() -> str:
return current_app.config.get("OPEN_ID_CLIENT_SECRET_KEY", "")
@classmethod
def open_id_endpoint_for_name(cls, name: str) -> None:
def open_id_endpoint_for_name(cls, name: str) -> str:
"""All openid systems provide a mapping of static names to the full path of that endpoint."""
if name not in AuthenticationService.ENDPOINT_CACHE:
request_url = f"{cls.server_url()}/.well-known/openid-configuration"
@ -51,7 +51,7 @@ class AuthenticationService:
AuthenticationService.ENDPOINT_CACHE = response.json()
if name not in AuthenticationService.ENDPOINT_CACHE:
raise Exception(f"Unknown OpenID Endpoint: {name}")
return AuthenticationService.ENDPOINT_CACHE[name]
return AuthenticationService.ENDPOINT_CACHE.get(name, "")
@staticmethod
def get_backend_url() -> str: