mirror of
https://github.com/logos-blockchain/logos-blockchain-simulations.git
synced 2026-01-06 23:23:08 +00:00
optimize: use generic for messages to reduce the size of msg cache in gossip and discard serde cost
This commit is contained in:
parent
3a2f3cc079
commit
dbf1b78134
@ -1,59 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from framework import Framework, Queue
|
||||
from protocol.temporalmix import TemporalMix, TemporalMixConfig
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class SimplexConnection(abc.ABC):
|
||||
|
||||
class SimplexConnection(abc.ABC, Generic[T]):
|
||||
"""
|
||||
An abstract class for a simplex connection that can send and receive data in one direction
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(self, data: bytes) -> None:
|
||||
async def send(self, data: T) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def recv(self) -> bytes:
|
||||
async def recv(self) -> T:
|
||||
pass
|
||||
|
||||
|
||||
class LocalSimplexConnection(SimplexConnection):
|
||||
class LocalSimplexConnection(SimplexConnection[T]):
|
||||
"""
|
||||
A simplex connection that doesn't have any network latency.
|
||||
Data sent through this connection can be immediately received from the other end.
|
||||
"""
|
||||
|
||||
def __init__(self, framework: Framework):
|
||||
self.queue: Queue[bytes] = framework.queue()
|
||||
self.queue: Queue[T] = framework.queue()
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
async def send(self, data: T) -> None:
|
||||
await self.queue.put(data)
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
async def recv(self) -> T:
|
||||
return await self.queue.get()
|
||||
|
||||
|
||||
class DuplexConnection:
|
||||
class DuplexConnection(Generic[T]):
|
||||
"""
|
||||
A duplex connection in which data can be transmitted and received simultaneously in both directions.
|
||||
This is to mimic duplex communication in a real network (such as TCP or QUIC).
|
||||
"""
|
||||
|
||||
def __init__(self, inbound: SimplexConnection, outbound: SimplexConnection):
|
||||
def __init__(self, inbound: SimplexConnection[T], outbound: SimplexConnection[T]):
|
||||
self.inbound = inbound
|
||||
self.outbound = outbound
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
async def recv(self) -> T:
|
||||
return await self.inbound.recv()
|
||||
|
||||
async def send(self, packet: bytes):
|
||||
async def send(self, packet: T):
|
||||
await self.outbound.send(packet)
|
||||
|
||||
|
||||
class MixSimplexConnection(SimplexConnection):
|
||||
class MixSimplexConnection(SimplexConnection[T]):
|
||||
"""
|
||||
Wraps a SimplexConnection to add a transmission rate and noise to the connection.
|
||||
"""
|
||||
@ -61,16 +64,16 @@ class MixSimplexConnection(SimplexConnection):
|
||||
def __init__(
|
||||
self,
|
||||
framework: Framework,
|
||||
conn: SimplexConnection,
|
||||
conn: SimplexConnection[T],
|
||||
transmission_rate_per_sec: int,
|
||||
noise_msg: bytes,
|
||||
noise_msg: T,
|
||||
temporal_mix_config: TemporalMixConfig,
|
||||
# OPTIMIZATION ONLY FOR EXPERIMENTS WITHOUT BANDWIDTH MEASUREMENT
|
||||
# If True, skip sending a noise even if it's time to send one.
|
||||
skip_sending_noise: bool,
|
||||
):
|
||||
self.framework = framework
|
||||
self.queue: Queue[bytes] = TemporalMix.queue(
|
||||
self.queue: Queue[T] = TemporalMix.queue(
|
||||
temporal_mix_config, framework, noise_msg
|
||||
)
|
||||
self.conn = conn
|
||||
@ -87,8 +90,8 @@ class MixSimplexConnection(SimplexConnection):
|
||||
continue
|
||||
await self.conn.send(msg)
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
async def send(self, data: T) -> None:
|
||||
await self.queue.put(data)
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
async def recv(self) -> T:
|
||||
return await self.conn.recv()
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable
|
||||
from typing import Awaitable, Callable, Generic, Protocol, TypeVar
|
||||
|
||||
from framework import Framework
|
||||
from protocol.connection import (
|
||||
@ -18,7 +17,14 @@ class GossipConfig:
|
||||
peering_degree: int
|
||||
|
||||
|
||||
class Gossip:
|
||||
class HasId(Protocol):
|
||||
def id(self) -> int: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=HasId)
|
||||
|
||||
|
||||
class Gossip(Generic[T]):
|
||||
"""
|
||||
A gossip channel that broadcasts messages to all connected peers.
|
||||
Peers are connected via DuplexConnection.
|
||||
@ -28,15 +34,15 @@ class Gossip:
|
||||
self,
|
||||
framework: Framework,
|
||||
config: GossipConfig,
|
||||
handler: Callable[[bytes], Awaitable[None]],
|
||||
handler: Callable[[T], Awaitable[None]],
|
||||
):
|
||||
self.framework = framework
|
||||
self.config = config
|
||||
self.conns: list[DuplexConnection] = []
|
||||
self.conns: list[DuplexConnection[T]] = []
|
||||
# A handler to process inbound messages.
|
||||
self.handler = handler
|
||||
# msg -> received_cnt
|
||||
self.packet_cache: dict[bytes, int] = dict()
|
||||
# msg_id -> received_cnt
|
||||
self.packet_cache: dict[int, int] = dict()
|
||||
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
|
||||
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
|
||||
self.tasks: set[Awaitable] = set()
|
||||
@ -44,12 +50,12 @@ class Gossip:
|
||||
def can_accept_conn(self) -> bool:
|
||||
return len(self.conns) < self.config.peering_degree
|
||||
|
||||
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
|
||||
def add_conn(self, inbound: SimplexConnection[T], outbound: SimplexConnection[T]):
|
||||
if not self.can_accept_conn():
|
||||
# For simplicity of the spec, reject the connection if the peering degree is reached.
|
||||
raise PeeringDegreeReached()
|
||||
|
||||
conn = DuplexConnection(
|
||||
conn = DuplexConnection[T](
|
||||
inbound,
|
||||
outbound,
|
||||
)
|
||||
@ -57,18 +63,18 @@ class Gossip:
|
||||
task = self.framework.spawn(self.__process_inbound_conn(conn))
|
||||
self.tasks.add(task)
|
||||
|
||||
async def __process_inbound_conn(self, conn: DuplexConnection):
|
||||
async def __process_inbound_conn(self, conn: DuplexConnection[T]):
|
||||
while True:
|
||||
msg = await conn.recv()
|
||||
if self._check_update_cache(msg):
|
||||
continue
|
||||
await self._process_inbound_msg(msg, conn)
|
||||
|
||||
async def _process_inbound_msg(self, msg: bytes, received_from: DuplexConnection):
|
||||
async def _process_inbound_msg(self, msg: T, received_from: DuplexConnection[T]):
|
||||
await self._gossip(msg, [received_from])
|
||||
await self.handler(msg)
|
||||
|
||||
async def publish(self, msg: bytes):
|
||||
async def publish(self, msg: T):
|
||||
"""
|
||||
Publish a message to all nodes in the network.
|
||||
"""
|
||||
@ -83,7 +89,7 @@ class Gossip:
|
||||
# which means that we consider that this publisher node received the message.
|
||||
await self.handler(msg)
|
||||
|
||||
async def _gossip(self, msg: bytes, excludes: list[DuplexConnection] = []):
|
||||
async def _gossip(self, msg: T, excludes: list[DuplexConnection] = []):
|
||||
"""
|
||||
Gossip a message to all peers connected to this node.
|
||||
"""
|
||||
@ -91,26 +97,26 @@ class Gossip:
|
||||
if conn not in excludes:
|
||||
await conn.send(msg)
|
||||
|
||||
def _check_update_cache(self, packet: bytes, publishing: bool = False) -> bool:
|
||||
def _check_update_cache(self, msg: T, publishing: bool = False) -> bool:
|
||||
"""
|
||||
Add a message to the cache, and return True if the message was already in the cache.
|
||||
"""
|
||||
hash = hashlib.sha256(packet).digest()
|
||||
seen = hash in self.packet_cache
|
||||
id = msg.id()
|
||||
seen = id in self.packet_cache
|
||||
|
||||
if publishing:
|
||||
if not seen:
|
||||
# Put 0 when publishing, so that the publisher node doesn't gossip the message again
|
||||
# even when it first receive the message from one of its peers later.
|
||||
self.packet_cache[hash] = 0
|
||||
self.packet_cache[id] = 0
|
||||
else:
|
||||
if not seen:
|
||||
self.packet_cache[hash] = 1
|
||||
self.packet_cache[id] = 1
|
||||
else:
|
||||
self.packet_cache[hash] += 1
|
||||
self.packet_cache[id] += 1
|
||||
# Remove the message from the cache if it's received from all adjacent peers in the end
|
||||
# to reduce the size of cache.
|
||||
if self.packet_cache[hash] >= self.config.peering_degree:
|
||||
del self.packet_cache[hash]
|
||||
if self.packet_cache[id] >= self.config.peering_degree:
|
||||
del self.packet_cache[id]
|
||||
|
||||
return seen
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Awaitable, Callable
|
||||
from typing import Awaitable, Callable, Generic, Protocol, Self, Type, TypeVar
|
||||
|
||||
from pysphinx.sphinx import (
|
||||
ProcessedFinalHopPacket,
|
||||
@ -13,11 +13,22 @@ from protocol.config import GlobalConfig, NodeConfig
|
||||
from protocol.connection import SimplexConnection
|
||||
from protocol.error import PeeringDegreeReached
|
||||
from protocol.gossip import Gossip
|
||||
from protocol.nomssip import Nomssip, NomssipConfig
|
||||
from protocol.nomssip import Nomssip, NomssipConfig, NomssipMessage
|
||||
from protocol.sphinx import SphinxPacketBuilder
|
||||
|
||||
|
||||
class Node:
|
||||
class HasIdAndLenAndBytes(Protocol):
|
||||
def id(self) -> int: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __bytes__(self) -> bytes: ...
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=HasIdAndLenAndBytes)
|
||||
|
||||
|
||||
class Node(Generic[T]):
|
||||
"""
|
||||
This represents any node in the network, which:
|
||||
- generates/gossips mix messages (Sphinx packets)
|
||||
@ -31,57 +42,53 @@ class Node:
|
||||
config: NodeConfig,
|
||||
global_config: GlobalConfig,
|
||||
# A handler called when a node receives a broadcasted message originated from the last mix.
|
||||
broadcasted_msg_handler: Callable[[bytes], Awaitable[None]],
|
||||
# An optional handler only for the simulation,
|
||||
# which is called when a message is fully recovered by the last mix
|
||||
broadcasted_msg_handler: Callable[[T], Awaitable[None]],
|
||||
# A handler called when a message is fully recovered by the last mix
|
||||
# and returns a new message to be broadcasted.
|
||||
recovered_msg_handler: Callable[[bytes], Awaitable[bytes]] | None = None,
|
||||
recovered_msg_handler: Callable[[bytes], Awaitable[T]],
|
||||
noise_msg: T,
|
||||
):
|
||||
self.framework = framework
|
||||
self.config = config
|
||||
self.global_config = global_config
|
||||
nomssip_config = NomssipConfig(
|
||||
config.gossip.peering_degree,
|
||||
global_config.transmission_rate_per_sec,
|
||||
SphinxPacketBuilder.size(global_config),
|
||||
config.temporal_mix,
|
||||
)
|
||||
self.nomssip = Nomssip(
|
||||
framework,
|
||||
NomssipConfig(
|
||||
config.gossip.peering_degree,
|
||||
global_config.transmission_rate_per_sec,
|
||||
self.__calculate_message_size(global_config),
|
||||
config.temporal_mix,
|
||||
),
|
||||
nomssip_config,
|
||||
self.__process_msg,
|
||||
noise_msg=NomssipMessage[T](NomssipMessage.Flag.NOISE, noise_msg),
|
||||
)
|
||||
self.broadcast = Gossip(framework, config.gossip, broadcasted_msg_handler)
|
||||
self.broadcast = Gossip[T](framework, config.gossip, broadcasted_msg_handler)
|
||||
self.recovered_msg_handler = recovered_msg_handler
|
||||
|
||||
@staticmethod
|
||||
def __calculate_message_size(global_config: GlobalConfig) -> int:
|
||||
"""
|
||||
Calculate the actual message size to be gossiped, which depends on the maximum length of mix path.
|
||||
"""
|
||||
sample_sphinx_packet, _ = SphinxPacketBuilder.build(
|
||||
bytes(global_config.max_message_size),
|
||||
global_config,
|
||||
global_config.max_mix_path_length,
|
||||
)
|
||||
return len(sample_sphinx_packet.bytes())
|
||||
|
||||
async def __process_msg(self, msg: bytes) -> None:
|
||||
async def __process_msg(self, msg: NomssipMessage[T]) -> None:
|
||||
"""
|
||||
A handler to process messages received via Nomssip channel
|
||||
"""
|
||||
assert msg.flag == NomssipMessage.Flag.REAL
|
||||
|
||||
sphinx_packet = SphinxPacket.from_bytes(
|
||||
msg, self.global_config.max_mix_path_length
|
||||
bytes(msg.message), self.global_config.max_mix_path_length
|
||||
)
|
||||
result = await self.__process_sphinx_packet(sphinx_packet)
|
||||
match result:
|
||||
case SphinxPacket():
|
||||
# Gossip the next Sphinx packet
|
||||
await self.nomssip.publish(result.bytes())
|
||||
t: Type[T] = type(msg.message)
|
||||
await self.nomssip.publish(
|
||||
NomssipMessage[T](
|
||||
NomssipMessage.Flag.REAL,
|
||||
t.from_bytes(result.bytes()),
|
||||
)
|
||||
)
|
||||
case bytes():
|
||||
if self.recovered_msg_handler is not None:
|
||||
result = await self.recovered_msg_handler(result)
|
||||
# Broadcast the message fully recovered from Sphinx packets
|
||||
await self.broadcast.publish(result)
|
||||
await self.broadcast.publish(await self.recovered_msg_handler(result))
|
||||
case None:
|
||||
return
|
||||
|
||||
@ -105,31 +112,36 @@ class Node:
|
||||
def connect_mix(
|
||||
self,
|
||||
peer: Node,
|
||||
inbound_conn: SimplexConnection,
|
||||
outbound_conn: SimplexConnection,
|
||||
inbound_conn: SimplexConnection[NomssipMessage[T]],
|
||||
outbound_conn: SimplexConnection[NomssipMessage[T]],
|
||||
):
|
||||
connect_nodes(self.nomssip, peer.nomssip, inbound_conn, outbound_conn)
|
||||
|
||||
def connect_broadcast(
|
||||
self,
|
||||
peer: Node,
|
||||
inbound_conn: SimplexConnection,
|
||||
outbound_conn: SimplexConnection,
|
||||
inbound_conn: SimplexConnection[T],
|
||||
outbound_conn: SimplexConnection[T],
|
||||
):
|
||||
connect_nodes(self.broadcast, peer.broadcast, inbound_conn, outbound_conn)
|
||||
|
||||
async def send_message(self, msg: bytes):
|
||||
async def send_message(self, msg: T):
|
||||
"""
|
||||
Build a Sphinx packet and gossip it to all connected peers.
|
||||
"""
|
||||
# Here, we handle the case in which a msg is split into multiple Sphinx packets.
|
||||
# But, in practice, we expect a message to be small enough to fit in a single Sphinx packet.
|
||||
sphinx_packet, _ = SphinxPacketBuilder.build(
|
||||
msg,
|
||||
bytes(msg),
|
||||
self.global_config,
|
||||
self.config.mix_path_length,
|
||||
)
|
||||
await self.nomssip.publish(sphinx_packet.bytes())
|
||||
t: Type[T] = type(msg)
|
||||
await self.nomssip.publish(
|
||||
NomssipMessage(
|
||||
NomssipMessage.Flag.REAL, t.from_bytes(sphinx_packet.bytes())
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def connect_nodes(
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, Self, override
|
||||
from typing import Awaitable, Callable, Generic, Protocol, TypeVar, override
|
||||
|
||||
from framework import Framework
|
||||
from protocol.connection import (
|
||||
@ -24,7 +22,31 @@ class NomssipConfig(GossipConfig):
|
||||
skip_sending_noise: bool = False
|
||||
|
||||
|
||||
class Nomssip(Gossip):
|
||||
class HasIdAndLen(Protocol):
|
||||
def id(self) -> int: ...
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=HasIdAndLen)
|
||||
|
||||
|
||||
class NomssipMessage(Generic[T]):
|
||||
class Flag(Enum):
|
||||
REAL = b"\x00"
|
||||
NOISE = b"\x01"
|
||||
|
||||
def __init__(self, flag: Flag, message: T):
|
||||
self.flag = flag
|
||||
self.message = message
|
||||
|
||||
def id(self) -> int:
|
||||
return self.message.id()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.flag.value) + len(self.message)
|
||||
|
||||
|
||||
class Nomssip(Gossip[NomssipMessage[T]]):
|
||||
"""
|
||||
A NomMix gossip channel that extends the Gossip channel
|
||||
by adding global transmission rate and noise generation.
|
||||
@ -34,72 +56,53 @@ class Nomssip(Gossip):
|
||||
self,
|
||||
framework: Framework,
|
||||
config: NomssipConfig,
|
||||
handler: Callable[[bytes], Awaitable[None]],
|
||||
handler: Callable[[NomssipMessage[T]], Awaitable[None]],
|
||||
noise_msg: NomssipMessage[T],
|
||||
):
|
||||
super().__init__(framework, config, handler)
|
||||
self.config = config
|
||||
self.noise_msg = noise_msg
|
||||
|
||||
@override
|
||||
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
|
||||
noise_packet = FlaggedPacket(
|
||||
FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size)
|
||||
).bytes()
|
||||
def add_conn(
|
||||
self,
|
||||
inbound: SimplexConnection[NomssipMessage[T]],
|
||||
outbound: SimplexConnection[NomssipMessage[T]],
|
||||
):
|
||||
super().add_conn(
|
||||
inbound,
|
||||
MixSimplexConnection(
|
||||
MixSimplexConnection[NomssipMessage[T]](
|
||||
self.framework,
|
||||
outbound,
|
||||
self.config.transmission_rate_per_sec,
|
||||
noise_packet,
|
||||
self.noise_msg,
|
||||
self.config.temporal_mix,
|
||||
self.config.skip_sending_noise,
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def _process_inbound_msg(self, msg: bytes, received_from: DuplexConnection):
|
||||
packet = FlaggedPacket.from_bytes(msg)
|
||||
match packet.flag:
|
||||
case FlaggedPacket.Flag.NOISE:
|
||||
async def _process_inbound_msg(
|
||||
self, msg: NomssipMessage[T], received_from: DuplexConnection
|
||||
):
|
||||
match msg.flag:
|
||||
case NomssipMessage.Flag.NOISE:
|
||||
# Drop noise packet
|
||||
return
|
||||
case FlaggedPacket.Flag.REAL:
|
||||
self.assert_message_size(packet.message)
|
||||
case NomssipMessage.Flag.REAL:
|
||||
self.assert_message_size(msg.message)
|
||||
await super()._gossip(msg, [received_from])
|
||||
await self.handler(packet.message)
|
||||
await self.handler(msg)
|
||||
|
||||
@override
|
||||
async def publish(self, msg: bytes):
|
||||
self.assert_message_size(msg)
|
||||
async def publish(self, msg: NomssipMessage[T]):
|
||||
self.assert_message_size(msg.message)
|
||||
|
||||
packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg).bytes()
|
||||
# Please see comments in super().publish() for the reason of the following line.
|
||||
if not self._check_update_cache(packet, publishing=True):
|
||||
await self._gossip(packet)
|
||||
if not self._check_update_cache(msg, publishing=True):
|
||||
await self._gossip(msg)
|
||||
await self.handler(msg)
|
||||
|
||||
def assert_message_size(self, msg: bytes):
|
||||
def assert_message_size(self, msg: T):
|
||||
# The message size must be fixed.
|
||||
assert len(msg) == self.config.msg_size, f"{len(msg)} != {self.config.msg_size}"
|
||||
|
||||
|
||||
class FlaggedPacket:
|
||||
class Flag(Enum):
|
||||
REAL = b"\x00"
|
||||
NOISE = b"\x01"
|
||||
|
||||
def __init__(self, flag: Flag, message: bytes):
|
||||
self.flag = flag
|
||||
self.message = message
|
||||
|
||||
def bytes(self) -> bytes:
|
||||
return self.flag.value + self.message
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, packet: bytes) -> Self:
|
||||
"""
|
||||
Parse a flagged packet from bytes
|
||||
"""
|
||||
if len(packet) < 1:
|
||||
raise ValueError("Invalid message format")
|
||||
return cls(cls.Flag(packet[:1]), packet[1:])
|
||||
|
||||
@ -31,3 +31,15 @@ class SphinxPacketBuilder:
|
||||
max_plain_payload_size=global_config.max_message_size,
|
||||
)
|
||||
return (packet, route)
|
||||
|
||||
@staticmethod
|
||||
def size(global_config: GlobalConfig) -> int:
|
||||
"""
|
||||
Calculate the size of Sphinx packet, which depends on the maximum length of mix path.
|
||||
"""
|
||||
sample_sphinx_packet, _ = SphinxPacketBuilder.build(
|
||||
bytes(global_config.max_message_size),
|
||||
global_config,
|
||||
global_config.max_mix_path_length,
|
||||
)
|
||||
return len(sample_sphinx_packet.bytes())
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
import framework.asyncio as asynciofw
|
||||
from framework.framework import Queue
|
||||
from protocol.connection import LocalSimplexConnection
|
||||
from protocol.node import Node
|
||||
from protocol.nomssip import NomssipMessage
|
||||
from protocol.test_utils import (
|
||||
init_mixnet_config,
|
||||
)
|
||||
@ -14,31 +18,42 @@ class TestNode(IsolatedAsyncioTestCase):
|
||||
framework = asynciofw.Framework()
|
||||
global_config, node_configs, _ = init_mixnet_config(10)
|
||||
|
||||
queue: Queue[bytes] = framework.queue()
|
||||
queue: Queue[Message] = framework.queue()
|
||||
|
||||
async def broadcasted_msg_handler(msg: bytes) -> None:
|
||||
async def broadcasted_msg_handler(msg: Message) -> None:
|
||||
await queue.put(msg)
|
||||
|
||||
async def recovered_msg_handler(msg: bytes) -> Message:
|
||||
return Message(msg)
|
||||
|
||||
nodes = [
|
||||
Node(framework, node_config, global_config, broadcasted_msg_handler)
|
||||
Node[Message](
|
||||
framework,
|
||||
node_config,
|
||||
global_config,
|
||||
broadcasted_msg_handler,
|
||||
recovered_msg_handler,
|
||||
noise_msg=Message(b""),
|
||||
)
|
||||
for node_config in node_configs
|
||||
]
|
||||
for i, node in enumerate(nodes):
|
||||
try:
|
||||
node.connect_mix(
|
||||
nodes[(i + 1) % len(nodes)],
|
||||
LocalSimplexConnection(framework),
|
||||
LocalSimplexConnection(framework),
|
||||
LocalSimplexConnection[NomssipMessage[Message]](framework),
|
||||
LocalSimplexConnection[NomssipMessage[Message]](framework),
|
||||
)
|
||||
node.connect_broadcast(
|
||||
nodes[(i + 1) % len(nodes)],
|
||||
LocalSimplexConnection(framework),
|
||||
LocalSimplexConnection(framework),
|
||||
LocalSimplexConnection[Message](framework),
|
||||
LocalSimplexConnection[Message](framework),
|
||||
)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
|
||||
await nodes[0].send_message(b"block selection")
|
||||
msg = Message(b"block selection")
|
||||
await nodes[0].send_message(msg)
|
||||
|
||||
# Wait for all nodes to receive the broadcast
|
||||
num_nodes_received_broadcast = 0
|
||||
@ -47,7 +62,7 @@ class TestNode(IsolatedAsyncioTestCase):
|
||||
await framework.sleep(1)
|
||||
|
||||
while not queue.empty():
|
||||
self.assertEqual(b"block selection", await queue.get())
|
||||
self.assertEqual(msg, await queue.get())
|
||||
num_nodes_received_broadcast += 1
|
||||
|
||||
if num_nodes_received_broadcast == len(nodes):
|
||||
@ -56,3 +71,21 @@ class TestNode(IsolatedAsyncioTestCase):
|
||||
self.assertEqual(len(nodes), num_nodes_received_broadcast)
|
||||
|
||||
# TODO: check noise
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
data: bytes
|
||||
|
||||
def id(self) -> int:
|
||||
return int.from_bytes(hashlib.sha256(self.data).digest(), byteorder="big")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.data
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
return cls(data)
|
||||
|
||||
30
mixnet/queuesim/message.py
Normal file
30
mixnet/queuesim/message.py
Normal file
@ -0,0 +1,30 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from framework.framework import Framework
|
||||
|
||||
MESSAGE_SIZE = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
_id: int
|
||||
sent_time: float
|
||||
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
def __len__(self) -> int:
|
||||
# Return any number here, since we don't use Sphinx encoding for queuesim and byte serialization.
|
||||
# This must be matched with NomssipConfig.msg_size.
|
||||
return MESSAGE_SIZE
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
def __init__(self, framework: Framework):
|
||||
self.framework = framework
|
||||
self.next_id = 0
|
||||
|
||||
def next(self) -> Message:
|
||||
msg = Message(self.next_id, self.framework.now())
|
||||
self.next_id += 1
|
||||
return msg
|
||||
@ -5,7 +5,8 @@ from typing import Awaitable, Callable
|
||||
from framework.framework import Framework
|
||||
from protocol.connection import SimplexConnection
|
||||
from protocol.node import connect_nodes
|
||||
from protocol.nomssip import Nomssip, NomssipConfig
|
||||
from protocol.nomssip import Nomssip, NomssipConfig, NomssipMessage
|
||||
from queuesim.message import Message
|
||||
|
||||
|
||||
class Node:
|
||||
@ -13,20 +14,25 @@ class Node:
|
||||
self,
|
||||
framework: Framework,
|
||||
nomssip_config: NomssipConfig,
|
||||
msg_handler: Callable[[bytes], Awaitable[None]],
|
||||
msg_handler: Callable[[NomssipMessage[Message]], Awaitable[None]],
|
||||
):
|
||||
self.nomssip = Nomssip(framework, nomssip_config, msg_handler)
|
||||
self.nomssip = Nomssip(
|
||||
framework,
|
||||
nomssip_config,
|
||||
msg_handler,
|
||||
noise_msg=NomssipMessage(NomssipMessage.Flag.NOISE, Message(-1, 0)),
|
||||
)
|
||||
|
||||
def connect(
|
||||
self,
|
||||
peer: Node,
|
||||
inbound_conn: SimplexConnection,
|
||||
outbound_conn: SimplexConnection,
|
||||
inbound_conn: SimplexConnection[NomssipMessage[Message]],
|
||||
outbound_conn: SimplexConnection[NomssipMessage[Message]],
|
||||
):
|
||||
connect_nodes(self.nomssip, peer.nomssip, inbound_conn, outbound_conn)
|
||||
|
||||
async def send_message(self, msg: bytes):
|
||||
async def send_message(self, msg: Message):
|
||||
"""
|
||||
Send the message via Nomos Gossip to all connected peers.
|
||||
"""
|
||||
await self.nomssip.publish(msg)
|
||||
await self.nomssip.publish(NomssipMessage(NomssipMessage.Flag.REAL, msg))
|
||||
|
||||
@ -16,6 +16,7 @@ import usim
|
||||
from protocol.nomssip import NomssipConfig
|
||||
from protocol.temporalmix import TemporalMixConfig, TemporalMixType
|
||||
from queuesim.config import Config
|
||||
from queuesim.message import MESSAGE_SIZE
|
||||
from queuesim.paramset import (
|
||||
EXPERIMENT_TITLES,
|
||||
ExperimentID,
|
||||
@ -32,7 +33,7 @@ DEFAULT_CONFIG = Config(
|
||||
nomssip=NomssipConfig(
|
||||
peering_degree=3,
|
||||
transmission_rate_per_sec=10,
|
||||
msg_size=8,
|
||||
msg_size=MESSAGE_SIZE,
|
||||
temporal_mix=TemporalMixConfig(
|
||||
mix_type=TemporalMixType.NONE,
|
||||
min_queue_size=10,
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import csv
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from typing import Counter, Self
|
||||
from typing import Counter
|
||||
|
||||
import pandas as pd
|
||||
import usim
|
||||
@ -9,7 +7,9 @@ import usim
|
||||
from framework.framework import Queue
|
||||
from framework.usim import Framework
|
||||
from protocol.connection import LocalSimplexConnection, SimplexConnection
|
||||
from protocol.nomssip import NomssipMessage
|
||||
from queuesim.config import Config
|
||||
from queuesim.message import Message, MessageBuilder
|
||||
from queuesim.node import Node
|
||||
from sim.connection import RemoteSimplexConnection
|
||||
from sim.topology import build_full_random_topology
|
||||
@ -31,7 +31,7 @@ class Simulation:
|
||||
self.framework.stop_tasks()
|
||||
|
||||
async def __run(self, out_csv_path: str, topology_path: str):
|
||||
self.received_msg_queue: Queue[tuple[float, bytes]] = self.framework.queue()
|
||||
self.received_msg_queue: Queue[tuple[float, Message]] = self.framework.queue()
|
||||
|
||||
# Run and connect nodes
|
||||
nodes = self.__run_nodes()
|
||||
@ -48,7 +48,7 @@ class Simulation:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["dissemination_time", "sent_time", "all_received_time"])
|
||||
# To count how many nodes have received each message
|
||||
received_msg_counters: Counter[bytes] = Counter()
|
||||
received_msg_counters: Counter[int] = Counter()
|
||||
# To count how many results (dissemination time) have been collected so far
|
||||
result_cnt = 0
|
||||
# Wait until all messages are disseminated to the entire network.
|
||||
@ -56,13 +56,16 @@ class Simulation:
|
||||
# Wait until a node notifies that it has received a new message.
|
||||
received_time, msg = await self.received_msg_queue.get()
|
||||
# If the message has been received by all nodes, calculate the dissemination time.
|
||||
received_msg_counters.update([msg])
|
||||
if received_msg_counters[msg] == len(nodes):
|
||||
sent_time = Message.from_bytes(msg).sent_time
|
||||
dissemination_time = received_time - sent_time
|
||||
received_msg_counters.update([msg.id()])
|
||||
if received_msg_counters[msg.id()] == len(nodes):
|
||||
dissemination_time = received_time - msg.sent_time
|
||||
# Use repr to convert a float to a string with as much precision as Python can provide
|
||||
writer.writerow(
|
||||
[repr(dissemination_time), repr(sent_time), repr(received_time)]
|
||||
[
|
||||
repr(dissemination_time),
|
||||
repr(msg.sent_time),
|
||||
repr(received_time),
|
||||
]
|
||||
)
|
||||
result_cnt += 1
|
||||
|
||||
@ -76,13 +79,13 @@ class Simulation:
|
||||
for _ in range(self.config.num_nodes)
|
||||
]
|
||||
|
||||
async def __process_msg(self, msg: bytes) -> None:
|
||||
async def __process_msg(self, msg: NomssipMessage[Message]) -> None:
|
||||
"""
|
||||
A handler to process messages received via Nomos Gossip channel
|
||||
"""
|
||||
# Notify that a new message has been received by the node.
|
||||
# The received time is also included in the notification.
|
||||
await self.received_msg_queue.put((self.framework.now(), msg))
|
||||
await self.received_msg_queue.put((self.framework.now(), msg.message))
|
||||
|
||||
def __connect_nodes(self, nodes: list[Node], topology_path: str):
|
||||
topology = build_full_random_topology(
|
||||
@ -127,30 +130,5 @@ class Simulation:
|
||||
for i in range(self.config.num_sent_msgs):
|
||||
if i > 0:
|
||||
await self.framework.sleep(self.config.msg_interval_sec)
|
||||
msg = bytes(self.message_builder.next())
|
||||
msg = self.message_builder.next()
|
||||
await sender.send_message(msg)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
id: int
|
||||
sent_time: float
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack("if", self.id, self.sent_time)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
id, sent_from = struct.unpack("if", data)
|
||||
return cls(id, sent_from)
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
def __init__(self, framework: Framework):
|
||||
self.framework = framework
|
||||
self.next_id = 0
|
||||
|
||||
def next(self) -> Message:
|
||||
msg = Message(self.next_id, self.framework.now())
|
||||
self.next_id += 1
|
||||
return msg
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from collections import Counter
|
||||
from typing import Awaitable
|
||||
from typing import Protocol, TypeVar
|
||||
|
||||
import pandas
|
||||
from typing_extensions import override
|
||||
|
||||
from framework import Framework, Queue
|
||||
from protocol.connection import SimplexConnection
|
||||
from sim.config import LatencyConfig, NetworkConfig
|
||||
from sim.config import LatencyConfig
|
||||
from sim.state import NodeState
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class RemoteSimplexConnection(SimplexConnection):
|
||||
|
||||
class RemoteSimplexConnection(SimplexConnection[T]):
|
||||
"""
|
||||
A simplex connection implementation that simulates network latency.
|
||||
"""
|
||||
@ -22,18 +23,18 @@ class RemoteSimplexConnection(SimplexConnection):
|
||||
# A connection has a random constant latency
|
||||
self.latency = config.random_latency()
|
||||
# A queue of tuple(timestamp, msg) where a sender puts messages to be sent
|
||||
self.send_queue: Queue[tuple[float, bytes]] = framework.queue()
|
||||
self.send_queue: Queue[tuple[float, T]] = framework.queue()
|
||||
# A task that reads messages from send_queue, and puts them to recv_queue.
|
||||
# Before putting messages to recv_queue, the task simulates network latency according to the timestamp of each message.
|
||||
self.relayer = framework.spawn(self.__run_relayer())
|
||||
# A queue where a receiver gets messages
|
||||
self.recv_queue: Queue[bytes] = framework.queue()
|
||||
self.recv_queue: Queue[T] = framework.queue()
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
async def send(self, data: T) -> None:
|
||||
await self.send_queue.put((self.framework.now(), data))
|
||||
self.on_sending(data)
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
async def recv(self) -> T:
|
||||
return await self.recv_queue.get()
|
||||
|
||||
async def __run_relayer(self):
|
||||
@ -54,16 +55,23 @@ class RemoteSimplexConnection(SimplexConnection):
|
||||
self.on_receiving(data)
|
||||
await self.recv_queue.put(data)
|
||||
|
||||
def on_sending(self, data: bytes) -> None:
|
||||
def on_sending(self, data: T) -> None:
|
||||
# Should be overridden by subclass
|
||||
pass
|
||||
|
||||
def on_receiving(self, data: bytes) -> None:
|
||||
def on_receiving(self, data: T) -> None:
|
||||
# Should be overridden by subclass
|
||||
pass
|
||||
|
||||
|
||||
class MeteredRemoteSimplexConnection(RemoteSimplexConnection):
|
||||
class HasLen(Protocol):
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
|
||||
TL = TypeVar("TL", bound=HasLen)
|
||||
|
||||
|
||||
class MeteredRemoteSimplexConnection(RemoteSimplexConnection[TL]):
|
||||
"""
|
||||
An extension of RemoteSimplexConnection that measures bandwidth usages.
|
||||
"""
|
||||
@ -81,14 +89,14 @@ class MeteredRemoteSimplexConnection(RemoteSimplexConnection):
|
||||
self.recv_meters: list[int] = []
|
||||
|
||||
@override
|
||||
def on_sending(self, data: bytes) -> None:
|
||||
def on_sending(self, data: TL) -> None:
|
||||
"""
|
||||
Update statistics when sending a message
|
||||
"""
|
||||
self.__update_meter(self.send_meters, len(data))
|
||||
|
||||
@override
|
||||
def on_receiving(self, data: bytes) -> None:
|
||||
def on_receiving(self, data: TL) -> None:
|
||||
"""
|
||||
Update statistics when receiving a message
|
||||
"""
|
||||
@ -120,7 +128,7 @@ class MeteredRemoteSimplexConnection(RemoteSimplexConnection):
|
||||
return pandas.Series(meters, name="bandwidth")
|
||||
|
||||
|
||||
class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection):
|
||||
class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection[TL]):
|
||||
"""
|
||||
An extension of MeteredRemoteSimplexConnection that is observed by passive observer.
|
||||
The observer monitors the node states of the sender and receiver and message sizes.
|
||||
@ -143,13 +151,13 @@ class ObservedMeteredRemoteSimplexConnection(MeteredRemoteSimplexConnection):
|
||||
self.msg_sizes: Counter[int] = Counter()
|
||||
|
||||
@override
|
||||
def on_sending(self, data: bytes) -> None:
|
||||
def on_sending(self, data: TL) -> None:
|
||||
super().on_sending(data)
|
||||
self.__update_node_state(self.send_node_states, NodeState.SENDING)
|
||||
self.msg_sizes.update([len(data)])
|
||||
|
||||
@override
|
||||
def on_receiving(self, data: bytes) -> None:
|
||||
def on_receiving(self, data: TL) -> None:
|
||||
super().on_receiving(data)
|
||||
self.__update_node_state(self.recv_node_states, NodeState.RECEIVING)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
@ -8,7 +9,29 @@ class Message:
|
||||
"""
|
||||
A message structure for simulation, which will be sent through mix nodes
|
||||
and eventually broadcasted to all nodes in the network.
|
||||
"""
|
||||
|
||||
# The bytes of Sphinx packet
|
||||
data: bytes
|
||||
|
||||
def id(self) -> int:
|
||||
return int.from_bytes(hashlib.sha256(self.data).digest(), byteorder="big")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.data
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
return cls(data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InnerMessage:
|
||||
"""
|
||||
The inner message that is wrapped by Sphinx packet.
|
||||
The `id` must ensure the uniqueness of the message.
|
||||
"""
|
||||
|
||||
@ -23,11 +46,8 @@ class Message:
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
return pickle.loads(data)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.id
|
||||
|
||||
|
||||
class UniqueMessageBuilder:
|
||||
class UniqueInnerMessageBuilder:
|
||||
"""
|
||||
Builds a unique message with an incremental ID,
|
||||
assuming that the simulation is run in a single thread.
|
||||
@ -36,7 +56,7 @@ class UniqueMessageBuilder:
|
||||
def __init__(self):
|
||||
self.next_id = 0
|
||||
|
||||
def next(self, created_at: float, body: bytes) -> Message:
|
||||
msg = Message(created_at, self.next_id, body)
|
||||
def next(self, created_at: float, body: bytes) -> InnerMessage:
|
||||
msg = InnerMessage(created_at, self.next_id, body)
|
||||
self.next_id += 1
|
||||
return msg
|
||||
|
||||
@ -1,20 +1,19 @@
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pprint
|
||||
from typing import Self
|
||||
|
||||
import usim
|
||||
from matplotlib import pyplot
|
||||
|
||||
import framework.usim as usimfw
|
||||
from framework import Framework
|
||||
from protocol.config import GlobalConfig, MixMembership, NodeInfo
|
||||
from protocol.node import Node, PeeringDegreeReached
|
||||
from protocol.node import Node
|
||||
from protocol.nomssip import NomssipMessage
|
||||
from protocol.sphinx import SphinxPacketBuilder
|
||||
from sim.config import Config
|
||||
from sim.connection import (
|
||||
MeteredRemoteSimplexConnection,
|
||||
ObservedMeteredRemoteSimplexConnection,
|
||||
)
|
||||
from sim.message import Message, UniqueMessageBuilder
|
||||
from sim.message import InnerMessage, Message, UniqueInnerMessageBuilder
|
||||
from sim.state import NodeState, NodeStateTable
|
||||
from sim.stats import ConnectionStats, DisseminationTime
|
||||
from sim.topology import build_full_random_topology
|
||||
@ -27,7 +26,7 @@ class Simulation:
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.msg_builder = UniqueMessageBuilder()
|
||||
self.inner_msg_builder = UniqueInnerMessageBuilder()
|
||||
self.dissemination_time = DisseminationTime(self.config.network.num_nodes)
|
||||
|
||||
async def run(self):
|
||||
@ -61,7 +60,7 @@ class Simulation:
|
||||
# Return analysis tools once the μSim scope is done
|
||||
return conn_stats, node_state_table
|
||||
|
||||
def __init_nodes(self) -> list[Node]:
|
||||
def __init_nodes(self) -> list[Node[Message]]:
|
||||
# Initialize node/global configurations
|
||||
node_configs = self.config.node_configs()
|
||||
global_config = GlobalConfig(
|
||||
@ -78,20 +77,22 @@ class Simulation:
|
||||
)
|
||||
|
||||
# Initialize/return Node instances
|
||||
noise_msg = Message(bytes(SphinxPacketBuilder.size(global_config)))
|
||||
return [
|
||||
Node(
|
||||
Node[Message](
|
||||
self.framework,
|
||||
node_config,
|
||||
global_config,
|
||||
self.__process_broadcasted_msg,
|
||||
self.__process_recovered_msg,
|
||||
noise_msg,
|
||||
)
|
||||
for node_config in node_configs
|
||||
]
|
||||
|
||||
def __connect_nodes(
|
||||
self,
|
||||
nodes: list[Node],
|
||||
nodes: list[Node[Message]],
|
||||
node_state_table: NodeStateTable,
|
||||
conn_stats: ConnectionStats,
|
||||
):
|
||||
@ -144,8 +145,8 @@ class Simulation:
|
||||
meter_start_time: float,
|
||||
sender_states: list[NodeState],
|
||||
receiver_states: list[NodeState],
|
||||
) -> ObservedMeteredRemoteSimplexConnection:
|
||||
return ObservedMeteredRemoteSimplexConnection(
|
||||
) -> ObservedMeteredRemoteSimplexConnection[NomssipMessage[Message]]:
|
||||
return ObservedMeteredRemoteSimplexConnection[NomssipMessage[Message]](
|
||||
self.config.network.latency,
|
||||
self.framework,
|
||||
meter_start_time,
|
||||
@ -156,14 +157,14 @@ class Simulation:
|
||||
def __create_conn(
|
||||
self,
|
||||
meter_start_time: float,
|
||||
) -> MeteredRemoteSimplexConnection:
|
||||
return MeteredRemoteSimplexConnection(
|
||||
) -> MeteredRemoteSimplexConnection[Message]:
|
||||
return MeteredRemoteSimplexConnection[Message](
|
||||
self.config.network.latency,
|
||||
self.framework,
|
||||
meter_start_time,
|
||||
)
|
||||
|
||||
async def __run_node_logic(self, node: Node):
|
||||
async def __run_node_logic(self, node: Node[Message]):
|
||||
"""
|
||||
Runs the lottery periodically to check if the node is selected to send a block.
|
||||
If the node is selected, creates a block and sends it through mix nodes.
|
||||
@ -172,27 +173,29 @@ class Simulation:
|
||||
while True:
|
||||
await self.framework.sleep(lottery_config.interval_sec)
|
||||
if lottery_config.seed.random() < lottery_config.probability:
|
||||
msg = self.msg_builder.next(self.framework.now(), b"selected block")
|
||||
await node.send_message(bytes(msg))
|
||||
inner_msg = self.inner_msg_builder.next(
|
||||
self.framework.now(), b"selected block"
|
||||
)
|
||||
await node.send_message(Message(bytes(inner_msg)))
|
||||
|
||||
async def __process_broadcasted_msg(self, msg: bytes):
|
||||
async def __process_broadcasted_msg(self, msg: Message):
|
||||
"""
|
||||
Process a broadcasted message originated from the last mix.
|
||||
"""
|
||||
message = Message.from_bytes(msg)
|
||||
elapsed = self.framework.now() - message.created_at
|
||||
self.dissemination_time.add_broadcasted_msg(message, elapsed)
|
||||
inner_msg = InnerMessage.from_bytes(msg.data)
|
||||
elapsed = self.framework.now() - inner_msg.created_at
|
||||
self.dissemination_time.add_broadcasted_msg(msg, elapsed)
|
||||
|
||||
async def __process_recovered_msg(self, msg: bytes) -> bytes:
|
||||
async def __process_recovered_msg(self, msg: bytes) -> Message:
|
||||
"""
|
||||
Process a message fully recovered by the last mix
|
||||
and returns a new message to be broadcasted.
|
||||
"""
|
||||
message = Message.from_bytes(msg)
|
||||
elapsed = self.framework.now() - message.created_at
|
||||
inner_msg = InnerMessage.from_bytes(Message.from_bytes(msg).data)
|
||||
elapsed = self.framework.now() - inner_msg.created_at
|
||||
self.dissemination_time.add_mix_propagation_time(elapsed)
|
||||
|
||||
# Update the timestamp and return the message to be broadcasted,
|
||||
# so that the broadcast dissemination time can be calculated from now.
|
||||
message.created_at = self.framework.now()
|
||||
return bytes(message)
|
||||
inner_msg.created_at = self.framework.now()
|
||||
return Message(bytes(inner_msg))
|
||||
|
||||
@ -3,7 +3,6 @@ from collections import Counter, defaultdict
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy
|
||||
import pandas
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
from protocol.node import Node
|
||||
from sim.connection import ObservedMeteredRemoteSimplexConnection
|
||||
@ -126,16 +125,18 @@ class DisseminationTime:
|
||||
# A collection of time taken for a message to be broadcasted from the last mix to all nodes in the network
|
||||
self.broadcast_dissemination_times: list[float] = []
|
||||
# Data structures to check if a message has been broadcasted to all nodes
|
||||
self.broadcast_status: Counter[Message] = Counter()
|
||||
# msg_id (int) is a key.
|
||||
self.broadcast_status: Counter[int] = Counter()
|
||||
self.num_nodes: int = num_nodes
|
||||
|
||||
def add_mix_propagation_time(self, elapsed: float):
|
||||
self.mix_propagation_times.append(elapsed)
|
||||
|
||||
def add_broadcasted_msg(self, msg: Message, elapsed: float):
|
||||
assert self.broadcast_status[msg] < self.num_nodes
|
||||
self.broadcast_status.update([msg])
|
||||
if self.broadcast_status[msg] == self.num_nodes:
|
||||
id = msg.id()
|
||||
assert self.broadcast_status[id] < self.num_nodes
|
||||
self.broadcast_status.update([id])
|
||||
if self.broadcast_status[id] == self.num_nodes:
|
||||
self.broadcast_dissemination_times.append(elapsed)
|
||||
|
||||
def analyze(self):
|
||||
|
||||
@ -1,23 +1,22 @@
|
||||
import time
|
||||
from unittest import TestCase
|
||||
|
||||
from sim.message import Message, UniqueMessageBuilder
|
||||
from sim.message import InnerMessage, UniqueInnerMessageBuilder
|
||||
|
||||
|
||||
class TestMessage(TestCase):
|
||||
def test_message_serde(self):
|
||||
msg = Message(time.time(), 10, b"hello")
|
||||
def test_inner_message_serde(self):
|
||||
msg = InnerMessage(time.time(), 10, b"hello")
|
||||
serialized = bytes(msg)
|
||||
deserialized = Message.from_bytes(serialized)
|
||||
deserialized = InnerMessage.from_bytes(serialized)
|
||||
self.assertEqual(msg, deserialized)
|
||||
|
||||
|
||||
class TestUniqueMessageBuilder(TestCase):
|
||||
class TestUniqueInnerMessageBuilder(TestCase):
|
||||
def test_uniqueness(self):
|
||||
builder = UniqueMessageBuilder()
|
||||
builder = UniqueInnerMessageBuilder()
|
||||
msg1 = builder.next(time.time(), b"hello")
|
||||
msg2 = builder.next(time.time(), b"hello")
|
||||
self.assertEqual(0, msg1.id)
|
||||
self.assertEqual(1, msg2.id)
|
||||
self.assertNotEqual(msg1, msg2)
|
||||
self.assertNotEqual(hash(msg1), hash(msg2))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user