Merge pull request #95 from codex-storage/factorize

Factorize code
This commit is contained in:
Csaba Kiraly 2024-10-07 14:06:59 +02:00 committed by GitHub
commit b8bcb2d08d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 62 additions and 67 deletions

View File

@ -40,6 +40,9 @@ declareCounter discovery_session_decrypt_failures, "Session decrypt failures"
logScope:
topics = "discv5"
type
cipher = aes128
const
version: uint16 = 1
idSignatureText = "discovery v5 identity proof"
@ -162,7 +165,7 @@ proc deriveKeys*(n1, n2: NodeId, priv: PrivateKey, pub: PublicKey,
ok secrets
proc encryptGCM*(key: AesKey, nonce, pt, authData: openArray[byte]): seq[byte] =
var ectx: GCM[aes128]
var ectx: GCM[cipher]
ectx.init(key, nonce, authData)
result = newSeq[byte](pt.len + gcmTagSize)
ectx.encrypt(pt, result)
@ -175,7 +178,7 @@ proc decryptGCM*(key: AesKey, nonce, ct, authData: openArray[byte]):
debug "cipher is missing tag", len = ct.len
return
var dctx: GCM[aes128]
var dctx: GCM[cipher]
dctx.init(key, nonce, authData)
var res = newSeq[byte](ct.len - gcmTagSize)
var tag: array[gcmTagSize, byte]
@ -189,7 +192,7 @@ proc decryptGCM*(key: AesKey, nonce, ct, authData: openArray[byte]):
return some(res)
proc encryptHeader*(id: NodeId, iv, header: openArray[byte]): seq[byte] =
var ectx: CTR[aes128]
var ectx: CTR[cipher]
ectx.init(id.toByteArrayBE().toOpenArray(0, 15), iv)
result = newSeq[byte](header.len)
ectx.encrypt(header, result)
@ -374,7 +377,7 @@ proc decodeHeader*(id: NodeId, iv, maskedHeader: openArray[byte]):
DecodeResult[(StaticHeader, seq[byte])] =
# No need to check staticHeader size as that is included in minimum packet
# size check in decodePacket
var ectx: CTR[aes128]
var ectx: CTR[cipher]
ectx.init(id.toByteArrayBE().toOpenArray(0, aesKeySize - 1), iv)
# Decrypt static-header part of the header
var staticHeader = newSeq[byte](staticHeaderSize)

View File

@ -325,7 +325,7 @@ proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): seq[byte] =
pb.write(2, encoded)
pb.finish()
result.add(pb.buffer)
trace "Encoded protobuf message", typ = $T, encoded
trace "Encoded protobuf message", typ = $T
proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] =
## Decodes to the specific `Message` type.

View File

@ -183,6 +183,9 @@ type
DiscResult*[T] = Result[T, cstring]
func `$`*(p: Protocol): string =
$p.localNode.id
const
defaultDiscoveryConfig* = DiscoveryConfig(
tableIpLimits: DefaultTableIpLimits,
@ -453,18 +456,41 @@ proc replaceNode(d: Protocol, n: Node) =
# peers in the routing table.
debug "Message request to bootstrap node failed", src=d.localNode, dst=n
proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T,
reqId: RequestId) =
doAssert(toNode.address.isSome())
let
message = encodeMessage(m, reqId)
proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId):
trace "Send message packet", dstId = toNode.id,
address = toNode.address, kind = messageKind(T)
discovery_message_requests_outgoing.inc()
d.transport.sendMessage(toNode, message)
proc waitResponse*[T: SomeMessage](d: Protocol, node: Node, msg: T):
Future[Option[Message]] =
let reqId = RequestId.init(d.rng[])
result = d.waitMessage(node, reqId)
sendRequest(d, node, msg, reqId)
proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId, timeout = ResponseTimeout):
Future[Option[Message]] =
result = newFuture[Option[Message]]("waitMessage")
let res = result
let key = (fromNode.id, reqId)
sleepAsync(ResponseTimeout).addCallback() do(data: pointer):
sleepAsync(timeout).addCallback() do(data: pointer):
d.awaitedMessages.del(key)
if not res.finished:
res.complete(none(Message))
d.awaitedMessages[key] = result
proc waitNodeResponses*[T: SomeMessage](d: Protocol, node: Node, msg: T):
Future[DiscResult[seq[SignedPeerRecord]]] =
let reqId = RequestId.init(d.rng[])
result = d.waitNodes(node, reqId)
sendRequest(d, node, msg, reqId)
proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId):
Future[DiscResult[seq[SignedPeerRecord]]] {.async.} =
## Wait for one or more nodes replies.
@ -493,40 +519,14 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId):
discovery_message_requests_outgoing.inc(labelValues = ["no_response"])
return err("Nodes message not received in time")
proc sendRequest*[T: SomeMessage](d: Protocol, toId: NodeId, toAddr: Address, m: T):
RequestId =
let
reqId = RequestId.init(d.rng[])
message = encodeMessage(m, reqId)
trace "Send message packet", dstId = toId, toAddr, kind = messageKind(T)
discovery_message_requests_outgoing.inc()
d.transport.sendMessage(toId, toAddr, message)
return reqId
proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T):
RequestId =
doAssert(toNode.address.isSome())
let
reqId = RequestId.init(d.rng[])
message = encodeMessage(m, reqId)
trace "Send message packet", dstId = toNode.id,
address = toNode.address, kind = messageKind(T)
discovery_message_requests_outgoing.inc()
d.transport.sendMessage(toNode, message)
return reqId
proc ping*(d: Protocol, toNode: Node):
Future[DiscResult[PongMessage]] {.async.} =
## Send a discovery ping message.
##
## Returns the received pong message or an error.
let reqId = d.sendRequest(toNode,
PingMessage(sprSeq: d.localNode.record.seqNum))
let resp = await d.waitMessage(toNode, reqId)
let
msg = PingMessage(sprSeq: d.localNode.record.seqNum)
resp = await d.waitResponse(toNode, msg)
if resp.isSome():
if resp.get().kind == pong:
@ -547,8 +547,9 @@ proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]):
##
## Returns the received nodes or an error.
## Received SPRs are already validated and converted to `Node`.
let reqId = d.sendRequest(toNode, FindNodeMessage(distances: distances))
let nodes = await d.waitNodes(toNode, reqId)
let
msg = FindNodeMessage(distances: distances)
nodes = await d.waitNodeResponses(toNode, msg)
if nodes.isOk:
let res = verifyNodesRecords(nodes.get(), toNode, FindNodeResultLimit, distances)
@ -565,8 +566,9 @@ proc findNodeFast*(d: Protocol, toNode: Node, target: NodeId):
##
## Returns the received nodes or an error.
## Received SPRs are already validated and converted to `Node`.
let reqId = d.sendRequest(toNode, FindNodeFastMessage(target: target))
let nodes = await d.waitNodes(toNode, reqId)
let
msg = FindNodeFastMessage(target: target)
nodes = await d.waitNodeResponses(toNode, msg)
if nodes.isOk:
let res = verifyNodesRecords(nodes.get(), toNode, FindNodeFastResultLimit)
@ -582,9 +584,9 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]):
## Send a discovery talkreq message.
##
## Returns the received talkresp message or an error.
let reqId = d.sendRequest(toNode,
TalkReqMessage(protocol: protocol, request: request))
let resp = await d.waitMessage(toNode, reqId)
let
msg = TalkReqMessage(protocol: protocol, request: request)
resp = await d.waitResponse(toNode, msg)
if resp.isSome():
if resp.get().kind == talkResp:
@ -611,25 +613,18 @@ proc lookupDistances*(target, dest: NodeId): seq[uint16] =
result.add(td - uint16(i))
inc i
proc lookupWorker(d: Protocol, destNode: Node, target: NodeId):
proc lookupWorker(d: Protocol, destNode: Node, target: NodeId, fast: bool):
Future[seq[Node]] {.async.} =
let dists = lookupDistances(target, destNode.id)
# Instead of doing max `LookupRequestLimit` findNode requests, make use
# of the discv5.1 functionality to request nodes for multiple distances.
let r = await d.findNode(destNode, dists)
if r.isOk:
result.add(r[])
let r =
if fast:
await d.findNodeFast(destNode, target)
else:
# Instead of doing max `LookupRequestLimit` findNode requests, make use
# of the discv5.1 functionality to request nodes for multiple distances.
let dists = lookupDistances(target, destNode.id)
await d.findNode(destNode, dists)
# Attempt to add all nodes discovered
for n in result:
discard d.addNode(n)
proc lookupWorkerFast(d: Protocol, destNode: Node, target: NodeId):
Future[seq[Node]] {.async.} =
## use terget NodeId based find_node
let r = await d.findNodeFast(destNode, target)
if r.isOk:
result.add(r[])
@ -660,10 +655,7 @@ proc lookup*(d: Protocol, target: NodeId, fast: bool = false): Future[seq[Node]]
while i < closestNodes.len and pendingQueries.len < Alpha:
let n = closestNodes[i]
if not asked.containsOrIncl(n.id):
if fast:
pendingQueries.add(d.lookupWorkerFast(n, target))
else:
pendingQueries.add(d.lookupWorker(n, target))
pendingQueries.add(d.lookupWorker(n, target, fast))
inc i
trace "discv5 pending queries", total = pendingQueries.len
@ -708,7 +700,8 @@ proc addProvider*(
res.add(d.localNode)
for toNode in res:
if toNode != d.localNode:
discard d.sendRequest(toNode, AddProviderMessage(cId: cId, prov: pr))
let reqId = RequestId.init(d.rng[])
d.sendRequest(toNode, AddProviderMessage(cId: cId, prov: pr), reqId)
else:
asyncSpawn d.addProviderLocal(cId, pr)
@ -721,8 +714,7 @@ proc sendGetProviders(d: Protocol, toNode: Node,
trace "sendGetProviders", toNode, msg
let
reqId = d.sendRequest(toNode, msg)
resp = await d.waitMessage(toNode, reqId)
resp = await d.waitResponse(toNode, msg)
if resp.isSome():
if resp.get().kind == MessageKind.providers:
@ -824,7 +816,7 @@ proc query*(d: Protocol, target: NodeId, k = BUCKET_SIZE): Future[seq[Node]]
while i < min(queryBuffer.len, k) and pendingQueries.len < Alpha:
let n = queryBuffer[i]
if not asked.containsOrIncl(n.id):
pendingQueries.add(d.lookupWorker(n, target))
pendingQueries.add(d.lookupWorker(n, target, false))
inc i
trace "discv5 pending queries", total = pendingQueries.len