fixing some typing issues.
This commit is contained in:
parent
186727e371
commit
64e30358aa
|
@ -5,6 +5,7 @@ This is just here to make local development, testing, and demonstration easier.
|
||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
@ -15,6 +16,7 @@ from flask import redirect
|
||||||
from flask import render_template
|
from flask import render_template
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask import url_for
|
from flask import url_for
|
||||||
|
from werkzeug.wrappers import Response
|
||||||
|
|
||||||
openid_blueprint = Blueprint(
|
openid_blueprint = Blueprint(
|
||||||
"openid", __name__, template_folder="templates", static_folder="static"
|
"openid", __name__, template_folder="templates", static_folder="static"
|
||||||
|
@ -24,7 +26,7 @@ MY_SECRET_CODE = ":this_is_not_secure_do_not_use_in_production"
|
||||||
|
|
||||||
|
|
||||||
@openid_blueprint.route("/.well-known/openid-configuration", methods=["GET"])
|
@openid_blueprint.route("/.well-known/openid-configuration", methods=["GET"])
|
||||||
def well_known():
|
def well_known() -> dict:
|
||||||
"""OpenID Discovery endpoint -- as these urls can be very different from system to system,
|
"""OpenID Discovery endpoint -- as these urls can be very different from system to system,
|
||||||
this is just a small subset."""
|
this is just a small subset."""
|
||||||
host_url = request.host_url.strip("/")
|
host_url = request.host_url.strip("/")
|
||||||
|
@ -37,7 +39,7 @@ def well_known():
|
||||||
|
|
||||||
|
|
||||||
@openid_blueprint.route("/auth", methods=["GET"])
|
@openid_blueprint.route("/auth", methods=["GET"])
|
||||||
def auth():
|
def auth() -> str:
|
||||||
"""Accepts a series of parameters"""
|
"""Accepts a series of parameters"""
|
||||||
return render_template(
|
return render_template(
|
||||||
"login.html",
|
"login.html",
|
||||||
|
@ -51,7 +53,7 @@ def auth():
|
||||||
|
|
||||||
|
|
||||||
@openid_blueprint.route("/form_submit", methods=["POST"])
|
@openid_blueprint.route("/form_submit", methods=["POST"])
|
||||||
def form_submit():
|
def form_submit() -> Response | str:
|
||||||
users = get_users()
|
users = get_users()
|
||||||
if (
|
if (
|
||||||
request.values["Uname"] in users
|
request.values["Uname"] in users
|
||||||
|
@ -79,7 +81,7 @@ def form_submit():
|
||||||
|
|
||||||
|
|
||||||
@openid_blueprint.route("/token", methods=["POST"])
|
@openid_blueprint.route("/token", methods=["POST"])
|
||||||
def token():
|
def token() -> dict:
|
||||||
"""Url that will return a valid token, given the super secret sauce"""
|
"""Url that will return a valid token, given the super secret sauce"""
|
||||||
request.values.get("grant_type")
|
request.values.get("grant_type")
|
||||||
code = request.values.get("code")
|
code = request.values.get("code")
|
||||||
|
@ -90,7 +92,7 @@ def token():
|
||||||
user_details = get_users()[user_name]
|
user_details = get_users()[user_name]
|
||||||
|
|
||||||
"""Get authentication from headers."""
|
"""Get authentication from headers."""
|
||||||
authorization = request.headers.get("Authorization")
|
authorization = request.headers.get("Authorization", "Basic ")
|
||||||
authorization = authorization[6:] # Remove "Basic"
|
authorization = authorization[6:] # Remove "Basic"
|
||||||
authorization = base64.b64decode(authorization).decode("utf-8")
|
authorization = base64.b64decode(authorization).decode("utf-8")
|
||||||
client_id, client_secret = authorization.split(":")
|
client_id, client_secret = authorization.split(":")
|
||||||
|
@ -120,21 +122,21 @@ def token():
|
||||||
|
|
||||||
|
|
||||||
@openid_blueprint.route("/end_session", methods=["GET"])
|
@openid_blueprint.route("/end_session", methods=["GET"])
|
||||||
def end_session():
|
def end_session() -> Response:
|
||||||
redirect_url = request.args.get("post_logout_redirect_uri")
|
redirect_url = request.args.get("post_logout_redirect_uri", "http://localhost")
|
||||||
request.args.get("id_token_hint")
|
request.args.get("id_token_hint")
|
||||||
return redirect(redirect_url)
|
return redirect(redirect_url)
|
||||||
|
|
||||||
|
|
||||||
@openid_blueprint.route("/refresh", methods=["POST"])
|
@openid_blueprint.route("/refresh", methods=["POST"])
|
||||||
def refresh():
|
def refresh() -> str:
|
||||||
pass
|
return ""
|
||||||
|
|
||||||
|
|
||||||
permission_cache = None
|
permission_cache = None
|
||||||
|
|
||||||
|
|
||||||
def get_users():
|
def get_users() -> Any:
|
||||||
global permission_cache
|
global permission_cache
|
||||||
if not permission_cache:
|
if not permission_cache:
|
||||||
with open(current_app.config["PERMISSIONS_FILE_FULLPATH"]) as file:
|
with open(current_app.config["PERMISSIONS_FILE_FULLPATH"]) as file:
|
||||||
|
|
|
@ -198,7 +198,7 @@ def login(redirect_url: str = "/") -> Response:
|
||||||
return redirect(login_redirect_url)
|
return redirect(login_redirect_url)
|
||||||
|
|
||||||
|
|
||||||
def parse_id_token(token: str) -> dict:
|
def parse_id_token(token: str) -> Any:
|
||||||
parts = token.split(".")
|
parts = token.split(".")
|
||||||
if len(parts) != 3:
|
if len(parts) != 3:
|
||||||
raise Exception("Incorrect id token format")
|
raise Exception("Incorrect id token format")
|
||||||
|
|
|
@ -26,24 +26,24 @@ class AuthenticationProviderTypes(enum.Enum):
|
||||||
class AuthenticationService:
|
class AuthenticationService:
|
||||||
"""AuthenticationService."""
|
"""AuthenticationService."""
|
||||||
|
|
||||||
ENDPOINT_CACHE = (
|
ENDPOINT_CACHE: dict = (
|
||||||
{}
|
{}
|
||||||
) # We only need to find the openid endpoints once, then we can cache them.
|
) # We only need to find the openid endpoints once, then we can cache them.
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def client_id():
|
def client_id() -> str:
|
||||||
return current_app.config["OPEN_ID_CLIENT_ID"]
|
return current_app.config.get("OPEN_ID_CLIENT_ID", "")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def server_url():
|
def server_url() -> str:
|
||||||
return current_app.config["OPEN_ID_SERVER_URL"]
|
return current_app.config.get("OPEN_ID_SERVER_URL","")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def secret_key():
|
def secret_key() -> str:
|
||||||
return current_app.config["OPEN_ID_CLIENT_SECRET_KEY"]
|
return current_app.config.get("OPEN_ID_CLIENT_SECRET_KEY", "")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def open_id_endpoint_for_name(cls, name: str) -> None:
|
def open_id_endpoint_for_name(cls, name: str) -> str:
|
||||||
"""All openid systems provide a mapping of static names to the full path of that endpoint."""
|
"""All openid systems provide a mapping of static names to the full path of that endpoint."""
|
||||||
if name not in AuthenticationService.ENDPOINT_CACHE:
|
if name not in AuthenticationService.ENDPOINT_CACHE:
|
||||||
request_url = f"{cls.server_url()}/.well-known/openid-configuration"
|
request_url = f"{cls.server_url()}/.well-known/openid-configuration"
|
||||||
|
@ -51,7 +51,7 @@ class AuthenticationService:
|
||||||
AuthenticationService.ENDPOINT_CACHE = response.json()
|
AuthenticationService.ENDPOINT_CACHE = response.json()
|
||||||
if name not in AuthenticationService.ENDPOINT_CACHE:
|
if name not in AuthenticationService.ENDPOINT_CACHE:
|
||||||
raise Exception(f"Unknown OpenID Endpoint: {name}")
|
raise Exception(f"Unknown OpenID Endpoint: {name}")
|
||||||
return AuthenticationService.ENDPOINT_CACHE[name]
|
return AuthenticationService.ENDPOINT_CACHE.get(name, "")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_backend_url() -> str:
|
def get_backend_url() -> str:
|
||||||
|
|
Loading…
Reference in New Issue