diff --git a/dht/providers.nim b/dht/providers.nim index 0719fe0..8423b77 100644 --- a/dht/providers.nim +++ b/dht/providers.nim @@ -24,11 +24,8 @@ proc addProviderLocal(p: ProvidersProtocol, cId: NodeId, prov: PeerRecord) = trace "adding provider to local db", n=p.discovery.localNode, cId, prov p.providers.mgetOrPut(cId, @[]).add(prov) -proc recvAddProvider(p: ProvidersProtocol, nodeId: NodeId, payload: openArray[byte]) +proc recvAddProvider(p: ProvidersProtocol, nodeId: NodeId, msg: AddProviderMessage) {.raises: [Defect].} = - #TODO: add checks, add signed version - let msg = AddProviderMessage.decode(payload).get() - trace "<<< add_provider ", src = nodeId, dst = p.discovery.localNode.id, cid = msg.cId, prov=msg.prov p.addProviderLocal(msg.cId, msg.prov) #TODO: check that CID is reasonably close to our NodeID @@ -38,7 +35,12 @@ const proc registerAddProvider(p: ProvidersProtocol) = proc handler(protocol: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte] {.gcsafe, raises: [Defect].} = - recvAddProvider(p, fromId, request) + trace "<<< add_provider ", src = nodeId, dst = p.discovery.localNode.id, cid = msg.cId, prov=msg.prov + #TODO: add checks, add signed version + let msg = AddProviderMessage.decode(request).get() + + recvAddProvider(p, fromId, msg) + @[] # talk requires a response let protocol = TalkProtocol(protocolHandler: handler) @@ -122,25 +124,25 @@ proc getProviders*( # trace "getProviders collected: ", result -proc recvGetProviders(p: ProvidersProtocol, nodeId: NodeId, payload: openArray[byte]) : ProvidersMessage +proc recvGetProviders(p: ProvidersProtocol, nodeId: NodeId, msg: GetProvidersMessage) : ProvidersMessage {.raises: [Defect].} = - trace "recvGetProviders" - let msg = GetProvidersMessage.decode(payload).get() - trace "<<< get_providers ", src = nodeId, dst = p.discovery.localNode.id, cid = msg.cId #TODO: add checks, add signed version let provs = p.providers.getOrDefault(msg.cId) - trace "providers:", provs + ##TODO: handle multiple messages ProvidersMessage(total: 1, provs: provs) - proc registerGetProviders(p: ProvidersProtocol) = proc handler(protocol: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte] {.gcsafe, raises: [Defect].} = - let returnMsg = recvGetProviders(p, fromId, request) + trace "<<< get_providers ", src = nodeId, dst = p.discovery.localNode.id, cid = msg.cId + let msg = GetProvidersMessage.decode(request).get() + + let returnMsg = recvGetProviders(p, fromId, msg) trace "returnMsg", returnMsg - returnMsg.encode() # TODO: response + + returnMsg.encode() let protocol = TalkProtocol(protocolHandler: handler) discard p.discovery.registerTalkProtocol(protoIdGetProviders, protocol) #TODO: handle error