mirror of https://github.com/status-im/nim-eth.git
Discv5 WIP
This commit is contained in:
parent
edd674662a
commit
eda6c2906c
|
@ -0,0 +1,30 @@
|
|||
import types
|
||||
import eth/trie/db
|
||||
|
||||
type
|
||||
DiscoveryDB* = ref object of Database
|
||||
backend: TrieDatabaseRef
|
||||
|
||||
DbKeyKind = enum
|
||||
kNodeToKeys = 100
|
||||
|
||||
proc init*(T: type DiscoveryDB, backend: TrieDatabaseRef): DiscoveryDB =
|
||||
T(backend: backend)
|
||||
|
||||
proc makeKey(id: NodeId, address: int): array[1 + sizeof(id) + sizeof(address), byte] =
|
||||
result[0] = byte(kNodeToKeys)
|
||||
copyMem(addr result[1], unsafeAddr id, sizeof(id))
|
||||
copyMem(addr result[sizeof(id) + 1], unsafeAddr address, sizeof(address))
|
||||
|
||||
method storeKeys*(db: DiscoveryDB, id: NodeId, address: int, r, w: array[16, byte]) =
|
||||
var value: array[sizeof(r) + sizeof(w), byte]
|
||||
value[0 .. 15] = r
|
||||
value[16 .. ^1] = w
|
||||
db.backend.put(makeKey(id, address), value)
|
||||
|
||||
method loadKeys*(db: DiscoveryDB, id: NodeId, address: int, r, w: var array[16, byte]): bool =
|
||||
let res = db.backend.get(makeKey(id, address))
|
||||
if res.len == sizeof(r) + sizeof(w):
|
||||
copyMem(addr r[0], unsafeAddr res[0], sizeof(r))
|
||||
copyMem(addr w[0], unsafeAddr res[sizeof(r)], sizeof(w))
|
||||
result = true
|
|
@ -0,0 +1,274 @@
|
|||
import tables
|
||||
import types, node, enr, hkdf, eth/[rlp, keys], nimcrypto, stint
|
||||
|
||||
const
|
||||
idNoncePrefix = "discovery-id-nonce"
|
||||
gcmNonceSize* = 12
|
||||
keyAgreementPrefix = "discovery v5 key agreement"
|
||||
authSchemeName = "gcm"
|
||||
|
||||
type
|
||||
AuthResponse = object
|
||||
version: int
|
||||
signature: array[64, byte]
|
||||
record: Record
|
||||
|
||||
Codec* = object
|
||||
localNode*: Node
|
||||
privKey*: PrivateKey
|
||||
db*: Database
|
||||
handshakes*: Table[string, Whoareyou] # TODO: Implement hash for NodeID
|
||||
|
||||
HandshakeSecrets = object
|
||||
writeKey: array[16, byte]
|
||||
readKey: array[16, byte]
|
||||
authRespKey: array[16, byte]
|
||||
|
||||
AuthHeader = object
|
||||
auth: array[12, byte]
|
||||
idNonce: array[32, byte]
|
||||
scheme: string
|
||||
ephemeralKey: array[64, byte]
|
||||
response: seq[byte]
|
||||
|
||||
|
||||
const
|
||||
gcmTagSize = 16
|
||||
|
||||
proc randomBytes(v: var openarray[byte]) =
|
||||
if nimcrypto.randomBytes(v) != v.len:
|
||||
raise newException(Exception, "Could not randomize bytes") # TODO:
|
||||
|
||||
proc idNonceHash(nonce, ephkey: openarray[byte]): array[32, byte] =
|
||||
var ctx: sha256
|
||||
ctx.init()
|
||||
ctx.update(idNoncePrefix)
|
||||
ctx.update(nonce)
|
||||
ctx.update(ephkey)
|
||||
ctx.finish().data
|
||||
|
||||
proc signIDNonce(c: Codec, idNonce, ephKey: openarray[byte]): SignatureNR =
|
||||
if signRawMessage(idNonceHash(idNonce, ephKey), c.privKey, result) != EthKeysStatus.Success:
|
||||
raise newException(Exception, "Could not sign idNonce")
|
||||
|
||||
proc deriveKeys(n1, n2: NodeID, priv: PrivateKey, pub: PublicKey, challenge: Whoareyou, result: var HandshakeSecrets) =
|
||||
var eph: SharedSecretFull
|
||||
if ecdhAgree(priv, pub, eph) != EthKeysStatus.Success:
|
||||
raise newException(Exception, "ecdhAgree failed")
|
||||
|
||||
# TODO: Unneeded allocation here
|
||||
var info = newSeqOfCap[byte](idNoncePrefix.len + 32 * 2)
|
||||
for i, c in keyAgreementPrefix: info.add(byte(c))
|
||||
info.add(n1.toByteArrayBE())
|
||||
info.add(n2.toByteArrayBE())
|
||||
|
||||
# echo "EPH: ", eph.data.toHex, " idNonce: ", challenge.idNonce.toHex, "info: ", info.toHex
|
||||
|
||||
static: assert(sizeof(result) == 16 * 3)
|
||||
var res = cast[ptr UncheckedArray[byte]](addr result)
|
||||
hkdf(sha256, eph.data, challenge.idNonce, info, toOpenArray(res, 0, sizeof(result) - 1))
|
||||
|
||||
proc encryptGCM(key, nonce, pt, authData: openarray[byte]): seq[byte] =
|
||||
var ectx: GCM[aes128]
|
||||
ectx.init(key, nonce, authData)
|
||||
result = newSeq[byte](pt.len + gcmTagSize)
|
||||
ectx.encrypt(pt, result)
|
||||
ectx.getTag(result.toOpenArray(pt.len, result.high))
|
||||
ectx.clear()
|
||||
|
||||
proc makeAuthHeader(c: Codec, toNode: Node, nonce: array[gcmNonceSize, byte],
|
||||
handhsakeSecrets: var HandshakeSecrets, challenge: Whoareyou): seq[byte] =
|
||||
var resp = AuthResponse(version: 5)
|
||||
let ln = c.localNode
|
||||
|
||||
if challenge.recordSeq < ln.record.sequenceNumber:
|
||||
resp.record = ln.record
|
||||
|
||||
var remotePubkey: PublicKey
|
||||
if not toNode.record.get(remotePubkey):
|
||||
raise newException(Exception, "Could not get public key from remote ENR") # Should not happen!
|
||||
|
||||
let ephKey = newPrivateKey()
|
||||
let ephPubkey = ephKey.getPublicKey().getRaw
|
||||
|
||||
resp.signature = c.signIDNonce(challenge.idNonce, ephPubkey).getRaw
|
||||
|
||||
deriveKeys(ln.id, toNode.id, ephKey, remotePubkey, challenge, handhsakeSecrets)
|
||||
|
||||
let respRlp = rlp.encode(resp)
|
||||
|
||||
var zeroNonce: array[gcmNonceSize, byte]
|
||||
let respEnc = encryptGCM(handhsakeSecrets.authRespKey, zeroNonce, respRLP, [])
|
||||
|
||||
let header = AuthHeader(auth: nonce, idNonce: challenge.idNonce, scheme: authSchemeName,
|
||||
ephemeralKey: ephPubkey, response: respEnc)
|
||||
rlp.encode(header)
|
||||
|
||||
proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] =
|
||||
for i in 0 .. a.high:
|
||||
result[i] = a[i] xor b[i]
|
||||
|
||||
proc packetTag(destNode, srcNode: NodeID): array[32, byte] =
|
||||
let destId = destNode.toByteArrayBE()
|
||||
let srcId = srcNode.toByteArrayBE()
|
||||
let destidHash = sha256.digest(destId)
|
||||
result = srcId xor destidHash.data
|
||||
|
||||
proc encodeEncrypted*(c: Codec, toNode: Node, toAddr: int, packetData: seq[byte], challenge: Whoareyou): (seq[byte], array[gcmNonceSize, byte]) =
|
||||
var nonce: array[gcmNonceSize, byte]
|
||||
randomBytes(nonce)
|
||||
var headEnc: seq[byte]
|
||||
|
||||
var writeKey: array[16, byte]
|
||||
|
||||
if challenge.isNil:
|
||||
headEnc = rlp.encode(nonce)
|
||||
var readKey: array[16, byte]
|
||||
|
||||
# We might not have the node's keys if the handshake hasn't been performed
|
||||
# yet. That's fine, we will be responded with whoareyou.
|
||||
discard c.db.loadKeys(toNode.id, toAddr, readKey, writeKey)
|
||||
else:
|
||||
var sec: HandshakeSecrets
|
||||
headEnc = c.makeAuthHeader(toNode, nonce, sec, challenge)
|
||||
|
||||
writeKey = sec.writeKey
|
||||
|
||||
c.db.storeKeys(toNode.id, toAddr, sec.readKey, sec.writeKey)
|
||||
|
||||
var body = packetData
|
||||
let tag = packetTag(toNode.id, c.localNode.id)
|
||||
|
||||
var headBuf = newSeqOfCap[byte](tag.len + headEnc.len)
|
||||
headBuf.add(tag)
|
||||
headBuf.add(headEnc)
|
||||
|
||||
headBuf.add(encryptGCM(writeKey, nonce, body, tag))
|
||||
return (headBuf, nonce)
|
||||
|
||||
proc decryptGCM(key: array[16, byte], nonce, ct, authData: openarray[byte]): seq[byte] =
|
||||
var dctx: GCM[aes128]
|
||||
dctx.init(key, nonce, authData)
|
||||
result = newSeq[byte](ct.len - gcmTagSize)
|
||||
var tag: array[gcmTagSize, byte]
|
||||
dctx.decrypt(ct.toOpenArray(0, ct.high - gcmTagSize), result)
|
||||
dctx.getTag(tag)
|
||||
if tag != ct.toOpenArray(ct.len - gcmTagSize, ct.high):
|
||||
result = @[]
|
||||
dctx.clear()
|
||||
|
||||
proc decodePacketBody(typ: byte, body: openarray[byte], res: var Packet): bool =
|
||||
if typ >= PacketKind.low.byte and typ <= PacketKind.high.byte:
|
||||
let kind = cast[PacketKind](typ)
|
||||
res = Packet(kind: kind)
|
||||
var rlp = rlpFromBytes(@body.toRange)
|
||||
rlp.enterList()
|
||||
res.reqId = rlp.read(RequestId)
|
||||
|
||||
proc decode[T](rlp: var Rlp, v: var T) {.inline, nimcall.} =
|
||||
for k, v in v.fieldPairs:
|
||||
v = rlp.read(typeof(v))
|
||||
|
||||
template decode(k: untyped) =
|
||||
if k == kind:
|
||||
decode(rlp, res.k)
|
||||
result = true
|
||||
|
||||
decode(ping)
|
||||
decode(pong)
|
||||
decode(findNode)
|
||||
decode(nodes)
|
||||
else:
|
||||
echo "unknown packet: ", typ
|
||||
|
||||
return true
|
||||
|
||||
proc decodeAuthResp(c: Codec, fromId: NodeId, head: AuthHeader, challenge: Whoareyou, secrets: var HandshakeSecrets, newNode: var Node): bool =
|
||||
if head.scheme != authSchemeName:
|
||||
echo "Unknown auth scheme"
|
||||
return false
|
||||
|
||||
var ephKey: PublicKey
|
||||
if recoverPublicKey(head.ephemeralKey, ephKey) != EthKeysStatus.Success:
|
||||
return false
|
||||
|
||||
deriveKeys(fromId, c.localNode.id, c.privKey, ephKey, challenge, secrets)
|
||||
|
||||
var zeroNonce: array[gcmNonceSize, byte]
|
||||
let respData = decryptGCM(secrets.authRespKey, zeroNonce, head.response, [])
|
||||
let authResp = rlp.decode(respData, AuthResponse)
|
||||
|
||||
newNode = newNode(authResp.record)
|
||||
return true
|
||||
|
||||
proc decodeEncrypted*(c: var Codec, fromId: NodeID, fromAddr: int, input: seq[byte], authTag: var array[12, byte], newNode: var Node, packet: var Packet): bool =
|
||||
let input = input.toRange
|
||||
var r = rlpFromBytes(input[32 .. ^1])
|
||||
let authEndPos = r.currentElemEnd
|
||||
var auth: AuthHeader
|
||||
var readKey: array[16, byte]
|
||||
if r.isList:
|
||||
# Handshake
|
||||
|
||||
# TODO: Auth failure will result in resending whoareyou. Do we really want this?
|
||||
auth = r.read(AuthHeader)
|
||||
authTag = auth.auth
|
||||
|
||||
let challenge = c.handshakes.getOrDefault($fromId)
|
||||
if challenge.isNil:
|
||||
return false
|
||||
|
||||
if auth.idNonce != challenge.idNonce:
|
||||
return false
|
||||
|
||||
var sec: HandshakeSecrets
|
||||
if not c.decodeAuthResp(fromId, auth, challenge, sec, newNode):
|
||||
return false
|
||||
c.handshakes.del($fromId)
|
||||
|
||||
# Swap keys to match remote
|
||||
swap(sec.readKey, sec.writeKey)
|
||||
c.db.storeKeys(fromId, fromAddr, sec.readKey, sec.writeKey)
|
||||
readKey = sec.readKey
|
||||
|
||||
else:
|
||||
authTag = r.read(array[12, byte])
|
||||
auth.auth = authTag
|
||||
var writeKey: array[16, byte]
|
||||
if not c.db.loadKeys(fromId, fromAddr, readKey, writeKey):
|
||||
return false
|
||||
# doAssert(false, "TODO: HANDLE ME!")
|
||||
|
||||
let headSize = 32 + r.position
|
||||
let bodyEnc = input[headSize .. ^1]
|
||||
|
||||
let body = decryptGCM(readKey, auth.auth, bodyEnc.toOpenArray, input[0 .. 31].toOpenArray)
|
||||
if body.len > 1:
|
||||
result = decodePacketBody(body[0], body.toOpenArray(1, body.high), packet)
|
||||
|
||||
proc newRequestId*(): RequestId =
|
||||
randomBytes(result)
|
||||
|
||||
proc numFields(T: typedesc): int =
|
||||
for k, v in fieldPairs(default(T)): inc result
|
||||
|
||||
proc encodePacket*[T: SomePacket](p: T, reqId: RequestId): seq[byte] =
|
||||
result = newSeqOfCap[byte](64)
|
||||
result.add(packetKind(T).ord)
|
||||
# result.add(rlp.encode(p))
|
||||
|
||||
const sz = numFields(T)
|
||||
var writer = initRlpList(sz + 1)
|
||||
writer.append(reqId)
|
||||
for k, v in fieldPairs(p):
|
||||
writer.append(v)
|
||||
result.add(writer.finish())
|
||||
|
||||
proc encodePacket*[T: SomePacket](p: T): seq[byte] =
|
||||
encodePacket(p, newRequestId())
|
||||
|
||||
proc makePingPacket*(enrSeq: uint64): seq[byte] =
|
||||
encodePacket(PingPacket(enrSeq: enrSeq))
|
||||
|
||||
proc makeFindnodePacket*(distance: uint32): seq[byte] =
|
||||
encodePacket(FindNodePacket(distance: distance))
|
|
@ -94,7 +94,7 @@ proc get*[T: seq[byte] | string | SomeInteger](r: Record, key: string, typ: type
|
|||
if r.getField(key, f):
|
||||
when typ is SomeInteger:
|
||||
requireKind(f, kNum)
|
||||
return f.num
|
||||
return typ(f.num)
|
||||
elif typ is seq[byte]:
|
||||
requireKind(f, kBytes)
|
||||
return f.bytes
|
||||
|
@ -223,3 +223,11 @@ proc `$`*(r: Record): string =
|
|||
result &= ": "
|
||||
result &= $v
|
||||
result &= ')'
|
||||
|
||||
proc read*(rlp: var Rlp, T: typedesc[Record]): T {.inline.} =
|
||||
if not result.fromBytes(rlp.rawData.toOpenArray):
|
||||
raise newException(ValueError, "Could not deserialize")
|
||||
rlp.skipElem()
|
||||
|
||||
proc append*(rlpWriter: var RlpWriter, value: Record) =
|
||||
rlpWriter.appendRawBytes(value.raw.toRange)
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
import nimcrypto
|
||||
|
||||
proc hkdf*(HashType: typedesc, secret, salt, info: openarray[byte], output: var openarray[byte]) =
|
||||
var ctx: HMAC[HashType]
|
||||
ctx.init(salt)
|
||||
ctx.update(secret)
|
||||
let prk = ctx.finish().data
|
||||
const hashLen = HashType.bits div 8
|
||||
|
||||
var t: MDigest[HashType.bits]
|
||||
|
||||
var numIters = output.len div hashLen
|
||||
if output.len mod hashLen != 0:
|
||||
inc numIters
|
||||
|
||||
for i in 0 ..< numIters:
|
||||
ctx.init(prk)
|
||||
if i != 0:
|
||||
ctx.update(t.data)
|
||||
ctx.update(info)
|
||||
ctx.update([uint8(i + 1)])
|
||||
t = ctx.finish()
|
||||
let iStart = i * hashLen
|
||||
var sz = hashLen
|
||||
if iStart + sz >= output.len:
|
||||
sz = output.len - iStart
|
||||
copyMem(addr output[iStart], addr t.data, sz)
|
||||
|
||||
ctx.clear()
|
||||
|
||||
when isMainModule:
|
||||
import stew/byteutils
|
||||
|
||||
proc hextToBytes(s: string): seq[byte] =
|
||||
if s.len != 0: return hexToSeqByte(s)
|
||||
|
||||
template test(constants: untyped) =
|
||||
block:
|
||||
constants
|
||||
|
||||
let
|
||||
bikm = hextToBytes(IKM)
|
||||
bsalt = hextToBytes(salt)
|
||||
binfo = hextToBytes(info)
|
||||
bprk {.used.} = hextToBytes(PRK)
|
||||
bokm = hextToBytes(OKM)
|
||||
|
||||
var output = newSeq[byte](L)
|
||||
hkdf(HashType, bikm, bsalt, binfo, output)
|
||||
doAssert(output == bokm)
|
||||
|
||||
test: # 1
|
||||
type HashType = sha256
|
||||
const
|
||||
IKM = "0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"
|
||||
salt = "0x000102030405060708090a0b0c"
|
||||
info = "0xf0f1f2f3f4f5f6f7f8f9"
|
||||
L = 42
|
||||
|
||||
PRK = "0x077709362c2e32df0ddc3f0dc47bba63" &
|
||||
"90b6c73bb50f9c3122ec844ad7c2b3e5"
|
||||
OKM = "0x3cb25f25faacd57a90434f64d0362f2a" &
|
||||
"2d2d0a90cf1a5a4c5db02d56ecc4c5bf" &
|
||||
"34007208d5b887185865"
|
||||
|
||||
test: # 2
|
||||
type HashType = sha256
|
||||
const
|
||||
IKM = "0x000102030405060708090a0b0c0d0e0f" &
|
||||
"101112131415161718191a1b1c1d1e1f" &
|
||||
"202122232425262728292a2b2c2d2e2f" &
|
||||
"303132333435363738393a3b3c3d3e3f" &
|
||||
"404142434445464748494a4b4c4d4e4f"
|
||||
salt = "0x606162636465666768696a6b6c6d6e6f" &
|
||||
"707172737475767778797a7b7c7d7e7f" &
|
||||
"808182838485868788898a8b8c8d8e8f" &
|
||||
"909192939495969798999a9b9c9d9e9f" &
|
||||
"a0a1a2a3a4a5a6a7a8a9aaabacadaeaf"
|
||||
info = "0xb0b1b2b3b4b5b6b7b8b9babbbcbdbebf" &
|
||||
"c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" &
|
||||
"d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" &
|
||||
"e0e1e2e3e4e5e6e7e8e9eaebecedeeef" &
|
||||
"f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff"
|
||||
L = 82
|
||||
|
||||
PRK = "0x06a6b88c5853361a06104c9ceb35b45c" &
|
||||
"ef760014904671014a193f40c15fc244"
|
||||
OKM = "0xb11e398dc80327a1c8e7f78c596a4934" &
|
||||
"4f012eda2d4efad8a050cc4c19afa97c" &
|
||||
"59045a99cac7827271cb41c65e590e09" &
|
||||
"da3275600c2f09b8367793a9aca3db71" &
|
||||
"cc30c58179ec3e87c14c01d5c1f3434f" &
|
||||
"1d87"
|
||||
|
||||
test: # 3
|
||||
type HashType = sha256
|
||||
const
|
||||
IKM = "0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"
|
||||
salt = ""
|
||||
info = ""
|
||||
L = 42
|
||||
|
||||
PRK = "0x19ef24a32c717b167f33a91d6f648bdf" &
|
||||
"96596776afdb6377ac434c1c293ccb04"
|
||||
OKM = "0x8da4e775a563c18f715f802a063c5a31" &
|
||||
"b8a11f5c5ee1879ec3454e5f3c738d2d" &
|
||||
"9d201395faa4b61a96c8"
|
||||
|
||||
test: # 4
|
||||
type HashType = sha1
|
||||
const
|
||||
IKM = "0x0b0b0b0b0b0b0b0b0b0b0b"
|
||||
salt = "0x000102030405060708090a0b0c"
|
||||
info = "0xf0f1f2f3f4f5f6f7f8f9"
|
||||
L = 42
|
||||
|
||||
PRK = "0x9b6c18c432a7bf8f0e71c8eb88f4b30baa2ba243"
|
||||
OKM = "0x085a01ea1b10f36933068b56efa5ad81" &
|
||||
"a4f14b822f5b091568a9cdd4f155fda2" &
|
||||
"c22e422478d305f3f896"
|
||||
|
||||
test: # 5
|
||||
type HashType = sha1
|
||||
const
|
||||
IKM = "0x000102030405060708090a0b0c0d0e0f" &
|
||||
"101112131415161718191a1b1c1d1e1f" &
|
||||
"202122232425262728292a2b2c2d2e2f" &
|
||||
"303132333435363738393a3b3c3d3e3f" &
|
||||
"404142434445464748494a4b4c4d4e4f"
|
||||
salt = "0x606162636465666768696a6b6c6d6e6f" &
|
||||
"707172737475767778797a7b7c7d7e7f" &
|
||||
"808182838485868788898a8b8c8d8e8f" &
|
||||
"909192939495969798999a9b9c9d9e9f" &
|
||||
"a0a1a2a3a4a5a6a7a8a9aaabacadaeaf"
|
||||
info = "0xb0b1b2b3b4b5b6b7b8b9babbbcbdbebf" &
|
||||
"c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" &
|
||||
"d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" &
|
||||
"e0e1e2e3e4e5e6e7e8e9eaebecedeeef" &
|
||||
"f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff"
|
||||
L = 82
|
||||
|
||||
PRK = "0x8adae09a2a307059478d309b26c4115a224cfaf6"
|
||||
OKM = "0x0bd770a74d1160f7c9f12cd5912a06eb" &
|
||||
"ff6adcae899d92191fe4305673ba2ffe" &
|
||||
"8fa3f1a4e5ad79f3f334b3b202b2173c" &
|
||||
"486ea37ce3d397ed034c7f9dfeb15c5e" &
|
||||
"927336d0441f4c4300e2cff0d0900b52" &
|
||||
"d3b4"
|
||||
|
||||
test: # 6
|
||||
type HashType = sha1
|
||||
const
|
||||
IKM = "0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"
|
||||
salt = ""
|
||||
info = ""
|
||||
L = 42
|
||||
|
||||
PRK = "0xda8c8a73c7fa77288ec6f5e7c297786aa0d32d01"
|
||||
OKM = "0x0ac1af7002b3d761d1e55298da9d0506" &
|
||||
"b9ae52057220a306e07b6b87e8df21d0" &
|
||||
"ea00033de03984d34918"
|
||||
|
||||
test: # 7
|
||||
type HashType = sha1
|
||||
const
|
||||
IKM = "0x0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c"
|
||||
salt = ""
|
||||
info = ""
|
||||
L = 42
|
||||
|
||||
PRK = "0x2adccada18779e7c2077ad2eb19d3f3e731385dd"
|
||||
OKM = "0x2c91117204d745f3500d636a62f64f0a" &
|
||||
"b3bae548aa53d423b0d1f27ebba6f5e5" &
|
||||
"673a081d70cce7acfc48"
|
||||
|
||||
block:
|
||||
var output: array[5, byte]
|
||||
var secret = [0x01.byte, 0x02, 0x03]
|
||||
var salt = [0x04.byte, 0x05, 0x06]
|
||||
var info = [0x07.byte, 0x08, 0x09]
|
||||
hkdf(sha256, secret, salt, info, output)
|
||||
doAssert(@output == "D5CF839F63".hextToBytes)
|
|
@ -0,0 +1,52 @@
|
|||
import std/[net, endians, hashes]
|
||||
import nimcrypto, stint
|
||||
import types, enr, eth/keys, ../enode
|
||||
|
||||
type
|
||||
Node* = ref object
|
||||
node*: ENode
|
||||
id*: NodeId
|
||||
record*: Record
|
||||
|
||||
proc toNodeId*(pk: PublicKey): NodeId =
|
||||
readUintBE[256](keccak256.digest(pk.getRaw()).data)
|
||||
|
||||
proc newNode*(pk: PublicKey, address: Address): Node =
|
||||
result.new()
|
||||
result.node = initENode(pk, address)
|
||||
result.id = pk.toNodeId()
|
||||
|
||||
proc newNode*(uriString: string): Node =
|
||||
result.new()
|
||||
result.node = initENode(uriString)
|
||||
result.id = result.node.pubkey.toNodeId()
|
||||
|
||||
proc newNode*(enode: ENode): Node =
|
||||
result.new()
|
||||
result.node = enode
|
||||
result.id = result.node.pubkey.toNodeId()
|
||||
|
||||
proc newNode*(r: Record): Node =
|
||||
var a: Address
|
||||
var pk: PublicKey
|
||||
# TODO: Handle IPv6
|
||||
var ip = r.get("ip", int32)
|
||||
|
||||
a.ip = IpAddress(family: IpAddressFamily.IPv4)
|
||||
bigEndian32(addr a.ip.address_v4, addr ip)
|
||||
|
||||
a.udpPort = Port(r.get("udp", int))
|
||||
if recoverPublicKey(r.get("secp256k1", seq[byte]), pk) != EthKeysStatus.Success:
|
||||
echo "Could not recover public key"
|
||||
|
||||
result = newNode(initENode(pk, a))
|
||||
result.record = r
|
||||
|
||||
proc hash*(n: Node): hashes.Hash = hash(n.node.pubkey.data)
|
||||
proc `==`*(a, b: Node): bool = (a.isNil and b.isNil) or (not a.isNil and not b.isNil and a.node.pubkey == b.node.pubkey)
|
||||
|
||||
proc `$`*(n: Node): string =
|
||||
if n == nil:
|
||||
"Node[local]"
|
||||
else:
|
||||
"Node[" & $n.node.address.ip & ":" & $n.node.address.udpPort & "]"
|
|
@ -0,0 +1,344 @@
|
|||
import tables, sets, endians, options, math
|
||||
import types, encoding, node, routing_table, eth/[rlp, keys], chronicles, chronos, ../enode, stint, enr, byteutils
|
||||
import nimcrypto except toHex
|
||||
|
||||
type
|
||||
Protocol* = ref object
|
||||
transp: DatagramTransport
|
||||
localNode: Node
|
||||
privateKey: PrivateKey
|
||||
whoareyouMagic: array[32, byte]
|
||||
idHash: array[32, byte]
|
||||
pendingRequests: Table[array[12, byte], PendingRequest]
|
||||
db: Database
|
||||
routingTable: RoutingTable
|
||||
codec: Codec
|
||||
awaitedPackets: Table[(Node, RequestId), Future[Option[Packet]]]
|
||||
|
||||
PendingRequest = object
|
||||
node: Node
|
||||
packet: seq[byte]
|
||||
|
||||
const
|
||||
lookupRequestLimit = 15
|
||||
findnodeResultLimit = 15 # applies in FINDNODE handler
|
||||
|
||||
proc whoareyouMagic(toNode: NodeId): array[32, byte] =
|
||||
let srcId = toNode.toByteArrayBE()
|
||||
var data: seq[byte]
|
||||
data.add(srcId)
|
||||
for c in "WHOAREYOU": data.add(byte(c))
|
||||
sha256.digest(data).data
|
||||
|
||||
proc newProtocol*(privKey: PrivateKey, db: Database, port: Port): Protocol =
|
||||
result = Protocol(privateKey: privKey, db: db)
|
||||
var a: Address
|
||||
a.ip = parseIpAddress("127.0.0.1")
|
||||
a.udpPort = port
|
||||
var ipAddr: int32
|
||||
bigEndian32(addr ipAddr, addr a.ip.address_v4)
|
||||
|
||||
result.localNode = newNode(initENode(result.privateKey.getPublicKey(), a))
|
||||
result.localNode.record = initRecord(12, result.privateKey, {"udp": int(a.udpPort), "ip": ipAddr})
|
||||
|
||||
let srcId = result.localNode.id.toByteArrayBE()
|
||||
result.whoareyouMagic = whoareyouMagic(result.localNode.id)
|
||||
|
||||
result.idHash = sha256.digest(srcId).data
|
||||
result.routingTable.init(result.localNode)
|
||||
|
||||
result.codec = Codec(localNode: result.localNode, privKey: result.privateKey, db: result.db)
|
||||
|
||||
proc start*(p: Protocol) =
|
||||
discard
|
||||
|
||||
proc send(d: Protocol, a: Address, data: seq[byte]) =
|
||||
# echo "Sending ", data.len, " bytes to ", a
|
||||
let ta = initTAddress(a.ip, a.udpPort)
|
||||
let f = d.transp.sendTo(ta, data)
|
||||
f.callback = proc(data: pointer) {.gcsafe.} =
|
||||
if f.failed:
|
||||
debug "Discovery send failed", msg = f.readError.msg
|
||||
|
||||
proc send(d: Protocol, n: Node, data: seq[byte]) =
|
||||
d.send(n.node.address, data)
|
||||
|
||||
proc randomBytes(v: var openarray[byte]) =
|
||||
if nimcrypto.randomBytes(v) != v.len:
|
||||
raise newException(Exception, "Could not randomize bytes") # TODO:
|
||||
|
||||
proc `xor`[N: static[int], T](a, b: array[N, T]): array[N, T] =
|
||||
for i in 0 .. a.high:
|
||||
result[i] = a[i] xor b[i]
|
||||
|
||||
proc isWhoAreYou(d: Protocol, msg: Bytes): bool =
|
||||
if msg.len > d.whoareyouMagic.len:
|
||||
result = d.whoareyouMagic == msg.toOpenArray(0, 31)
|
||||
|
||||
proc decodeWhoAreYou(d: Protocol, msg: Bytes): Whoareyou =
|
||||
result = Whoareyou()
|
||||
result[] = rlp.decode(msg.toRange[32 .. ^1], WhoareyouObj)
|
||||
|
||||
proc sendWhoareyou(d: Protocol, address: Address, toNode: NodeId, authTag: array[12, byte]) =
|
||||
let challenge = Whoareyou(authTag: authTag, recordSeq: 1)
|
||||
randomBytes(challenge.idNonce)
|
||||
d.codec.handshakes[$toNode] = challenge
|
||||
var data = @(whoareyouMagic(toNode))
|
||||
data.add(rlp.encode(challenge[]))
|
||||
d.send(address, data)
|
||||
|
||||
proc sendNodes(d: Protocol, toNode: Node, reqId: RequestId, nodes: openarray[Node]) =
|
||||
proc sendNodes(d: Protocol, toNode: Node, packet: NodesPacket, reqId: RequestId) {.nimcall.} =
|
||||
let (data, _) = d.codec.encodeEncrypted(toNode, 12345, encodePacket(packet, reqId), challenge = nil)
|
||||
d.send(toNode, data)
|
||||
|
||||
const maxNodesPerPacket = 3
|
||||
|
||||
var packet: NodesPacket
|
||||
packet.total = ceil(nodes.len / maxNodesPerPacket).uint32
|
||||
|
||||
for i in 0 ..< nodes.len:
|
||||
packet.enrs.add(nodes[i].record)
|
||||
if packet.enrs.len == 3:
|
||||
d.sendNodes(toNode, packet, reqId)
|
||||
packet.enrs.setLen(0)
|
||||
|
||||
if packet.enrs.len != 0:
|
||||
d.sendNodes(toNode, packet, reqId)
|
||||
|
||||
proc handlePing(d: Protocol, fromNode: Node, a: Address, ping: PingPacket, reqId: RequestId) =
|
||||
var pong: PongPacket
|
||||
pong.enrSeq = ping.enrSeq
|
||||
pong.ip = case a.ip.family
|
||||
of IpAddressFamily.IPv4: @(a.ip.address_v4)
|
||||
of IpAddressFamily.IPv6: @(a.ip.address_v6)
|
||||
pong.port = a.udpPort.uint16
|
||||
|
||||
let (data, _) = d.codec.encodeEncrypted(fromNode, 12345, encodePacket(pong, reqId), challenge = nil)
|
||||
d.send(fromNode, data)
|
||||
|
||||
proc handleFindNode(d: Protocol, fromNode: Node, a: Address, fn: FindNodePacket, reqId: RequestId) =
|
||||
if fn.distance == 0:
|
||||
d.sendNodes(fromNode, reqId, [d.localNode])
|
||||
else:
|
||||
let distance = min(fn.distance, 256)
|
||||
d.sendNodes(fromNode, reqId, d.routingTable.neighboursAtDistance(distance))
|
||||
|
||||
proc receive*(d: Protocol, a: Address, msg: Bytes) {.gcsafe.} =
|
||||
## Can raise `DiscProtocolError` and all of `RlpError`
|
||||
# Note: export only needed for testing
|
||||
if msg.len < 32:
|
||||
return # Invalid msg
|
||||
|
||||
try:
|
||||
# echo "Packet received: ", msg.len
|
||||
|
||||
if d.isWhoAreYou(msg):
|
||||
let whoareyou = d.decodeWhoAreYou(msg)
|
||||
var pr: PendingRequest
|
||||
if d.pendingRequests.take(whoareyou.authTag, pr):
|
||||
let toNode = pr.node
|
||||
|
||||
let (data, _) = d.codec.encodeEncrypted(toNode, 12345, pr.packet, challenge = whoareyou)
|
||||
d.send(toNode, data)
|
||||
|
||||
else:
|
||||
var tag: array[32, byte]
|
||||
tag[0 .. ^1] = msg.toOpenArray(0, 31)
|
||||
let senderData = tag xor d.idHash
|
||||
let sender = readUintBE[256](senderData)
|
||||
|
||||
var authTag: array[12, byte]
|
||||
var node: Node
|
||||
var packet: Packet
|
||||
|
||||
if d.codec.decodeEncrypted(sender, 12345, msg, authTag, node, packet):
|
||||
if node.isNil:
|
||||
node = d.routingTable.getNode(sender)
|
||||
else:
|
||||
echo "Adding new node to routing table"
|
||||
discard d.routingTable.addNode(node)
|
||||
|
||||
doAssert(not node.isNil, "No node in the routing table (internal error?)")
|
||||
|
||||
case packet.kind
|
||||
of ping:
|
||||
d.handlePing(node, a, packet.ping, packet.reqId)
|
||||
of findNode:
|
||||
d.handleFindNode(node, a, packet.findNode, packet.reqId)
|
||||
else:
|
||||
var waiter: Future[Option[Packet]]
|
||||
if d.awaitedPackets.take((node, packet.reqId), waiter):
|
||||
waiter.complete(packet.some)
|
||||
else:
|
||||
echo "TODO: handle packet: ", packet.kind, " from ", node
|
||||
|
||||
else:
|
||||
d.sendWhoareyou(a, sender, authTag)
|
||||
echo "Could not decode, respond with whoareyou"
|
||||
|
||||
except Exception as e:
|
||||
echo "Exception: ", e.msg
|
||||
echo e.getStackTrace()
|
||||
|
||||
proc waitPacket(d: Protocol, fromNode: Node, reqId: RequestId): Future[Option[Packet]] =
|
||||
result = newFuture[Option[Packet]]("waitPacket")
|
||||
let res = result
|
||||
let key = (fromNode, reqId)
|
||||
sleepAsync(1000).addCallback() do(data: pointer):
|
||||
d.awaitedPackets.del(key)
|
||||
if not res.finished:
|
||||
res.complete(none(Packet))
|
||||
d.awaitedPackets[key] = result
|
||||
|
||||
proc addNodesFromENRs(result: var seq[Node], enrs: openarray[Record]) =
|
||||
for r in enrs: result.add(newNode(r))
|
||||
|
||||
proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): Future[seq[Node]] {.async.} =
|
||||
var op = await d.waitPacket(fromNode, reqId)
|
||||
if op.isSome and op.get.kind == nodes:
|
||||
result.addNodesFromENRs(op.get.nodes.enrs)
|
||||
let total = op.get.nodes.total
|
||||
for i in 1 ..< total:
|
||||
op = await d.waitPacket(fromNode, reqId)
|
||||
if op.isSome and op.get.kind == nodes:
|
||||
result.addNodesFromENRs(op.get.nodes.enrs)
|
||||
else:
|
||||
break
|
||||
|
||||
proc findNode(d: Protocol, toNode: Node, distance: uint32): Future[seq[Node]] {.async.} =
|
||||
let reqId = newRequestId()
|
||||
let packet = encodePacket(FindNodePacket(distance: distance), reqId)
|
||||
let (data, nonce) = d.codec.encodeEncrypted(toNode, 12345, packet, challenge = nil)
|
||||
d.pendingRequests[nonce] = PendingRequest(node: toNode, packet: packet)
|
||||
d.send(toNode, data)
|
||||
result = await d.waitNodes(toNode, reqId)
|
||||
|
||||
proc lookupDistances(target, dest: NodeId): seq[uint32] =
|
||||
let td = logDist(target, dest)
|
||||
result.add(td)
|
||||
var i = 1'u32
|
||||
while result.len < lookupRequestLimit:
|
||||
if td + i < 256:
|
||||
result.add(td + i)
|
||||
if td - i > 0'u32:
|
||||
result.add(td - i)
|
||||
inc i
|
||||
|
||||
proc lookupWorker(p: Protocol, destNode: Node, target: NodeId): Future[seq[Node]] {.async.} =
|
||||
let dists = lookupDistances(target, destNode.id)
|
||||
var i = 0
|
||||
while i < lookupRequestLimit and result.len < findNodeResultLimit:
|
||||
let r = await p.findNode(destNode, dists[i])
|
||||
# TODO: Handle falures
|
||||
result.add(r)
|
||||
inc i
|
||||
|
||||
for n in result:
|
||||
discard p.routingTable.addNode(n)
|
||||
|
||||
proc lookup(p: Protocol, target: NodeId): Future[seq[Node]] {.async.} =
|
||||
result = p.routingTable.neighbours(target, 16)
|
||||
var asked = initHashSet[NodeId]()
|
||||
asked.incl(p.localNode.id)
|
||||
var seen = asked
|
||||
|
||||
const alpha = 3
|
||||
|
||||
var pendingQueries = newSeqOfCap[Future[seq[Node]]](alpha)
|
||||
|
||||
while true:
|
||||
var i = 0
|
||||
while i < result.len and pendingQueries.len < alpha:
|
||||
let n = result[i]
|
||||
if not asked.containsOrIncl(n.id):
|
||||
pendingQueries.add(p.lookupWorker(n, target))
|
||||
inc i
|
||||
|
||||
if pendingQueries.len == 0:
|
||||
break
|
||||
|
||||
let idx = await oneIndex(pendingQueries)
|
||||
|
||||
let nodes = pendingQueries[idx].read
|
||||
pendingQueries.del(idx)
|
||||
for n in nodes:
|
||||
if not seen.containsOrIncl(n.id):
|
||||
if result.len < BUCKET_SIZE:
|
||||
result.add(n)
|
||||
|
||||
proc lookupRandom(p: Protocol): Future[seq[Node]] =
|
||||
var id: NodeId
|
||||
discard randomBytes(addr id, sizeof(id))
|
||||
p.lookup(id)
|
||||
|
||||
proc processClient(transp: DatagramTransport,
|
||||
raddr: TransportAddress): Future[void] {.async, gcsafe.} =
|
||||
var proto = getUserData[Protocol](transp)
|
||||
try:
|
||||
# TODO: Maybe here better to use `peekMessage()` to avoid allocation,
|
||||
# but `Bytes` object is just a simple seq[byte], and `ByteRange` object
|
||||
# do not support custom length.
|
||||
var buf = transp.getMessage()
|
||||
let a = Address(ip: raddr.address, udpPort: raddr.port, tcpPort: raddr.port)
|
||||
proto.receive(a, buf)
|
||||
except RlpError:
|
||||
debug "Receive failed", err = getCurrentExceptionMsg()
|
||||
except:
|
||||
debug "Receive failed", err = getCurrentExceptionMsg()
|
||||
raise
|
||||
|
||||
proc open*(d: Protocol) =
|
||||
# TODO allow binding to specific IP / IPv6 / etc
|
||||
let ta = initTAddress(IPv4_any(), d.localNode.node.address.udpPort)
|
||||
d.transp = newDatagramTransport(processClient, udata = d, local = ta)
|
||||
|
||||
when isMainModule:
|
||||
import discovery_db
|
||||
import eth/trie/db
|
||||
|
||||
proc genDiscoveries(n: int): seq[Protocol] =
|
||||
var pks = ["98b3d4d4fe348ac5192d16b46aa36c41f847b9f265ba4d56f6326669449a968b", "88d125288fbb19ecd7b6a355faf3e842e3c6158d38af14bb97ac8d957ec9cb58", "c9a24471d2f84efa103b9abbdedd4c0fea8402f94e5ceb3ca4d9cff951fc407f"]
|
||||
for i in 0 ..< n:
|
||||
var pk: PrivateKey
|
||||
if i < pks.len:
|
||||
pk = initPrivateKey(pks[i])
|
||||
else:
|
||||
pk = newPrivateKey()
|
||||
|
||||
let d = newProtocol(pk, DiscoveryDB.init(newMemoryDB()), Port(12001 + i))
|
||||
d.open()
|
||||
result.add(d)
|
||||
|
||||
proc addNode(d: Protocol, enr: string) =
|
||||
var r: Record
|
||||
let res = r.fromUri(enr)
|
||||
doAssert(res)
|
||||
discard d.routingTable.addNode(newNode(r))
|
||||
|
||||
proc addNode(d: openarray[Protocol], enr: string) =
|
||||
for dd in d: dd.addNode(enr)
|
||||
|
||||
proc test() {.async.} =
|
||||
block:
|
||||
let d = genDiscoveries(3)
|
||||
d.addNode("enr:-IS4QPvi3TdAUd2Jdrx-8ScRbCzrV1kVsTTM02mfz8Fx7CtrAfYN7AjxTx3MWbY2efRmAhS-Yyv4nhyzKu_YS6jSh08BgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQJeWTAJhJYN2q3BvcQwsyo7pIi8KnfwDIrhNdflCFvqr4N1ZHCCD6A")
|
||||
|
||||
for i, dd in d:
|
||||
let nodes = await dd.lookupRandom()
|
||||
echo "NODES ", i, ": ", nodes
|
||||
|
||||
# block:
|
||||
# var d = genDiscoveries(4)
|
||||
# let rootD = d[0]
|
||||
# d.del(0)
|
||||
|
||||
|
||||
# d.addNode(rootD.localNode.record.toUri)
|
||||
|
||||
# for i, dd in d:
|
||||
# let nodes = await dd.lookupRandom()
|
||||
# echo "NODES ", i, ": ", nodes
|
||||
|
||||
waitFor test()
|
||||
runForever()
|
|
@ -0,0 +1,198 @@
|
|||
import algorithm, times, sequtils, bitops
|
||||
import types, node
|
||||
import stint, chronicles
|
||||
|
||||
type
|
||||
RoutingTable* = object
|
||||
thisNode: Node
|
||||
buckets: seq[KBucket]
|
||||
|
||||
KBucket = ref object
|
||||
istart, iend: NodeId
|
||||
nodes: seq[Node]
|
||||
replacementCache: seq[Node]
|
||||
lastUpdated: float # epochTime
|
||||
|
||||
const
|
||||
BUCKET_SIZE* = 16
|
||||
BITS_PER_HOP = 8
|
||||
ID_SIZE = 256
|
||||
|
||||
proc distanceTo(n: Node, id: NodeId): UInt256 = n.id xor id
|
||||
proc logDist*(a, b: NodeId): uint32 =
|
||||
let a = a.toBytes
|
||||
let b = b.toBytes
|
||||
var lz = 0
|
||||
for i in 0 ..< a.len:
|
||||
let x = a[i] xor b[i]
|
||||
if x == 0:
|
||||
result += 8
|
||||
else:
|
||||
result += bitops.countLeadingZeroBits(x).uint8
|
||||
uint32(a.len * 8 - lz)
|
||||
|
||||
proc newKBucket(istart, iend: NodeId): KBucket =
|
||||
result.new()
|
||||
result.istart = istart
|
||||
result.iend = iend
|
||||
result.nodes = @[]
|
||||
result.replacementCache = @[]
|
||||
|
||||
proc midpoint(k: KBucket): NodeId =
|
||||
k.istart + (k.iend - k.istart) div 2.u256
|
||||
|
||||
proc distanceTo(k: KBucket, id: NodeId): UInt256 = k.midpoint xor id
|
||||
proc nodesByDistanceTo(k: KBucket, id: NodeId): seq[Node] =
|
||||
sortedByIt(k.nodes, it.distanceTo(id))
|
||||
|
||||
proc len(k: KBucket): int {.inline.} = k.nodes.len
|
||||
proc head(k: KBucket): Node {.inline.} = k.nodes[0]
|
||||
|
||||
proc add(k: KBucket, n: Node): Node =
|
||||
## Try to add the given node to this bucket.
|
||||
|
||||
## If the node is already present, it is moved to the tail of the list, and we return nil.
|
||||
|
||||
## If the node is not already present and the bucket has fewer than k entries, it is inserted
|
||||
## at the tail of the list, and we return nil.
|
||||
|
||||
## If the bucket is full, we add the node to the bucket's replacement cache and return the
|
||||
## node at the head of the list (i.e. the least recently seen), which should be evicted if it
|
||||
## fails to respond to a ping.
|
||||
k.lastUpdated = epochTime()
|
||||
let nodeIdx = k.nodes.find(n)
|
||||
if nodeIdx != -1:
|
||||
k.nodes.delete(nodeIdx)
|
||||
k.nodes.add(n)
|
||||
elif k.len < BUCKET_SIZE:
|
||||
k.nodes.add(n)
|
||||
else:
|
||||
k.replacementCache.add(n)
|
||||
return k.head
|
||||
return nil
|
||||
|
||||
proc removeNode(k: KBucket, n: Node) =
|
||||
let i = k.nodes.find(n)
|
||||
if i != -1: k.nodes.delete(i)
|
||||
|
||||
proc split(k: KBucket): tuple[lower, upper: KBucket] =
|
||||
## Split at the median id
|
||||
let splitid = k.midpoint
|
||||
result.lower = newKBucket(k.istart, splitid)
|
||||
result.upper = newKBucket(splitid + 1.u256, k.iend)
|
||||
for node in k.nodes:
|
||||
let bucket = if node.id <= splitid: result.lower else: result.upper
|
||||
discard bucket.add(node)
|
||||
for node in k.replacementCache:
|
||||
let bucket = if node.id <= splitid: result.lower else: result.upper
|
||||
bucket.replacementCache.add(node)
|
||||
|
||||
proc inRange(k: KBucket, n: Node): bool {.inline.} =
|
||||
k.istart <= n.id and n.id <= k.iend
|
||||
|
||||
proc isFull(k: KBucket): bool = k.len == BUCKET_SIZE
|
||||
|
||||
proc contains(k: KBucket, n: Node): bool = n in k.nodes
|
||||
|
||||
proc binaryGetBucketForNode(buckets: openarray[KBucket],
|
||||
id: NodeId): KBucket {.inline.} =
|
||||
## Given a list of ordered buckets, returns the bucket for a given node.
|
||||
let bucketPos = lowerBound(buckets, id) do(a: KBucket, b: NodeId) -> int:
|
||||
cmp(a.iend, b)
|
||||
# Prevents edge cases where bisect_left returns an out of range index
|
||||
if bucketPos < buckets.len:
|
||||
let bucket = buckets[bucketPos]
|
||||
if bucket.istart <= id and id <= bucket.iend:
|
||||
result = bucket
|
||||
|
||||
if result.isNil:
|
||||
raise newException(ValueError, "No bucket found for node with id " & $id)
|
||||
|
||||
proc computeSharedPrefixBits(nodes: openarray[Node]): int =
|
||||
## Count the number of prefix bits shared by all nodes.
|
||||
if nodes.len < 2:
|
||||
return ID_SIZE
|
||||
|
||||
var mask = zero(UInt256)
|
||||
let one = one(UInt256)
|
||||
|
||||
for i in 1 .. ID_SIZE:
|
||||
mask = mask or (one shl (ID_SIZE - i))
|
||||
let reference = nodes[0].id and mask
|
||||
for j in 1 .. nodes.high:
|
||||
if (nodes[j].id and mask) != reference: return i - 1
|
||||
|
||||
for n in nodes:
|
||||
echo n.id.toHex()
|
||||
|
||||
doAssert(false, "Unable to calculate number of shared prefix bits")
|
||||
|
||||
proc init*(r: var RoutingTable, thisNode: Node) {.inline.} =
|
||||
r.thisNode = thisNode
|
||||
r.buckets = @[newKBucket(0.u256, high(Uint256))]
|
||||
|
||||
proc splitBucket(r: var RoutingTable, index: int) =
|
||||
let bucket = r.buckets[index]
|
||||
let (a, b) = bucket.split()
|
||||
r.buckets[index] = a
|
||||
r.buckets.insert(b, index + 1)
|
||||
|
||||
proc bucketForNode(r: RoutingTable, id: NodeId): KBucket =
|
||||
binaryGetBucketForNode(r.buckets, id)
|
||||
|
||||
proc removeNode(r: var RoutingTable, n: Node) =
|
||||
r.bucketForNode(n.id).removeNode(n)
|
||||
|
||||
proc addNode*(r: var RoutingTable, n: Node): Node =
|
||||
if n == r.thisNode:
|
||||
# warn "Trying to add ourselves to the routing table", node = n
|
||||
return
|
||||
let bucket = r.bucketForNode(n.id)
|
||||
let evictionCandidate = bucket.add(n)
|
||||
if not evictionCandidate.isNil:
|
||||
# Split if the bucket has the local node in its range or if the depth is not congruent
|
||||
# to 0 mod BITS_PER_HOP
|
||||
|
||||
let depth = computeSharedPrefixBits(bucket.nodes)
|
||||
if bucket.inRange(r.thisNode) or (depth mod BITS_PER_HOP != 0 and depth != ID_SIZE):
|
||||
r.splitBucket(r.buckets.find(bucket))
|
||||
return r.addNode(n) # retry
|
||||
|
||||
# Nothing added, ping evictionCandidate
|
||||
return evictionCandidate
|
||||
|
||||
proc getNode*(r: RoutingTable, id: NodeId): Node =
|
||||
let b = binaryGetBucketForNode(r.buckets, id)
|
||||
for n in b.nodes:
|
||||
if n.id == id:
|
||||
return n
|
||||
|
||||
proc contains(r: RoutingTable, n: Node): bool = n in r.bucketForNode(n.id)
|
||||
|
||||
proc bucketsByDistanceTo(r: RoutingTable, id: NodeId): seq[KBucket] =
|
||||
sortedByIt(r.buckets, it.distanceTo(id))
|
||||
|
||||
proc notFullBuckets(r: RoutingTable): seq[KBucket] =
|
||||
r.buckets.filterIt(not it.isFull)
|
||||
|
||||
proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE): seq[Node] =
|
||||
## Return up to k neighbours of the given node.
|
||||
result = newSeqOfCap[Node](k * 2)
|
||||
for bucket in r.bucketsByDistanceTo(id):
|
||||
for n in bucket.nodesByDistanceTo(id):
|
||||
if n.id != id:
|
||||
result.add(n)
|
||||
if result.len == k * 2:
|
||||
break
|
||||
result = sortedByIt(result, it.distanceTo(id))
|
||||
if result.len > k:
|
||||
result.setLen(k)
|
||||
|
||||
proc idAtDistance(id: NodeId, dist: uint32): NodeId =
|
||||
id and (Uint256.high shl dist.int)
|
||||
|
||||
proc neighboursAtDistance*(r: RoutingTable, distance: uint32, k: int = BUCKET_SIZE): seq[Node] =
|
||||
r.neighbours(idAtDistance(r.thisNode.id, distance), k)
|
||||
|
||||
proc len(r: RoutingTable): int =
|
||||
for b in r.buckets: result += b.len
|
|
@ -0,0 +1,73 @@
|
|||
import hashes
|
||||
import ../enode, enr, stint
|
||||
|
||||
type
|
||||
NodeId* = UInt256
|
||||
|
||||
WhoareyouObj* = object
|
||||
authTag*: array[12, byte]
|
||||
idNonce*: array[32, byte]
|
||||
recordSeq*: uint64
|
||||
|
||||
Whoareyou* = ref WhoareyouObj
|
||||
|
||||
Database* = ref object of RootRef
|
||||
|
||||
PacketKind* = enum
|
||||
ping = 0x01
|
||||
pong = 0x02
|
||||
findnode = 0x03
|
||||
nodes = 0x04
|
||||
regtopic = 0x05
|
||||
ticket = 0x06
|
||||
regconfirmation = 0x07
|
||||
topicquery = 0x08
|
||||
|
||||
RequestId* = array[8, byte]
|
||||
|
||||
PingPacket* = object
|
||||
enrSeq*: uint64
|
||||
|
||||
PongPacket* = object
|
||||
enrSeq*: uint64
|
||||
ip*: seq[byte]
|
||||
port*: uint16
|
||||
|
||||
FindNodePacket* = object
|
||||
distance*: uint32
|
||||
|
||||
NodesPacket* = object
|
||||
total*: uint32
|
||||
enrs*: seq[Record]
|
||||
|
||||
SomePacket* = PingPacket or PongPacket or FindNodePacket or NodesPacket
|
||||
|
||||
Packet* = object
|
||||
reqId*: RequestId
|
||||
case kind*: PacketKind
|
||||
of ping:
|
||||
ping*: PingPacket
|
||||
of pong:
|
||||
pong*: PongPacket
|
||||
of findnode:
|
||||
findNode*: FindNodePacket
|
||||
of nodes:
|
||||
nodes*: NodesPacket
|
||||
else:
|
||||
# TODO: Define the rest
|
||||
discard
|
||||
|
||||
template packetKind*(T: typedesc[SomePacket]): PacketKind =
|
||||
when T is PingPacket: ping
|
||||
elif T is PongPacket: pong
|
||||
elif T is FindNodePacket: findNode
|
||||
elif T is NodesPacket: nodes
|
||||
|
||||
method storeKeys*(db: Database, id: NodeId, address: int, r, w: array[16, byte]) {.base.} = discard
|
||||
method loadKeys*(db: Database, id: NodeId, address: int, r, w: var array[16, byte]): bool {.base.} = discard
|
||||
|
||||
proc toBytes*(id: NodeId): array[32, byte] {.inline.} =
|
||||
id.toByteArrayBE()
|
||||
|
||||
proc hash*(id: NodeId): Hash {.inline.} =
|
||||
hashData(unsafeAddr id, sizeof(id))
|
|
@ -13,7 +13,7 @@ export
|
|||
type
|
||||
Rlp* = object
|
||||
bytes: BytesRange
|
||||
position: int
|
||||
position*: int
|
||||
|
||||
RlpNodeType* = enum
|
||||
rlpBlob
|
||||
|
|
Loading…
Reference in New Issue