diff --git a/mixnet/node.py b/mixnet/node.py index 994a41c..00b213b 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -5,7 +5,6 @@ import hashlib from enum import Enum from typing import Awaitable, Callable, TypeAlias -from pysphinx.payload import DEFAULT_PAYLOAD_SIZE from pysphinx.sphinx import ( Payload, ProcessedFinalHopPacket, @@ -27,6 +26,7 @@ class Node: mixgossip_channel: MixGossipChannel reconstructor: MessageReconstructor broadcast_channel: BroadcastChannel + packet_size: int def __init__(self, config: NodeConfig, global_config: GlobalConfig): self.config = config @@ -37,6 +37,11 @@ class Node: self.reconstructor = MessageReconstructor() self.broadcast_channel = asyncio.Queue() + sample_packet, _ = PacketBuilder.build_real_packets( + bytes(1), global_config.membership + )[0] + self.packet_size = len(sample_packet.bytes()) + async def __process_sphinx_packet( self, packet: SphinxPacket ) -> SphinxPacket | None: @@ -61,12 +66,15 @@ class Node: await self.broadcast_channel.put(msg) def connect(self, peer: Node): + noise_msg = build_msg(MsgType.NOISE, bytes(self.packet_size)) inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue() self.mixgossip_channel.add_conn( DuplexConnection( inbound_conn, MixOutboundConnection( - outbound_conn, self.global_config.transmission_rate_per_sec + outbound_conn, + self.global_config.transmission_rate_per_sec, + noise_msg, ), ) ) @@ -74,7 +82,9 @@ class Node: DuplexConnection( outbound_conn, MixOutboundConnection( - inbound_conn, self.global_config.transmission_rate_per_sec + inbound_conn, + self.global_config.transmission_rate_per_sec, + noise_msg, ), ) ) @@ -163,11 +173,15 @@ class MixOutboundConnection: queue: NetworkPacketQueue conn: Connection transmission_rate_per_sec: int + noise_msg: bytes - def __init__(self, conn: Connection, transmission_rate_per_sec: int): + def __init__( + self, conn: Connection, transmission_rate_per_sec: int, noise_msg: bytes + ): self.queue = asyncio.Queue() self.conn = conn self.transmission_rate_per_sec = transmission_rate_per_sec + self.noise_msg = noise_msg self.task = asyncio.create_task(self.__run()) async def __run(self): @@ -175,7 +189,7 @@ class MixOutboundConnection: await asyncio.sleep(1 / self.transmission_rate_per_sec) # TODO: time mixing if self.queue.empty(): - elem = build_noise_packet() + elem = self.noise_msg else: elem = self.queue.get_nowait() await self.conn.put(elem) @@ -197,7 +211,3 @@ def parse_msg(data: bytes) -> tuple[MsgType, bytes]: if len(data) < 1: raise ValueError("Invalid message format") return (MsgType(data[:1]), data[1:]) - - -def build_noise_packet() -> bytes: - return build_msg(MsgType.NOISE, bytes(DEFAULT_PAYLOAD_SIZE))