Mixnet: Sphinx packet builder for mix clients (#47)

This commit is contained in:
Youngjoon Lee 2024-01-15 15:17:35 +09:00
parent 2263327320
commit 1fc319de9e
No known key found for this signature in database
GPG Key ID: FC4855084E0B6A46
4 changed files with 316 additions and 11 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import random
from dataclasses import dataclass
from typing import List, TypeAlias
@ -7,6 +8,7 @@ from cryptography.hazmat.primitives.asymmetric.x25519 import (
X25519PrivateKey,
X25519PublicKey,
)
from pysphinx.node import Node
from mixnet.bls import BlsPrivateKey, BlsPublicKey
from mixnet.fisheryates import FisherYates
@ -43,24 +45,29 @@ class Mixnet:
layers.append(layer)
return MixnetTopology(layers)
def choose_mixnode(self) -> MixNode:
return random.choice(self.mix_nodes)
@dataclass
class MixNode:
identity_public_key: BlsPublicKey
encryption_public_key: X25519PublicKey
identity_private_key: BlsPrivateKey
encryption_private_key: X25519PrivateKey
addr: NodeAddress
def __init__(
self,
identity_private_key: BlsPrivateKey,
encryption_private_key: X25519PrivateKey,
addr: NodeAddress,
):
self.identity_public_key = identity_private_key.get_g1()
self.encryption_public_key = encryption_private_key.public_key()
self.addr = addr
def identity_public_key(self) -> BlsPublicKey:
return self.identity_private_key.get_g1()
def encryption_public_key(self) -> X25519PublicKey:
return self.encryption_private_key.public_key()
def sphinx_node(self) -> Node:
return Node(self.encryption_private_key, self.addr)
@dataclass
class MixnetTopology:
layers: List[List[MixNode]]
def generate_route(self) -> list[MixNode]:
return [random.choice(layer) for layer in self.layers]

205
mixnet/packet.py Normal file
View File

@ -0,0 +1,205 @@
from __future__ import annotations
import uuid
from dataclasses import dataclass
from enum import Enum
from itertools import batched
from typing import Dict, Iterator, List, Self, Tuple, TypeAlias
from pysphinx.payload import Payload
from pysphinx.sphinx import SphinxPacket
from mixnet.mixnet import Mixnet, MixnetTopology, MixNode
class MessageFlag(Enum):
MESSAGE_FLAG_REAL = b"\x00"
MESSAGE_FLAG_DROP_COVER = b"\x01"
def bytes(self) -> bytes:
return bytes(self.value)
class PacketBuilder:
iter: Iterator[Tuple[SphinxPacket, List[MixNode]]]
def __init__(
self,
flag: MessageFlag,
message: bytes,
mixnet: Mixnet,
topology: MixnetTopology,
):
destination = mixnet.choose_mixnode()
msg_with_flag = flag.bytes() + message
# NOTE: We don't encrypt msg_with_flag for destination.
# If encryption is needed, a shared secret must be appended in front of the message along with the MessageFlag.
fragment_set = FragmentSet(msg_with_flag)
packets_and_routes = []
for fragment in fragment_set.fragments:
route = topology.generate_route()
packet = SphinxPacket.build(
fragment.bytes(),
[mixnode.sphinx_node() for mixnode in route],
destination.sphinx_node(),
)
packets_and_routes.append((packet, route))
self.iter = iter(packets_and_routes)
@classmethod
def real(cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology) -> Self:
return cls(MessageFlag.MESSAGE_FLAG_REAL, message, mixnet, topology)
@classmethod
def drop_cover(
cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology
) -> Self:
return cls(MessageFlag.MESSAGE_FLAG_DROP_COVER, message, mixnet, topology)
def next(self) -> Tuple[SphinxPacket, List[MixNode]]:
return next(self.iter)
@staticmethod
def parse_msg_and_flag(data: bytes) -> Tuple[MessageFlag, bytes]:
"""Remove a MessageFlag from data"""
if len(data) < 1:
raise ValueError("data is too short")
return (MessageFlag(data[0:1]), data[1:])
# Unlikely, Nym uses i32 for FragmentSetId, which may cause more collisions.
# We will use UUID until figuring out why Nym uses i32.
FragmentSetId: TypeAlias = bytes # 128bit UUID v4
FragmentId: TypeAlias = int # unsigned 8bit int in big endian
FRAGMENT_SET_ID_LENGTH: int = 16
FRAGMENT_ID_LENGTH: int = 1
@dataclass
class FragmentHeader:
"""
Contain all information for reconstructing a message that was fragmented into the same FragmentSet.
"""
set_id: FragmentSetId
total_fragments: FragmentId
fragment_id: FragmentId
SIZE: int = FRAGMENT_SET_ID_LENGTH + FRAGMENT_ID_LENGTH * 2
@staticmethod
def max_total_fragments() -> int:
return 256 # because total_fragment is u8
def bytes(self) -> bytes:
return (
self.set_id
+ self.total_fragments.to_bytes(1)
+ self.fragment_id.to_bytes(1)
)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
if len(data) != cls.SIZE:
raise ValueError("Invalid data length", len(data))
return cls(data[:16], int.from_bytes(data[16:17]), int.from_bytes(data[17:18]))
@dataclass
class FragmentSet:
"""
Represent a set of Fragments that can be reconstructed to a single original message.
Note that the maximum number of fragments in a FragmentSet is limited for now.
"""
fragments: List[Fragment]
MAX_FRAGMENTS: int = FragmentHeader.max_total_fragments()
def __init__(self, message: bytes):
"""
Build a FragmentSet by chunking a message into Fragments.
"""
chunked_messages = chunks(message, Fragment.MAX_PAYLOAD_SIZE)
# For now, we don't support more than max_fragments() fragments.
# If needed, we can devise the FragmentSet chaining to support larger messages, like Nym.
if len(chunked_messages) > self.MAX_FRAGMENTS:
raise ValueError(f"Too long message: {len(chunked_messages)} chunks")
set_id = uuid.uuid4().bytes
self.fragments = [
Fragment(FragmentHeader(set_id, len(chunked_messages), i), chunk)
for i, chunk in enumerate(chunked_messages)
]
@dataclass
class Fragment:
"""Represent a piece of data that can be transformed to a single SphinxPacket"""
header: FragmentHeader
body: bytes
MAX_PAYLOAD_SIZE: int = Payload.max_plain_payload_size() - FragmentHeader.SIZE
def bytes(self) -> bytes:
return self.header.bytes() + self.body
@classmethod
def from_bytes(cls, data: bytes) -> Self:
header = FragmentHeader.from_bytes(data[: FragmentHeader.SIZE])
body = data[FragmentHeader.SIZE :]
return cls(header, body)
@dataclass
class MessageReconstructor:
fragmentSets: Dict[FragmentSetId, FragmentSetReconstructor]
def __init__(self):
self.fragmentSets = {}
def add(self, fragment: Fragment) -> bytes | None:
if fragment.header.set_id not in self.fragmentSets:
self.fragmentSets[fragment.header.set_id] = FragmentSetReconstructor(
fragment.header.total_fragments
)
msg = self.fragmentSets[fragment.header.set_id].add(fragment)
if msg is not None:
del self.fragmentSets[fragment.header.set_id]
return msg
@dataclass
class FragmentSetReconstructor:
total_fragments: FragmentId
fragments: Dict[FragmentId, Fragment]
def __init__(self, total_fragments: FragmentId):
self.total_fragments = total_fragments
self.fragments = {}
def add(self, fragment: Fragment) -> bytes | None:
self.fragments[fragment.header.fragment_id] = fragment
if len(self.fragments) == self.total_fragments:
return self.build_message()
else:
return None
def build_message(self) -> bytes:
message = b""
for i in range(self.total_fragments):
message += self.fragments[FragmentId(i)].body
return message
def chunks(data: bytes, size: int) -> List[bytes]:
return list(map(bytes, batched(data, size)))

91
mixnet/test_packet.py Normal file
View File

@ -0,0 +1,91 @@
from typing import List, Tuple
from unittest import TestCase
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket
from mixnet.bls import generate_bls
from mixnet.mixnet import Mixnet, MixnetTopology, MixNode
from mixnet.packet import (
Fragment,
MessageFlag,
MessageReconstructor,
PacketBuilder,
)
from mixnet.utils import random_bytes
class TestPacket(TestCase):
def test_real_packet(self):
mixnet, topology = self.init()
msg = random_bytes(3500)
builder = PacketBuilder.real(msg, mixnet, topology)
packet0, route0 = builder.next()
packet1, route1 = builder.next()
packet2, route2 = builder.next()
packet3, route3 = builder.next()
self.assertRaises(StopIteration, builder.next)
reconstructor = MessageReconstructor()
self.assertIsNone(
reconstructor.add(self.process_packet(packet1, route1)),
)
self.assertIsNone(
reconstructor.add(self.process_packet(packet3, route3)),
)
self.assertIsNone(
reconstructor.add(self.process_packet(packet2, route2)),
)
msg_with_flag = reconstructor.add(self.process_packet(packet0, route0))
assert msg_with_flag is not None
self.assertEqual(
PacketBuilder.parse_msg_and_flag(msg_with_flag),
(MessageFlag.MESSAGE_FLAG_REAL, msg),
)
def test_cover_packet(self):
mixnet, topology = self.init()
msg = b"cover"
builder = PacketBuilder.drop_cover(msg, mixnet, topology)
packet, route = builder.next()
self.assertRaises(StopIteration, builder.next)
reconstructor = MessageReconstructor()
msg_with_flag = reconstructor.add(self.process_packet(packet, route))
assert msg_with_flag is not None
self.assertEqual(
PacketBuilder.parse_msg_and_flag(msg_with_flag),
(MessageFlag.MESSAGE_FLAG_DROP_COVER, msg),
)
@staticmethod
def init() -> Tuple[Mixnet, MixnetTopology]:
mixnet = Mixnet(
[
MixNode(
generate_bls(),
X25519PrivateKey.generate(),
random_bytes(32),
)
for _ in range(12)
]
)
topology = mixnet.build_topology(b"entropy", 3, 3)
return mixnet, topology
@staticmethod
def process_packet(packet: SphinxPacket, route: List[MixNode]) -> Fragment:
processed = packet.process(route[0].encryption_private_key)
if isinstance(processed, ProcessedFinalHopPacket):
return Fragment.from_bytes(processed.payload.recover_plain_playload())
else:
processed = processed
for node in route[1:]:
p = processed.next_packet.process(node.encryption_private_key)
if isinstance(p, ProcessedFinalHopPacket):
return Fragment.from_bytes(p.payload.recover_plain_playload())
else:
processed = p
assert False

View File

@ -3,5 +3,7 @@ cffi==1.16.0
cryptography==41.0.7
numpy==1.26.3
pycparser==2.21
pysphinx==0.0.1
scipy==1.11.4
setuptools==69.0.3
wheel==0.42.0