mirror of
https://github.com/logos-blockchain/logos-blockchain-block-explorer-template.git
synced 2026-01-02 05:03:07 +00:00
Add authentication mechanism via env variables.
This commit is contained in:
parent
c3c357d09a
commit
0fd630d6a4
@ -2,11 +2,12 @@ from asyncio import Task, gather
|
||||
from typing import Literal, Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import Field
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from starlette.datastructures import State
|
||||
|
||||
from constants import DIR_REPO
|
||||
from core.authentication import Authentication
|
||||
from db.blocks import BlockRepository
|
||||
from db.clients import DbClient
|
||||
from db.transaction import TransactionRepository
|
||||
@ -28,6 +29,18 @@ class NBESettings(BaseSettings):
|
||||
node_api_port: int = Field(alias="NBE_NODE_API_PORT", default=18080)
|
||||
node_api_timeout: int = Field(alias="NBE_NODE_API_TIMEOUT", default=60)
|
||||
node_api_protocol: str = Field(alias="NBE_NODE_API_PROTOCOL", default="http")
|
||||
node_api_auth: Optional[Authentication] = Field(alias="NBE_NODE_API_AUTH", default=None)
|
||||
|
||||
@field_validator("node_api_auth", mode="before")
|
||||
@classmethod
|
||||
def parse_auth(cls, value: str) -> Optional[Authentication]:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return Authentication.from_string(value)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid NBE_NODE_API_AUTH: {value}") from e
|
||||
|
||||
|
||||
class NBEState(State):
|
||||
|
||||
26
src/core/authentication.py
Normal file
26
src/core/authentication.py
Normal file
@ -0,0 +1,26 @@
|
||||
import base64
|
||||
import dataclasses
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Authentication:
|
||||
_raw: str
|
||||
type: str
|
||||
credentials: str
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, string: str) -> "Authentication":
|
||||
(auth_type, credentials) = string.split(" ", 1)
|
||||
return cls(string, auth_type.lower(), credentials)
|
||||
|
||||
def for_requests(self) -> str:
|
||||
return self._raw
|
||||
|
||||
def for_httpx(self) -> httpx.BasicAuth:
|
||||
if self.type == "basic":
|
||||
decoded = base64.b64decode(self.credentials).decode("utf-8")
|
||||
(username, password) = decoded.split(":", 1)
|
||||
return httpx.BasicAuth(username, password)
|
||||
raise NotImplementedError
|
||||
@ -1,11 +1,13 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, AsyncIterator, List
|
||||
from typing import TYPE_CHECKING, AsyncIterator, List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from pydantic import ValidationError
|
||||
from rusty_results import Empty, Option, Some
|
||||
from third_party import requests
|
||||
|
||||
from core.authentication import Authentication
|
||||
from node.api.base import NodeApi
|
||||
from node.api.serializers.block import BlockSerializer
|
||||
from node.api.serializers.health import HealthSerializer
|
||||
@ -28,6 +30,9 @@ class HttpNodeApi(NodeApi):
|
||||
self.port: int = settings.node_api_port
|
||||
self.protocol: str = settings.node_api_protocol or "http"
|
||||
self.timeout: int = settings.node_api_timeout or 60
|
||||
self.authentication: Option[Authentication] = (
|
||||
Some(settings.node_api_auth) if settings.node_api_auth else Empty()
|
||||
)
|
||||
|
||||
@property
|
||||
def base_url(self):
|
||||
@ -35,7 +40,7 @@ class HttpNodeApi(NodeApi):
|
||||
|
||||
async def get_health(self) -> HealthSerializer:
|
||||
url = urljoin(self.base_url, self.ENDPOINT_INFO)
|
||||
response = requests.get(url, timeout=60)
|
||||
response = requests.get(url, auth=self.authentication, timeout=60)
|
||||
if response.status_code == 200:
|
||||
return HealthSerializer.from_healthy()
|
||||
else:
|
||||
@ -45,15 +50,15 @@ class HttpNodeApi(NodeApi):
|
||||
query_string = f"slot_from={slot_from}&slot_to={slot_to}"
|
||||
endpoint = urljoin(self.base_url, self.ENDPOINT_BLOCKS)
|
||||
url = f"{endpoint}?{query_string}"
|
||||
response = requests.get(url, timeout=60)
|
||||
response = requests.get(url, auth=self.authentication, timeout=60)
|
||||
python_json = response.json()
|
||||
blocks = [BlockSerializer.model_validate(item) for item in python_json]
|
||||
return blocks
|
||||
|
||||
async def get_blocks_stream(self) -> AsyncIterator[BlockSerializer]:
|
||||
url = urljoin(self.base_url, self.ENDPOINT_BLOCKS_STREAM)
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
auth = self.authentication.map(lambda _auth: _auth.for_httpx()).unwrap_or(None)
|
||||
async with httpx.AsyncClient(timeout=self.timeout, auth=auth) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
response.raise_for_status() # TODO: Result
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class ProofOfLeadershipSerializer(NbeSerializer, EnforceSubclassFromRandom, ABC)
|
||||
|
||||
class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer):
|
||||
entropy_contribution: BytesFromHex = Field(description="Fr integer.")
|
||||
leader_key: BytesFromIntArray = Field(description="Bytes in Integer Array format.")
|
||||
leader_key: BytesFromHex = Field(description="Bytes in Integer Array format.")
|
||||
proof: BytesFromIntArray = Field(
|
||||
description="Bytes in Integer Array format.",
|
||||
)
|
||||
@ -47,7 +47,7 @@ class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer):
|
||||
return cls.model_validate(
|
||||
{
|
||||
"entropy_contribution": random_bytes(32).hex(),
|
||||
"leader_key": list(random_bytes(32)),
|
||||
"leader_key": random_bytes(32).hex(),
|
||||
"proof": list(random_bytes(128)),
|
||||
"public": PublicSerializer.from_random(slot),
|
||||
"voucher_cm": random_bytes(32).hex(),
|
||||
|
||||
0
third_party/__init__.py
vendored
Normal file
0
third_party/__init__.py
vendored
Normal file
11
third_party/requests.py
vendored
Normal file
11
third_party/requests.py
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
import requests
|
||||
from rusty_results import Option
|
||||
|
||||
from core.authentication import Authentication
|
||||
|
||||
|
||||
def get(url, params=None, auth: Option[Authentication] = None, **kwargs):
|
||||
headers = kwargs.get("headers", {})
|
||||
if auth.is_some:
|
||||
headers["Authorization"] = auth.unwrap().for_requests()
|
||||
return requests.get(url, params, headers=headers, **kwargs)
|
||||
Loading…
x
Reference in New Issue
Block a user