mirror of
https://github.com/codex-storage/nim-codex-dht.git
synced 2025-01-09 11:32:18 +00:00
commit
b8bcb2d08d
@ -40,6 +40,9 @@ declareCounter discovery_session_decrypt_failures, "Session decrypt failures"
|
||||
logScope:
|
||||
topics = "discv5"
|
||||
|
||||
type
|
||||
cipher = aes128
|
||||
|
||||
const
|
||||
version: uint16 = 1
|
||||
idSignatureText = "discovery v5 identity proof"
|
||||
@ -162,7 +165,7 @@ proc deriveKeys*(n1, n2: NodeId, priv: PrivateKey, pub: PublicKey,
|
||||
ok secrets
|
||||
|
||||
proc encryptGCM*(key: AesKey, nonce, pt, authData: openArray[byte]): seq[byte] =
|
||||
var ectx: GCM[aes128]
|
||||
var ectx: GCM[cipher]
|
||||
ectx.init(key, nonce, authData)
|
||||
result = newSeq[byte](pt.len + gcmTagSize)
|
||||
ectx.encrypt(pt, result)
|
||||
@ -175,7 +178,7 @@ proc decryptGCM*(key: AesKey, nonce, ct, authData: openArray[byte]):
|
||||
debug "cipher is missing tag", len = ct.len
|
||||
return
|
||||
|
||||
var dctx: GCM[aes128]
|
||||
var dctx: GCM[cipher]
|
||||
dctx.init(key, nonce, authData)
|
||||
var res = newSeq[byte](ct.len - gcmTagSize)
|
||||
var tag: array[gcmTagSize, byte]
|
||||
@ -189,7 +192,7 @@ proc decryptGCM*(key: AesKey, nonce, ct, authData: openArray[byte]):
|
||||
return some(res)
|
||||
|
||||
proc encryptHeader*(id: NodeId, iv, header: openArray[byte]): seq[byte] =
|
||||
var ectx: CTR[aes128]
|
||||
var ectx: CTR[cipher]
|
||||
ectx.init(id.toByteArrayBE().toOpenArray(0, 15), iv)
|
||||
result = newSeq[byte](header.len)
|
||||
ectx.encrypt(header, result)
|
||||
@ -374,7 +377,7 @@ proc decodeHeader*(id: NodeId, iv, maskedHeader: openArray[byte]):
|
||||
DecodeResult[(StaticHeader, seq[byte])] =
|
||||
# No need to check staticHeader size as that is included in minimum packet
|
||||
# size check in decodePacket
|
||||
var ectx: CTR[aes128]
|
||||
var ectx: CTR[cipher]
|
||||
ectx.init(id.toByteArrayBE().toOpenArray(0, aesKeySize - 1), iv)
|
||||
# Decrypt static-header part of the header
|
||||
var staticHeader = newSeq[byte](staticHeaderSize)
|
||||
|
@ -325,7 +325,7 @@ proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): seq[byte] =
|
||||
pb.write(2, encoded)
|
||||
pb.finish()
|
||||
result.add(pb.buffer)
|
||||
trace "Encoded protobuf message", typ = $T, encoded
|
||||
trace "Encoded protobuf message", typ = $T
|
||||
|
||||
proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] =
|
||||
## Decodes to the specific `Message` type.
|
||||
|
@ -183,6 +183,9 @@ type
|
||||
|
||||
DiscResult*[T] = Result[T, cstring]
|
||||
|
||||
func `$`*(p: Protocol): string =
|
||||
$p.localNode.id
|
||||
|
||||
const
|
||||
defaultDiscoveryConfig* = DiscoveryConfig(
|
||||
tableIpLimits: DefaultTableIpLimits,
|
||||
@ -453,18 +456,41 @@ proc replaceNode(d: Protocol, n: Node) =
|
||||
# peers in the routing table.
|
||||
debug "Message request to bootstrap node failed", src=d.localNode, dst=n
|
||||
|
||||
proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T,
|
||||
reqId: RequestId) =
|
||||
doAssert(toNode.address.isSome())
|
||||
let
|
||||
message = encodeMessage(m, reqId)
|
||||
|
||||
proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId):
|
||||
trace "Send message packet", dstId = toNode.id,
|
||||
address = toNode.address, kind = messageKind(T)
|
||||
discovery_message_requests_outgoing.inc()
|
||||
|
||||
d.transport.sendMessage(toNode, message)
|
||||
|
||||
proc waitResponse*[T: SomeMessage](d: Protocol, node: Node, msg: T):
|
||||
Future[Option[Message]] =
|
||||
let reqId = RequestId.init(d.rng[])
|
||||
result = d.waitMessage(node, reqId)
|
||||
sendRequest(d, node, msg, reqId)
|
||||
|
||||
proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId, timeout = ResponseTimeout):
|
||||
Future[Option[Message]] =
|
||||
result = newFuture[Option[Message]]("waitMessage")
|
||||
let res = result
|
||||
let key = (fromNode.id, reqId)
|
||||
sleepAsync(ResponseTimeout).addCallback() do(data: pointer):
|
||||
sleepAsync(timeout).addCallback() do(data: pointer):
|
||||
d.awaitedMessages.del(key)
|
||||
if not res.finished:
|
||||
res.complete(none(Message))
|
||||
d.awaitedMessages[key] = result
|
||||
|
||||
proc waitNodeResponses*[T: SomeMessage](d: Protocol, node: Node, msg: T):
|
||||
Future[DiscResult[seq[SignedPeerRecord]]] =
|
||||
let reqId = RequestId.init(d.rng[])
|
||||
result = d.waitNodes(node, reqId)
|
||||
sendRequest(d, node, msg, reqId)
|
||||
|
||||
proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId):
|
||||
Future[DiscResult[seq[SignedPeerRecord]]] {.async.} =
|
||||
## Wait for one or more nodes replies.
|
||||
@ -493,40 +519,14 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId):
|
||||
discovery_message_requests_outgoing.inc(labelValues = ["no_response"])
|
||||
return err("Nodes message not received in time")
|
||||
|
||||
proc sendRequest*[T: SomeMessage](d: Protocol, toId: NodeId, toAddr: Address, m: T):
|
||||
RequestId =
|
||||
let
|
||||
reqId = RequestId.init(d.rng[])
|
||||
message = encodeMessage(m, reqId)
|
||||
|
||||
trace "Send message packet", dstId = toId, toAddr, kind = messageKind(T)
|
||||
discovery_message_requests_outgoing.inc()
|
||||
|
||||
d.transport.sendMessage(toId, toAddr, message)
|
||||
return reqId
|
||||
|
||||
proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T):
|
||||
RequestId =
|
||||
doAssert(toNode.address.isSome())
|
||||
let
|
||||
reqId = RequestId.init(d.rng[])
|
||||
message = encodeMessage(m, reqId)
|
||||
|
||||
trace "Send message packet", dstId = toNode.id,
|
||||
address = toNode.address, kind = messageKind(T)
|
||||
discovery_message_requests_outgoing.inc()
|
||||
|
||||
d.transport.sendMessage(toNode, message)
|
||||
return reqId
|
||||
|
||||
proc ping*(d: Protocol, toNode: Node):
|
||||
Future[DiscResult[PongMessage]] {.async.} =
|
||||
## Send a discovery ping message.
|
||||
##
|
||||
## Returns the received pong message or an error.
|
||||
let reqId = d.sendRequest(toNode,
|
||||
PingMessage(sprSeq: d.localNode.record.seqNum))
|
||||
let resp = await d.waitMessage(toNode, reqId)
|
||||
let
|
||||
msg = PingMessage(sprSeq: d.localNode.record.seqNum)
|
||||
resp = await d.waitResponse(toNode, msg)
|
||||
|
||||
if resp.isSome():
|
||||
if resp.get().kind == pong:
|
||||
@ -547,8 +547,9 @@ proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]):
|
||||
##
|
||||
## Returns the received nodes or an error.
|
||||
## Received SPRs are already validated and converted to `Node`.
|
||||
let reqId = d.sendRequest(toNode, FindNodeMessage(distances: distances))
|
||||
let nodes = await d.waitNodes(toNode, reqId)
|
||||
let
|
||||
msg = FindNodeMessage(distances: distances)
|
||||
nodes = await d.waitNodeResponses(toNode, msg)
|
||||
|
||||
if nodes.isOk:
|
||||
let res = verifyNodesRecords(nodes.get(), toNode, FindNodeResultLimit, distances)
|
||||
@ -565,8 +566,9 @@ proc findNodeFast*(d: Protocol, toNode: Node, target: NodeId):
|
||||
##
|
||||
## Returns the received nodes or an error.
|
||||
## Received SPRs are already validated and converted to `Node`.
|
||||
let reqId = d.sendRequest(toNode, FindNodeFastMessage(target: target))
|
||||
let nodes = await d.waitNodes(toNode, reqId)
|
||||
let
|
||||
msg = FindNodeFastMessage(target: target)
|
||||
nodes = await d.waitNodeResponses(toNode, msg)
|
||||
|
||||
if nodes.isOk:
|
||||
let res = verifyNodesRecords(nodes.get(), toNode, FindNodeFastResultLimit)
|
||||
@ -582,9 +584,9 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]):
|
||||
## Send a discovery talkreq message.
|
||||
##
|
||||
## Returns the received talkresp message or an error.
|
||||
let reqId = d.sendRequest(toNode,
|
||||
TalkReqMessage(protocol: protocol, request: request))
|
||||
let resp = await d.waitMessage(toNode, reqId)
|
||||
let
|
||||
msg = TalkReqMessage(protocol: protocol, request: request)
|
||||
resp = await d.waitResponse(toNode, msg)
|
||||
|
||||
if resp.isSome():
|
||||
if resp.get().kind == talkResp:
|
||||
@ -611,25 +613,18 @@ proc lookupDistances*(target, dest: NodeId): seq[uint16] =
|
||||
result.add(td - uint16(i))
|
||||
inc i
|
||||
|
||||
proc lookupWorker(d: Protocol, destNode: Node, target: NodeId):
|
||||
proc lookupWorker(d: Protocol, destNode: Node, target: NodeId, fast: bool):
|
||||
Future[seq[Node]] {.async.} =
|
||||
let dists = lookupDistances(target, destNode.id)
|
||||
|
||||
# Instead of doing max `LookupRequestLimit` findNode requests, make use
|
||||
# of the discv5.1 functionality to request nodes for multiple distances.
|
||||
let r = await d.findNode(destNode, dists)
|
||||
if r.isOk:
|
||||
result.add(r[])
|
||||
let r =
|
||||
if fast:
|
||||
await d.findNodeFast(destNode, target)
|
||||
else:
|
||||
# Instead of doing max `LookupRequestLimit` findNode requests, make use
|
||||
# of the discv5.1 functionality to request nodes for multiple distances.
|
||||
let dists = lookupDistances(target, destNode.id)
|
||||
await d.findNode(destNode, dists)
|
||||
|
||||
# Attempt to add all nodes discovered
|
||||
for n in result:
|
||||
discard d.addNode(n)
|
||||
|
||||
proc lookupWorkerFast(d: Protocol, destNode: Node, target: NodeId):
|
||||
Future[seq[Node]] {.async.} =
|
||||
## use terget NodeId based find_node
|
||||
|
||||
let r = await d.findNodeFast(destNode, target)
|
||||
if r.isOk:
|
||||
result.add(r[])
|
||||
|
||||
@ -660,10 +655,7 @@ proc lookup*(d: Protocol, target: NodeId, fast: bool = false): Future[seq[Node]]
|
||||
while i < closestNodes.len and pendingQueries.len < Alpha:
|
||||
let n = closestNodes[i]
|
||||
if not asked.containsOrIncl(n.id):
|
||||
if fast:
|
||||
pendingQueries.add(d.lookupWorkerFast(n, target))
|
||||
else:
|
||||
pendingQueries.add(d.lookupWorker(n, target))
|
||||
pendingQueries.add(d.lookupWorker(n, target, fast))
|
||||
inc i
|
||||
|
||||
trace "discv5 pending queries", total = pendingQueries.len
|
||||
@ -708,7 +700,8 @@ proc addProvider*(
|
||||
res.add(d.localNode)
|
||||
for toNode in res:
|
||||
if toNode != d.localNode:
|
||||
discard d.sendRequest(toNode, AddProviderMessage(cId: cId, prov: pr))
|
||||
let reqId = RequestId.init(d.rng[])
|
||||
d.sendRequest(toNode, AddProviderMessage(cId: cId, prov: pr), reqId)
|
||||
else:
|
||||
asyncSpawn d.addProviderLocal(cId, pr)
|
||||
|
||||
@ -721,8 +714,7 @@ proc sendGetProviders(d: Protocol, toNode: Node,
|
||||
trace "sendGetProviders", toNode, msg
|
||||
|
||||
let
|
||||
reqId = d.sendRequest(toNode, msg)
|
||||
resp = await d.waitMessage(toNode, reqId)
|
||||
resp = await d.waitResponse(toNode, msg)
|
||||
|
||||
if resp.isSome():
|
||||
if resp.get().kind == MessageKind.providers:
|
||||
@ -824,7 +816,7 @@ proc query*(d: Protocol, target: NodeId, k = BUCKET_SIZE): Future[seq[Node]]
|
||||
while i < min(queryBuffer.len, k) and pendingQueries.len < Alpha:
|
||||
let n = queryBuffer[i]
|
||||
if not asked.containsOrIncl(n.id):
|
||||
pendingQueries.add(d.lookupWorker(n, target))
|
||||
pendingQueries.add(d.lookupWorker(n, target, false))
|
||||
inc i
|
||||
|
||||
trace "discv5 pending queries", total = pendingQueries.len
|
||||
|
Loading…
x
Reference in New Issue
Block a user