sphinx serde

This commit is contained in:
Youngjoon Lee 2024-06-28 18:06:55 +09:00
parent 3a703434da
commit 4aa1ee06fb
No known key found for this signature in database
GPG Key ID: B4253AFBA618BF4D

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import hashlib import hashlib
from enum import Enum
from typing import Awaitable, Callable, TypeAlias from typing import Awaitable, Callable, TypeAlias
from pysphinx.payload import DEFAULT_PAYLOAD_SIZE from pysphinx.payload import DEFAULT_PAYLOAD_SIZE
@ -15,8 +16,7 @@ from pysphinx.sphinx import (
from mixnet.config import GlobalConfig, NodeConfig from mixnet.config import GlobalConfig, NodeConfig
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder
NetworkPacket: TypeAlias = SphinxPacket | bytes NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes]
NetworkPacketQueue: TypeAlias = asyncio.Queue[NetworkPacket]
Connection: TypeAlias = NetworkPacketQueue Connection: TypeAlias = NetworkPacketQueue
BroadcastChannel: TypeAlias = asyncio.Queue[bytes] BroadcastChannel: TypeAlias = asyncio.Queue[bytes]
@ -39,7 +39,7 @@ class Node:
async def __process_sphinx_packet( async def __process_sphinx_packet(
self, packet: SphinxPacket self, packet: SphinxPacket
) -> NetworkPacket | None: ) -> SphinxPacket | None:
try: try:
processed = packet.process(self.config.private_key) processed = packet.process(self.config.private_key)
match processed: match processed:
@ -83,19 +83,19 @@ class Node:
for packet, _ in PacketBuilder.build_real_packets( for packet, _ in PacketBuilder.build_real_packets(
msg, self.global_config.membership msg, self.global_config.membership
): ):
await self.mixgossip_channel.gossip(packet) await self.mixgossip_channel.gossip(build_msg(MsgType.REAL, packet.bytes()))
class MixGossipChannel: class MixGossipChannel:
peering_degree: int peering_degree: int
conns: list[DuplexConnection] conns: list[DuplexConnection]
handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]] handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]]
msg_cache: set[NetworkPacket] msg_cache: set[bytes]
def __init__( def __init__(
self, self,
peer_degree: int, peer_degree: int,
handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]], handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]],
): ):
self.peering_degree = peer_degree self.peering_degree = peer_degree
self.conns = [] self.conns = []
@ -118,26 +118,28 @@ class MixGossipChannel:
async def __process_inbound_conn(self, conn: DuplexConnection): async def __process_inbound_conn(self, conn: DuplexConnection):
while True: while True:
elem = await conn.recv() msg = await conn.recv()
# In practice, data transmitted through connections is going to be always 'bytes'. # Don't process the same message twice.
# But here, we use the SphinxPacket type explicitly for simplicity msg_hash = hashlib.sha256(msg).digest()
# without implementing serde for SphinxPacket. if msg_hash in self.msg_cache:
if isinstance(elem, bytes):
assert elem == build_noise_packet()
# Drop packet
continue continue
elif isinstance(elem, SphinxPacket): self.msg_cache.add(msg_hash)
# Don't process the same message twice.
msg_hash = hashlib.sha256(elem.bytes()).digest()
if msg_hash in self.msg_cache:
continue
self.msg_cache.add(msg_hash)
# Handle the packet and gossip the result if needed.
net_packet = await self.handler(elem)
if net_packet is not None:
await self.gossip(net_packet)
async def gossip(self, packet: NetworkPacket): flag, msg = parse_msg(msg)
match flag:
case MsgType.NOISE:
# Drop noise packet
continue
case MsgType.REAL:
# Handle the packet and gossip the result if needed.
sphinx_packet = SphinxPacket.from_bytes(msg)
new_sphinx_packet = await self.handler(sphinx_packet)
if new_sphinx_packet is not None:
await self.gossip(
build_msg(MsgType.REAL, new_sphinx_packet.bytes())
)
async def gossip(self, packet: bytes):
for conn in self.conns: for conn in self.conns:
await conn.send(packet) await conn.send(packet)
@ -150,10 +152,10 @@ class DuplexConnection:
self.inbound = inbound self.inbound = inbound
self.outbound = outbound self.outbound = outbound
async def recv(self) -> NetworkPacket: async def recv(self) -> bytes:
return await self.inbound.get() return await self.inbound.get()
async def send(self, packet: NetworkPacket): async def send(self, packet: bytes):
await self.outbound.send(packet) await self.outbound.send(packet)
@ -178,9 +180,24 @@ class MixOutboundConnection:
elem = self.queue.get_nowait() elem = self.queue.get_nowait()
await self.conn.put(elem) await self.conn.put(elem)
async def send(self, elem: NetworkPacket): async def send(self, elem: bytes):
await self.queue.put(elem) await self.queue.put(elem)
class MsgType(Enum):
REAL = b"\x00"
NOISE = b"\x01"
def build_msg(flag: MsgType, data: bytes) -> bytes:
return flag.value + data
def parse_msg(data: bytes) -> tuple[MsgType, bytes]:
if len(data) < 1:
raise ValueError("Invalid message format")
return (MsgType(data[:1]), data[1:])
def build_noise_packet() -> bytes: def build_noise_packet() -> bytes:
return bytes(DEFAULT_PAYLOAD_SIZE) return build_msg(MsgType.NOISE, bytes(DEFAULT_PAYLOAD_SIZE))