nomos-simulations/deprecated/mixnet-overall/protocol/test_sphinx.py

67 lines
2.6 KiB
Python

from random import randint
from typing import cast
from unittest import TestCase
from pysphinx.sphinx import (
ProcessedFinalHopPacket,
ProcessedForwardHopPacket,
)
from protocol.sphinx import SphinxPacketBuilder
from protocol.test_utils import init_mixnet_config
class TestSphinxPacketBuilder(TestCase):
def test_builder(self):
global_config, _, key_map = init_mixnet_config(10)
msg = self.random_bytes(500)
packet, route = SphinxPacketBuilder.build(msg, global_config, 3)
self.assertEqual(3, len(route))
processed = packet.process(key_map[route[0].public_key.public_bytes_raw()])
self.assertIsInstance(processed, ProcessedForwardHopPacket)
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
key_map[route[1].public_key.public_bytes_raw()]
)
self.assertIsInstance(processed, ProcessedForwardHopPacket)
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
key_map[route[2].public_key.public_bytes_raw()]
)
self.assertIsInstance(processed, ProcessedFinalHopPacket)
recovered = cast(
ProcessedFinalHopPacket, processed
).payload.recover_plain_playload()
self.assertEqual(msg, recovered)
def test_max_message_size(self):
global_config, _, _ = init_mixnet_config(10, max_message_size=2000)
mix_path_length = global_config.max_mix_path_length
packet1, _ = SphinxPacketBuilder.build(
self.random_bytes(1500), global_config, mix_path_length
)
packet2, _ = SphinxPacketBuilder.build(
self.random_bytes(2000), global_config, mix_path_length
)
self.assertEqual(len(packet1.bytes()), len(packet2.bytes()))
msg = self.random_bytes(2001)
with self.assertRaises(ValueError):
_ = SphinxPacketBuilder.build(msg, global_config, mix_path_length)
def test_max_mix_path_length(self):
global_config, _, _ = init_mixnet_config(10, max_mix_path_length=2)
msg = self.random_bytes(global_config.max_message_size)
packet1, _ = SphinxPacketBuilder.build(msg, global_config, 1)
packet2, _ = SphinxPacketBuilder.build(msg, global_config, 2)
self.assertEqual(len(packet1.bytes()), len(packet2.bytes()))
with self.assertRaises(ValueError):
_ = SphinxPacketBuilder.build(msg, global_config, 3)
@staticmethod
def random_bytes(size: int) -> bytes:
assert size >= 0
return bytes([randint(0, 255) for _ in range(size)])