diff --git a/mixnet/mixnet.py b/mixnet/mixnet.py index 4fd5cb1..fbc1ece 100644 --- a/mixnet/mixnet.py +++ b/mixnet/mixnet.py @@ -1,5 +1,6 @@ from __future__ import annotations +import random from dataclasses import dataclass from typing import List, TypeAlias @@ -7,6 +8,7 @@ from cryptography.hazmat.primitives.asymmetric.x25519 import ( X25519PrivateKey, X25519PublicKey, ) +from pysphinx.node import Node from mixnet.bls import BlsPrivateKey, BlsPublicKey from mixnet.fisheryates import FisherYates @@ -43,24 +45,29 @@ class Mixnet: layers.append(layer) return MixnetTopology(layers) + def choose_mixnode(self) -> MixNode: + return random.choice(self.mix_nodes) + @dataclass class MixNode: - identity_public_key: BlsPublicKey - encryption_public_key: X25519PublicKey + identity_private_key: BlsPrivateKey + encryption_private_key: X25519PrivateKey addr: NodeAddress - def __init__( - self, - identity_private_key: BlsPrivateKey, - encryption_private_key: X25519PrivateKey, - addr: NodeAddress, - ): - self.identity_public_key = identity_private_key.get_g1() - self.encryption_public_key = encryption_private_key.public_key() - self.addr = addr + def identity_public_key(self) -> BlsPublicKey: + return self.identity_private_key.get_g1() + + def encryption_public_key(self) -> X25519PublicKey: + return self.encryption_private_key.public_key() + + def sphinx_node(self) -> Node: + return Node(self.encryption_private_key, self.addr) @dataclass class MixnetTopology: layers: List[List[MixNode]] + + def generate_route(self) -> list[MixNode]: + return [random.choice(layer) for layer in self.layers] diff --git a/mixnet/packet.py b/mixnet/packet.py new file mode 100644 index 0000000..50b721b --- /dev/null +++ b/mixnet/packet.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from enum import Enum +from itertools import batched +from typing import Dict, Iterator, List, Self, Tuple, TypeAlias + +from pysphinx.payload import Payload +from pysphinx.sphinx import SphinxPacket + +from mixnet.mixnet import Mixnet, MixnetTopology, MixNode + + +class MessageFlag(Enum): + MESSAGE_FLAG_REAL = b"\x00" + MESSAGE_FLAG_DROP_COVER = b"\x01" + + def bytes(self) -> bytes: + return bytes(self.value) + + +class PacketBuilder: + iter: Iterator[Tuple[SphinxPacket, List[MixNode]]] + + def __init__( + self, + flag: MessageFlag, + message: bytes, + mixnet: Mixnet, + topology: MixnetTopology, + ): + destination = mixnet.choose_mixnode() + + msg_with_flag = flag.bytes() + message + # NOTE: We don't encrypt msg_with_flag for destination. + # If encryption is needed, a shared secret must be appended in front of the message along with the MessageFlag. + fragment_set = FragmentSet(msg_with_flag) + + packets_and_routes = [] + for fragment in fragment_set.fragments: + route = topology.generate_route() + packet = SphinxPacket.build( + fragment.bytes(), + [mixnode.sphinx_node() for mixnode in route], + destination.sphinx_node(), + ) + packets_and_routes.append((packet, route)) + + self.iter = iter(packets_and_routes) + + @classmethod + def real(cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology) -> Self: + return cls(MessageFlag.MESSAGE_FLAG_REAL, message, mixnet, topology) + + @classmethod + def drop_cover( + cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology + ) -> Self: + return cls(MessageFlag.MESSAGE_FLAG_DROP_COVER, message, mixnet, topology) + + def next(self) -> Tuple[SphinxPacket, List[MixNode]]: + return next(self.iter) + + @staticmethod + def parse_msg_and_flag(data: bytes) -> Tuple[MessageFlag, bytes]: + """Remove a MessageFlag from data""" + if len(data) < 1: + raise ValueError("data is too short") + + return (MessageFlag(data[0:1]), data[1:]) + + +# Unlikely, Nym uses i32 for FragmentSetId, which may cause more collisions. +# We will use UUID until figuring out why Nym uses i32. +FragmentSetId: TypeAlias = bytes # 128bit UUID v4 +FragmentId: TypeAlias = int # unsigned 8bit int in big endian + +FRAGMENT_SET_ID_LENGTH: int = 16 +FRAGMENT_ID_LENGTH: int = 1 + + +@dataclass +class FragmentHeader: + """ + Contain all information for reconstructing a message that was fragmented into the same FragmentSet. + """ + + set_id: FragmentSetId + total_fragments: FragmentId + fragment_id: FragmentId + + SIZE: int = FRAGMENT_SET_ID_LENGTH + FRAGMENT_ID_LENGTH * 2 + + @staticmethod + def max_total_fragments() -> int: + return 256 # because total_fragment is u8 + + def bytes(self) -> bytes: + return ( + self.set_id + + self.total_fragments.to_bytes(1) + + self.fragment_id.to_bytes(1) + ) + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + if len(data) != cls.SIZE: + raise ValueError("Invalid data length", len(data)) + + return cls(data[:16], int.from_bytes(data[16:17]), int.from_bytes(data[17:18])) + + +@dataclass +class FragmentSet: + """ + Represent a set of Fragments that can be reconstructed to a single original message. + + Note that the maximum number of fragments in a FragmentSet is limited for now. + """ + + fragments: List[Fragment] + + MAX_FRAGMENTS: int = FragmentHeader.max_total_fragments() + + def __init__(self, message: bytes): + """ + Build a FragmentSet by chunking a message into Fragments. + """ + chunked_messages = chunks(message, Fragment.MAX_PAYLOAD_SIZE) + # For now, we don't support more than max_fragments() fragments. + # If needed, we can devise the FragmentSet chaining to support larger messages, like Nym. + if len(chunked_messages) > self.MAX_FRAGMENTS: + raise ValueError(f"Too long message: {len(chunked_messages)} chunks") + + set_id = uuid.uuid4().bytes + self.fragments = [ + Fragment(FragmentHeader(set_id, len(chunked_messages), i), chunk) + for i, chunk in enumerate(chunked_messages) + ] + + +@dataclass +class Fragment: + """Represent a piece of data that can be transformed to a single SphinxPacket""" + + header: FragmentHeader + body: bytes + + MAX_PAYLOAD_SIZE: int = Payload.max_plain_payload_size() - FragmentHeader.SIZE + + def bytes(self) -> bytes: + return self.header.bytes() + self.body + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + header = FragmentHeader.from_bytes(data[: FragmentHeader.SIZE]) + body = data[FragmentHeader.SIZE :] + return cls(header, body) + + +@dataclass +class MessageReconstructor: + fragmentSets: Dict[FragmentSetId, FragmentSetReconstructor] + + def __init__(self): + self.fragmentSets = {} + + def add(self, fragment: Fragment) -> bytes | None: + if fragment.header.set_id not in self.fragmentSets: + self.fragmentSets[fragment.header.set_id] = FragmentSetReconstructor( + fragment.header.total_fragments + ) + + msg = self.fragmentSets[fragment.header.set_id].add(fragment) + if msg is not None: + del self.fragmentSets[fragment.header.set_id] + return msg + + +@dataclass +class FragmentSetReconstructor: + total_fragments: FragmentId + fragments: Dict[FragmentId, Fragment] + + def __init__(self, total_fragments: FragmentId): + self.total_fragments = total_fragments + self.fragments = {} + + def add(self, fragment: Fragment) -> bytes | None: + self.fragments[fragment.header.fragment_id] = fragment + if len(self.fragments) == self.total_fragments: + return self.build_message() + else: + return None + + def build_message(self) -> bytes: + message = b"" + for i in range(self.total_fragments): + message += self.fragments[FragmentId(i)].body + return message + + +def chunks(data: bytes, size: int) -> List[bytes]: + return list(map(bytes, batched(data, size))) diff --git a/mixnet/test_packet.py b/mixnet/test_packet.py new file mode 100644 index 0000000..676da7e --- /dev/null +++ b/mixnet/test_packet.py @@ -0,0 +1,91 @@ +from typing import List, Tuple +from unittest import TestCase + +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey +from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket + +from mixnet.bls import generate_bls +from mixnet.mixnet import Mixnet, MixnetTopology, MixNode +from mixnet.packet import ( + Fragment, + MessageFlag, + MessageReconstructor, + PacketBuilder, +) +from mixnet.utils import random_bytes + + +class TestPacket(TestCase): + def test_real_packet(self): + mixnet, topology = self.init() + + msg = random_bytes(3500) + builder = PacketBuilder.real(msg, mixnet, topology) + packet0, route0 = builder.next() + packet1, route1 = builder.next() + packet2, route2 = builder.next() + packet3, route3 = builder.next() + self.assertRaises(StopIteration, builder.next) + + reconstructor = MessageReconstructor() + self.assertIsNone( + reconstructor.add(self.process_packet(packet1, route1)), + ) + self.assertIsNone( + reconstructor.add(self.process_packet(packet3, route3)), + ) + self.assertIsNone( + reconstructor.add(self.process_packet(packet2, route2)), + ) + msg_with_flag = reconstructor.add(self.process_packet(packet0, route0)) + assert msg_with_flag is not None + self.assertEqual( + PacketBuilder.parse_msg_and_flag(msg_with_flag), + (MessageFlag.MESSAGE_FLAG_REAL, msg), + ) + + def test_cover_packet(self): + mixnet, topology = self.init() + + msg = b"cover" + builder = PacketBuilder.drop_cover(msg, mixnet, topology) + packet, route = builder.next() + self.assertRaises(StopIteration, builder.next) + + reconstructor = MessageReconstructor() + msg_with_flag = reconstructor.add(self.process_packet(packet, route)) + assert msg_with_flag is not None + self.assertEqual( + PacketBuilder.parse_msg_and_flag(msg_with_flag), + (MessageFlag.MESSAGE_FLAG_DROP_COVER, msg), + ) + + @staticmethod + def init() -> Tuple[Mixnet, MixnetTopology]: + mixnet = Mixnet( + [ + MixNode( + generate_bls(), + X25519PrivateKey.generate(), + random_bytes(32), + ) + for _ in range(12) + ] + ) + topology = mixnet.build_topology(b"entropy", 3, 3) + return mixnet, topology + + @staticmethod + def process_packet(packet: SphinxPacket, route: List[MixNode]) -> Fragment: + processed = packet.process(route[0].encryption_private_key) + if isinstance(processed, ProcessedFinalHopPacket): + return Fragment.from_bytes(processed.payload.recover_plain_playload()) + else: + processed = processed + for node in route[1:]: + p = processed.next_packet.process(node.encryption_private_key) + if isinstance(p, ProcessedFinalHopPacket): + return Fragment.from_bytes(p.payload.recover_plain_playload()) + else: + processed = p + assert False diff --git a/requirements.txt b/requirements.txt index de2bd75..0013983 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,7 @@ cffi==1.16.0 cryptography==41.0.7 numpy==1.26.3 pycparser==2.21 +pysphinx==0.0.1 scipy==1.11.4 setuptools==69.0.3 +wheel==0.42.0