diff --git a/codexdht/private/eth/p2p/discoveryv5/encoding.nim b/codexdht/private/eth/p2p/discoveryv5/encoding.nim index 1743559..65f9e7c 100644 --- a/codexdht/private/eth/p2p/discoveryv5/encoding.nim +++ b/codexdht/private/eth/p2p/discoveryv5/encoding.nim @@ -209,8 +209,9 @@ proc encodeStaticHeader*(flag: Flag, nonce: AESGCMNonce, authSize: int): proc encodeMessagePacket*(rng: var HmacDrbgContext, c: var Codec, toId: NodeId, toAddr: Address, message: openArray[byte]): - (seq[byte], AESGCMNonce) = + (seq[byte], AESGCMNonce, bool) = var nonce: AESGCMNonce + var haskey: bool hmacDrbgGenerate(rng, nonce) # Random AESGCM nonce var iv: array[ivSize, byte] hmacDrbgGenerate(rng, iv) # Random IV @@ -228,6 +229,7 @@ proc encodeMessagePacket*(rng: var HmacDrbgContext, c: var Codec, var messageEncrypted: seq[byte] var initiatorKey, recipientKey: AesKey if c.sessions.load(toId, toAddr, recipientKey, initiatorKey): + haskey = true messageEncrypted = encryptGCM(initiatorKey, nonce, message, @iv & header) discovery_session_lru_cache_hits.inc() else: @@ -238,6 +240,7 @@ proc encodeMessagePacket*(rng: var HmacDrbgContext, c: var Codec, # message. 16 bytes for the gcm tag and 4 bytes for ping with requestId of # 1 byte (e.g "01c20101"). Could increase to 27 for 8 bytes requestId in # case this must not look like a random packet. + haskey = false var randomData: array[gcmTagSize + 4, byte] hmacDrbgGenerate(rng, randomData) messageEncrypted.add(randomData) @@ -250,7 +253,7 @@ proc encodeMessagePacket*(rng: var HmacDrbgContext, c: var Codec, packet.add(maskedHeader) packet.add(messageEncrypted) - return (packet, nonce) + return (packet, nonce, haskey) proc encodeWhoareyouPacket*(rng: var HmacDrbgContext, c: var Codec, toId: NodeId, toAddr: Address, requestNonce: AESGCMNonce, recordSeq: uint64, diff --git a/codexdht/private/eth/p2p/discoveryv5/transport.nim b/codexdht/private/eth/p2p/discoveryv5/transport.nim index 2cf48df..461335c 100644 --- a/codexdht/private/eth/p2p/discoveryv5/transport.nim +++ b/codexdht/private/eth/p2p/discoveryv5/transport.nim @@ -6,7 +6,7 @@ # Everything below the handling of ordinary messages import - std/[tables, options], + std/[tables, options, sets], bearssl/rand, chronos, chronicles, @@ -26,6 +26,8 @@ type bindAddress: Address ## UDP binding address transp: DatagramTransport pendingRequests: Table[AESGCMNonce, PendingRequest] + keyexchangeInProgress: HashSet[NodeId] + pendingRequestsByNode: Table[NodeId, seq[seq[byte]]] codec*: Codec rng: ref HmacDrbgContext @@ -34,6 +36,7 @@ type message: seq[byte] proc sendToA(t: Transport, a: Address, data: seq[byte]) = + trace "Send packet", myport = t.bindAddress.port, address = a let ta = initTAddress(a.ip, a.port) let f = t.transp.sendTo(ta, data) f.callback = proc(data: pointer) {.gcsafe.} = @@ -55,7 +58,7 @@ proc send(t: Transport, n: Node, data: seq[byte]) = t.sendToA(n.address.get(), data) proc sendMessage*(t: Transport, toId: NodeId, toAddr: Address, message: seq[byte]) = - let (data, _) = encodeMessagePacket(t.rng[], t.codec, toId, toAddr, + let (data, _, _) = encodeMessagePacket(t.rng[], t.codec, toId, toAddr, message) t.sendToA(toAddr, data) @@ -73,11 +76,30 @@ proc registerRequest(t: Transport, n: Node, message: seq[byte], proc sendMessage*(t: Transport, toNode: Node, message: seq[byte]) = doAssert(toNode.address.isSome()) let address = toNode.address.get() - let (data, nonce) = encodeMessagePacket(t.rng[], t.codec, + let (data, nonce, haskey) = encodeMessagePacket(t.rng[], t.codec, toNode.id, address, message) - t.registerRequest(toNode, message, nonce) - t.send(toNode, data) + if haskey: + trace "Send message: has key", myport = t.bindAddress.port , dstId = toNode + t.registerRequest(toNode, message, nonce) + t.send(toNode, data) + else: + # we don't have an encryption key for this target, so we should initiate keyexchange + if not (toNode.id in t.keyexchangeInProgress): + trace "Send message: send random to trigger Whoareyou", myport = t.bindAddress.port , dstId = toNode + t.registerRequest(toNode, message, nonce) + t.send(toNode, data) + t.keyexchangeInProgress.incl(toNode.id) + trace "keyexchangeInProgress added", myport = t.bindAddress.port , dstId = toNode + sleepAsync(responseTimeout).addCallback() do(data: pointer): + t.keyexchangeInProgress.excl(toNode.id) + trace "keyexchangeInProgress removed (timeout)", myport = t.bindAddress.port , dstId = toNode + else: + # delay sending this message until whoareyou is received and handshake is sent + # have to reencode once keys are clear + t.pendingRequestsByNode.mgetOrPut(toNode.id, newSeq[seq[byte]]()).add(message) + trace "Send message: Node with this id already has ongoing keyexchage, delaying packet", + myport = t.bindAddress.port , dstId = toNode, qlen=t.pendingRequestsByNode[toNode.id].len proc sendWhoareyou(t: Transport, toId: NodeId, a: Address, requestNonce: AESGCMNonce, node: Option[Node]) = @@ -94,13 +116,28 @@ proc sendWhoareyou(t: Transport, toId: NodeId, a: Address, sleepAsync(handshakeTimeout).addCallback() do(data: pointer): # TODO: should we still provide cancellation in case handshake completes # correctly? - t.codec.handshakes.del(key) + if t.codec.hasHandshake(key): + debug "Handshake timeout", myport = t.bindAddress.port , dstId = toId, address = a + t.codec.handshakes.del(key) trace "Send whoareyou", dstId = toId, address = a t.sendToA(a, data) else: - debug "Node with this id already has ongoing handshake, ignoring packet" + # TODO: is this reasonable to drop it? Should we allow a mini-queue here? + # Queue should be on sender side, as this is random encoded! + debug "Node with this id already has ongoing handshake, queuing packet", myport = t.bindAddress.port , dstId = toId, address = a +proc sendPending(t:Transport, toNode: Node): + Future[void] {.async.} = + if t.pendingRequestsByNode.hasKey(toNode.id): + trace "Found pending request", myport = t.bindAddress.port, src = toNode, len = t.pendingRequestsByNode[toNode.id].len + for message in t.pendingRequestsByNode[toNode.id]: + trace "Sending pending packet", myport = t.bindAddress.port, dstId = toNode.id + let address = toNode.address.get() + let (data, nonce, haskey) = encodeMessagePacket(t.rng[], t.codec, toNode.id, address, message) + t.registerRequest(toNode, message, nonce) + t.send(toNode, data) + t.pendingRequestsByNode.del(toNode.id) proc receive*(t: Transport, a: Address, packet: openArray[byte]) = let decoded = t.codec.decodePacket(a, packet) if decoded.isOk: @@ -109,17 +146,23 @@ proc receive*(t: Transport, a: Address, packet: openArray[byte]) = of OrdinaryMessage: if packet.messageOpt.isSome(): let message = packet.messageOpt.get() - trace "Received message packet", srcId = packet.srcId, address = a, + trace "Received message packet", myport = t.bindAddress.port, srcId = packet.srcId, address = a, kind = message.kind, p = $packet t.client.handleMessage(packet.srcId, a, message) else: - trace "Not decryptable message packet received", + trace "Not decryptable message packet received", myport = t.bindAddress.port, srcId = packet.srcId, address = a + # If we already have a keyexchange in progress, we have a case of simultaneous cross-connect. + # We could try to decide here which should go on, but since we are on top of UDP, a more robust + # choice is to answer here and resolve conflicts in the next stage (reception of Whoareyou), or + # even later (reception of Handshake). + if packet.srcId in t.keyexchangeInProgress: + trace "cross-connect detected, still sending Whoareyou" t.sendWhoareyou(packet.srcId, a, packet.requestNonce, t.client.getNode(packet.srcId)) of Flag.Whoareyou: - trace "Received whoareyou packet", address = a + trace "Received whoareyou packet", myport = t.bindAddress.port, address = a var pr: PendingRequest if t.pendingRequests.take(packet.whoareyou.requestNonce, pr): let toNode = pr.node @@ -136,12 +179,17 @@ proc receive*(t: Transport, a: Address, packet: openArray[byte]) = toNode.pubkey ).expect("Valid handshake packet to encode") - trace "Send handshake message packet", dstId = toNode.id, address + trace "Send handshake message packet", myport = t.bindAddress.port, dstId = toNode.id, address t.send(toNode, data) + # keyexchange ready, we can send queued packets + t.keyexchangeInProgress.excl(toNode.id) + trace "keyexchangeInProgress removed (finished)", myport = t.bindAddress.port, dstId = toNode.id, address + discard t.sendPending(toNode) + else: debug "Timed out or unrequested whoareyou packet", address = a of HandshakeMessage: - trace "Received handshake message packet", srcId = packet.srcIdHs, + trace "Received handshake message packet", myport = t.bindAddress.port, srcId = packet.srcIdHs, address = a, kind = packet.message.kind t.client.handleMessage(packet.srcIdHs, a, packet.message) # For a handshake message it is possible that we received an newer SPR. @@ -157,9 +205,10 @@ proc receive*(t: Transport, a: Address, packet: openArray[byte]) = # sending the 'whoareyou' message to. In that case, we can set 'seen' node.seen = true if t.client.addNode(node): - trace "Added new node to routing table after handshake", node + trace "Added new node to routing table after handshake", node, tablesize=t.client.nodesDiscovered() + discard t.sendPending(node) else: - trace "Packet decoding error", error = decoded.error, address = a + trace "Packet decoding error", myport = t.bindAddress.port, error = decoded.error, address = a proc processClient[T](transp: DatagramTransport, raddr: TransportAddress): Future[void] {.async.} = diff --git a/tests/discv5/test_discoveryv5.nim b/tests/discv5/test_discoveryv5.nim index 1255569..34aeeb1 100644 --- a/tests/discv5/test_discoveryv5.nim +++ b/tests/discv5/test_discoveryv5.nim @@ -629,7 +629,7 @@ suite "Discovery v5 Tests": sendNode = newNode(enrRec).expect("Properly initialized record") var codec = Codec(localNode: sendNode, privKey: privKey, sessions: Sessions.init(5)) - let (packet, _) = encodeMessagePacket(rng[], codec, + let (packet, _, _) = encodeMessagePacket(rng[], codec, receiveNode.localNode.id, receiveNode.localNode.address.get(), @[]) receiveNode.transport.receive(a, packet) @@ -659,7 +659,7 @@ suite "Discovery v5 Tests": var codec = Codec(localNode: sendNode, privKey: privKey, sessions: Sessions.init(5)) for i in 0 ..< 5: let a = localAddress(20303 + i) - let (packet, _) = encodeMessagePacket(rng[], codec, + let (packet, _, _) = encodeMessagePacket(rng[], codec, receiveNode.localNode.id, receiveNode.localNode.address.get(), @[]) receiveNode.transport.receive(a, packet) @@ -691,7 +691,7 @@ suite "Discovery v5 Tests": var firstRequestNonce: AESGCMNonce for i in 0 ..< 5: - let (packet, requestNonce) = encodeMessagePacket(rng[], codec, + let (packet, requestNonce, _) = encodeMessagePacket(rng[], codec, receiveNode.localNode.id, receiveNode.localNode.address.get(), @[]) receiveNode.transport.receive(a, packet) if i == 0: diff --git a/tests/discv5/test_discoveryv5_encoding.nim b/tests/discv5/test_discoveryv5_encoding.nim index e2c87bc..4c50d84 100644 --- a/tests/discv5/test_discoveryv5_encoding.nim +++ b/tests/discv5/test_discoveryv5_encoding.nim @@ -526,7 +526,7 @@ suite "Discovery v5.1 Additional Encode/Decode": reqId = RequestId.init(rng[]) message = encodeMessage(m, reqId) - let (data, nonce) = encodeMessagePacket(rng[], codecA, nodeB.id, + let (data, nonce, _) = encodeMessagePacket(rng[], codecA, nodeB.id, nodeB.address.get(), message) let decoded = codecB.decodePacket(nodeA.address.get(), data) @@ -642,7 +642,7 @@ suite "Discovery v5.1 Additional Encode/Decode": codecB.sessions.store(nodeA.id, nodeA.address.get(), secrets.initiatorKey, secrets.recipientKey) - let (data, nonce) = encodeMessagePacket(rng[], codecA, nodeB.id, + let (data, nonce, _) = encodeMessagePacket(rng[], codecA, nodeB.id, nodeB.address.get(), message) let decoded = codecB.decodePacket(nodeA.address.get(), data)