107 lines
3.0 KiB
Python
Raw Normal View History

2024-08-01 11:07:52 +09:00
from __future__ import annotations
import hashlib
import random
from dataclasses import dataclass
from enum import Enum
from typing import Awaitable, Callable, Self, override
from framework import Framework
from protocol.connection import (
DuplexConnection,
MixSimplexConnection,
SimplexConnection,
)
from protocol.error import PeeringDegreeReached
from protocol.gossip import Gossip, GossipConfig
from protocol.temporalmix import TemporalMixConfig
@dataclass
class NomssipConfig(GossipConfig):
transmission_rate_per_sec: int
msg_size: int
temporal_mix: TemporalMixConfig
class Nomssip(Gossip):
"""
A NomMix gossip channel that extends the Gossip channel
by adding global transmission rate and noise generation.
"""
def __init__(
self,
framework: Framework,
config: NomssipConfig,
handler: Callable[[bytes], Awaitable[None]],
):
super().__init__(framework, config, handler)
self.config = config
@override
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
noise_packet = FlaggedPacket(
FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size)
).bytes()
super().add_conn(
inbound,
MixSimplexConnection(
self.framework,
outbound,
self.config.transmission_rate_per_sec,
noise_packet,
self.config.temporal_mix,
),
)
@override
async def process_inbound_msg(self, msg: bytes):
packet = FlaggedPacket.from_bytes(msg)
match packet.flag:
case FlaggedPacket.Flag.NOISE:
# Drop noise packet
return
case FlaggedPacket.Flag.REAL:
await self.__gossip_flagged_packet(packet)
await self.handler(packet.message)
@override
async def gossip(self, msg: bytes):
"""
Gossip a message to all connected peers with prepending a message flag
"""
# The message size must be fixed.
assert len(msg) == self.config.msg_size
packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg)
await self.__gossip_flagged_packet(packet)
async def __gossip_flagged_packet(self, packet: FlaggedPacket):
"""
An internal method to send a flagged packet to all connected peers
"""
await super().gossip(packet.bytes())
class FlaggedPacket:
class Flag(Enum):
REAL = b"\x00"
NOISE = b"\x01"
def __init__(self, flag: Flag, message: bytes):
self.flag = flag
self.message = message
def bytes(self) -> bytes:
return self.flag.value + self.message
@classmethod
def from_bytes(cls, packet: bytes) -> Self:
"""
Parse a flagged packet from bytes
"""
if len(packet) < 1:
raise ValueError("Invalid message format")
return cls(cls.Flag(packet[:1]), packet[1:])