diff --git a/libp2pdht/dht/value_encoding.nim b/libp2pdht/dht/value_encoding.nim index c21d25d..bacdb14 100644 --- a/libp2pdht/dht/value_encoding.nim +++ b/libp2pdht/dht/value_encoding.nim @@ -75,3 +75,23 @@ proc encode*(msg: ValueMessage): seq[byte] = pb.finish() pb.buffer + +proc decode*( + T: typedesc[FindValueMessage], + buffer: openArray[byte]): Result[FindValueMessage, ProtoError] = + + let pb = initProtoBuffer(buffer) + var msg = FindValueMessage() + + ? pb.getRequiredField(1, msg.cId) + + ok(msg) + +proc encode*(msg: FindValueMessage): seq[byte] = + var pb = initProtoBuffer() + + pb.write(1, msg.cId) + + pb.finish() + pb.buffer + diff --git a/libp2pdht/dht/value_messages.nim b/libp2pdht/dht/value_messages.nim index 57e6362..aa34fc4 100644 --- a/libp2pdht/dht/value_messages.nim +++ b/libp2pdht/dht/value_messages.nim @@ -12,3 +12,6 @@ type ValueMessage* = object #total*: uint32 value*: seq[byte] + + FindValueMessage* = object + cId*: NodeId diff --git a/libp2pdht/private/eth/p2p/discoveryv5/messages.nim b/libp2pdht/private/eth/p2p/discoveryv5/messages.nim index 1bbfa9f..de9e8e2 100644 --- a/libp2pdht/private/eth/p2p/discoveryv5/messages.nim +++ b/libp2pdht/private/eth/p2p/discoveryv5/messages.nim @@ -46,6 +46,7 @@ type addValue = 0x0E getValue = 0x0F respValue = 0x10 + findValue = 0x11 findNodeFast = 0x83 RequestId* = object @@ -85,7 +86,7 @@ type SomeMessage* = PingMessage or PongMessage or FindNodeMessage or NodesMessage or TalkReqMessage or TalkRespMessage or AddProviderMessage or GetProvidersMessage or ProvidersMessage or FindNodeFastMessage or - AddValueMessage or GetValueMessage or ValueMessage + AddValueMessage or GetValueMessage or ValueMessage or FindValueMessage Message* = object reqId*: RequestId @@ -124,6 +125,8 @@ type getValue*: GetValueMessage of respValue: value*: ValueMessage + of findValue: + findValue*: FindValueMessage else: discard @@ -141,6 +144,7 @@ template messageKind*(T: typedesc[SomeMessage]): MessageKind = elif T is AddValueMessage: MessageKind.addValue elif T is GetValueMessage: MessageKind.getValue elif T is ValueMessage: MessageKind.respValue + elif T is FindValueMessage: MessageKind.findValue proc hash*(reqId: RequestId): Hash = hash(reqId.id) diff --git a/libp2pdht/private/eth/p2p/discoveryv5/messages_encoding.nim b/libp2pdht/private/eth/p2p/discoveryv5/messages_encoding.nim index 4d6d989..a347152 100644 --- a/libp2pdht/private/eth/p2p/discoveryv5/messages_encoding.nim +++ b/libp2pdht/private/eth/p2p/discoveryv5/messages_encoding.nim @@ -459,6 +459,14 @@ proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] = else: return err("Unable to decode ValueMessage") + of findValue: + let res = FindValueMessage.decode(encoded) + if res.isOk: + message.findValue = res.get + return ok(message) + else: + return err("Unable to decode FindValueMessage") + of regTopic, ticket, regConfirmation, topicQuery: # We just pass the empty type of this message without attempting to # decode, so that the protocol knows what was received. diff --git a/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim b/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim index 7f7365b..969ccc9 100644 --- a/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim +++ b/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim @@ -432,6 +432,19 @@ proc handleGetValue( trace "no value in local db", n = d.localNode, cID = getValue.cId # TODO: add noValue response +proc handleFindValue(d: Protocol, fromId: NodeId, fromAddr: Address, + fv: FindValueMessage, reqId: RequestId) {.async.} = + try: + let value = d.valueStore[fv.cId] + trace "retrieved value from local db", n = d.localNode, cID = fv.cId, value + ##TODO: handle multiple messages? + let response = ValueMessage(value: value) + d.sendResponse(fromId, fromAddr, response, reqId) + except KeyError: + d.sendNodes(fromId, fromAddr, reqId, + d.routingTable.neighbours(fv.cId, seenOnly = true, k = FindNodeResultLimit)) + # TODO: if known, maybe we should add exact target even if not yet "seen" + proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address, message: Message) = case message.kind @@ -461,6 +474,9 @@ proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address, of getValue: discovery_message_requests_incoming.inc() asyncSpawn d.handleGetValue(srcId, fromAddr, message.getValue, message.reqId) + of findValue: + discovery_message_requests_incoming.inc() + asyncSpawn d.handleFindValue(srcId, fromAddr, message.findValue, message.reqId) of regTopic, topicQuery: discovery_message_requests_incoming.inc() discovery_message_requests_incoming.inc(labelValues = ["no_response"]) @@ -926,6 +942,130 @@ proc getValue*( return err "getValue failed" +proc waitNodesOrValue(d: Protocol, fromNode: Node, reqId: RequestId): + Future[DiscResult[(seq[SignedPeerRecord], seq[byte])]] {.async.} = + + var op = await d.waitMessage(fromNode, reqId) + if op.isSome: + if op.get.kind == MessageKind.nodes: + var res = op.get.nodes.sprs + let total = op.get.nodes.total + for i in 1 ..< total: + op = await d.waitMessage(fromNode, reqId) + if op.isSome and op.get.kind == MessageKind.nodes: + res.add(op.get.nodes.sprs) + else: + # No error on this as we received some nodes. + break + return ok((res, @[])) + elif op.get.kind == MessageKind.respValue: + var res = op.get.value.value + return ok((@[], res)) + else: + discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"]) + return err("Invalid response to find node message") + else: + discovery_message_requests_outgoing.inc(labelValues = ["no_response"]) + return err("Nodes message not received in time") + +proc waitFindValueResponses*[T: SomeMessage](d: Protocol, node: Node, msg: T): + Future[DiscResult[(seq[SignedPeerRecord], seq[byte])]] = + let reqId = RequestId.init(d.rng[]) + result = d.waitNodesOrValue(node, reqId) + sendRequest(d, node, msg, reqId) + +proc sendFindValue*(d: Protocol, toNode: Node, target: NodeId): + Future[DiscResult[(seq[Node], seq[byte])]] {.async.} = + let + msg = FindValueMessage(cId: target) + response = await d.waitFindValueResponses(toNode, msg) + + if response.isOk: + let (nodes, value) = response.get() + if nodes.len > 0: + let res = verifyNodesRecords(nodes, toNode, FindNodeFastResultLimit) + d.routingTable.setJustSeen(toNode) + return ok((res, @[])) + else: + return ok((@[], value)) + else: + d.replaceNode(toNode) + return err(response.error) + +proc findValue*( + d: Protocol, + target: NodeId, + timeout: Duration = 5000.milliseconds # TODO: not used? + ): Future[DiscResult[seq[byte]]] {.async.} = + ## Perform a findValue lookup for the given value, descending on nodes with + ## multiple parallel requests and returning the first instance of the + ## key-value pair found. + + proc worker(d: Protocol, destNode: Node, target: NodeId): + Future[(seq[Node], seq[byte])] {.async.} = + + let r = await d.sendFindValue(destNode, target) + + if r.isOk: + let (nodes, value) = r.get + result = (nodes, value) + + # Attempt to add all nodes discovered + for n in nodes: + discard d.addNode(n) + + var closestNodes = d.routingTable.neighbours(target, BUCKET_SIZE, + seenOnly = false) + + var asked, seen = initHashSet[NodeId]() + asked.incl(d.localNode.id) # No need to ask our own node + seen.incl(d.localNode.id) # No need to discover our own node + for node in closestNodes: + seen.incl(node.id) + + var pendingQueries = newSeqOfCap[Future[(seq[Node], seq[byte])]](Alpha) + + while true: + var i = 0 + # Doing `Alpha` amount of requests at once as long as closer non queried + # nodes are discovered. + while i < closestNodes.len and pendingQueries.len < Alpha: + let n = closestNodes[i] + if not asked.containsOrIncl(n.id): + pendingQueries.add(d.worker(n, target)) + inc i + + trace "discv5 pending queries", total = pendingQueries.len + + if pendingQueries.len == 0: + break + + let query = await one(pendingQueries) + trace "Got discv5 lookup query response" + + let index = pendingQueries.find(query) + if index != -1: + pendingQueries.del(index) + else: + error "Resulting query should have been in the pending queries" + + let (nodes, value) = query.read + # TODO: Remove node on timed-out query? + if value.len > 0: + return ok(value) + for n in nodes: + if not seen.containsOrIncl(n.id): + # If it wasn't seen before, insert node while remaining sorted + closestNodes.insert(n, closestNodes.lowerBound(n, + proc(x: Node, n: Node): int = + cmp(distance(x.id, target), distance(n.id, target)) + )) + + if closestNodes.len > BUCKET_SIZE: + closestNodes.del(closestNodes.high()) + + d.lastLookup = now(chronos.Moment) + proc query*(d: Protocol, target: NodeId, k = BUCKET_SIZE): Future[seq[Node]] {.async.} = ## Query k nodes for the given target, returns all nodes found, including the