from __future__ import annotations import hashlib import random from dataclasses import dataclass from enum import Enum from typing import Awaitable, Callable, Self, override from framework import Framework from protocol.connection import ( DuplexConnection, MixSimplexConnection, SimplexConnection, ) from protocol.error import PeeringDegreeReached from protocol.gossip import Gossip, GossipConfig from protocol.temporalmix import TemporalMixConfig @dataclass class NomssipConfig(GossipConfig): transmission_rate_per_sec: int msg_size: int temporal_mix: TemporalMixConfig class Nomssip(Gossip): """ A NomMix gossip channel that extends the Gossip channel by adding global transmission rate and noise generation. """ def __init__( self, framework: Framework, config: NomssipConfig, handler: Callable[[bytes], Awaitable[None]], ): super().__init__(framework, config, handler) self.config = config @override def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection): noise_packet = FlaggedPacket( FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size) ).bytes() super().add_conn( inbound, MixSimplexConnection( self.framework, outbound, self.config.transmission_rate_per_sec, noise_packet, self.config.temporal_mix, ), ) @override async def process_inbound_msg(self, msg: bytes): packet = FlaggedPacket.from_bytes(msg) match packet.flag: case FlaggedPacket.Flag.NOISE: # Drop noise packet return case FlaggedPacket.Flag.REAL: await self.__gossip_flagged_packet(packet) await self.handler(packet.message) @override async def gossip(self, msg: bytes): """ Gossip a message to all connected peers with prepending a message flag """ # The message size must be fixed. assert len(msg) == self.config.msg_size packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg) await self.__gossip_flagged_packet(packet) async def __gossip_flagged_packet(self, packet: FlaggedPacket): """ An internal method to send a flagged packet to all connected peers """ await super().gossip(packet.bytes()) class FlaggedPacket: class Flag(Enum): REAL = b"\x00" NOISE = b"\x01" def __init__(self, flag: Flag, message: bytes): self.flag = flag self.message = message def bytes(self) -> bytes: return self.flag.value + self.message @classmethod def from_bytes(cls, packet: bytes) -> Self: """ Parse a flagged packet from bytes """ if len(packet) < 1: raise ValueError("Invalid message format") return cls(cls.Flag(packet[:1]), packet[1:])