mirror of
https://github.com/logos-co/nomos-specs.git
synced 2025-02-01 10:06:10 +00:00
124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hashlib
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Awaitable, Callable, Self
|
|
|
|
from mixnet.connection import DuplexConnection, MixSimplexConnection, SimplexConnection
|
|
|
|
|
|
class Nomssip:
|
|
"""
|
|
A NomMix gossip channel that broadcasts messages to all connected peers.
|
|
Peers are connected via DuplexConnection.
|
|
"""
|
|
|
|
@dataclass
|
|
class Config:
|
|
transmission_rate_per_sec: int
|
|
peering_degree: int
|
|
msg_size: int
|
|
|
|
def __init__(
|
|
self,
|
|
config: Config,
|
|
handler: Callable[[bytes], Awaitable[None]],
|
|
):
|
|
self.config = config
|
|
self.conns: list[DuplexConnection] = []
|
|
# A handler to process inbound messages.
|
|
self.handler = handler
|
|
self.packet_cache: set[bytes] = set()
|
|
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
|
|
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
|
|
self.tasks: set[asyncio.Task] = set()
|
|
|
|
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
|
|
if len(self.conns) >= self.config.peering_degree:
|
|
# For simplicity of the spec, reject the connection if the peering degree is reached.
|
|
raise ValueError("The peering degree is reached.")
|
|
|
|
noise_packet = FlaggedPacket(
|
|
FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size)
|
|
).bytes()
|
|
conn = DuplexConnection(
|
|
inbound,
|
|
MixSimplexConnection(
|
|
outbound,
|
|
self.config.transmission_rate_per_sec,
|
|
noise_packet,
|
|
),
|
|
)
|
|
|
|
self.conns.append(conn)
|
|
task = asyncio.create_task(self.__process_inbound_conn(conn))
|
|
self.tasks.add(task)
|
|
# To discard the task from the set automatically when it is done.
|
|
task.add_done_callback(self.tasks.discard)
|
|
|
|
async def __process_inbound_conn(self, conn: DuplexConnection):
|
|
while True:
|
|
packet = await conn.recv()
|
|
if self.__check_update_cache(packet):
|
|
continue
|
|
|
|
packet = FlaggedPacket.from_bytes(packet)
|
|
match packet.flag:
|
|
case FlaggedPacket.Flag.NOISE:
|
|
# Drop noise packet
|
|
continue
|
|
case FlaggedPacket.Flag.REAL:
|
|
await self.__gossip_flagged_packet(packet)
|
|
await self.handler(packet.message)
|
|
|
|
async def gossip(self, msg: bytes):
|
|
"""
|
|
Gossip a message to all connected peers with prepending a message flag
|
|
"""
|
|
# The message size must be fixed.
|
|
assert len(msg) == self.config.msg_size
|
|
|
|
packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg)
|
|
await self.__gossip_flagged_packet(packet)
|
|
|
|
async def __gossip_flagged_packet(self, packet: FlaggedPacket):
|
|
"""
|
|
An internal method to send a flagged packet to all connected peers
|
|
"""
|
|
for conn in self.conns:
|
|
await conn.send(packet.bytes())
|
|
|
|
def __check_update_cache(self, packet: bytes) -> bool:
|
|
"""
|
|
Add a message to the cache, and return True if the message was already in the cache.
|
|
"""
|
|
hash = hashlib.sha256(packet).digest()
|
|
if hash in self.packet_cache:
|
|
return True
|
|
self.packet_cache.add(hash)
|
|
return False
|
|
|
|
|
|
class FlaggedPacket:
|
|
class Flag(Enum):
|
|
REAL = b"\x00"
|
|
NOISE = b"\x01"
|
|
|
|
def __init__(self, flag: Flag, message: bytes):
|
|
self.flag = flag
|
|
self.message = message
|
|
|
|
def bytes(self) -> bytes:
|
|
return self.flag.value + self.message
|
|
|
|
@classmethod
|
|
def from_bytes(cls, packet: bytes) -> Self:
|
|
"""
|
|
Parse a flagged packet from bytes
|
|
"""
|
|
if len(packet) < 1:
|
|
raise ValueError("Invalid message format")
|
|
return cls(cls.Flag(packet[:1]), packet[1:])
|