diff --git a/eth/p2p/discoveryv5/discovery_db.nim b/eth/p2p/discoveryv5/discovery_db.nim index bb9681d..64b8293 100644 --- a/eth/p2p/discoveryv5/discovery_db.nim +++ b/eth/p2p/discoveryv5/discovery_db.nim @@ -27,7 +27,8 @@ proc makeKey(id: NodeId, address: Address): array[keySize, byte] = copyMem(addr result[sizeof(id) + 1], unsafeAddr address.ip.address_v6, sizeof(address.ip.address_v6)) copyMem(addr result[sizeof(id) + 1 + sizeof(address.ip.address_v6)], unsafeAddr address.udpPort, sizeof(address.udpPort)) -method storeKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: array[16, byte]): bool {.raises: [Defect].} = +method storeKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: AesKey): + bool {.raises: [Defect].} = try: var value: array[sizeof(r) + sizeof(w), byte] value[0 .. 15] = r @@ -37,7 +38,8 @@ method storeKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: array[16, except CatchableError: return false -method loadKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: var array[16, byte]): bool {.raises: [Defect].} = +method loadKeys*(db: DiscoveryDB, id: NodeId, address: Address, r, w: var AesKey): + bool {.raises: [Defect].} = try: let res = db.backend.get(makeKey(id, address)) if res.len != sizeof(r) + sizeof(w): diff --git a/eth/p2p/discoveryv5/encoding.nim b/eth/p2p/discoveryv5/encoding.nim index 5ba323e..23e9f3d 100644 --- a/eth/p2p/discoveryv5/encoding.nim +++ b/eth/p2p/discoveryv5/encoding.nim @@ -8,13 +8,12 @@ const authSchemeName* = "gcm" gcmNonceSize* = 12 gcmTagSize = 16 - aesKeySize* = 128 div 8 tagSize* = 32 ## size of the tag where each message (except whoareyou) starts ## with type - AesKey = array[aesKeySize, byte] - PacketTag = array[tagSize, byte] + + PacketTag* = array[tagSize, byte] AuthResponse = object version: int @@ -25,7 +24,7 @@ type localNode*: Node privKey*: PrivateKey db*: Database - handshakes*: Table[string, Whoareyou] # TODO: Implement type & hash for NodeID + address + handshakes*: Table[HandShakeKey, Whoareyou] HandshakeSecrets = object writeKey: AesKey @@ -247,7 +246,8 @@ proc decodeEncrypted*(c: var Codec, auth = r.read(AuthHeader) authTag = auth.auth - let challenge = c.handshakes.getOrDefault($fromId & $fromAddr) + let key = HandShakeKey(nodeId: fromId, address: $fromAddr) + let challenge = c.handshakes.getOrDefault(key) if challenge.isNil: trace "Decoding failed (no challenge)" return HandshakeError @@ -260,7 +260,7 @@ proc decodeEncrypted*(c: var Codec, if not c.decodeAuthResp(fromId, auth, challenge, sec, newNode): trace "Decoding failed (bad auth)" return HandshakeError - c.handshakes.del($fromId & $fromAddr) + c.handshakes.del(key) # Swap keys to match remote swap(sec.readKey, sec.writeKey) @@ -272,7 +272,7 @@ proc decodeEncrypted*(c: var Codec, # Message packet or random packet - rlp bytes (size 12) indicates auth-tag authTag = r.read(AuthTag) auth.auth = authTag - var writeKey: array[aesKeySize, byte] + var writeKey: AesKey if not c.db.loadKeys(fromId, fromAddr, readKey, writeKey): trace "Decoding failed (no keys)" return PacketError diff --git a/eth/p2p/discoveryv5/protocol.nim b/eth/p2p/discoveryv5/protocol.nim index 0f1a92f..f8bd3fe 100644 --- a/eth/p2p/discoveryv5/protocol.nim +++ b/eth/p2p/discoveryv5/protocol.nim @@ -102,11 +102,13 @@ proc sendWhoareyou(d: Protocol, address: Address, toNode: NodeId, authTag: AuthT # a loop. # Use toNode + address to make it more difficult for an attacker to occupy # the handshake of another node. - if not d.codec.handshakes.hasKeyOrPut($toNode & $address, challenge): + + let key = HandShakeKey(nodeId: toNode, address: $address) + if not d.codec.handshakes.hasKeyOrPut(key, challenge): sleepAsync(handshakeTimeout).addCallback() do(data: pointer): # TODO: should we still provide cancellation in case handshake completes # correctly? - d.codec.handshakes.del($toNode & $address) + d.codec.handshakes.del(key) var data = @(whoareyouMagic(toNode)) data.add(rlp.encode(challenge[])) diff --git a/eth/p2p/discoveryv5/types.nim b/eth/p2p/discoveryv5/types.nim index f56ea21..a60b01e 100644 --- a/eth/p2p/discoveryv5/types.nim +++ b/eth/p2p/discoveryv5/types.nim @@ -5,11 +5,17 @@ import const authTagSize* = 12 idNonceSize* = 32 + aesKeySize* = 128 div 8 type NodeId* = UInt256 AuthTag* = array[authTagSize, byte] IdNonce* = array[idNonceSize, byte] + AesKey* = array[aesKeySize, byte] + + HandshakeKey* = object + nodeId*: NodeId + address*: string # TODO: Replace with Address, need hash WhoareyouObj* = object authTag*: AuthTag @@ -75,12 +81,23 @@ template packetKind*(T: typedesc[SomePacket]): PacketKind = elif T is FindNodePacket: findNode elif T is NodesPacket: nodes -method storeKeys*(db: Database, id: NodeId, address: Address, r, w: array[16, byte]): bool {.base, raises: [Defect].} = discard +method storeKeys*(db: Database, id: NodeId, address: Address, r, w: AesKey): + bool {.base, raises: [Defect].} = discard -method loadKeys*(db: Database, id: NodeId, address: Address, r, w: var array[16, byte]): bool {.base, raises: [Defect].} = discard +method loadKeys*(db: Database, id: NodeId, address: Address, r, w: var AesKey): + bool {.base, raises: [Defect].} = discard proc toBytes*(id: NodeId): array[32, byte] {.inline.} = id.toByteArrayBE() proc hash*(id: NodeId): Hash {.inline.} = - hashData(unsafeAddr id, sizeof(id)) + result = hashData(unsafeAddr id, sizeof(id)) + +# TODO: To make this work I think we also need to implement `==` due to case +# fields in object +proc hash*(address: Address): Hash {.inline.} = + hashData(unsafeAddr address, sizeof(address)) + +proc hash*(key: HandshakeKey): Hash = + result = key.nodeId.hash !& key.address.hash + result = !$result diff --git a/eth/p2p/enode.nim b/eth/p2p/enode.nim index 342e4b0..66e0f5a 100644 --- a/eth/p2p/enode.nim +++ b/eth/p2p/enode.nim @@ -160,3 +160,8 @@ proc `$`*(n: ENode): string = result.add("?") result.add("discport=") result.add($int(n.address.udpPort)) + +proc `$`*(a: Address): string = + result.add($a.ip) + result.add(":" & $a.udpPort) + result.add(":" & $a.tcpPort) diff --git a/tests/p2p/test_discoveryv5.nim b/tests/p2p/test_discoveryv5.nim index d6c5dac..96516c2 100644 --- a/tests/p2p/test_discoveryv5.nim +++ b/tests/p2p/test_discoveryv5.nim @@ -21,6 +21,18 @@ proc nodeIdInNodes(id: NodeId, nodes: openarray[Node]): bool = for n in nodes: if id == n.id: return true +# Creating a random packet with specific nodeid each time +proc randomPacket(tag: PacketTag): seq[byte] = + var + authTag: AuthTag + msg: array[44, byte] + + randomBytes(authTag) + randomBytes(msg) + result.add(tag) + result.add(rlp.encode(authTag)) + result.add(msg) + suite "Discovery v5 Tests": asyncTest "Random nodes": let @@ -75,37 +87,49 @@ suite "Discovery v5 Tests": for node in nodes: await node.closeWait() - asyncTest "Handshakes": + asyncTest "Handshake cleanup": let node = initDiscoveryNode(newPrivateKey(), localAddress(20302), @[]) - - # Creating a random packet with different nodeid each time - proc randomPacket(): seq[byte] = - var - tag: array[32, byte] - authTag: array[12, byte] - msg: array[44, byte] - - randomBytes(tag) - randomBytes(authTag) - randomBytes(msg) - result.add(tag) - result.add(rlp.encode(authTag)) - result.add(msg) - + var tag: PacketTag let a = localAddress(20303) - for i in 0 ..< 5: - node.receive(a, randomPacket()) + for i in 0 ..< 5: + randomBytes(tag) + node.receive(a, randomPacket(tag)) + + # Checking different nodeIds but same address check node.codec.handshakes.len == 5 + # TODO: Could get rid of the sleep by storing the timeout future of the + # handshake await sleepAsync(handshakeTimeout) # Checking handshake cleanup check node.codec.handshakes.len == 0 - let packet = randomPacket() + await node.closeWait() + + asyncTest "Handshake different address": + let node = initDiscoveryNode(newPrivateKey(), localAddress(20302), @[]) + var tag: PacketTag + for i in 0 ..< 5: - node.receive(a, packet) + let a = localAddress(20303 + i) + node.receive(a, randomPacket(tag)) + + check node.codec.handshakes.len == 5 + + await node.closeWait() + + asyncTest "Handshake duplicates": + let node = initDiscoveryNode(newPrivateKey(), localAddress(20302), @[]) + var tag: PacketTag + let a = localAddress(20303) + + for i in 0 ..< 5: + node.receive(a, randomPacket(tag)) # Checking handshake duplicates check node.codec.handshakes.len == 1 + # TODO: add check that gets the Whoareyou value and checks if its authTag + # is that of the first packet. + await node.closeWait()