67 lines
2.6 KiB
Python
67 lines
2.6 KiB
Python
from random import randint
|
|
from typing import cast
|
|
from unittest import TestCase
|
|
|
|
from pysphinx.sphinx import (
|
|
ProcessedFinalHopPacket,
|
|
ProcessedForwardHopPacket,
|
|
)
|
|
|
|
from protocol.sphinx import SphinxPacketBuilder
|
|
from protocol.test_utils import init_mixnet_config
|
|
|
|
|
|
class TestSphinxPacketBuilder(TestCase):
|
|
def test_builder(self):
|
|
global_config, _, key_map = init_mixnet_config(10)
|
|
msg = self.random_bytes(500)
|
|
packet, route = SphinxPacketBuilder.build(msg, global_config, 3)
|
|
self.assertEqual(3, len(route))
|
|
|
|
processed = packet.process(key_map[route[0].public_key.public_bytes_raw()])
|
|
self.assertIsInstance(processed, ProcessedForwardHopPacket)
|
|
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
|
|
key_map[route[1].public_key.public_bytes_raw()]
|
|
)
|
|
self.assertIsInstance(processed, ProcessedForwardHopPacket)
|
|
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
|
|
key_map[route[2].public_key.public_bytes_raw()]
|
|
)
|
|
self.assertIsInstance(processed, ProcessedFinalHopPacket)
|
|
recovered = cast(
|
|
ProcessedFinalHopPacket, processed
|
|
).payload.recover_plain_playload()
|
|
self.assertEqual(msg, recovered)
|
|
|
|
def test_max_message_size(self):
|
|
global_config, _, _ = init_mixnet_config(10, max_message_size=2000)
|
|
mix_path_length = global_config.max_mix_path_length
|
|
|
|
packet1, _ = SphinxPacketBuilder.build(
|
|
self.random_bytes(1500), global_config, mix_path_length
|
|
)
|
|
packet2, _ = SphinxPacketBuilder.build(
|
|
self.random_bytes(2000), global_config, mix_path_length
|
|
)
|
|
self.assertEqual(len(packet1.bytes()), len(packet2.bytes()))
|
|
|
|
msg = self.random_bytes(2001)
|
|
with self.assertRaises(ValueError):
|
|
_ = SphinxPacketBuilder.build(msg, global_config, mix_path_length)
|
|
|
|
def test_max_mix_path_length(self):
|
|
global_config, _, _ = init_mixnet_config(10, max_mix_path_length=2)
|
|
msg = self.random_bytes(global_config.max_message_size)
|
|
|
|
packet1, _ = SphinxPacketBuilder.build(msg, global_config, 1)
|
|
packet2, _ = SphinxPacketBuilder.build(msg, global_config, 2)
|
|
self.assertEqual(len(packet1.bytes()), len(packet2.bytes()))
|
|
|
|
with self.assertRaises(ValueError):
|
|
_ = SphinxPacketBuilder.build(msg, global_config, 3)
|
|
|
|
@staticmethod
|
|
def random_bytes(size: int) -> bytes:
|
|
assert size >= 0
|
|
return bytes([randint(0, 255) for _ in range(size)])
|