remove message fragments

This commit is contained in:
Youngjoon Lee 2024-07-11 12:44:04 +09:00
parent ab9943d291
commit ad669abc6b
No known key found for this signature in database
GPG Key ID: B4253AFBA618BF4D
6 changed files with 81 additions and 343 deletions

View File

@ -49,19 +49,11 @@ class MixMembership:
nodes: List[NodeInfo]
def generate_route(self, num_hops: int, last_mix: NodeInfo) -> list[NodeInfo]:
def generate_route(self, length: int) -> list[NodeInfo]:
"""
Generate a mix route for a Sphinx packet.
The pre-selected mix_destination is used as a last mix node in the route,
so that associated packets can be merged together into a original message.
"""
return [*(self.choose() for _ in range(num_hops - 1)), last_mix]
def choose(self) -> NodeInfo:
"""
Choose a mix node as a mix destination that will reconstruct a message from Sphinx packets.
"""
return random.choice(self.nodes)
return [random.choice(self.nodes) for _ in range(length)]
@dataclass

View File

@ -4,7 +4,6 @@ import asyncio
from typing import TypeAlias
from pysphinx.sphinx import (
Payload,
ProcessedFinalHopPacket,
ProcessedForwardHopPacket,
SphinxPacket,
@ -12,7 +11,7 @@ from pysphinx.sphinx import (
from mixnet.config import GlobalConfig, NodeConfig
from mixnet.nomssip import Nomssip
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder
from mixnet.sphinx import SphinxPacketBuilder
BroadcastChannel: TypeAlias = asyncio.Queue[bytes]
@ -28,7 +27,6 @@ class Node:
config: NodeConfig
global_config: GlobalConfig
nomssip: Nomssip
reconstructor: MessageReconstructor
broadcast_channel: BroadcastChannel
def __init__(self, config: NodeConfig, global_config: GlobalConfig):
@ -42,7 +40,6 @@ class Node:
),
self.__process_msg,
)
self.reconstructor = MessageReconstructor()
self.broadcast_channel = asyncio.Queue()
@staticmethod
@ -50,10 +47,10 @@ class Node:
"""
Calculate the actual message size to be gossiped, which depends on the maximum length of mix path.
"""
sample_packet, _ = PacketBuilder.build_real_packets(
sample_sphinx_packet, _ = SphinxPacketBuilder.build(
bytes(1), global_config.membership, global_config.max_mix_path_length
)[0]
return len(sample_packet.bytes())
)
return len(sample_sphinx_packet.bytes())
async def __process_msg(self, msg: bytes) -> None:
"""
@ -83,24 +80,11 @@ class Node:
case ProcessedForwardHopPacket():
return processed.next_packet
case ProcessedFinalHopPacket():
return await self.__process_sphinx_payload(processed.payload)
return processed.payload.recover_plain_playload()
except ValueError:
# Return nothing, if it cannot be unwrapped by the private key of this node.
return None
async def __process_sphinx_payload(self, payload: Payload) -> bytes | None:
"""
Process the Sphinx payload if possible
"""
msg_with_flag = self.reconstructor.add(
Fragment.from_bytes(payload.recover_plain_playload())
)
if msg_with_flag is not None:
flag, msg = PacketBuilder.parse_msg_and_flag(msg_with_flag)
if flag == MessageFlag.MESSAGE_FLAG_REAL:
return msg
return None
def connect(self, peer: Node):
"""
Establish a duplex connection with a peer node.
@ -117,9 +101,9 @@ class Node:
"""
# Here, we handle the case in which a msg is split into multiple Sphinx packets.
# But, in practice, we expect a message to be small enough to fit in a single Sphinx packet.
for packet, _ in PacketBuilder.build_real_packets(
sphinx_packet, _ = SphinxPacketBuilder.build(
msg,
self.global_config.membership,
self.config.mix_path_length,
):
await self.nomssip.gossip(packet.bytes())
)
await self.nomssip.gossip(sphinx_packet.bytes())

View File

@ -1,206 +0,0 @@
from __future__ import annotations
import uuid
from dataclasses import dataclass
from enum import Enum
from itertools import batched
from typing import Dict, List, Self, Tuple, TypeAlias
from pysphinx.payload import Payload
from pysphinx.sphinx import SphinxPacket
from mixnet.config import MixMembership, NodeInfo
class MessageFlag(Enum):
MESSAGE_FLAG_REAL = b"\x00"
MESSAGE_FLAG_DROP_COVER = b"\x01"
def bytes(self) -> bytes:
return bytes(self.value)
class PacketBuilder:
@staticmethod
def build_real_packets(
message: bytes, membership: MixMembership, path_len: int
) -> List[Tuple[SphinxPacket, List[NodeInfo]]]:
return PacketBuilder.__build_packets(
MessageFlag.MESSAGE_FLAG_REAL, message, membership, path_len
)
@staticmethod
def build_drop_cover_packets(
message: bytes, membership: MixMembership, path_len: int
) -> List[Tuple[SphinxPacket, List[NodeInfo]]]:
return PacketBuilder.__build_packets(
MessageFlag.MESSAGE_FLAG_DROP_COVER, message, membership, path_len
)
@staticmethod
def __build_packets(
flag: MessageFlag, message: bytes, membership: MixMembership, path_len: int
) -> List[Tuple[SphinxPacket, List[NodeInfo]]]:
if path_len <= 0:
raise ValueError("path_len must be greater than 0")
last_mix = membership.choose()
msg_with_flag = flag.bytes() + message
# NOTE: We don't encrypt msg_with_flag for destination.
# If encryption is needed, a shared secret must be appended in front of the message along with the MessageFlag.
fragment_set = FragmentSet(msg_with_flag)
out = []
for fragment in fragment_set.fragments:
route = membership.generate_route(path_len, last_mix)
packet = SphinxPacket.build(
fragment.bytes(),
[mixnode.sphinx_node() for mixnode in route],
last_mix.sphinx_node(),
)
out.append((packet, route))
return out
@staticmethod
def parse_msg_and_flag(data: bytes) -> Tuple[MessageFlag, bytes]:
"""Remove a MessageFlag from data"""
if len(data) < 1:
raise ValueError("data is too short")
return (MessageFlag(data[0:1]), data[1:])
# Unlikely, Nym uses i32 for FragmentSetId, which may cause more collisions.
# We will use UUID until figuring out why Nym uses i32.
FragmentSetId: TypeAlias = bytes # 128bit UUID v4
FragmentId: TypeAlias = int # unsigned 8bit int in big endian
FRAGMENT_SET_ID_LENGTH: int = 16
FRAGMENT_ID_LENGTH: int = 1
@dataclass
class FragmentHeader:
"""
Contain all information for reconstructing a message that was fragmented into the same FragmentSet.
"""
set_id: FragmentSetId
total_fragments: FragmentId
fragment_id: FragmentId
SIZE: int = FRAGMENT_SET_ID_LENGTH + FRAGMENT_ID_LENGTH * 2
@staticmethod
def max_total_fragments() -> int:
return 256 # because total_fragment is u8
def bytes(self) -> bytes:
return (
self.set_id
+ self.total_fragments.to_bytes(1)
+ self.fragment_id.to_bytes(1)
)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
if len(data) != cls.SIZE:
raise ValueError("Invalid data length", len(data))
return cls(data[:16], int.from_bytes(data[16:17]), int.from_bytes(data[17:18]))
@dataclass
class FragmentSet:
"""
Represent a set of Fragments that can be reconstructed to a single original message.
Note that the maximum number of fragments in a FragmentSet is limited for now.
"""
fragments: List[Fragment]
MAX_FRAGMENTS: int = FragmentHeader.max_total_fragments()
def __init__(self, message: bytes):
"""
Build a FragmentSet by chunking a message into Fragments.
"""
chunked_messages = chunks(message, Fragment.MAX_PAYLOAD_SIZE)
# For now, we don't support more than max_fragments() fragments.
# If needed, we can devise the FragmentSet chaining to support larger messages, like Nym.
if len(chunked_messages) > self.MAX_FRAGMENTS:
raise ValueError(f"Too long message: {len(chunked_messages)} chunks")
set_id = uuid.uuid4().bytes
self.fragments = [
Fragment(FragmentHeader(set_id, len(chunked_messages), i), chunk)
for i, chunk in enumerate(chunked_messages)
]
@dataclass
class Fragment:
"""Represent a piece of data that can be transformed to a single SphinxPacket"""
header: FragmentHeader
body: bytes
MAX_PAYLOAD_SIZE: int = Payload.max_plain_payload_size() - FragmentHeader.SIZE
def bytes(self) -> bytes:
return self.header.bytes() + self.body
@classmethod
def from_bytes(cls, data: bytes) -> Self:
header = FragmentHeader.from_bytes(data[: FragmentHeader.SIZE])
body = data[FragmentHeader.SIZE :]
return cls(header, body)
@dataclass
class MessageReconstructor:
fragmentSets: Dict[FragmentSetId, FragmentSetReconstructor]
def __init__(self):
self.fragmentSets = {}
def add(self, fragment: Fragment) -> bytes | None:
if fragment.header.set_id not in self.fragmentSets:
self.fragmentSets[fragment.header.set_id] = FragmentSetReconstructor(
fragment.header.total_fragments
)
msg = self.fragmentSets[fragment.header.set_id].add(fragment)
if msg is not None:
del self.fragmentSets[fragment.header.set_id]
return msg
@dataclass
class FragmentSetReconstructor:
total_fragments: FragmentId
fragments: Dict[FragmentId, Fragment]
def __init__(self, total_fragments: FragmentId):
self.total_fragments = total_fragments
self.fragments = {}
def add(self, fragment: Fragment) -> bytes | None:
self.fragments[fragment.header.fragment_id] = fragment
if len(self.fragments) == self.total_fragments:
return self.build_message()
else:
return None
def build_message(self) -> bytes:
message = b""
for i in range(self.total_fragments):
message += self.fragments[FragmentId(i)].body
return message
def chunks(data: bytes, size: int) -> List[bytes]:
return list(map(bytes, batched(data, size)))

32
mixnet/sphinx.py Normal file
View File

@ -0,0 +1,32 @@
from __future__ import annotations
from typing import List, Tuple
from pysphinx.payload import Payload
from pysphinx.sphinx import SphinxPacket
from mixnet.config import MixMembership, NodeInfo
class SphinxPacketBuilder:
@staticmethod
def build(
message: bytes, membership: MixMembership, path_len: int
) -> Tuple[SphinxPacket, List[NodeInfo]]:
if path_len <= 0:
raise ValueError("path_len must be greater than 0")
if len(message) > Payload.max_plain_payload_size():
raise ValueError("message is too long")
route = membership.generate_route(path_len)
# We don't need the destination (defined in the Loopix Sphinx spec)
# because the last mix will broadcast the fully unwrapped message.
# Later, we will optimize the Sphinx according to our requirements.
dummy_destination = route[-1]
packet = SphinxPacket.build(
message,
route=[mixnode.sphinx_node() for mixnode in route],
destination=dummy_destination.sphinx_node(),
)
return (packet, route)

View File

@ -1,103 +0,0 @@
from random import randint
from typing import List
from unittest import TestCase
from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket, X25519PrivateKey
from mixnet.config import NodeInfo
from mixnet.packet import (
Fragment,
MessageFlag,
MessageReconstructor,
PacketBuilder,
)
from mixnet.test_utils import init_mixnet_config
class TestPacket(TestCase):
def test_real_packet(self):
global_config, _, key_map = init_mixnet_config(10)
msg = self.random_bytes(3500)
packets_and_routes = PacketBuilder.build_real_packets(
msg, global_config.membership, 3
)
self.assertEqual(4, len(packets_and_routes))
reconstructor = MessageReconstructor()
self.assertIsNone(
reconstructor.add(
self.process_packet(
packets_and_routes[1][0], packets_and_routes[1][1], key_map
)
),
)
self.assertIsNone(
reconstructor.add(
self.process_packet(
packets_and_routes[3][0], packets_and_routes[3][1], key_map
)
),
)
self.assertIsNone(
reconstructor.add(
self.process_packet(
packets_and_routes[2][0], packets_and_routes[2][1], key_map
)
),
)
msg_with_flag = reconstructor.add(
self.process_packet(
packets_and_routes[0][0], packets_and_routes[0][1], key_map
)
)
assert msg_with_flag is not None
self.assertEqual(
PacketBuilder.parse_msg_and_flag(msg_with_flag),
(MessageFlag.MESSAGE_FLAG_REAL, msg),
)
def test_cover_packet(self):
global_config, _, key_map = init_mixnet_config(10)
msg = b"cover"
packets_and_routes = PacketBuilder.build_drop_cover_packets(
msg, global_config.membership, 3
)
self.assertEqual(1, len(packets_and_routes))
reconstructor = MessageReconstructor()
msg_with_flag = reconstructor.add(
self.process_packet(
packets_and_routes[0][0], packets_and_routes[0][1], key_map
)
)
assert msg_with_flag is not None
self.assertEqual(
PacketBuilder.parse_msg_and_flag(msg_with_flag),
(MessageFlag.MESSAGE_FLAG_DROP_COVER, msg),
)
@staticmethod
def process_packet(
packet: SphinxPacket,
route: List[NodeInfo],
key_map: dict[bytes, X25519PrivateKey],
) -> Fragment:
processed = packet.process(key_map[route[0].public_key.public_bytes_raw()])
if isinstance(processed, ProcessedFinalHopPacket):
return Fragment.from_bytes(processed.payload.recover_plain_playload())
else:
processed = processed
for node in route[1:]:
p = processed.next_packet.process(
key_map[node.public_key.public_bytes_raw()]
)
if isinstance(p, ProcessedFinalHopPacket):
return Fragment.from_bytes(p.payload.recover_plain_playload())
else:
processed = p
assert False
@staticmethod
def random_bytes(size: int) -> bytes:
assert size >= 0
return bytes([randint(0, 255) for _ in range(size)])

39
mixnet/test_sphinx.py Normal file
View File

@ -0,0 +1,39 @@
from random import randint
from typing import cast
from unittest import TestCase
from pysphinx.sphinx import (
ProcessedFinalHopPacket,
ProcessedForwardHopPacket,
)
from mixnet.sphinx import SphinxPacketBuilder
from mixnet.test_utils import init_mixnet_config
class TestSphinxPacketBuilder(TestCase):
def test_builder(self):
global_config, _, key_map = init_mixnet_config(10)
msg = self.random_bytes(500)
packet, route = SphinxPacketBuilder.build(msg, global_config.membership, 3)
self.assertEqual(3, len(route))
processed = packet.process(key_map[route[0].public_key.public_bytes_raw()])
self.assertIsInstance(processed, ProcessedForwardHopPacket)
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
key_map[route[1].public_key.public_bytes_raw()]
)
self.assertIsInstance(processed, ProcessedForwardHopPacket)
processed = cast(ProcessedForwardHopPacket, processed).next_packet.process(
key_map[route[2].public_key.public_bytes_raw()]
)
self.assertIsInstance(processed, ProcessedFinalHopPacket)
recovered = cast(
ProcessedFinalHopPacket, processed
).payload.recover_plain_playload()
self.assertEqual(msg, recovered)
@staticmethod
def random_bytes(size: int) -> bytes:
assert size >= 0
return bytes([randint(0, 255) for _ in range(size)])