diff --git a/eth/p2p/providers.nim b/eth/p2p/providers.nim new file mode 100644 index 0000000..d4dd8d3 --- /dev/null +++ b/eth/p2p/providers.nim @@ -0,0 +1,103 @@ +# Copyright (c) 2020-2022 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + chronos, + chronicles, + std/tables, + stew/byteutils, # toBytes + discoveryv5/[protocol, node], + libp2p/routing_record, + libp2p/protobuf/minprotobuf + +type + ProvidersProtocol = ref object + providers: Table[NodeId, seq[PeerRecord]] + discovery: protocol.Protocol + + AddProviderMessage* = object + cId: NodeId + prov: PeerRecord + + GetProvidersMessage* = object + cId: NodeId + + ProvidersMessage* = object + total*: uint32 + enrs*: seq[Record] + +func getField*(pb: ProtoBuffer, field: int, + nid: var NodeId): ProtoResult[bool] {.inline.} = + ## Read ``NodeId`` from ProtoBuf's message and validate it + var buffer: seq[byte] + let res = ? pb.getField(field, buffer) + if not(res): + ok(false) + else: + nid = readUintBE[256](buffer) + ok(true) + +func getField*(pb: ProtoBuffer, field: int, + pr: var PeerRecord): ProtoResult[bool] {.inline.} = + ## Read ``NodeId`` from ProtoBuf's message and validate it + var buffer: seq[byte] + let res = ? pb.getField(field, buffer) + if not(res): + ok(false) + else: + let res2 = PeerRecord.decode(buffer) + if res2.isOk(): + pr = res2.get() + ok(true) + else: + err(ProtoError.IncorrectBlob) + +proc decode*( + T: typedesc[AddProviderMessage], + buffer: openArray[byte]): Result[AddProviderMessage, ProtoError] = + + let pb = initProtoBuffer(buffer) + var msg = AddProviderMessage() + + ? pb.getRequiredField(1, msg.cId) + ? pb.getRequiredField(2, msg.prov) + + ok(msg) + +proc addProviderLocal(p: ProvidersProtocol, cId: NodeId, prov: PeerRecord) = + p.providers.mgetOrPut(cId, @[]).add(prov) + +proc recvAddProvider(p: ProvidersProtocol, nodeId: NodeId, payload: openArray[byte]) + {.raises: [Defect].} = + #TODO: add checks, add signed version + let msg = AddProviderMessage.decode(payload).get() + p.addProviderLocal(msg.cId, msg.prov) + #TODO: check that CID is reasonably close to our NodeID + +const + protoIdAddProvider = "AP".toBytes() + +proc registerAddProvider(p: ProvidersProtocol) = + proc handler(protocol: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte] + {.gcsafe, raises: [Defect].} = + recvAddProvider(p, fromId, request) + @[] # talk requires a response + + let protocol = TalkProtocol(protocolHandler: handler) + discard p.discovery.registerTalkProtocol(protoIdAddProvider, protocol) #TODO: handle error + +proc sendAddProvider*(p: ProvidersProtocol, dst: Node, cId: NodeId, pr: PeerRecord) = + #type NodeDesc = tuple[ip: IpAddress, udpPort, tcpPort: Port, pk: PublicKey] + let msg = AddProviderMessage(cId: cId, prov: pr) + discard p.discovery.talkReq(dst, protoIdAddProvider, pr.encode()) + +proc addProvider*(p: ProvidersProtocol, cId: NodeId, pr: PeerRecord): Future[seq[Node]] {.async.} = + result = await p.discovery.lookup(cId) + for n in result: + if n != p.discovery.localNode: + p.sendAddProvider(n, cId, pr) + else: + p.addProviderLocal(cId, pr)