Use addresses instead of stubs as db keys

This commit is contained in:
Yuriy Glukhov 2019-12-18 12:36:11 +02:00 committed by zah
parent 988d743c9a
commit 9772fbe470
5 changed files with 36 additions and 22 deletions

View File

@ -1,4 +1,5 @@
import types import std/net
import types, ../enode
import eth/trie/db import eth/trie/db
type type
@ -11,18 +12,28 @@ type
proc init*(T: type DiscoveryDB, backend: TrieDatabaseRef): DiscoveryDB = proc init*(T: type DiscoveryDB, backend: TrieDatabaseRef): DiscoveryDB =
T(backend: backend) T(backend: backend)
proc makeKey(id: NodeId, address: int): array[1 + sizeof(id) + sizeof(address), byte] = const keySize = 1 + # unique triedb prefix (kNodeToKeys)
sizeof(NodeId) +
16 + # max size of ip address (ipv6)
2 # Sizeof port
proc makeKey(id: NodeId, address: Address): array[keySize, byte] =
result[0] = byte(kNodeToKeys) result[0] = byte(kNodeToKeys)
copyMem(addr result[1], unsafeAddr id, sizeof(id)) copyMem(addr result[1], unsafeAddr id, sizeof(id))
copyMem(addr result[sizeof(id) + 1], unsafeAddr address, sizeof(address)) case address.ip.family
of IpAddressFamily.IpV4:
copyMem(addr result[sizeof(id) + 1], unsafeAddr address.ip.address_v4, sizeof(address.ip.address_v4))
of IpAddressFamily.IpV6:
copyMem(addr result[sizeof(id) + 1], unsafeAddr address.ip.address_v6, sizeof(address.ip.address_v6))
copyMem(addr result[sizeof(id) + 1 + sizeof(address.ip.address_v6)], unsafeAddr address.udpPort, sizeof(address.udpPort))
method storeKeys*(db: DiscoveryDB, id: NodeId, address: int, r, w: array[16, byte]) = method storeKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: array[16, byte]) =
var value: array[sizeof(r) + sizeof(w), byte] var value: array[sizeof(r) + sizeof(w), byte]
value[0 .. 15] = r value[0 .. 15] = r
value[16 .. ^1] = w value[16 .. ^1] = w
db.backend.put(makeKey(id, address), value) db.backend.put(makeKey(id, address), value)
method loadKeys*(db: DiscoveryDB, id: NodeId, address: int, r, w: var array[16, byte]): bool = method loadKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: var array[16, byte]): bool =
let res = db.backend.get(makeKey(id, address)) let res = db.backend.get(makeKey(id, address))
if res.len == sizeof(r) + sizeof(w): if res.len == sizeof(r) + sizeof(w):
copyMem(addr r[0], unsafeAddr res[0], sizeof(r)) copyMem(addr r[0], unsafeAddr res[0], sizeof(r))

View File

@ -1,5 +1,5 @@
import tables import tables
import types, node, enr, hkdf, eth/[rlp, keys], nimcrypto, stint import types, node, enr, hkdf, ../enode, eth/[rlp, keys], nimcrypto, stint
const const
idNoncePrefix = "discovery-id-nonce" idNoncePrefix = "discovery-id-nonce"
@ -114,7 +114,7 @@ proc packetTag(destNode, srcNode: NodeID): array[32, byte] =
let destidHash = sha256.digest(destId) let destidHash = sha256.digest(destId)
result = srcId xor destidHash.data result = srcId xor destidHash.data
proc encodeEncrypted*(c: Codec, toNode: Node, toAddr: int, packetData: seq[byte], challenge: Whoareyou): (seq[byte], array[gcmNonceSize, byte]) = proc encodeEncrypted*(c: Codec, toNode: Node, packetData: seq[byte], challenge: Whoareyou): (seq[byte], array[gcmNonceSize, byte]) =
var nonce: array[gcmNonceSize, byte] var nonce: array[gcmNonceSize, byte]
randomBytes(nonce) randomBytes(nonce)
var headEnc: seq[byte] var headEnc: seq[byte]
@ -127,14 +127,14 @@ proc encodeEncrypted*(c: Codec, toNode: Node, toAddr: int, packetData: seq[byte]
# We might not have the node's keys if the handshake hasn't been performed # We might not have the node's keys if the handshake hasn't been performed
# yet. That's fine, we will be responded with whoareyou. # yet. That's fine, we will be responded with whoareyou.
discard c.db.loadKeys(toNode.id, toAddr, readKey, writeKey) discard c.db.loadKeys(toNode.id, toNode.address, readKey, writeKey)
else: else:
var sec: HandshakeSecrets var sec: HandshakeSecrets
headEnc = c.makeAuthHeader(toNode, nonce, sec, challenge) headEnc = c.makeAuthHeader(toNode, nonce, sec, challenge)
writeKey = sec.writeKey writeKey = sec.writeKey
c.db.storeKeys(toNode.id, toAddr, sec.readKey, sec.writeKey) c.db.storeKeys(toNode.id, toNode.address, sec.readKey, sec.writeKey)
var body = packetData var body = packetData
let tag = packetTag(toNode.id, c.localNode.id) let tag = packetTag(toNode.id, c.localNode.id)
@ -201,7 +201,7 @@ proc decodeAuthResp(c: Codec, fromId: NodeId, head: AuthHeader, challenge: Whoar
newNode = newNode(authResp.record) newNode = newNode(authResp.record)
return true return true
proc decodeEncrypted*(c: var Codec, fromId: NodeID, fromAddr: int, input: seq[byte], authTag: var array[12, byte], newNode: var Node, packet: var Packet): bool = proc decodeEncrypted*(c: var Codec, fromId: NodeID, fromAddr: Address, input: seq[byte], authTag: var array[12, byte], newNode: var Node, packet: var Packet): bool =
let input = input.toRange let input = input.toRange
var r = rlpFromBytes(input[32 .. ^1]) var r = rlpFromBytes(input[32 .. ^1])
let authEndPos = r.currentElemEnd let authEndPos = r.currentElemEnd

View File

@ -45,6 +45,8 @@ proc newNode*(r: Record): Node =
proc hash*(n: Node): hashes.Hash = hash(n.node.pubkey.data) proc hash*(n: Node): hashes.Hash = hash(n.node.pubkey.data)
proc `==`*(a, b: Node): bool = (a.isNil and b.isNil) or (not a.isNil and not b.isNil and a.node.pubkey == b.node.pubkey) proc `==`*(a, b: Node): bool = (a.isNil and b.isNil) or (not a.isNil and not b.isNil and a.node.pubkey == b.node.pubkey)
proc address*(n: Node): Address {.inline.} = n.node.address
proc `$`*(n: Node): string = proc `$`*(n: Node): string =
if n == nil: if n == nil:
"Node[local]" "Node[local]"

View File

@ -92,7 +92,7 @@ proc sendWhoareyou(d: Protocol, address: Address, toNode: NodeId, authTag: array
proc sendNodes(d: Protocol, toNode: Node, reqId: RequestId, nodes: openarray[Node]) = proc sendNodes(d: Protocol, toNode: Node, reqId: RequestId, nodes: openarray[Node]) =
proc sendNodes(d: Protocol, toNode: Node, packet: NodesPacket, reqId: RequestId) {.nimcall.} = proc sendNodes(d: Protocol, toNode: Node, packet: NodesPacket, reqId: RequestId) {.nimcall.} =
let (data, _) = d.codec.encodeEncrypted(toNode, 12345, encodePacket(packet, reqId), challenge = nil) let (data, _) = d.codec.encodeEncrypted(toNode, encodePacket(packet, reqId), challenge = nil)
d.send(toNode, data) d.send(toNode, data)
const maxNodesPerPacket = 3 const maxNodesPerPacket = 3
@ -109,7 +109,8 @@ proc sendNodes(d: Protocol, toNode: Node, reqId: RequestId, nodes: openarray[Nod
if packet.enrs.len != 0: if packet.enrs.len != 0:
d.sendNodes(toNode, packet, reqId) d.sendNodes(toNode, packet, reqId)
proc handlePing(d: Protocol, fromNode: Node, a: Address, ping: PingPacket, reqId: RequestId) = proc handlePing(d: Protocol, fromNode: Node, ping: PingPacket, reqId: RequestId) =
let a = fromNode.address
var pong: PongPacket var pong: PongPacket
pong.enrSeq = ping.enrSeq pong.enrSeq = ping.enrSeq
pong.ip = case a.ip.family pong.ip = case a.ip.family
@ -117,10 +118,10 @@ proc handlePing(d: Protocol, fromNode: Node, a: Address, ping: PingPacket, reqId
of IpAddressFamily.IPv6: @(a.ip.address_v6) of IpAddressFamily.IPv6: @(a.ip.address_v6)
pong.port = a.udpPort.uint16 pong.port = a.udpPort.uint16
let (data, _) = d.codec.encodeEncrypted(fromNode, 12345, encodePacket(pong, reqId), challenge = nil) let (data, _) = d.codec.encodeEncrypted(fromNode, encodePacket(pong, reqId), challenge = nil)
d.send(fromNode, data) d.send(fromNode, data)
proc handleFindNode(d: Protocol, fromNode: Node, a: Address, fn: FindNodePacket, reqId: RequestId) = proc handleFindNode(d: Protocol, fromNode: Node, fn: FindNodePacket, reqId: RequestId) =
if fn.distance == 0: if fn.distance == 0:
d.sendNodes(fromNode, reqId, [d.localNode]) d.sendNodes(fromNode, reqId, [d.localNode])
else: else:
@ -142,7 +143,7 @@ proc receive*(d: Protocol, a: Address, msg: Bytes) {.gcsafe.} =
if d.pendingRequests.take(whoareyou.authTag, pr): if d.pendingRequests.take(whoareyou.authTag, pr):
let toNode = pr.node let toNode = pr.node
let (data, _) = d.codec.encodeEncrypted(toNode, 12345, pr.packet, challenge = whoareyou) let (data, _) = d.codec.encodeEncrypted(toNode, pr.packet, challenge = whoareyou)
d.send(toNode, data) d.send(toNode, data)
else: else:
@ -155,7 +156,7 @@ proc receive*(d: Protocol, a: Address, msg: Bytes) {.gcsafe.} =
var node: Node var node: Node
var packet: Packet var packet: Packet
if d.codec.decodeEncrypted(sender, 12345, msg, authTag, node, packet): if d.codec.decodeEncrypted(sender, a, msg, authTag, node, packet):
if node.isNil: if node.isNil:
node = d.routingTable.getNode(sender) node = d.routingTable.getNode(sender)
else: else:
@ -166,9 +167,9 @@ proc receive*(d: Protocol, a: Address, msg: Bytes) {.gcsafe.} =
case packet.kind case packet.kind
of ping: of ping:
d.handlePing(node, a, packet.ping, packet.reqId) d.handlePing(node, packet.ping, packet.reqId)
of findNode: of findNode:
d.handleFindNode(node, a, packet.findNode, packet.reqId) d.handleFindNode(node, packet.findNode, packet.reqId)
else: else:
var waiter: Future[Option[Packet]] var waiter: Future[Option[Packet]]
if d.awaitedPackets.take((node, packet.reqId), waiter): if d.awaitedPackets.take((node, packet.reqId), waiter):
@ -212,7 +213,7 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): Future[seq[Node]]
proc findNode(d: Protocol, toNode: Node, distance: uint32): Future[seq[Node]] {.async.} = proc findNode(d: Protocol, toNode: Node, distance: uint32): Future[seq[Node]] {.async.} =
let reqId = newRequestId() let reqId = newRequestId()
let packet = encodePacket(FindNodePacket(distance: distance), reqId) let packet = encodePacket(FindNodePacket(distance: distance), reqId)
let (data, nonce) = d.codec.encodeEncrypted(toNode, 12345, packet, challenge = nil) let (data, nonce) = d.codec.encodeEncrypted(toNode, packet, challenge = nil)
d.pendingRequests[nonce] = PendingRequest(node: toNode, packet: packet) d.pendingRequests[nonce] = PendingRequest(node: toNode, packet: packet)
d.send(toNode, data) d.send(toNode, data)
result = await d.waitNodes(toNode, reqId) result = await d.waitNodes(toNode, reqId)
@ -323,7 +324,7 @@ when isMainModule:
result.add(d) result.add(d)
proc addNode(d: openarray[Protocol], enr: string) = proc addNode(d: openarray[Protocol], enr: string) =
for dd in d: dd.addNode(enr) for dd in d: dd.addNode(EnrUri(enr))
proc test() {.async.} = proc test() {.async.} =
block: block:

View File

@ -68,8 +68,8 @@ template packetKind*(T: typedesc[SomePacket]): PacketKind =
elif T is FindNodePacket: findNode elif T is FindNodePacket: findNode
elif T is NodesPacket: nodes elif T is NodesPacket: nodes
method storeKeys*(db: Database, id: NodeId, address: int, r, w: array[16, byte]) {.base.} = discard method storeKeys*(db: Database, id: NodeId, address: Address, r, w: array[16, byte]) {.base.} = discard
method loadKeys*(db: Database, id: NodeId, address: int, r, w: var array[16, byte]): bool {.base.} = discard method loadKeys*(db: Database, id: NodeId, address: Address, r, w: var array[16, byte]): bool {.base.} = discard
proc toBytes*(id: NodeId): array[32, byte] {.inline.} = proc toBytes*(id: NodeId): array[32, byte] {.inline.} =
id.toByteArrayBE() id.toByteArrayBE()