diff --git a/src/core/app.py b/src/core/app.py index b9c8b65..eb4a835 100644 --- a/src/core/app.py +++ b/src/core/app.py @@ -2,11 +2,12 @@ from asyncio import Task, gather from typing import Literal, Optional from fastapi import FastAPI -from pydantic import Field +from pydantic import Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.datastructures import State from constants import DIR_REPO +from core.authentication import Authentication from db.blocks import BlockRepository from db.clients import DbClient from db.transaction import TransactionRepository @@ -28,6 +29,18 @@ class NBESettings(BaseSettings): node_api_port: int = Field(alias="NBE_NODE_API_PORT", default=18080) node_api_timeout: int = Field(alias="NBE_NODE_API_TIMEOUT", default=60) node_api_protocol: str = Field(alias="NBE_NODE_API_PROTOCOL", default="http") + node_api_auth: Optional[Authentication] = Field(alias="NBE_NODE_API_AUTH", default=None) + + @field_validator("node_api_auth", mode="before") + @classmethod + def parse_auth(cls, value: str) -> Optional[Authentication]: + if value is None: + return None + + try: + return Authentication.from_string(value) + except Exception as e: + raise ValueError(f"Invalid NBE_NODE_API_AUTH: {value}") from e class NBEState(State): diff --git a/src/core/authentication.py b/src/core/authentication.py new file mode 100644 index 0000000..4089663 --- /dev/null +++ b/src/core/authentication.py @@ -0,0 +1,26 @@ +import base64 +import dataclasses + +import httpx + + +@dataclasses.dataclass +class Authentication: + _raw: str + type: str + credentials: str + + @classmethod + def from_string(cls, string: str) -> "Authentication": + (auth_type, credentials) = string.split(" ", 1) + return cls(string, auth_type.lower(), credentials) + + def for_requests(self) -> str: + return self._raw + + def for_httpx(self) -> httpx.BasicAuth: + if self.type == "basic": + decoded = base64.b64decode(self.credentials).decode("utf-8") + (username, password) = decoded.split(":", 1) + return httpx.BasicAuth(username, password) + raise NotImplementedError diff --git a/src/node/api/http.py b/src/node/api/http.py index 9b1ff6c..2e83b03 100644 --- a/src/node/api/http.py +++ b/src/node/api/http.py @@ -1,11 +1,13 @@ import logging -from typing import TYPE_CHECKING, AsyncIterator, List +from typing import TYPE_CHECKING, AsyncIterator, List, Optional from urllib.parse import urljoin import httpx -import requests from pydantic import ValidationError +from rusty_results import Empty, Option, Some +from third_party import requests +from core.authentication import Authentication from node.api.base import NodeApi from node.api.serializers.block import BlockSerializer from node.api.serializers.health import HealthSerializer @@ -28,6 +30,9 @@ class HttpNodeApi(NodeApi): self.port: int = settings.node_api_port self.protocol: str = settings.node_api_protocol or "http" self.timeout: int = settings.node_api_timeout or 60 + self.authentication: Option[Authentication] = ( + Some(settings.node_api_auth) if settings.node_api_auth else Empty() + ) @property def base_url(self): @@ -35,7 +40,7 @@ class HttpNodeApi(NodeApi): async def get_health(self) -> HealthSerializer: url = urljoin(self.base_url, self.ENDPOINT_INFO) - response = requests.get(url, timeout=60) + response = requests.get(url, auth=self.authentication, timeout=60) if response.status_code == 200: return HealthSerializer.from_healthy() else: @@ -45,15 +50,15 @@ class HttpNodeApi(NodeApi): query_string = f"slot_from={slot_from}&slot_to={slot_to}" endpoint = urljoin(self.base_url, self.ENDPOINT_BLOCKS) url = f"{endpoint}?{query_string}" - response = requests.get(url, timeout=60) + response = requests.get(url, auth=self.authentication, timeout=60) python_json = response.json() blocks = [BlockSerializer.model_validate(item) for item in python_json] return blocks async def get_blocks_stream(self) -> AsyncIterator[BlockSerializer]: url = urljoin(self.base_url, self.ENDPOINT_BLOCKS_STREAM) - - async with httpx.AsyncClient(timeout=self.timeout) as client: + auth = self.authentication.map(lambda _auth: _auth.for_httpx()).unwrap_or(None) + async with httpx.AsyncClient(timeout=self.timeout, auth=auth) as client: async with client.stream("GET", url) as response: response.raise_for_status() # TODO: Result diff --git a/src/node/api/serializers/proof_of_leadership.py b/src/node/api/serializers/proof_of_leadership.py index 4a550b4..5636ca4 100644 --- a/src/node/api/serializers/proof_of_leadership.py +++ b/src/node/api/serializers/proof_of_leadership.py @@ -23,7 +23,7 @@ class ProofOfLeadershipSerializer(NbeSerializer, EnforceSubclassFromRandom, ABC) class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer): entropy_contribution: BytesFromHex = Field(description="Fr integer.") - leader_key: BytesFromIntArray = Field(description="Bytes in Integer Array format.") + leader_key: BytesFromHex = Field(description="Bytes in Integer Array format.") proof: BytesFromIntArray = Field( description="Bytes in Integer Array format.", ) @@ -47,7 +47,7 @@ class Groth16LeaderProofSerializer(ProofOfLeadershipSerializer, NbeSerializer): return cls.model_validate( { "entropy_contribution": random_bytes(32).hex(), - "leader_key": list(random_bytes(32)), + "leader_key": random_bytes(32).hex(), "proof": list(random_bytes(128)), "public": PublicSerializer.from_random(slot), "voucher_cm": random_bytes(32).hex(), diff --git a/third_party/__init__.py b/third_party/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/third_party/requests.py b/third_party/requests.py new file mode 100644 index 0000000..3b2a5ad --- /dev/null +++ b/third_party/requests.py @@ -0,0 +1,11 @@ +import requests +from rusty_results import Option + +from core.authentication import Authentication + + +def get(url, params=None, auth: Option[Authentication] = None, **kwargs): + headers = kwargs.get("headers", {}) + if auth.is_some: + headers["Authorization"] = auth.unwrap().for_requests() + return requests.get(url, params, headers=headers, **kwargs)