Discv5 WIP

This commit is contained in:
Yuriy Glukhov 2019-12-16 21:38:45 +02:00 committed by zah
parent edd674662a
commit eda6c2906c
9 changed files with 1163 additions and 2 deletions

View File

@ -0,0 +1,30 @@
import types
import eth/trie/db
type
DiscoveryDB* = ref object of Database
backend: TrieDatabaseRef
DbKeyKind = enum
kNodeToKeys = 100
proc init*(T: type DiscoveryDB, backend: TrieDatabaseRef): DiscoveryDB =
T(backend: backend)
proc makeKey(id: NodeId, address: int): array[1 + sizeof(id) + sizeof(address), byte] =
result[0] = byte(kNodeToKeys)
copyMem(addr result[1], unsafeAddr id, sizeof(id))
copyMem(addr result[sizeof(id) + 1], unsafeAddr address, sizeof(address))
method storeKeys*(db: DiscoveryDB, id: NodeId, address: int, r, w: array[16, byte]) =
var value: array[sizeof(r) + sizeof(w), byte]
value[0 .. 15] = r
value[16 .. ^1] = w
db.backend.put(makeKey(id, address), value)
method loadKeys*(db: DiscoveryDB, id: NodeId, address: int, r, w: var array[16, byte]): bool =
let res = db.backend.get(makeKey(id, address))
if res.len == sizeof(r) + sizeof(w):
copyMem(addr r[0], unsafeAddr res[0], sizeof(r))
copyMem(addr w[0], unsafeAddr res[sizeof(r)], sizeof(w))
result = true

View File

@ -0,0 +1,274 @@
import tables
import types, node, enr, hkdf, eth/[rlp, keys], nimcrypto, stint
const
idNoncePrefix = "discovery-id-nonce"
gcmNonceSize* = 12
keyAgreementPrefix = "discovery v5 key agreement"
authSchemeName = "gcm"
type
AuthResponse = object
version: int
signature: array[64, byte]
record: Record
Codec* = object
localNode*: Node
privKey*: PrivateKey
db*: Database
handshakes*: Table[string, Whoareyou] # TODO: Implement hash for NodeID
HandshakeSecrets = object
writeKey: array[16, byte]
readKey: array[16, byte]
authRespKey: array[16, byte]
AuthHeader = object
auth: array[12, byte]
idNonce: array[32, byte]
scheme: string
ephemeralKey: array[64, byte]
response: seq[byte]
const
gcmTagSize = 16
proc randomBytes(v: var openarray[byte]) =
if nimcrypto.randomBytes(v) != v.len:
raise newException(Exception, "Could not randomize bytes") # TODO:
proc idNonceHash(nonce, ephkey: openarray[byte]): array[32, byte] =
var ctx: sha256
ctx.init()
ctx.update(idNoncePrefix)
ctx.update(nonce)
ctx.update(ephkey)
ctx.finish().data
proc signIDNonce(c: Codec, idNonce, ephKey: openarray[byte]): SignatureNR =
if signRawMessage(idNonceHash(idNonce, ephKey), c.privKey, result) != EthKeysStatus.Success:
raise newException(Exception, "Could not sign idNonce")
proc deriveKeys(n1, n2: NodeID, priv: PrivateKey, pub: PublicKey, challenge: Whoareyou, result: var HandshakeSecrets) =
var eph: SharedSecretFull
if ecdhAgree(priv, pub, eph) != EthKeysStatus.Success:
raise newException(Exception, "ecdhAgree failed")
# TODO: Unneeded allocation here
var info = newSeqOfCap[byte](idNoncePrefix.len + 32 * 2)
for i, c in keyAgreementPrefix: info.add(byte(c))
info.add(n1.toByteArrayBE())
info.add(n2.toByteArrayBE())
# echo "EPH: ", eph.data.toHex, " idNonce: ", challenge.idNonce.toHex, "info: ", info.toHex
static: assert(sizeof(result) == 16 * 3)
var res = cast[ptr UncheckedArray[byte]](addr result)
hkdf(sha256, eph.data, challenge.idNonce, info, toOpenArray(res, 0, sizeof(result) - 1))
proc encryptGCM(key, nonce, pt, authData: openarray[byte]): seq[byte] =
var ectx: GCM[aes128]
ectx.init(key, nonce, authData)
result = newSeq[byte](pt.len + gcmTagSize)
ectx.encrypt(pt, result)
ectx.getTag(result.toOpenArray(pt.len, result.high))
ectx.clear()
proc makeAuthHeader(c: Codec, toNode: Node, nonce: array[gcmNonceSize, byte],
handhsakeSecrets: var HandshakeSecrets, challenge: Whoareyou): seq[byte] =
var resp = AuthResponse(version: 5)
let ln = c.localNode
if challenge.recordSeq < ln.record.sequenceNumber:
resp.record = ln.record
var remotePubkey: PublicKey
if not toNode.record.get(remotePubkey):
raise newException(Exception, "Could not get public key from remote ENR") # Should not happen!
let ephKey = newPrivateKey()
let ephPubkey = ephKey.getPublicKey().getRaw
resp.signature = c.signIDNonce(challenge.idNonce, ephPubkey).getRaw
deriveKeys(ln.id, toNode.id, ephKey, remotePubkey, challenge, handhsakeSecrets)
let respRlp = rlp.encode(resp)
var zeroNonce: array[gcmNonceSize, byte]
let respEnc = encryptGCM(handhsakeSecrets.authRespKey, zeroNonce, respRLP, [])
let header = AuthHeader(auth: nonce, idNonce: challenge.idNonce, scheme: authSchemeName,
ephemeralKey: ephPubkey, response: respEnc)
rlp.encode(header)
proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] =
for i in 0 .. a.high:
result[i] = a[i] xor b[i]
proc packetTag(destNode, srcNode: NodeID): array[32, byte] =
let destId = destNode.toByteArrayBE()
let srcId = srcNode.toByteArrayBE()
let destidHash = sha256.digest(destId)
result = srcId xor destidHash.data
proc encodeEncrypted*(c: Codec, toNode: Node, toAddr: int, packetData: seq[byte], challenge: Whoareyou): (seq[byte], array[gcmNonceSize, byte]) =
var nonce: array[gcmNonceSize, byte]
randomBytes(nonce)
var headEnc: seq[byte]
var writeKey: array[16, byte]
if challenge.isNil:
headEnc = rlp.encode(nonce)
var readKey: array[16, byte]
# We might not have the node's keys if the handshake hasn't been performed
# yet. That's fine, we will be responded with whoareyou.
discard c.db.loadKeys(toNode.id, toAddr, readKey, writeKey)
else:
var sec: HandshakeSecrets
headEnc = c.makeAuthHeader(toNode, nonce, sec, challenge)
writeKey = sec.writeKey
c.db.storeKeys(toNode.id, toAddr, sec.readKey, sec.writeKey)
var body = packetData
let tag = packetTag(toNode.id, c.localNode.id)
var headBuf = newSeqOfCap[byte](tag.len + headEnc.len)
headBuf.add(tag)
headBuf.add(headEnc)
headBuf.add(encryptGCM(writeKey, nonce, body, tag))
return (headBuf, nonce)
proc decryptGCM(key: array[16, byte], nonce, ct, authData: openarray[byte]): seq[byte] =
var dctx: GCM[aes128]
dctx.init(key, nonce, authData)
result = newSeq[byte](ct.len - gcmTagSize)
var tag: array[gcmTagSize, byte]
dctx.decrypt(ct.toOpenArray(0, ct.high - gcmTagSize), result)
dctx.getTag(tag)
if tag != ct.toOpenArray(ct.len - gcmTagSize, ct.high):
result = @[]
dctx.clear()
proc decodePacketBody(typ: byte, body: openarray[byte], res: var Packet): bool =
if typ >= PacketKind.low.byte and typ <= PacketKind.high.byte:
let kind = cast[PacketKind](typ)
res = Packet(kind: kind)
var rlp = rlpFromBytes(@body.toRange)
rlp.enterList()
res.reqId = rlp.read(RequestId)
proc decode[T](rlp: var Rlp, v: var T) {.inline, nimcall.} =
for k, v in v.fieldPairs:
v = rlp.read(typeof(v))
template decode(k: untyped) =
if k == kind:
decode(rlp, res.k)
result = true
decode(ping)
decode(pong)
decode(findNode)
decode(nodes)
else:
echo "unknown packet: ", typ
return true
proc decodeAuthResp(c: Codec, fromId: NodeId, head: AuthHeader, challenge: Whoareyou, secrets: var HandshakeSecrets, newNode: var Node): bool =
if head.scheme != authSchemeName:
echo "Unknown auth scheme"
return false
var ephKey: PublicKey
if recoverPublicKey(head.ephemeralKey, ephKey) != EthKeysStatus.Success:
return false
deriveKeys(fromId, c.localNode.id, c.privKey, ephKey, challenge, secrets)
var zeroNonce: array[gcmNonceSize, byte]
let respData = decryptGCM(secrets.authRespKey, zeroNonce, head.response, [])
let authResp = rlp.decode(respData, AuthResponse)
newNode = newNode(authResp.record)
return true
proc decodeEncrypted*(c: var Codec, fromId: NodeID, fromAddr: int, input: seq[byte], authTag: var array[12, byte], newNode: var Node, packet: var Packet): bool =
let input = input.toRange
var r = rlpFromBytes(input[32 .. ^1])
let authEndPos = r.currentElemEnd
var auth: AuthHeader
var readKey: array[16, byte]
if r.isList:
# Handshake
# TODO: Auth failure will result in resending whoareyou. Do we really want this?
auth = r.read(AuthHeader)
authTag = auth.auth
let challenge = c.handshakes.getOrDefault($fromId)
if challenge.isNil:
return false
if auth.idNonce != challenge.idNonce:
return false
var sec: HandshakeSecrets
if not c.decodeAuthResp(fromId, auth, challenge, sec, newNode):
return false
c.handshakes.del($fromId)
# Swap keys to match remote
swap(sec.readKey, sec.writeKey)
c.db.storeKeys(fromId, fromAddr, sec.readKey, sec.writeKey)
readKey = sec.readKey
else:
authTag = r.read(array[12, byte])
auth.auth = authTag
var writeKey: array[16, byte]
if not c.db.loadKeys(fromId, fromAddr, readKey, writeKey):
return false
# doAssert(false, "TODO: HANDLE ME!")
let headSize = 32 + r.position
let bodyEnc = input[headSize .. ^1]
let body = decryptGCM(readKey, auth.auth, bodyEnc.toOpenArray, input[0 .. 31].toOpenArray)
if body.len > 1:
result = decodePacketBody(body[0], body.toOpenArray(1, body.high), packet)
proc newRequestId*(): RequestId =
randomBytes(result)
proc numFields(T: typedesc): int =
for k, v in fieldPairs(default(T)): inc result
proc encodePacket*[T: SomePacket](p: T, reqId: RequestId): seq[byte] =
result = newSeqOfCap[byte](64)
result.add(packetKind(T).ord)
# result.add(rlp.encode(p))
const sz = numFields(T)
var writer = initRlpList(sz + 1)
writer.append(reqId)
for k, v in fieldPairs(p):
writer.append(v)
result.add(writer.finish())
proc encodePacket*[T: SomePacket](p: T): seq[byte] =
encodePacket(p, newRequestId())
proc makePingPacket*(enrSeq: uint64): seq[byte] =
encodePacket(PingPacket(enrSeq: enrSeq))
proc makeFindnodePacket*(distance: uint32): seq[byte] =
encodePacket(FindNodePacket(distance: distance))

View File

@ -94,7 +94,7 @@ proc get*[T: seq[byte] | string | SomeInteger](r: Record, key: string, typ: type
if r.getField(key, f):
when typ is SomeInteger:
requireKind(f, kNum)
return f.num
return typ(f.num)
elif typ is seq[byte]:
requireKind(f, kBytes)
return f.bytes
@ -223,3 +223,11 @@ proc `$`*(r: Record): string =
result &= ": "
result &= $v
result &= ')'
proc read*(rlp: var Rlp, T: typedesc[Record]): T {.inline.} =
if not result.fromBytes(rlp.rawData.toOpenArray):
raise newException(ValueError, "Could not deserialize")
rlp.skipElem()
proc append*(rlpWriter: var RlpWriter, value: Record) =
rlpWriter.appendRawBytes(value.raw.toRange)

View File

@ -0,0 +1,182 @@
import nimcrypto
proc hkdf*(HashType: typedesc, secret, salt, info: openarray[byte], output: var openarray[byte]) =
var ctx: HMAC[HashType]
ctx.init(salt)
ctx.update(secret)
let prk = ctx.finish().data
const hashLen = HashType.bits div 8
var t: MDigest[HashType.bits]
var numIters = output.len div hashLen
if output.len mod hashLen != 0:
inc numIters
for i in 0 ..< numIters:
ctx.init(prk)
if i != 0:
ctx.update(t.data)
ctx.update(info)
ctx.update([uint8(i + 1)])
t = ctx.finish()
let iStart = i * hashLen
var sz = hashLen
if iStart + sz >= output.len:
sz = output.len - iStart
copyMem(addr output[iStart], addr t.data, sz)
ctx.clear()
when isMainModule:
import stew/byteutils
proc hextToBytes(s: string): seq[byte] =
if s.len != 0: return hexToSeqByte(s)
template test(constants: untyped) =
block:
constants
let
bikm = hextToBytes(IKM)
bsalt = hextToBytes(salt)
binfo = hextToBytes(info)
bprk {.used.} = hextToBytes(PRK)
bokm = hextToBytes(OKM)
var output = newSeq[byte](L)
hkdf(HashType, bikm, bsalt, binfo, output)
doAssert(output == bokm)
test: # 1
type HashType = sha256
const
IKM = "0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"
salt = "0x000102030405060708090a0b0c"
info = "0xf0f1f2f3f4f5f6f7f8f9"
L = 42
PRK = "0x077709362c2e32df0ddc3f0dc47bba63" &
"90b6c73bb50f9c3122ec844ad7c2b3e5"
OKM = "0x3cb25f25faacd57a90434f64d0362f2a" &
"2d2d0a90cf1a5a4c5db02d56ecc4c5bf" &
"34007208d5b887185865"
test: # 2
type HashType = sha256
const
IKM = "0x000102030405060708090a0b0c0d0e0f" &
"101112131415161718191a1b1c1d1e1f" &
"202122232425262728292a2b2c2d2e2f" &
"303132333435363738393a3b3c3d3e3f" &
"404142434445464748494a4b4c4d4e4f"
salt = "0x606162636465666768696a6b6c6d6e6f" &
"707172737475767778797a7b7c7d7e7f" &
"808182838485868788898a8b8c8d8e8f" &
"909192939495969798999a9b9c9d9e9f" &
"a0a1a2a3a4a5a6a7a8a9aaabacadaeaf"
info = "0xb0b1b2b3b4b5b6b7b8b9babbbcbdbebf" &
"c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" &
"d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" &
"e0e1e2e3e4e5e6e7e8e9eaebecedeeef" &
"f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff"
L = 82
PRK = "0x06a6b88c5853361a06104c9ceb35b45c" &
"ef760014904671014a193f40c15fc244"
OKM = "0xb11e398dc80327a1c8e7f78c596a4934" &
"4f012eda2d4efad8a050cc4c19afa97c" &
"59045a99cac7827271cb41c65e590e09" &
"da3275600c2f09b8367793a9aca3db71" &
"cc30c58179ec3e87c14c01d5c1f3434f" &
"1d87"
test: # 3
type HashType = sha256
const
IKM = "0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"
salt = ""
info = ""
L = 42
PRK = "0x19ef24a32c717b167f33a91d6f648bdf" &
"96596776afdb6377ac434c1c293ccb04"
OKM = "0x8da4e775a563c18f715f802a063c5a31" &
"b8a11f5c5ee1879ec3454e5f3c738d2d" &
"9d201395faa4b61a96c8"
test: # 4
type HashType = sha1
const
IKM = "0x0b0b0b0b0b0b0b0b0b0b0b"
salt = "0x000102030405060708090a0b0c"
info = "0xf0f1f2f3f4f5f6f7f8f9"
L = 42
PRK = "0x9b6c18c432a7bf8f0e71c8eb88f4b30baa2ba243"
OKM = "0x085a01ea1b10f36933068b56efa5ad81" &
"a4f14b822f5b091568a9cdd4f155fda2" &
"c22e422478d305f3f896"
test: # 5
type HashType = sha1
const
IKM = "0x000102030405060708090a0b0c0d0e0f" &
"101112131415161718191a1b1c1d1e1f" &
"202122232425262728292a2b2c2d2e2f" &
"303132333435363738393a3b3c3d3e3f" &
"404142434445464748494a4b4c4d4e4f"
salt = "0x606162636465666768696a6b6c6d6e6f" &
"707172737475767778797a7b7c7d7e7f" &
"808182838485868788898a8b8c8d8e8f" &
"909192939495969798999a9b9c9d9e9f" &
"a0a1a2a3a4a5a6a7a8a9aaabacadaeaf"
info = "0xb0b1b2b3b4b5b6b7b8b9babbbcbdbebf" &
"c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" &
"d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" &
"e0e1e2e3e4e5e6e7e8e9eaebecedeeef" &
"f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff"
L = 82
PRK = "0x8adae09a2a307059478d309b26c4115a224cfaf6"
OKM = "0x0bd770a74d1160f7c9f12cd5912a06eb" &
"ff6adcae899d92191fe4305673ba2ffe" &
"8fa3f1a4e5ad79f3f334b3b202b2173c" &
"486ea37ce3d397ed034c7f9dfeb15c5e" &
"927336d0441f4c4300e2cff0d0900b52" &
"d3b4"
test: # 6
type HashType = sha1
const
IKM = "0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"
salt = ""
info = ""
L = 42
PRK = "0xda8c8a73c7fa77288ec6f5e7c297786aa0d32d01"
OKM = "0x0ac1af7002b3d761d1e55298da9d0506" &
"b9ae52057220a306e07b6b87e8df21d0" &
"ea00033de03984d34918"
test: # 7
type HashType = sha1
const
IKM = "0x0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c"
salt = ""
info = ""
L = 42
PRK = "0x2adccada18779e7c2077ad2eb19d3f3e731385dd"
OKM = "0x2c91117204d745f3500d636a62f64f0a" &
"b3bae548aa53d423b0d1f27ebba6f5e5" &
"673a081d70cce7acfc48"
block:
var output: array[5, byte]
var secret = [0x01.byte, 0x02, 0x03]
var salt = [0x04.byte, 0x05, 0x06]
var info = [0x07.byte, 0x08, 0x09]
hkdf(sha256, secret, salt, info, output)
doAssert(@output == "D5CF839F63".hextToBytes)

View File

@ -0,0 +1,52 @@
import std/[net, endians, hashes]
import nimcrypto, stint
import types, enr, eth/keys, ../enode
type
Node* = ref object
node*: ENode
id*: NodeId
record*: Record
proc toNodeId*(pk: PublicKey): NodeId =
readUintBE[256](keccak256.digest(pk.getRaw()).data)
proc newNode*(pk: PublicKey, address: Address): Node =
result.new()
result.node = initENode(pk, address)
result.id = pk.toNodeId()
proc newNode*(uriString: string): Node =
result.new()
result.node = initENode(uriString)
result.id = result.node.pubkey.toNodeId()
proc newNode*(enode: ENode): Node =
result.new()
result.node = enode
result.id = result.node.pubkey.toNodeId()
proc newNode*(r: Record): Node =
var a: Address
var pk: PublicKey
# TODO: Handle IPv6
var ip = r.get("ip", int32)
a.ip = IpAddress(family: IpAddressFamily.IPv4)
bigEndian32(addr a.ip.address_v4, addr ip)
a.udpPort = Port(r.get("udp", int))
if recoverPublicKey(r.get("secp256k1", seq[byte]), pk) != EthKeysStatus.Success:
echo "Could not recover public key"
result = newNode(initENode(pk, a))
result.record = r
proc hash*(n: Node): hashes.Hash = hash(n.node.pubkey.data)
proc `==`*(a, b: Node): bool = (a.isNil and b.isNil) or (not a.isNil and not b.isNil and a.node.pubkey == b.node.pubkey)
proc `$`*(n: Node): string =
if n == nil:
"Node[local]"
else:
"Node[" & $n.node.address.ip & ":" & $n.node.address.udpPort & "]"

View File

@ -0,0 +1,344 @@
import tables, sets, endians, options, math
import types, encoding, node, routing_table, eth/[rlp, keys], chronicles, chronos, ../enode, stint, enr, byteutils
import nimcrypto except toHex
type
Protocol* = ref object
transp: DatagramTransport
localNode: Node
privateKey: PrivateKey
whoareyouMagic: array[32, byte]
idHash: array[32, byte]
pendingRequests: Table[array[12, byte], PendingRequest]
db: Database
routingTable: RoutingTable
codec: Codec
awaitedPackets: Table[(Node, RequestId), Future[Option[Packet]]]
PendingRequest = object
node: Node
packet: seq[byte]
const
lookupRequestLimit = 15
findnodeResultLimit = 15 # applies in FINDNODE handler
proc whoareyouMagic(toNode: NodeId): array[32, byte] =
let srcId = toNode.toByteArrayBE()
var data: seq[byte]
data.add(srcId)
for c in "WHOAREYOU": data.add(byte(c))
sha256.digest(data).data
proc newProtocol*(privKey: PrivateKey, db: Database, port: Port): Protocol =
result = Protocol(privateKey: privKey, db: db)
var a: Address
a.ip = parseIpAddress("127.0.0.1")
a.udpPort = port
var ipAddr: int32
bigEndian32(addr ipAddr, addr a.ip.address_v4)
result.localNode = newNode(initENode(result.privateKey.getPublicKey(), a))
result.localNode.record = initRecord(12, result.privateKey, {"udp": int(a.udpPort), "ip": ipAddr})
let srcId = result.localNode.id.toByteArrayBE()
result.whoareyouMagic = whoareyouMagic(result.localNode.id)
result.idHash = sha256.digest(srcId).data
result.routingTable.init(result.localNode)
result.codec = Codec(localNode: result.localNode, privKey: result.privateKey, db: result.db)
proc start*(p: Protocol) =
discard
proc send(d: Protocol, a: Address, data: seq[byte]) =
# echo "Sending ", data.len, " bytes to ", a
let ta = initTAddress(a.ip, a.udpPort)
let f = d.transp.sendTo(ta, data)
f.callback = proc(data: pointer) {.gcsafe.} =
if f.failed:
debug "Discovery send failed", msg = f.readError.msg
proc send(d: Protocol, n: Node, data: seq[byte]) =
d.send(n.node.address, data)
proc randomBytes(v: var openarray[byte]) =
if nimcrypto.randomBytes(v) != v.len:
raise newException(Exception, "Could not randomize bytes") # TODO:
proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] =
for i in 0 .. a.high:
result[i] = a[i] xor b[i]
proc isWhoAreYou(d: Protocol, msg: Bytes): bool =
if msg.len > d.whoareyouMagic.len:
result = d.whoareyouMagic == msg.toOpenArray(0, 31)
proc decodeWhoAreYou(d: Protocol, msg: Bytes): Whoareyou =
result = Whoareyou()
result[] = rlp.decode(msg.toRange[32 .. ^1], WhoareyouObj)
proc sendWhoareyou(d: Protocol, address: Address, toNode: NodeId, authTag: array[12, byte]) =
let challenge = Whoareyou(authTag: authTag, recordSeq: 1)
randomBytes(challenge.idNonce)
d.codec.handshakes[$toNode] = challenge
var data = @(whoareyouMagic(toNode))
data.add(rlp.encode(challenge[]))
d.send(address, data)
proc sendNodes(d: Protocol, toNode: Node, reqId: RequestId, nodes: openarray[Node]) =
proc sendNodes(d: Protocol, toNode: Node, packet: NodesPacket, reqId: RequestId) {.nimcall.} =
let (data, _) = d.codec.encodeEncrypted(toNode, 12345, encodePacket(packet, reqId), challenge = nil)
d.send(toNode, data)
const maxNodesPerPacket = 3
var packet: NodesPacket
packet.total = ceil(nodes.len / maxNodesPerPacket).uint32
for i in 0 ..< nodes.len:
packet.enrs.add(nodes[i].record)
if packet.enrs.len == 3:
d.sendNodes(toNode, packet, reqId)
packet.enrs.setLen(0)
if packet.enrs.len != 0:
d.sendNodes(toNode, packet, reqId)
proc handlePing(d: Protocol, fromNode: Node, a: Address, ping: PingPacket, reqId: RequestId) =
var pong: PongPacket
pong.enrSeq = ping.enrSeq
pong.ip = case a.ip.family
of IpAddressFamily.IPv4: @(a.ip.address_v4)
of IpAddressFamily.IPv6: @(a.ip.address_v6)
pong.port = a.udpPort.uint16
let (data, _) = d.codec.encodeEncrypted(fromNode, 12345, encodePacket(pong, reqId), challenge = nil)
d.send(fromNode, data)
proc handleFindNode(d: Protocol, fromNode: Node, a: Address, fn: FindNodePacket, reqId: RequestId) =
if fn.distance == 0:
d.sendNodes(fromNode, reqId, [d.localNode])
else:
let distance = min(fn.distance, 256)
d.sendNodes(fromNode, reqId, d.routingTable.neighboursAtDistance(distance))
proc receive*(d: Protocol, a: Address, msg: Bytes) {.gcsafe.} =
## Can raise `DiscProtocolError` and all of `RlpError`
# Note: export only needed for testing
if msg.len < 32:
return # Invalid msg
try:
# echo "Packet received: ", msg.len
if d.isWhoAreYou(msg):
let whoareyou = d.decodeWhoAreYou(msg)
var pr: PendingRequest
if d.pendingRequests.take(whoareyou.authTag, pr):
let toNode = pr.node
let (data, _) = d.codec.encodeEncrypted(toNode, 12345, pr.packet, challenge = whoareyou)
d.send(toNode, data)
else:
var tag: array[32, byte]
tag[0 .. ^1] = msg.toOpenArray(0, 31)
let senderData = tag xor d.idHash
let sender = readUintBE[256](senderData)
var authTag: array[12, byte]
var node: Node
var packet: Packet
if d.codec.decodeEncrypted(sender, 12345, msg, authTag, node, packet):
if node.isNil:
node = d.routingTable.getNode(sender)
else:
echo "Adding new node to routing table"
discard d.routingTable.addNode(node)
doAssert(not node.isNil, "No node in the routing table (internal error?)")
case packet.kind
of ping:
d.handlePing(node, a, packet.ping, packet.reqId)
of findNode:
d.handleFindNode(node, a, packet.findNode, packet.reqId)
else:
var waiter: Future[Option[Packet]]
if d.awaitedPackets.take((node, packet.reqId), waiter):
waiter.complete(packet.some)
else:
echo "TODO: handle packet: ", packet.kind, " from ", node
else:
d.sendWhoareyou(a, sender, authTag)
echo "Could not decode, respond with whoareyou"
except Exception as e:
echo "Exception: ", e.msg
echo e.getStackTrace()
proc waitPacket(d: Protocol, fromNode: Node, reqId: RequestId): Future[Option[Packet]] =
result = newFuture[Option[Packet]]("waitPacket")
let res = result
let key = (fromNode, reqId)
sleepAsync(1000).addCallback() do(data: pointer):
d.awaitedPackets.del(key)
if not res.finished:
res.complete(none(Packet))
d.awaitedPackets[key] = result
proc addNodesFromENRs(result: var seq[Node], enrs: openarray[Record]) =
for r in enrs: result.add(newNode(r))
proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): Future[seq[Node]] {.async.} =
var op = await d.waitPacket(fromNode, reqId)
if op.isSome and op.get.kind == nodes:
result.addNodesFromENRs(op.get.nodes.enrs)
let total = op.get.nodes.total
for i in 1 ..< total:
op = await d.waitPacket(fromNode, reqId)
if op.isSome and op.get.kind == nodes:
result.addNodesFromENRs(op.get.nodes.enrs)
else:
break
proc findNode(d: Protocol, toNode: Node, distance: uint32): Future[seq[Node]] {.async.} =
let reqId = newRequestId()
let packet = encodePacket(FindNodePacket(distance: distance), reqId)
let (data, nonce) = d.codec.encodeEncrypted(toNode, 12345, packet, challenge = nil)
d.pendingRequests[nonce] = PendingRequest(node: toNode, packet: packet)
d.send(toNode, data)
result = await d.waitNodes(toNode, reqId)
proc lookupDistances(target, dest: NodeId): seq[uint32] =
let td = logDist(target, dest)
result.add(td)
var i = 1'u32
while result.len < lookupRequestLimit:
if td + i < 256:
result.add(td + i)
if td - i > 0'u32:
result.add(td - i)
inc i
proc lookupWorker(p: Protocol, destNode: Node, target: NodeId): Future[seq[Node]] {.async.} =
let dists = lookupDistances(target, destNode.id)
var i = 0
while i < lookupRequestLimit and result.len < findNodeResultLimit:
let r = await p.findNode(destNode, dists[i])
# TODO: Handle falures
result.add(r)
inc i
for n in result:
discard p.routingTable.addNode(n)
proc lookup(p: Protocol, target: NodeId): Future[seq[Node]] {.async.} =
result = p.routingTable.neighbours(target, 16)
var asked = initHashSet[NodeId]()
asked.incl(p.localNode.id)
var seen = asked
const alpha = 3
var pendingQueries = newSeqOfCap[Future[seq[Node]]](alpha)
while true:
var i = 0
while i < result.len and pendingQueries.len < alpha:
let n = result[i]
if not asked.containsOrIncl(n.id):
pendingQueries.add(p.lookupWorker(n, target))
inc i
if pendingQueries.len == 0:
break
let idx = await oneIndex(pendingQueries)
let nodes = pendingQueries[idx].read
pendingQueries.del(idx)
for n in nodes:
if not seen.containsOrIncl(n.id):
if result.len < BUCKET_SIZE:
result.add(n)
proc lookupRandom(p: Protocol): Future[seq[Node]] =
var id: NodeId
discard randomBytes(addr id, sizeof(id))
p.lookup(id)
proc processClient(transp: DatagramTransport,
raddr: TransportAddress): Future[void] {.async, gcsafe.} =
var proto = getUserData[Protocol](transp)
try:
# TODO: Maybe here better to use `peekMessage()` to avoid allocation,
# but `Bytes` object is just a simple seq[byte], and `ByteRange` object
# do not support custom length.
var buf = transp.getMessage()
let a = Address(ip: raddr.address, udpPort: raddr.port, tcpPort: raddr.port)
proto.receive(a, buf)
except RlpError:
debug "Receive failed", err = getCurrentExceptionMsg()
except:
debug "Receive failed", err = getCurrentExceptionMsg()
raise
proc open*(d: Protocol) =
# TODO allow binding to specific IP / IPv6 / etc
let ta = initTAddress(IPv4_any(), d.localNode.node.address.udpPort)
d.transp = newDatagramTransport(processClient, udata = d, local = ta)
when isMainModule:
import discovery_db
import eth/trie/db
proc genDiscoveries(n: int): seq[Protocol] =
var pks = ["98b3d4d4fe348ac5192d16b46aa36c41f847b9f265ba4d56f6326669449a968b", "88d125288fbb19ecd7b6a355faf3e842e3c6158d38af14bb97ac8d957ec9cb58", "c9a24471d2f84efa103b9abbdedd4c0fea8402f94e5ceb3ca4d9cff951fc407f"]
for i in 0 ..< n:
var pk: PrivateKey
if i < pks.len:
pk = initPrivateKey(pks[i])
else:
pk = newPrivateKey()
let d = newProtocol(pk, DiscoveryDB.init(newMemoryDB()), Port(12001 + i))
d.open()
result.add(d)
proc addNode(d: Protocol, enr: string) =
var r: Record
let res = r.fromUri(enr)
doAssert(res)
discard d.routingTable.addNode(newNode(r))
proc addNode(d: openarray[Protocol], enr: string) =
for dd in d: dd.addNode(enr)
proc test() {.async.} =
block:
let d = genDiscoveries(3)
d.addNode("enr:-IS4QPvi3TdAUd2Jdrx-8ScRbCzrV1kVsTTM02mfz8Fx7CtrAfYN7AjxTx3MWbY2efRmAhS-Yyv4nhyzKu_YS6jSh08BgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQJeWTAJhJYN2q3BvcQwsyo7pIi8KnfwDIrhNdflCFvqr4N1ZHCCD6A")
for i, dd in d:
let nodes = await dd.lookupRandom()
echo "NODES ", i, ": ", nodes
# block:
# var d = genDiscoveries(4)
# let rootD = d[0]
# d.del(0)
# d.addNode(rootD.localNode.record.toUri)
# for i, dd in d:
# let nodes = await dd.lookupRandom()
# echo "NODES ", i, ": ", nodes
waitFor test()
runForever()

View File

@ -0,0 +1,198 @@
import algorithm, times, sequtils, bitops
import types, node
import stint, chronicles
type
RoutingTable* = object
thisNode: Node
buckets: seq[KBucket]
KBucket = ref object
istart, iend: NodeId
nodes: seq[Node]
replacementCache: seq[Node]
lastUpdated: float # epochTime
const
BUCKET_SIZE* = 16
BITS_PER_HOP = 8
ID_SIZE = 256
proc distanceTo(n: Node, id: NodeId): UInt256 = n.id xor id
proc logDist*(a, b: NodeId): uint32 =
let a = a.toBytes
let b = b.toBytes
var lz = 0
for i in 0 ..< a.len:
let x = a[i] xor b[i]
if x == 0:
result += 8
else:
result += bitops.countLeadingZeroBits(x).uint8
uint32(a.len * 8 - lz)
proc newKBucket(istart, iend: NodeId): KBucket =
result.new()
result.istart = istart
result.iend = iend
result.nodes = @[]
result.replacementCache = @[]
proc midpoint(k: KBucket): NodeId =
k.istart + (k.iend - k.istart) div 2.u256
proc distanceTo(k: KBucket, id: NodeId): UInt256 = k.midpoint xor id
proc nodesByDistanceTo(k: KBucket, id: NodeId): seq[Node] =
sortedByIt(k.nodes, it.distanceTo(id))
proc len(k: KBucket): int {.inline.} = k.nodes.len
proc head(k: KBucket): Node {.inline.} = k.nodes[0]
proc add(k: KBucket, n: Node): Node =
## Try to add the given node to this bucket.
## If the node is already present, it is moved to the tail of the list, and we return nil.
## If the node is not already present and the bucket has fewer than k entries, it is inserted
## at the tail of the list, and we return nil.
## If the bucket is full, we add the node to the bucket's replacement cache and return the
## node at the head of the list (i.e. the least recently seen), which should be evicted if it
## fails to respond to a ping.
k.lastUpdated = epochTime()
let nodeIdx = k.nodes.find(n)
if nodeIdx != -1:
k.nodes.delete(nodeIdx)
k.nodes.add(n)
elif k.len < BUCKET_SIZE:
k.nodes.add(n)
else:
k.replacementCache.add(n)
return k.head
return nil
proc removeNode(k: KBucket, n: Node) =
let i = k.nodes.find(n)
if i != -1: k.nodes.delete(i)
proc split(k: KBucket): tuple[lower, upper: KBucket] =
## Split at the median id
let splitid = k.midpoint
result.lower = newKBucket(k.istart, splitid)
result.upper = newKBucket(splitid + 1.u256, k.iend)
for node in k.nodes:
let bucket = if node.id <= splitid: result.lower else: result.upper
discard bucket.add(node)
for node in k.replacementCache:
let bucket = if node.id <= splitid: result.lower else: result.upper
bucket.replacementCache.add(node)
proc inRange(k: KBucket, n: Node): bool {.inline.} =
k.istart <= n.id and n.id <= k.iend
proc isFull(k: KBucket): bool = k.len == BUCKET_SIZE
proc contains(k: KBucket, n: Node): bool = n in k.nodes
proc binaryGetBucketForNode(buckets: openarray[KBucket],
id: NodeId): KBucket {.inline.} =
## Given a list of ordered buckets, returns the bucket for a given node.
let bucketPos = lowerBound(buckets, id) do(a: KBucket, b: NodeId) -> int:
cmp(a.iend, b)
# Prevents edge cases where bisect_left returns an out of range index
if bucketPos < buckets.len:
let bucket = buckets[bucketPos]
if bucket.istart <= id and id <= bucket.iend:
result = bucket
if result.isNil:
raise newException(ValueError, "No bucket found for node with id " & $id)
proc computeSharedPrefixBits(nodes: openarray[Node]): int =
## Count the number of prefix bits shared by all nodes.
if nodes.len < 2:
return ID_SIZE
var mask = zero(UInt256)
let one = one(UInt256)
for i in 1 .. ID_SIZE:
mask = mask or (one shl (ID_SIZE - i))
let reference = nodes[0].id and mask
for j in 1 .. nodes.high:
if (nodes[j].id and mask) != reference: return i - 1
for n in nodes:
echo n.id.toHex()
doAssert(false, "Unable to calculate number of shared prefix bits")
proc init*(r: var RoutingTable, thisNode: Node) {.inline.} =
r.thisNode = thisNode
r.buckets = @[newKBucket(0.u256, high(Uint256))]
proc splitBucket(r: var RoutingTable, index: int) =
let bucket = r.buckets[index]
let (a, b) = bucket.split()
r.buckets[index] = a
r.buckets.insert(b, index + 1)
proc bucketForNode(r: RoutingTable, id: NodeId): KBucket =
binaryGetBucketForNode(r.buckets, id)
proc removeNode(r: var RoutingTable, n: Node) =
r.bucketForNode(n.id).removeNode(n)
proc addNode*(r: var RoutingTable, n: Node): Node =
if n == r.thisNode:
# warn "Trying to add ourselves to the routing table", node = n
return
let bucket = r.bucketForNode(n.id)
let evictionCandidate = bucket.add(n)
if not evictionCandidate.isNil:
# Split if the bucket has the local node in its range or if the depth is not congruent
# to 0 mod BITS_PER_HOP
let depth = computeSharedPrefixBits(bucket.nodes)
if bucket.inRange(r.thisNode) or (depth mod BITS_PER_HOP != 0 and depth != ID_SIZE):
r.splitBucket(r.buckets.find(bucket))
return r.addNode(n) # retry
# Nothing added, ping evictionCandidate
return evictionCandidate
proc getNode*(r: RoutingTable, id: NodeId): Node =
let b = binaryGetBucketForNode(r.buckets, id)
for n in b.nodes:
if n.id == id:
return n
proc contains(r: RoutingTable, n: Node): bool = n in r.bucketForNode(n.id)
proc bucketsByDistanceTo(r: RoutingTable, id: NodeId): seq[KBucket] =
sortedByIt(r.buckets, it.distanceTo(id))
proc notFullBuckets(r: RoutingTable): seq[KBucket] =
r.buckets.filterIt(not it.isFull)
proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE): seq[Node] =
## Return up to k neighbours of the given node.
result = newSeqOfCap[Node](k * 2)
for bucket in r.bucketsByDistanceTo(id):
for n in bucket.nodesByDistanceTo(id):
if n.id != id:
result.add(n)
if result.len == k * 2:
break
result = sortedByIt(result, it.distanceTo(id))
if result.len > k:
result.setLen(k)
proc idAtDistance(id: NodeId, dist: uint32): NodeId =
id and (Uint256.high shl dist.int)
proc neighboursAtDistance*(r: RoutingTable, distance: uint32, k: int = BUCKET_SIZE): seq[Node] =
r.neighbours(idAtDistance(r.thisNode.id, distance), k)
proc len(r: RoutingTable): int =
for b in r.buckets: result += b.len

View File

@ -0,0 +1,73 @@
import hashes
import ../enode, enr, stint
type
NodeId* = UInt256
WhoareyouObj* = object
authTag*: array[12, byte]
idNonce*: array[32, byte]
recordSeq*: uint64
Whoareyou* = ref WhoareyouObj
Database* = ref object of RootRef
PacketKind* = enum
ping = 0x01
pong = 0x02
findnode = 0x03
nodes = 0x04
regtopic = 0x05
ticket = 0x06
regconfirmation = 0x07
topicquery = 0x08
RequestId* = array[8, byte]
PingPacket* = object
enrSeq*: uint64
PongPacket* = object
enrSeq*: uint64
ip*: seq[byte]
port*: uint16
FindNodePacket* = object
distance*: uint32
NodesPacket* = object
total*: uint32
enrs*: seq[Record]
SomePacket* = PingPacket or PongPacket or FindNodePacket or NodesPacket
Packet* = object
reqId*: RequestId
case kind*: PacketKind
of ping:
ping*: PingPacket
of pong:
pong*: PongPacket
of findnode:
findNode*: FindNodePacket
of nodes:
nodes*: NodesPacket
else:
# TODO: Define the rest
discard
template packetKind*(T: typedesc[SomePacket]): PacketKind =
when T is PingPacket: ping
elif T is PongPacket: pong
elif T is FindNodePacket: findNode
elif T is NodesPacket: nodes
method storeKeys*(db: Database, id: NodeId, address: int, r, w: array[16, byte]) {.base.} = discard
method loadKeys*(db: Database, id: NodeId, address: int, r, w: var array[16, byte]): bool {.base.} = discard
proc toBytes*(id: NodeId): array[32, byte] {.inline.} =
id.toByteArrayBE()
proc hash*(id: NodeId): Hash {.inline.} =
hashData(unsafeAddr id, sizeof(id))

View File

@ -13,7 +13,7 @@ export
type
Rlp* = object
bytes: BytesRange
position: int
position*: int
RlpNodeType* = enum
rlpBlob