use public key for NodeInfo instead of private key

This commit is contained in:
Youngjoon Lee 2024-06-27 18:04:51 +09:00
parent 3585d3cf86
commit 9cd601c7ba
No known key found for this signature in database
GPG Key ID: 09B750B5BD6F08A2
5 changed files with 50 additions and 24 deletions

View File

@ -48,12 +48,9 @@ class MixMembership:
@dataclass
class NodeInfo:
private_key: X25519PrivateKey
def public_key(self) -> X25519PublicKey:
return self.private_key.public_key()
public_key: X25519PublicKey
def sphinx_node(self) -> SphinxNode:
# TODO: Use a pre-signed incentive tx, instead of NodeAddress
dummy_node_addr = bytes(32)
return SphinxNode(self.private_key, dummy_node_addr)
return SphinxNode(self.public_key, dummy_node_addr)

View File

@ -9,7 +9,7 @@ from mixnet.test_utils import (
class TestNode(IsolatedAsyncioTestCase):
async def test_node(self):
global_config, node_configs = init_mixnet_config(10)
global_config, node_configs, _ = init_mixnet_config(10)
nodes = [Node(node_config, global_config) for node_config in node_configs]
for i, node in enumerate(nodes):
node.connect(nodes[(i + 1) % len(nodes)])

View File

@ -2,7 +2,7 @@ from random import randint
from typing import List
from unittest import TestCase
from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket
from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket, X25519PrivateKey
from mixnet.config import NodeInfo
from mixnet.packet import (
@ -16,29 +16,39 @@ from mixnet.test_utils import init_mixnet_config
class TestPacket(TestCase):
def test_real_packet(self):
membership = init_mixnet_config(10)[0].membership
global_config, _, key_map = init_mixnet_config(10)
msg = self.random_bytes(3500)
packets_and_routes = PacketBuilder.build_real_packets(msg, membership)
packets_and_routes = PacketBuilder.build_real_packets(
msg, global_config.membership
)
self.assertEqual(4, len(packets_and_routes))
reconstructor = MessageReconstructor()
self.assertIsNone(
reconstructor.add(
self.process_packet(packets_and_routes[1][0], packets_and_routes[1][1])
self.process_packet(
packets_and_routes[1][0], packets_and_routes[1][1], key_map
)
),
)
self.assertIsNone(
reconstructor.add(
self.process_packet(packets_and_routes[3][0], packets_and_routes[3][1])
self.process_packet(
packets_and_routes[3][0], packets_and_routes[3][1], key_map
)
),
)
self.assertIsNone(
reconstructor.add(
self.process_packet(packets_and_routes[2][0], packets_and_routes[2][1])
self.process_packet(
packets_and_routes[2][0], packets_and_routes[2][1], key_map
)
),
)
msg_with_flag = reconstructor.add(
self.process_packet(packets_and_routes[0][0], packets_and_routes[0][1])
self.process_packet(
packets_and_routes[0][0], packets_and_routes[0][1], key_map
)
)
assert msg_with_flag is not None
self.assertEqual(
@ -47,14 +57,18 @@ class TestPacket(TestCase):
)
def test_cover_packet(self):
membership = init_mixnet_config(10)[0].membership
global_config, _, key_map = init_mixnet_config(10)
msg = b"cover"
packets_and_routes = PacketBuilder.build_drop_cover_packets(msg, membership)
packets_and_routes = PacketBuilder.build_drop_cover_packets(
msg, global_config.membership
)
self.assertEqual(1, len(packets_and_routes))
reconstructor = MessageReconstructor()
msg_with_flag = reconstructor.add(
self.process_packet(packets_and_routes[0][0], packets_and_routes[0][1])
self.process_packet(
packets_and_routes[0][0], packets_and_routes[0][1], key_map
)
)
assert msg_with_flag is not None
self.assertEqual(
@ -63,14 +77,20 @@ class TestPacket(TestCase):
)
@staticmethod
def process_packet(packet: SphinxPacket, route: List[NodeInfo]) -> Fragment:
processed = packet.process(route[0].private_key)
def process_packet(
packet: SphinxPacket,
route: List[NodeInfo],
key_map: dict[bytes, X25519PrivateKey],
) -> Fragment:
processed = packet.process(key_map[route[0].public_key.public_bytes_raw()])
if isinstance(processed, ProcessedFinalHopPacket):
return Fragment.from_bytes(processed.payload.recover_plain_playload())
else:
processed = processed
for node in route[1:]:
p = processed.next_packet.process(node.private_key)
p = processed.next_packet.process(
key_map[node.public_key.public_bytes_raw()]
)
if isinstance(p, ProcessedFinalHopPacket):
return Fragment.from_bytes(p.payload.recover_plain_playload())
else:

View File

@ -8,7 +8,9 @@ from mixnet.config import (
)
def init_mixnet_config(num_nodes: int) -> tuple[GlobalConfig, list[NodeConfig]]:
def init_mixnet_config(
num_nodes: int,
) -> tuple[GlobalConfig, list[NodeConfig], dict[bytes, X25519PrivateKey]]:
transmission_rate_per_sec = 3
max_mix_path_length = 3
node_configs = [
@ -17,9 +19,16 @@ def init_mixnet_config(num_nodes: int) -> tuple[GlobalConfig, list[NodeConfig]]:
]
global_config = GlobalConfig(
MixMembership(
[NodeInfo(node_config.private_key) for node_config in node_configs]
[
NodeInfo(node_config.private_key.public_key())
for node_config in node_configs
]
),
transmission_rate_per_sec,
max_mix_path_length,
)
return (global_config, node_configs)
key_map = {
node_config.private_key.public_key().public_bytes_raw(): node_config.private_key
for node_config in node_configs
}
return (global_config, node_configs, key_map)

View File

@ -3,7 +3,7 @@ cffi==1.16.0
cryptography==41.0.7
numpy==1.26.3
pycparser==2.21
pysphinx==0.0.1
pysphinx==0.0.2
scipy==1.11.4
black==23.12.1
sympy==1.12
sympy==1.12