mirror of
https://github.com/logos-blockchain/logos-blockchain-specs.git
synced 2026-02-17 11:43:13 +00:00
sphinx serde
This commit is contained in:
parent
3a703434da
commit
4aa1ee06fb
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from enum import Enum
|
||||||
from typing import Awaitable, Callable, TypeAlias
|
from typing import Awaitable, Callable, TypeAlias
|
||||||
|
|
||||||
from pysphinx.payload import DEFAULT_PAYLOAD_SIZE
|
from pysphinx.payload import DEFAULT_PAYLOAD_SIZE
|
||||||
@ -15,8 +16,7 @@ from pysphinx.sphinx import (
|
|||||||
from mixnet.config import GlobalConfig, NodeConfig
|
from mixnet.config import GlobalConfig, NodeConfig
|
||||||
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder
|
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder
|
||||||
|
|
||||||
NetworkPacket: TypeAlias = SphinxPacket | bytes
|
NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes]
|
||||||
NetworkPacketQueue: TypeAlias = asyncio.Queue[NetworkPacket]
|
|
||||||
Connection: TypeAlias = NetworkPacketQueue
|
Connection: TypeAlias = NetworkPacketQueue
|
||||||
BroadcastChannel: TypeAlias = asyncio.Queue[bytes]
|
BroadcastChannel: TypeAlias = asyncio.Queue[bytes]
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ class Node:
|
|||||||
|
|
||||||
async def __process_sphinx_packet(
|
async def __process_sphinx_packet(
|
||||||
self, packet: SphinxPacket
|
self, packet: SphinxPacket
|
||||||
) -> NetworkPacket | None:
|
) -> SphinxPacket | None:
|
||||||
try:
|
try:
|
||||||
processed = packet.process(self.config.private_key)
|
processed = packet.process(self.config.private_key)
|
||||||
match processed:
|
match processed:
|
||||||
@ -83,19 +83,19 @@ class Node:
|
|||||||
for packet, _ in PacketBuilder.build_real_packets(
|
for packet, _ in PacketBuilder.build_real_packets(
|
||||||
msg, self.global_config.membership
|
msg, self.global_config.membership
|
||||||
):
|
):
|
||||||
await self.mixgossip_channel.gossip(packet)
|
await self.mixgossip_channel.gossip(build_msg(MsgType.REAL, packet.bytes()))
|
||||||
|
|
||||||
|
|
||||||
class MixGossipChannel:
|
class MixGossipChannel:
|
||||||
peering_degree: int
|
peering_degree: int
|
||||||
conns: list[DuplexConnection]
|
conns: list[DuplexConnection]
|
||||||
handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]]
|
handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]]
|
||||||
msg_cache: set[NetworkPacket]
|
msg_cache: set[bytes]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
peer_degree: int,
|
peer_degree: int,
|
||||||
handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]],
|
handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]],
|
||||||
):
|
):
|
||||||
self.peering_degree = peer_degree
|
self.peering_degree = peer_degree
|
||||||
self.conns = []
|
self.conns = []
|
||||||
@ -118,26 +118,28 @@ class MixGossipChannel:
|
|||||||
|
|
||||||
async def __process_inbound_conn(self, conn: DuplexConnection):
|
async def __process_inbound_conn(self, conn: DuplexConnection):
|
||||||
while True:
|
while True:
|
||||||
elem = await conn.recv()
|
msg = await conn.recv()
|
||||||
# In practice, data transmitted through connections is going to be always 'bytes'.
|
# Don't process the same message twice.
|
||||||
# But here, we use the SphinxPacket type explicitly for simplicity
|
msg_hash = hashlib.sha256(msg).digest()
|
||||||
# without implementing serde for SphinxPacket.
|
if msg_hash in self.msg_cache:
|
||||||
if isinstance(elem, bytes):
|
|
||||||
assert elem == build_noise_packet()
|
|
||||||
# Drop packet
|
|
||||||
continue
|
continue
|
||||||
elif isinstance(elem, SphinxPacket):
|
self.msg_cache.add(msg_hash)
|
||||||
# Don't process the same message twice.
|
|
||||||
msg_hash = hashlib.sha256(elem.bytes()).digest()
|
|
||||||
if msg_hash in self.msg_cache:
|
|
||||||
continue
|
|
||||||
self.msg_cache.add(msg_hash)
|
|
||||||
# Handle the packet and gossip the result if needed.
|
|
||||||
net_packet = await self.handler(elem)
|
|
||||||
if net_packet is not None:
|
|
||||||
await self.gossip(net_packet)
|
|
||||||
|
|
||||||
async def gossip(self, packet: NetworkPacket):
|
flag, msg = parse_msg(msg)
|
||||||
|
match flag:
|
||||||
|
case MsgType.NOISE:
|
||||||
|
# Drop noise packet
|
||||||
|
continue
|
||||||
|
case MsgType.REAL:
|
||||||
|
# Handle the packet and gossip the result if needed.
|
||||||
|
sphinx_packet = SphinxPacket.from_bytes(msg)
|
||||||
|
new_sphinx_packet = await self.handler(sphinx_packet)
|
||||||
|
if new_sphinx_packet is not None:
|
||||||
|
await self.gossip(
|
||||||
|
build_msg(MsgType.REAL, new_sphinx_packet.bytes())
|
||||||
|
)
|
||||||
|
|
||||||
|
async def gossip(self, packet: bytes):
|
||||||
for conn in self.conns:
|
for conn in self.conns:
|
||||||
await conn.send(packet)
|
await conn.send(packet)
|
||||||
|
|
||||||
@ -150,10 +152,10 @@ class DuplexConnection:
|
|||||||
self.inbound = inbound
|
self.inbound = inbound
|
||||||
self.outbound = outbound
|
self.outbound = outbound
|
||||||
|
|
||||||
async def recv(self) -> NetworkPacket:
|
async def recv(self) -> bytes:
|
||||||
return await self.inbound.get()
|
return await self.inbound.get()
|
||||||
|
|
||||||
async def send(self, packet: NetworkPacket):
|
async def send(self, packet: bytes):
|
||||||
await self.outbound.send(packet)
|
await self.outbound.send(packet)
|
||||||
|
|
||||||
|
|
||||||
@ -178,9 +180,24 @@ class MixOutboundConnection:
|
|||||||
elem = self.queue.get_nowait()
|
elem = self.queue.get_nowait()
|
||||||
await self.conn.put(elem)
|
await self.conn.put(elem)
|
||||||
|
|
||||||
async def send(self, elem: NetworkPacket):
|
async def send(self, elem: bytes):
|
||||||
await self.queue.put(elem)
|
await self.queue.put(elem)
|
||||||
|
|
||||||
|
|
||||||
|
class MsgType(Enum):
|
||||||
|
REAL = b"\x00"
|
||||||
|
NOISE = b"\x01"
|
||||||
|
|
||||||
|
|
||||||
|
def build_msg(flag: MsgType, data: bytes) -> bytes:
|
||||||
|
return flag.value + data
|
||||||
|
|
||||||
|
|
||||||
|
def parse_msg(data: bytes) -> tuple[MsgType, bytes]:
|
||||||
|
if len(data) < 1:
|
||||||
|
raise ValueError("Invalid message format")
|
||||||
|
return (MsgType(data[:1]), data[1:])
|
||||||
|
|
||||||
|
|
||||||
def build_noise_packet() -> bytes:
|
def build_noise_packet() -> bytes:
|
||||||
return bytes(DEFAULT_PAYLOAD_SIZE)
|
return build_msg(MsgType.NOISE, bytes(DEFAULT_PAYLOAD_SIZE))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user