From 9cd601c7ba89aea0a3b52f4320a21f991f56eaa0 Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Thu, 27 Jun 2024 18:04:51 +0900 Subject: [PATCH] use public key for NodeInfo instead of private key --- mixnet/config.py | 7 ++----- mixnet/test_node.py | 2 +- mixnet/test_packet.py | 46 +++++++++++++++++++++++++++++++------------ mixnet/test_utils.py | 15 +++++++++++--- requirements.txt | 4 ++-- 5 files changed, 50 insertions(+), 24 deletions(-) diff --git a/mixnet/config.py b/mixnet/config.py index 6360cc3..a61e9d3 100644 --- a/mixnet/config.py +++ b/mixnet/config.py @@ -48,12 +48,9 @@ class MixMembership: @dataclass class NodeInfo: - private_key: X25519PrivateKey - - def public_key(self) -> X25519PublicKey: - return self.private_key.public_key() + public_key: X25519PublicKey def sphinx_node(self) -> SphinxNode: # TODO: Use a pre-signed incentive tx, instead of NodeAddress dummy_node_addr = bytes(32) - return SphinxNode(self.private_key, dummy_node_addr) + return SphinxNode(self.public_key, dummy_node_addr) diff --git a/mixnet/test_node.py b/mixnet/test_node.py index 9a67e25..53e49f3 100644 --- a/mixnet/test_node.py +++ b/mixnet/test_node.py @@ -9,7 +9,7 @@ from mixnet.test_utils import ( class TestNode(IsolatedAsyncioTestCase): async def test_node(self): - global_config, node_configs = init_mixnet_config(10) + global_config, node_configs, _ = init_mixnet_config(10) nodes = [Node(node_config, global_config) for node_config in node_configs] for i, node in enumerate(nodes): node.connect(nodes[(i + 1) % len(nodes)]) diff --git a/mixnet/test_packet.py b/mixnet/test_packet.py index 5cf8098..77e1a03 100644 --- a/mixnet/test_packet.py +++ b/mixnet/test_packet.py @@ -2,7 +2,7 @@ from random import randint from typing import List from unittest import TestCase -from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket +from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket, X25519PrivateKey from mixnet.config import NodeInfo from mixnet.packet import ( @@ -16,29 +16,39 @@ from mixnet.test_utils import init_mixnet_config class TestPacket(TestCase): def test_real_packet(self): - membership = init_mixnet_config(10)[0].membership + global_config, _, key_map = init_mixnet_config(10) msg = self.random_bytes(3500) - packets_and_routes = PacketBuilder.build_real_packets(msg, membership) + packets_and_routes = PacketBuilder.build_real_packets( + msg, global_config.membership + ) self.assertEqual(4, len(packets_and_routes)) reconstructor = MessageReconstructor() self.assertIsNone( reconstructor.add( - self.process_packet(packets_and_routes[1][0], packets_and_routes[1][1]) + self.process_packet( + packets_and_routes[1][0], packets_and_routes[1][1], key_map + ) ), ) self.assertIsNone( reconstructor.add( - self.process_packet(packets_and_routes[3][0], packets_and_routes[3][1]) + self.process_packet( + packets_and_routes[3][0], packets_and_routes[3][1], key_map + ) ), ) self.assertIsNone( reconstructor.add( - self.process_packet(packets_and_routes[2][0], packets_and_routes[2][1]) + self.process_packet( + packets_and_routes[2][0], packets_and_routes[2][1], key_map + ) ), ) msg_with_flag = reconstructor.add( - self.process_packet(packets_and_routes[0][0], packets_and_routes[0][1]) + self.process_packet( + packets_and_routes[0][0], packets_and_routes[0][1], key_map + ) ) assert msg_with_flag is not None self.assertEqual( @@ -47,14 +57,18 @@ class TestPacket(TestCase): ) def test_cover_packet(self): - membership = init_mixnet_config(10)[0].membership + global_config, _, key_map = init_mixnet_config(10) msg = b"cover" - packets_and_routes = PacketBuilder.build_drop_cover_packets(msg, membership) + packets_and_routes = PacketBuilder.build_drop_cover_packets( + msg, global_config.membership + ) self.assertEqual(1, len(packets_and_routes)) reconstructor = MessageReconstructor() msg_with_flag = reconstructor.add( - self.process_packet(packets_and_routes[0][0], packets_and_routes[0][1]) + self.process_packet( + packets_and_routes[0][0], packets_and_routes[0][1], key_map + ) ) assert msg_with_flag is not None self.assertEqual( @@ -63,14 +77,20 @@ class TestPacket(TestCase): ) @staticmethod - def process_packet(packet: SphinxPacket, route: List[NodeInfo]) -> Fragment: - processed = packet.process(route[0].private_key) + def process_packet( + packet: SphinxPacket, + route: List[NodeInfo], + key_map: dict[bytes, X25519PrivateKey], + ) -> Fragment: + processed = packet.process(key_map[route[0].public_key.public_bytes_raw()]) if isinstance(processed, ProcessedFinalHopPacket): return Fragment.from_bytes(processed.payload.recover_plain_playload()) else: processed = processed for node in route[1:]: - p = processed.next_packet.process(node.private_key) + p = processed.next_packet.process( + key_map[node.public_key.public_bytes_raw()] + ) if isinstance(p, ProcessedFinalHopPacket): return Fragment.from_bytes(p.payload.recover_plain_playload()) else: diff --git a/mixnet/test_utils.py b/mixnet/test_utils.py index 151a693..be3541e 100644 --- a/mixnet/test_utils.py +++ b/mixnet/test_utils.py @@ -8,7 +8,9 @@ from mixnet.config import ( ) -def init_mixnet_config(num_nodes: int) -> tuple[GlobalConfig, list[NodeConfig]]: +def init_mixnet_config( + num_nodes: int, +) -> tuple[GlobalConfig, list[NodeConfig], dict[bytes, X25519PrivateKey]]: transmission_rate_per_sec = 3 max_mix_path_length = 3 node_configs = [ @@ -17,9 +19,16 @@ def init_mixnet_config(num_nodes: int) -> tuple[GlobalConfig, list[NodeConfig]]: ] global_config = GlobalConfig( MixMembership( - [NodeInfo(node_config.private_key) for node_config in node_configs] + [ + NodeInfo(node_config.private_key.public_key()) + for node_config in node_configs + ] ), transmission_rate_per_sec, max_mix_path_length, ) - return (global_config, node_configs) + key_map = { + node_config.private_key.public_key().public_bytes_raw(): node_config.private_key + for node_config in node_configs + } + return (global_config, node_configs, key_map) diff --git a/requirements.txt b/requirements.txt index 5e802ca..68d4bfe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ cffi==1.16.0 cryptography==41.0.7 numpy==1.26.3 pycparser==2.21 -pysphinx==0.0.1 +pysphinx==0.0.2 scipy==1.11.4 black==23.12.1 -sympy==1.12 \ No newline at end of file +sympy==1.12