92 lines
2.6 KiB
Python
Raw Normal View History

import hashlib
from dataclasses import dataclass
from typing import Self
2024-08-01 11:07:52 +09:00
from unittest import IsolatedAsyncioTestCase
import framework.asyncio as asynciofw
from framework.framework import Queue
from protocol.connection import LocalSimplexConnection
from protocol.node import Node
from protocol.nomssip import NomssipMessage
2024-08-01 11:07:52 +09:00
from protocol.test_utils import (
init_mixnet_config,
)
class TestNode(IsolatedAsyncioTestCase):
async def test_node(self):
framework = asynciofw.Framework()
global_config, node_configs, _ = init_mixnet_config(10)
queue: Queue[Message] = framework.queue()
2024-08-01 11:07:52 +09:00
async def broadcasted_msg_handler(msg: Message) -> None:
2024-08-01 11:07:52 +09:00
await queue.put(msg)
async def recovered_msg_handler(msg: bytes) -> Message:
return Message(msg)
2024-08-01 11:07:52 +09:00
nodes = [
Node[Message](
framework,
node_config,
global_config,
broadcasted_msg_handler,
recovered_msg_handler,
noise_msg=Message(b""),
)
2024-08-01 11:07:52 +09:00
for node_config in node_configs
]
for i, node in enumerate(nodes):
try:
node.connect_mix(
nodes[(i + 1) % len(nodes)],
LocalSimplexConnection[NomssipMessage[Message]](framework),
LocalSimplexConnection[NomssipMessage[Message]](framework),
2024-08-01 11:07:52 +09:00
)
node.connect_broadcast(
nodes[(i + 1) % len(nodes)],
LocalSimplexConnection[Message](framework),
LocalSimplexConnection[Message](framework),
2024-08-01 11:07:52 +09:00
)
except ValueError as e:
print(e)
msg = Message(b"block selection")
await nodes[0].send_message(msg)
2024-08-01 11:07:52 +09:00
# Wait for all nodes to receive the broadcast
num_nodes_received_broadcast = 0
timeout = 15
for _ in range(timeout):
await framework.sleep(1)
while not queue.empty():
self.assertEqual(msg, await queue.get())
2024-08-01 11:07:52 +09:00
num_nodes_received_broadcast += 1
if num_nodes_received_broadcast == len(nodes):
break
self.assertEqual(len(nodes), num_nodes_received_broadcast)
# TODO: check noise
@dataclass
class Message:
data: bytes
def id(self) -> int:
return int.from_bytes(hashlib.sha256(self.data).digest(), byteorder="big")
def __len__(self) -> int:
return len(self.data)
def __bytes__(self) -> bytes:
return self.data
@classmethod
def from_bytes(cls, data: bytes) -> Self:
return cls(data)