diff --git a/mixnet/v2/sim/p2p.py b/mixnet/v2/sim/p2p.py index c82c055..3a44e95 100644 --- a/mixnet/v2/sim/p2p.py +++ b/mixnet/v2/sim/p2p.py @@ -43,7 +43,7 @@ class P2P(ABC): yield self.env.timeout(0) @abstractmethod - def send(self, msg: SphinxPacket | bytes, receiver: "Node"): + def send(self, msg: SphinxPacket | bytes, sender: "Node", receiver: "Node"): # simulate network latency yield self.env.timeout(random.uniform(0, self.config.p2p.max_network_latency)) # Measurement and adversary @@ -64,12 +64,12 @@ class NaiveBroadcastP2P(P2P): def broadcast(self, sender: "Node", msg: SphinxPacket | bytes): yield from super().broadcast(sender, msg) self.log(f"Node:{sender.id}: Broadcasting a msg: {len(msg)} bytes") - for node in self.nodes: + for receiver in self.nodes: self.measurement.measure_egress(sender, msg) - self.env.process(self.send(msg, node)) + self.env.process(self.send(msg, sender, receiver)) - def send(self, msg: SphinxPacket | bytes, receiver: "Node"): - yield from super().send(msg, receiver) + def send(self, msg: SphinxPacket | bytes, sender: "Node", receiver: "Node"): + yield from super().send(msg, sender, receiver) self.env.process(receiver.receive_message(msg)) @@ -77,7 +77,7 @@ class GossipP2P(P2P): def __init__(self, env: simpy.Environment, config: Config): super().__init__(env, config) self.topology = defaultdict(set) - self.message_cache = defaultdict(set) + self.message_cache = defaultdict(dict) def set_nodes(self, nodes: list["Node"]): super().set_nodes(nodes) @@ -104,18 +104,24 @@ class GossipP2P(P2P): yield from super().broadcast(sender, msg) self.log(f"Node:{sender.id}: Gossiping a msg: {len(msg)} bytes") + # if the msg is created originally by the sender (not forwarded from others), cache it with the sender itself. msg_hash = hashlib.sha256(bytes(msg)).digest() - self.message_cache[sender].add(msg_hash) + if msg_hash not in self.message_cache[sender]: + self.message_cache[sender][msg_hash] = sender for receiver in self.topology[sender]: - self.measurement.measure_egress(sender, msg) - self.env.process(self.send(msg, receiver)) + # Don't gossip the message if it was received from the node who is going to be the receiver, + # which means that the node already knows the message. + if receiver != self.message_cache[sender][msg_hash]: + self.measurement.measure_egress(sender, msg) + self.env.process(self.send(msg, sender, receiver)) - def send(self, msg: SphinxPacket | bytes, receiver: "Node"): - yield from super().send(msg, receiver) - # receive the msg only if it hasn't been received before + def send(self, msg: SphinxPacket | bytes, sender: "Node", receiver: "Node"): + yield from super().send(msg, sender, receiver) + # Receive/gossip the msg only if it hasn't been received before. + # i.e. each message is received/gossiped at most once by each node. msg_hash = hashlib.sha256(bytes(msg)).digest() if msg_hash not in self.message_cache[receiver]: - self.message_cache[receiver].add(msg_hash) + self.message_cache[receiver][msg_hash] = sender self.env.process(receiver.receive_message(msg)) self.env.process(self.broadcast(receiver, msg))