diff --git a/eth/p2p/providers.nim b/eth/p2p/providers.nim index a8744e7..3bb75dc 100644 --- a/eth/p2p/providers.nim +++ b/eth/p2p/providers.nim @@ -6,8 +6,9 @@ import chronos, + chronos/timer, chronicles, - std/tables, + std/tables, sequtils, stew/byteutils, # toBytes discoveryv5/[protocol, node], libp2p/routing_record, @@ -26,8 +27,8 @@ type cId: NodeId ProvidersMessage* = object - total*: uint32 - enrs*: seq[PeerRecord] + total: uint32 + provs: seq[PeerRecord] func getField*(pb: ProtoBuffer, field: int, nid: var NodeId): ProtoResult[bool] {.inline.} = @@ -63,6 +64,24 @@ func write*(pb: var ProtoBuffer, field: int, pr: PeerRecord) = ## Write PeerRecord value ``pr`` to object ``pb`` using ProtoBuf's encoding. write(pb, field, pr.encode()) +proc getRepeatedField*(pb: ProtoBuffer, field: int, + value: var seq[PeerRecord]): ProtoResult[bool] {. + inline.} = + var items: seq[seq[byte]] + value.setLen(0) + let res = ? pb.getRepeatedField(field, items) + if not(res): + ok(false) + else: + for item in items: + let ma = PeerRecord.decode(item) + if ma.isOk(): + value.add(ma.get()) + else: + value.setLen(0) + return err(ProtoError.IncorrectBlob) + ok(true) + proc decode*( T: typedesc[AddProviderMessage], buffer: openArray[byte]): Result[AddProviderMessage, ProtoError] = @@ -124,3 +143,134 @@ proc addProvider*(p: ProvidersProtocol, cId: NodeId, pr: PeerRecord): Future[seq p.sendAddProvider(n, cId, pr) else: p.addProviderLocal(cId, pr) + +## ---- GetProviders ---- + +const + protoIdGetProviders = "GP".toBytes() + +proc decode*( + T: typedesc[GetProvidersMessage], + buffer: openArray[byte]): Result[GetProvidersMessage, ProtoError] = + + let pb = initProtoBuffer(buffer) + var msg = GetProvidersMessage() + + ? pb.getRequiredField(1, msg.cId) + + ok(msg) + +proc encode*(msg: GetProvidersMessage): seq[byte] = + var pb = initProtoBuffer() + + pb.write(1, msg.cId) + + pb.finish() + pb.buffer + +proc decode*( + T: typedesc[ProvidersMessage], + buffer: openArray[byte]): Result[ProvidersMessage, ProtoError] = + + let pb = initProtoBuffer(buffer) + var msg = ProvidersMessage() + + ? pb.getRequiredField(1, msg.total) + discard ? pb.getRepeatedField(2, msg.provs) + + ok(msg) + +proc encode*(msg: ProvidersMessage): seq[byte] = + var pb = initProtoBuffer() + + pb.write(1, msg.total) + for prov in msg.provs: + pb.write(2, prov) + + pb.finish() + pb.buffer + +proc sendGetProviders(p: ProvidersProtocol, dst: Node, + cId: NodeId): Future[ProvidersMessage] + {.async.} = + let msg = GetProvidersMessage(cId: cId) + trace "sendGetProviders", msg + let respbytes = await p.discovery.talkReq(dst, protoIdGetProviders, msg.encode()) + if respbytes.isOK(): + let a = respbytes.get() + result = ProvidersMessage.decode(a).get() + else: + trace "sendGetProviders", msg + result = ProvidersMessage() #TODO: add error handling + +proc getProvidersLocal*( + p: ProvidersProtocol, + cId: NodeId, + maxitems: int = 5, + ): seq[PeerRecord] {.raises: [KeyError,Defect].}= + result = if (cId in p.providers): p.providers[cId] else: @[] + +proc getProviders*( + p: ProvidersProtocol, + cId: NodeId, + maxitems: int = 5, + timeout: timer.Duration = chronos.milliseconds(5000) + ): Future[seq[PeerRecord]] {.async.} = + ## Search for providers of the given cId. + + # What providers do we know about? + result = p.getProvidersLocal(cId, maxitems) + trace "local providers:", result + + let nodesNearby = await p.discovery.lookup(cId) + trace "nearby:", nodesNearby + var providersFut: seq[Future[ProvidersMessage]] + for n in nodesNearby: + if n != p.discovery.localNode: + providersFut.add(p.sendGetProviders(n, cId)) + + while providersFut.len > 0: + let providersMsg = await one(providersFut) + # trace "Got providers response", providersMsg + + let index = providersFut.find(providersMsg) + if index != -1: + providersFut.del(index) + + let providersMsg2 = await providersMsg + trace "2", providersMsg2 + + let providers = providersMsg.read.provs + result = result.concat(providers).deduplicate + # TODO: hsndle timeout + # + trace "getProviders collected: ", result + +proc recvGetProviders(p: ProvidersProtocol, nodeId: NodeId, payload: openArray[byte]) : ProvidersMessage + {.raises: [Defect].} = + trace "recvGetProviders" + let msg = GetProvidersMessage.decode(payload).get() + trace "<<< get_providers ", src = nodeId, dst = p.discovery.localNode.id, cid = msg.cId + #TODO: add checks, add signed version + let provs = p.providers.getOrDefault(msg.cId) + + trace "providers:", provs + ##TODO: handle multiple messages + ProvidersMessage(total: 1, provs: provs) + + +proc registerGetProviders(p: ProvidersProtocol) = + proc handler(protocol: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte] + {.gcsafe, raises: [Defect].} = + let returnMsg = recvGetProviders(p, fromId, request) + trace "returnMsg", returnMsg + returnMsg.encode() # TODO: response + + let protocol = TalkProtocol(protocolHandler: handler) + discard p.discovery.registerTalkProtocol(protoIdGetProviders, protocol) #TODO: handle error + +proc newProvidersProtocol*(d: protocol.Protocol) : ProvidersProtocol = + result.new() + result.discovery = d + result.registerAddProvider() + result.registerGetProviders()