Add authentication mechanism via env variables.

This commit is contained in:
Alejandro Cabeza Romero 2025-12-19 12:26:12 +01:00
parent c3c357d09a
commit 0fd630d6a4
No known key found for this signature in database
GPG Key ID: DA3D14AE478030FD
6 changed files with 64 additions and 9 deletions

View File

@ -2,11 +2,12 @@ from asyncio import Task, gather
from typing import Literal, Optional from typing import Literal, Optional
from fastapi import FastAPI from fastapi import FastAPI
from pydantic import Field from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.datastructures import State from starlette.datastructures import State
from constants import DIR_REPO from constants import DIR_REPO
from core.authentication import Authentication
from db.blocks import BlockRepository from db.blocks import BlockRepository
from db.clients import DbClient from db.clients import DbClient
from db.transaction import TransactionRepository from db.transaction import TransactionRepository
@ -28,6 +29,18 @@ class NBESettings(BaseSettings):
node_api_port: int = Field(alias="NBE_NODE_API_PORT", default=18080) node_api_port: int = Field(alias="NBE_NODE_API_PORT", default=18080)
node_api_timeout: int = Field(alias="NBE_NODE_API_TIMEOUT", default=60) node_api_timeout: int = Field(alias="NBE_NODE_API_TIMEOUT", default=60)
node_api_protocol: str = Field(alias="NBE_NODE_API_PROTOCOL", default="http") node_api_protocol: str = Field(alias="NBE_NODE_API_PROTOCOL", default="http")
node_api_auth: Optional[Authentication] = Field(alias="NBE_NODE_API_AUTH", default=None)
@field_validator("node_api_auth", mode="before")
@classmethod
def parse_auth(cls, value: str) -> Optional[Authentication]:
if value is None:
return None
try:
return Authentication.from_string(value)
except Exception as e:
raise ValueError(f"Invalid NBE_NODE_API_AUTH: {value}") from e
class NBEState(State): class NBEState(State):

View File

@ -0,0 +1,26 @@
import base64
import dataclasses
import httpx
@dataclasses.dataclass
class Authentication:
_raw: str
type: str
credentials: str
@classmethod
def from_string(cls, string: str) -> "Authentication":
(auth_type, credentials) = string.split(" ", 1)
return cls(string, auth_type.lower(), credentials)
def for_requests(self) -> str:
return self._raw
def for_httpx(self) -> httpx.BasicAuth:
if self.type == "basic":
decoded = base64.b64decode(self.credentials).decode("utf-8")
(username, password) = decoded.split(":", 1)
return httpx.BasicAuth(username, password)
raise NotImplementedError

View File

@ -1,11 +1,13 @@
import logging import logging
from typing import TYPE_CHECKING, AsyncIterator, List from typing import TYPE_CHECKING, AsyncIterator, List, Optional
from urllib.parse import urljoin from urllib.parse import urljoin
import httpx import httpx
import requests
from pydantic import ValidationError from pydantic import ValidationError
from rusty_results import Empty, Option, Some
from third_party import requests
from core.authentication import Authentication
from node.api.base import NodeApi from node.api.base import NodeApi
from node.api.serializers.block import BlockSerializer from node.api.serializers.block import BlockSerializer
from node.api.serializers.health import HealthSerializer from node.api.serializers.health import HealthSerializer
@ -28,6 +30,9 @@ class HttpNodeApi(NodeApi):
self.port: int = settings.node_api_port self.port: int = settings.node_api_port
self.protocol: str = settings.node_api_protocol or "http" self.protocol: str = settings.node_api_protocol or "http"
self.timeout: int = settings.node_api_timeout or 60 self.timeout: int = settings.node_api_timeout or 60
self.authentication: Option[Authentication] = (
Some(settings.node_api_auth) if settings.node_api_auth else Empty()
)
@property @property
def base_url(self): def base_url(self):
@ -35,7 +40,7 @@ class HttpNodeApi(NodeApi):
async def get_health(self) -> HealthSerializer: async def get_health(self) -> HealthSerializer:
url = urljoin(self.base_url, self.ENDPOINT_INFO) url = urljoin(self.base_url, self.ENDPOINT_INFO)
response = requests.get(url, timeout=60) response = requests.get(url, auth=self.authentication, timeout=60)
if response.status_code == 200: if response.status_code == 200:
return HealthSerializer.from_healthy() return HealthSerializer.from_healthy()
else: else:
@ -45,15 +50,15 @@ class HttpNodeApi(NodeApi):
query_string = f"slot_from={slot_from}&slot_to={slot_to}" query_string = f"slot_from={slot_from}&slot_to={slot_to}"
endpoint = urljoin(self.base_url, self.ENDPOINT_BLOCKS) endpoint = urljoin(self.base_url, self.ENDPOINT_BLOCKS)
url = f"{endpoint}?{query_string}" url = f"{endpoint}?{query_string}"
response = requests.get(url, timeout=60) response = requests.get(url, auth=self.authentication, timeout=60)
python_json = response.json() python_json = response.json()
blocks = [BlockSerializer.model_validate(item) for item in python_json] blocks = [BlockSerializer.model_validate(item) for item in python_json]
return blocks return blocks
async def get_blocks_stream(self) -> AsyncIterator[BlockSerializer]: async def get_blocks_stream(self) -> AsyncIterator[BlockSerializer]:
url = urljoin(self.base_url, self.ENDPOINT_BLOCKS_STREAM) url = urljoin(self.base_url, self.ENDPOINT_BLOCKS_STREAM)
auth = self.authentication.map(lambda _auth: _auth.for_httpx()).unwrap_or(None)
async with httpx.AsyncClient(timeout=self.timeout) as client: async with httpx.AsyncClient(timeout=self.timeout, auth=auth) as client:
async with client.stream("GET", url) as response: async with client.stream("GET", url) as response:
response.raise_for_status() # TODO: Result response.raise_for_status() # TODO: Result

View File

@ -23,7 +23,7 @@ class ProofOfLeadershipSerializer(NbeSerializer, EnforceSubclassFromRandom, ABC)
class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer): class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer):
entropy_contribution: BytesFromHex = Field(description="Fr integer.") entropy_contribution: BytesFromHex = Field(description="Fr integer.")
leader_key: BytesFromIntArray = Field(description="Bytes in Integer Array format.") leader_key: BytesFromHex = Field(description="Bytes in Integer Array format.")
proof: BytesFromIntArray = Field( proof: BytesFromIntArray = Field(
description="Bytes in Integer Array format.", description="Bytes in Integer Array format.",
) )
@ -47,7 +47,7 @@ class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer):
return cls.model_validate( return cls.model_validate(
{ {
"entropy_contribution": random_bytes(32).hex(), "entropy_contribution": random_bytes(32).hex(),
"leader_key": list(random_bytes(32)), "leader_key": random_bytes(32).hex(),
"proof": list(random_bytes(128)), "proof": list(random_bytes(128)),
"public": PublicSerializer.from_random(slot), "public": PublicSerializer.from_random(slot),
"voucher_cm": random_bytes(32).hex(), "voucher_cm": random_bytes(32).hex(),

0
third_party/__init__.py vendored Normal file
View File

11
third_party/requests.py vendored Normal file
View File

@ -0,0 +1,11 @@
import requests
from rusty_results import Option
from core.authentication import Authentication
def get(url, params=None, auth: Option[Authentication] = None, **kwargs):
headers = kwargs.get("headers", {})
if auth.is_some:
headers["Authorization"] = auth.unwrap().for_requests()
return requests.get(url, params, headers=headers, **kwargs)