diff --git a/codexdht/private/eth/p2p/discoveryv5/encoding.nim b/codexdht/private/eth/p2p/discoveryv5/encoding.nim index 2520b28..7fd18e7 100644 --- a/codexdht/private/eth/p2p/discoveryv5/encoding.nim +++ b/codexdht/private/eth/p2p/discoveryv5/encoding.nim @@ -40,6 +40,9 @@ declareCounter discovery_session_decrypt_failures, "Session decrypt failures" logScope: topics = "discv5" +type + cipher = aes128 + const version: uint16 = 1 idSignatureText = "discovery v5 identity proof" @@ -162,7 +165,7 @@ proc deriveKeys*(n1, n2: NodeId, priv: PrivateKey, pub: PublicKey, ok secrets proc encryptGCM*(key: AesKey, nonce, pt, authData: openArray[byte]): seq[byte] = - var ectx: GCM[aes128] + var ectx: GCM[cipher] ectx.init(key, nonce, authData) result = newSeq[byte](pt.len + gcmTagSize) ectx.encrypt(pt, result) @@ -175,7 +178,7 @@ proc decryptGCM*(key: AesKey, nonce, ct, authData: openArray[byte]): debug "cipher is missing tag", len = ct.len return - var dctx: GCM[aes128] + var dctx: GCM[cipher] dctx.init(key, nonce, authData) var res = newSeq[byte](ct.len - gcmTagSize) var tag: array[gcmTagSize, byte] @@ -189,7 +192,7 @@ proc decryptGCM*(key: AesKey, nonce, ct, authData: openArray[byte]): return some(res) proc encryptHeader*(id: NodeId, iv, header: openArray[byte]): seq[byte] = - var ectx: CTR[aes128] + var ectx: CTR[cipher] ectx.init(id.toByteArrayBE().toOpenArray(0, 15), iv) result = newSeq[byte](header.len) ectx.encrypt(header, result) @@ -374,7 +377,7 @@ proc decodeHeader*(id: NodeId, iv, maskedHeader: openArray[byte]): DecodeResult[(StaticHeader, seq[byte])] = # No need to check staticHeader size as that is included in minimum packet # size check in decodePacket - var ectx: CTR[aes128] + var ectx: CTR[cipher] ectx.init(id.toByteArrayBE().toOpenArray(0, aesKeySize - 1), iv) # Decrypt static-header part of the header var staticHeader = newSeq[byte](staticHeaderSize) diff --git a/codexdht/private/eth/p2p/discoveryv5/messages_encoding.nim b/codexdht/private/eth/p2p/discoveryv5/messages_encoding.nim index 707c325..09f690d 100644 --- a/codexdht/private/eth/p2p/discoveryv5/messages_encoding.nim +++ b/codexdht/private/eth/p2p/discoveryv5/messages_encoding.nim @@ -325,7 +325,7 @@ proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): seq[byte] = pb.write(2, encoded) pb.finish() result.add(pb.buffer) - trace "Encoded protobuf message", typ = $T, encoded + trace "Encoded protobuf message", typ = $T proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] = ## Decodes to the specific `Message` type. diff --git a/codexdht/private/eth/p2p/discoveryv5/protocol.nim b/codexdht/private/eth/p2p/discoveryv5/protocol.nim index 013f3d9..9deefbb 100644 --- a/codexdht/private/eth/p2p/discoveryv5/protocol.nim +++ b/codexdht/private/eth/p2p/discoveryv5/protocol.nim @@ -183,6 +183,9 @@ type DiscResult*[T] = Result[T, cstring] +func `$`*(p: Protocol): string = + $p.localNode.id + const defaultDiscoveryConfig* = DiscoveryConfig( tableIpLimits: DefaultTableIpLimits, @@ -453,18 +456,41 @@ proc replaceNode(d: Protocol, n: Node) = # peers in the routing table. debug "Message request to bootstrap node failed", src=d.localNode, dst=n +proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T, + reqId: RequestId) = + doAssert(toNode.address.isSome()) + let + message = encodeMessage(m, reqId) -proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId): + trace "Send message packet", dstId = toNode.id, + address = toNode.address, kind = messageKind(T) + discovery_message_requests_outgoing.inc() + + d.transport.sendMessage(toNode, message) + +proc waitResponse*[T: SomeMessage](d: Protocol, node: Node, msg: T): + Future[Option[Message]] = + let reqId = RequestId.init(d.rng[]) + result = d.waitMessage(node, reqId) + sendRequest(d, node, msg, reqId) + +proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId, timeout = ResponseTimeout): Future[Option[Message]] = result = newFuture[Option[Message]]("waitMessage") let res = result let key = (fromNode.id, reqId) - sleepAsync(ResponseTimeout).addCallback() do(data: pointer): + sleepAsync(timeout).addCallback() do(data: pointer): d.awaitedMessages.del(key) if not res.finished: res.complete(none(Message)) d.awaitedMessages[key] = result +proc waitNodeResponses*[T: SomeMessage](d: Protocol, node: Node, msg: T): + Future[DiscResult[seq[SignedPeerRecord]]] = + let reqId = RequestId.init(d.rng[]) + result = d.waitNodes(node, reqId) + sendRequest(d, node, msg, reqId) + proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): Future[DiscResult[seq[SignedPeerRecord]]] {.async.} = ## Wait for one or more nodes replies. @@ -493,40 +519,14 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): discovery_message_requests_outgoing.inc(labelValues = ["no_response"]) return err("Nodes message not received in time") -proc sendRequest*[T: SomeMessage](d: Protocol, toId: NodeId, toAddr: Address, m: T): - RequestId = - let - reqId = RequestId.init(d.rng[]) - message = encodeMessage(m, reqId) - - trace "Send message packet", dstId = toId, toAddr, kind = messageKind(T) - discovery_message_requests_outgoing.inc() - - d.transport.sendMessage(toId, toAddr, message) - return reqId - -proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T): - RequestId = - doAssert(toNode.address.isSome()) - let - reqId = RequestId.init(d.rng[]) - message = encodeMessage(m, reqId) - - trace "Send message packet", dstId = toNode.id, - address = toNode.address, kind = messageKind(T) - discovery_message_requests_outgoing.inc() - - d.transport.sendMessage(toNode, message) - return reqId - proc ping*(d: Protocol, toNode: Node): Future[DiscResult[PongMessage]] {.async.} = ## Send a discovery ping message. ## ## Returns the received pong message or an error. - let reqId = d.sendRequest(toNode, - PingMessage(sprSeq: d.localNode.record.seqNum)) - let resp = await d.waitMessage(toNode, reqId) + let + msg = PingMessage(sprSeq: d.localNode.record.seqNum) + resp = await d.waitResponse(toNode, msg) if resp.isSome(): if resp.get().kind == pong: @@ -547,8 +547,9 @@ proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]): ## ## Returns the received nodes or an error. ## Received SPRs are already validated and converted to `Node`. - let reqId = d.sendRequest(toNode, FindNodeMessage(distances: distances)) - let nodes = await d.waitNodes(toNode, reqId) + let + msg = FindNodeMessage(distances: distances) + nodes = await d.waitNodeResponses(toNode, msg) if nodes.isOk: let res = verifyNodesRecords(nodes.get(), toNode, FindNodeResultLimit, distances) @@ -565,8 +566,9 @@ proc findNodeFast*(d: Protocol, toNode: Node, target: NodeId): ## ## Returns the received nodes or an error. ## Received SPRs are already validated and converted to `Node`. - let reqId = d.sendRequest(toNode, FindNodeFastMessage(target: target)) - let nodes = await d.waitNodes(toNode, reqId) + let + msg = FindNodeFastMessage(target: target) + nodes = await d.waitNodeResponses(toNode, msg) if nodes.isOk: let res = verifyNodesRecords(nodes.get(), toNode, FindNodeFastResultLimit) @@ -582,9 +584,9 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]): ## Send a discovery talkreq message. ## ## Returns the received talkresp message or an error. - let reqId = d.sendRequest(toNode, - TalkReqMessage(protocol: protocol, request: request)) - let resp = await d.waitMessage(toNode, reqId) + let + msg = TalkReqMessage(protocol: protocol, request: request) + resp = await d.waitResponse(toNode, msg) if resp.isSome(): if resp.get().kind == talkResp: @@ -611,25 +613,18 @@ proc lookupDistances*(target, dest: NodeId): seq[uint16] = result.add(td - uint16(i)) inc i -proc lookupWorker(d: Protocol, destNode: Node, target: NodeId): +proc lookupWorker(d: Protocol, destNode: Node, target: NodeId, fast: bool): Future[seq[Node]] {.async.} = - let dists = lookupDistances(target, destNode.id) - # Instead of doing max `LookupRequestLimit` findNode requests, make use - # of the discv5.1 functionality to request nodes for multiple distances. - let r = await d.findNode(destNode, dists) - if r.isOk: - result.add(r[]) + let r = + if fast: + await d.findNodeFast(destNode, target) + else: + # Instead of doing max `LookupRequestLimit` findNode requests, make use + # of the discv5.1 functionality to request nodes for multiple distances. + let dists = lookupDistances(target, destNode.id) + await d.findNode(destNode, dists) - # Attempt to add all nodes discovered - for n in result: - discard d.addNode(n) - -proc lookupWorkerFast(d: Protocol, destNode: Node, target: NodeId): - Future[seq[Node]] {.async.} = - ## use terget NodeId based find_node - - let r = await d.findNodeFast(destNode, target) if r.isOk: result.add(r[]) @@ -660,10 +655,7 @@ proc lookup*(d: Protocol, target: NodeId, fast: bool = false): Future[seq[Node]] while i < closestNodes.len and pendingQueries.len < Alpha: let n = closestNodes[i] if not asked.containsOrIncl(n.id): - if fast: - pendingQueries.add(d.lookupWorkerFast(n, target)) - else: - pendingQueries.add(d.lookupWorker(n, target)) + pendingQueries.add(d.lookupWorker(n, target, fast)) inc i trace "discv5 pending queries", total = pendingQueries.len @@ -708,7 +700,8 @@ proc addProvider*( res.add(d.localNode) for toNode in res: if toNode != d.localNode: - discard d.sendRequest(toNode, AddProviderMessage(cId: cId, prov: pr)) + let reqId = RequestId.init(d.rng[]) + d.sendRequest(toNode, AddProviderMessage(cId: cId, prov: pr), reqId) else: asyncSpawn d.addProviderLocal(cId, pr) @@ -721,8 +714,7 @@ proc sendGetProviders(d: Protocol, toNode: Node, trace "sendGetProviders", toNode, msg let - reqId = d.sendRequest(toNode, msg) - resp = await d.waitMessage(toNode, reqId) + resp = await d.waitResponse(toNode, msg) if resp.isSome(): if resp.get().kind == MessageKind.providers: @@ -824,7 +816,7 @@ proc query*(d: Protocol, target: NodeId, k = BUCKET_SIZE): Future[seq[Node]] while i < min(queryBuffer.len, k) and pendingQueries.len < Alpha: let n = queryBuffer[i] if not asked.containsOrIncl(n.id): - pendingQueries.add(d.lookupWorker(n, target)) + pendingQueries.add(d.lookupWorker(n, target, false)) inc i trace "discv5 pending queries", total = pendingQueries.len