Add authentication mechanism via env variables.

This commit is contained in:
Alejandro Cabeza Romero 2025-12-19 12:26:12 +01:00
parent c3c357d09a
commit 0fd630d6a4
No known key found for this signature in database
GPG Key ID: DA3D14AE478030FD
6 changed files with 64 additions and 9 deletions

View File

@ -2,11 +2,12 @@ from asyncio import Task, gather
from typing import Literal, Optional
from fastapi import FastAPI
from pydantic import Field
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.datastructures import State
from constants import DIR_REPO
from core.authentication import Authentication
from db.blocks import BlockRepository
from db.clients import DbClient
from db.transaction import TransactionRepository
@ -28,6 +29,18 @@ class NBESettings(BaseSettings):
node_api_port: int = Field(alias="NBE_NODE_API_PORT", default=18080)
node_api_timeout: int = Field(alias="NBE_NODE_API_TIMEOUT", default=60)
node_api_protocol: str = Field(alias="NBE_NODE_API_PROTOCOL", default="http")
node_api_auth: Optional[Authentication] = Field(alias="NBE_NODE_API_AUTH", default=None)
@field_validator("node_api_auth", mode="before")
@classmethod
def parse_auth(cls, value: str) -> Optional[Authentication]:
if value is None:
return None
try:
return Authentication.from_string(value)
except Exception as e:
raise ValueError(f"Invalid NBE_NODE_API_AUTH: {value}") from e
class NBEState(State):

View File

@ -0,0 +1,26 @@
import base64
import dataclasses
import httpx
@dataclasses.dataclass
class Authentication:
_raw: str
type: str
credentials: str
@classmethod
def from_string(cls, string: str) -> "Authentication":
(auth_type, credentials) = string.split(" ", 1)
return cls(string, auth_type.lower(), credentials)
def for_requests(self) -> str:
return self._raw
def for_httpx(self) -> httpx.BasicAuth:
if self.type == "basic":
decoded = base64.b64decode(self.credentials).decode("utf-8")
(username, password) = decoded.split(":", 1)
return httpx.BasicAuth(username, password)
raise NotImplementedError

View File

@ -1,11 +1,13 @@
import logging
from typing import TYPE_CHECKING, AsyncIterator, List
from typing import TYPE_CHECKING, AsyncIterator, List, Optional
from urllib.parse import urljoin
import httpx
import requests
from pydantic import ValidationError
from rusty_results import Empty, Option, Some
from third_party import requests
from core.authentication import Authentication
from node.api.base import NodeApi
from node.api.serializers.block import BlockSerializer
from node.api.serializers.health import HealthSerializer
@ -28,6 +30,9 @@ class HttpNodeApi(NodeApi):
self.port: int = settings.node_api_port
self.protocol: str = settings.node_api_protocol or "http"
self.timeout: int = settings.node_api_timeout or 60
self.authentication: Option[Authentication] = (
Some(settings.node_api_auth) if settings.node_api_auth else Empty()
)
@property
def base_url(self):
@ -35,7 +40,7 @@ class HttpNodeApi(NodeApi):
async def get_health(self) -> HealthSerializer:
url = urljoin(self.base_url, self.ENDPOINT_INFO)
response = requests.get(url, timeout=60)
response = requests.get(url, auth=self.authentication, timeout=60)
if response.status_code == 200:
return HealthSerializer.from_healthy()
else:
@ -45,15 +50,15 @@ class HttpNodeApi(NodeApi):
query_string = f"slot_from={slot_from}&slot_to={slot_to}"
endpoint = urljoin(self.base_url, self.ENDPOINT_BLOCKS)
url = f"{endpoint}?{query_string}"
response = requests.get(url, timeout=60)
response = requests.get(url, auth=self.authentication, timeout=60)
python_json = response.json()
blocks = [BlockSerializer.model_validate(item) for item in python_json]
return blocks
async def get_blocks_stream(self) -> AsyncIterator[BlockSerializer]:
url = urljoin(self.base_url, self.ENDPOINT_BLOCKS_STREAM)
async with httpx.AsyncClient(timeout=self.timeout) as client:
auth = self.authentication.map(lambda _auth: _auth.for_httpx()).unwrap_or(None)
async with httpx.AsyncClient(timeout=self.timeout, auth=auth) as client:
async with client.stream("GET", url) as response:
response.raise_for_status() # TODO: Result

View File

@ -23,7 +23,7 @@ class ProofOfLeadershipSerializer(NbeSerializer, EnforceSubclassFromRandom, ABC)
class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer):
entropy_contribution: BytesFromHex = Field(description="Fr integer.")
leader_key: BytesFromIntArray = Field(description="Bytes in Integer Array format.")
leader_key: BytesFromHex = Field(description="Bytes in Integer Array format.")
proof: BytesFromIntArray = Field(
description="Bytes in Integer Array format.",
)
@ -47,7 +47,7 @@ class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer):
return cls.model_validate(
{
"entropy_contribution": random_bytes(32).hex(),
"leader_key": list(random_bytes(32)),
"leader_key": random_bytes(32).hex(),
"proof": list(random_bytes(128)),
"public": PublicSerializer.from_random(slot),
"voucher_cm": random_bytes(32).hex(),

0
third_party/__init__.py vendored Normal file
View File

11
third_party/requests.py vendored Normal file
View File

@ -0,0 +1,11 @@
import requests
from rusty_results import Option
from core.authentication import Authentication
def get(url, params=None, auth: Option[Authentication] = None, **kwargs):
headers = kwargs.get("headers", {})
if auth.is_some:
headers["Authorization"] = auth.unwrap().for_requests()
return requests.get(url, params, headers=headers, **kwargs)