nomos-specs/mixnet/nomssip.py

124 lines
3.9 KiB
Python
Raw Normal View History

2024-07-11 17:32:48 +09:00
from __future__ import annotations
2024-07-10 09:23:33 +09:00
import asyncio
import hashlib
2024-07-11 14:29:00 +09:00
from dataclasses import dataclass
from enum import Enum
2024-07-11 17:32:48 +09:00
from typing import Awaitable, Callable, Self
2024-07-10 09:23:33 +09:00
2024-07-11 14:29:00 +09:00
from mixnet.connection import DuplexConnection, MixSimplexConnection, SimplexConnection
2024-07-10 09:23:33 +09:00
2024-07-11 10:10:29 +09:00
class Nomssip:
2024-07-10 09:55:51 +09:00
"""
2024-07-11 10:10:29 +09:00
A NomMix gossip channel that broadcasts messages to all connected peers.
2024-07-10 09:55:51 +09:00
Peers are connected via DuplexConnection.
"""
2024-07-11 14:29:00 +09:00
@dataclass
class Config:
transmission_rate_per_sec: int
peering_degree: int
msg_size: int
2024-07-10 09:23:33 +09:00
def __init__(
self,
2024-07-11 14:29:00 +09:00
config: Config,
handler: Callable[[bytes], Awaitable[None]],
2024-07-10 09:23:33 +09:00
):
self.config = config
2024-07-11 16:56:19 +09:00
self.conns: list[DuplexConnection] = []
# A handler to process inbound messages.
2024-07-10 09:23:33 +09:00
self.handler = handler
2024-07-11 16:56:19 +09:00
self.packet_cache: set[bytes] = set()
2024-07-10 09:23:33 +09:00
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
2024-07-11 16:56:19 +09:00
self.tasks: set[asyncio.Task] = set()
2024-07-10 09:23:33 +09:00
2024-07-11 14:29:00 +09:00
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
2024-07-10 09:23:33 +09:00
if len(self.conns) >= self.config.peering_degree:
# For simplicity of the spec, reject the connection if the peering degree is reached.
raise ValueError("The peering degree is reached.")
2024-07-11 17:32:48 +09:00
noise_packet = FlaggedPacket(
FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size)
).bytes()
2024-07-11 14:29:00 +09:00
conn = DuplexConnection(
inbound,
MixSimplexConnection(
outbound,
self.config.transmission_rate_per_sec,
noise_packet,
),
)
2024-07-10 09:23:33 +09:00
self.conns.append(conn)
task = asyncio.create_task(self.__process_inbound_conn(conn))
self.tasks.add(task)
# To discard the task from the set automatically when it is done.
task.add_done_callback(self.tasks.discard)
async def __process_inbound_conn(self, conn: DuplexConnection):
while True:
2024-07-11 14:29:00 +09:00
packet = await conn.recv()
if self.__check_update_cache(packet):
2024-07-10 09:23:33 +09:00
continue
2024-07-11 17:32:48 +09:00
packet = FlaggedPacket.from_bytes(packet)
match packet.flag:
case FlaggedPacket.Flag.NOISE:
2024-07-11 14:29:00 +09:00
# Drop noise packet
continue
2024-07-11 17:32:48 +09:00
case FlaggedPacket.Flag.REAL:
await self.__gossip_flagged_packet(packet)
await self.handler(packet.message)
2024-07-11 14:29:00 +09:00
async def gossip(self, msg: bytes):
"""
Gossip a message to all connected peers with prepending a message flag
"""
# The message size must be fixed.
assert len(msg) == self.config.msg_size
2024-07-11 17:32:48 +09:00
packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg)
await self.__gossip_flagged_packet(packet)
2024-07-10 09:23:33 +09:00
2024-07-11 17:32:48 +09:00
async def __gossip_flagged_packet(self, packet: FlaggedPacket):
2024-07-11 14:29:00 +09:00
"""
An internal method to send a flagged packet to all connected peers
"""
2024-07-10 09:23:33 +09:00
for conn in self.conns:
2024-07-11 17:32:48 +09:00
await conn.send(packet.bytes())
2024-07-11 14:29:00 +09:00
def __check_update_cache(self, packet: bytes) -> bool:
"""
Add a message to the cache, and return True if the message was already in the cache.
"""
hash = hashlib.sha256(packet).digest()
if hash in self.packet_cache:
return True
self.packet_cache.add(hash)
return False
2024-07-11 17:32:48 +09:00
class FlaggedPacket:
class Flag(Enum):
2024-07-11 14:29:00 +09:00
REAL = b"\x00"
NOISE = b"\x01"
2024-07-11 17:32:48 +09:00
def __init__(self, flag: Flag, message: bytes):
self.flag = flag
self.message = message
def bytes(self) -> bytes:
return self.flag.value + self.message
2024-07-11 14:29:00 +09:00
2024-07-11 17:32:48 +09:00
@classmethod
def from_bytes(cls, packet: bytes) -> Self:
2024-07-11 14:29:00 +09:00
"""
2024-07-11 17:32:48 +09:00
Parse a flagged packet from bytes
2024-07-11 14:29:00 +09:00
"""
2024-07-11 17:32:48 +09:00
if len(packet) < 1:
2024-07-11 14:29:00 +09:00
raise ValueError("Invalid message format")
2024-07-11 17:32:48 +09:00
return cls(cls.Flag(packet[:1]), packet[1:])