sphinx serde
This commit is contained in:
parent
3a703434da
commit
4aa1ee06fb
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import hashlib
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, TypeAlias
|
||||
|
||||
from pysphinx.payload import DEFAULT_PAYLOAD_SIZE
|
||||
|
@ -15,8 +16,7 @@ from pysphinx.sphinx import (
|
|||
from mixnet.config import GlobalConfig, NodeConfig
|
||||
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder
|
||||
|
||||
NetworkPacket: TypeAlias = SphinxPacket | bytes
|
||||
NetworkPacketQueue: TypeAlias = asyncio.Queue[NetworkPacket]
|
||||
NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes]
|
||||
Connection: TypeAlias = NetworkPacketQueue
|
||||
BroadcastChannel: TypeAlias = asyncio.Queue[bytes]
|
||||
|
||||
|
@ -39,7 +39,7 @@ class Node:
|
|||
|
||||
async def __process_sphinx_packet(
|
||||
self, packet: SphinxPacket
|
||||
) -> NetworkPacket | None:
|
||||
) -> SphinxPacket | None:
|
||||
try:
|
||||
processed = packet.process(self.config.private_key)
|
||||
match processed:
|
||||
|
@ -83,19 +83,19 @@ class Node:
|
|||
for packet, _ in PacketBuilder.build_real_packets(
|
||||
msg, self.global_config.membership
|
||||
):
|
||||
await self.mixgossip_channel.gossip(packet)
|
||||
await self.mixgossip_channel.gossip(build_msg(MsgType.REAL, packet.bytes()))
|
||||
|
||||
|
||||
class MixGossipChannel:
|
||||
peering_degree: int
|
||||
conns: list[DuplexConnection]
|
||||
handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]]
|
||||
msg_cache: set[NetworkPacket]
|
||||
handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]]
|
||||
msg_cache: set[bytes]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
peer_degree: int,
|
||||
handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]],
|
||||
handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]],
|
||||
):
|
||||
self.peering_degree = peer_degree
|
||||
self.conns = []
|
||||
|
@ -118,26 +118,28 @@ class MixGossipChannel:
|
|||
|
||||
async def __process_inbound_conn(self, conn: DuplexConnection):
|
||||
while True:
|
||||
elem = await conn.recv()
|
||||
# In practice, data transmitted through connections is going to be always 'bytes'.
|
||||
# But here, we use the SphinxPacket type explicitly for simplicity
|
||||
# without implementing serde for SphinxPacket.
|
||||
if isinstance(elem, bytes):
|
||||
assert elem == build_noise_packet()
|
||||
# Drop packet
|
||||
msg = await conn.recv()
|
||||
# Don't process the same message twice.
|
||||
msg_hash = hashlib.sha256(msg).digest()
|
||||
if msg_hash in self.msg_cache:
|
||||
continue
|
||||
elif isinstance(elem, SphinxPacket):
|
||||
# Don't process the same message twice.
|
||||
msg_hash = hashlib.sha256(elem.bytes()).digest()
|
||||
if msg_hash in self.msg_cache:
|
||||
continue
|
||||
self.msg_cache.add(msg_hash)
|
||||
# Handle the packet and gossip the result if needed.
|
||||
net_packet = await self.handler(elem)
|
||||
if net_packet is not None:
|
||||
await self.gossip(net_packet)
|
||||
self.msg_cache.add(msg_hash)
|
||||
|
||||
async def gossip(self, packet: NetworkPacket):
|
||||
flag, msg = parse_msg(msg)
|
||||
match flag:
|
||||
case MsgType.NOISE:
|
||||
# Drop noise packet
|
||||
continue
|
||||
case MsgType.REAL:
|
||||
# Handle the packet and gossip the result if needed.
|
||||
sphinx_packet = SphinxPacket.from_bytes(msg)
|
||||
new_sphinx_packet = await self.handler(sphinx_packet)
|
||||
if new_sphinx_packet is not None:
|
||||
await self.gossip(
|
||||
build_msg(MsgType.REAL, new_sphinx_packet.bytes())
|
||||
)
|
||||
|
||||
async def gossip(self, packet: bytes):
|
||||
for conn in self.conns:
|
||||
await conn.send(packet)
|
||||
|
||||
|
@ -150,10 +152,10 @@ class DuplexConnection:
|
|||
self.inbound = inbound
|
||||
self.outbound = outbound
|
||||
|
||||
async def recv(self) -> NetworkPacket:
|
||||
async def recv(self) -> bytes:
|
||||
return await self.inbound.get()
|
||||
|
||||
async def send(self, packet: NetworkPacket):
|
||||
async def send(self, packet: bytes):
|
||||
await self.outbound.send(packet)
|
||||
|
||||
|
||||
|
@ -178,9 +180,24 @@ class MixOutboundConnection:
|
|||
elem = self.queue.get_nowait()
|
||||
await self.conn.put(elem)
|
||||
|
||||
async def send(self, elem: NetworkPacket):
|
||||
async def send(self, elem: bytes):
|
||||
await self.queue.put(elem)
|
||||
|
||||
|
||||
class MsgType(Enum):
|
||||
REAL = b"\x00"
|
||||
NOISE = b"\x01"
|
||||
|
||||
|
||||
def build_msg(flag: MsgType, data: bytes) -> bytes:
|
||||
return flag.value + data
|
||||
|
||||
|
||||
def parse_msg(data: bytes) -> tuple[MsgType, bytes]:
|
||||
if len(data) < 1:
|
||||
raise ValueError("Invalid message format")
|
||||
return (MsgType(data[:1]), data[1:])
|
||||
|
||||
|
||||
def build_noise_packet() -> bytes:
|
||||
return bytes(DEFAULT_PAYLOAD_SIZE)
|
||||
return build_msg(MsgType.NOISE, bytes(DEFAULT_PAYLOAD_SIZE))
|
||||
|
|
Loading…
Reference in New Issue