163 lines
5.6 KiB
Python
Raw Normal View History

2024-08-01 11:07:52 +09:00
from __future__ import annotations
from typing import Awaitable, Callable, Generic, Protocol, Self, Type, TypeVar
2024-08-01 11:07:52 +09:00
from pysphinx.sphinx import (
ProcessedFinalHopPacket,
ProcessedForwardHopPacket,
SphinxPacket,
)
2024-08-07 22:36:21 +09:00
from framework import Framework
2024-08-01 11:07:52 +09:00
from protocol.config import GlobalConfig, NodeConfig
from protocol.connection import SimplexConnection
from protocol.error import PeeringDegreeReached
from protocol.gossip import Gossip
from protocol.nomssip import Nomssip, NomssipConfig, NomssipMessage
2024-08-01 11:07:52 +09:00
from protocol.sphinx import SphinxPacketBuilder
class HasIdAndLenAndBytes(Protocol):
def id(self) -> int: ...
def __len__(self) -> int: ...
def __bytes__(self) -> bytes: ...
@classmethod
def from_bytes(cls, data: bytes) -> Self: ...
T = TypeVar("T", bound=HasIdAndLenAndBytes)
class Node(Generic[T]):
2024-08-01 11:07:52 +09:00
"""
This represents any node in the network, which:
- generates/gossips mix messages (Sphinx packets)
- performs cryptographic mix (unwrapping Sphinx packets)
- generates noise
"""
def __init__(
self,
framework: Framework,
config: NodeConfig,
global_config: GlobalConfig,
# A handler called when a node receives a broadcasted message originated from the last mix.
broadcasted_msg_handler: Callable[[T], Awaitable[None]],
# A handler called when a message is fully recovered by the last mix
2024-08-01 11:07:52 +09:00
# and returns a new message to be broadcasted.
recovered_msg_handler: Callable[[bytes], Awaitable[T]],
noise_msg: T,
2024-08-01 11:07:52 +09:00
):
self.framework = framework
self.config = config
self.global_config = global_config
nomssip_config = NomssipConfig(
config.gossip.peering_degree,
global_config.transmission_rate_per_sec,
SphinxPacketBuilder.size(global_config),
config.temporal_mix,
)
2024-08-01 11:07:52 +09:00
self.nomssip = Nomssip(
framework,
nomssip_config,
2024-08-01 11:07:52 +09:00
self.__process_msg,
noise_msg=NomssipMessage[T](NomssipMessage.Flag.NOISE, noise_msg),
2024-08-01 11:07:52 +09:00
)
self.broadcast = Gossip[T](framework, config.gossip, broadcasted_msg_handler)
2024-08-01 11:07:52 +09:00
self.recovered_msg_handler = recovered_msg_handler
async def __process_msg(self, msg: NomssipMessage[T]) -> None:
2024-08-01 11:07:52 +09:00
"""
A handler to process messages received via Nomssip channel
"""
assert msg.flag == NomssipMessage.Flag.REAL
2024-08-01 11:07:52 +09:00
sphinx_packet = SphinxPacket.from_bytes(
bytes(msg.message), self.global_config.max_mix_path_length
2024-08-01 11:07:52 +09:00
)
result = await self.__process_sphinx_packet(sphinx_packet)
match result:
case SphinxPacket():
# Gossip the next Sphinx packet
t: Type[T] = type(msg.message)
await self.nomssip.publish(
NomssipMessage[T](
NomssipMessage.Flag.REAL,
t.from_bytes(result.bytes()),
)
)
2024-08-01 11:07:52 +09:00
case bytes():
# Broadcast the message fully recovered from Sphinx packets
await self.broadcast.publish(await self.recovered_msg_handler(result))
2024-08-01 11:07:52 +09:00
case None:
return
async def __process_sphinx_packet(
self, packet: SphinxPacket
) -> SphinxPacket | bytes | None:
"""
Unwrap the Sphinx packet and process the next Sphinx packet or the payload if possible
"""
try:
processed = packet.process(self.config.private_key)
match processed:
case ProcessedForwardHopPacket():
return processed.next_packet
case ProcessedFinalHopPacket():
return processed.payload.recover_plain_playload()
except ValueError:
# Return nothing, if it cannot be unwrapped by the private key of this node.
return None
def connect_mix(
self,
peer: Node,
inbound_conn: SimplexConnection[NomssipMessage[T]],
outbound_conn: SimplexConnection[NomssipMessage[T]],
2024-08-01 11:07:52 +09:00
):
2024-07-24 15:35:41 +09:00
connect_nodes(self.nomssip, peer.nomssip, inbound_conn, outbound_conn)
2024-08-01 11:07:52 +09:00
def connect_broadcast(
self,
peer: Node,
inbound_conn: SimplexConnection[T],
outbound_conn: SimplexConnection[T],
2024-08-01 11:07:52 +09:00
):
2024-07-24 15:35:41 +09:00
connect_nodes(self.broadcast, peer.broadcast, inbound_conn, outbound_conn)
2024-08-01 11:07:52 +09:00
async def send_message(self, msg: T):
2024-08-01 11:07:52 +09:00
"""
Build a Sphinx packet and gossip it to all connected peers.
"""
# Here, we handle the case in which a msg is split into multiple Sphinx packets.
# But, in practice, we expect a message to be small enough to fit in a single Sphinx packet.
sphinx_packet, _ = SphinxPacketBuilder.build(
bytes(msg),
2024-08-01 11:07:52 +09:00
self.global_config,
self.config.mix_path_length,
)
t: Type[T] = type(msg)
await self.nomssip.publish(
NomssipMessage(
NomssipMessage.Flag.REAL, t.from_bytes(sphinx_packet.bytes())
)
)
2024-07-24 15:35:41 +09:00
def connect_nodes(
self_channel: Gossip,
peer_channel: Gossip,
inbound_conn: SimplexConnection,
outbound_conn: SimplexConnection,
):
"""
Establish a duplex connection with a peer node.
"""
if not self_channel.can_accept_conn() or not peer_channel.can_accept_conn():
raise PeeringDegreeReached()
# Register a duplex connection for its own use
self_channel.add_conn(inbound_conn, outbound_conn)
# Register a duplex connection for the peer
peer_channel.add_conn(outbound_conn, inbound_conn)