Mixnet: Sphinx packet builder for mix clients (#47)
This commit is contained in:
parent
2263327320
commit
1fc319de9e
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, TypeAlias
|
from typing import List, TypeAlias
|
||||||
|
|
||||||
|
@ -7,6 +8,7 @@ from cryptography.hazmat.primitives.asymmetric.x25519 import (
|
||||||
X25519PrivateKey,
|
X25519PrivateKey,
|
||||||
X25519PublicKey,
|
X25519PublicKey,
|
||||||
)
|
)
|
||||||
|
from pysphinx.node import Node
|
||||||
|
|
||||||
from mixnet.bls import BlsPrivateKey, BlsPublicKey
|
from mixnet.bls import BlsPrivateKey, BlsPublicKey
|
||||||
from mixnet.fisheryates import FisherYates
|
from mixnet.fisheryates import FisherYates
|
||||||
|
@ -43,24 +45,29 @@ class Mixnet:
|
||||||
layers.append(layer)
|
layers.append(layer)
|
||||||
return MixnetTopology(layers)
|
return MixnetTopology(layers)
|
||||||
|
|
||||||
|
def choose_mixnode(self) -> MixNode:
|
||||||
|
return random.choice(self.mix_nodes)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MixNode:
|
class MixNode:
|
||||||
identity_public_key: BlsPublicKey
|
identity_private_key: BlsPrivateKey
|
||||||
encryption_public_key: X25519PublicKey
|
encryption_private_key: X25519PrivateKey
|
||||||
addr: NodeAddress
|
addr: NodeAddress
|
||||||
|
|
||||||
def __init__(
|
def identity_public_key(self) -> BlsPublicKey:
|
||||||
self,
|
return self.identity_private_key.get_g1()
|
||||||
identity_private_key: BlsPrivateKey,
|
|
||||||
encryption_private_key: X25519PrivateKey,
|
def encryption_public_key(self) -> X25519PublicKey:
|
||||||
addr: NodeAddress,
|
return self.encryption_private_key.public_key()
|
||||||
):
|
|
||||||
self.identity_public_key = identity_private_key.get_g1()
|
def sphinx_node(self) -> Node:
|
||||||
self.encryption_public_key = encryption_private_key.public_key()
|
return Node(self.encryption_private_key, self.addr)
|
||||||
self.addr = addr
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MixnetTopology:
|
class MixnetTopology:
|
||||||
layers: List[List[MixNode]]
|
layers: List[List[MixNode]]
|
||||||
|
|
||||||
|
def generate_route(self) -> list[MixNode]:
|
||||||
|
return [random.choice(layer) for layer in self.layers]
|
||||||
|
|
|
@ -0,0 +1,205 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from itertools import batched
|
||||||
|
from typing import Dict, Iterator, List, Self, Tuple, TypeAlias
|
||||||
|
|
||||||
|
from pysphinx.payload import Payload
|
||||||
|
from pysphinx.sphinx import SphinxPacket
|
||||||
|
|
||||||
|
from mixnet.mixnet import Mixnet, MixnetTopology, MixNode
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFlag(Enum):
|
||||||
|
MESSAGE_FLAG_REAL = b"\x00"
|
||||||
|
MESSAGE_FLAG_DROP_COVER = b"\x01"
|
||||||
|
|
||||||
|
def bytes(self) -> bytes:
|
||||||
|
return bytes(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
class PacketBuilder:
|
||||||
|
iter: Iterator[Tuple[SphinxPacket, List[MixNode]]]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
flag: MessageFlag,
|
||||||
|
message: bytes,
|
||||||
|
mixnet: Mixnet,
|
||||||
|
topology: MixnetTopology,
|
||||||
|
):
|
||||||
|
destination = mixnet.choose_mixnode()
|
||||||
|
|
||||||
|
msg_with_flag = flag.bytes() + message
|
||||||
|
# NOTE: We don't encrypt msg_with_flag for destination.
|
||||||
|
# If encryption is needed, a shared secret must be appended in front of the message along with the MessageFlag.
|
||||||
|
fragment_set = FragmentSet(msg_with_flag)
|
||||||
|
|
||||||
|
packets_and_routes = []
|
||||||
|
for fragment in fragment_set.fragments:
|
||||||
|
route = topology.generate_route()
|
||||||
|
packet = SphinxPacket.build(
|
||||||
|
fragment.bytes(),
|
||||||
|
[mixnode.sphinx_node() for mixnode in route],
|
||||||
|
destination.sphinx_node(),
|
||||||
|
)
|
||||||
|
packets_and_routes.append((packet, route))
|
||||||
|
|
||||||
|
self.iter = iter(packets_and_routes)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def real(cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology) -> Self:
|
||||||
|
return cls(MessageFlag.MESSAGE_FLAG_REAL, message, mixnet, topology)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def drop_cover(
|
||||||
|
cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology
|
||||||
|
) -> Self:
|
||||||
|
return cls(MessageFlag.MESSAGE_FLAG_DROP_COVER, message, mixnet, topology)
|
||||||
|
|
||||||
|
def next(self) -> Tuple[SphinxPacket, List[MixNode]]:
|
||||||
|
return next(self.iter)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_msg_and_flag(data: bytes) -> Tuple[MessageFlag, bytes]:
|
||||||
|
"""Remove a MessageFlag from data"""
|
||||||
|
if len(data) < 1:
|
||||||
|
raise ValueError("data is too short")
|
||||||
|
|
||||||
|
return (MessageFlag(data[0:1]), data[1:])
|
||||||
|
|
||||||
|
|
||||||
|
# Unlikely, Nym uses i32 for FragmentSetId, which may cause more collisions.
|
||||||
|
# We will use UUID until figuring out why Nym uses i32.
|
||||||
|
FragmentSetId: TypeAlias = bytes # 128bit UUID v4
|
||||||
|
FragmentId: TypeAlias = int # unsigned 8bit int in big endian
|
||||||
|
|
||||||
|
FRAGMENT_SET_ID_LENGTH: int = 16
|
||||||
|
FRAGMENT_ID_LENGTH: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FragmentHeader:
|
||||||
|
"""
|
||||||
|
Contain all information for reconstructing a message that was fragmented into the same FragmentSet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
set_id: FragmentSetId
|
||||||
|
total_fragments: FragmentId
|
||||||
|
fragment_id: FragmentId
|
||||||
|
|
||||||
|
SIZE: int = FRAGMENT_SET_ID_LENGTH + FRAGMENT_ID_LENGTH * 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def max_total_fragments() -> int:
|
||||||
|
return 256 # because total_fragment is u8
|
||||||
|
|
||||||
|
def bytes(self) -> bytes:
|
||||||
|
return (
|
||||||
|
self.set_id
|
||||||
|
+ self.total_fragments.to_bytes(1)
|
||||||
|
+ self.fragment_id.to_bytes(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data: bytes) -> Self:
|
||||||
|
if len(data) != cls.SIZE:
|
||||||
|
raise ValueError("Invalid data length", len(data))
|
||||||
|
|
||||||
|
return cls(data[:16], int.from_bytes(data[16:17]), int.from_bytes(data[17:18]))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FragmentSet:
|
||||||
|
"""
|
||||||
|
Represent a set of Fragments that can be reconstructed to a single original message.
|
||||||
|
|
||||||
|
Note that the maximum number of fragments in a FragmentSet is limited for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fragments: List[Fragment]
|
||||||
|
|
||||||
|
MAX_FRAGMENTS: int = FragmentHeader.max_total_fragments()
|
||||||
|
|
||||||
|
def __init__(self, message: bytes):
|
||||||
|
"""
|
||||||
|
Build a FragmentSet by chunking a message into Fragments.
|
||||||
|
"""
|
||||||
|
chunked_messages = chunks(message, Fragment.MAX_PAYLOAD_SIZE)
|
||||||
|
# For now, we don't support more than max_fragments() fragments.
|
||||||
|
# If needed, we can devise the FragmentSet chaining to support larger messages, like Nym.
|
||||||
|
if len(chunked_messages) > self.MAX_FRAGMENTS:
|
||||||
|
raise ValueError(f"Too long message: {len(chunked_messages)} chunks")
|
||||||
|
|
||||||
|
set_id = uuid.uuid4().bytes
|
||||||
|
self.fragments = [
|
||||||
|
Fragment(FragmentHeader(set_id, len(chunked_messages), i), chunk)
|
||||||
|
for i, chunk in enumerate(chunked_messages)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Fragment:
|
||||||
|
"""Represent a piece of data that can be transformed to a single SphinxPacket"""
|
||||||
|
|
||||||
|
header: FragmentHeader
|
||||||
|
body: bytes
|
||||||
|
|
||||||
|
MAX_PAYLOAD_SIZE: int = Payload.max_plain_payload_size() - FragmentHeader.SIZE
|
||||||
|
|
||||||
|
def bytes(self) -> bytes:
|
||||||
|
return self.header.bytes() + self.body
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data: bytes) -> Self:
|
||||||
|
header = FragmentHeader.from_bytes(data[: FragmentHeader.SIZE])
|
||||||
|
body = data[FragmentHeader.SIZE :]
|
||||||
|
return cls(header, body)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageReconstructor:
|
||||||
|
fragmentSets: Dict[FragmentSetId, FragmentSetReconstructor]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.fragmentSets = {}
|
||||||
|
|
||||||
|
def add(self, fragment: Fragment) -> bytes | None:
|
||||||
|
if fragment.header.set_id not in self.fragmentSets:
|
||||||
|
self.fragmentSets[fragment.header.set_id] = FragmentSetReconstructor(
|
||||||
|
fragment.header.total_fragments
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = self.fragmentSets[fragment.header.set_id].add(fragment)
|
||||||
|
if msg is not None:
|
||||||
|
del self.fragmentSets[fragment.header.set_id]
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FragmentSetReconstructor:
|
||||||
|
total_fragments: FragmentId
|
||||||
|
fragments: Dict[FragmentId, Fragment]
|
||||||
|
|
||||||
|
def __init__(self, total_fragments: FragmentId):
|
||||||
|
self.total_fragments = total_fragments
|
||||||
|
self.fragments = {}
|
||||||
|
|
||||||
|
def add(self, fragment: Fragment) -> bytes | None:
|
||||||
|
self.fragments[fragment.header.fragment_id] = fragment
|
||||||
|
if len(self.fragments) == self.total_fragments:
|
||||||
|
return self.build_message()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def build_message(self) -> bytes:
|
||||||
|
message = b""
|
||||||
|
for i in range(self.total_fragments):
|
||||||
|
message += self.fragments[FragmentId(i)].body
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def chunks(data: bytes, size: int) -> List[bytes]:
|
||||||
|
return list(map(bytes, batched(data, size)))
|
|
@ -0,0 +1,91 @@
|
||||||
|
from typing import List, Tuple
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
||||||
|
from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket
|
||||||
|
|
||||||
|
from mixnet.bls import generate_bls
|
||||||
|
from mixnet.mixnet import Mixnet, MixnetTopology, MixNode
|
||||||
|
from mixnet.packet import (
|
||||||
|
Fragment,
|
||||||
|
MessageFlag,
|
||||||
|
MessageReconstructor,
|
||||||
|
PacketBuilder,
|
||||||
|
)
|
||||||
|
from mixnet.utils import random_bytes
|
||||||
|
|
||||||
|
|
||||||
|
class TestPacket(TestCase):
|
||||||
|
def test_real_packet(self):
|
||||||
|
mixnet, topology = self.init()
|
||||||
|
|
||||||
|
msg = random_bytes(3500)
|
||||||
|
builder = PacketBuilder.real(msg, mixnet, topology)
|
||||||
|
packet0, route0 = builder.next()
|
||||||
|
packet1, route1 = builder.next()
|
||||||
|
packet2, route2 = builder.next()
|
||||||
|
packet3, route3 = builder.next()
|
||||||
|
self.assertRaises(StopIteration, builder.next)
|
||||||
|
|
||||||
|
reconstructor = MessageReconstructor()
|
||||||
|
self.assertIsNone(
|
||||||
|
reconstructor.add(self.process_packet(packet1, route1)),
|
||||||
|
)
|
||||||
|
self.assertIsNone(
|
||||||
|
reconstructor.add(self.process_packet(packet3, route3)),
|
||||||
|
)
|
||||||
|
self.assertIsNone(
|
||||||
|
reconstructor.add(self.process_packet(packet2, route2)),
|
||||||
|
)
|
||||||
|
msg_with_flag = reconstructor.add(self.process_packet(packet0, route0))
|
||||||
|
assert msg_with_flag is not None
|
||||||
|
self.assertEqual(
|
||||||
|
PacketBuilder.parse_msg_and_flag(msg_with_flag),
|
||||||
|
(MessageFlag.MESSAGE_FLAG_REAL, msg),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cover_packet(self):
|
||||||
|
mixnet, topology = self.init()
|
||||||
|
|
||||||
|
msg = b"cover"
|
||||||
|
builder = PacketBuilder.drop_cover(msg, mixnet, topology)
|
||||||
|
packet, route = builder.next()
|
||||||
|
self.assertRaises(StopIteration, builder.next)
|
||||||
|
|
||||||
|
reconstructor = MessageReconstructor()
|
||||||
|
msg_with_flag = reconstructor.add(self.process_packet(packet, route))
|
||||||
|
assert msg_with_flag is not None
|
||||||
|
self.assertEqual(
|
||||||
|
PacketBuilder.parse_msg_and_flag(msg_with_flag),
|
||||||
|
(MessageFlag.MESSAGE_FLAG_DROP_COVER, msg),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init() -> Tuple[Mixnet, MixnetTopology]:
|
||||||
|
mixnet = Mixnet(
|
||||||
|
[
|
||||||
|
MixNode(
|
||||||
|
generate_bls(),
|
||||||
|
X25519PrivateKey.generate(),
|
||||||
|
random_bytes(32),
|
||||||
|
)
|
||||||
|
for _ in range(12)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
topology = mixnet.build_topology(b"entropy", 3, 3)
|
||||||
|
return mixnet, topology
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_packet(packet: SphinxPacket, route: List[MixNode]) -> Fragment:
|
||||||
|
processed = packet.process(route[0].encryption_private_key)
|
||||||
|
if isinstance(processed, ProcessedFinalHopPacket):
|
||||||
|
return Fragment.from_bytes(processed.payload.recover_plain_playload())
|
||||||
|
else:
|
||||||
|
processed = processed
|
||||||
|
for node in route[1:]:
|
||||||
|
p = processed.next_packet.process(node.encryption_private_key)
|
||||||
|
if isinstance(p, ProcessedFinalHopPacket):
|
||||||
|
return Fragment.from_bytes(p.payload.recover_plain_playload())
|
||||||
|
else:
|
||||||
|
processed = p
|
||||||
|
assert False
|
|
@ -3,5 +3,7 @@ cffi==1.16.0
|
||||||
cryptography==41.0.7
|
cryptography==41.0.7
|
||||||
numpy==1.26.3
|
numpy==1.26.3
|
||||||
pycparser==2.21
|
pycparser==2.21
|
||||||
|
pysphinx==0.0.1
|
||||||
scipy==1.11.4
|
scipy==1.11.4
|
||||||
setuptools==69.0.3
|
setuptools==69.0.3
|
||||||
|
wheel==0.42.0
|
||||||
|
|
Loading…
Reference in New Issue