diff --git a/mixnet/config.py b/mixnet/config.py index a61e9d3..716a72c 100644 --- a/mixnet/config.py +++ b/mixnet/config.py @@ -22,6 +22,8 @@ class GlobalConfig: @dataclass class NodeConfig: private_key: X25519PrivateKey + # The max number of peers a node should maintain in its p2p network + peering_degree: int mix_path_length: int # TODO: use this when creating Sphinx packets diff --git a/mixnet/node.py b/mixnet/node.py index 42fe6d2..a7e7f12 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -30,7 +30,9 @@ class Node: def __init__(self, config: NodeConfig, global_config: GlobalConfig): self.config = config self.global_config = global_config - self.mixgossip_channel = MixGossipChannel(self.__process_sphinx_packet) + self.mixgossip_channel = MixGossipChannel( + config.peering_degree, self.__process_sphinx_packet + ) self.reconstructor = MessageReconstructor() self.broadcast_channel = asyncio.Queue() @@ -58,10 +60,22 @@ class Node: await self.broadcast_channel.put(msg) def connect(self, peer: Node): - conn = asyncio.Queue() - peer.mixgossip_channel.add_inbound(conn) - self.mixgossip_channel.add_outbound( - MixOutboundConnection(conn, self.global_config.transmission_rate_per_sec) + inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue() + self.mixgossip_channel.add_conn( + DuplexConnection( + inbound_conn, + MixOutboundConnection( + outbound_conn, self.global_config.transmission_rate_per_sec + ), + ) + ) + peer.mixgossip_channel.add_conn( + DuplexConnection( + outbound_conn, + MixOutboundConnection( + inbound_conn, self.global_config.transmission_rate_per_sec + ), + ) ) async def send_message(self, msg: bytes): @@ -72,34 +86,36 @@ class Node: class MixGossipChannel: - inbound_conns: list[Connection] - outbound_conns: list[MixOutboundConnection] + peering_degree: int + conns: list[DuplexConnection] handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]] def __init__( self, + peer_degree: int, handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]], ): - self.inbound_conns = [] - self.outbound_conns = [] + self.peering_degree = peer_degree + self.conns = [] self.handler = handler # A set just for gathering a reference of tasks to prevent them from being garbage collected. # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task self.tasks = set() - def add_inbound(self, conn: Connection): - self.inbound_conns.append(conn) + def add_conn(self, conn: DuplexConnection): + if len(self.conns) >= self.peering_degree: + # For simplicity of the spec, reject the connection if the peering degree is reached. + raise ValueError("The peering degree is reached.") + + self.conns.append(conn) task = asyncio.create_task(self.__process_inbound_conn(conn)) self.tasks.add(task) # To discard the task from the set automatically when it is done. task.add_done_callback(self.tasks.discard) - def add_outbound(self, conn: MixOutboundConnection): - self.outbound_conns.append(conn) - - async def __process_inbound_conn(self, conn: Connection): + async def __process_inbound_conn(self, conn: DuplexConnection): while True: - elem = await conn.get() + elem = await conn.recv() # In practice, data transmitted through connections is going to be always 'bytes'. # But here, we use the SphinxPacket type explicitly for simplicity # without implementing serde for SphinxPacket. @@ -113,10 +129,25 @@ class MixGossipChannel: await self.gossip(net_packet) async def gossip(self, packet: NetworkPacket): - for conn in self.outbound_conns: + for conn in self.conns: await conn.send(packet) +class DuplexConnection: + inbound: Connection + outbound: MixOutboundConnection + + def __init__(self, inbound: Connection, outbound: MixOutboundConnection): + self.inbound = inbound + self.outbound = outbound + + async def recv(self) -> NetworkPacket: + return await self.inbound.get() + + async def send(self, packet: NetworkPacket): + await self.outbound.send(packet) + + class MixOutboundConnection: queue: NetworkPacketQueue conn: Connection diff --git a/mixnet/test_node.py b/mixnet/test_node.py index 53e49f3..f4ba644 100644 --- a/mixnet/test_node.py +++ b/mixnet/test_node.py @@ -12,7 +12,10 @@ class TestNode(IsolatedAsyncioTestCase): global_config, node_configs, _ = init_mixnet_config(10) nodes = [Node(node_config, global_config) for node_config in node_configs] for i, node in enumerate(nodes): - node.connect(nodes[(i + 1) % len(nodes)]) + try: + node.connect(nodes[(i + 1) % len(nodes)]) + except ValueError as e: + print(e) await nodes[0].send_message(b"block selection") diff --git a/mixnet/test_utils.py b/mixnet/test_utils.py index be3541e..bdbd220 100644 --- a/mixnet/test_utils.py +++ b/mixnet/test_utils.py @@ -12,9 +12,10 @@ def init_mixnet_config( num_nodes: int, ) -> tuple[GlobalConfig, list[NodeConfig], dict[bytes, X25519PrivateKey]]: transmission_rate_per_sec = 3 + peering_degree = 6 max_mix_path_length = 3 node_configs = [ - NodeConfig(X25519PrivateKey.generate(), max_mix_path_length) + NodeConfig(X25519PrivateKey.generate(), peering_degree, max_mix_path_length) for _ in range(num_nodes) ] global_config = GlobalConfig(