diff --git a/eth/p2p/discoveryv5/discovery_db.nim b/eth/p2p/discoveryv5/discovery_db.nim new file mode 100644 index 0000000..0fa0ed1 --- /dev/null +++ b/eth/p2p/discoveryv5/discovery_db.nim @@ -0,0 +1,30 @@ +import types +import eth/trie/db + +type + DiscoveryDB* = ref object of Database + backend: TrieDatabaseRef + + DbKeyKind = enum + kNodeToKeys = 100 + +proc init*(T: type DiscoveryDB, backend: TrieDatabaseRef): DiscoveryDB = + T(backend: backend) + +proc makeKey(id: NodeId, address: int): array[1 + sizeof(id) + sizeof(address), byte] = + result[0] = byte(kNodeToKeys) + copyMem(addr result[1], unsafeAddr id, sizeof(id)) + copyMem(addr result[sizeof(id) + 1], unsafeAddr address, sizeof(address)) + +method storeKeys*(db: DiscoveryDB, id: NodeId, address: int, r, w: array[16, byte]) = + var value: array[sizeof(r) + sizeof(w), byte] + value[0 .. 15] = r + value[16 .. ^1] = w + db.backend.put(makeKey(id, address), value) + +method loadKeys*(db: DiscoveryDB, id: NodeId, address: int, r, w: var array[16, byte]): bool = + let res = db.backend.get(makeKey(id, address)) + if res.len == sizeof(r) + sizeof(w): + copyMem(addr r[0], unsafeAddr res[0], sizeof(r)) + copyMem(addr w[0], unsafeAddr res[sizeof(r)], sizeof(w)) + result = true diff --git a/eth/p2p/discoveryv5/encoding.nim b/eth/p2p/discoveryv5/encoding.nim new file mode 100644 index 0000000..e5cbd1d --- /dev/null +++ b/eth/p2p/discoveryv5/encoding.nim @@ -0,0 +1,274 @@ +import tables +import types, node, enr, hkdf, eth/[rlp, keys], nimcrypto, stint + +const + idNoncePrefix = "discovery-id-nonce" + gcmNonceSize* = 12 + keyAgreementPrefix = "discovery v5 key agreement" + authSchemeName = "gcm" + +type + AuthResponse = object + version: int + signature: array[64, byte] + record: Record + + Codec* = object + localNode*: Node + privKey*: PrivateKey + db*: Database + handshakes*: Table[string, Whoareyou] # TODO: Implement hash for NodeID + + HandshakeSecrets = object + writeKey: array[16, byte] + readKey: array[16, byte] + authRespKey: array[16, byte] + + AuthHeader = object + auth: array[12, byte] + idNonce: array[32, byte] + scheme: string + ephemeralKey: array[64, byte] + response: seq[byte] + + +const + gcmTagSize = 16 + +proc randomBytes(v: var openarray[byte]) = + if nimcrypto.randomBytes(v) != v.len: + raise newException(Exception, "Could not randomize bytes") # TODO: + +proc idNonceHash(nonce, ephkey: openarray[byte]): array[32, byte] = + var ctx: sha256 + ctx.init() + ctx.update(idNoncePrefix) + ctx.update(nonce) + ctx.update(ephkey) + ctx.finish().data + +proc signIDNonce(c: Codec, idNonce, ephKey: openarray[byte]): SignatureNR = + if signRawMessage(idNonceHash(idNonce, ephKey), c.privKey, result) != EthKeysStatus.Success: + raise newException(Exception, "Could not sign idNonce") + +proc deriveKeys(n1, n2: NodeID, priv: PrivateKey, pub: PublicKey, challenge: Whoareyou, result: var HandshakeSecrets) = + var eph: SharedSecretFull + if ecdhAgree(priv, pub, eph) != EthKeysStatus.Success: + raise newException(Exception, "ecdhAgree failed") + + # TODO: Unneeded allocation here + var info = newSeqOfCap[byte](idNoncePrefix.len + 32 * 2) + for i, c in keyAgreementPrefix: info.add(byte(c)) + info.add(n1.toByteArrayBE()) + info.add(n2.toByteArrayBE()) + + # echo "EPH: ", eph.data.toHex, " idNonce: ", challenge.idNonce.toHex, "info: ", info.toHex + + static: assert(sizeof(result) == 16 * 3) + var res = cast[ptr UncheckedArray[byte]](addr result) + hkdf(sha256, eph.data, challenge.idNonce, info, toOpenArray(res, 0, sizeof(result) - 1)) + +proc encryptGCM(key, nonce, pt, authData: openarray[byte]): seq[byte] = + var ectx: GCM[aes128] + ectx.init(key, nonce, authData) + result = newSeq[byte](pt.len + gcmTagSize) + ectx.encrypt(pt, result) + ectx.getTag(result.toOpenArray(pt.len, result.high)) + ectx.clear() + +proc makeAuthHeader(c: Codec, toNode: Node, nonce: array[gcmNonceSize, byte], + handhsakeSecrets: var HandshakeSecrets, challenge: Whoareyou): seq[byte] = + var resp = AuthResponse(version: 5) + let ln = c.localNode + + if challenge.recordSeq < ln.record.sequenceNumber: + resp.record = ln.record + + var remotePubkey: PublicKey + if not toNode.record.get(remotePubkey): + raise newException(Exception, "Could not get public key from remote ENR") # Should not happen! + + let ephKey = newPrivateKey() + let ephPubkey = ephKey.getPublicKey().getRaw + + resp.signature = c.signIDNonce(challenge.idNonce, ephPubkey).getRaw + + deriveKeys(ln.id, toNode.id, ephKey, remotePubkey, challenge, handhsakeSecrets) + + let respRlp = rlp.encode(resp) + + var zeroNonce: array[gcmNonceSize, byte] + let respEnc = encryptGCM(handhsakeSecrets.authRespKey, zeroNonce, respRLP, []) + + let header = AuthHeader(auth: nonce, idNonce: challenge.idNonce, scheme: authSchemeName, + ephemeralKey: ephPubkey, response: respEnc) + rlp.encode(header) + +proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] = + for i in 0 .. a.high: + result[i] = a[i] xor b[i] + +proc packetTag(destNode, srcNode: NodeID): array[32, byte] = + let destId = destNode.toByteArrayBE() + let srcId = srcNode.toByteArrayBE() + let destidHash = sha256.digest(destId) + result = srcId xor destidHash.data + +proc encodeEncrypted*(c: Codec, toNode: Node, toAddr: int, packetData: seq[byte], challenge: Whoareyou): (seq[byte], array[gcmNonceSize, byte]) = + var nonce: array[gcmNonceSize, byte] + randomBytes(nonce) + var headEnc: seq[byte] + + var writeKey: array[16, byte] + + if challenge.isNil: + headEnc = rlp.encode(nonce) + var readKey: array[16, byte] + + # We might not have the node's keys if the handshake hasn't been performed + # yet. That's fine, we will be responded with whoareyou. + discard c.db.loadKeys(toNode.id, toAddr, readKey, writeKey) + else: + var sec: HandshakeSecrets + headEnc = c.makeAuthHeader(toNode, nonce, sec, challenge) + + writeKey = sec.writeKey + + c.db.storeKeys(toNode.id, toAddr, sec.readKey, sec.writeKey) + + var body = packetData + let tag = packetTag(toNode.id, c.localNode.id) + + var headBuf = newSeqOfCap[byte](tag.len + headEnc.len) + headBuf.add(tag) + headBuf.add(headEnc) + + headBuf.add(encryptGCM(writeKey, nonce, body, tag)) + return (headBuf, nonce) + +proc decryptGCM(key: array[16, byte], nonce, ct, authData: openarray[byte]): seq[byte] = + var dctx: GCM[aes128] + dctx.init(key, nonce, authData) + result = newSeq[byte](ct.len - gcmTagSize) + var tag: array[gcmTagSize, byte] + dctx.decrypt(ct.toOpenArray(0, ct.high - gcmTagSize), result) + dctx.getTag(tag) + if tag != ct.toOpenArray(ct.len - gcmTagSize, ct.high): + result = @[] + dctx.clear() + +proc decodePacketBody(typ: byte, body: openarray[byte], res: var Packet): bool = + if typ >= PacketKind.low.byte and typ <= PacketKind.high.byte: + let kind = cast[PacketKind](typ) + res = Packet(kind: kind) + var rlp = rlpFromBytes(@body.toRange) + rlp.enterList() + res.reqId = rlp.read(RequestId) + + proc decode[T](rlp: var Rlp, v: var T) {.inline, nimcall.} = + for k, v in v.fieldPairs: + v = rlp.read(typeof(v)) + + template decode(k: untyped) = + if k == kind: + decode(rlp, res.k) + result = true + + decode(ping) + decode(pong) + decode(findNode) + decode(nodes) + else: + echo "unknown packet: ", typ + + return true + +proc decodeAuthResp(c: Codec, fromId: NodeId, head: AuthHeader, challenge: Whoareyou, secrets: var HandshakeSecrets, newNode: var Node): bool = + if head.scheme != authSchemeName: + echo "Unknown auth scheme" + return false + + var ephKey: PublicKey + if recoverPublicKey(head.ephemeralKey, ephKey) != EthKeysStatus.Success: + return false + + deriveKeys(fromId, c.localNode.id, c.privKey, ephKey, challenge, secrets) + + var zeroNonce: array[gcmNonceSize, byte] + let respData = decryptGCM(secrets.authRespKey, zeroNonce, head.response, []) + let authResp = rlp.decode(respData, AuthResponse) + + newNode = newNode(authResp.record) + return true + +proc decodeEncrypted*(c: var Codec, fromId: NodeID, fromAddr: int, input: seq[byte], authTag: var array[12, byte], newNode: var Node, packet: var Packet): bool = + let input = input.toRange + var r = rlpFromBytes(input[32 .. ^1]) + let authEndPos = r.currentElemEnd + var auth: AuthHeader + var readKey: array[16, byte] + if r.isList: + # Handshake + + # TODO: Auth failure will result in resending whoareyou. Do we really want this? + auth = r.read(AuthHeader) + authTag = auth.auth + + let challenge = c.handshakes.getOrDefault($fromId) + if challenge.isNil: + return false + + if auth.idNonce != challenge.idNonce: + return false + + var sec: HandshakeSecrets + if not c.decodeAuthResp(fromId, auth, challenge, sec, newNode): + return false + c.handshakes.del($fromId) + + # Swap keys to match remote + swap(sec.readKey, sec.writeKey) + c.db.storeKeys(fromId, fromAddr, sec.readKey, sec.writeKey) + readKey = sec.readKey + + else: + authTag = r.read(array[12, byte]) + auth.auth = authTag + var writeKey: array[16, byte] + if not c.db.loadKeys(fromId, fromAddr, readKey, writeKey): + return false + # doAssert(false, "TODO: HANDLE ME!") + + let headSize = 32 + r.position + let bodyEnc = input[headSize .. ^1] + + let body = decryptGCM(readKey, auth.auth, bodyEnc.toOpenArray, input[0 .. 31].toOpenArray) + if body.len > 1: + result = decodePacketBody(body[0], body.toOpenArray(1, body.high), packet) + +proc newRequestId*(): RequestId = + randomBytes(result) + +proc numFields(T: typedesc): int = + for k, v in fieldPairs(default(T)): inc result + +proc encodePacket*[T: SomePacket](p: T, reqId: RequestId): seq[byte] = + result = newSeqOfCap[byte](64) + result.add(packetKind(T).ord) + # result.add(rlp.encode(p)) + + const sz = numFields(T) + var writer = initRlpList(sz + 1) + writer.append(reqId) + for k, v in fieldPairs(p): + writer.append(v) + result.add(writer.finish()) + +proc encodePacket*[T: SomePacket](p: T): seq[byte] = + encodePacket(p, newRequestId()) + +proc makePingPacket*(enrSeq: uint64): seq[byte] = + encodePacket(PingPacket(enrSeq: enrSeq)) + +proc makeFindnodePacket*(distance: uint32): seq[byte] = + encodePacket(FindNodePacket(distance: distance)) diff --git a/eth/p2p/discoveryv5/enr.nim b/eth/p2p/discoveryv5/enr.nim index a7f8259..7aea040 100644 --- a/eth/p2p/discoveryv5/enr.nim +++ b/eth/p2p/discoveryv5/enr.nim @@ -94,7 +94,7 @@ proc get*[T: seq[byte] | string | SomeInteger](r: Record, key: string, typ: type if r.getField(key, f): when typ is SomeInteger: requireKind(f, kNum) - return f.num + return typ(f.num) elif typ is seq[byte]: requireKind(f, kBytes) return f.bytes @@ -223,3 +223,11 @@ proc `$`*(r: Record): string = result &= ": " result &= $v result &= ')' + +proc read*(rlp: var Rlp, T: typedesc[Record]): T {.inline.} = + if not result.fromBytes(rlp.rawData.toOpenArray): + raise newException(ValueError, "Could not deserialize") + rlp.skipElem() + +proc append*(rlpWriter: var RlpWriter, value: Record) = + rlpWriter.appendRawBytes(value.raw.toRange) diff --git a/eth/p2p/discoveryv5/hkdf.nim b/eth/p2p/discoveryv5/hkdf.nim new file mode 100644 index 0000000..325b77c --- /dev/null +++ b/eth/p2p/discoveryv5/hkdf.nim @@ -0,0 +1,182 @@ +import nimcrypto + +proc hkdf*(HashType: typedesc, secret, salt, info: openarray[byte], output: var openarray[byte]) = + var ctx: HMAC[HashType] + ctx.init(salt) + ctx.update(secret) + let prk = ctx.finish().data + const hashLen = HashType.bits div 8 + + var t: MDigest[HashType.bits] + + var numIters = output.len div hashLen + if output.len mod hashLen != 0: + inc numIters + + for i in 0 ..< numIters: + ctx.init(prk) + if i != 0: + ctx.update(t.data) + ctx.update(info) + ctx.update([uint8(i + 1)]) + t = ctx.finish() + let iStart = i * hashLen + var sz = hashLen + if iStart + sz >= output.len: + sz = output.len - iStart + copyMem(addr output[iStart], addr t.data, sz) + + ctx.clear() + +when isMainModule: + import stew/byteutils + + proc hextToBytes(s: string): seq[byte] = + if s.len != 0: return hexToSeqByte(s) + + template test(constants: untyped) = + block: + constants + + let + bikm = hextToBytes(IKM) + bsalt = hextToBytes(salt) + binfo = hextToBytes(info) + bprk {.used.} = hextToBytes(PRK) + bokm = hextToBytes(OKM) + + var output = newSeq[byte](L) + hkdf(HashType, bikm, bsalt, binfo, output) + doAssert(output == bokm) + + test: # 1 + type HashType = sha256 + const + IKM = "0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b" + salt = "0x000102030405060708090a0b0c" + info = "0xf0f1f2f3f4f5f6f7f8f9" + L = 42 + + PRK = "0x077709362c2e32df0ddc3f0dc47bba63" & + "90b6c73bb50f9c3122ec844ad7c2b3e5" + OKM = "0x3cb25f25faacd57a90434f64d0362f2a" & + "2d2d0a90cf1a5a4c5db02d56ecc4c5bf" & + "34007208d5b887185865" + + test: # 2 + type HashType = sha256 + const + IKM = "0x000102030405060708090a0b0c0d0e0f" & + "101112131415161718191a1b1c1d1e1f" & + "202122232425262728292a2b2c2d2e2f" & + "303132333435363738393a3b3c3d3e3f" & + "404142434445464748494a4b4c4d4e4f" + salt = "0x606162636465666768696a6b6c6d6e6f" & + "707172737475767778797a7b7c7d7e7f" & + "808182838485868788898a8b8c8d8e8f" & + "909192939495969798999a9b9c9d9e9f" & + "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf" + info = "0xb0b1b2b3b4b5b6b7b8b9babbbcbdbebf" & + "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" & + "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" & + "e0e1e2e3e4e5e6e7e8e9eaebecedeeef" & + "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff" + L = 82 + + PRK = "0x06a6b88c5853361a06104c9ceb35b45c" & + "ef760014904671014a193f40c15fc244" + OKM = "0xb11e398dc80327a1c8e7f78c596a4934" & + "4f012eda2d4efad8a050cc4c19afa97c" & + "59045a99cac7827271cb41c65e590e09" & + "da3275600c2f09b8367793a9aca3db71" & + "cc30c58179ec3e87c14c01d5c1f3434f" & + "1d87" + + test: # 3 + type HashType = sha256 + const + IKM = "0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b" + salt = "" + info = "" + L = 42 + + PRK = "0x19ef24a32c717b167f33a91d6f648bdf" & + "96596776afdb6377ac434c1c293ccb04" + OKM = "0x8da4e775a563c18f715f802a063c5a31" & + "b8a11f5c5ee1879ec3454e5f3c738d2d" & + "9d201395faa4b61a96c8" + + test: # 4 + type HashType = sha1 + const + IKM = "0x0b0b0b0b0b0b0b0b0b0b0b" + salt = "0x000102030405060708090a0b0c" + info = "0xf0f1f2f3f4f5f6f7f8f9" + L = 42 + + PRK = "0x9b6c18c432a7bf8f0e71c8eb88f4b30baa2ba243" + OKM = "0x085a01ea1b10f36933068b56efa5ad81" & + "a4f14b822f5b091568a9cdd4f155fda2" & + "c22e422478d305f3f896" + + test: # 5 + type HashType = sha1 + const + IKM = "0x000102030405060708090a0b0c0d0e0f" & + "101112131415161718191a1b1c1d1e1f" & + "202122232425262728292a2b2c2d2e2f" & + "303132333435363738393a3b3c3d3e3f" & + "404142434445464748494a4b4c4d4e4f" + salt = "0x606162636465666768696a6b6c6d6e6f" & + "707172737475767778797a7b7c7d7e7f" & + "808182838485868788898a8b8c8d8e8f" & + "909192939495969798999a9b9c9d9e9f" & + "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf" + info = "0xb0b1b2b3b4b5b6b7b8b9babbbcbdbebf" & + "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" & + "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" & + "e0e1e2e3e4e5e6e7e8e9eaebecedeeef" & + "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff" + L = 82 + + PRK = "0x8adae09a2a307059478d309b26c4115a224cfaf6" + OKM = "0x0bd770a74d1160f7c9f12cd5912a06eb" & + "ff6adcae899d92191fe4305673ba2ffe" & + "8fa3f1a4e5ad79f3f334b3b202b2173c" & + "486ea37ce3d397ed034c7f9dfeb15c5e" & + "927336d0441f4c4300e2cff0d0900b52" & + "d3b4" + + test: # 6 + type HashType = sha1 + const + IKM = "0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b" + salt = "" + info = "" + L = 42 + + PRK = "0xda8c8a73c7fa77288ec6f5e7c297786aa0d32d01" + OKM = "0x0ac1af7002b3d761d1e55298da9d0506" & + "b9ae52057220a306e07b6b87e8df21d0" & + "ea00033de03984d34918" + + test: # 7 + type HashType = sha1 + const + IKM = "0x0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c" + salt = "" + info = "" + L = 42 + + PRK = "0x2adccada18779e7c2077ad2eb19d3f3e731385dd" + OKM = "0x2c91117204d745f3500d636a62f64f0a" & + "b3bae548aa53d423b0d1f27ebba6f5e5" & + "673a081d70cce7acfc48" + + block: + var output: array[5, byte] + var secret = [0x01.byte, 0x02, 0x03] + var salt = [0x04.byte, 0x05, 0x06] + var info = [0x07.byte, 0x08, 0x09] + hkdf(sha256, secret, salt, info, output) + doAssert(@output == "D5CF839F63".hextToBytes) diff --git a/eth/p2p/discoveryv5/node.nim b/eth/p2p/discoveryv5/node.nim new file mode 100644 index 0000000..03d4439 --- /dev/null +++ b/eth/p2p/discoveryv5/node.nim @@ -0,0 +1,52 @@ +import std/[net, endians, hashes] +import nimcrypto, stint +import types, enr, eth/keys, ../enode + +type + Node* = ref object + node*: ENode + id*: NodeId + record*: Record + +proc toNodeId*(pk: PublicKey): NodeId = + readUintBE[256](keccak256.digest(pk.getRaw()).data) + +proc newNode*(pk: PublicKey, address: Address): Node = + result.new() + result.node = initENode(pk, address) + result.id = pk.toNodeId() + +proc newNode*(uriString: string): Node = + result.new() + result.node = initENode(uriString) + result.id = result.node.pubkey.toNodeId() + +proc newNode*(enode: ENode): Node = + result.new() + result.node = enode + result.id = result.node.pubkey.toNodeId() + +proc newNode*(r: Record): Node = + var a: Address + var pk: PublicKey + # TODO: Handle IPv6 + var ip = r.get("ip", int32) + + a.ip = IpAddress(family: IpAddressFamily.IPv4) + bigEndian32(addr a.ip.address_v4, addr ip) + + a.udpPort = Port(r.get("udp", int)) + if recoverPublicKey(r.get("secp256k1", seq[byte]), pk) != EthKeysStatus.Success: + echo "Could not recover public key" + + result = newNode(initENode(pk, a)) + result.record = r + +proc hash*(n: Node): hashes.Hash = hash(n.node.pubkey.data) +proc `==`*(a, b: Node): bool = (a.isNil and b.isNil) or (not a.isNil and not b.isNil and a.node.pubkey == b.node.pubkey) + +proc `$`*(n: Node): string = + if n == nil: + "Node[local]" + else: + "Node[" & $n.node.address.ip & ":" & $n.node.address.udpPort & "]" diff --git a/eth/p2p/discoveryv5/protocol.nim b/eth/p2p/discoveryv5/protocol.nim new file mode 100644 index 0000000..1e3a6e3 --- /dev/null +++ b/eth/p2p/discoveryv5/protocol.nim @@ -0,0 +1,344 @@ +import tables, sets, endians, options, math +import types, encoding, node, routing_table, eth/[rlp, keys], chronicles, chronos, ../enode, stint, enr, byteutils +import nimcrypto except toHex + +type + Protocol* = ref object + transp: DatagramTransport + localNode: Node + privateKey: PrivateKey + whoareyouMagic: array[32, byte] + idHash: array[32, byte] + pendingRequests: Table[array[12, byte], PendingRequest] + db: Database + routingTable: RoutingTable + codec: Codec + awaitedPackets: Table[(Node, RequestId), Future[Option[Packet]]] + + PendingRequest = object + node: Node + packet: seq[byte] + +const + lookupRequestLimit = 15 + findnodeResultLimit = 15 # applies in FINDNODE handler + +proc whoareyouMagic(toNode: NodeId): array[32, byte] = + let srcId = toNode.toByteArrayBE() + var data: seq[byte] + data.add(srcId) + for c in "WHOAREYOU": data.add(byte(c)) + sha256.digest(data).data + +proc newProtocol*(privKey: PrivateKey, db: Database, port: Port): Protocol = + result = Protocol(privateKey: privKey, db: db) + var a: Address + a.ip = parseIpAddress("127.0.0.1") + a.udpPort = port + var ipAddr: int32 + bigEndian32(addr ipAddr, addr a.ip.address_v4) + + result.localNode = newNode(initENode(result.privateKey.getPublicKey(), a)) + result.localNode.record = initRecord(12, result.privateKey, {"udp": int(a.udpPort), "ip": ipAddr}) + + let srcId = result.localNode.id.toByteArrayBE() + result.whoareyouMagic = whoareyouMagic(result.localNode.id) + + result.idHash = sha256.digest(srcId).data + result.routingTable.init(result.localNode) + + result.codec = Codec(localNode: result.localNode, privKey: result.privateKey, db: result.db) + +proc start*(p: Protocol) = + discard + +proc send(d: Protocol, a: Address, data: seq[byte]) = + # echo "Sending ", data.len, " bytes to ", a + let ta = initTAddress(a.ip, a.udpPort) + let f = d.transp.sendTo(ta, data) + f.callback = proc(data: pointer) {.gcsafe.} = + if f.failed: + debug "Discovery send failed", msg = f.readError.msg + +proc send(d: Protocol, n: Node, data: seq[byte]) = + d.send(n.node.address, data) + +proc randomBytes(v: var openarray[byte]) = + if nimcrypto.randomBytes(v) != v.len: + raise newException(Exception, "Could not randomize bytes") # TODO: + +proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] = + for i in 0 .. a.high: + result[i] = a[i] xor b[i] + +proc isWhoAreYou(d: Protocol, msg: Bytes): bool = + if msg.len > d.whoareyouMagic.len: + result = d.whoareyouMagic == msg.toOpenArray(0, 31) + +proc decodeWhoAreYou(d: Protocol, msg: Bytes): Whoareyou = + result = Whoareyou() + result[] = rlp.decode(msg.toRange[32 .. ^1], WhoareyouObj) + +proc sendWhoareyou(d: Protocol, address: Address, toNode: NodeId, authTag: array[12, byte]) = + let challenge = Whoareyou(authTag: authTag, recordSeq: 1) + randomBytes(challenge.idNonce) + d.codec.handshakes[$toNode] = challenge + var data = @(whoareyouMagic(toNode)) + data.add(rlp.encode(challenge[])) + d.send(address, data) + +proc sendNodes(d: Protocol, toNode: Node, reqId: RequestId, nodes: openarray[Node]) = + proc sendNodes(d: Protocol, toNode: Node, packet: NodesPacket, reqId: RequestId) {.nimcall.} = + let (data, _) = d.codec.encodeEncrypted(toNode, 12345, encodePacket(packet, reqId), challenge = nil) + d.send(toNode, data) + + const maxNodesPerPacket = 3 + + var packet: NodesPacket + packet.total = ceil(nodes.len / maxNodesPerPacket).uint32 + + for i in 0 ..< nodes.len: + packet.enrs.add(nodes[i].record) + if packet.enrs.len == 3: + d.sendNodes(toNode, packet, reqId) + packet.enrs.setLen(0) + + if packet.enrs.len != 0: + d.sendNodes(toNode, packet, reqId) + +proc handlePing(d: Protocol, fromNode: Node, a: Address, ping: PingPacket, reqId: RequestId) = + var pong: PongPacket + pong.enrSeq = ping.enrSeq + pong.ip = case a.ip.family + of IpAddressFamily.IPv4: @(a.ip.address_v4) + of IpAddressFamily.IPv6: @(a.ip.address_v6) + pong.port = a.udpPort.uint16 + + let (data, _) = d.codec.encodeEncrypted(fromNode, 12345, encodePacket(pong, reqId), challenge = nil) + d.send(fromNode, data) + +proc handleFindNode(d: Protocol, fromNode: Node, a: Address, fn: FindNodePacket, reqId: RequestId) = + if fn.distance == 0: + d.sendNodes(fromNode, reqId, [d.localNode]) + else: + let distance = min(fn.distance, 256) + d.sendNodes(fromNode, reqId, d.routingTable.neighboursAtDistance(distance)) + +proc receive*(d: Protocol, a: Address, msg: Bytes) {.gcsafe.} = + ## Can raise `DiscProtocolError` and all of `RlpError` + # Note: export only needed for testing + if msg.len < 32: + return # Invalid msg + + try: + # echo "Packet received: ", msg.len + + if d.isWhoAreYou(msg): + let whoareyou = d.decodeWhoAreYou(msg) + var pr: PendingRequest + if d.pendingRequests.take(whoareyou.authTag, pr): + let toNode = pr.node + + let (data, _) = d.codec.encodeEncrypted(toNode, 12345, pr.packet, challenge = whoareyou) + d.send(toNode, data) + + else: + var tag: array[32, byte] + tag[0 .. ^1] = msg.toOpenArray(0, 31) + let senderData = tag xor d.idHash + let sender = readUintBE[256](senderData) + + var authTag: array[12, byte] + var node: Node + var packet: Packet + + if d.codec.decodeEncrypted(sender, 12345, msg, authTag, node, packet): + if node.isNil: + node = d.routingTable.getNode(sender) + else: + echo "Adding new node to routing table" + discard d.routingTable.addNode(node) + + doAssert(not node.isNil, "No node in the routing table (internal error?)") + + case packet.kind + of ping: + d.handlePing(node, a, packet.ping, packet.reqId) + of findNode: + d.handleFindNode(node, a, packet.findNode, packet.reqId) + else: + var waiter: Future[Option[Packet]] + if d.awaitedPackets.take((node, packet.reqId), waiter): + waiter.complete(packet.some) + else: + echo "TODO: handle packet: ", packet.kind, " from ", node + + else: + d.sendWhoareyou(a, sender, authTag) + echo "Could not decode, respond with whoareyou" + + except Exception as e: + echo "Exception: ", e.msg + echo e.getStackTrace() + +proc waitPacket(d: Protocol, fromNode: Node, reqId: RequestId): Future[Option[Packet]] = + result = newFuture[Option[Packet]]("waitPacket") + let res = result + let key = (fromNode, reqId) + sleepAsync(1000).addCallback() do(data: pointer): + d.awaitedPackets.del(key) + if not res.finished: + res.complete(none(Packet)) + d.awaitedPackets[key] = result + +proc addNodesFromENRs(result: var seq[Node], enrs: openarray[Record]) = + for r in enrs: result.add(newNode(r)) + +proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): Future[seq[Node]] {.async.} = + var op = await d.waitPacket(fromNode, reqId) + if op.isSome and op.get.kind == nodes: + result.addNodesFromENRs(op.get.nodes.enrs) + let total = op.get.nodes.total + for i in 1 ..< total: + op = await d.waitPacket(fromNode, reqId) + if op.isSome and op.get.kind == nodes: + result.addNodesFromENRs(op.get.nodes.enrs) + else: + break + +proc findNode(d: Protocol, toNode: Node, distance: uint32): Future[seq[Node]] {.async.} = + let reqId = newRequestId() + let packet = encodePacket(FindNodePacket(distance: distance), reqId) + let (data, nonce) = d.codec.encodeEncrypted(toNode, 12345, packet, challenge = nil) + d.pendingRequests[nonce] = PendingRequest(node: toNode, packet: packet) + d.send(toNode, data) + result = await d.waitNodes(toNode, reqId) + +proc lookupDistances(target, dest: NodeId): seq[uint32] = + let td = logDist(target, dest) + result.add(td) + var i = 1'u32 + while result.len < lookupRequestLimit: + if td + i < 256: + result.add(td + i) + if td - i > 0'u32: + result.add(td - i) + inc i + +proc lookupWorker(p: Protocol, destNode: Node, target: NodeId): Future[seq[Node]] {.async.} = + let dists = lookupDistances(target, destNode.id) + var i = 0 + while i < lookupRequestLimit and result.len < findNodeResultLimit: + let r = await p.findNode(destNode, dists[i]) + # TODO: Handle falures + result.add(r) + inc i + + for n in result: + discard p.routingTable.addNode(n) + +proc lookup(p: Protocol, target: NodeId): Future[seq[Node]] {.async.} = + result = p.routingTable.neighbours(target, 16) + var asked = initHashSet[NodeId]() + asked.incl(p.localNode.id) + var seen = asked + + const alpha = 3 + + var pendingQueries = newSeqOfCap[Future[seq[Node]]](alpha) + + while true: + var i = 0 + while i < result.len and pendingQueries.len < alpha: + let n = result[i] + if not asked.containsOrIncl(n.id): + pendingQueries.add(p.lookupWorker(n, target)) + inc i + + if pendingQueries.len == 0: + break + + let idx = await oneIndex(pendingQueries) + + let nodes = pendingQueries[idx].read + pendingQueries.del(idx) + for n in nodes: + if not seen.containsOrIncl(n.id): + if result.len < BUCKET_SIZE: + result.add(n) + +proc lookupRandom(p: Protocol): Future[seq[Node]] = + var id: NodeId + discard randomBytes(addr id, sizeof(id)) + p.lookup(id) + +proc processClient(transp: DatagramTransport, + raddr: TransportAddress): Future[void] {.async, gcsafe.} = + var proto = getUserData[Protocol](transp) + try: + # TODO: Maybe here better to use `peekMessage()` to avoid allocation, + # but `Bytes` object is just a simple seq[byte], and `ByteRange` object + # do not support custom length. + var buf = transp.getMessage() + let a = Address(ip: raddr.address, udpPort: raddr.port, tcpPort: raddr.port) + proto.receive(a, buf) + except RlpError: + debug "Receive failed", err = getCurrentExceptionMsg() + except: + debug "Receive failed", err = getCurrentExceptionMsg() + raise + +proc open*(d: Protocol) = + # TODO allow binding to specific IP / IPv6 / etc + let ta = initTAddress(IPv4_any(), d.localNode.node.address.udpPort) + d.transp = newDatagramTransport(processClient, udata = d, local = ta) + +when isMainModule: + import discovery_db + import eth/trie/db + + proc genDiscoveries(n: int): seq[Protocol] = + var pks = ["98b3d4d4fe348ac5192d16b46aa36c41f847b9f265ba4d56f6326669449a968b", "88d125288fbb19ecd7b6a355faf3e842e3c6158d38af14bb97ac8d957ec9cb58", "c9a24471d2f84efa103b9abbdedd4c0fea8402f94e5ceb3ca4d9cff951fc407f"] + for i in 0 ..< n: + var pk: PrivateKey + if i < pks.len: + pk = initPrivateKey(pks[i]) + else: + pk = newPrivateKey() + + let d = newProtocol(pk, DiscoveryDB.init(newMemoryDB()), Port(12001 + i)) + d.open() + result.add(d) + + proc addNode(d: Protocol, enr: string) = + var r: Record + let res = r.fromUri(enr) + doAssert(res) + discard d.routingTable.addNode(newNode(r)) + + proc addNode(d: openarray[Protocol], enr: string) = + for dd in d: dd.addNode(enr) + + proc test() {.async.} = + block: + let d = genDiscoveries(3) + d.addNode("enr:-IS4QPvi3TdAUd2Jdrx-8ScRbCzrV1kVsTTM02mfz8Fx7CtrAfYN7AjxTx3MWbY2efRmAhS-Yyv4nhyzKu_YS6jSh08BgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQJeWTAJhJYN2q3BvcQwsyo7pIi8KnfwDIrhNdflCFvqr4N1ZHCCD6A") + + for i, dd in d: + let nodes = await dd.lookupRandom() + echo "NODES ", i, ": ", nodes + + # block: + # var d = genDiscoveries(4) + # let rootD = d[0] + # d.del(0) + + + # d.addNode(rootD.localNode.record.toUri) + + # for i, dd in d: + # let nodes = await dd.lookupRandom() + # echo "NODES ", i, ": ", nodes + + waitFor test() + runForever() \ No newline at end of file diff --git a/eth/p2p/discoveryv5/routing_table.nim b/eth/p2p/discoveryv5/routing_table.nim new file mode 100644 index 0000000..9e08a1f --- /dev/null +++ b/eth/p2p/discoveryv5/routing_table.nim @@ -0,0 +1,198 @@ +import algorithm, times, sequtils, bitops +import types, node +import stint, chronicles + +type + RoutingTable* = object + thisNode: Node + buckets: seq[KBucket] + + KBucket = ref object + istart, iend: NodeId + nodes: seq[Node] + replacementCache: seq[Node] + lastUpdated: float # epochTime + +const + BUCKET_SIZE* = 16 + BITS_PER_HOP = 8 + ID_SIZE = 256 + +proc distanceTo(n: Node, id: NodeId): UInt256 = n.id xor id +proc logDist*(a, b: NodeId): uint32 = + let a = a.toBytes + let b = b.toBytes + var lz = 0 + for i in 0 ..< a.len: + let x = a[i] xor b[i] + if x == 0: + result += 8 + else: + result += bitops.countLeadingZeroBits(x).uint8 + uint32(a.len * 8 - lz) + +proc newKBucket(istart, iend: NodeId): KBucket = + result.new() + result.istart = istart + result.iend = iend + result.nodes = @[] + result.replacementCache = @[] + +proc midpoint(k: KBucket): NodeId = + k.istart + (k.iend - k.istart) div 2.u256 + +proc distanceTo(k: KBucket, id: NodeId): UInt256 = k.midpoint xor id +proc nodesByDistanceTo(k: KBucket, id: NodeId): seq[Node] = + sortedByIt(k.nodes, it.distanceTo(id)) + +proc len(k: KBucket): int {.inline.} = k.nodes.len +proc head(k: KBucket): Node {.inline.} = k.nodes[0] + +proc add(k: KBucket, n: Node): Node = + ## Try to add the given node to this bucket. + + ## If the node is already present, it is moved to the tail of the list, and we return nil. + + ## If the node is not already present and the bucket has fewer than k entries, it is inserted + ## at the tail of the list, and we return nil. + + ## If the bucket is full, we add the node to the bucket's replacement cache and return the + ## node at the head of the list (i.e. the least recently seen), which should be evicted if it + ## fails to respond to a ping. + k.lastUpdated = epochTime() + let nodeIdx = k.nodes.find(n) + if nodeIdx != -1: + k.nodes.delete(nodeIdx) + k.nodes.add(n) + elif k.len < BUCKET_SIZE: + k.nodes.add(n) + else: + k.replacementCache.add(n) + return k.head + return nil + +proc removeNode(k: KBucket, n: Node) = + let i = k.nodes.find(n) + if i != -1: k.nodes.delete(i) + +proc split(k: KBucket): tuple[lower, upper: KBucket] = + ## Split at the median id + let splitid = k.midpoint + result.lower = newKBucket(k.istart, splitid) + result.upper = newKBucket(splitid + 1.u256, k.iend) + for node in k.nodes: + let bucket = if node.id <= splitid: result.lower else: result.upper + discard bucket.add(node) + for node in k.replacementCache: + let bucket = if node.id <= splitid: result.lower else: result.upper + bucket.replacementCache.add(node) + +proc inRange(k: KBucket, n: Node): bool {.inline.} = + k.istart <= n.id and n.id <= k.iend + +proc isFull(k: KBucket): bool = k.len == BUCKET_SIZE + +proc contains(k: KBucket, n: Node): bool = n in k.nodes + +proc binaryGetBucketForNode(buckets: openarray[KBucket], + id: NodeId): KBucket {.inline.} = + ## Given a list of ordered buckets, returns the bucket for a given node. + let bucketPos = lowerBound(buckets, id) do(a: KBucket, b: NodeId) -> int: + cmp(a.iend, b) + # Prevents edge cases where bisect_left returns an out of range index + if bucketPos < buckets.len: + let bucket = buckets[bucketPos] + if bucket.istart <= id and id <= bucket.iend: + result = bucket + + if result.isNil: + raise newException(ValueError, "No bucket found for node with id " & $id) + +proc computeSharedPrefixBits(nodes: openarray[Node]): int = + ## Count the number of prefix bits shared by all nodes. + if nodes.len < 2: + return ID_SIZE + + var mask = zero(UInt256) + let one = one(UInt256) + + for i in 1 .. ID_SIZE: + mask = mask or (one shl (ID_SIZE - i)) + let reference = nodes[0].id and mask + for j in 1 .. nodes.high: + if (nodes[j].id and mask) != reference: return i - 1 + + for n in nodes: + echo n.id.toHex() + + doAssert(false, "Unable to calculate number of shared prefix bits") + +proc init*(r: var RoutingTable, thisNode: Node) {.inline.} = + r.thisNode = thisNode + r.buckets = @[newKBucket(0.u256, high(Uint256))] + +proc splitBucket(r: var RoutingTable, index: int) = + let bucket = r.buckets[index] + let (a, b) = bucket.split() + r.buckets[index] = a + r.buckets.insert(b, index + 1) + +proc bucketForNode(r: RoutingTable, id: NodeId): KBucket = + binaryGetBucketForNode(r.buckets, id) + +proc removeNode(r: var RoutingTable, n: Node) = + r.bucketForNode(n.id).removeNode(n) + +proc addNode*(r: var RoutingTable, n: Node): Node = + if n == r.thisNode: + # warn "Trying to add ourselves to the routing table", node = n + return + let bucket = r.bucketForNode(n.id) + let evictionCandidate = bucket.add(n) + if not evictionCandidate.isNil: + # Split if the bucket has the local node in its range or if the depth is not congruent + # to 0 mod BITS_PER_HOP + + let depth = computeSharedPrefixBits(bucket.nodes) + if bucket.inRange(r.thisNode) or (depth mod BITS_PER_HOP != 0 and depth != ID_SIZE): + r.splitBucket(r.buckets.find(bucket)) + return r.addNode(n) # retry + + # Nothing added, ping evictionCandidate + return evictionCandidate + +proc getNode*(r: RoutingTable, id: NodeId): Node = + let b = binaryGetBucketForNode(r.buckets, id) + for n in b.nodes: + if n.id == id: + return n + +proc contains(r: RoutingTable, n: Node): bool = n in r.bucketForNode(n.id) + +proc bucketsByDistanceTo(r: RoutingTable, id: NodeId): seq[KBucket] = + sortedByIt(r.buckets, it.distanceTo(id)) + +proc notFullBuckets(r: RoutingTable): seq[KBucket] = + r.buckets.filterIt(not it.isFull) + +proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE): seq[Node] = + ## Return up to k neighbours of the given node. + result = newSeqOfCap[Node](k * 2) + for bucket in r.bucketsByDistanceTo(id): + for n in bucket.nodesByDistanceTo(id): + if n.id != id: + result.add(n) + if result.len == k * 2: + break + result = sortedByIt(result, it.distanceTo(id)) + if result.len > k: + result.setLen(k) + +proc idAtDistance(id: NodeId, dist: uint32): NodeId = + id and (Uint256.high shl dist.int) + +proc neighboursAtDistance*(r: RoutingTable, distance: uint32, k: int = BUCKET_SIZE): seq[Node] = + r.neighbours(idAtDistance(r.thisNode.id, distance), k) + +proc len(r: RoutingTable): int = + for b in r.buckets: result += b.len diff --git a/eth/p2p/discoveryv5/types.nim b/eth/p2p/discoveryv5/types.nim new file mode 100644 index 0000000..2cb15e3 --- /dev/null +++ b/eth/p2p/discoveryv5/types.nim @@ -0,0 +1,73 @@ +import hashes +import ../enode, enr, stint + +type + NodeId* = UInt256 + + WhoareyouObj* = object + authTag*: array[12, byte] + idNonce*: array[32, byte] + recordSeq*: uint64 + + Whoareyou* = ref WhoareyouObj + + Database* = ref object of RootRef + + PacketKind* = enum + ping = 0x01 + pong = 0x02 + findnode = 0x03 + nodes = 0x04 + regtopic = 0x05 + ticket = 0x06 + regconfirmation = 0x07 + topicquery = 0x08 + + RequestId* = array[8, byte] + + PingPacket* = object + enrSeq*: uint64 + + PongPacket* = object + enrSeq*: uint64 + ip*: seq[byte] + port*: uint16 + + FindNodePacket* = object + distance*: uint32 + + NodesPacket* = object + total*: uint32 + enrs*: seq[Record] + + SomePacket* = PingPacket or PongPacket or FindNodePacket or NodesPacket + + Packet* = object + reqId*: RequestId + case kind*: PacketKind + of ping: + ping*: PingPacket + of pong: + pong*: PongPacket + of findnode: + findNode*: FindNodePacket + of nodes: + nodes*: NodesPacket + else: + # TODO: Define the rest + discard + +template packetKind*(T: typedesc[SomePacket]): PacketKind = + when T is PingPacket: ping + elif T is PongPacket: pong + elif T is FindNodePacket: findNode + elif T is NodesPacket: nodes + +method storeKeys*(db: Database, id: NodeId, address: int, r, w: array[16, byte]) {.base.} = discard +method loadKeys*(db: Database, id: NodeId, address: int, r, w: var array[16, byte]): bool {.base.} = discard + +proc toBytes*(id: NodeId): array[32, byte] {.inline.} = + id.toByteArrayBE() + +proc hash*(id: NodeId): Hash {.inline.} = + hashData(unsafeAddr id, sizeof(id)) diff --git a/eth/rlp.nim b/eth/rlp.nim index 9e45443..bc075cf 100644 --- a/eth/rlp.nim +++ b/eth/rlp.nim @@ -13,7 +13,7 @@ export type Rlp* = object bytes: BytesRange - position: int + position*: int RlpNodeType* = enum rlpBlob