Replace memory db with lrucache for temporary storage of sessions (#292)

This commit is contained in:
Kim De Mey 2020-09-10 14:49:48 +02:00 committed by GitHub
parent 8e8c982270
commit c9caafb2a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 261 additions and 125 deletions

View File

@ -50,6 +50,7 @@ proc runP2pTests() =
"test_protocol_handlers",
"test_enr",
"test_hkdf",
"test_lru",
"test_discoveryv5",
"test_discv5_encoding",
"test_routing_table"
@ -94,6 +95,7 @@ proc runDiscv5Tests() =
for filename in [
"test_enr",
"test_hkdf",
"test_lru",
"test_discoveryv5",
"test_discv5_encoding",
"test_routing_table"
@ -103,5 +105,4 @@ proc runDiscv5Tests() =
task test_discv5, "run tests of discovery v5 and its dependencies":
runKeysTests()
runRlpTests()
runTrieTests() # This probably tests a bit to much for what we use it for.
runDiscv5Tests()

View File

@ -2,7 +2,7 @@ import
std/[options, strutils],
chronos, chronicles, chronicles/topics_registry, confutils, metrics,
stew/byteutils, confutils/std/net,
eth/keys, eth/trie/db, eth/net/nat, protocol, discovery_db, enr, node
eth/keys, eth/net/nat, protocol, enr, node
type
DiscoveryCmd* = enum
@ -145,9 +145,7 @@ proc setupNat(conf: DiscoveryConf): tuple[ip: Option[ValidIpAddress],
proc run(config: DiscoveryConf) =
let
(ip, tcpPort, udpPort) = setupNat(config)
ddb = DiscoveryDB.init(newMemoryDB())
# TODO: newProtocol should allow for no tcpPort
d = newProtocol(config.nodeKey, ddb, ip, tcpPort, udpPort,
d = newProtocol(config.nodeKey, ip, tcpPort, udpPort,
bootstrapRecords = config.bootnodes)
d.open()

View File

@ -1,63 +0,0 @@
import
stint, stew/endians2, stew/shims/net,
eth/trie/db, types, node
{.push raises: [Defect].}
type
DiscoveryDB* = ref object of Database
backend: TrieDatabaseRef
DbKeyKind = enum
kNodeToKeys = 100
proc init*(T: type DiscoveryDB, backend: TrieDatabaseRef): DiscoveryDB =
T(backend: backend)
const keySize = 1 + # unique triedb prefix (kNodeToKeys)
sizeof(NodeId) +
16 + # max size of ip address (ipv6)
2 # Sizeof port
proc makeKey(id: NodeId, address: Address): array[keySize, byte] =
result[0] = byte(kNodeToKeys)
var pos = 1
result[pos ..< pos+sizeof(id)] = toBytes(id)
pos.inc(sizeof(id))
case address.ip.family
of IpAddressFamily.IpV4:
result[pos ..< pos+sizeof(address.ip.address_v4)] = address.ip.address_v4
of IpAddressFamily.IpV6:
result[pos ..< pos+sizeof(address.ip.address_v6)] = address.ip.address_v6
pos.inc(sizeof(address.ip.address_v6))
result[pos ..< pos+sizeof(address.port)] = toBytes(address.port.uint16)
method storeKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: AesKey):
bool =
try:
var value: array[sizeof(r) + sizeof(w), byte]
value[0 .. 15] = r
value[16 .. ^1] = w
db.backend.put(makeKey(id, address), value)
return true
except CatchableError:
return false
method loadKeys*(db: DiscoveryDB, id: NodeId, address: Address,
r, w: var AesKey): bool =
try:
let res = db.backend.get(makeKey(id, address))
if res.len != sizeof(r) + sizeof(w):
return false
copyMem(addr r[0], unsafeAddr res[0], sizeof(r))
copyMem(addr w[0], unsafeAddr res[sizeof(r)], sizeof(w))
return true
except CatchableError:
return false
method deleteKeys*(db: DiscoveryDB, id: NodeId, address: Address): bool =
try:
db.backend.del(makeKey(id, address))
return true
except CatchableError:
return false

View File

@ -1,7 +1,7 @@
import
std/[tables, options],
nimcrypto, stint, chronicles, stew/results, bearssl,
eth/[rlp, keys], types, node, enr, hkdf
eth/[rlp, keys], types, node, enr, hkdf, sessions
export keys
@ -27,8 +27,8 @@ type
Codec* = object
localNode*: Node
privKey*: PrivateKey
db*: Database
handshakes*: Table[HandShakeKey, Whoareyou]
sessions*: Sessions
HandshakeSecrets = object
writeKey: AesKey
@ -142,7 +142,7 @@ proc packetTag(destNode, srcNode: NodeID): PacketTag =
proc encodePacket*(
rng: var BrHmacDrbgContext,
c: Codec,
c: var Codec,
toId: NodeID,
toAddr: Address,
message: openarray[byte],
@ -165,9 +165,7 @@ proc encodePacket*(
# TODO: Should we change API to get just the key we need?
var writeKey, readKey: AesKey
# We might not have the node's keys if the handshake hasn't been performed
# yet. That's fine, we will be responded with whoareyou.
if c.db.loadKeys(toId, toAddr, readKey, writeKey):
if c.sessions.load(toId, toAddr, readKey, writeKey):
packet.add(encryptGCM(writeKey, nonce, message, tag))
else:
# We might not have the node's keys if the handshake hasn't been performed
@ -182,8 +180,7 @@ proc encodePacket*(
let (headEnc, secrets) = encodeAuthHeader(rng, c, toId, nonce, challenge)
packet.add(headEnc)
if not c.db.storeKeys(toId, toAddr, secrets.readKey, secrets.writeKey):
warn "Storing of keys for session failed, will have to redo a handshake"
c.sessions.store(toId, toAddr, secrets.readKey, secrets.writeKey)
packet.add(encryptGCM(secrets.writeKey, nonce, message, tag))
@ -340,8 +337,7 @@ proc decodePacket*(c: var Codec,
# Swap keys to match remote
swap(sec.readKey, sec.writeKey)
if not c.db.storeKeys(fromId, fromAddr, sec.readKey, sec.writeKey):
warn "Storing of keys for session failed, will have to redo a handshake"
c.sessions.store(fromId, fromAddr, sec.readKey, sec.writeKey)
readKey = sec.readKey
else:
# Message packet or random packet - rlp bytes (size 12) indicates auth-tag
@ -352,7 +348,7 @@ proc decodePacket*(c: var Codec,
auth.auth = authTag
# TODO: Should we change API to get just the key we need?
var writeKey: AesKey
if not c.db.loadKeys(fromId, fromAddr, readKey, writeKey):
if not c.sessions.load(fromId, fromAddr, readKey, writeKey):
trace "Decoding failed (no keys)"
return err(DecryptError)
@ -363,7 +359,7 @@ proc decodePacket*(c: var Codec,
input.toOpenArray(headSize, input.high),
input.toOpenArray(0, tagSize - 1))
if message.isNone():
discard c.db.deleteKeys(fromId, fromAddr)
c.sessions.del(fromId, fromAddr)
return err(DecryptError)
decodeMessage(message.get())

View File

@ -0,0 +1,41 @@
import std/[tables, lists, options]
{.push raises: [Defect].}
type
LRUCache*[K, V] = object of RootObj
list: DoublyLinkedList[(K, V)] # Head is MRU k:v and tail is LRU k:v
table: Table[K, DoublyLinkedNode[(K, V)]] # DoublyLinkedNode is alraedy ref
capacity: int
proc init*[K, V](T: type LRUCache[K, V], capacity: int): LRUCache[K, V] =
LRUCache[K, V](capacity: capacity) # Table and list init is done default
proc get*[K, V](lru: var LRUCache[K, V], key: K): Option[V] =
let node = lru.table.getOrDefault(key, nil)
if node.isNil:
return none(V)
lru.list.remove(node)
lru.list.prepend(node)
return some(node.value[1])
proc put*[K, V](lru: var LRUCache[K, V], key: K, value: V) =
let node = lru.table.getOrDefault(key, nil)
if not node.isNil:
lru.list.remove(node)
else:
if lru.table.len >= lru.capacity:
lru.table.del(lru.list.tail.value[0])
lru.list.remove(lru.list.tail)
lru.list.prepend((key, value))
lru.table[key] = lru.list.head
proc del*[K, V](lru: var LRUCache[K, V], key: K) =
var node: DoublyLinkedNode[(K, V)]
if lru.table.pop(key, node):
lru.list.remove(node)
proc len*[K, V](lru: LRUCache[K, V]): int =
lru.table.len

View File

@ -77,7 +77,7 @@ import
stew/shims/net as stewNet, json_serialization/std/net,
stew/[byteutils, endians2], chronicles, chronos, stint, bearssl,
eth/[rlp, keys, async_utils],
types, encoding, node, routing_table, enr, random2
types, encoding, node, routing_table, enr, random2, sessions
import nimcrypto except toHex
@ -114,7 +114,6 @@ type
whoareyouMagic: array[magicSize, byte]
idHash: array[32, byte]
pendingRequests: Table[AuthTag, PendingRequest]
db: Database
routingTable: RoutingTable
codec*: Codec
awaitedMessages: Table[(NodeId, RequestId), Future[Option[Message]]]
@ -477,13 +476,6 @@ proc validIp(sender, address: IpAddress): bool {.raises: [Defect].} =
proc replaceNode(d: Protocol, n: Node) =
if n.record notin d.bootstrapRecords:
d.routingTable.replaceNode(n)
# Remove shared secrets when removing the node from routing table.
# TODO: This might be to direct, so we could keep these longer. But better
# would be to simply not remove the nodes immediatly but use an LRU cache.
# Also because some shared secrets will be with nodes not eligable for
# the routing table, and these don't get deleted now, see issue:
# https://github.com/status-im/nim-eth/issues/242
discard d.codec.db.deleteKeys(n.id, n.address.get())
else:
# For now we never remove bootstrap nodes. It might make sense to actually
# do so and to retry them only in case we drop to a really low amount of
@ -758,7 +750,7 @@ proc lookupLoop(d: Protocol) {.async, raises: [Exception, Defect].} =
except CancelledError:
trace "lookupLoop canceled"
proc newProtocol*(privKey: PrivateKey, db: Database,
proc newProtocol*(privKey: PrivateKey,
externalIp: Option[ValidIpAddress], tcpPort, udpPort: Port,
localEnrFields: openarray[(string, seq[byte])] = [],
bootstrapRecords: openarray[Record] = [],
@ -789,12 +781,12 @@ proc newProtocol*(privKey: PrivateKey, db: Database,
result = Protocol(
privateKey: privKey,
db: db,
localNode: node,
bindAddress: Address(ip: ValidIpAddress.init(bindIp), port: udpPort),
whoareyouMagic: whoareyouMagic(node.id),
idHash: sha256.digest(node.id.toByteArrayBE).data,
codec: Codec(localNode: node, privKey: privKey, db: db),
codec: Codec(localNode: node, privKey: privKey,
sessions: Sessions.init(256)),
bootstrapRecords: @bootstrapRecords,
rng: rng)

View File

@ -0,0 +1,48 @@
import
std/options,
stint, stew/endians2, stew/shims/net,
types, node, lru
export lru
{.push raises: [Defect].}
const keySize = sizeof(NodeId) +
16 + # max size of ip address (ipv6)
2 # Sizeof port
type
SessionKey* = array[keySize, byte]
SessionValue* = array[sizeof(AesKey) + sizeof(AesKey), byte]
Sessions* = LRUCache[SessionKey, SessionValue]
proc makeKey(id: NodeId, address: Address): SessionKey =
var pos = 0
result[pos ..< pos+sizeof(id)] = toBytes(id)
pos.inc(sizeof(id))
case address.ip.family
of IpAddressFamily.IpV4:
result[pos ..< pos+sizeof(address.ip.address_v4)] = address.ip.address_v4
of IpAddressFamily.IpV6:
result[pos ..< pos+sizeof(address.ip.address_v6)] = address.ip.address_v6
pos.inc(sizeof(address.ip.address_v6))
result[pos ..< pos+sizeof(address.port)] = toBytes(address.port.uint16)
proc store*(s: var Sessions, id: NodeId, address: Address, r, w: AesKey) =
var value: array[sizeof(r) + sizeof(w), byte]
value[0 .. 15] = r
value[16 .. ^1] = w
s.put(makeKey(id, address), value)
proc load*(s: var Sessions, id: NodeId, address: Address, r, w: var AesKey): bool =
let res = s.get(makeKey(id, address))
if res.isSome():
let val = res.get()
copyMem(addr r[0], unsafeAddr val[0], sizeof(r))
copyMem(addr w[0], unsafeAddr val[sizeof(r)], sizeof(w))
return true
else:
return false
proc del*(s: var Sessions, id: NodeId, address: Address) =
s.del(makeKey(id, address))

View File

@ -27,8 +27,6 @@ type
Whoareyou* = ref WhoareyouObj
Database* = ref object of RootRef
MessageKind* = enum
# TODO This is needed only to make Nim 1.0.4 happy
# Without it, the `MessageKind` type cannot be used as
@ -84,15 +82,6 @@ template messageKind*(T: typedesc[SomeMessage]): MessageKind =
elif T is FindNodeMessage: findNode
elif T is NodesMessage: nodes
method storeKeys*(db: Database, id: NodeId, address: Address,
r, w: AesKey): bool {.base.} = discard
method loadKeys*(db: Database, id: NodeId, address: Address,
r, w: var AesKey): bool {.base.} = discard
method deleteKeys*(db: Database, id: NodeId, address: Address):
bool {.base.} = discard
proc toBytes*(id: NodeId): array[32, byte] {.inline.} =
id.toByteArrayBE()

View File

@ -1,7 +1,7 @@
import
testutils/unittests, stew/shims/net, bearssl,
eth/[keys, rlp, trie/db],
eth/p2p/discoveryv5/[discovery_db, enr, node, types, routing_table, encoding],
eth/[keys, rlp],
eth/p2p/discoveryv5/[enr, node, types, routing_table, encoding],
eth/p2p/discoveryv5/protocol as discv5_protocol
proc localAddress*(port: int): Address =
@ -13,8 +13,7 @@ proc initDiscoveryNode*(rng: ref BrHmacDrbgContext, privKey: PrivateKey,
localEnrFields: openarray[(string, seq[byte])] = [],
previousRecord = none[enr.Record]()):
discv5_protocol.Protocol =
var db = DiscoveryDB.init(newMemoryDB())
result = newProtocol(privKey, db,
result = newProtocol(privKey,
some(address.ip),
address.port, address.port,
bootstrapRecords = bootstrapRecords,

View File

@ -1,7 +1,8 @@
import
chronos, chronicles, tables, stint, testutils/unittests,
stew/shims/net, eth/[keys, trie/db], bearssl,
eth/p2p/discoveryv5/[enr, node, types, routing_table, encoding, discovery_db],
std/tables,
chronos, chronicles, stint, testutils/unittests,
stew/shims/net, eth/keys, bearssl,
eth/p2p/discoveryv5/[enr, node, types, routing_table, encoding],
eth/p2p/discoveryv5/protocol as discv5_protocol,
./discv5_test_helper
@ -330,10 +331,14 @@ procSuite "Discovery v5 Tests":
# updated ENR.
block:
targetNode.open()
# ping to node again to add as it was removed after failed findNode in
# resolve in previous test block.
let pong = await targetNode.ping(mainNode.localNode)
check pong.isOk()
# Request the target ENR and manually add it to the routing table.
# Ping for handshake based ENR passing will not work as our previous
# session will still be in the LRU cache.
let nodes = await mainNode.findNode(targetNode.localNode, 0)
check:
nodes.isOk()
nodes[].len == 1
mainNode.addNode(nodes[][0])
targetSeqNum.inc()
# need to add something to get the enr sequence number incremented
@ -405,13 +410,12 @@ procSuite "Discovery v5 Tests":
privKey = PrivateKey.random(rng[])
ip = some(ValidIpAddress.init("127.0.0.1"))
port = Port(20301)
db = DiscoveryDB.init(newMemoryDB())
node = newProtocol(privKey, db, ip, port, port, rng = rng)
noUpdatesNode = newProtocol(privKey, db, ip, port, port, rng = rng,
node = newProtocol(privKey, ip, port, port, rng = rng)
noUpdatesNode = newProtocol(privKey, ip, port, port, rng = rng,
previousRecord = some(node.getRecord()))
updatesNode = newProtocol(privKey, db, ip, port, Port(20302), rng = rng,
updatesNode = newProtocol(privKey, ip, port, Port(20302), rng = rng,
previousRecord = some(noUpdatesNode.getRecord()))
moreUpdatesNode = newProtocol(privKey, db, ip, port, port, rng = rng,
moreUpdatesNode = newProtocol(privKey, ip, port, port, rng = rng,
localEnrFields = {"addfield": @[byte 0]},
previousRecord = some(updatesNode.getRecord()))
check:
@ -423,7 +427,7 @@ procSuite "Discovery v5 Tests":
# Defect (for now?) on incorrect key use
expect ResultDefect:
let incorrectKeyUpdates = newProtocol(PrivateKey.random(rng[]),
db, ip, port, port, rng = rng,
ip, port, port, rng = rng,
previousRecord = some(updatesNode.getRecord()))
asyncTest "Update node record with revalidate":

View File

@ -1,5 +1,6 @@
import
unittest, options, sequtils, stint, stew/byteutils, stew/shims/net,
std/[unittest, options, sequtils],
stint, stew/byteutils, stew/shims/net,
eth/[rlp, keys] , eth/p2p/discoveryv5/[types, encoding, enr, node]
# According to test vectors:

View File

@ -1,5 +1,5 @@
import
unittest, options, sequtils,
std/[unittest, options, sequtils],
nimcrypto/utils, stew/shims/net,
eth/p2p/enode, eth/p2p/discoveryv5/enr, eth/[keys, rlp]

131
tests/p2p/test_lru.nim Normal file
View File

@ -0,0 +1,131 @@
import
std/[unittest, options],
eth/p2p/discoveryv5/lru
suite "LRUCache":
const
capacity = 10
target = 4
test "LRU value gets removed":
var lru = LRUCache[int, int].init(capacity = capacity)
# Fully fill the LRU
for i in 0..<capacity:
lru.put(i, i) # new key, so new put
# Get value for each key
for i in 0..<capacity:
let val = lru.get(i)
check:
val.isSome()
val.get() == i
check lru.len() == capacity
# Add one new key
lru.put(capacity, 0)
# Oldest one should be gone
check:
lru.len() == capacity
lru.get(0).isNone()
lru.get(capacity).isSome()
test "LRU renew oldest by get":
var lru = LRUCache[int, int].init(capacity = capacity)
for i in 0..<capacity:
lru.put(i, i)
var val = lru.get(0)
check:
val.isSome
val.get() == 0
lru.put(capacity, 0)
val = lru.get(0)
check:
lru.len() == capacity
val.isSome()
val.get() == 0
test "LRU renew oldest by put":
var lru = LRUCache[int, int].init(capacity = capacity)
for i in 0..<capacity:
lru.put(i, i)
lru.put(0, 1)
check lru.len() == capacity
lru.put(capacity, 0)
let val = lru.get(0)
check:
lru.len() == capacity
val.isSome()
val.get() == 1
test "LRU renew by put":
var lru = LRUCache[int, int].init(capacity = capacity)
for i in 0..<capacity:
lru.put(i, i)
lru.put(target, 1)
check lru.len() == capacity
lru.put(capacity, 0)
let val = lru.get(target)
check:
lru.len() == capacity
val.isSome()
val.get() == 1
test "LRU renew by get":
var lru = LRUCache[int, int].init(capacity = capacity)
for i in 0..<capacity:
lru.put(i, i)
var val = lru.get(target)
check:
val.isSome
val.get() == target
lru.put(capacity, 0)
val = lru.get(target)
check:
lru.len() == capacity
val.isSome()
val.get() == target
test "LRU delete oldest and add":
var lru = LRUCache[int, int].init(capacity = capacity)
for i in 0..<capacity:
lru.put(i, i)
lru.del(0)
check lru.len == capacity - 1
lru.put(0, 1)
check lru.len == capacity
lru.put(capacity, 0)
let val = lru.get(0)
check:
lru.len() == capacity
val.isSome()
val.get() == 1
test "LRU delete not existing":
var lru = LRUCache[int, int].init(capacity = capacity)
for i in 0..<capacity:
lru.put(i, i)
lru.del(capacity)
check lru.len == capacity

View File

@ -1,7 +1,6 @@
import
unittest, bearssl,
eth/keys,
eth/p2p/discoveryv5/[routing_table, node],
std/unittest,
bearssl, eth/keys, eth/p2p/discoveryv5/[routing_table, node],
./discv5_test_helper
suite "Routing Table Tests":