From 4aa1ee06fb8a8a3f7bb57f4d9846a05fb665cece Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Fri, 28 Jun 2024 18:06:55 +0900 Subject: [PATCH] sphinx serde --- mixnet/node.py | 75 +++++++++++++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/mixnet/node.py b/mixnet/node.py index e914a68..f15056f 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import hashlib +from enum import Enum from typing import Awaitable, Callable, TypeAlias from pysphinx.payload import DEFAULT_PAYLOAD_SIZE @@ -15,8 +16,7 @@ from pysphinx.sphinx import ( from mixnet.config import GlobalConfig, NodeConfig from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder -NetworkPacket: TypeAlias = SphinxPacket | bytes -NetworkPacketQueue: TypeAlias = asyncio.Queue[NetworkPacket] +NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes] Connection: TypeAlias = NetworkPacketQueue BroadcastChannel: TypeAlias = asyncio.Queue[bytes] @@ -39,7 +39,7 @@ class Node: async def __process_sphinx_packet( self, packet: SphinxPacket - ) -> NetworkPacket | None: + ) -> SphinxPacket | None: try: processed = packet.process(self.config.private_key) match processed: @@ -83,19 +83,19 @@ class Node: for packet, _ in PacketBuilder.build_real_packets( msg, self.global_config.membership ): - await self.mixgossip_channel.gossip(packet) + await self.mixgossip_channel.gossip(build_msg(MsgType.REAL, packet.bytes())) class MixGossipChannel: peering_degree: int conns: list[DuplexConnection] - handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]] - msg_cache: set[NetworkPacket] + handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]] + msg_cache: set[bytes] def __init__( self, peer_degree: int, - handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]], + handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]], ): self.peering_degree = peer_degree self.conns = [] @@ -118,26 +118,28 @@ class MixGossipChannel: async def __process_inbound_conn(self, conn: DuplexConnection): while True: - elem = await conn.recv() - # In practice, data transmitted through connections is going to be always 'bytes'. - # But here, we use the SphinxPacket type explicitly for simplicity - # without implementing serde for SphinxPacket. - if isinstance(elem, bytes): - assert elem == build_noise_packet() - # Drop packet + msg = await conn.recv() + # Don't process the same message twice. + msg_hash = hashlib.sha256(msg).digest() + if msg_hash in self.msg_cache: continue - elif isinstance(elem, SphinxPacket): - # Don't process the same message twice. - msg_hash = hashlib.sha256(elem.bytes()).digest() - if msg_hash in self.msg_cache: - continue - self.msg_cache.add(msg_hash) - # Handle the packet and gossip the result if needed. - net_packet = await self.handler(elem) - if net_packet is not None: - await self.gossip(net_packet) + self.msg_cache.add(msg_hash) - async def gossip(self, packet: NetworkPacket): + flag, msg = parse_msg(msg) + match flag: + case MsgType.NOISE: + # Drop noise packet + continue + case MsgType.REAL: + # Handle the packet and gossip the result if needed. + sphinx_packet = SphinxPacket.from_bytes(msg) + new_sphinx_packet = await self.handler(sphinx_packet) + if new_sphinx_packet is not None: + await self.gossip( + build_msg(MsgType.REAL, new_sphinx_packet.bytes()) + ) + + async def gossip(self, packet: bytes): for conn in self.conns: await conn.send(packet) @@ -150,10 +152,10 @@ class DuplexConnection: self.inbound = inbound self.outbound = outbound - async def recv(self) -> NetworkPacket: + async def recv(self) -> bytes: return await self.inbound.get() - async def send(self, packet: NetworkPacket): + async def send(self, packet: bytes): await self.outbound.send(packet) @@ -178,9 +180,24 @@ class MixOutboundConnection: elem = self.queue.get_nowait() await self.conn.put(elem) - async def send(self, elem: NetworkPacket): + async def send(self, elem: bytes): await self.queue.put(elem) +class MsgType(Enum): + REAL = b"\x00" + NOISE = b"\x01" + + +def build_msg(flag: MsgType, data: bytes) -> bytes: + return flag.value + data + + +def parse_msg(data: bytes) -> tuple[MsgType, bytes]: + if len(data) < 1: + raise ValueError("Invalid message format") + return (MsgType(data[:1]), data[1:]) + + def build_noise_packet() -> bytes: - return bytes(DEFAULT_PAYLOAD_SIZE) + return build_msg(MsgType.NOISE, bytes(DEFAULT_PAYLOAD_SIZE))