92 lines
2.6 KiB
Python

import hashlib
from dataclasses import dataclass
from typing import Self
from unittest import IsolatedAsyncioTestCase
import framework.asyncio as asynciofw
from framework.framework import Queue
from protocol.connection import LocalSimplexConnection
from protocol.node import Node
from protocol.nomssip import NomssipMessage
from protocol.test_utils import (
init_mixnet_config,
)
class TestNode(IsolatedAsyncioTestCase):
async def test_node(self):
framework = asynciofw.Framework()
global_config, node_configs, _ = init_mixnet_config(10)
queue: Queue[Message] = framework.queue()
async def broadcasted_msg_handler(msg: Message) -> None:
await queue.put(msg)
async def recovered_msg_handler(msg: bytes) -> Message:
return Message(msg)
nodes = [
Node[Message](
framework,
node_config,
global_config,
broadcasted_msg_handler,
recovered_msg_handler,
noise_msg=Message(b""),
)
for node_config in node_configs
]
for i, node in enumerate(nodes):
try:
node.connect_mix(
nodes[(i + 1) % len(nodes)],
LocalSimplexConnection[NomssipMessage[Message]](framework),
LocalSimplexConnection[NomssipMessage[Message]](framework),
)
node.connect_broadcast(
nodes[(i + 1) % len(nodes)],
LocalSimplexConnection[Message](framework),
LocalSimplexConnection[Message](framework),
)
except ValueError as e:
print(e)
msg = Message(b"block selection")
await nodes[0].send_message(msg)
# Wait for all nodes to receive the broadcast
num_nodes_received_broadcast = 0
timeout = 15
for _ in range(timeout):
await framework.sleep(1)
while not queue.empty():
self.assertEqual(msg, await queue.get())
num_nodes_received_broadcast += 1
if num_nodes_received_broadcast == len(nodes):
break
self.assertEqual(len(nodes), num_nodes_received_broadcast)
# TODO: check noise
@dataclass
class Message:
data: bytes
def id(self) -> int:
return int.from_bytes(hashlib.sha256(self.data).digest(), byteorder="big")
def __len__(self) -> int:
return len(self.data)
def __bytes__(self) -> bytes:
return self.data
@classmethod
def from_bytes(cls, data: bytes) -> Self:
return cls(data)