refactor nomssip vs node encapsulation

This commit is contained in:
Youngjoon Lee 2024-07-11 14:29:00 +09:00
parent 953b2d6875
commit ab9943d291
No known key found for this signature in database
GPG Key ID: B4253AFBA618BF4D
2 changed files with 123 additions and 93 deletions

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
from enum import Enum
from typing import TypeAlias
from pysphinx.sphinx import (
@ -12,7 +11,6 @@ from pysphinx.sphinx import (
)
from mixnet.config import GlobalConfig, NodeConfig
from mixnet.connection import DuplexConnection, MixSimplexConnection
from mixnet.nomssip import Nomssip
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder
@ -32,44 +30,52 @@ class Node:
nomssip: Nomssip
reconstructor: MessageReconstructor
broadcast_channel: BroadcastChannel
# The actual packet size is calculated based on the max length of mix path by Sphinx encoding
# when the node is initialized, so that it can be used to generate noise packets.
packet_size: int
def __init__(self, config: NodeConfig, global_config: GlobalConfig):
self.config = config
self.global_config = global_config
self.nomssip = Nomssip(config.nomssip, self.__process_msg)
self.nomssip = Nomssip(
Nomssip.Config(
global_config.transmission_rate_per_sec,
config.nomssip.peering_degree,
self.__calculate_message_size(global_config),
),
self.__process_msg,
)
self.reconstructor = MessageReconstructor()
self.broadcast_channel = asyncio.Queue()
@staticmethod
def __calculate_message_size(global_config: GlobalConfig) -> int:
"""
Calculate the actual message size to be gossiped, which depends on the maximum length of mix path.
"""
sample_packet, _ = PacketBuilder.build_real_packets(
bytes(1), global_config.membership, self.global_config.max_mix_path_length
bytes(1), global_config.membership, global_config.max_mix_path_length
)[0]
self.packet_size = len(sample_packet.bytes())
return len(sample_packet.bytes())
async def __process_msg(self, msg: bytes) -> bytes | None:
async def __process_msg(self, msg: bytes) -> None:
"""
A handler to process messages received via gossip channel
"""
flag, msg = Node.__parse_msg(msg)
match flag:
case MsgType.NOISE:
# Drop noise packet
return None
case MsgType.REAL:
# Handle the packet and gossip the result if needed.
sphinx_packet = SphinxPacket.from_bytes(msg)
new_sphinx_packet = await self.__process_sphinx_packet(sphinx_packet)
if new_sphinx_packet is None:
return None
return Node.__build_msg(MsgType.REAL, new_sphinx_packet.bytes())
sphinx_packet = SphinxPacket.from_bytes(msg)
result = await self.__process_sphinx_packet(sphinx_packet)
match result:
case SphinxPacket():
# Gossip the next Sphinx packet
await self.nomssip.gossip(result.bytes())
case bytes():
# Broadcast the message fully recovered from Sphinx packets
await self.broadcast_channel.put(result)
case None:
return
async def __process_sphinx_packet(
self, packet: SphinxPacket
) -> SphinxPacket | None:
) -> SphinxPacket | bytes | None:
"""
Unwrap the Sphinx packet and process the next Sphinx packet or the payload.
Unwrap the Sphinx packet and process the next Sphinx packet or the payload if possible
"""
try:
processed = packet.process(self.config.private_key)
@ -77,14 +83,14 @@ class Node:
case ProcessedForwardHopPacket():
return processed.next_packet
case ProcessedFinalHopPacket():
await self.__process_sphinx_payload(processed.payload)
return await self.__process_sphinx_payload(processed.payload)
except ValueError:
# Return SphinxPacket as it is, if it cannot be unwrapped by the private key of this node.
return packet
# Return nothing, if it cannot be unwrapped by the private key of this node.
return None
async def __process_sphinx_payload(self, payload: Payload):
async def __process_sphinx_payload(self, payload: Payload) -> bytes | None:
"""
Process the Sphinx payload and broadcast it if it is a real message.
Process the Sphinx payload if possible
"""
msg_with_flag = self.reconstructor.add(
Fragment.from_bytes(payload.recover_plain_playload())
@ -92,37 +98,18 @@ class Node:
if msg_with_flag is not None:
flag, msg = PacketBuilder.parse_msg_and_flag(msg_with_flag)
if flag == MessageFlag.MESSAGE_FLAG_REAL:
await self.broadcast_channel.put(msg)
return msg
return None
def connect(self, peer: Node):
"""
Establish a duplex connection with a peer node.
"""
noise_msg = Node.__build_msg(MsgType.NOISE, bytes(self.packet_size))
inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue()
# Register a duplex connection for its own use
self.nomssip.add_conn(
DuplexConnection(
inbound_conn,
MixSimplexConnection(
outbound_conn,
self.global_config.transmission_rate_per_sec,
noise_msg,
),
)
)
# Register the same duplex connection for the peer
peer.nomssip.add_conn(
DuplexConnection(
outbound_conn,
MixSimplexConnection(
inbound_conn,
self.global_config.transmission_rate_per_sec,
noise_msg,
),
)
)
self.nomssip.add_conn(inbound_conn, outbound_conn)
# Register a duplex connection for the peer
peer.nomssip.add_conn(outbound_conn, inbound_conn)
async def send_message(self, msg: bytes):
"""
@ -135,25 +122,4 @@ class Node:
self.global_config.membership,
self.config.mix_path_length,
):
await self.nomssip.gossip(Node.__build_msg(MsgType.REAL, packet.bytes()))
@staticmethod
def __build_msg(flag: MsgType, data: bytes) -> bytes:
"""
Prepend a flag to the message, right before sending it via network channel.
"""
return flag.value + data
@staticmethod
def __parse_msg(data: bytes) -> tuple[MsgType, bytes]:
"""
Parse the message and extract the flag.
"""
if len(data) < 1:
raise ValueError("Invalid message format")
return (MsgType(data[:1]), data[1:])
class MsgType(Enum):
REAL = b"\x00"
NOISE = b"\x01"
await self.nomssip.gossip(packet.bytes())

View File

@ -1,9 +1,10 @@
import asyncio
import hashlib
from dataclasses import dataclass
from enum import Enum
from typing import Awaitable, Callable
from mixnet.config import NomssipConfig
from mixnet.connection import DuplexConnection
from mixnet.connection import DuplexConnection, MixSimplexConnection, SimplexConnection
class Nomssip:
@ -12,31 +13,49 @@ class Nomssip:
Peers are connected via DuplexConnection.
"""
config: NomssipConfig
@dataclass
class Config:
transmission_rate_per_sec: int
peering_degree: int
msg_size: int
config: Config
conns: list[DuplexConnection]
# A handler to process inbound messages.
handler: Callable[[bytes], Awaitable[bytes | None]]
# A set of message hashes to prevent processing the same message twice.
msg_cache: set[bytes]
handler: Callable[[bytes], Awaitable[None]]
# A set of packet hashes to prevent gossiping/processing the same packet twice.
packet_cache: set[bytes]
def __init__(
self,
config: NomssipConfig,
handler: Callable[[bytes], Awaitable[bytes | None]],
config: Config,
handler: Callable[[bytes], Awaitable[None]],
):
self.config = config
self.conns = []
self.handler = handler
self.msg_cache = set()
self.packet_cache = set()
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
self.tasks = set()
def add_conn(self, conn: DuplexConnection):
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
if len(self.conns) >= self.config.peering_degree:
# For simplicity of the spec, reject the connection if the peering degree is reached.
raise ValueError("The peering degree is reached.")
noise_packet = self.__build_packet(
self.PacketType.NOISE, bytes(self.config.msg_size)
)
conn = DuplexConnection(
inbound,
MixSimplexConnection(
outbound,
self.config.transmission_rate_per_sec,
noise_packet,
),
)
self.conns.append(conn)
task = asyncio.create_task(self.__process_inbound_conn(conn))
self.tasks.add(task)
@ -45,17 +64,62 @@ class Nomssip:
async def __process_inbound_conn(self, conn: DuplexConnection):
while True:
msg = await conn.recv()
# Don't process the same message twice.
msg_hash = hashlib.sha256(msg).digest()
if msg_hash in self.msg_cache:
packet = await conn.recv()
if self.__check_update_cache(packet):
continue
self.msg_cache.add(msg_hash)
new_msg = await self.handler(msg)
if new_msg is not None:
await self.gossip(new_msg)
flag, msg = self.__parse_packet(packet)
match flag:
case self.PacketType.NOISE:
# Drop noise packet
continue
case self.PacketType.REAL:
await self.__gossip(packet)
await self.handler(msg)
async def gossip(self, packet: bytes):
async def gossip(self, msg: bytes):
"""
Gossip a message to all connected peers with prepending a message flag
"""
# The message size must be fixed.
assert len(msg) == self.config.msg_size
packet = self.__build_packet(self.PacketType.REAL, msg)
await self.__gossip(packet)
async def __gossip(self, packet: bytes):
"""
An internal method to send a flagged packet to all connected peers
"""
for conn in self.conns:
await conn.send(packet)
def __check_update_cache(self, packet: bytes) -> bool:
"""
Add a message to the cache, and return True if the message was already in the cache.
"""
hash = hashlib.sha256(packet).digest()
if hash in self.packet_cache:
return True
self.packet_cache.add(hash)
return False
class PacketType(Enum):
REAL = b"\x00"
NOISE = b"\x01"
@staticmethod
def __build_packet(flag: PacketType, data: bytes) -> bytes:
"""
Prepend a flag to the message, right before sending it via network channel.
"""
return flag.value + data
@staticmethod
def __parse_packet(data: bytes) -> tuple[PacketType, bytes]:
"""
Parse the message and extract the flag.
"""
if len(data) < 1:
raise ValueError("Invalid message format")
return (Nomssip.PacketType(data[:1]), data[1:])