optimize: use generic for messages to reduce the size of msg cache in gossip and discard serde cost

This commit is contained in:
Youngjoon Lee 2024-08-15 02:40:06 +09:00
parent 3a2f3cc079
commit dbf1b78134
No known key found for this signature in database
GPG Key ID: 167546E2D1712F8C
15 changed files with 352 additions and 237 deletions

View File

@ -1,59 +1,62 @@
from __future__ import annotations
import abc
from typing import Generic, TypeVar
from framework import Framework, Queue
from protocol.temporalmix import TemporalMix, TemporalMixConfig
T = TypeVar("T")
class SimplexConnection(abc.ABC):
class SimplexConnection(abc.ABC, Generic[T]):
"""
An abstract class for a simplex connection that can send and receive data in one direction
"""
@abc.abstractmethod
async def send(self, data: bytes) -> None:
async def send(self, data: T) -> None:
pass
@abc.abstractmethod
async def recv(self) -> bytes:
async def recv(self) -> T:
pass
class LocalSimplexConnection(SimplexConnection):
class LocalSimplexConnection(SimplexConnection[T]):
"""
A simplex connection that doesn't have any network latency.
Data sent through this connection can be immediately received from the other end.
"""
def __init__(self, framework: Framework):
self.queue: Queue[bytes] = framework.queue()
self.queue: Queue[T] = framework.queue()
async def send(self, data: bytes) -> None:
async def send(self, data: T) -> None:
await self.queue.put(data)
async def recv(self) -> bytes:
async def recv(self) -> T:
return await self.queue.get()
class DuplexConnection:
class DuplexConnection(Generic[T]):
"""
A duplex connection in which data can be transmitted and received simultaneously in both directions.
This is to mimic duplex communication in a real network (such as TCP or QUIC).
"""
def __init__(self, inbound: SimplexConnection, outbound: SimplexConnection):
def __init__(self, inbound: SimplexConnection[T], outbound: SimplexConnection[T]):
self.inbound = inbound
self.outbound = outbound
async def recv(self) -> bytes:
async def recv(self) -> T:
return await self.inbound.recv()
async def send(self, packet: bytes):
async def send(self, packet: T):
await self.outbound.send(packet)
class MixSimplexConnection(SimplexConnection):
class MixSimplexConnection(SimplexConnection[T]):
"""
Wraps a SimplexConnection to add a transmission rate and noise to the connection.
"""
@ -61,16 +64,16 @@ class MixSimplexConnection(SimplexConnection):
def __init__(
self,
framework: Framework,
conn: SimplexConnection,
conn: SimplexConnection[T],
transmission_rate_per_sec: int,
noise_msg: bytes,
noise_msg: T,
temporal_mix_config: TemporalMixConfig,
# OPTIMIZATION ONLY FOR EXPERIMENTS WITHOUT BANDWIDTH MEASUREMENT
# If True, skip sending a noise even if it's time to send one.
skip_sending_noise: bool,
):
self.framework = framework
self.queue: Queue[bytes] = TemporalMix.queue(
self.queue: Queue[T] = TemporalMix.queue(
temporal_mix_config, framework, noise_msg
)
self.conn = conn
@ -87,8 +90,8 @@ class MixSimplexConnection(SimplexConnection):
continue
await self.conn.send(msg)
async def send(self, data: bytes) -> None:
async def send(self, data: T) -> None:
await self.queue.put(data)
async def recv(self) -> bytes:
async def recv(self) -> T:
return await self.conn.recv()

View File

@ -1,8 +1,7 @@
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from typing import Awaitable, Callable
from typing import Awaitable, Callable, Generic, Protocol, TypeVar
from framework import Framework
from protocol.connection import (
@ -18,7 +17,14 @@ class GossipConfig:
peering_degree: int
class Gossip:
class HasId(Protocol):
def id(self) -> int: ...
T = TypeVar("T", bound=HasId)
class Gossip(Generic[T]):
"""
A gossip channel that broadcasts messages to all connected peers.
Peers are connected via DuplexConnection.
@ -28,15 +34,15 @@ class Gossip:
self,
framework: Framework,
config: GossipConfig,
handler: Callable[[bytes], Awaitable[None]],
handler: Callable[[T], Awaitable[None]],
):
self.framework = framework
self.config = config
self.conns: list[DuplexConnection] = []
self.conns: list[DuplexConnection[T]] = []
# A handler to process inbound messages.
self.handler = handler
# msg -> received_cnt
self.packet_cache: dict[bytes, int] = dict()
# msg_id -> received_cnt
self.packet_cache: dict[int, int] = dict()
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
self.tasks: set[Awaitable] = set()
@ -44,12 +50,12 @@ class Gossip:
def can_accept_conn(self) -> bool:
return len(self.conns) < self.config.peering_degree
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
def add_conn(self, inbound: SimplexConnection[T], outbound: SimplexConnection[T]):
if not self.can_accept_conn():
# For simplicity of the spec, reject the connection if the peering degree is reached.
raise PeeringDegreeReached()
conn = DuplexConnection(
conn = DuplexConnection[T](
inbound,
outbound,
)
@ -57,18 +63,18 @@ class Gossip:
task = self.framework.spawn(self.__process_inbound_conn(conn))
self.tasks.add(task)
async def __process_inbound_conn(self, conn: DuplexConnection):
async def __process_inbound_conn(self, conn: DuplexConnection[T]):
while True:
msg = await conn.recv()
if self._check_update_cache(msg):
continue
await self._process_inbound_msg(msg, conn)
async def _process_inbound_msg(self, msg: bytes, received_from: DuplexConnection):
async def _process_inbound_msg(self, msg: T, received_from: DuplexConnection[T]):
await self._gossip(msg, [received_from])
await self.handler(msg)
async def publish(self, msg: bytes):
async def publish(self, msg: T):
"""
Publish a message to all nodes in the network.
"""
@ -83,7 +89,7 @@ class Gossip:
# which means that we consider that this publisher node received the message.
await self.handler(msg)
async def _gossip(self, msg: bytes, excludes: list[DuplexConnection] = []):
async def _gossip(self, msg: T, excludes: list[DuplexConnection] = []):
"""
Gossip a message to all peers connected to this node.
"""
@ -91,26 +97,26 @@ class Gossip:
if conn not in excludes:
await conn.send(msg)
def _check_update_cache(self, packet: bytes, publishing: bool = False) -> bool:
def _check_update_cache(self, msg: T, publishing: bool = False) -> bool:
"""
Add a message to the cache, and return True if the message was already in the cache.
"""
hash = hashlib.sha256(packet).digest()
seen = hash in self.packet_cache
id = msg.id()
seen = id in self.packet_cache
if publishing:
if not seen:
# Put 0 when publishing, so that the publisher node doesn't gossip the message again
# even when it first receive the message from one of its peers later.
self.packet_cache[hash] = 0
self.packet_cache[id] = 0
else:
if not seen:
self.packet_cache[hash] = 1
self.packet_cache[id] = 1
else:
self.packet_cache[hash] += 1
self.packet_cache[id] += 1
# Remove the message from the cache if it's received from all adjacent peers in the end
# to reduce the size of cache.
if self.packet_cache[hash] >= self.config.peering_degree:
del self.packet_cache[hash]
if self.packet_cache[id] >= self.config.peering_degree:
del self.packet_cache[id]
return seen

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Awaitable, Callable
from typing import Awaitable, Callable, Generic, Protocol, Self, Type, TypeVar
from pysphinx.sphinx import (
ProcessedFinalHopPacket,
@ -13,11 +13,22 @@ from protocol.config import GlobalConfig, NodeConfig
from protocol.connection import SimplexConnection
from protocol.error import PeeringDegreeReached
from protocol.gossip import Gossip
from protocol.nomssip import Nomssip, NomssipConfig
from protocol.nomssip import Nomssip, NomssipConfig, NomssipMessage
from protocol.sphinx import SphinxPacketBuilder
class Node:
class HasIdAndLenAndBytes(Protocol):
def id(self) -> int: ...
def __len__(self) -> int: ...
def __bytes__(self) -> bytes: ...
@classmethod
def from_bytes(cls, data: bytes) -> Self: ...
T = TypeVar("T", bound=HasIdAndLenAndBytes)
class Node(Generic[T]):
"""
This represents any node in the network, which:
- generates/gossips mix messages (Sphinx packets)
@ -31,57 +42,53 @@ class Node:
config: NodeConfig,
global_config: GlobalConfig,
# A handler called when a node receives a broadcasted message originated from the last mix.
broadcasted_msg_handler: Callable[[bytes], Awaitable[None]],
# An optional handler only for the simulation,
# which is called when a message is fully recovered by the last mix
broadcasted_msg_handler: Callable[[T], Awaitable[None]],
# A handler called when a message is fully recovered by the last mix
# and returns a new message to be broadcasted.
recovered_msg_handler: Callable[[bytes], Awaitable[bytes]] | None = None,
recovered_msg_handler: Callable[[bytes], Awaitable[T]],
noise_msg: T,
):
self.framework = framework
self.config = config
self.global_config = global_config
nomssip_config = NomssipConfig(
config.gossip.peering_degree,
global_config.transmission_rate_per_sec,
SphinxPacketBuilder.size(global_config),
config.temporal_mix,
)
self.nomssip = Nomssip(
framework,
NomssipConfig(
config.gossip.peering_degree,
global_config.transmission_rate_per_sec,
self.__calculate_message_size(global_config),
config.temporal_mix,
),
nomssip_config,
self.__process_msg,
noise_msg=NomssipMessage[T](NomssipMessage.Flag.NOISE, noise_msg),
)
self.broadcast = Gossip(framework, config.gossip, broadcasted_msg_handler)
self.broadcast = Gossip[T](framework, config.gossip, broadcasted_msg_handler)
self.recovered_msg_handler = recovered_msg_handler
@staticmethod
def __calculate_message_size(global_config: GlobalConfig) -> int:
"""
Calculate the actual message size to be gossiped, which depends on the maximum length of mix path.
"""
sample_sphinx_packet, _ = SphinxPacketBuilder.build(
bytes(global_config.max_message_size),
global_config,
global_config.max_mix_path_length,
)
return len(sample_sphinx_packet.bytes())
async def __process_msg(self, msg: bytes) -> None:
async def __process_msg(self, msg: NomssipMessage[T]) -> None:
"""
A handler to process messages received via Nomssip channel
"""
assert msg.flag == NomssipMessage.Flag.REAL
sphinx_packet = SphinxPacket.from_bytes(
msg, self.global_config.max_mix_path_length
bytes(msg.message), self.global_config.max_mix_path_length
)
result = await self.__process_sphinx_packet(sphinx_packet)
match result:
case SphinxPacket():
# Gossip the next Sphinx packet
await self.nomssip.publish(result.bytes())
t: Type[T] = type(msg.message)
await self.nomssip.publish(
NomssipMessage[T](
NomssipMessage.Flag.REAL,
t.from_bytes(result.bytes()),
)
)
case bytes():
if self.recovered_msg_handler is not None:
result = await self.recovered_msg_handler(result)
# Broadcast the message fully recovered from Sphinx packets
await self.broadcast.publish(result)
await self.broadcast.publish(await self.recovered_msg_handler(result))
case None:
return
@ -105,31 +112,36 @@ class Node:
def connect_mix(
self,
peer: Node,
inbound_conn: SimplexConnection,
outbound_conn: SimplexConnection,
inbound_conn: SimplexConnection[NomssipMessage[T]],
outbound_conn: SimplexConnection[NomssipMessage[T]],
):
connect_nodes(self.nomssip, peer.nomssip, inbound_conn, outbound_conn)
def connect_broadcast(
self,
peer: Node,
inbound_conn: SimplexConnection,
outbound_conn: SimplexConnection,
inbound_conn: SimplexConnection[T],
outbound_conn: SimplexConnection[T],
):
connect_nodes(self.broadcast, peer.broadcast, inbound_conn, outbound_conn)
async def send_message(self, msg: bytes):
async def send_message(self, msg: T):
"""
Build a Sphinx packet and gossip it to all connected peers.
"""
# Here, we handle the case in which a msg is split into multiple Sphinx packets.
# But, in practice, we expect a message to be small enough to fit in a single Sphinx packet.
sphinx_packet, _ = SphinxPacketBuilder.build(
msg,
bytes(msg),
self.global_config,
self.config.mix_path_length,
)
await self.nomssip.publish(sphinx_packet.bytes())
t: Type[T] = type(msg)
await self.nomssip.publish(
NomssipMessage(
NomssipMessage.Flag.REAL, t.from_bytes(sphinx_packet.bytes())
)
)
def connect_nodes(

View File

@ -1,8 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Awaitable, Callable, Self, override
from typing import Awaitable, Callable, Generic, Protocol, TypeVar, override
from framework import Framework
from protocol.connection import (
@ -24,7 +22,31 @@ class NomssipConfig(GossipConfig):
skip_sending_noise: bool = False
class Nomssip(Gossip):
class HasIdAndLen(Protocol):
def id(self) -> int: ...
def __len__(self) -> int: ...
T = TypeVar("T", bound=HasIdAndLen)
class NomssipMessage(Generic[T]):
class Flag(Enum):
REAL = b"\x00"
NOISE = b"\x01"
def __init__(self, flag: Flag, message: T):
self.flag = flag
self.message = message
def id(self) -> int:
return self.message.id()
def __len__(self) -> int:
return len(self.flag.value) + len(self.message)
class Nomssip(Gossip[NomssipMessage[T]]):
"""
A NomMix gossip channel that extends the Gossip channel
by adding global transmission rate and noise generation.
@ -34,72 +56,53 @@ class Nomssip(Gossip):
self,
framework: Framework,
config: NomssipConfig,
handler: Callable[[bytes], Awaitable[None]],
handler: Callable[[NomssipMessage[T]], Awaitable[None]],
noise_msg: NomssipMessage[T],
):
super().__init__(framework, config, handler)
self.config = config
self.noise_msg = noise_msg
@override
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
noise_packet = FlaggedPacket(
FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size)
).bytes()
def add_conn(
self,
inbound: SimplexConnection[NomssipMessage[T]],
outbound: SimplexConnection[NomssipMessage[T]],
):
super().add_conn(
inbound,
MixSimplexConnection(
MixSimplexConnection[NomssipMessage[T]](
self.framework,
outbound,
self.config.transmission_rate_per_sec,
noise_packet,
self.noise_msg,
self.config.temporal_mix,
self.config.skip_sending_noise,
),
)
@override
async def _process_inbound_msg(self, msg: bytes, received_from: DuplexConnection):
packet = FlaggedPacket.from_bytes(msg)
match packet.flag:
case FlaggedPacket.Flag.NOISE:
async def _process_inbound_msg(
self, msg: NomssipMessage[T], received_from: DuplexConnection
):
match msg.flag:
case NomssipMessage.Flag.NOISE:
# Drop noise packet
return
case FlaggedPacket.Flag.REAL:
self.assert_message_size(packet.message)
case NomssipMessage.Flag.REAL:
self.assert_message_size(msg.message)
await super()._gossip(msg, [received_from])
await self.handler(packet.message)
await self.handler(msg)
@override
async def publish(self, msg: bytes):
self.assert_message_size(msg)
async def publish(self, msg: NomssipMessage[T]):
self.assert_message_size(msg.message)
packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg).bytes()
# Please see comments in super().publish() for the reason of the following line.
if not self._check_update_cache(packet, publishing=True):
await self._gossip(packet)
if not self._check_update_cache(msg, publishing=True):
await self._gossip(msg)
await self.handler(msg)
def assert_message_size(self, msg: bytes):
def assert_message_size(self, msg: T):
# The message size must be fixed.
assert len(msg) == self.config.msg_size, f"{len(msg)} != {self.config.msg_size}"
class FlaggedPacket:
class Flag(Enum):
REAL = b"\x00"
NOISE = b"\x01"
def __init__(self, flag: Flag, message: bytes):
self.flag = flag
self.message = message
def bytes(self) -> bytes:
return self.flag.value + self.message
@classmethod
def from_bytes(cls, packet: bytes) -> Self:
"""
Parse a flagged packet from bytes
"""
if len(packet) < 1:
raise ValueError("Invalid message format")
return cls(cls.Flag(packet[:1]), packet[1:])

View File

@ -31,3 +31,15 @@ class SphinxPacketBuilder:
max_plain_payload_size=global_config.max_message_size,
)
return (packet, route)
@staticmethod
def size(global_config: GlobalConfig) -> int:
"""
Calculate the size of Sphinx packet, which depends on the maximum length of mix path.
"""
sample_sphinx_packet, _ = SphinxPacketBuilder.build(
bytes(global_config.max_message_size),
global_config,
global_config.max_mix_path_length,
)
return len(sample_sphinx_packet.bytes())

View File

@ -1,9 +1,13 @@
import hashlib
from dataclasses import dataclass
from typing import Self
from unittest import IsolatedAsyncioTestCase
import framework.asyncio as asynciofw
from framework.framework import Queue
from protocol.connection import LocalSimplexConnection
from protocol.node import Node
from protocol.nomssip import NomssipMessage
from protocol.test_utils import (
init_mixnet_config,
)
@ -14,31 +18,42 @@ class TestNode(IsolatedAsyncioTestCase):
framework = asynciofw.Framework()
global_config, node_configs, _ = init_mixnet_config(10)
queue: Queue[bytes] = framework.queue()
queue: Queue[Message] = framework.queue()
async def broadcasted_msg_handler(msg: bytes) -> None:
async def broadcasted_msg_handler(msg: Message) -> None:
await queue.put(msg)
async def recovered_msg_handler(msg: bytes) -> Message:
return Message(msg)
nodes = [
Node(framework, node_config, global_config, broadcasted_msg_handler)
Node[Message](
framework,
node_config,
global_config,
broadcasted_msg_handler,
recovered_msg_handler,
noise_msg=Message(b""),
)
for node_config in node_configs
]
for i, node in enumerate(nodes):
try:
node.connect_mix(
nodes[(i + 1) % len(nodes)],
LocalSimplexConnection(framework),
LocalSimplexConnection(framework),
LocalSimplexConnection[NomssipMessage[Message]](framework),
LocalSimplexConnection[NomssipMessage[Message]](framework),
)
node.connect_broadcast(
nodes[(i + 1) % len(nodes)],
LocalSimplexConnection(framework),
LocalSimplexConnection(framework),
LocalSimplexConnection[Message](framework),
LocalSimplexConnection[Message](framework),
)
except ValueError as e:
print(e)
await nodes[0].send_message(b"block selection")
msg = Message(b"block selection")
await nodes[0].send_message(msg)
# Wait for all nodes to receive the broadcast
num_nodes_received_broadcast = 0
@ -47,7 +62,7 @@ class TestNode(IsolatedAsyncioTestCase):
await framework.sleep(1)
while not queue.empty():
self.assertEqual(b"block selection", await queue.get())
self.assertEqual(msg, await queue.get())
num_nodes_received_broadcast += 1
if num_nodes_received_broadcast == len(nodes):
@ -56,3 +71,21 @@ class TestNode(IsolatedAsyncioTestCase):
self.assertEqual(len(nodes), num_nodes_received_broadcast)
# TODO: check noise
@dataclass
class Message:
data: bytes
def id(self) -> int:
return int.from_bytes(hashlib.sha256(self.data).digest(), byteorder="big")
def __len__(self) -> int:
return len(self.data)
def __bytes__(self) -> bytes:
return self.data
@classmethod
def from_bytes(cls, data: bytes) -> Self:
return cls(data)

View File

@ -0,0 +1,30 @@
from dataclasses import dataclass
from framework.framework import Framework
MESSAGE_SIZE = 1
@dataclass
class Message:
_id: int
sent_time: float
def id(self) -> int:
return self._id
def __len__(self) -> int:
# Return any number here, since we don't use Sphinx encoding for queuesim and byte serialization.
# This must be matched with NomssipConfig.msg_size.
return MESSAGE_SIZE
class MessageBuilder:
def __init__(self, framework: Framework):
self.framework = framework
self.next_id = 0
def next(self) -> Message:
msg = Message(self.next_id, self.framework.now())
self.next_id += 1
return msg

View File

@ -5,7 +5,8 @@ from typing import Awaitable, Callable
from framework.framework import Framework
from protocol.connection import SimplexConnection
from protocol.node import connect_nodes
from protocol.nomssip import Nomssip, NomssipConfig
from protocol.nomssip import Nomssip, NomssipConfig, NomssipMessage
from queuesim.message import Message
class Node:
@ -13,20 +14,25 @@ class Node:
self,
framework: Framework,
nomssip_config: NomssipConfig,
msg_handler: Callable[[bytes], Awaitable[None]],
msg_handler: Callable[[NomssipMessage[Message]], Awaitable[None]],
):
self.nomssip = Nomssip(framework, nomssip_config, msg_handler)
self.nomssip = Nomssip(
framework,
nomssip_config,
msg_handler,
noise_msg=NomssipMessage(NomssipMessage.Flag.NOISE, Message(-1, 0)),
)
def connect(
self,
peer: Node,
inbound_conn: SimplexConnection,
outbound_conn: SimplexConnection,
inbound_conn: SimplexConnection[NomssipMessage[Message]],
outbound_conn: SimplexConnection[NomssipMessage[Message]],
):
connect_nodes(self.nomssip, peer.nomssip, inbound_conn, outbound_conn)
async def send_message(self, msg: bytes):
async def send_message(self, msg: Message):
"""
Send the message via Nomos Gossip to all connected peers.
"""
await self.nomssip.publish(msg)
await self.nomssip.publish(NomssipMessage(NomssipMessage.Flag.REAL, msg))

View File

@ -16,6 +16,7 @@ import usim
from protocol.nomssip import NomssipConfig
from protocol.temporalmix import TemporalMixConfig, TemporalMixType
from queuesim.config import Config
from queuesim.message import MESSAGE_SIZE
from queuesim.paramset import (
EXPERIMENT_TITLES,
ExperimentID,
@ -32,7 +33,7 @@ DEFAULT_CONFIG = Config(
nomssip=NomssipConfig(
peering_degree=3,
transmission_rate_per_sec=10,
msg_size=8,
msg_size=MESSAGE_SIZE,
temporal_mix=TemporalMixConfig(
mix_type=TemporalMixType.NONE,
min_queue_size=10,

View File

@ -1,7 +1,5 @@
import csv
import struct
from dataclasses import dataclass
from typing import Counter, Self
from typing import Counter
import pandas as pd
import usim
@ -9,7 +7,9 @@ import usim
from framework.framework import Queue
from framework.usim import Framework
from protocol.connection import LocalSimplexConnection, SimplexConnection
from protocol.nomssip import NomssipMessage
from queuesim.config import Config
from queuesim.message import Message, MessageBuilder
from queuesim.node import Node
from sim.connection import RemoteSimplexConnection
from sim.topology import build_full_random_topology
@ -31,7 +31,7 @@ class Simulation:
self.framework.stop_tasks()
async def __run(self, out_csv_path: str, topology_path: str):
self.received_msg_queue: Queue[tuple[float, bytes]] = self.framework.queue()
self.received_msg_queue: Queue[tuple[float, Message]] = self.framework.queue()
# Run and connect nodes
nodes = self.__run_nodes()
@ -48,7 +48,7 @@ class Simulation:
writer = csv.writer(f)
writer.writerow(["dissemination_time", "sent_time", "all_received_time"])
# To count how many nodes have received each message
received_msg_counters: Counter[bytes] = Counter()
received_msg_counters: Counter[int] = Counter()
# To count how many results (dissemination time) have been collected so far
result_cnt = 0
# Wait until all messages are disseminated to the entire network.
@ -56,13 +56,16 @@ class Simulation:
# Wait until a node notifies that it has received a new message.
received_time, msg = await self.received_msg_queue.get()
# If the message has been received by all nodes, calculate the dissemination time.
received_msg_counters.update([msg])
if received_msg_counters[msg] == len(nodes):
sent_time = Message.from_bytes(msg).sent_time
dissemination_time = received_time - sent_time
received_msg_counters.update([msg.id()])
if received_msg_counters[msg.id()] == len(nodes):
dissemination_time = received_time - msg.sent_time
# Use repr to convert a float to a string with as much precision as Python can provide
writer.writerow(
[repr(dissemination_time), repr(sent_time), repr(received_time)]
[
repr(dissemination_time),
repr(msg.sent_time),
repr(received_time),
]
)
result_cnt += 1
@ -76,13 +79,13 @@ class Simulation:
for _ in range(self.config.num_nodes)
]
async def __process_msg(self, msg: bytes) -> None:
async def __process_msg(self, msg: NomssipMessage[Message]) -> None:
"""
A handler to process messages received via Nomos Gossip channel
"""
# Notify that a new message has been received by the node.
# The received time is also included in the notification.
await self.received_msg_queue.put((self.framework.now(), msg))
await self.received_msg_queue.put((self.framework.now(), msg.message))
def __connect_nodes(self, nodes: list[Node], topology_path: str):
topology = build_full_random_topology(
@ -127,30 +130,5 @@ class Simulation:
for i in range(self.config.num_sent_msgs):
if i > 0:
await self.framework.sleep(self.config.msg_interval_sec)
msg = bytes(self.message_builder.next())
msg = self.message_builder.next()
await sender.send_message(msg)
@dataclass
class Message:
id: int
sent_time: float
def __bytes__(self) -> bytes:
return struct.pack("if", self.id, self.sent_time)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
id, sent_from = struct.unpack("if", data)
return cls(id, sent_from)
class MessageBuilder:
def __init__(self, framework: Framework):
self.framework = framework
self.next_id = 0
def next(self) -> Message:
msg = Message(self.next_id, self.framework.now())
self.next_id += 1
return msg

View File

@ -1,18 +1,19 @@
import math
from abc import abstractmethod
from collections import Counter
from typing import Awaitable
from typing import Protocol, TypeVar
import pandas
from typing_extensions import override
from framework import Framework, Queue
from protocol.connection import SimplexConnection
from sim.config import LatencyConfig, NetworkConfig
from sim.config import LatencyConfig
from sim.state import NodeState
T = TypeVar("T")
class RemoteSimplexConnection(SimplexConnection):
class RemoteSimplexConnection(SimplexConnection[T]):
"""
A simplex connection implementation that simulates network latency.
"""
@ -22,18 +23,18 @@ class RemoteSimplexConnection(SimplexConnection):
# A connection has a random constant latency
self.latency = config.random_latency()
# A queue of tuple(timestamp, msg) where a sender puts messages to be sent
self.send_queue: Queue[tuple[float, bytes]] = framework.queue()
self.send_queue: Queue[tuple[float, T]] = framework.queue()
# A task that reads messages from send_queue, and puts them to recv_queue.
# Before putting messages to recv_queue, the task simulates network latency according to the timestamp of each message.
self.relayer = framework.spawn(self.__run_relayer())
# A queue where a receiver gets messages
self.recv_queue: Queue[bytes] = framework.queue()
self.recv_queue: Queue[T] = framework.queue()
async def send(self, data: bytes) -> None:
async def send(self, data: T) -> None:
await self.send_queue.put((self.framework.now(), data))
self.on_sending(data)
async def recv(self) -> bytes:
async def recv(self) -> T:
return await self.recv_queue.get()
async def __run_relayer(self):
@ -54,16 +55,23 @@ class RemoteSimplexConnection(SimplexConnection):
self.on_receiving(data)
await self.recv_queue.put(data)
def on_sending(self, data: bytes) -> None:
def on_sending(self, data: T) -> None:
# Should be overridden by subclass
pass
def on_receiving(self, data: bytes) -> None:
def on_receiving(self, data: T) -> None:
# Should be overridden by subclass
pass
class MeteredRemoteSimplexConnection(RemoteSimplexConnection):
class HasLen(Protocol):
def __len__(self) -> int: ...
TL = TypeVar("TL", bound=HasLen)
class MeteredRemoteSimplexConnection(RemoteSimplexConnection[TL]):
"""
An extension of RemoteSimplexConnection that measures bandwidth usages.
"""
@ -81,14 +89,14 @@ class MeteredRemoteSimplexConnection(RemoteSimplexConnection):
self.recv_meters: list[int] = []
@override
def on_sending(self, data: bytes) -> None:
def on_sending(self, data: TL) -> None:
"""
Update statistics when sending a message
"""
self.__update_meter(self.send_meters, len(data))
@override
def on_receiving(self, data: bytes) -> None:
def on_receiving(self, data: TL) -> None:
"""
Update statistics when receiving a message
"""
@ -120,7 +128,7 @@ class MeteredRemoteSimplexConnection(RemoteSimplexConnection):
return pandas.Series(meters, name="bandwidth")
class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection):
class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection[TL]):
"""
An extension of MeteredRemoteSimplexConnection that is observed by passive observer.
The observer monitors the node states of the sender and receiver and message sizes.
@ -143,13 +151,13 @@ class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection):
self.msg_sizes: Counter[int] = Counter()
@override
def on_sending(self, data: bytes) -> None:
def on_sending(self, data: TL) -> None:
super().on_sending(data)
self.__update_node_state(self.send_node_states, NodeState.SENDING)
self.msg_sizes.update([len(data)])
@override
def on_receiving(self, data: bytes) -> None:
def on_receiving(self, data: TL) -> None:
super().on_receiving(data)
self.__update_node_state(self.recv_node_states, NodeState.RECEIVING)

View File

@ -1,3 +1,4 @@
import hashlib
import pickle
from dataclasses import dataclass
from typing import Self
@ -8,7 +9,29 @@ class Message:
"""
A message structure for simulation, which will be sent through mix nodes
and eventually broadcasted to all nodes in the network.
"""
# The bytes of Sphinx packet
data: bytes
def id(self) -> int:
return int.from_bytes(hashlib.sha256(self.data).digest(), byteorder="big")
def __len__(self) -> int:
return len(self.data)
def __bytes__(self) -> bytes:
return self.data
@classmethod
def from_bytes(cls, data: bytes) -> Self:
return cls(data)
@dataclass
class InnerMessage:
"""
The inner message that is wrapped by Sphinx packet.
The `id` must ensure the uniqueness of the message.
"""
@ -23,11 +46,8 @@ class Message:
def from_bytes(cls, data: bytes) -> Self:
return pickle.loads(data)
def __hash__(self) -> int:
return self.id
class UniqueMessageBuilder:
class UniqueInnerMessageBuilder:
"""
Builds a unique message with an incremental ID,
assuming that the simulation is run in a single thread.
@ -36,7 +56,7 @@ class UniqueMessageBuilder:
def __init__(self):
self.next_id = 0
def next(self, created_at: float, body: bytes) -> Message:
msg = Message(created_at, self.next_id, body)
def next(self, created_at: float, body: bytes) -> InnerMessage:
msg = InnerMessage(created_at, self.next_id, body)
self.next_id += 1
return msg

View File

@ -1,20 +1,19 @@
from dataclasses import asdict, dataclass
from pprint import pprint
from typing import Self
import usim
from matplotlib import pyplot
import framework.usim as usimfw
from framework import Framework
from protocol.config import GlobalConfig, MixMembership, NodeInfo
from protocol.node import Node, PeeringDegreeReached
from protocol.node import Node
from protocol.nomssip import NomssipMessage
from protocol.sphinx import SphinxPacketBuilder
from sim.config import Config
from sim.connection import (
MeteredRemoteSimplexConnection,
ObservedMeteredRemoteSimplexConnection,
)
from sim.message import Message, UniqueMessageBuilder
from sim.message import InnerMessage, Message, UniqueInnerMessageBuilder
from sim.state import NodeState, NodeStateTable
from sim.stats import ConnectionStats, DisseminationTime
from sim.topology import build_full_random_topology
@ -27,7 +26,7 @@ class Simulation:
def __init__(self, config: Config):
self.config = config
self.msg_builder = UniqueMessageBuilder()
self.inner_msg_builder = UniqueInnerMessageBuilder()
self.dissemination_time = DisseminationTime(self.config.network.num_nodes)
async def run(self):
@ -61,7 +60,7 @@ class Simulation:
# Return analysis tools once the μSim scope is done
return conn_stats, node_state_table
def __init_nodes(self) -> list[Node]:
def __init_nodes(self) -> list[Node[Message]]:
# Initialize node/global configurations
node_configs = self.config.node_configs()
global_config = GlobalConfig(
@ -78,20 +77,22 @@ class Simulation:
)
# Initialize/return Node instances
noise_msg = Message(bytes(SphinxPacketBuilder.size(global_config)))
return [
Node(
Node[Message](
self.framework,
node_config,
global_config,
self.__process_broadcasted_msg,
self.__process_recovered_msg,
noise_msg,
)
for node_config in node_configs
]
def __connect_nodes(
self,
nodes: list[Node],
nodes: list[Node[Message]],
node_state_table: NodeStateTable,
conn_stats: ConnectionStats,
):
@ -144,8 +145,8 @@ class Simulation:
meter_start_time: float,
sender_states: list[NodeState],
receiver_states: list[NodeState],
) -> ObservedMeteredRemoteSimplexConnection:
return ObservedMeteredRemoteSimplexConnection(
) -> ObservedMeteredRemoteSimplexConnection[NomssipMessage[Message]]:
return ObservedMeteredRemoteSimplexConnection[NomssipMessage[Message]](
self.config.network.latency,
self.framework,
meter_start_time,
@ -156,14 +157,14 @@ class Simulation:
def __create_conn(
self,
meter_start_time: float,
) -> MeteredRemoteSimplexConnection:
return MeteredRemoteSimplexConnection(
) -> MeteredRemoteSimplexConnection[Message]:
return MeteredRemoteSimplexConnection[Message](
self.config.network.latency,
self.framework,
meter_start_time,
)
async def __run_node_logic(self, node: Node):
async def __run_node_logic(self, node: Node[Message]):
"""
Runs the lottery periodically to check if the node is selected to send a block.
If the node is selected, creates a block and sends it through mix nodes.
@ -172,27 +173,29 @@ class Simulation:
while True:
await self.framework.sleep(lottery_config.interval_sec)
if lottery_config.seed.random() < lottery_config.probability:
msg = self.msg_builder.next(self.framework.now(), b"selected block")
await node.send_message(bytes(msg))
inner_msg = self.inner_msg_builder.next(
self.framework.now(), b"selected block"
)
await node.send_message(Message(bytes(inner_msg)))
async def __process_broadcasted_msg(self, msg: bytes):
async def __process_broadcasted_msg(self, msg: Message):
"""
Process a broadcasted message originated from the last mix.
"""
message = Message.from_bytes(msg)
elapsed = self.framework.now() - message.created_at
self.dissemination_time.add_broadcasted_msg(message, elapsed)
inner_msg = InnerMessage.from_bytes(msg.data)
elapsed = self.framework.now() - inner_msg.created_at
self.dissemination_time.add_broadcasted_msg(msg, elapsed)
async def __process_recovered_msg(self, msg: bytes) -> bytes:
async def __process_recovered_msg(self, msg: bytes) -> Message:
"""
Process a message fully recovered by the last mix
and returns a new message to be broadcasted.
"""
message = Message.from_bytes(msg)
elapsed = self.framework.now() - message.created_at
inner_msg = InnerMessage.from_bytes(Message.from_bytes(msg).data)
elapsed = self.framework.now() - inner_msg.created_at
self.dissemination_time.add_mix_propagation_time(elapsed)
# Update the timestamp and return the message to be broadcasted,
# so that the broadcast dissemination time can be calculated from now.
message.created_at = self.framework.now()
return bytes(message)
inner_msg.created_at = self.framework.now()
return Message(bytes(inner_msg))

View File

@ -3,7 +3,6 @@ from collections import Counter, defaultdict
import matplotlib.pyplot as plt
import numpy
import pandas
from matplotlib.axes import Axes
from protocol.node import Node
from sim.connection import ObservedMeteredRemoteSimplexConnection
@ -126,16 +125,18 @@ class DisseminationTime:
# A collection of time taken for a message to be broadcasted from the last mix to all nodes in the network
self.broadcast_dissemination_times: list[float] = []
# Data structures to check if a message has been broadcasted to all nodes
self.broadcast_status: Counter[Message] = Counter()
# msg_id (int) is a key.
self.broadcast_status: Counter[int] = Counter()
self.num_nodes: int = num_nodes
def add_mix_propagation_time(self, elapsed: float):
self.mix_propagation_times.append(elapsed)
def add_broadcasted_msg(self, msg: Message, elapsed: float):
assert self.broadcast_status[msg] < self.num_nodes
self.broadcast_status.update([msg])
if self.broadcast_status[msg] == self.num_nodes:
id = msg.id()
assert self.broadcast_status[id] < self.num_nodes
self.broadcast_status.update([id])
if self.broadcast_status[id] == self.num_nodes:
self.broadcast_dissemination_times.append(elapsed)
def analyze(self):

View File

@ -1,23 +1,22 @@
import time
from unittest import TestCase
from sim.message import Message, UniqueMessageBuilder
from sim.message import InnerMessage, UniqueInnerMessageBuilder
class TestMessage(TestCase):
def test_message_serde(self):
msg = Message(time.time(), 10, b"hello")
def test_inner_message_serde(self):
msg = InnerMessage(time.time(), 10, b"hello")
serialized = bytes(msg)
deserialized = Message.from_bytes(serialized)
deserialized = InnerMessage.from_bytes(serialized)
self.assertEqual(msg, deserialized)
class TestUniqueMessageBuilder(TestCase):
class TestUniqueInnerMessageBuilder(TestCase):
def test_uniqueness(self):
builder = UniqueMessageBuilder()
builder = UniqueInnerMessageBuilder()
msg1 = builder.next(time.time(), b"hello")
msg2 = builder.next(time.time(), b"hello")
self.assertEqual(0, msg1.id)
self.assertEqual(1, msg2.id)
self.assertNotEqual(msg1, msg2)
self.assertNotEqual(hash(msg1), hash(msg2))