define FlaggedPacket

This commit is contained in:
Youngjoon Lee 2024-07-11 17:32:48 +09:00
parent 693cf63b73
commit 9ed5fd517c
No known key found for this signature in database
GPG Key ID: B4253AFBA618BF4D

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
import hashlib
from dataclasses import dataclass
from enum import Enum
from typing import Awaitable, Callable
from typing import Awaitable, Callable, Self
from mixnet.connection import DuplexConnection, MixSimplexConnection, SimplexConnection
@ -38,9 +40,9 @@ class Nomssip:
# For simplicity of the spec, reject the connection if the peering degree is reached.
raise ValueError("The peering degree is reached.")
noise_packet = self.__build_packet(
self.PacketType.NOISE, bytes(self.config.msg_size)
)
noise_packet = FlaggedPacket(
FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size)
).bytes()
conn = DuplexConnection(
inbound,
MixSimplexConnection(
@ -62,14 +64,14 @@ class Nomssip:
if self.__check_update_cache(packet):
continue
flag, msg = self.__parse_packet(packet)
match flag:
case self.PacketType.NOISE:
packet = FlaggedPacket.from_bytes(packet)
match packet.flag:
case FlaggedPacket.Flag.NOISE:
# Drop noise packet
continue
case self.PacketType.REAL:
await self.__gossip(packet)
await self.handler(msg)
case FlaggedPacket.Flag.REAL:
await self.__gossip_flagged_packet(packet)
await self.handler(packet.message)
async def gossip(self, msg: bytes):
"""
@ -78,15 +80,15 @@ class Nomssip:
# The message size must be fixed.
assert len(msg) == self.config.msg_size
packet = self.__build_packet(self.PacketType.REAL, msg)
await self.__gossip(packet)
packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg)
await self.__gossip_flagged_packet(packet)
async def __gossip(self, packet: bytes):
async def __gossip_flagged_packet(self, packet: FlaggedPacket):
"""
An internal method to send a flagged packet to all connected peers
"""
for conn in self.conns:
await conn.send(packet)
await conn.send(packet.bytes())
def __check_update_cache(self, packet: bytes) -> bool:
"""
@ -98,22 +100,24 @@ class Nomssip:
self.packet_cache.add(hash)
return False
class PacketType(Enum):
class FlaggedPacket:
class Flag(Enum):
REAL = b"\x00"
NOISE = b"\x01"
@staticmethod
def __build_packet(flag: PacketType, data: bytes) -> bytes:
"""
Prepend a flag to the message, right before sending it via network channel.
"""
return flag.value + data
def __init__(self, flag: Flag, message: bytes):
self.flag = flag
self.message = message
@staticmethod
def __parse_packet(data: bytes) -> tuple[PacketType, bytes]:
def bytes(self) -> bytes:
return self.flag.value + self.message
@classmethod
def from_bytes(cls, packet: bytes) -> Self:
"""
Parse the message and extract the flag.
Parse a flagged packet from bytes
"""
if len(data) < 1:
if len(packet) < 1:
raise ValueError("Invalid message format")
return (Nomssip.PacketType(data[:1]), data[1:])
return cls(cls.Flag(packet[:1]), packet[1:])