mirror of
https://github.com/logos-blockchain/logos-blockchain-block-explorer-template.git
synced 2026-01-02 05:03:07 +00:00
Add authentication mechanism via env variables.
This commit is contained in:
parent
c3c357d09a
commit
0fd630d6a4
@ -2,11 +2,12 @@ from asyncio import Task, gather
|
|||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pydantic import Field
|
from pydantic import Field, field_validator
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from starlette.datastructures import State
|
from starlette.datastructures import State
|
||||||
|
|
||||||
from constants import DIR_REPO
|
from constants import DIR_REPO
|
||||||
|
from core.authentication import Authentication
|
||||||
from db.blocks import BlockRepository
|
from db.blocks import BlockRepository
|
||||||
from db.clients import DbClient
|
from db.clients import DbClient
|
||||||
from db.transaction import TransactionRepository
|
from db.transaction import TransactionRepository
|
||||||
@ -28,6 +29,18 @@ class NBESettings(BaseSettings):
|
|||||||
node_api_port: int = Field(alias="NBE_NODE_API_PORT", default=18080)
|
node_api_port: int = Field(alias="NBE_NODE_API_PORT", default=18080)
|
||||||
node_api_timeout: int = Field(alias="NBE_NODE_API_TIMEOUT", default=60)
|
node_api_timeout: int = Field(alias="NBE_NODE_API_TIMEOUT", default=60)
|
||||||
node_api_protocol: str = Field(alias="NBE_NODE_API_PROTOCOL", default="http")
|
node_api_protocol: str = Field(alias="NBE_NODE_API_PROTOCOL", default="http")
|
||||||
|
node_api_auth: Optional[Authentication] = Field(alias="NBE_NODE_API_AUTH", default=None)
|
||||||
|
|
||||||
|
@field_validator("node_api_auth", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def parse_auth(cls, value: str) -> Optional[Authentication]:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return Authentication.from_string(value)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid NBE_NODE_API_AUTH: {value}") from e
|
||||||
|
|
||||||
|
|
||||||
class NBEState(State):
|
class NBEState(State):
|
||||||
|
|||||||
26
src/core/authentication.py
Normal file
26
src/core/authentication.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import base64
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Authentication:
|
||||||
|
_raw: str
|
||||||
|
type: str
|
||||||
|
credentials: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, string: str) -> "Authentication":
|
||||||
|
(auth_type, credentials) = string.split(" ", 1)
|
||||||
|
return cls(string, auth_type.lower(), credentials)
|
||||||
|
|
||||||
|
def for_requests(self) -> str:
|
||||||
|
return self._raw
|
||||||
|
|
||||||
|
def for_httpx(self) -> httpx.BasicAuth:
|
||||||
|
if self.type == "basic":
|
||||||
|
decoded = base64.b64decode(self.credentials).decode("utf-8")
|
||||||
|
(username, password) = decoded.split(":", 1)
|
||||||
|
return httpx.BasicAuth(username, password)
|
||||||
|
raise NotImplementedError
|
||||||
@ -1,11 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, AsyncIterator, List
|
from typing import TYPE_CHECKING, AsyncIterator, List, Optional
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import requests
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
from rusty_results import Empty, Option, Some
|
||||||
|
from third_party import requests
|
||||||
|
|
||||||
|
from core.authentication import Authentication
|
||||||
from node.api.base import NodeApi
|
from node.api.base import NodeApi
|
||||||
from node.api.serializers.block import BlockSerializer
|
from node.api.serializers.block import BlockSerializer
|
||||||
from node.api.serializers.health import HealthSerializer
|
from node.api.serializers.health import HealthSerializer
|
||||||
@ -28,6 +30,9 @@ class HttpNodeApi(NodeApi):
|
|||||||
self.port: int = settings.node_api_port
|
self.port: int = settings.node_api_port
|
||||||
self.protocol: str = settings.node_api_protocol or "http"
|
self.protocol: str = settings.node_api_protocol or "http"
|
||||||
self.timeout: int = settings.node_api_timeout or 60
|
self.timeout: int = settings.node_api_timeout or 60
|
||||||
|
self.authentication: Option[Authentication] = (
|
||||||
|
Some(settings.node_api_auth) if settings.node_api_auth else Empty()
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_url(self):
|
def base_url(self):
|
||||||
@ -35,7 +40,7 @@ class HttpNodeApi(NodeApi):
|
|||||||
|
|
||||||
async def get_health(self) -> HealthSerializer:
|
async def get_health(self) -> HealthSerializer:
|
||||||
url = urljoin(self.base_url, self.ENDPOINT_INFO)
|
url = urljoin(self.base_url, self.ENDPOINT_INFO)
|
||||||
response = requests.get(url, timeout=60)
|
response = requests.get(url, auth=self.authentication, timeout=60)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return HealthSerializer.from_healthy()
|
return HealthSerializer.from_healthy()
|
||||||
else:
|
else:
|
||||||
@ -45,15 +50,15 @@ class HttpNodeApi(NodeApi):
|
|||||||
query_string = f"slot_from={slot_from}&slot_to={slot_to}"
|
query_string = f"slot_from={slot_from}&slot_to={slot_to}"
|
||||||
endpoint = urljoin(self.base_url, self.ENDPOINT_BLOCKS)
|
endpoint = urljoin(self.base_url, self.ENDPOINT_BLOCKS)
|
||||||
url = f"{endpoint}?{query_string}"
|
url = f"{endpoint}?{query_string}"
|
||||||
response = requests.get(url, timeout=60)
|
response = requests.get(url, auth=self.authentication, timeout=60)
|
||||||
python_json = response.json()
|
python_json = response.json()
|
||||||
blocks = [BlockSerializer.model_validate(item) for item in python_json]
|
blocks = [BlockSerializer.model_validate(item) for item in python_json]
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
async def get_blocks_stream(self) -> AsyncIterator[BlockSerializer]:
|
async def get_blocks_stream(self) -> AsyncIterator[BlockSerializer]:
|
||||||
url = urljoin(self.base_url, self.ENDPOINT_BLOCKS_STREAM)
|
url = urljoin(self.base_url, self.ENDPOINT_BLOCKS_STREAM)
|
||||||
|
auth = self.authentication.map(lambda _auth: _auth.for_httpx()).unwrap_or(None)
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
async with httpx.AsyncClient(timeout=self.timeout, auth=auth) as client:
|
||||||
async with client.stream("GET", url) as response:
|
async with client.stream("GET", url) as response:
|
||||||
response.raise_for_status() # TODO: Result
|
response.raise_for_status() # TODO: Result
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ class ProofOfLeadershipSerializer(NbeSerializer, EnforceSubclassFromRandom, ABC)
|
|||||||
|
|
||||||
class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer):
|
class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer):
|
||||||
entropy_contribution: BytesFromHex = Field(description="Fr integer.")
|
entropy_contribution: BytesFromHex = Field(description="Fr integer.")
|
||||||
leader_key: BytesFromIntArray = Field(description="Bytes in Integer Array format.")
|
leader_key: BytesFromHex = Field(description="Bytes in Integer Array format.")
|
||||||
proof: BytesFromIntArray = Field(
|
proof: BytesFromIntArray = Field(
|
||||||
description="Bytes in Integer Array format.",
|
description="Bytes in Integer Array format.",
|
||||||
)
|
)
|
||||||
@ -47,7 +47,7 @@ class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer):
|
|||||||
return cls.model_validate(
|
return cls.model_validate(
|
||||||
{
|
{
|
||||||
"entropy_contribution": random_bytes(32).hex(),
|
"entropy_contribution": random_bytes(32).hex(),
|
||||||
"leader_key": list(random_bytes(32)),
|
"leader_key": random_bytes(32).hex(),
|
||||||
"proof": list(random_bytes(128)),
|
"proof": list(random_bytes(128)),
|
||||||
"public": PublicSerializer.from_random(slot),
|
"public": PublicSerializer.from_random(slot),
|
||||||
"voucher_cm": random_bytes(32).hex(),
|
"voucher_cm": random_bytes(32).hex(),
|
||||||
|
|||||||
0
third_party/__init__.py
vendored
Normal file
0
third_party/__init__.py
vendored
Normal file
11
third_party/requests.py
vendored
Normal file
11
third_party/requests.py
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import requests
|
||||||
|
from rusty_results import Option
|
||||||
|
|
||||||
|
from core.authentication import Authentication
|
||||||
|
|
||||||
|
|
||||||
|
def get(url, params=None, auth: Option[Authentication] = None, **kwargs):
|
||||||
|
headers = kwargs.get("headers", {})
|
||||||
|
if auth.is_some:
|
||||||
|
headers["Authorization"] = auth.unwrap().for_requests()
|
||||||
|
return requests.get(url, params, headers=headers, **kwargs)
|
||||||
Loading…
x
Reference in New Issue
Block a user