Mixnet: Sphinx packet builder for mix clients (#47)
This commit is contained in:
parent
2263327320
commit
1fc319de9e
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List, TypeAlias
|
||||
|
||||
|
@ -7,6 +8,7 @@ from cryptography.hazmat.primitives.asymmetric.x25519 import (
|
|||
X25519PrivateKey,
|
||||
X25519PublicKey,
|
||||
)
|
||||
from pysphinx.node import Node
|
||||
|
||||
from mixnet.bls import BlsPrivateKey, BlsPublicKey
|
||||
from mixnet.fisheryates import FisherYates
|
||||
|
@ -43,24 +45,29 @@ class Mixnet:
|
|||
layers.append(layer)
|
||||
return MixnetTopology(layers)
|
||||
|
||||
def choose_mixnode(self) -> MixNode:
|
||||
return random.choice(self.mix_nodes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixNode:
|
||||
identity_public_key: BlsPublicKey
|
||||
encryption_public_key: X25519PublicKey
|
||||
identity_private_key: BlsPrivateKey
|
||||
encryption_private_key: X25519PrivateKey
|
||||
addr: NodeAddress
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identity_private_key: BlsPrivateKey,
|
||||
encryption_private_key: X25519PrivateKey,
|
||||
addr: NodeAddress,
|
||||
):
|
||||
self.identity_public_key = identity_private_key.get_g1()
|
||||
self.encryption_public_key = encryption_private_key.public_key()
|
||||
self.addr = addr
|
||||
def identity_public_key(self) -> BlsPublicKey:
|
||||
return self.identity_private_key.get_g1()
|
||||
|
||||
def encryption_public_key(self) -> X25519PublicKey:
|
||||
return self.encryption_private_key.public_key()
|
||||
|
||||
def sphinx_node(self) -> Node:
|
||||
return Node(self.encryption_private_key, self.addr)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixnetTopology:
|
||||
layers: List[List[MixNode]]
|
||||
|
||||
def generate_route(self) -> list[MixNode]:
|
||||
return [random.choice(layer) for layer in self.layers]
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from itertools import batched
|
||||
from typing import Dict, Iterator, List, Self, Tuple, TypeAlias
|
||||
|
||||
from pysphinx.payload import Payload
|
||||
from pysphinx.sphinx import SphinxPacket
|
||||
|
||||
from mixnet.mixnet import Mixnet, MixnetTopology, MixNode
|
||||
|
||||
|
||||
class MessageFlag(Enum):
|
||||
MESSAGE_FLAG_REAL = b"\x00"
|
||||
MESSAGE_FLAG_DROP_COVER = b"\x01"
|
||||
|
||||
def bytes(self) -> bytes:
|
||||
return bytes(self.value)
|
||||
|
||||
|
||||
class PacketBuilder:
|
||||
iter: Iterator[Tuple[SphinxPacket, List[MixNode]]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flag: MessageFlag,
|
||||
message: bytes,
|
||||
mixnet: Mixnet,
|
||||
topology: MixnetTopology,
|
||||
):
|
||||
destination = mixnet.choose_mixnode()
|
||||
|
||||
msg_with_flag = flag.bytes() + message
|
||||
# NOTE: We don't encrypt msg_with_flag for destination.
|
||||
# If encryption is needed, a shared secret must be appended in front of the message along with the MessageFlag.
|
||||
fragment_set = FragmentSet(msg_with_flag)
|
||||
|
||||
packets_and_routes = []
|
||||
for fragment in fragment_set.fragments:
|
||||
route = topology.generate_route()
|
||||
packet = SphinxPacket.build(
|
||||
fragment.bytes(),
|
||||
[mixnode.sphinx_node() for mixnode in route],
|
||||
destination.sphinx_node(),
|
||||
)
|
||||
packets_and_routes.append((packet, route))
|
||||
|
||||
self.iter = iter(packets_and_routes)
|
||||
|
||||
@classmethod
|
||||
def real(cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology) -> Self:
|
||||
return cls(MessageFlag.MESSAGE_FLAG_REAL, message, mixnet, topology)
|
||||
|
||||
@classmethod
|
||||
def drop_cover(
|
||||
cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology
|
||||
) -> Self:
|
||||
return cls(MessageFlag.MESSAGE_FLAG_DROP_COVER, message, mixnet, topology)
|
||||
|
||||
def next(self) -> Tuple[SphinxPacket, List[MixNode]]:
|
||||
return next(self.iter)
|
||||
|
||||
@staticmethod
|
||||
def parse_msg_and_flag(data: bytes) -> Tuple[MessageFlag, bytes]:
|
||||
"""Remove a MessageFlag from data"""
|
||||
if len(data) < 1:
|
||||
raise ValueError("data is too short")
|
||||
|
||||
return (MessageFlag(data[0:1]), data[1:])
|
||||
|
||||
|
||||
# Unlikely, Nym uses i32 for FragmentSetId, which may cause more collisions.
|
||||
# We will use UUID until figuring out why Nym uses i32.
|
||||
FragmentSetId: TypeAlias = bytes # 128bit UUID v4
|
||||
FragmentId: TypeAlias = int # unsigned 8bit int in big endian
|
||||
|
||||
FRAGMENT_SET_ID_LENGTH: int = 16
|
||||
FRAGMENT_ID_LENGTH: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class FragmentHeader:
|
||||
"""
|
||||
Contain all information for reconstructing a message that was fragmented into the same FragmentSet.
|
||||
"""
|
||||
|
||||
set_id: FragmentSetId
|
||||
total_fragments: FragmentId
|
||||
fragment_id: FragmentId
|
||||
|
||||
SIZE: int = FRAGMENT_SET_ID_LENGTH + FRAGMENT_ID_LENGTH * 2
|
||||
|
||||
@staticmethod
|
||||
def max_total_fragments() -> int:
|
||||
return 256 # because total_fragment is u8
|
||||
|
||||
def bytes(self) -> bytes:
|
||||
return (
|
||||
self.set_id
|
||||
+ self.total_fragments.to_bytes(1)
|
||||
+ self.fragment_id.to_bytes(1)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
if len(data) != cls.SIZE:
|
||||
raise ValueError("Invalid data length", len(data))
|
||||
|
||||
return cls(data[:16], int.from_bytes(data[16:17]), int.from_bytes(data[17:18]))
|
||||
|
||||
|
||||
@dataclass
|
||||
class FragmentSet:
|
||||
"""
|
||||
Represent a set of Fragments that can be reconstructed to a single original message.
|
||||
|
||||
Note that the maximum number of fragments in a FragmentSet is limited for now.
|
||||
"""
|
||||
|
||||
fragments: List[Fragment]
|
||||
|
||||
MAX_FRAGMENTS: int = FragmentHeader.max_total_fragments()
|
||||
|
||||
def __init__(self, message: bytes):
|
||||
"""
|
||||
Build a FragmentSet by chunking a message into Fragments.
|
||||
"""
|
||||
chunked_messages = chunks(message, Fragment.MAX_PAYLOAD_SIZE)
|
||||
# For now, we don't support more than max_fragments() fragments.
|
||||
# If needed, we can devise the FragmentSet chaining to support larger messages, like Nym.
|
||||
if len(chunked_messages) > self.MAX_FRAGMENTS:
|
||||
raise ValueError(f"Too long message: {len(chunked_messages)} chunks")
|
||||
|
||||
set_id = uuid.uuid4().bytes
|
||||
self.fragments = [
|
||||
Fragment(FragmentHeader(set_id, len(chunked_messages), i), chunk)
|
||||
for i, chunk in enumerate(chunked_messages)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fragment:
|
||||
"""Represent a piece of data that can be transformed to a single SphinxPacket"""
|
||||
|
||||
header: FragmentHeader
|
||||
body: bytes
|
||||
|
||||
MAX_PAYLOAD_SIZE: int = Payload.max_plain_payload_size() - FragmentHeader.SIZE
|
||||
|
||||
def bytes(self) -> bytes:
|
||||
return self.header.bytes() + self.body
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
header = FragmentHeader.from_bytes(data[: FragmentHeader.SIZE])
|
||||
body = data[FragmentHeader.SIZE :]
|
||||
return cls(header, body)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageReconstructor:
|
||||
fragmentSets: Dict[FragmentSetId, FragmentSetReconstructor]
|
||||
|
||||
def __init__(self):
|
||||
self.fragmentSets = {}
|
||||
|
||||
def add(self, fragment: Fragment) -> bytes | None:
|
||||
if fragment.header.set_id not in self.fragmentSets:
|
||||
self.fragmentSets[fragment.header.set_id] = FragmentSetReconstructor(
|
||||
fragment.header.total_fragments
|
||||
)
|
||||
|
||||
msg = self.fragmentSets[fragment.header.set_id].add(fragment)
|
||||
if msg is not None:
|
||||
del self.fragmentSets[fragment.header.set_id]
|
||||
return msg
|
||||
|
||||
|
||||
@dataclass
|
||||
class FragmentSetReconstructor:
|
||||
total_fragments: FragmentId
|
||||
fragments: Dict[FragmentId, Fragment]
|
||||
|
||||
def __init__(self, total_fragments: FragmentId):
|
||||
self.total_fragments = total_fragments
|
||||
self.fragments = {}
|
||||
|
||||
def add(self, fragment: Fragment) -> bytes | None:
|
||||
self.fragments[fragment.header.fragment_id] = fragment
|
||||
if len(self.fragments) == self.total_fragments:
|
||||
return self.build_message()
|
||||
else:
|
||||
return None
|
||||
|
||||
def build_message(self) -> bytes:
|
||||
message = b""
|
||||
for i in range(self.total_fragments):
|
||||
message += self.fragments[FragmentId(i)].body
|
||||
return message
|
||||
|
||||
|
||||
def chunks(data: bytes, size: int) -> List[bytes]:
|
||||
return list(map(bytes, batched(data, size)))
|
|
@ -0,0 +1,91 @@
|
|||
from typing import List, Tuple
|
||||
from unittest import TestCase
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
||||
from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket
|
||||
|
||||
from mixnet.bls import generate_bls
|
||||
from mixnet.mixnet import Mixnet, MixnetTopology, MixNode
|
||||
from mixnet.packet import (
|
||||
Fragment,
|
||||
MessageFlag,
|
||||
MessageReconstructor,
|
||||
PacketBuilder,
|
||||
)
|
||||
from mixnet.utils import random_bytes
|
||||
|
||||
|
||||
class TestPacket(TestCase):
|
||||
def test_real_packet(self):
|
||||
mixnet, topology = self.init()
|
||||
|
||||
msg = random_bytes(3500)
|
||||
builder = PacketBuilder.real(msg, mixnet, topology)
|
||||
packet0, route0 = builder.next()
|
||||
packet1, route1 = builder.next()
|
||||
packet2, route2 = builder.next()
|
||||
packet3, route3 = builder.next()
|
||||
self.assertRaises(StopIteration, builder.next)
|
||||
|
||||
reconstructor = MessageReconstructor()
|
||||
self.assertIsNone(
|
||||
reconstructor.add(self.process_packet(packet1, route1)),
|
||||
)
|
||||
self.assertIsNone(
|
||||
reconstructor.add(self.process_packet(packet3, route3)),
|
||||
)
|
||||
self.assertIsNone(
|
||||
reconstructor.add(self.process_packet(packet2, route2)),
|
||||
)
|
||||
msg_with_flag = reconstructor.add(self.process_packet(packet0, route0))
|
||||
assert msg_with_flag is not None
|
||||
self.assertEqual(
|
||||
PacketBuilder.parse_msg_and_flag(msg_with_flag),
|
||||
(MessageFlag.MESSAGE_FLAG_REAL, msg),
|
||||
)
|
||||
|
||||
def test_cover_packet(self):
|
||||
mixnet, topology = self.init()
|
||||
|
||||
msg = b"cover"
|
||||
builder = PacketBuilder.drop_cover(msg, mixnet, topology)
|
||||
packet, route = builder.next()
|
||||
self.assertRaises(StopIteration, builder.next)
|
||||
|
||||
reconstructor = MessageReconstructor()
|
||||
msg_with_flag = reconstructor.add(self.process_packet(packet, route))
|
||||
assert msg_with_flag is not None
|
||||
self.assertEqual(
|
||||
PacketBuilder.parse_msg_and_flag(msg_with_flag),
|
||||
(MessageFlag.MESSAGE_FLAG_DROP_COVER, msg),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init() -> Tuple[Mixnet, MixnetTopology]:
|
||||
mixnet = Mixnet(
|
||||
[
|
||||
MixNode(
|
||||
generate_bls(),
|
||||
X25519PrivateKey.generate(),
|
||||
random_bytes(32),
|
||||
)
|
||||
for _ in range(12)
|
||||
]
|
||||
)
|
||||
topology = mixnet.build_topology(b"entropy", 3, 3)
|
||||
return mixnet, topology
|
||||
|
||||
@staticmethod
|
||||
def process_packet(packet: SphinxPacket, route: List[MixNode]) -> Fragment:
|
||||
processed = packet.process(route[0].encryption_private_key)
|
||||
if isinstance(processed, ProcessedFinalHopPacket):
|
||||
return Fragment.from_bytes(processed.payload.recover_plain_playload())
|
||||
else:
|
||||
processed = processed
|
||||
for node in route[1:]:
|
||||
p = processed.next_packet.process(node.encryption_private_key)
|
||||
if isinstance(p, ProcessedFinalHopPacket):
|
||||
return Fragment.from_bytes(p.payload.recover_plain_playload())
|
||||
else:
|
||||
processed = p
|
||||
assert False
|
|
@ -3,5 +3,7 @@ cffi==1.16.0
|
|||
cryptography==41.0.7
|
||||
numpy==1.26.3
|
||||
pycparser==2.21
|
||||
pysphinx==0.0.1
|
||||
scipy==1.11.4
|
||||
setuptools==69.0.3
|
||||
wheel==0.42.0
|
||||
|
|
Loading…
Reference in New Issue