implement findValue
retrieve a value from the DHT Signed-off-by: Csaba Kiraly <csaba.kiraly@gmail.com>
This commit is contained in:
parent
bca26eb059
commit
d7de86060c
|
@ -75,3 +75,23 @@ proc encode*(msg: ValueMessage): seq[byte] =
|
|||
|
||||
pb.finish()
|
||||
pb.buffer
|
||||
|
||||
proc decode*(
|
||||
T: typedesc[FindValueMessage],
|
||||
buffer: openArray[byte]): Result[FindValueMessage, ProtoError] =
|
||||
|
||||
let pb = initProtoBuffer(buffer)
|
||||
var msg = FindValueMessage()
|
||||
|
||||
? pb.getRequiredField(1, msg.cId)
|
||||
|
||||
ok(msg)
|
||||
|
||||
proc encode*(msg: FindValueMessage): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
|
||||
pb.write(1, msg.cId)
|
||||
|
||||
pb.finish()
|
||||
pb.buffer
|
||||
|
||||
|
|
|
@ -12,3 +12,6 @@ type
|
|||
ValueMessage* = object
|
||||
#total*: uint32
|
||||
value*: seq[byte]
|
||||
|
||||
FindValueMessage* = object
|
||||
cId*: NodeId
|
||||
|
|
|
@ -46,6 +46,7 @@ type
|
|||
addValue = 0x0E
|
||||
getValue = 0x0F
|
||||
respValue = 0x10
|
||||
findValue = 0x11
|
||||
findNodeFast = 0x83
|
||||
|
||||
RequestId* = object
|
||||
|
@ -85,7 +86,7 @@ type
|
|||
SomeMessage* = PingMessage or PongMessage or FindNodeMessage or NodesMessage or
|
||||
TalkReqMessage or TalkRespMessage or AddProviderMessage or GetProvidersMessage or
|
||||
ProvidersMessage or FindNodeFastMessage or
|
||||
AddValueMessage or GetValueMessage or ValueMessage
|
||||
AddValueMessage or GetValueMessage or ValueMessage or FindValueMessage
|
||||
|
||||
Message* = object
|
||||
reqId*: RequestId
|
||||
|
@ -124,6 +125,8 @@ type
|
|||
getValue*: GetValueMessage
|
||||
of respValue:
|
||||
value*: ValueMessage
|
||||
of findValue:
|
||||
findValue*: FindValueMessage
|
||||
else:
|
||||
discard
|
||||
|
||||
|
@ -141,6 +144,7 @@ template messageKind*(T: typedesc[SomeMessage]): MessageKind =
|
|||
elif T is AddValueMessage: MessageKind.addValue
|
||||
elif T is GetValueMessage: MessageKind.getValue
|
||||
elif T is ValueMessage: MessageKind.respValue
|
||||
elif T is FindValueMessage: MessageKind.findValue
|
||||
|
||||
proc hash*(reqId: RequestId): Hash =
|
||||
hash(reqId.id)
|
||||
|
|
|
@ -459,6 +459,14 @@ proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] =
|
|||
else:
|
||||
return err("Unable to decode ValueMessage")
|
||||
|
||||
of findValue:
|
||||
let res = FindValueMessage.decode(encoded)
|
||||
if res.isOk:
|
||||
message.findValue = res.get
|
||||
return ok(message)
|
||||
else:
|
||||
return err("Unable to decode FindValueMessage")
|
||||
|
||||
of regTopic, ticket, regConfirmation, topicQuery:
|
||||
# We just pass the empty type of this message without attempting to
|
||||
# decode, so that the protocol knows what was received.
|
||||
|
|
|
@ -432,6 +432,18 @@ proc handleGetValue(
|
|||
trace "no value in local db", n = d.localNode, cID = getValue.cId
|
||||
# TODO: add noValue response
|
||||
|
||||
proc handleFindValue(d: Protocol, fromId: NodeId, fromAddr: Address,
|
||||
fv: FindValueMessage, reqId: RequestId) {.async.} =
|
||||
try:
|
||||
let value = d.valueStore[fv.cId]
|
||||
trace "retrieved value from local db", n = d.localNode, cID = fv.cId, value
|
||||
##TODO: handle multiple messages?
|
||||
let response = ValueMessage(value: value)
|
||||
d.sendResponse(fromId, fromAddr, response, reqId)
|
||||
except KeyError:
|
||||
d.sendNodes(fromId, fromAddr, reqId,
|
||||
d.routingTable.neighbours(fv.cId, seenOnly = true, k = FindNodeFastResultLimit))
|
||||
|
||||
proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address,
|
||||
message: Message) =
|
||||
case message.kind
|
||||
|
@ -461,6 +473,9 @@ proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address,
|
|||
of getValue:
|
||||
discovery_message_requests_incoming.inc()
|
||||
asyncSpawn d.handleGetValue(srcId, fromAddr, message.getValue, message.reqId)
|
||||
of findValue:
|
||||
discovery_message_requests_incoming.inc()
|
||||
asyncSpawn d.handleFindValue(srcId, fromAddr, message.findValue, message.reqId)
|
||||
of regTopic, topicQuery:
|
||||
discovery_message_requests_incoming.inc()
|
||||
discovery_message_requests_incoming.inc(labelValues = ["no_response"])
|
||||
|
@ -935,6 +950,130 @@ proc getValue*(
|
|||
|
||||
return err "getValue failed"
|
||||
|
||||
proc waitNodesOrValue(d: Protocol, fromNode: Node, reqId: RequestId):
|
||||
Future[DiscResult[(seq[SignedPeerRecord], seq[byte])]] {.async.} =
|
||||
|
||||
var op = await d.waitMessage(fromNode, reqId)
|
||||
if op.isSome:
|
||||
if op.get.kind == MessageKind.nodes:
|
||||
var res = op.get.nodes.sprs
|
||||
let total = op.get.nodes.total
|
||||
for i in 1 ..< total:
|
||||
op = await d.waitMessage(fromNode, reqId)
|
||||
if op.isSome and op.get.kind == MessageKind.nodes:
|
||||
res.add(op.get.nodes.sprs)
|
||||
else:
|
||||
# No error on this as we received some nodes.
|
||||
break
|
||||
return ok((res, @[]))
|
||||
elif op.get.kind == MessageKind.respValue:
|
||||
var res = op.get.value.value
|
||||
return ok((@[], res))
|
||||
else:
|
||||
discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"])
|
||||
return err("Invalid response to find node message")
|
||||
else:
|
||||
discovery_message_requests_outgoing.inc(labelValues = ["no_response"])
|
||||
return err("Nodes message not received in time")
|
||||
|
||||
proc waitFindValueResponses*[T: SomeMessage](d: Protocol, node: Node, msg: T):
|
||||
Future[DiscResult[(seq[SignedPeerRecord], seq[byte])]] =
|
||||
let reqId = RequestId.init(d.rng[])
|
||||
result = d.waitNodesOrValue(node, reqId)
|
||||
sendRequest(d, node, msg, reqId)
|
||||
|
||||
proc sendFindValue*(d: Protocol, toNode: Node, target: NodeId):
|
||||
Future[DiscResult[(seq[Node], seq[byte])]] {.async.} =
|
||||
let
|
||||
msg = FindValueMessage(cId: target)
|
||||
response = await d.waitFindValueResponses(toNode, msg)
|
||||
|
||||
if response.isOk:
|
||||
let (nodes, value) = response.get()
|
||||
if nodes.len > 0:
|
||||
let res = verifyNodesRecords(nodes, toNode, FindNodeFastResultLimit)
|
||||
d.routingTable.setJustSeen(toNode)
|
||||
return ok((res, @[]))
|
||||
else:
|
||||
return ok((@[], value))
|
||||
else:
|
||||
d.replaceNode(toNode)
|
||||
return err(response.error)
|
||||
|
||||
proc findValue*(
|
||||
d: Protocol,
|
||||
target: NodeId,
|
||||
timeout: Duration = 5000.milliseconds # TODO: not used?
|
||||
): Future[DiscResult[seq[byte]]] {.async.} =
|
||||
## Perform a findValue lookup for the given value, descending on nodes with
|
||||
## multiple parallel requests and returning the first instance of the
|
||||
## key-value pair found.
|
||||
|
||||
proc worker(d: Protocol, destNode: Node, target: NodeId):
|
||||
Future[(seq[Node], seq[byte])] {.async.} =
|
||||
|
||||
let r = await d.sendFindValue(destNode, target)
|
||||
|
||||
if r.isOk:
|
||||
let (nodes, value) = r.get
|
||||
result = (nodes, value)
|
||||
|
||||
# Attempt to add all nodes discovered
|
||||
for n in nodes:
|
||||
discard d.addNode(n)
|
||||
|
||||
var closestNodes = d.routingTable.neighbours(target, BUCKET_SIZE,
|
||||
seenOnly = false)
|
||||
|
||||
var asked, seen = initHashSet[NodeId]()
|
||||
asked.incl(d.localNode.id) # No need to ask our own node
|
||||
seen.incl(d.localNode.id) # No need to discover our own node
|
||||
for node in closestNodes:
|
||||
seen.incl(node.id)
|
||||
|
||||
var pendingQueries = newSeqOfCap[Future[(seq[Node], seq[byte])]](Alpha)
|
||||
|
||||
while true:
|
||||
var i = 0
|
||||
# Doing `Alpha` amount of requests at once as long as closer non queried
|
||||
# nodes are discovered.
|
||||
while i < closestNodes.len and pendingQueries.len < Alpha:
|
||||
let n = closestNodes[i]
|
||||
if not asked.containsOrIncl(n.id):
|
||||
pendingQueries.add(d.worker(n, target))
|
||||
inc i
|
||||
|
||||
trace "discv5 pending queries", total = pendingQueries.len
|
||||
|
||||
if pendingQueries.len == 0:
|
||||
break
|
||||
|
||||
let query = await one(pendingQueries)
|
||||
trace "Got discv5 lookup query response"
|
||||
|
||||
let index = pendingQueries.find(query)
|
||||
if index != -1:
|
||||
pendingQueries.del(index)
|
||||
else:
|
||||
error "Resulting query should have been in the pending queries"
|
||||
|
||||
let (nodes, value) = query.read
|
||||
# TODO: Remove node on timed-out query?
|
||||
if value.len > 0:
|
||||
return ok(value)
|
||||
for n in nodes:
|
||||
if not seen.containsOrIncl(n.id):
|
||||
# If it wasn't seen before, insert node while remaining sorted
|
||||
closestNodes.insert(n, closestNodes.lowerBound(n,
|
||||
proc(x: Node, n: Node): int =
|
||||
cmp(distance(x.id, target), distance(n.id, target))
|
||||
))
|
||||
|
||||
if closestNodes.len > BUCKET_SIZE:
|
||||
closestNodes.del(closestNodes.high())
|
||||
|
||||
d.lastLookup = now(chronos.Moment)
|
||||
|
||||
proc query*(d: Protocol, target: NodeId, k = BUCKET_SIZE): Future[seq[Node]]
|
||||
{.async.} =
|
||||
## Query k nodes for the given target, returns all nodes found, including the
|
||||
|
|
Loading…
Reference in New Issue