204 lines
6.1 KiB
Python
204 lines
6.1 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from itertools import batched
|
|
from typing import Dict, List, Self, Tuple, TypeAlias
|
|
|
|
from pysphinx.payload import Payload
|
|
from pysphinx.sphinx import SphinxPacket
|
|
|
|
from mixnet.config import MixnetTopology, MixNodeInfo
|
|
|
|
|
|
class MessageFlag(Enum):
|
|
MESSAGE_FLAG_REAL = b"\x00"
|
|
MESSAGE_FLAG_DROP_COVER = b"\x01"
|
|
|
|
def bytes(self) -> bytes:
|
|
return bytes(self.value)
|
|
|
|
|
|
class PacketBuilder:
|
|
@staticmethod
|
|
def build_real_packets(
|
|
message: bytes, topology: MixnetTopology
|
|
) -> List[Tuple[SphinxPacket, List[MixNodeInfo]]]:
|
|
return PacketBuilder.__build_packets(
|
|
MessageFlag.MESSAGE_FLAG_REAL, message, topology
|
|
)
|
|
|
|
@staticmethod
|
|
def build_drop_cover_packets(
|
|
message: bytes, topology: MixnetTopology
|
|
) -> List[Tuple[SphinxPacket, List[MixNodeInfo]]]:
|
|
return PacketBuilder.__build_packets(
|
|
MessageFlag.MESSAGE_FLAG_DROP_COVER, message, topology
|
|
)
|
|
|
|
@staticmethod
|
|
def __build_packets(
|
|
flag: MessageFlag, message: bytes, topology: MixnetTopology
|
|
) -> List[Tuple[SphinxPacket, List[MixNodeInfo]]]:
|
|
destination = topology.choose_mix_destination()
|
|
|
|
msg_with_flag = flag.bytes() + message
|
|
# NOTE: We don't encrypt msg_with_flag for destination.
|
|
# If encryption is needed, a shared secret must be appended in front of the message along with the MessageFlag.
|
|
fragment_set = FragmentSet(msg_with_flag)
|
|
|
|
out = []
|
|
for fragment in fragment_set.fragments:
|
|
route = topology.generate_route(destination)
|
|
packet = SphinxPacket.build(
|
|
fragment.bytes(),
|
|
[mixnode.sphinx_node() for mixnode in route],
|
|
destination.sphinx_node(),
|
|
)
|
|
out.append((packet, route))
|
|
|
|
return out
|
|
|
|
@staticmethod
|
|
def parse_msg_and_flag(data: bytes) -> Tuple[MessageFlag, bytes]:
|
|
"""Remove a MessageFlag from data"""
|
|
if len(data) < 1:
|
|
raise ValueError("data is too short")
|
|
|
|
return (MessageFlag(data[0:1]), data[1:])
|
|
|
|
|
|
# Unlikely, Nym uses i32 for FragmentSetId, which may cause more collisions.
|
|
# We will use UUID until figuring out why Nym uses i32.
|
|
FragmentSetId: TypeAlias = bytes # 128bit UUID v4
|
|
FragmentId: TypeAlias = int # unsigned 8bit int in big endian
|
|
|
|
FRAGMENT_SET_ID_LENGTH: int = 16
|
|
FRAGMENT_ID_LENGTH: int = 1
|
|
|
|
|
|
@dataclass
|
|
class FragmentHeader:
|
|
"""
|
|
Contain all information for reconstructing a message that was fragmented into the same FragmentSet.
|
|
"""
|
|
|
|
set_id: FragmentSetId
|
|
total_fragments: FragmentId
|
|
fragment_id: FragmentId
|
|
|
|
SIZE: int = FRAGMENT_SET_ID_LENGTH + FRAGMENT_ID_LENGTH * 2
|
|
|
|
@staticmethod
|
|
def max_total_fragments() -> int:
|
|
return 256 # because total_fragment is u8
|
|
|
|
def bytes(self) -> bytes:
|
|
return (
|
|
self.set_id
|
|
+ self.total_fragments.to_bytes(1)
|
|
+ self.fragment_id.to_bytes(1)
|
|
)
|
|
|
|
@classmethod
|
|
def from_bytes(cls, data: bytes) -> Self:
|
|
if len(data) != cls.SIZE:
|
|
raise ValueError("Invalid data length", len(data))
|
|
|
|
return cls(data[:16], int.from_bytes(data[16:17]), int.from_bytes(data[17:18]))
|
|
|
|
|
|
@dataclass
|
|
class FragmentSet:
|
|
"""
|
|
Represent a set of Fragments that can be reconstructed to a single original message.
|
|
|
|
Note that the maximum number of fragments in a FragmentSet is limited for now.
|
|
"""
|
|
|
|
fragments: List[Fragment]
|
|
|
|
MAX_FRAGMENTS: int = FragmentHeader.max_total_fragments()
|
|
|
|
def __init__(self, message: bytes):
|
|
"""
|
|
Build a FragmentSet by chunking a message into Fragments.
|
|
"""
|
|
chunked_messages = chunks(message, Fragment.MAX_PAYLOAD_SIZE)
|
|
# For now, we don't support more than max_fragments() fragments.
|
|
# If needed, we can devise the FragmentSet chaining to support larger messages, like Nym.
|
|
if len(chunked_messages) > self.MAX_FRAGMENTS:
|
|
raise ValueError(f"Too long message: {len(chunked_messages)} chunks")
|
|
|
|
set_id = uuid.uuid4().bytes
|
|
self.fragments = [
|
|
Fragment(FragmentHeader(set_id, len(chunked_messages), i), chunk)
|
|
for i, chunk in enumerate(chunked_messages)
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class Fragment:
|
|
"""Represent a piece of data that can be transformed to a single SphinxPacket"""
|
|
|
|
header: FragmentHeader
|
|
body: bytes
|
|
|
|
MAX_PAYLOAD_SIZE: int = Payload.max_plain_payload_size() - FragmentHeader.SIZE
|
|
|
|
def bytes(self) -> bytes:
|
|
return self.header.bytes() + self.body
|
|
|
|
@classmethod
|
|
def from_bytes(cls, data: bytes) -> Self:
|
|
header = FragmentHeader.from_bytes(data[: FragmentHeader.SIZE])
|
|
body = data[FragmentHeader.SIZE :]
|
|
return cls(header, body)
|
|
|
|
|
|
@dataclass
|
|
class MessageReconstructor:
|
|
fragmentSets: Dict[FragmentSetId, FragmentSetReconstructor]
|
|
|
|
def __init__(self):
|
|
self.fragmentSets = {}
|
|
|
|
def add(self, fragment: Fragment) -> bytes | None:
|
|
if fragment.header.set_id not in self.fragmentSets:
|
|
self.fragmentSets[fragment.header.set_id] = FragmentSetReconstructor(
|
|
fragment.header.total_fragments
|
|
)
|
|
|
|
msg = self.fragmentSets[fragment.header.set_id].add(fragment)
|
|
if msg is not None:
|
|
del self.fragmentSets[fragment.header.set_id]
|
|
return msg
|
|
|
|
|
|
@dataclass
|
|
class FragmentSetReconstructor:
|
|
total_fragments: FragmentId
|
|
fragments: Dict[FragmentId, Fragment]
|
|
|
|
def __init__(self, total_fragments: FragmentId):
|
|
self.total_fragments = total_fragments
|
|
self.fragments = {}
|
|
|
|
def add(self, fragment: Fragment) -> bytes | None:
|
|
self.fragments[fragment.header.fragment_id] = fragment
|
|
if len(self.fragments) == self.total_fragments:
|
|
return self.build_message()
|
|
else:
|
|
return None
|
|
|
|
def build_message(self) -> bytes:
|
|
message = b""
|
|
for i in range(self.total_fragments):
|
|
message += self.fragments[FragmentId(i)].body
|
|
return message
|
|
|
|
|
|
def chunks(data: bytes, size: int) -> List[bytes]:
|
|
return list(map(bytes, batched(data, size)))
|