mirror of
https://github.com/logos-co/nomos-specs.git
synced 2025-01-20 20:40:15 +00:00
remove message fragments
This commit is contained in:
parent
ab9943d291
commit
ad669abc6b
@ -49,19 +49,11 @@ class MixMembership:
|
||||
|
||||
nodes: List[NodeInfo]
|
||||
|
||||
def generate_route(self, num_hops: int, last_mix: NodeInfo) -> list[NodeInfo]:
|
||||
def generate_route(self, length: int) -> list[NodeInfo]:
|
||||
"""
|
||||
Generate a mix route for a Sphinx packet.
|
||||
The pre-selected mix_destination is used as a last mix node in the route,
|
||||
so that associated packets can be merged together into a original message.
|
||||
"""
|
||||
return [*(self.choose() for _ in range(num_hops - 1)), last_mix]
|
||||
|
||||
def choose(self) -> NodeInfo:
|
||||
"""
|
||||
Choose a mix node as a mix destination that will reconstruct a message from Sphinx packets.
|
||||
"""
|
||||
return random.choice(self.nodes)
|
||||
return [random.choice(self.nodes) for _ in range(length)]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -4,7 +4,6 @@ import asyncio
|
||||
from typing import TypeAlias
|
||||
|
||||
from pysphinx.sphinx import (
|
||||
Payload,
|
||||
ProcessedFinalHopPacket,
|
||||
ProcessedForwardHopPacket,
|
||||
SphinxPacket,
|
||||
@ -12,7 +11,7 @@ from pysphinx.sphinx import (
|
||||
|
||||
from mixnet.config import GlobalConfig, NodeConfig
|
||||
from mixnet.nomssip import Nomssip
|
||||
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder
|
||||
from mixnet.sphinx import SphinxPacketBuilder
|
||||
|
||||
BroadcastChannel: TypeAlias = asyncio.Queue[bytes]
|
||||
|
||||
@ -28,7 +27,6 @@ class Node:
|
||||
config: NodeConfig
|
||||
global_config: GlobalConfig
|
||||
nomssip: Nomssip
|
||||
reconstructor: MessageReconstructor
|
||||
broadcast_channel: BroadcastChannel
|
||||
|
||||
def __init__(self, config: NodeConfig, global_config: GlobalConfig):
|
||||
@ -42,7 +40,6 @@ class Node:
|
||||
),
|
||||
self.__process_msg,
|
||||
)
|
||||
self.reconstructor = MessageReconstructor()
|
||||
self.broadcast_channel = asyncio.Queue()
|
||||
|
||||
@staticmethod
|
||||
@ -50,10 +47,10 @@ class Node:
|
||||
"""
|
||||
Calculate the actual message size to be gossiped, which depends on the maximum length of mix path.
|
||||
"""
|
||||
sample_packet, _ = PacketBuilder.build_real_packets(
|
||||
sample_sphinx_packet, _ = SphinxPacketBuilder.build(
|
||||
bytes(1), global_config.membership, global_config.max_mix_path_length
|
||||
)[0]
|
||||
return len(sample_packet.bytes())
|
||||
)
|
||||
return len(sample_sphinx_packet.bytes())
|
||||
|
||||
async def __process_msg(self, msg: bytes) -> None:
|
||||
"""
|
||||
@ -83,24 +80,11 @@ class Node:
|
||||
case ProcessedForwardHopPacket():
|
||||
return processed.next_packet
|
||||
case ProcessedFinalHopPacket():
|
||||
return await self.__process_sphinx_payload(processed.payload)
|
||||
return processed.payload.recover_plain_playload()
|
||||
except ValueError:
|
||||
# Return nothing, if it cannot be unwrapped by the private key of this node.
|
||||
return None
|
||||
|
||||
async def __process_sphinx_payload(self, payload: Payload) -> bytes | None:
|
||||
"""
|
||||
Process the Sphinx payload if possible
|
||||
"""
|
||||
msg_with_flag = self.reconstructor.add(
|
||||
Fragment.from_bytes(payload.recover_plain_playload())
|
||||
)
|
||||
if msg_with_flag is not None:
|
||||
flag, msg = PacketBuilder.parse_msg_and_flag(msg_with_flag)
|
||||
if flag == MessageFlag.MESSAGE_FLAG_REAL:
|
||||
return msg
|
||||
return None
|
||||
|
||||
def connect(self, peer: Node):
|
||||
"""
|
||||
Establish a duplex connection with a peer node.
|
||||
@ -117,9 +101,9 @@ class Node:
|
||||
"""
|
||||
# Here, we handle the case in which a msg is split into multiple Sphinx packets.
|
||||
# But, in practice, we expect a message to be small enough to fit in a single Sphinx packet.
|
||||
for packet, _ in PacketBuilder.build_real_packets(
|
||||
sphinx_packet, _ = SphinxPacketBuilder.build(
|
||||
msg,
|
||||
self.global_config.membership,
|
||||
self.config.mix_path_length,
|
||||
):
|
||||
await self.nomssip.gossip(packet.bytes())
|
||||
)
|
||||
await self.nomssip.gossip(sphinx_packet.bytes())
|
||||
|
206
mixnet/packet.py
206
mixnet/packet.py
@ -1,206 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from itertools import batched
|
||||
from typing import Dict, List, Self, Tuple, TypeAlias
|
||||
|
||||
from pysphinx.payload import Payload
|
||||
from pysphinx.sphinx import SphinxPacket
|
||||
|
||||
from mixnet.config import MixMembership, NodeInfo
|
||||
|
||||
|
||||
class MessageFlag(Enum):
|
||||
MESSAGE_FLAG_REAL = b"\x00"
|
||||
MESSAGE_FLAG_DROP_COVER = b"\x01"
|
||||
|
||||
def bytes(self) -> bytes:
|
||||
return bytes(self.value)
|
||||
|
||||
|
||||
class PacketBuilder:
|
||||
@staticmethod
|
||||
def build_real_packets(
|
||||
message: bytes, membership: MixMembership, path_len: int
|
||||
) -> List[Tuple[SphinxPacket, List[NodeInfo]]]:
|
||||
return PacketBuilder.__build_packets(
|
||||
MessageFlag.MESSAGE_FLAG_REAL, message, membership, path_len
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_drop_cover_packets(
|
||||
message: bytes, membership: MixMembership, path_len: int
|
||||
) -> List[Tuple[SphinxPacket, List[NodeInfo]]]:
|
||||
return PacketBuilder.__build_packets(
|
||||
MessageFlag.MESSAGE_FLAG_DROP_COVER, message, membership, path_len
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __build_packets(
|
||||
flag: MessageFlag, message: bytes, membership: MixMembership, path_len: int
|
||||
) -> List[Tuple[SphinxPacket, List[NodeInfo]]]:
|
||||
if path_len <= 0:
|
||||
raise ValueError("path_len must be greater than 0")
|
||||
|
||||
last_mix = membership.choose()
|
||||
|
||||
msg_with_flag = flag.bytes() + message
|
||||
# NOTE: We don't encrypt msg_with_flag for destination.
|
||||
# If encryption is needed, a shared secret must be appended in front of the message along with the MessageFlag.
|
||||
fragment_set = FragmentSet(msg_with_flag)
|
||||
|
||||
out = []
|
||||
for fragment in fragment_set.fragments:
|
||||
route = membership.generate_route(path_len, last_mix)
|
||||
packet = SphinxPacket.build(
|
||||
fragment.bytes(),
|
||||
[mixnode.sphinx_node() for mixnode in route],
|
||||
last_mix.sphinx_node(),
|
||||
)
|
||||
out.append((packet, route))
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def parse_msg_and_flag(data: bytes) -> Tuple[MessageFlag, bytes]:
|
||||
"""Remove a MessageFlag from data"""
|
||||
if len(data) < 1:
|
||||
raise ValueError("data is too short")
|
||||
|
||||
return (MessageFlag(data[0:1]), data[1:])
|
||||
|
||||
|
||||
# Unlikely, Nym uses i32 for FragmentSetId, which may cause more collisions.
|
||||
# We will use UUID until figuring out why Nym uses i32.
|
||||
FragmentSetId: TypeAlias = bytes # 128bit UUID v4
|
||||
FragmentId: TypeAlias = int # unsigned 8bit int in big endian
|
||||
|
||||
FRAGMENT_SET_ID_LENGTH: int = 16
|
||||
FRAGMENT_ID_LENGTH: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class FragmentHeader:
|
||||
"""
|
||||
Contain all information for reconstructing a message that was fragmented into the same FragmentSet.
|
||||
"""
|
||||
|
||||
set_id: FragmentSetId
|
||||
total_fragments: FragmentId
|
||||
fragment_id: FragmentId
|
||||
|
||||
SIZE: int = FRAGMENT_SET_ID_LENGTH + FRAGMENT_ID_LENGTH * 2
|
||||
|
||||
@staticmethod
|
||||
def max_total_fragments() -> int:
|
||||
return 256 # because total_fragment is u8
|
||||
|
||||
def bytes(self) -> bytes:
|
||||
return (
|
||||
self.set_id
|
||||
+ self.total_fragments.to_bytes(1)
|
||||
+ self.fragment_id.to_bytes(1)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
if len(data) != cls.SIZE:
|
||||
raise ValueError("Invalid data length", len(data))
|
||||
|
||||
return cls(data[:16], int.from_bytes(data[16:17]), int.from_bytes(data[17:18]))
|
||||
|
||||
|
||||
@dataclass
|
||||
class FragmentSet:
|
||||
"""
|
||||
Represent a set of Fragments that can be reconstructed to a single original message.
|
||||
|
||||
Note that the maximum number of fragments in a FragmentSet is limited for now.
|
||||
"""
|
||||
|
||||
fragments: List[Fragment]
|
||||
|
||||
MAX_FRAGMENTS: int = FragmentHeader.max_total_fragments()
|
||||
|
||||
def __init__(self, message: bytes):
|
||||
"""
|
||||
Build a FragmentSet by chunking a message into Fragments.
|
||||
"""
|
||||
chunked_messages = chunks(message, Fragment.MAX_PAYLOAD_SIZE)
|
||||
# For now, we don't support more than max_fragments() fragments.
|
||||
# If needed, we can devise the FragmentSet chaining to support larger messages, like Nym.
|
||||
if len(chunked_messages) > self.MAX_FRAGMENTS:
|
||||
raise ValueError(f"Too long message: {len(chunked_messages)} chunks")
|
||||
|
||||
set_id = uuid.uuid4().bytes
|
||||
self.fragments = [
|
||||
Fragment(FragmentHeader(set_id, len(chunked_messages), i), chunk)
|
||||
for i, chunk in enumerate(chunked_messages)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fragment:
|
||||
"""Represent a piece of data that can be transformed to a single SphinxPacket"""
|
||||
|
||||
header: FragmentHeader
|
||||
body: bytes
|
||||
|
||||
MAX_PAYLOAD_SIZE: int = Payload.max_plain_payload_size() - FragmentHeader.SIZE
|
||||
|
||||
def bytes(self) -> bytes:
|
||||
return self.header.bytes() + self.body
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
header = FragmentHeader.from_bytes(data[: FragmentHeader.SIZE])
|
||||
body = data[FragmentHeader.SIZE :]
|
||||
return cls(header, body)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageReconstructor:
|
||||
fragmentSets: Dict[FragmentSetId, FragmentSetReconstructor]
|
||||
|
||||
def __init__(self):
|
||||
self.fragmentSets = {}
|
||||
|
||||
def add(self, fragment: Fragment) -> bytes | None:
|
||||
if fragment.header.set_id not in self.fragmentSets:
|
||||
self.fragmentSets[fragment.header.set_id] = FragmentSetReconstructor(
|
||||
fragment.header.total_fragments
|
||||
)
|
||||
|
||||
msg = self.fragmentSets[fragment.header.set_id].add(fragment)
|
||||
if msg is not None:
|
||||
del self.fragmentSets[fragment.header.set_id]
|
||||
return msg
|
||||
|
||||
|
||||
@dataclass
|
||||
class FragmentSetReconstructor:
|
||||
total_fragments: FragmentId
|
||||
fragments: Dict[FragmentId, Fragment]
|
||||
|
||||
def __init__(self, total_fragments: FragmentId):
|
||||
self.total_fragments = total_fragments
|
||||
self.fragments = {}
|
||||
|
||||
def add(self, fragment: Fragment) -> bytes | None:
|
||||
self.fragments[fragment.header.fragment_id] = fragment
|
||||
if len(self.fragments) == self.total_fragments:
|
||||
return self.build_message()
|
||||
else:
|
||||
return None
|
||||
|
||||
def build_message(self) -> bytes:
|
||||
message = b""
|
||||
for i in range(self.total_fragments):
|
||||
message += self.fragments[FragmentId(i)].body
|
||||
return message
|
||||
|
||||
|
||||
def chunks(data: bytes, size: int) -> List[bytes]:
|
||||
return list(map(bytes, batched(data, size)))
|
32
mixnet/sphinx.py
Normal file
32
mixnet/sphinx.py
Normal file
@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
from pysphinx.payload import Payload
|
||||
from pysphinx.sphinx import SphinxPacket
|
||||
|
||||
from mixnet.config import MixMembership, NodeInfo
|
||||
|
||||
|
||||
class SphinxPacketBuilder:
|
||||
@staticmethod
|
||||
def build(
|
||||
message: bytes, membership: MixMembership, path_len: int
|
||||
) -> Tuple[SphinxPacket, List[NodeInfo]]:
|
||||
if path_len <= 0:
|
||||
raise ValueError("path_len must be greater than 0")
|
||||
if len(message) > Payload.max_plain_payload_size():
|
||||
raise ValueError("message is too long")
|
||||
|
||||
route = membership.generate_route(path_len)
|
||||
# We don't need the destination (defined in the Loopix Sphinx spec)
|
||||
# because the last mix will broadcast the fully unwrapped message.
|
||||
# Later, we will optimize the Sphinx according to our requirements.
|
||||
dummy_destination = route[-1]
|
||||
|
||||
packet = SphinxPacket.build(
|
||||
message,
|
||||
route=[mixnode.sphinx_node() for mixnode in route],
|
||||
destination=dummy_destination.sphinx_node(),
|
||||
)
|
||||
return (packet, route)
|
@ -1,103 +0,0 @@
|
||||
from random import randint
|
||||
from typing import List
|
||||
from unittest import TestCase
|
||||
|
||||
from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket, X25519PrivateKey
|
||||
|
||||
from mixnet.config import NodeInfo
|
||||
from mixnet.packet import (
|
||||
Fragment,
|
||||
MessageFlag,
|
||||
MessageReconstructor,
|
||||
PacketBuilder,
|
||||
)
|
||||
from mixnet.test_utils import init_mixnet_config
|
||||
|
||||
|
||||
class TestPacket(TestCase):
|
||||
def test_real_packet(self):
|
||||
global_config, _, key_map = init_mixnet_config(10)
|
||||
msg = self.random_bytes(3500)
|
||||
packets_and_routes = PacketBuilder.build_real_packets(
|
||||
msg, global_config.membership, 3
|
||||
)
|
||||
self.assertEqual(4, len(packets_and_routes))
|
||||
|
||||
reconstructor = MessageReconstructor()
|
||||
self.assertIsNone(
|
||||
reconstructor.add(
|
||||
self.process_packet(
|
||||
packets_and_routes[1][0], packets_and_routes[1][1], key_map
|
||||
)
|
||||
),
|
||||
)
|
||||
self.assertIsNone(
|
||||
reconstructor.add(
|
||||
self.process_packet(
|
||||
packets_and_routes[3][0], packets_and_routes[3][1], key_map
|
||||
)
|
||||
),
|
||||
)
|
||||
self.assertIsNone(
|
||||
reconstructor.add(
|
||||
self.process_packet(
|
||||
packets_and_routes[2][0], packets_and_routes[2][1], key_map
|
||||
)
|
||||
),
|
||||
)
|
||||
msg_with_flag = reconstructor.add(
|
||||
self.process_packet(
|
||||
packets_and_routes[0][0], packets_and_routes[0][1], key_map
|
||||
)
|
||||
)
|
||||
assert msg_with_flag is not None
|
||||
self.assertEqual(
|
||||
PacketBuilder.parse_msg_and_flag(msg_with_flag),
|
||||
(MessageFlag.MESSAGE_FLAG_REAL, msg),
|
||||
)
|
||||
|
||||
def test_cover_packet(self):
|
||||
global_config, _, key_map = init_mixnet_config(10)
|
||||
msg = b"cover"
|
||||
packets_and_routes = PacketBuilder.build_drop_cover_packets(
|
||||
msg, global_config.membership, 3
|
||||
)
|
||||
self.assertEqual(1, len(packets_and_routes))
|
||||
|
||||
reconstructor = MessageReconstructor()
|
||||
msg_with_flag = reconstructor.add(
|
||||
self.process_packet(
|
||||
packets_and_routes[0][0], packets_and_routes[0][1], key_map
|
||||
)
|
||||
)
|
||||
assert msg_with_flag is not None
|
||||
self.assertEqual(
|
||||
PacketBuilder.parse_msg_and_flag(msg_with_flag),
|
||||
(MessageFlag.MESSAGE_FLAG_DROP_COVER, msg),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def process_packet(
|
||||
packet: SphinxPacket,
|
||||
route: List[NodeInfo],
|
||||
key_map: dict[bytes, X25519PrivateKey],
|
||||
) -> Fragment:
|
||||
processed = packet.process(key_map[route[0].public_key.public_bytes_raw()])
|
||||
if isinstance(processed, ProcessedFinalHopPacket):
|
||||
return Fragment.from_bytes(processed.payload.recover_plain_playload())
|
||||
else:
|
||||
processed = processed
|
||||
for node in route[1:]:
|
||||
p = processed.next_packet.process(
|
||||
key_map[node.public_key.public_bytes_raw()]
|
||||
)
|
||||
if isinstance(p, ProcessedFinalHopPacket):
|
||||
return Fragment.from_bytes(p.payload.recover_plain_playload())
|
||||
else:
|
||||
processed = p
|
||||
assert False
|
||||
|
||||
@staticmethod
|
||||
def random_bytes(size: int) -> bytes:
|
||||
assert size >= 0
|
||||
return bytes([randint(0, 255) for _ in range(size)])
|
39
mixnet/test_sphinx.py
Normal file
39
mixnet/test_sphinx.py
Normal file
@ -0,0 +1,39 @@
|
||||
from random import randint
|
||||
from typing import cast
|
||||
from unittest import TestCase
|
||||
|
||||
from pysphinx.sphinx import (
|
||||
ProcessedFinalHopPacket,
|
||||
ProcessedForwardHopPacket,
|
||||
)
|
||||
|
||||
from mixnet.sphinx import SphinxPacketBuilder
|
||||
from mixnet.test_utils import init_mixnet_config
|
||||
|
||||
|
||||
class TestSphinxPacketBuilder(TestCase):
|
||||
def test_builder(self):
|
||||
global_config, _, key_map = init_mixnet_config(10)
|
||||
msg = self.random_bytes(500)
|
||||
packet, route = SphinxPacketBuilder.build(msg, global_config.membership, 3)
|
||||
self.assertEqual(3, len(route))
|
||||
|
||||
processed = packet.process(key_map[route[0].public_key.public_bytes_raw()])
|
||||
self.assertIsInstance(processed, ProcessedForwardHopPacket)
|
||||
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
|
||||
key_map[route[1].public_key.public_bytes_raw()]
|
||||
)
|
||||
self.assertIsInstance(processed, ProcessedForwardHopPacket)
|
||||
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
|
||||
key_map[route[2].public_key.public_bytes_raw()]
|
||||
)
|
||||
self.assertIsInstance(processed, ProcessedFinalHopPacket)
|
||||
recovered = cast(
|
||||
ProcessedFinalHopPacket, processed
|
||||
).payload.recover_plain_playload()
|
||||
self.assertEqual(msg, recovered)
|
||||
|
||||
@staticmethod
|
||||
def random_bytes(size: int) -> bytes:
|
||||
assert size >= 0
|
||||
return bytes([randint(0, 255) for _ in range(size)])
|
Loading…
x
Reference in New Issue
Block a user