Rlp experimental (#227)

* rlp: remove experimental features

* avoid range library

* trie: avoid reference-unsafe bitrange type
This commit is contained in:
Jacek Sieka 2020-04-20 20:14:39 +02:00 committed by GitHub
parent 1646d78d83
commit fd6caa0fdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
56 changed files with 1032 additions and 1025 deletions

View File

@ -10,15 +10,9 @@ and [Wiki](https://github.com/ethereum/wiki/wiki/RLP).
### Reading RLP data ### Reading RLP data
The `Rlp` type provided by this library represents a cursor over an RLP-encoded The `Rlp` type provided by this library represents a cursor over an RLP-encoded
byte stream. Before instantiating such a cursor, you must convert your byte stream.
input data a `BytesRange` value provided by the [nim-ranges][RNG] library,
which represents an immutable and thus cheap-to-copy sub-range view over an
underlying `seq[byte]` instance:
[RNG]: https://github.com/status-im/nim-ranges
``` nim ``` nim
proc rlpFromBytes*(data: BytesRange): Rlp proc rlpFromBytes*(data: openArray[byte]): Rlp
``` ```
### Streaming API ### Streaming API
@ -67,7 +61,7 @@ type
RlpNode* = object RlpNode* = object
case kind*: RlpNodeType case kind*: RlpNodeType
of rlpBlob: of rlpBlob:
bytes*: BytesRange bytes*: seq[byte]
of rlpList: of rlpList:
elems*: seq[RlpNode] elems*: seq[RlpNode]
``` ```
@ -86,7 +80,7 @@ proc inspect*(self: Rlp, indent = 0): string
The `RlpWriter` type can be used to encode RLP data. Instances are created The `RlpWriter` type can be used to encode RLP data. Instances are created
with the `initRlpWriter` proc. This should be followed by one or more calls with the `initRlpWriter` proc. This should be followed by one or more calls
to `append` which is overloaded to accept arbitrary values. Finally, you can to `append` which is overloaded to accept arbitrary values. Finally, you can
call `finish` to obtain the final `BytesRange`. call `finish` to obtain the final `seq[byte]`.
If the end result should be a RLP list of particular length, you can replace If the end result should be a RLP list of particular length, you can replace
the initial call to `initRlpWriter` with `initRlpList(n)`. Calling `finish` the initial call to `initRlpWriter` with `initRlpList(n)`. Calling `finish`

View File

@ -62,12 +62,9 @@ The primary API for Binary-trie is `set` and `get`.
* set(key, value) --- _store a value associated with a key_ * set(key, value) --- _store a value associated with a key_
* get(key): value --- _get a value using a key_ * get(key): value --- _get a value using a key_
Both `key` and `value` are of `BytesRange` type. And they cannot have zero length. Both `key` and `value` are of `seq[byte]` type. And they cannot have zero length.
You can also use convenience API `get` and `set` which accepts
`Bytes` or `string` (a `string` is conceptually wrong in this context
and may costlier than a `BytesRange`, but it is good for testing purpose).
Getting a non-existent key will return zero length BytesRange. Getting a non-existent key will return zero length seq[byte].
Binary-trie also provide dictionary syntax API for `set` and `get`. Binary-trie also provide dictionary syntax API for `set` and `get`.
* trie[key] = value -- same as `set` * trie[key] = value -- same as `set`
@ -81,11 +78,11 @@ Additional APIs are:
that starts with the same key prefix that starts with the same key prefix
* rootNode() -- get root node * rootNode() -- get root node
* rootNode(node) -- replace the root node * rootNode(node) -- replace the root node
* getRootHash(): `KeccakHash` with `BytesRange` type * getRootHash(): `KeccakHash` with `seq[byte]` type
* getDB(): `DB` -- get flat-db pointer * getDB(): `DB` -- get flat-db pointer
Constructor API: Constructor API:
* initBinaryTrie(DB, rootHash[optional]) -- rootHash has `BytesRange` or KeccakHash type * initBinaryTrie(DB, rootHash[optional]) -- rootHash has `seq[byte]` or KeccakHash type
* init(BinaryTrie, DB, rootHash[optional]) * init(BinaryTrie, DB, rootHash[optional])
Normally you would not set the rootHash when constructing an empty Binary-trie. Normally you would not set the rootHash when constructing an empty Binary-trie.
@ -103,17 +100,17 @@ var db = newMemoryDB()
var trie = initBinaryTrie(db) var trie = initBinaryTrie(db)
trie.set("key1", "value1") trie.set("key1", "value1")
trie.set("key2", "value2") trie.set("key2", "value2")
doAssert trie.get("key1") == "value1".toRange doAssert trie.get("key1") == "value1".toBytes
doAssert trie.get("key2") == "value2".toRange doAssert trie.get("key2") == "value2".toBytes
# delete all subtrie with key prefixes "key" # delete all subtrie with key prefixes "key"
trie.deleteSubtrie("key") trie.deleteSubtrie("key")
doAssert trie.get("key1") == zeroBytesRange doAssert trie.get("key1") == []
doAssert trie.get("key2") == zeroBytesRange doAssert trie.get("key2") == []]
trie["moon"] = "sun" trie["moon"] = "sun"
doAssert "moon" in trie doAssert "moon" in trie
doAssert trie["moon"] == "sun".toRange doAssert trie["moon"] == "sun".toBytes
``` ```
Remember, `set` and `get` are trie operations. A single `set` operation may invoke Remember, `set` and `get` are trie operations. A single `set` operation may invoke
@ -142,12 +139,12 @@ The branch utils consist of these API:
* getTrieNodes(DB; nodeHash): branch * getTrieNodes(DB; nodeHash): branch
`keyPrefix`, `key`, and `value` are bytes container with length greater than zero. `keyPrefix`, `key`, and `value` are bytes container with length greater than zero.
They can be BytesRange, Bytes, and string(again, for convenience and testing purpose). They can be openArray[byte].
`rootHash` and `nodeHash` also bytes container, `rootHash` and `nodeHash` also bytes container,
but they have constraint: must be 32 bytes in length, and it must be a keccak_256 hash value. but they have constraint: must be 32 bytes in length, and it must be a keccak_256 hash value.
`branch` is a list of nodes, or in this case a seq[BytesRange]. `branch` is a list of nodes, or in this case a `seq[seq[byte]]`.
A list? yes, the structure is stored along with the encoded node. A list? yes, the structure is stored along with the encoded node.
Therefore a list is enough to reconstruct the entire trie/branch. Therefore a list is enough to reconstruct the entire trie/branch.
@ -303,14 +300,14 @@ let
trie.set(key1, "value1") trie.set(key1, "value1")
trie.set(key2, "value2") trie.set(key2, "value2")
doAssert trie.get(key1) == "value1".toRange doAssert trie.get(key1) == "value1".toBytes
doAssert trie.get(key2) == "value2".toRange doAssert trie.get(key2) == "value2".toBytes
trie.delete(key1) trie.delete(key1)
doAssert trie.get(key1) == zeroBytesRange doAssert trie.get(key1) == []
trie.delete(key2) trie.delete(key2)
doAssert trie[key2] == zeroBytesRange doAssert trie[key2] == []
``` ```
Remember, `set` and `get` are trie operations. A single `set` operation may invoke Remember, `set` and `get` are trie operations. A single `set` operation may invoke

View File

@ -244,7 +244,7 @@ proc read*(rlp: var Rlp, T: typedesc[StUint]): T {.inline.} =
if bytes.len > 0: if bytes.len > 0:
# be sure the amount of bytes matches the size of the stint # be sure the amount of bytes matches the size of the stint
if bytes.len <= sizeof(result): if bytes.len <= sizeof(result):
result.initFromBytesBE(bytes.toOpenArray) result.initFromBytesBE(bytes)
else: else:
raise newException(RlpTypeMismatch, "Unsigned integer expected, but the source RLP has the wrong length") raise newException(RlpTypeMismatch, "Unsigned integer expected, but the source RLP has the wrong length")
else: else:
@ -375,7 +375,7 @@ method getTrieDB*(db: AbstractChainDB): TrieDatabaseRef {.base, gcsafe.} =
method getCodeByHash*(db: AbstractChainDB, hash: KeccakHash): Blob {.base, gcsafe.} = method getCodeByHash*(db: AbstractChainDB, hash: KeccakHash): Blob {.base, gcsafe.} =
notImplemented() notImplemented()
method getSetting*(db: AbstractChainDB, key: string): Bytes {.base, gcsafe.} = method getSetting*(db: AbstractChainDB, key: string): seq[byte] {.base, gcsafe.} =
notImplemented() notImplemented()
method setSetting*(db: AbstractChainDB, key: string, val: openarray[byte]) {.base, gcsafe.} = method setSetting*(db: AbstractChainDB, key: string, val: openarray[byte]) {.base, gcsafe.} =

View File

@ -6,7 +6,7 @@ proc getAccount*(db: TrieDatabaseRef,
rootHash: KeccakHash, rootHash: KeccakHash,
account: EthAddress): Account = account: EthAddress): Account =
let trie = initSecureHexaryTrie(db, rootHash) let trie = initSecureHexaryTrie(db, rootHash)
let data = trie.get(unnecessary_OpenArrayToRange account) let data = trie.get(account)
if data.len > 0: if data.len > 0:
result = rlp.decode(data, Account) result = rlp.decode(data, Account)
else: else:
@ -21,6 +21,3 @@ proc getContractCode*(chain: AbstractChainDB, req: ContractCodeRequest): Blob {.
proc getStorageNode*(chain: AbstractChainDB, hash: KeccakHash): Blob = proc getStorageNode*(chain: AbstractChainDB, hash: KeccakHash): Blob =
let db = chain.getTrieDB let db = chain.getTrieDB
return db.get(hash.data) return db.get(hash.data)
# let trie = initSecureHexaryTrie(db, emptyRlpHash) # TODO emptyRlpHash is not correct here
# return trie.get(unnecessary_OpenArrayToRange hash.data)

View File

@ -14,7 +14,7 @@
import eth/[keys, rlp], nimcrypto import eth/[keys, rlp], nimcrypto
import ecies import ecies
import stew/[byteutils, endians2, results] import stew/[byteutils, endians2, objects, results]
export results export results
@ -362,9 +362,6 @@ proc decodeAuthMessageV4(h: var Handshake, m: openarray[byte]): AuthResult[void]
proc decodeAuthMessageEip8(h: var Handshake, m: openarray[byte]): AuthResult[void] = proc decodeAuthMessageEip8(h: var Handshake, m: openarray[byte]): AuthResult[void] =
## Decodes EIP-8 AuthMessage. ## Decodes EIP-8 AuthMessage.
var
nonce: Nonce
let size = uint16.fromBytesBE(m) let size = uint16.fromBytesBE(m)
h.expectedLength = int(size) + 2 h.expectedLength = int(size) + 2
if h.expectedLength > len(m): if h.expectedLength > len(m):
@ -374,7 +371,7 @@ proc decodeAuthMessageEip8(h: var Handshake, m: openarray[byte]): AuthResult[voi
toa(m, 0, 2)).isErr: toa(m, 0, 2)).isErr:
return err(EciesError) return err(EciesError)
try: try:
var reader = rlpFromBytes(buffer.toRange()) var reader = rlpFromBytes(buffer)
if not reader.isList() or reader.listLen() < 4: if not reader.isList() or reader.listLen() < 4:
return err(InvalidAuth) return err(InvalidAuth)
if reader.listElem(0).blobLen != RawSignatureSize: if reader.listElem(0).blobLen != RawSignatureSize:
@ -385,29 +382,28 @@ proc decodeAuthMessageEip8(h: var Handshake, m: openarray[byte]): AuthResult[voi
return err(InvalidAuth) return err(InvalidAuth)
if reader.listElem(3).blobLen != 1: if reader.listElem(3).blobLen != 1:
return err(InvalidAuth) return err(InvalidAuth)
var signatureBr = reader.listElem(0).toBytes() let
var pubkeyBr = reader.listElem(1).toBytes() signatureBr = reader.listElem(0).toBytes()
var nonceBr = reader.listElem(2).toBytes() pubkeyBr = reader.listElem(1).toBytes()
var versionBr = reader.listElem(3).toBytes() nonceBr = reader.listElem(2).toBytes()
versionBr = reader.listElem(3).toBytes()
let pubkey = let
? PublicKey.fromRaw(pubkeyBr.toOpenArray()).mapErrTo(InvalidPubKey) signature = ? Signature.fromRaw(signatureBr).mapErrTo(SignatureError)
pubkey = ? PublicKey.fromRaw(pubkeyBr).mapErrTo(InvalidPubKey)
copyMem(addr nonce[0], nonceBr.baseAddr, KeyLength) nonce = toArray(KeyLength, nonceBr)
var secret = ? ecdhRaw(h.host.seckey, pubkey).mapErrTo(EcdhError) var secret = ? ecdhRaw(h.host.seckey, pubkey).mapErrTo(EcdhError)
let xornonce = nonce xor secret.data let xornonce = nonce xor secret.data
secret.clear() secret.clear()
let signature =
? Signature.fromRaw(signatureBr.toOpenArray()).mapErrTo(SignatureError)
h.remoteEPubkey = h.remoteEPubkey =
? recover(signature, SkMessage(data: xornonce)).mapErrTo(SignatureError) ? recover(signature, SkMessage(data: xornonce)).mapErrTo(SignatureError)
h.initiatorNonce = nonce h.initiatorNonce = nonce
h.remoteHPubkey = pubkey h.remoteHPubkey = pubkey
h.version = cast[ptr byte](versionBr.baseAddr)[] h.version = versionBr[0]
ok() ok()
except CatchableError: except CatchableError:
err(RlpError) err(RlpError)
@ -424,7 +420,7 @@ proc decodeAckMessageEip8*(h: var Handshake, m: openarray[byte]): AuthResult[voi
toa(m, 0, 2)).isErr: toa(m, 0, 2)).isErr:
return err(EciesError) return err(EciesError)
try: try:
var reader = rlpFromBytes(buffer.toRange()) var reader = rlpFromBytes(buffer)
if not reader.isList() or reader.listLen() < 3: if not reader.isList() or reader.listLen() < 3:
return err(InvalidAck) return err(InvalidAck)
if reader.listElem(0).blobLen != RawPublicKeySize: if reader.listElem(0).blobLen != RawPublicKeySize:
@ -433,14 +429,14 @@ proc decodeAckMessageEip8*(h: var Handshake, m: openarray[byte]): AuthResult[voi
return err(InvalidAck) return err(InvalidAck)
if reader.listElem(2).blobLen != 1: if reader.listElem(2).blobLen != 1:
return err(InvalidAck) return err(InvalidAck)
let pubkeyBr = reader.listElem(0).toBytes() let
let nonceBr = reader.listElem(1).toBytes() pubkeyBr = reader.listElem(0).toBytes()
let versionBr = reader.listElem(2).toBytes() nonceBr = reader.listElem(1).toBytes()
h.remoteEPubkey = versionBr = reader.listElem(2).toBytes()
? PublicKey.fromRaw(pubkeyBr.toOpenArray()).mapErrTo(InvalidPubKey)
copyMem(addr h.responderNonce[0], nonceBr.baseAddr, KeyLength) h.remoteEPubkey = ? PublicKey.fromRaw(pubkeyBr).mapErrTo(InvalidPubKey)
h.version = cast[ptr byte](versionBr.baseAddr)[] h.responderNonce = toArray(KeyLength, nonceBr)
h.version = versionBr[0]
ok() ok()
except CatchableError: except CatchableError:

View File

@ -13,7 +13,7 @@ import
chronos, stint, nimcrypto, chronicles, chronos, stint, nimcrypto, chronicles,
eth/[keys, rlp], eth/[keys, rlp],
kademlia, enode, kademlia, enode,
stew/results stew/[objects, results]
export export
Node, results Node, results
@ -53,27 +53,27 @@ const MinListLen: array[CommandId, int] = [4, 3, 2, 2]
proc append*(w: var RlpWriter, a: IpAddress) = proc append*(w: var RlpWriter, a: IpAddress) =
case a.family case a.family
of IpAddressFamily.IPv6: of IpAddressFamily.IPv6:
w.append(a.address_v6.toMemRange) w.append(a.address_v6)
of IpAddressFamily.IPv4: of IpAddressFamily.IPv4:
w.append(a.address_v4.toMemRange) w.append(a.address_v4)
proc append(w: var RlpWriter, p: Port) {.inline.} = w.append(p.int) proc append(w: var RlpWriter, p: Port) {.inline.} = w.append(p.int)
proc append(w: var RlpWriter, pk: PublicKey) {.inline.} = w.append(pk.toRaw()) proc append(w: var RlpWriter, pk: PublicKey) {.inline.} = w.append(pk.toRaw())
proc append(w: var RlpWriter, h: MDigest[256]) {.inline.} = w.append(h.data) proc append(w: var RlpWriter, h: MDigest[256]) {.inline.} = w.append(h.data)
proc pack(cmdId: CommandId, payload: BytesRange, pk: PrivateKey): Bytes = proc pack(cmdId: CommandId, payload: openArray[byte], pk: PrivateKey): seq[byte] =
## Create and sign a UDP message to be sent to a remote node. ## Create and sign a UDP message to be sent to a remote node.
## ##
## See https://github.com/ethereum/devp2p/blob/master/rlpx.md#node-discovery for information on ## See https://github.com/ethereum/devp2p/blob/master/rlpx.md#node-discovery for information on
## how UDP packets are structured. ## how UDP packets are structured.
# TODO: There is a lot of unneeded allocations here # TODO: There is a lot of unneeded allocations here
let encodedData = @[cmdId.byte] & payload.toSeq() let encodedData = @[cmdId.byte] & @payload
let signature = @(pk.sign(encodedData).tryGet().toRaw()) let signature = @(pk.sign(encodedData).tryGet().toRaw())
let msgHash = keccak256.digest(signature & encodedData) let msgHash = keccak256.digest(signature & encodedData)
result = @(msgHash.data) & signature & encodedData result = @(msgHash.data) & signature & encodedData
proc validateMsgHash(msg: Bytes): DiscResult[MDigest[256]] = proc validateMsgHash(msg: openArray[byte]): DiscResult[MDigest[256]] =
if msg.len > HEAD_SIZE: if msg.len > HEAD_SIZE:
var ret: MDigest[256] var ret: MDigest[256]
ret.data[0 .. ^1] = msg.toOpenArray(0, ret.data.high) ret.data[0 .. ^1] = msg.toOpenArray(0, ret.data.high)
@ -90,7 +90,7 @@ proc recoverMsgPublicKey(msg: openArray[byte]): DiscResult[PublicKey] =
let sig = ? Signature.fromRaw(msg.toOpenArray(MAC_SIZE, HEAD_SIZE)) let sig = ? Signature.fromRaw(msg.toOpenArray(MAC_SIZE, HEAD_SIZE))
recover(sig, msg.toOpenArray(HEAD_SIZE, msg.high)) recover(sig, msg.toOpenArray(HEAD_SIZE, msg.high))
proc unpack(msg: Bytes): tuple[cmdId: CommandId, payload: Bytes] = proc unpack(msg: openArray[byte]): tuple[cmdId: CommandId, payload: seq[byte]] =
# Check against possible RangeError # Check against possible RangeError
if msg[HEAD_SIZE].int < CommandId.low.ord or if msg[HEAD_SIZE].int < CommandId.low.ord or
msg[HEAD_SIZE].int > CommandId.high.ord: msg[HEAD_SIZE].int > CommandId.high.ord:
@ -112,14 +112,14 @@ proc send(d: DiscoveryProtocol, n: Node, data: seq[byte]) =
proc sendPing*(d: DiscoveryProtocol, n: Node): seq[byte] = proc sendPing*(d: DiscoveryProtocol, n: Node): seq[byte] =
let payload = rlp.encode((PROTO_VERSION, d.address, n.node.address, let payload = rlp.encode((PROTO_VERSION, d.address, n.node.address,
expiration())).toRange expiration()))
let msg = pack(cmdPing, payload, d.privKey) let msg = pack(cmdPing, payload, d.privKey)
result = msg[0 ..< MAC_SIZE] result = msg[0 ..< MAC_SIZE]
trace ">>> ping ", n trace ">>> ping ", n
d.send(n, msg) d.send(n, msg)
proc sendPong*(d: DiscoveryProtocol, n: Node, token: MDigest[256]) = proc sendPong*(d: DiscoveryProtocol, n: Node, token: MDigest[256]) =
let payload = rlp.encode((n.node.address, token, expiration())).toRange let payload = rlp.encode((n.node.address, token, expiration()))
let msg = pack(cmdPong, payload, d.privKey) let msg = pack(cmdPong, payload, d.privKey)
trace ">>> pong ", n trace ">>> pong ", n
d.send(n, msg) d.send(n, msg)
@ -127,7 +127,7 @@ proc sendPong*(d: DiscoveryProtocol, n: Node, token: MDigest[256]) =
proc sendFindNode*(d: DiscoveryProtocol, n: Node, targetNodeId: NodeId) = proc sendFindNode*(d: DiscoveryProtocol, n: Node, targetNodeId: NodeId) =
var data: array[64, byte] var data: array[64, byte]
data[32 .. ^1] = targetNodeId.toByteArrayBE() data[32 .. ^1] = targetNodeId.toByteArrayBE()
let payload = rlp.encode((data, expiration())).toRange let payload = rlp.encode((data, expiration()))
let msg = pack(cmdFindNode, payload, d.privKey) let msg = pack(cmdFindNode, payload, d.privKey)
trace ">>> find_node to ", n#, ": ", msg.toHex() trace ">>> find_node to ", n#, ": ", msg.toHex()
d.send(n, msg) d.send(n, msg)
@ -140,7 +140,7 @@ proc sendNeighbours*(d: DiscoveryProtocol, node: Node, neighbours: seq[Node]) =
template flush() = template flush() =
block: block:
let payload = rlp.encode((nodes, expiration())).toRange let payload = rlp.encode((nodes, expiration()))
let msg = pack(cmdNeighbours, payload, d.privkey) let msg = pack(cmdNeighbours, payload, d.privkey)
trace "Neighbours to", node, nodes trace "Neighbours to", node, nodes
d.send(node, msg) d.send(node, msg)
@ -169,14 +169,14 @@ proc recvPing(d: DiscoveryProtocol, node: Node,
msgHash: MDigest[256]) {.inline.} = msgHash: MDigest[256]) {.inline.} =
d.kademlia.recvPing(node, msgHash) d.kademlia.recvPing(node, msgHash)
proc recvPong(d: DiscoveryProtocol, node: Node, payload: Bytes) {.inline.} = proc recvPong(d: DiscoveryProtocol, node: Node, payload: seq[byte]) {.inline.} =
let rlp = rlpFromBytes(payload.toRange) let rlp = rlpFromBytes(payload)
let tok = rlp.listElem(1).toBytes().toSeq() let tok = rlp.listElem(1).toBytes()
d.kademlia.recvPong(node, tok) d.kademlia.recvPong(node, tok)
proc recvNeighbours(d: DiscoveryProtocol, node: Node, proc recvNeighbours(d: DiscoveryProtocol, node: Node,
payload: Bytes) {.inline.} = payload: seq[byte]) {.inline.} =
let rlp = rlpFromBytes(payload.toRange) let rlp = rlpFromBytes(payload)
let neighboursList = rlp.listElem(0) let neighboursList = rlp.listElem(0)
let sz = neighboursList.listLen() let sz = neighboursList.listLen()
@ -187,18 +187,18 @@ proc recvNeighbours(d: DiscoveryProtocol, node: Node,
var ip: IpAddress var ip: IpAddress
case ipBlob.len case ipBlob.len
of 4: of 4:
ip = IpAddress(family: IpAddressFamily.IPv4) ip = IpAddress(
copyMem(addr ip.address_v4[0], baseAddr ipBlob, 4) family: IpAddressFamily.IPv4, address_v4: toArray(4, ipBlob))
of 16: of 16:
ip = IpAddress(family: IpAddressFamily.IPv6) ip = IpAddress(
copyMem(addr ip.address_v6[0], baseAddr ipBlob, 16) family: IpAddressFamily.IPv6, address_v6: toArray(16, ipBlob))
else: else:
error "Wrong ip address length!" error "Wrong ip address length!"
continue continue
let udpPort = n.listElem(1).toInt(uint16).Port let udpPort = n.listElem(1).toInt(uint16).Port
let tcpPort = n.listElem(2).toInt(uint16).Port let tcpPort = n.listElem(2).toInt(uint16).Port
let pk = PublicKey.fromRaw(n.listElem(3).toBytes.toOpenArray()) let pk = PublicKey.fromRaw(n.listElem(3).toBytes)
if pk.isErr: if pk.isErr:
warn "Could not parse public key" warn "Could not parse public key"
continue continue
@ -206,24 +206,24 @@ proc recvNeighbours(d: DiscoveryProtocol, node: Node,
neighbours.add(newNode(pk[], Address(ip: ip, udpPort: udpPort, tcpPort: tcpPort))) neighbours.add(newNode(pk[], Address(ip: ip, udpPort: udpPort, tcpPort: tcpPort)))
d.kademlia.recvNeighbours(node, neighbours) d.kademlia.recvNeighbours(node, neighbours)
proc recvFindNode(d: DiscoveryProtocol, node: Node, payload: Bytes) {.inline, gcsafe.} = proc recvFindNode(d: DiscoveryProtocol, node: Node, payload: openArray[byte]) {.inline, gcsafe.} =
let rlp = rlpFromBytes(payload.toRange) let rlp = rlpFromBytes(payload)
trace "<<< find_node from ", node trace "<<< find_node from ", node
let rng = rlp.listElem(0).toBytes let rng = rlp.listElem(0).toBytes
# Check for pubkey len # Check for pubkey len
if rng.len == 64: if rng.len == 64:
let nodeId = readUIntBE[256](rng[32 .. ^1].toOpenArray()) let nodeId = readUIntBE[256](rng[32 .. ^1])
d.kademlia.recvFindNode(node, nodeId) d.kademlia.recvFindNode(node, nodeId)
else: else:
trace "Invalid target public key received" trace "Invalid target public key received"
proc expirationValid(cmdId: CommandId, rlpEncodedPayload: seq[byte]): proc expirationValid(cmdId: CommandId, rlpEncodedPayload: openArray[byte]):
bool {.inline, raises:[DiscProtocolError, RlpError].} = bool {.inline, raises:[DiscProtocolError, RlpError].} =
## Can only raise `DiscProtocolError` and all of `RlpError` ## Can only raise `DiscProtocolError` and all of `RlpError`
# Check if there is a payload # Check if there is a payload
if rlpEncodedPayload.len <= 0: if rlpEncodedPayload.len <= 0:
raise newException(DiscProtocolError, "RLP stream is empty") raise newException(DiscProtocolError, "RLP stream is empty")
let rlp = rlpFromBytes(rlpEncodedPayload.toRange) let rlp = rlpFromBytes(rlpEncodedPayload)
# Check payload is an RLP list and if the list has the minimum items required # Check payload is an RLP list and if the list has the minimum items required
# for this packet type # for this packet type
if rlp.isList and rlp.listLen >= MinListLen[cmdId]: if rlp.isList and rlp.listLen >= MinListLen[cmdId]:
@ -233,7 +233,7 @@ proc expirationValid(cmdId: CommandId, rlpEncodedPayload: seq[byte]):
else: else:
raise newException(DiscProtocolError, "Invalid RLP list for this packet id") raise newException(DiscProtocolError, "Invalid RLP list for this packet id")
proc receive*(d: DiscoveryProtocol, a: Address, msg: Bytes) {.gcsafe.} = proc receive*(d: DiscoveryProtocol, a: Address, msg: openArray[byte]) {.gcsafe.} =
## Can raise `DiscProtocolError` and all of `RlpError` ## Can raise `DiscProtocolError` and all of `RlpError`
# Note: export only needed for testing # Note: export only needed for testing
let msgHash = validateMsgHash(msg) let msgHash = validateMsgHash(msg)

View File

@ -202,7 +202,7 @@ proc decodePacketBody(typ: byte,
let kind = cast[PacketKind](typ) let kind = cast[PacketKind](typ)
res = Packet(kind: kind) res = Packet(kind: kind)
var rlp = rlpFromBytes(@body.toRange) var rlp = rlpFromBytes(body)
if rlp.enterList: if rlp.enterList:
res.reqId = rlp.read(RequestId) res.reqId = rlp.read(RequestId)
@ -258,12 +258,11 @@ proc decodeAuthResp(c: Codec, fromId: NodeId, head: AuthHeader,
proc decodeEncrypted*(c: var Codec, proc decodeEncrypted*(c: var Codec,
fromId: NodeID, fromId: NodeID,
fromAddr: Address, fromAddr: Address,
input: seq[byte], input: openArray[byte],
authTag: var AuthTag, authTag: var AuthTag,
newNode: var Node, newNode: var Node,
packet: var Packet): DecodeStatus = packet: var Packet): DecodeStatus =
let input = input.toRange var r = rlpFromBytes(input.toOpenArray(tagSize, input.high))
var r = rlpFromBytes(input[tagSize .. ^1])
var auth: AuthHeader var auth: AuthHeader
var readKey: AesKey var readKey: AesKey
@ -312,10 +311,11 @@ proc decodeEncrypted*(c: var Codec,
# doAssert(false, "TODO: HANDLE ME!") # doAssert(false, "TODO: HANDLE ME!")
let headSize = tagSize + r.position let headSize = tagSize + r.position
let bodyEnc = input[headSize .. ^1]
let body = decryptGCM(readKey, auth.auth, bodyEnc.toOpenArray, let body = decryptGCM(
input[0 .. tagSize - 1].toOpenArray) readKey, auth.auth,
input.toOpenArray(headSize, input.high),
input.toOpenArray(0, tagSize - 1))
if body.isNone(): if body.isNone():
discard c.db.deleteKeys(fromId, fromAddr) discard c.db.deleteKeys(fromId, fromAddr)
return DecryptError return DecryptError

View File

@ -214,11 +214,11 @@ proc verifySignatureV4(r: Record, sigData: openarray[byte], content: seq[byte]):
return verify(sig[], h, publicKey.get) return verify(sig[], h, publicKey.get)
proc verifySignature(r: Record): bool = proc verifySignature(r: Record): bool =
var rlp = rlpFromBytes(r.raw.toRange) var rlp = rlpFromBytes(r.raw)
let sz = rlp.listLen let sz = rlp.listLen
if not rlp.enterList: if not rlp.enterList:
return false return false
let sigData = rlp.read(Bytes) let sigData = rlp.read(seq[byte])
let content = block: let content = block:
var writer = initRlpList(sz - 1) var writer = initRlpList(sz - 1)
var reader = rlp var reader = rlp
@ -240,7 +240,7 @@ proc fromBytesAux(r: var Record): bool =
if r.raw.len > maxEnrSize: if r.raw.len > maxEnrSize:
return false return false
var rlp = rlpFromBytes(r.raw.toRange) var rlp = rlpFromBytes(r.raw)
if not rlp.isList: if not rlp.isList:
return false return false
@ -270,7 +270,7 @@ proc fromBytesAux(r: var Record): bool =
let v = rlp.read(uint16) let v = rlp.read(uint16)
r.pairs.add((k, Field(kind: kNum, num: v))) r.pairs.add((k, Field(kind: kNum, num: v)))
else: else:
r.pairs.add((k, Field(kind: kBytes, bytes: rlp.read(Bytes)))) r.pairs.add((k, Field(kind: kBytes, bytes: rlp.read(seq[byte]))))
verifySignature(r) verifySignature(r)
@ -329,9 +329,9 @@ proc `$`*(r: Record): string =
proc `==`*(a, b: Record): bool = a.raw == b.raw proc `==`*(a, b: Record): bool = a.raw == b.raw
proc read*(rlp: var Rlp, T: typedesc[Record]): T {.inline.} = proc read*(rlp: var Rlp, T: typedesc[Record]): T {.inline.} =
if not result.fromBytes(rlp.rawData.toOpenArray): if not result.fromBytes(rlp.rawData):
raise newException(ValueError, "Could not deserialize") raise newException(ValueError, "Could not deserialize")
rlp.skipElem() rlp.skipElem()
proc append*(rlpWriter: var RlpWriter, value: Record) = proc append*(rlpWriter: var RlpWriter, value: Record) =
rlpWriter.appendRawBytes(value.raw.toRange) rlpWriter.appendRawBytes(value.raw)

View File

@ -97,13 +97,13 @@ proc whoareyouMagic(toNode: NodeId): array[magicSize, byte] =
for i, c in prefix: data[sizeof(toNode) + i] = byte(c) for i, c in prefix: data[sizeof(toNode) + i] = byte(c)
sha256.digest(data).data sha256.digest(data).data
proc isWhoAreYou(d: Protocol, msg: Bytes): bool = proc isWhoAreYou(d: Protocol, msg: openArray[byte]): bool =
if msg.len > d.whoareyouMagic.len: if msg.len > d.whoareyouMagic.len:
result = d.whoareyouMagic == msg.toOpenArray(0, magicSize - 1) result = d.whoareyouMagic == msg.toOpenArray(0, magicSize - 1)
proc decodeWhoAreYou(d: Protocol, msg: Bytes): Whoareyou = proc decodeWhoAreYou(d: Protocol, msg: openArray[byte]): Whoareyou =
result = Whoareyou() result = Whoareyou()
result[] = rlp.decode(msg.toRange[magicSize .. ^1], WhoareyouObj) result[] = rlp.decode(msg.toOpenArray(magicSize, msg.high), WhoareyouObj)
proc sendWhoareyou(d: Protocol, address: Address, toNode: NodeId, authTag: AuthTag) = proc sendWhoareyou(d: Protocol, address: Address, toNode: NodeId, authTag: AuthTag) =
trace "sending who are you", to = $toNode, toAddress = $address trace "sending who are you", to = $toNode, toAddress = $address
@ -172,7 +172,7 @@ proc handleFindNode(d: Protocol, fromId: NodeId, fromAddr: Address,
d.sendNodes(fromId, fromAddr, reqId, d.sendNodes(fromId, fromAddr, reqId,
d.routingTable.neighboursAtDistance(distance)) d.routingTable.neighboursAtDistance(distance))
proc receive*(d: Protocol, a: Address, msg: Bytes) {.gcsafe, proc receive*(d: Protocol, a: Address, msg: openArray[byte]) {.gcsafe,
raises: [ raises: [
Defect, Defect,
# TODO This is now coming from Chronos's callSoon # TODO This is now coming from Chronos's callSoon

View File

@ -255,7 +255,7 @@ proc invokeThunk*(peer: Peer, msgId: int, msgData: var Rlp): Future[void] =
return thunk(peer, msgId, msgData) return thunk(peer, msgId, msgData)
template compressMsg(peer: Peer, data: Bytes): Bytes = template compressMsg(peer: Peer, data: seq[byte]): seq[byte] =
when useSnappy: when useSnappy:
if peer.snappyEnabled: if peer.snappyEnabled:
snappy.compress(data) snappy.compress(data)
@ -263,7 +263,7 @@ template compressMsg(peer: Peer, data: Bytes): Bytes =
else: else:
data data
proc sendMsg*(peer: Peer, data: Bytes) {.gcsafe, async.} = proc sendMsg*(peer: Peer, data: seq[byte]) {.gcsafe, async.} =
try: try:
var cipherText = encryptMsg(peer.compressMsg(data), peer.secretsState) var cipherText = encryptMsg(peer.compressMsg(data), peer.secretsState)
var res = await peer.transport.write(cipherText) var res = await peer.transport.write(cipherText)
@ -426,7 +426,7 @@ proc recvMsg*(peer: Peer): Future[tuple[msgId: int, msgData: Rlp]] {.async.} =
if decryptedBytes.len == 0: if decryptedBytes.len == 0:
await peer.disconnectAndRaise(BreachOfProtocol, await peer.disconnectAndRaise(BreachOfProtocol,
"Snappy uncompress encountered malformed data") "Snappy uncompress encountered malformed data")
var rlp = rlpFromBytes(decryptedBytes.toRange) var rlp = rlpFromBytes(decryptedBytes)
try: try:
# int32 as this seems more than big enough for the amount of msgIds # int32 as this seems more than big enough for the amount of msgIds
@ -561,7 +561,6 @@ proc p2pProtocolBackendImpl*(protocol: P2PProtocol): Backend =
EthereumNode = bindSym "EthereumNode" EthereumNode = bindSym "EthereumNode"
initRlpWriter = bindSym "initRlpWriter" initRlpWriter = bindSym "initRlpWriter"
rlpFromBytes = bindSym "rlpFromBytes"
append = bindSym("append", brForceOpen) append = bindSym("append", brForceOpen)
read = bindSym("read", brForceOpen) read = bindSym("read", brForceOpen)
checkedRlpRead = bindSym "checkedRlpRead" checkedRlpRead = bindSym "checkedRlpRead"

View File

@ -103,7 +103,7 @@ proc loadMessageStats*(network: LesNetwork,
break readFromDB break readFromDB
try: try:
var statsRlp = rlpFromBytes(stats.toRange) var statsRlp = rlpFromBytes(stats)
if not statsRlp.enterList: if not statsRlp.enterList:
notice "Found a corrupted LES stats record" notice "Found a corrupted LES stats record"
break readFromDB break readFromDB

View File

@ -15,12 +15,12 @@ const
requestCompleteTimeout = chronos.seconds(5) requestCompleteTimeout = chronos.seconds(5)
type type
Cursor = Bytes Cursor = seq[byte]
MailRequest* = object MailRequest* = object
lower*: uint32 ## Unix timestamp; oldest requested envelope's creation time lower*: uint32 ## Unix timestamp; oldest requested envelope's creation time
upper*: uint32 ## Unix timestamp; newest requested envelope's creation time upper*: uint32 ## Unix timestamp; newest requested envelope's creation time
bloom*: Bytes ## Bloom filter to apply on the envelopes bloom*: seq[byte] ## Bloom filter to apply on the envelopes
limit*: uint32 ## Maximum amount of envelopes to return limit*: uint32 ## Maximum amount of envelopes to return
cursor*: Cursor ## Optional cursor cursor*: Cursor ## Optional cursor

View File

@ -144,7 +144,7 @@ proc append*(rlpWriter: var RlpWriter, value: StatusOptions) =
let bytes = list.finish() let bytes = list.finish()
rlpWriter.append(rlpFromBytes(bytes.toRange)) rlpWriter.append(rlpFromBytes(bytes))
proc read*(rlp: var Rlp, T: typedesc[StatusOptions]): T = proc read*(rlp: var Rlp, T: typedesc[StatusOptions]): T =
if not rlp.isList(): if not rlp.isList():
@ -379,7 +379,7 @@ p2pProtocol Waku(version = wakuVersion,
proc p2pRequestComplete(peer: Peer, requestId: Hash, lastEnvelopeHash: Hash, proc p2pRequestComplete(peer: Peer, requestId: Hash, lastEnvelopeHash: Hash,
cursor: Bytes) = discard cursor: seq[byte]) = discard
# TODO: # TODO:
# In the current specification the parameters are not wrapped in a regular # In the current specification the parameters are not wrapped in a regular
# envelope as is done for the P2P Request packet. If we could alter this in # envelope as is done for the P2P Request packet. If we could alter this in
@ -488,8 +488,8 @@ proc queueMessage(node: EthereumNode, msg: Message): bool =
proc postMessage*(node: EthereumNode, pubKey = none[PublicKey](), proc postMessage*(node: EthereumNode, pubKey = none[PublicKey](),
symKey = none[SymKey](), src = none[PrivateKey](), symKey = none[SymKey](), src = none[PrivateKey](),
ttl: uint32, topic: Topic, payload: Bytes, ttl: uint32, topic: Topic, payload: seq[byte],
padding = none[Bytes](), powTime = 1'f, padding = none[seq[byte]](), powTime = 1'f,
powTarget = defaultMinPow, powTarget = defaultMinPow,
targetPeer = none[NodeId]()): bool = targetPeer = none[NodeId]()): bool =
## Post a message on the message queue which will be processed at the ## Post a message on the message queue which will be processed at the

View File

@ -9,7 +9,7 @@
# #
import import
algorithm, bitops, math, options, strutils, tables, times, chronicles, hashes, algorithm, bitops, math, options, tables, times, chronicles, hashes, strutils,
stew/[byteutils, endians2], metrics, stew/[byteutils, endians2], metrics,
nimcrypto/[bcmode, hash, keccak, rijndael, sysrand], nimcrypto/[bcmode, hash, keccak, rijndael, sysrand],
eth/[keys, rlp, p2p], eth/p2p/ecies eth/[keys, rlp, p2p], eth/p2p/ecies
@ -56,16 +56,16 @@ type
src*: Option[PrivateKey] ## Optional key used for signing message src*: Option[PrivateKey] ## Optional key used for signing message
dst*: Option[PublicKey] ## Optional key used for asymmetric encryption dst*: Option[PublicKey] ## Optional key used for asymmetric encryption
symKey*: Option[SymKey] ## Optional key used for symmetric encryption symKey*: Option[SymKey] ## Optional key used for symmetric encryption
payload*: Bytes ## Application data / message contents payload*: seq[byte] ## Application data / message contents
padding*: Option[Bytes] ## Padding - if unset, will automatically pad up to padding*: Option[seq[byte]] ## Padding - if unset, will automatically pad up to
## nearest maxPadLen-byte boundary ## nearest maxPadLen-byte boundary
DecodedPayload* = object DecodedPayload* = object
## The decoded payload of a received message. ## The decoded payload of a received message.
src*: Option[PublicKey] ## If the message was signed, this is the public key src*: Option[PublicKey] ## If the message was signed, this is the public key
## of the source ## of the source
payload*: Bytes ## Application data / message contents payload*: seq[byte] ## Application data / message contents
padding*: Option[Bytes] ## Message padding padding*: Option[seq[byte]] ## Message padding
Envelope* = object Envelope* = object
## What goes on the wire in the whisper protocol - a payload and some ## What goes on the wire in the whisper protocol - a payload and some
@ -74,7 +74,7 @@ type
expiry*: uint32 ## Unix timestamp when message expires expiry*: uint32 ## Unix timestamp when message expires
ttl*: uint32 ## Time-to-live, seconds - message was created at (expiry - ttl) ttl*: uint32 ## Time-to-live, seconds - message was created at (expiry - ttl)
topic*: Topic topic*: Topic
data*: Bytes ## Payload, as given by user data*: seq[byte] ## Payload, as given by user
nonce*: uint64 ## Nonce used for proof-of-work calculation nonce*: uint64 ## Nonce used for proof-of-work calculation
Message* = object Message* = object
@ -177,7 +177,7 @@ proc `or`(a, b: Bloom): Bloom =
for i in 0..<a.len: for i in 0..<a.len:
result[i] = a[i] or b[i] result[i] = a[i] or b[i]
proc bytesCopy*(bloom: var Bloom, b: Bytes) = proc bytesCopy*(bloom: var Bloom, b: openArray[byte]) =
doAssert b.len == bloomSize doAssert b.len == bloomSize
copyMem(addr bloom[0], unsafeAddr b[0], bloomSize) copyMem(addr bloom[0], unsafeAddr b[0], bloomSize)
@ -198,7 +198,7 @@ proc fullBloom*(): Bloom =
result[i] = 0xFF result[i] = 0xFF
proc encryptAesGcm(plain: openarray[byte], key: SymKey, proc encryptAesGcm(plain: openarray[byte], key: SymKey,
iv: array[gcmIVLen, byte]): Bytes = iv: array[gcmIVLen, byte]): seq[byte] =
## Encrypt using AES-GCM, making sure to append tag and iv, in that order ## Encrypt using AES-GCM, making sure to append tag and iv, in that order
var gcm: GCM[aes256] var gcm: GCM[aes256]
result = newSeqOfCap[byte](plain.len + gcmTagLen + iv.len) result = newSeqOfCap[byte](plain.len + gcmTagLen + iv.len)
@ -210,7 +210,7 @@ proc encryptAesGcm(plain: openarray[byte], key: SymKey,
result.add tag result.add tag
result.add iv result.add iv
proc decryptAesGcm(cipher: openarray[byte], key: SymKey): Option[Bytes] = proc decryptAesGcm(cipher: openarray[byte], key: SymKey): Option[seq[byte]] =
## Decrypt AES-GCM ciphertext and validate authenticity - assumes ## Decrypt AES-GCM ciphertext and validate authenticity - assumes
## cipher-tag-iv format of the buffer ## cipher-tag-iv format of the buffer
if cipher.len < gcmTagLen + gcmIVLen: if cipher.len < gcmTagLen + gcmIVLen:
@ -237,7 +237,7 @@ proc decryptAesGcm(cipher: openarray[byte], key: SymKey): Option[Bytes] =
# simply because that makes it closer to EIP 627 - see also: # simply because that makes it closer to EIP 627 - see also:
# https://github.com/paritytech/parity-ethereum/issues/9652 # https://github.com/paritytech/parity-ethereum/issues/9652
proc encode*(self: Payload): Option[Bytes] = proc encode*(self: Payload): Option[seq[byte]] =
## Encode a payload according so as to make it suitable to put in an Envelope ## Encode a payload according so as to make it suitable to put in an Envelope
## The format follows EIP 627 - https://eips.ethereum.org/EIPS/eip-627 ## The format follows EIP 627 - https://eips.ethereum.org/EIPS/eip-627
@ -333,7 +333,7 @@ proc decode*(data: openarray[byte], dst = none[PrivateKey](),
var res: DecodedPayload var res: DecodedPayload
var plain: Bytes var plain: seq[byte]
if dst.isSome(): if dst.isSome():
# XXX: eciesDecryptedLength is pretty fragile, API-wise.. is this really the # XXX: eciesDecryptedLength is pretty fragile, API-wise.. is this really the
# way to check for errors / sufficient length? # way to check for errors / sufficient length?
@ -426,11 +426,11 @@ proc valid*(self: Envelope, now = epochTime()): bool =
proc len(self: Envelope): int = 20 + self.data.len proc len(self: Envelope): int = 20 + self.data.len
proc toShortRlp*(self: Envelope): Bytes = proc toShortRlp*(self: Envelope): seq[byte] =
## RLP-encoded message without nonce is used during proof-of-work calculations ## RLP-encoded message without nonce is used during proof-of-work calculations
rlp.encodeList(self.expiry, self.ttl, self.topic, self.data) rlp.encodeList(self.expiry, self.ttl, self.topic, self.data)
proc toRlp(self: Envelope): Bytes = proc toRlp(self: Envelope): seq[byte] =
## What gets sent out over the wire includes the nonce ## What gets sent out over the wire includes the nonce
rlp.encode(self) rlp.encode(self)

View File

@ -167,7 +167,7 @@ p2pProtocol Whisper(version = whisperVersion,
proc status(peer: Peer, proc status(peer: Peer,
protocolVersion: uint, protocolVersion: uint,
powConverted: uint64, powConverted: uint64,
bloom: Bytes, bloom: seq[byte],
isLightNode: bool) isLightNode: bool)
proc messages(peer: Peer, envelopes: openarray[Envelope]) = proc messages(peer: Peer, envelopes: openarray[Envelope]) =
@ -220,7 +220,7 @@ p2pProtocol Whisper(version = whisperVersion,
peer.state.powRequirement = cast[float64](value) peer.state.powRequirement = cast[float64](value)
proc bloomFilterExchange(peer: Peer, bloom: Bytes) = proc bloomFilterExchange(peer: Peer, bloom: openArray[byte]) =
if not peer.state.initialized: if not peer.state.initialized:
warn "Handshake not completed yet, discarding bloomFilterExchange" warn "Handshake not completed yet, discarding bloomFilterExchange"
return return
@ -343,8 +343,8 @@ proc queueMessage(node: EthereumNode, msg: Message): bool =
proc postMessage*(node: EthereumNode, pubKey = none[PublicKey](), proc postMessage*(node: EthereumNode, pubKey = none[PublicKey](),
symKey = none[SymKey](), src = none[PrivateKey](), symKey = none[SymKey](), src = none[PrivateKey](),
ttl: uint32, topic: Topic, payload: Bytes, ttl: uint32, topic: Topic, payload: seq[byte],
padding = none[Bytes](), powTime = 1'f, padding = none[seq[byte]](), powTime = 1'f,
powTarget = defaultMinPow, powTarget = defaultMinPow,
targetPeer = none[NodeId]()): bool = targetPeer = none[NodeId]()): bool =
## Post a message on the message queue which will be processed at the ## Post a message on the message queue which will be processed at the

View File

@ -12,7 +12,7 @@
{.push raises: [Defect].} {.push raises: [Defect].}
import stew/ranges/stackarrays, eth/rlp/types, nimcrypto, stew/results import stew/ranges/stackarrays, nimcrypto, stew/results
from auth import ConnectionSecret from auth import ConnectionSecret
export results export results

View File

@ -3,16 +3,16 @@
## https://ethereum.github.io/yellowpaper/paper.pdf ## https://ethereum.github.io/yellowpaper/paper.pdf
import import
macros, strutils, parseutils, macros, strutils, stew/byteutils,
rlp/[types, writer, object_serialization], rlp/[writer, object_serialization],
rlp/priv/defs rlp/priv/defs
export export
types, writer, object_serialization writer, object_serialization
type type
Rlp* = object Rlp* = object
bytes: BytesRange bytes: seq[byte]
position*: int position*: int
RlpNodeType* = enum RlpNodeType* = enum
@ -22,7 +22,7 @@ type
RlpNode* = object RlpNode* = object
case kind*: RlpNodeType case kind*: RlpNodeType
of rlpBlob: of rlpBlob:
bytes*: BytesRange bytes*: seq[byte]
of rlpList: of rlpList:
elems*: seq[RlpNode] elems*: seq[RlpNode]
@ -31,101 +31,86 @@ type
UnsupportedRlpError* = object of RlpError UnsupportedRlpError* = object of RlpError
RlpTypeMismatch* = object of RlpError RlpTypeMismatch* = object of RlpError
proc rlpFromBytes*(data: BytesRange): Rlp = proc rlpFromBytes*(data: seq[byte]): Rlp =
result.bytes = data Rlp(bytes: data, position: 0)
result.position = 0
proc rlpFromBytes*(data: openArray[byte]): Rlp =
rlpFromBytes(@data)
const zeroBytesRlp* = Rlp() const zeroBytesRlp* = Rlp()
proc rlpFromHex*(input: string): Rlp = proc rlpFromHex*(input: string): Rlp =
doAssert input.len mod 2 == 0, rlpFromBytes(hexToSeqByte(input))
"rlpFromHex expects a string with even number of characters (assuming two characters per byte)"
var startByte = if input.len >= 2 and input[0] == '0' and input[1] == 'x': 2
else: 0
let totalBytes = (input.len - startByte) div 2
var backingStore = newSeq[byte](totalBytes)
for i in 0 ..< totalBytes:
var nextByte: int
if parseHex(input, nextByte, startByte + i*2, 2) == 2:
backingStore[i] = byte(nextByte)
else:
doAssert false, "rlpFromHex expects a hexademical string, but the input contains non hexademical characters"
result.bytes = backingStore.toRange()
{.this: self.}
proc hasData*(self: Rlp): bool = proc hasData*(self: Rlp): bool =
position < bytes.len self.position < self.bytes.len
proc currentElemEnd*(self: Rlp): int {.gcsafe.} proc currentElemEnd*(self: Rlp): int {.gcsafe.}
proc rawData*(self: Rlp): BytesRange = template rawData*(self: Rlp): openArray[byte] =
return self.bytes[position ..< self.currentElemEnd] self.bytes.toOpenArray(self.position, self.currentElemEnd - 1)
proc isBlob*(self: Rlp): bool = proc isBlob*(self: Rlp): bool =
hasData() and bytes[position] < LIST_START_MARKER self.hasData() and self.bytes[self.position] < LIST_START_MARKER
proc isEmpty*(self: Rlp): bool = proc isEmpty*(self: Rlp): bool =
### Contains a blob or a list of zero length ### Contains a blob or a list of zero length
hasData() and (bytes[position] == BLOB_START_MARKER or self.hasData() and (self.bytes[self.position] == BLOB_START_MARKER or
bytes[position] == LIST_START_MARKER) self.bytes[self.position] == LIST_START_MARKER)
proc isList*(self: Rlp): bool = proc isList*(self: Rlp): bool =
hasData() and bytes[position] >= LIST_START_MARKER self.hasData() and self.bytes[self.position] >= LIST_START_MARKER
template eosError = template eosError =
raise newException(MalformedRlpError, "Read past the end of the RLP stream") raise newException(MalformedRlpError, "Read past the end of the RLP stream")
template requireData {.dirty.} = template requireData {.dirty.} =
if not hasData(): if not self.hasData():
raise newException(MalformedRlpError, "Illegal operation over an empty RLP stream") raise newException(MalformedRlpError, "Illegal operation over an empty RLP stream")
proc getType*(self: Rlp): RlpNodeType = proc getType*(self: Rlp): RlpNodeType =
requireData() requireData()
return if isBlob(): rlpBlob else: rlpList return if self.isBlob(): rlpBlob else: rlpList
proc lengthBytesCount(self: Rlp): int = proc lengthBytesCount(self: Rlp): int =
var marker = bytes[position] var marker = self.bytes[self.position]
if isBlob() and marker > LEN_PREFIXED_BLOB_MARKER: if self.isBlob() and marker > LEN_PREFIXED_BLOB_MARKER:
return int(marker - LEN_PREFIXED_BLOB_MARKER) return int(marker - LEN_PREFIXED_BLOB_MARKER)
if isList() and marker > LEN_PREFIXED_LIST_MARKER: if self.isList() and marker > LEN_PREFIXED_LIST_MARKER:
return int(marker - LEN_PREFIXED_LIST_MARKER) return int(marker - LEN_PREFIXED_LIST_MARKER)
return 0 return 0
proc isSingleByte*(self: Rlp): bool = proc isSingleByte*(self: Rlp): bool =
hasData() and bytes[position] < BLOB_START_MARKER self.hasData() and self.bytes[self.position] < BLOB_START_MARKER
proc getByteValue*(self: Rlp): byte = proc getByteValue*(self: Rlp): byte =
doAssert self.isSingleByte() doAssert self.isSingleByte()
return bytes[position] return self.bytes[self.position]
proc payloadOffset(self: Rlp): int = proc payloadOffset(self: Rlp): int =
if isSingleByte(): 0 else: 1 + lengthBytesCount() if self.isSingleByte(): 0 else: 1 + self.lengthBytesCount()
template readAheadCheck(numberOfBytes) = template readAheadCheck(numberOfBytes: int) =
# important to add nothing to the left side of the equation as `numberOfBytes` # important to add nothing to the left side of the equation as `numberOfBytes`
# can in theory be at max size of its type already # can in theory be at max size of its type already
if numberOfBytes > bytes.len - position - payloadOffset(): eosError() if numberOfBytes > self.bytes.len - self.position - self.payloadOffset():
eosError()
template nonCanonicalNumberError = template nonCanonicalNumberError =
raise newException(MalformedRlpError, "Small number encoded in a non-canonical way") raise newException(MalformedRlpError, "Small number encoded in a non-canonical way")
proc payloadBytesCount(self: Rlp): int = proc payloadBytesCount(self: Rlp): int =
if not hasData(): if not self.hasData():
return 0 return 0
var marker = bytes[position] var marker = self.bytes[self.position]
if marker < BLOB_START_MARKER: if marker < BLOB_START_MARKER:
return 1 return 1
if marker <= LEN_PREFIXED_BLOB_MARKER: if marker <= LEN_PREFIXED_BLOB_MARKER:
result = int(marker - BLOB_START_MARKER) result = int(marker - BLOB_START_MARKER)
readAheadCheck(result) readAheadCheck(result)
if result == 1: if result == 1:
if bytes[position + 1] < BLOB_START_MARKER: if self.bytes[self.position + 1] < BLOB_START_MARKER:
nonCanonicalNumberError() nonCanonicalNumberError()
return return
@ -162,22 +147,22 @@ proc payloadBytesCount(self: Rlp): int =
readAheadCheck(result) readAheadCheck(result)
proc blobLen*(self: Rlp): int = proc blobLen*(self: Rlp): int =
if isBlob(): payloadBytesCount() else: 0 if self.isBlob(): self.payloadBytesCount() else: 0
proc isInt*(self: Rlp): bool = proc isInt*(self: Rlp): bool =
if not hasData(): if not self.hasData():
return false return false
var marker = bytes[position] var marker = self.bytes[self.position]
if marker < BLOB_START_MARKER: if marker < BLOB_START_MARKER:
return marker != 0 return marker != 0
if marker == BLOB_START_MARKER: if marker == BLOB_START_MARKER:
return true return true
if marker <= LEN_PREFIXED_BLOB_MARKER: if marker <= LEN_PREFIXED_BLOB_MARKER:
return bytes[position + 1] != 0 return self.bytes[self.position + 1] != 0
if marker < LIST_START_MARKER: if marker < LIST_START_MARKER:
let offset = position + int(marker + 1 - LEN_PREFIXED_BLOB_MARKER) let offset = self.position + int(marker + 1 - LEN_PREFIXED_BLOB_MARKER)
if offset >= bytes.len: eosError() if offset >= self.bytes.len: eosError()
return bytes[offset] != 0 return self.bytes[offset] != 0
return false return false
template maxBytes*(o: type[Ordinal | uint64 | uint]): int = sizeof(o) template maxBytes*(o: type[Ordinal | uint64 | uint]): int = sizeof(o)
@ -206,83 +191,83 @@ proc toInt*(self: Rlp, IntType: type): IntType =
result = (result shl 8) or OutputType(self.bytes[self.position + i]) result = (result shl 8) or OutputType(self.bytes[self.position + i])
proc toString*(self: Rlp): string = proc toString*(self: Rlp): string =
if not isBlob(): if not self.isBlob():
raise newException(RlpTypeMismatch, "String expected, but the source RLP is not a blob") raise newException(RlpTypeMismatch, "String expected, but the source RLP is not a blob")
let let
payloadOffset = payloadOffset() payloadOffset = self.payloadOffset()
payloadLen = payloadBytesCount() payloadLen = self.payloadBytesCount()
result = newString(payloadLen) result = newString(payloadLen)
for i in 0 ..< payloadLen: for i in 0 ..< payloadLen:
# XXX: switch to copyMem here # XXX: switch to copyMem here
result[i] = char(bytes[position + payloadOffset + i]) result[i] = char(self.bytes[self.position + payloadOffset + i])
proc toBytes*(self: Rlp): BytesRange = proc toBytes*(self: Rlp): seq[byte] =
if not isBlob(): if not self.isBlob():
raise newException(RlpTypeMismatch, raise newException(RlpTypeMismatch,
"Bytes expected, but the source RLP in not a blob") "Bytes expected, but the source RLP in not a blob")
let payloadLen = payloadBytesCount() let payloadLen = self.payloadBytesCount()
if payloadLen > 0: if payloadLen > 0:
let let
payloadOffset = payloadOffset() payloadOffset = self.payloadOffset()
ibegin = position + payloadOffset ibegin = self.position + payloadOffset
iend = ibegin + payloadLen - 1 iend = ibegin + payloadLen - 1
result = bytes.slice(ibegin, iend) result = self.bytes[ibegin..iend]
proc currentElemEnd*(self: Rlp): int = proc currentElemEnd*(self: Rlp): int =
doAssert hasData() doAssert self.hasData()
result = position result = self.position
if isSingleByte(): if self.isSingleByte():
result += 1 result += 1
elif isBlob() or isList(): elif self.isBlob() or self.isList():
result += payloadOffset() + payloadBytesCount() result += self.payloadOffset() + self.payloadBytesCount()
proc enterList*(self: var Rlp): bool = proc enterList*(self: var Rlp): bool =
if not isList(): if not self.isList():
return false return false
position += payloadOffset() self.position += self.payloadOffset()
return true return true
proc tryEnterList*(self: var Rlp) = proc tryEnterList*(self: var Rlp) =
if not enterList(): if not self.enterList():
raise newException(RlpTypeMismatch, "List expected, but source RLP is not a list") raise newException(RlpTypeMismatch, "List expected, but source RLP is not a list")
proc skipElem*(rlp: var Rlp) = proc skipElem*(rlp: var Rlp) =
rlp.position = rlp.currentElemEnd rlp.position = rlp.currentElemEnd
iterator items*(self: var Rlp): var Rlp = iterator items*(self: var Rlp): var Rlp =
doAssert isList() doAssert self.isList()
var var
payloadOffset = payloadOffset() payloadOffset = self.payloadOffset()
payloadEnd = position + payloadOffset + payloadBytesCount() payloadEnd = self.position + payloadOffset + self.payloadBytesCount()
if payloadEnd > bytes.len: if payloadEnd > self.bytes.len:
raise newException(MalformedRlpError, "List length extends past the end of the stream") raise newException(MalformedRlpError, "List length extends past the end of the stream")
position += payloadOffset self.position += payloadOffset
while position < payloadEnd: while self.position < payloadEnd:
let elemEnd = currentElemEnd() let elemEnd = self.currentElemEnd()
yield self yield self
position = elemEnd self.position = elemEnd
proc listElem*(self: Rlp, i: int): Rlp = proc listElem*(self: Rlp, i: int): Rlp =
doAssert isList() doAssert self.isList()
let let
payloadOffset = payloadOffset() payloadOffset = self.payloadOffset()
# This will only check if there is some data, not if it is correct according # This will only check if there is some data, not if it is correct according
# to list length. Could also run here payloadBytesCount() instead. # to list length. Could also run here payloadBytesCount() instead.
if position + payloadOffset + 1 > bytes.len: eosError() if self.position + payloadOffset + 1 > self.bytes.len: eosError()
let payload = bytes.slice(position + payloadOffset) let payload = self.bytes[self.position + payloadOffset..^1]
result = rlpFromBytes payload result = rlpFromBytes payload
var pos = 0 var pos = 0
while pos < i and result.hasData: while pos < i and result.hasData:
@ -290,7 +275,7 @@ proc listElem*(self: Rlp, i: int): Rlp =
inc pos inc pos
proc listLen*(self: Rlp): int = proc listLen*(self: Rlp): int =
if not isList(): if not self.isList():
return 0 return 0
var rlp = self var rlp = self
@ -336,7 +321,7 @@ proc readImpl[R, E](rlp: var Rlp, T: type array[R, E]): T =
if result.len != bytes.len: if result.len != bytes.len:
raise newException(RlpTypeMismatch, "Fixed-size array expected, but the source RLP contains a blob of different length") raise newException(RlpTypeMismatch, "Fixed-size array expected, but the source RLP contains a blob of different length")
copyMem(addr result[0], bytes.baseAddr, bytes.len) copyMem(addr result[0], unsafeAddr bytes[0], bytes.len)
rlp.skipElem rlp.skipElem
@ -356,10 +341,7 @@ proc readImpl[E](rlp: var Rlp, T: type seq[E]): T =
mixin read mixin read
when E is (byte or char): when E is (byte or char):
var bytes = rlp.toBytes result = rlp.toBytes
if bytes.len != 0:
result = newSeq[byte](bytes.len)
copyMem(addr result[0], bytes.baseAddr, bytes.len)
rlp.skipElem rlp.skipElem
else: else:
if not rlp.isList: if not rlp.isList:
@ -383,7 +365,9 @@ proc readImpl(rlp: var Rlp, T: type[object|tuple],
"List expected, but the source RLP is not a list.") "List expected, but the source RLP is not a list.")
var var
payloadOffset = rlp.payloadOffset() payloadOffset = rlp.payloadOffset()
payloadEnd = rlp.position + payloadOffset + rlp.payloadBytesCount()
# there's an exception-raising side effect in there *sigh*
discard rlp.payloadBytesCount()
rlp.position += payloadOffset rlp.position += payloadOffset
@ -398,16 +382,16 @@ proc readImpl(rlp: var Rlp, T: type[object|tuple],
proc toNodes*(self: var Rlp): RlpNode = proc toNodes*(self: var Rlp): RlpNode =
requireData() requireData()
if isList(): if self.isList():
result.kind = rlpList result.kind = rlpList
newSeq result.elems, 0 newSeq result.elems, 0
for e in self: for e in self:
result.elems.add e.toNodes result.elems.add e.toNodes
else: else:
doAssert isBlob() doAssert self.isBlob()
result.kind = rlpBlob result.kind = rlpBlob
result.bytes = toBytes() result.bytes = self.toBytes()
position = currentElemEnd() self.position = self.currentElemEnd()
# We define a single `read` template with a pretty low specifity # We define a single `read` template with a pretty low specifity
# score in order to facilitate easier overloading with user types: # score in order to facilitate easier overloading with user types:
@ -422,22 +406,18 @@ template readRecordType*(rlp: var Rlp, T: type, wrappedInList: bool): auto =
readImpl(rlp, T, wrappedInList) readImpl(rlp, T, wrappedInList)
proc decode*(bytes: openarray[byte]): RlpNode = proc decode*(bytes: openarray[byte]): RlpNode =
var var rlp = rlpFromBytes(bytes)
bytesCopy = @bytes rlp.toNodes
rlp = rlpFromBytes(bytesCopy.toRange())
return rlp.toNodes
template decode*(bytes: BytesRange, T: type): untyped = template decode*(bytes: openArray[byte], T: type): untyped =
mixin read mixin read
var rlp = rlpFromBytes(bytes) var rlp = rlpFromBytes(bytes)
rlp.read(T) rlp.read(T)
template decode*(bytes: openarray[byte], T: type): T =
var bytesCopy = @bytes
decode(bytesCopy.toRange, T)
template decode*(bytes: seq[byte], T: type): untyped = template decode*(bytes: seq[byte], T: type): untyped =
decode(bytes.toRange, T) mixin read
var rlp = rlpFromBytes(bytes)
rlp.read(T)
proc append*(writer: var RlpWriter; rlp: Rlp) = proc append*(writer: var RlpWriter; rlp: Rlp) =
appendRawBytes(writer, rlp.rawData) appendRawBytes(writer, rlp.rawData)
@ -450,7 +430,7 @@ proc isPrintable(s: string): bool =
return true return true
proc inspectAux(self: var Rlp, depth: int, hexOutput: bool, output: var string) = proc inspectAux(self: var Rlp, depth: int, hexOutput: bool, output: var string) =
if not hasData(): if not self.hasData():
return return
template indent = template indent =
@ -461,7 +441,7 @@ proc inspectAux(self: var Rlp, depth: int, hexOutput: bool, output: var string)
if self.isSingleByte: if self.isSingleByte:
output.add "byte " output.add "byte "
output.add $bytes[position] output.add $self.bytes[self.position]
elif self.isBlob: elif self.isBlob:
let str = self.toString let str = self.toString
if str.isPrintable: if str.isPrintable:
@ -491,6 +471,5 @@ proc inspectAux(self: var Rlp, depth: int, hexOutput: bool, output: var string)
proc inspect*(self: Rlp, indent = 0, hexOutput = true): string = proc inspect*(self: Rlp, indent = 0, hexOutput = true): string =
var rlpCopy = self var rlpCopy = self
result = newStringOfCap(bytes.len) result = newStringOfCap(self.bytes.len)
inspectAux(rlpCopy, indent, hexOutput, result) inspectAux(rlpCopy, indent, hexOutput, result)

View File

@ -1,6 +0,0 @@
import stew/ranges
export ranges
type
Bytes* = seq[byte]
BytesRange* = Range[byte]

View File

@ -1,15 +1,11 @@
import import
macros, types, macros,
stew/ranges/[memranges, ptr_arith],
object_serialization, priv/defs object_serialization, priv/defs
export
memranges
type type
RlpWriter* = object RlpWriter* = object
pendingLists: seq[tuple[remainingItems, outBytes: int]] pendingLists: seq[tuple[remainingItems, outBytes: int]]
output: Bytes output: seq[byte]
IntLike* = concept x, y IntLike* = concept x, y
type T = type(x) type T = type(x)
@ -39,7 +35,7 @@ proc bytesNeeded(num: Integer): int =
inc result inc result
n = n shr 8 n = n shr 8
proc writeBigEndian(outStream: var Bytes, number: Integer, proc writeBigEndian(outStream: var seq[byte], number: Integer,
lastByteIdx: int, numberOfBytes: int) = lastByteIdx: int, numberOfBytes: int) =
mixin `and`, `shr` mixin `and`, `shr`
@ -48,12 +44,12 @@ proc writeBigEndian(outStream: var Bytes, number: Integer,
outStream[i] = byte(n and 0xff) outStream[i] = byte(n and 0xff)
n = n shr 8 n = n shr 8
proc writeBigEndian(outStream: var Bytes, number: Integer, proc writeBigEndian(outStream: var seq[byte], number: Integer,
numberOfBytes: int) {.inline.} = numberOfBytes: int) {.inline.} =
outStream.setLen(outStream.len + numberOfBytes) outStream.setLen(outStream.len + numberOfBytes)
outStream.writeBigEndian(number, outStream.len - 1, numberOfBytes) outStream.writeBigEndian(number, outStream.len - 1, numberOfBytes)
proc writeCount(bytes: var Bytes, count: int, baseMarker: byte) = proc writeCount(bytes: var seq[byte], count: int, baseMarker: byte) =
if count < THRESHOLD_LIST_LEN: if count < THRESHOLD_LIST_LEN:
bytes.add(baseMarker + byte(count)) bytes.add(baseMarker + byte(count))
else: else:
@ -65,19 +61,6 @@ proc writeCount(bytes: var Bytes, count: int, baseMarker: byte) =
bytes[origLen] = baseMarker + (THRESHOLD_LIST_LEN - 1) + byte(lenPrefixBytes) bytes[origLen] = baseMarker + (THRESHOLD_LIST_LEN - 1) + byte(lenPrefixBytes)
bytes.writeBigEndian(count, bytes.len - 1, lenPrefixBytes) bytes.writeBigEndian(count, bytes.len - 1, lenPrefixBytes)
proc add(outStream: var Bytes, newChunk: BytesRange) =
let prevLen = outStream.len
outStream.setLen(prevLen + newChunk.len)
# XXX: Use copyMem here
for i in 0 ..< newChunk.len:
outStream[prevLen + i] = newChunk[i]
{.this: self.}
{.experimental.}
using
self: var RlpWriter
proc initRlpWriter*: RlpWriter = proc initRlpWriter*: RlpWriter =
newSeq(result.pendingLists, 0) newSeq(result.pendingLists, 0)
newSeq(result.output, 0) newSeq(result.output, 0)
@ -86,88 +69,74 @@ proc decRet(n: var int, delta: int): int =
n -= delta n -= delta
return n return n
proc maybeClosePendingLists(self) = proc maybeClosePendingLists(self: var RlpWriter) =
while pendingLists.len > 0: while self.pendingLists.len > 0:
let lastListIdx = pendingLists.len - 1 let lastListIdx = self.pendingLists.len - 1
doAssert pendingLists[lastListIdx].remainingItems >= 1 doAssert self.pendingLists[lastListIdx].remainingItems >= 1
if decRet(pendingLists[lastListIdx].remainingItems, 1) == 0: if decRet(self.pendingLists[lastListIdx].remainingItems, 1) == 0:
# A list have been just finished. It was started in `startList`. # A list have been just finished. It was started in `startList`.
let listStartPos = pendingLists[lastListIdx].outBytes let listStartPos = self.pendingLists[lastListIdx].outBytes
pendingLists.setLen lastListIdx self.pendingLists.setLen lastListIdx
# How many bytes were written since the start? # How many bytes were written since the start?
let listLen = output.len - listStartPos let listLen = self.output.len - listStartPos
# Compute the number of bytes required to write down the list length # Compute the number of bytes required to write down the list length
let totalPrefixBytes = if listLen < int(THRESHOLD_LIST_LEN): 1 let totalPrefixBytes = if listLen < int(THRESHOLD_LIST_LEN): 1
else: int(listLen.bytesNeeded) + 1 else: int(listLen.bytesNeeded) + 1
# Shift the written data to make room for the prefix length # Shift the written data to make room for the prefix length
output.setLen(output.len + totalPrefixBytes) self.output.setLen(self.output.len + totalPrefixBytes)
let outputBaseAddr = output.baseAddr
moveMem(outputBaseAddr.shift(listStartPos + totalPrefixBytes), moveMem(addr self.output[listStartPos + totalPrefixBytes],
outputBaseAddr.shift(listStartPos), unsafeAddr self.output[listStartPos],
listLen) listLen)
# Write out the prefix length # Write out the prefix length
if listLen < THRESHOLD_LIST_LEN: if listLen < THRESHOLD_LIST_LEN:
output[listStartPos] = LIST_START_MARKER + byte(listLen) self.output[listStartPos] = LIST_START_MARKER + byte(listLen)
else: else:
let listLenBytes = totalPrefixBytes - 1 let listLenBytes = totalPrefixBytes - 1
output[listStartPos] = LEN_PREFIXED_LIST_MARKER + byte(listLenBytes) self.output[listStartPos] = LEN_PREFIXED_LIST_MARKER + byte(listLenBytes)
output.writeBigEndian(listLen, listStartPos + listLenBytes, listLenBytes) self.output.writeBigEndian(listLen, listStartPos + listLenBytes, listLenBytes)
else: else:
# The currently open list is not finished yet. Nothing to do. # The currently open list is not finished yet. Nothing to do.
return return
proc appendRawList(self; bytes: BytesRange) = proc appendRawList(self: var RlpWriter, bytes: openArray[byte]) =
output.writeCount(bytes.len, LIST_START_MARKER) self.output.writeCount(bytes.len, LIST_START_MARKER)
output.add(bytes) self.output.add(bytes)
maybeClosePendingLists() self.maybeClosePendingLists()
proc appendRawBytes*(self; bytes: BytesRange) = proc appendRawBytes*(self: var RlpWriter, bytes: openArray[byte]) =
output.add(bytes) self.output.add(bytes)
maybeClosePendingLists() self.maybeClosePendingLists()
proc startList*(self; listSize: int) = proc startList*(self: var RlpWriter, listSize: int) =
if listSize == 0: if listSize == 0:
appendRawList(BytesRange()) self.appendRawList([])
else: else:
pendingLists.add((listSize, output.len)) self.pendingLists.add((listSize, self.output.len))
template appendBlob(self; data, startMarker) =
mixin baseAddr
proc appendBlob(self: var RlpWriter, data: openArray[byte], startMarker: byte) =
if data.len == 1 and byte(data[0]) < BLOB_START_MARKER: if data.len == 1 and byte(data[0]) < BLOB_START_MARKER:
self.output.add byte(data[0]) self.output.add byte(data[0])
else: else:
self.output.writeCount(data.len, startMarker) self.output.writeCount(data.len, startMarker)
self.output.add data
let startPos = output.len self.maybeClosePendingLists()
self.output.setLen(startPos + data.len)
copyMem(shift(baseAddr(self.output), startPos),
baseAddr(data),
data.len)
maybeClosePendingLists() proc appendImpl(self: var RlpWriter, data: string) =
appendBlob(self, data.toOpenArrayByte(0, data.high), BLOB_START_MARKER)
proc appendImpl(self; data: string) = proc appendBlob(self: var RlpWriter, data: openarray[byte]) =
appendBlob(self, data, BLOB_START_MARKER) appendBlob(self, data, BLOB_START_MARKER)
proc appendBlob(self; data: openarray[byte]) = proc appendBlob(self: var RlpWriter, data: openarray[char]) =
appendBlob(self, data, BLOB_START_MARKER) appendBlob(self, data.toOpenArrayByte(0, data.high), BLOB_START_MARKER)
proc appendBlob(self; data: openarray[char]) = proc appendInt(self: var RlpWriter, i: Integer) =
appendBlob(self, data, BLOB_START_MARKER)
proc appendBytesRange(self; data: BytesRange) =
appendBlob(self, data, BLOB_START_MARKER)
proc appendImpl(self; data: MemRange) =
appendBlob(self, data, BLOB_START_MARKER)
proc appendInt(self; i: Integer) =
# this is created as a separate proc as an extra precaution against # this is created as a separate proc as an extra precaution against
# any overloading resolution problems when matching the IntLike concept. # any overloading resolution problems when matching the IntLike concept.
type IntType = type(i) type IntType = type(i)
@ -183,7 +152,7 @@ proc appendInt(self; i: Integer) =
self.maybeClosePendingLists() self.maybeClosePendingLists()
proc appendFloat(self; data: float64) = proc appendFloat(self: var RlpWriter, data: float64) =
# This is not covered in the RLP spec, but Geth uses Go's # This is not covered in the RLP spec, but Geth uses Go's
# `math.Float64bits`, which is defined here: # `math.Float64bits`, which is defined here:
# https://github.com/gopherjs/gopherjs/blob/master/compiler/natives/src/math/math.go # https://github.com/gopherjs/gopherjs/blob/master/compiler/natives/src/math/math.go
@ -191,16 +160,16 @@ proc appendFloat(self; data: float64) =
let uint64bits = (uint64(uintWords[1]) shl 32) or uint64(uintWords[0]) let uint64bits = (uint64(uintWords[1]) shl 32) or uint64(uintWords[0])
self.appendInt(uint64bits) self.appendInt(uint64bits)
template appendImpl(self; i: Integer) = template appendImpl(self: var RlpWriter, i: Integer) =
appendInt(self, i) appendInt(self, i)
template appendImpl(self; e: enum) = template appendImpl(self: var RlpWriter, e: enum) =
appendImpl(self, int(e)) appendImpl(self, int(e))
template appendImpl(self; b: bool) = template appendImpl(self: var RlpWriter, b: bool) =
appendImpl(self, int(b)) appendImpl(self, int(b))
proc appendImpl[T](self; listOrBlob: openarray[T]) = proc appendImpl[T](self: var RlpWriter, listOrBlob: openarray[T]) =
mixin append mixin append
# TODO: This append proc should be overloaded by `openarray[byte]` after # TODO: This append proc should be overloaded by `openarray[byte]` after
@ -212,7 +181,7 @@ proc appendImpl[T](self; listOrBlob: openarray[T]) =
for i in 0 ..< listOrBlob.len: for i in 0 ..< listOrBlob.len:
self.append listOrBlob[i] self.append listOrBlob[i]
proc appendRecordType*(self; obj: object|tuple, wrapInList = wrapObjsInList) = proc appendRecordType*(self: var RlpWriter, obj: object|tuple, wrapInList = wrapObjsInList) =
mixin enumerateRlpFields, append mixin enumerateRlpFields, append
if wrapInList: if wrapInList:
@ -226,15 +195,10 @@ proc appendRecordType*(self; obj: object|tuple, wrapInList = wrapObjsInList) =
enumerateRlpFields(obj, op) enumerateRlpFields(obj, op)
proc appendImpl(self; data: object) {.inline.} = proc appendImpl(self: var RlpWriter, data: object) {.inline.} =
# TODO: This append proc should be overloaded by `BytesRange` after
# nim bug #7416 is fixed.
when data is BytesRange:
self.appendBytesRange(data)
else:
self.appendRecordType(data) self.appendRecordType(data)
proc appendImpl(self; data: tuple) {.inline.} = proc appendImpl(self: var RlpWriter, data: tuple) {.inline.} =
self.appendRecordType(data) self.appendRecordType(data)
# We define a single `append` template with a pretty low specifity # We define a single `append` template with a pretty low specifity
@ -253,22 +217,22 @@ proc initRlpList*(listSize: int): RlpWriter =
startList(result, listSize) startList(result, listSize)
# TODO: This should return a lent value # TODO: This should return a lent value
proc finish*(self): Bytes = template finish*(self: RlpWriter): seq[byte] =
doAssert pendingLists.len == 0, "Insufficient number of elements written to a started list" doAssert self.pendingLists.len == 0, "Insufficient number of elements written to a started list"
result = output self.output
proc encode*[T](v: T): Bytes = proc encode*[T](v: T): seq[byte] =
mixin append mixin append
var writer = initRlpWriter() var writer = initRlpWriter()
writer.append(v) writer.append(v)
return writer.finish return writer.finish
proc encodeInt*(i: Integer): Bytes = proc encodeInt*(i: Integer): seq[byte] =
var writer = initRlpWriter() var writer = initRlpWriter()
writer.appendInt(i) writer.appendInt(i)
return writer.finish return writer.finish
macro encodeList*(args: varargs[untyped]): Bytes = macro encodeList*(args: varargs[untyped]): seq[byte] =
var var
listLen = args.len listLen = args.len
writer = genSym(nskVar, "rlpWriter") writer = genSym(nskVar, "rlpWriter")
@ -286,10 +250,9 @@ macro encodeList*(args: varargs[untyped]): Bytes =
when false: when false:
# XXX: Currently fails with a malformed AST error on the args.len expression # XXX: Currently fails with a malformed AST error on the args.len expression
template encodeList*(args: varargs[untyped]): BytesRange = template encodeList*(args: varargs[untyped]): seq[byte] =
mixin append mixin append
var writer = initRlpList(args.len) var writer = initRlpList(args.len)
for arg in args: for arg in args:
writer.append(arg) writer.append(arg)
writer.finish writer.finish

View File

@ -1,5 +1,5 @@
import import
stew/ranges, tables, sets, tables, sets,
eth/trie/db eth/trie/db
type type

View File

@ -1,4 +1,4 @@
import os, stew/ranges, eth/trie/[trie_defs, db_tracing] import os, eth/trie/[trie_defs, db_tracing]
import backend_defs import backend_defs
when defined(windows): when defined(windows):

View File

@ -1,4 +1,4 @@
import os, rocksdb, stew/ranges, eth/trie/[trie_defs, db_tracing] import os, rocksdb, eth/trie/[trie_defs, db_tracing]
import backend_defs import backend_defs
type type

View File

@ -1,5 +1,5 @@
import import
os, sqlite3, stew/ranges, stew/ranges/ptr_arith, eth/trie/[db_tracing, trie_defs], os, sqlite3, stew/ranges/ptr_arith, eth/trie/[db_tracing, trie_defs],
backend_defs backend_defs
type type

View File

@ -1,6 +1,7 @@
import import
sequtils, sequtils,
stew/ranges/[ptr_arith, bitranges], eth/rlp/types, trie_defs stew/ranges/ptr_arith, trie_defs,
./trie_bitseq
type type
TrieNodeKind* = enum TrieNodeKind* = enum
@ -8,31 +9,30 @@ type
BRANCH_TYPE = 1 BRANCH_TYPE = 1
LEAF_TYPE = 2 LEAF_TYPE = 2
TrieNodeKey* = BytesRange TrieNodeKey* = seq[byte]
TrieBitRange* = BitRange
TrieNode* = object TrieNode* = object
case kind*: TrieNodeKind case kind*: TrieNodeKind
of KV_TYPE: of KV_TYPE:
keyPath*: TrieBitRange keyPath*: TrieBitSeq
child*: TrieNodeKey child*: TrieNodeKey
of BRANCH_TYPE: of BRANCH_TYPE:
leftChild*: TrieNodeKey leftChild*: TrieNodeKey
rightChild*: TrieNodeKey rightChild*: TrieNodeKey
of LEAF_TYPE: of LEAF_TYPE:
value*: BytesRange value*: seq[byte]
InvalidNode* = object of CorruptedTrieDatabase InvalidNode* = object of CorruptedTrieDatabase
ValidationError* = object of CorruptedTrieDatabase ValidationError* = object of CorruptedTrieDatabase
# ---------------------------------------------- # ----------------------------------------------
template sliceToEnd*(r: TrieBitRange, index: int): TrieBitRange = template sliceToEnd*(r: TrieBitSeq, index: int): TrieBitSeq =
if r.len <= index: TrieBitRange() else: r[index .. ^1] if r.len <= index: TrieBitSeq() else: r[index .. ^1]
proc decodeToBinKeypath*(path: BytesRange): TrieBitRange = proc decodeToBinKeypath*(path: seq[byte]): TrieBitSeq =
## Decodes bytes into a sequence of 0s and 1s ## Decodes bytes into a sequence of 0s and 1s
## Used in decoding key path of a KV-NODE ## Used in decoding key path of a KV-NODE
var path = MutByteRange(path).bits var path = path.bits
if path[0]: if path[0]:
path = path[4..^1] path = path[4..^1]
@ -42,11 +42,11 @@ proc decodeToBinKeypath*(path: BytesRange): TrieBitRange =
bits = bits or path[3].int bits = bits or path[3].int
if path.len > 4: if path.len > 4:
result = path[4+((4 - bits) mod 4)..^1] path[4+((4 - bits) mod 4)..^1]
else: else:
result = BitRange() TrieBitSeq()
proc parseNode*(node: BytesRange): TrieNode = proc parseNode*(node: openArray[byte]): TrieNode =
# Input: a serialized node # Input: a serialized node
if node.len == 0: if node.len == 0:
@ -76,7 +76,7 @@ proc parseNode*(node: BytesRange): TrieNode =
# Output: node type, value # Output: node type, value
return TrieNode(kind: LEAF_TYPE, value: node[1..^1]) return TrieNode(kind: LEAF_TYPE, value: node[1..^1])
proc encodeKVNode*(keyPath: TrieBitRange, childHash: TrieNodeKey): Bytes = proc encodeKVNode*(keyPath: TrieBitSeq, childHash: TrieNodeKey): seq[byte] =
## Serializes a key/value node ## Serializes a key/value node
if keyPath.len == 0: if keyPath.len == 0:
raise newException(ValidationError, "Key path can not be empty") raise newException(ValidationError, "Key path can not be empty")
@ -110,13 +110,13 @@ proc encodeKVNode*(keyPath: TrieBitRange, childHash: TrieNodeKey): Bytes =
inc(nbits, 8) inc(nbits, 8)
copyMem(result[^32].addr, childHash.baseAddr, 32) copyMem(result[^32].addr, childHash.baseAddr, 32)
proc encodeKVNode*(keyPath: bool, childHash: TrieNodeKey): Bytes = proc encodeKVNode*(keyPath: bool, childHash: TrieNodeKey): seq[byte] =
result = newSeq[byte](34) result = newSeq[byte](34)
result[0] = KV_TYPE.byte result[0] = KV_TYPE.byte
result[1] = byte(16) or byte(keyPath) result[1] = byte(16) or byte(keyPath)
copyMem(result[^32].addr, childHash.baseAddr, 32) copyMem(result[^32].addr, childHash.baseAddr, 32)
proc encodeBranchNode*(leftChildHash, rightChildHash: TrieNodeKey): Bytes = proc encodeBranchNode*(leftChildHash, rightChildHash: TrieNodeKey): seq[byte] =
## Serializes a branch node ## Serializes a branch node
const const
BRANCH_TYPE_PREFIX = @[BRANCH_TYPE.byte] BRANCH_TYPE_PREFIX = @[BRANCH_TYPE.byte]
@ -126,7 +126,7 @@ proc encodeBranchNode*(leftChildHash, rightChildHash: TrieNodeKey): Bytes =
result = BRANCH_TYPE_PREFIX.concat(leftChildHash, rightChildHash) result = BRANCH_TYPE_PREFIX.concat(leftChildHash, rightChildHash)
proc encodeLeafNode*(value: BytesRange | Bytes): Bytes = proc encodeLeafNode*(value: openArray[byte]): seq[byte] =
## Serializes a leaf node ## Serializes a leaf node
const const
LEAF_TYPE_PREFIX = @[LEAF_TYPE.byte] LEAF_TYPE_PREFIX = @[LEAF_TYPE.byte]
@ -134,9 +134,9 @@ proc encodeLeafNode*(value: BytesRange | Bytes): Bytes =
if value.len == 0: if value.len == 0:
raise newException(ValidationError, "Value of leaf node can not be empty") raise newException(ValidationError, "Value of leaf node can not be empty")
result = LEAF_TYPE_PREFIX.concat(value) result = LEAF_TYPE_PREFIX.concat(@value)
proc getCommonPrefixLength*(a, b: TrieBitRange): int = proc getCommonPrefixLength*(a, b: TrieBitSeq): int =
let len = min(a.len, b.len) let len = min(a.len, b.len)
for i in 0..<len: for i in 0..<len:
if a[i] != b[i]: return i if a[i] != b[i]: return i

View File

@ -1,9 +1,9 @@
import import
stew/ranges/[typedranges, bitranges], eth/rlp/types, ./trie_bitseq,
trie_defs, db, binaries, trie_utils ./trie_defs, ./db, ./binaries, ./trie_utils
export export
types, trie_utils trie_utils
type type
DB = TrieDatabaseRef DB = TrieDatabaseRef
@ -14,18 +14,18 @@ type
NodeOverrideError* = object of CatchableError NodeOverrideError* = object of CatchableError
let const
zeroHash* = zeroBytesRange zeroHash* = default(seq[byte])
proc init*(x: typedesc[BinaryTrie], db: DB, proc init*(x: typedesc[BinaryTrie], db: DB,
rootHash: BytesContainer | KeccakHash = zeroHash): BinaryTrie = rootHash: openArray[byte]): BinaryTrie =
checkValidHashZ(rootHash) checkValidHashZ(rootHash)
result.db = db result.db = db
result.rootHash = toRange(rootHash) result.rootHash = @(rootHash)
proc getDB*(t: BinaryTrie): auto = t.db proc getDB*(t: BinaryTrie): auto = t.db
proc initBinaryTrie*(db: DB, rootHash: BytesContainer | KeccakHash): BinaryTrie = proc initBinaryTrie*(db: DB, rootHash: openArray[byte]): BinaryTrie =
init(BinaryTrie, db, rootHash) init(BinaryTrie, db, rootHash)
proc initBinaryTrie*(db: DB): BinaryTrie = proc initBinaryTrie*(db: DB): BinaryTrie =
@ -36,64 +36,64 @@ proc getRootHash*(self: BinaryTrie): TrieNodeKey {.inline.} =
template fetchNode(self: BinaryTrie, nodeHash: TrieNodeKey): TrieNode = template fetchNode(self: BinaryTrie, nodeHash: TrieNodeKey): TrieNode =
doAssert(nodeHash.len == 32) doAssert(nodeHash.len == 32)
parseNode self.db.get(nodeHash.toOpenArray).toRange parseNode self.db.get(nodeHash)
proc getAux(self: BinaryTrie, nodeHash: TrieNodeKey, keyPath: TrieBitRange): BytesRange = proc getAux(self: BinaryTrie, nodeHash: TrieNodeKey, keyPath: TrieBitSeq): seq[byte] =
# Empty trie # Empty trie
if isZeroHash(nodeHash): if isZeroHash(nodeHash):
return zeroBytesRange return
let node = self.fetchNode(nodeHash) let node = self.fetchNode(nodeHash)
# Key-value node descend # Key-value node descend
if node.kind == LEAF_TYPE: if node.kind == LEAF_TYPE:
if keyPath.len != 0: return zeroBytesRange if keyPath.len != 0: return
return node.value return node.value
elif node.kind == KV_TYPE: elif node.kind == KV_TYPE:
# keyPath too short # keyPath too short
if keyPath.len == 0: return zeroBytesRange if keyPath.len == 0: return
let sliceLen = min(node.keyPath.len, keyPath.len) let sliceLen = min(node.keyPath.len, keyPath.len)
if keyPath[0..<sliceLen] == node.keyPath: if keyPath[0..<sliceLen] == node.keyPath:
return self.getAux(node.child, keyPath.sliceToEnd(node.keyPath.len)) return self.getAux(node.child, keyPath.sliceToEnd(node.keyPath.len))
else: else:
return zeroBytesRange return
# Branch node descend # Branch node descend
elif node.kind == BRANCH_TYPE: elif node.kind == BRANCH_TYPE:
# keyPath too short # keyPath too short
if keyPath.len == 0: return zeroBytesRange if keyPath.len == 0: return
if keyPath[0]: # first bit == 1 if keyPath[0]: # first bit == 1
return self.getAux(node.rightChild, keyPath.sliceToEnd(1)) return self.getAux(node.rightChild, keyPath.sliceToEnd(1))
else: else:
return self.getAux(node.leftChild, keyPath.sliceToEnd(1)) return self.getAux(node.leftChild, keyPath.sliceToEnd(1))
proc get*(self: BinaryTrie, key: BytesContainer): BytesRange {.inline.} = proc get*(self: BinaryTrie, key: openArray[byte]): seq[byte] {.inline.} =
var keyBits = MutByteRange(key.toRange).bits var keyBits = key.bits
return self.getAux(self.rootHash, keyBits) return self.getAux(self.rootHash, keyBits)
proc hashAndSave*(self: BinaryTrie, node: BytesRange | Bytes): TrieNodeKey = proc hashAndSave*(self: BinaryTrie, node: openArray[byte]): TrieNodeKey =
result = keccakHash(node) result = @(keccakHash(node).data)
self.db.put(result.toOpenArray, node.toRange.toOpenArray) self.db.put(result, node)
template saveKV(self: BinaryTrie, keyPath: TrieBitRange | bool, child: BytesRange): untyped = template saveKV(self: BinaryTrie, keyPath: TrieBitSeq | bool, child: openArray[byte]): untyped =
self.hashAndsave(encodeKVNode(keyPath, child)) self.hashAndsave(encodeKVNode(keyPath, child))
template saveLeaf(self: BinaryTrie, value: BytesRange): untyped = template saveLeaf(self: BinaryTrie, value: openArray[byte]): untyped =
self.hashAndsave(encodeLeafNode(value)) self.hashAndsave(encodeLeafNode(value))
template saveBranch(self: BinaryTrie, L, R: BytesRange): untyped = template saveBranch(self: BinaryTrie, L, R: openArray[byte]): untyped =
self.hashAndsave(encodeBranchNode(L, R)) self.hashAndsave(encodeBranchNode(L, R))
proc setBranchNode(self: BinaryTrie, keyPath: TrieBitRange, node: TrieNode, proc setBranchNode(self: BinaryTrie, keyPath: TrieBitSeq, node: TrieNode,
value: BytesRange, deleteSubtrie = false): TrieNodeKey value: openArray[byte], deleteSubtrie = false): TrieNodeKey
proc setKVNode(self: BinaryTrie, keyPath: TrieBitRange, nodeHash: TrieNodeKey, proc setKVNode(self: BinaryTrie, keyPath: TrieBitSeq, nodeHash: TrieNodeKey,
node: TrieNode, value: BytesRange, deleteSubtrie = false): TrieNodeKey node: TrieNode, value: openArray[byte], deleteSubtrie = false): TrieNodeKey
const const
overrideErrorMsg = overrideErrorMsg =
"Fail to set the value because the prefix of it's key is the same as existing key" "Fail to set the value because the prefix of it's key is the same as existing key"
proc setAux(self: BinaryTrie, nodeHash: TrieNodeKey, keyPath: TrieBitRange, proc setAux(self: BinaryTrie, nodeHash: TrieNodeKey, keyPath: TrieBitSeq,
value: BytesRange, deleteSubtrie = false): TrieNodeKey = value: openArray[byte], deleteSubtrie = false): TrieNodeKey =
## If deleteSubtrie is set to True, what it will do is that it take in a keyPath ## If deleteSubtrie is set to True, what it will do is that it take in a keyPath
## and traverse til the end of keyPath, then delete the whole subtrie of that node. ## and traverse til the end of keyPath, then delete the whole subtrie of that node.
## Note: keyPath should be in binary array format, i.e., encoded by encode_to_bin() ## Note: keyPath should be in binary array format, i.e., encoded by encode_to_bin()
@ -131,15 +131,15 @@ proc setAux(self: BinaryTrie, nodeHash: TrieNodeKey, keyPath: TrieBitRange,
checkBadKeyPath() checkBadKeyPath()
return self.setBranchNode(keyPath, node, value, deleteSubtrie) return self.setBranchNode(keyPath, node, value, deleteSubtrie)
proc set*(self: var BinaryTrie, key, value: distinct BytesContainer) {.inline.} = proc set*(self: var BinaryTrie, key, value: openArray[byte]) {.inline.} =
## Sets the value at the given keyPath from the given node ## Sets the value at the given keyPath from the given node
## Key will be encoded into binary array format first. ## Key will be encoded into binary array format first.
var keyBits = bits MutByteRange(key.toRange) var keyBits = key.bits
self.rootHash = self.setAux(self.rootHash, keyBits, toRange(value)) self.rootHash = self.setAux(self.rootHash, keyBits, value)
proc setBranchNode(self: BinaryTrie, keyPath: TrieBitRange, node: TrieNode, proc setBranchNode(self: BinaryTrie, keyPath: TrieBitSeq, node: TrieNode,
value: BytesRange, deleteSubtrie = false): TrieNodeKey = value: openArray[byte], deleteSubtrie = false): TrieNodeKey =
# Which child node to update? Depends on first bit in keyPath # Which child node to update? Depends on first bit in keyPath
var newLeftChild, newRightChild: TrieNodeKey var newLeftChild, newRightChild: TrieNodeKey
@ -169,8 +169,8 @@ proc setBranchNode(self: BinaryTrie, keyPath: TrieBitRange, node: TrieNode,
else: else:
result = self.saveBranch(newLeftChild, newRightChild) result = self.saveBranch(newLeftChild, newRightChild)
proc setKVNode(self: BinaryTrie, keyPath: TrieBitRange, nodeHash: TrieNodeKey, proc setKVNode(self: BinaryTrie, keyPath: TrieBitSeq, nodeHash: TrieNodeKey,
node: TrieNode, value: BytesRange, deleteSubtrie = false): TrieNodeKey = node: TrieNode, value: openArray[byte], deleteSubtrie = false): TrieNodeKey =
# keyPath prefixes match # keyPath prefixes match
if deleteSubtrie: if deleteSubtrie:
if keyPath.len < node.keyPath.len and keyPath == node.keyPath[0..<keyPath.len]: if keyPath.len < node.keyPath.len and keyPath == node.keyPath[0..<keyPath.len]:
@ -251,35 +251,35 @@ proc setKVNode(self: BinaryTrie, keyPath: TrieBitRange, nodeHash: TrieNodeKey,
else: else:
return newSub return newSub
template exists*(self: BinaryTrie, key: BytesContainer): bool = template exists*(self: BinaryTrie, key: openArray[byte]): bool =
self.get(toRange(key)) != zeroBytesRange self.get(key) != []
proc delete*(self: var BinaryTrie, key: BytesContainer) {.inline.} = proc delete*(self: var BinaryTrie, key: openArray[byte]) {.inline.} =
## Equals to setting the value to zeroBytesRange ## Equals to setting the value to zeroBytesRange
var keyBits = bits MutByteRange(key.toRange) var keyBits = key.bits
self.rootHash = self.setAux(self.rootHash, keyBits, zeroBytesRange) self.rootHash = self.setAux(self.rootHash, keyBits, [])
proc deleteSubtrie*(self: var BinaryTrie, key: BytesContainer) {.inline.} = proc deleteSubtrie*(self: var BinaryTrie, key: openArray[byte]) {.inline.} =
## Given a key prefix, delete the whole subtrie that starts with the key prefix. ## Given a key prefix, delete the whole subtrie that starts with the key prefix.
## Key will be encoded into binary array format first. ## Key will be encoded into binary array format first.
## It will call `setAux` with `deleteSubtrie` set to true. ## It will call `setAux` with `deleteSubtrie` set to true.
var keyBits = bits MutByteRange(key.toRange) var keyBits = key.bits
self.rootHash = self.setAux(self.rootHash, keyBits, zeroBytesRange, true) self.rootHash = self.setAux(self.rootHash, keyBits, [], true)
# Convenience # Convenience
proc rootNode*(self: BinaryTrie): BytesRange {.inline.} = proc rootNode*(self: BinaryTrie): seq[byte] {.inline.} =
self.db.get(self.rootHash.toOpenArray).toRange self.db.get(self.rootHash)
proc rootNode*(self: var BinaryTrie, node: BytesContainer) {.inline.} = proc rootNode*(self: var BinaryTrie, node: openArray[byte]) {.inline.} =
self.rootHash = self.hashAndSave(toRange(node)) self.rootHash = self.hashAndSave(node)
# Dictionary API # Dictionary API
template `[]`*(self: BinaryTrie, key: BytesContainer): BytesRange = template `[]`*(self: BinaryTrie, key: seq[byte]): seq[byte] =
self.get(key) self.get(key)
template `[]=`*(self: var BinaryTrie, key, value: distinct BytesContainer) = template `[]=`*(self: var BinaryTrie, key, value: seq[byte]) =
self.set(key, value) self.set(key, value)
template contains*(self: BinaryTrie, key: BytesContainer): bool = template contains*(self: BinaryTrie, key: seq[byte]): bool =
self.exists(key) self.exists(key)

View File

@ -1,6 +1,5 @@
import import
eth/rlp/types, stew/ranges/bitranges, ./trie_defs, ./binary, ./binaries, ./db, ./trie_utils, ./trie_bitseq
trie_defs, binary, binaries, db, trie_utils
type type
DB = TrieDatabaseRef DB = TrieDatabaseRef
@ -8,10 +7,10 @@ type
# TODO: replace the usages of this with regular asserts # TODO: replace the usages of this with regular asserts
InvalidKeyError* = object of Defect InvalidKeyError* = object of Defect
template query(db: DB, nodeHash: TrieNodeKey): BytesRange = template query(db: DB, nodeHash: TrieNodeKey): seq[byte] =
db.get(nodeHash.toOpenArray).toRange db.get(nodeHash)
proc checkIfBranchExistImpl(db: DB; nodeHash: TrieNodeKey; keyPrefix: TrieBitRange): bool = proc checkIfBranchExistImpl(db: DB; nodeHash: TrieNodeKey; keyPrefix: TrieBitSeq): bool =
if nodeHash == zeroHash: if nodeHash == zeroHash:
return false return false
@ -37,14 +36,14 @@ proc checkIfBranchExistImpl(db: DB; nodeHash: TrieNodeKey; keyPrefix: TrieBitRan
else: else:
return checkIfBranchExistImpl(db, node.rightChild, keyPrefix.sliceToEnd(1)) return checkIfBranchExistImpl(db, node.rightChild, keyPrefix.sliceToEnd(1))
proc checkIfBranchExist*(db: DB; rootHash: BytesContainer | KeccakHash, keyPrefix: BytesContainer): bool = proc checkIfBranchExist*(db: DB; rootHash: TrieNodeKey, keyPrefix: openArray[byte]): bool =
## Given a key prefix, return whether this prefix is ## Given a key prefix, return whether this prefix is
## the prefix of an existing key in the trie. ## the prefix of an existing key in the trie.
checkValidHashZ(rootHash) checkValidHashZ(rootHash)
var keyPrefixBits = bits MutByteRange(keyPrefix.toRange) var keyPrefixBits = bits keyPrefix
checkIfBranchExistImpl(db, toRange(rootHash), keyPrefixBits) checkIfBranchExistImpl(db, rootHash, keyPrefixBits)
proc getBranchImpl(db: DB; nodeHash: TrieNodeKey, keyPath: TrieBitRange, output: var seq[BytesRange]) = proc getBranchImpl(db: DB; nodeHash: TrieNodeKey, keyPath: TrieBitSeq, output: var seq[seq[byte]]) =
if nodeHash == zeroHash: return if nodeHash == zeroHash: return
let nodeVal = db.query(nodeHash) let nodeVal = db.query(nodeHash)
@ -76,14 +75,14 @@ proc getBranchImpl(db: DB; nodeHash: TrieNodeKey, keyPath: TrieBitRange, output:
else: else:
getBranchImpl(db, node.rightChild, keyPath.sliceToEnd(1), output) getBranchImpl(db, node.rightChild, keyPath.sliceToEnd(1), output)
proc getBranch*(db: DB; rootHash: BytesContainer | KeccakHash; key: BytesContainer): seq[BytesRange] = proc getBranch*(db: DB; rootHash: seq[byte]; key: openArray[byte]): seq[seq[byte]] =
## Get a long-format Merkle branch ## Get a long-format Merkle branch
checkValidHashZ(rootHash) checkValidHashZ(rootHash)
result = @[] result = @[]
var keyBits = bits MutByteRange(key.toRange) var keyBits = bits key
getBranchImpl(db, toRange(rootHash), keyBits, result) getBranchImpl(db, rootHash, keyBits, result)
proc isValidBranch*(branch: seq[BytesRange], rootHash: BytesContainer | KeccakHash, key, value: BytesContainer): bool = proc isValidBranch*(branch: seq[seq[byte]], rootHash: seq[byte], key, value: openArray[byte]): bool =
checkValidHashZ(rootHash) checkValidHashZ(rootHash)
# branch must not be empty # branch must not be empty
doAssert(branch.len != 0) doAssert(branch.len != 0)
@ -92,18 +91,18 @@ proc isValidBranch*(branch: seq[BytesRange], rootHash: BytesContainer | KeccakHa
for node in branch: for node in branch:
doAssert(node.len != 0) doAssert(node.len != 0)
let nodeHash = keccakHash(node) let nodeHash = keccakHash(node)
db.put(nodeHash.toOpenArray, node.toOpenArray) db.put(nodeHash.data, node)
var trie = initBinaryTrie(db, rootHash) var trie = initBinaryTrie(db, rootHash)
result = trie.get(key) == toRange(value) result = trie.get(key) == value
proc getTrieNodesImpl(db: DB; nodeHash: TrieNodeKey, output: var seq[BytesRange]): bool = proc getTrieNodesImpl(db: DB; nodeHash: TrieNodeKey, output: var seq[seq[byte]]): bool =
## Get full trie of a given root node ## Get full trie of a given root node
if nodeHash.isZeroHash(): return false if nodeHash.isZeroHash(): return false
var nodeVal: BytesRange var nodeVal: seq[byte]
if nodeHash.toOpenArray in db: if nodeHash in db:
nodeVal = db.query(nodeHash) nodeVal = db.query(nodeHash)
else: else:
return false return false
@ -121,19 +120,19 @@ proc getTrieNodesImpl(db: DB; nodeHash: TrieNodeKey, output: var seq[BytesRange]
of LEAF_TYPE: of LEAF_TYPE:
output.add nodeVal output.add nodeVal
proc getTrieNodes*(db: DB; nodeHash: BytesContainer | KeccakHash): seq[BytesRange] = proc getTrieNodes*(db: DB; nodeHash: TrieNodeKey): seq[seq[byte]] =
checkValidHashZ(nodeHash) checkValidHashZ(nodeHash)
result = @[] result = @[]
discard getTrieNodesImpl(db, toRange(nodeHash), result) discard getTrieNodesImpl(db, nodeHash, result)
proc getWitnessImpl*(db: DB; nodeHash: TrieNodeKey; keyPath: TrieBitRange; output: var seq[BytesRange]) = proc getWitnessImpl*(db: DB; nodeHash: TrieNodeKey; keyPath: TrieBitSeq; output: var seq[seq[byte]]) =
if keyPath.len == 0: if keyPath.len == 0:
if not getTrieNodesImpl(db, nodeHash, output): return if not getTrieNodesImpl(db, nodeHash, output): return
if nodeHash.isZeroHash(): return if nodeHash.isZeroHash(): return
var nodeVal: BytesRange var nodeVal: seq[byte]
if nodeHash.toOpenArray in db: if nodeHash in db:
nodeVal = db.query(nodeHash) nodeVal = db.query(nodeHash)
else: else:
return return
@ -157,7 +156,7 @@ proc getWitnessImpl*(db: DB; nodeHash: TrieNodeKey; keyPath: TrieBitRange; outpu
else: else:
getWitnessImpl(db, node.rightChild, keyPath.sliceToEnd(1), output) getWitnessImpl(db, node.rightChild, keyPath.sliceToEnd(1), output)
proc getWitness*(db: DB; nodeHash: BytesContainer | KeccakHash; key: BytesContainer): seq[BytesRange] = proc getWitness*(db: DB; nodeHash: TrieNodeKey; key: openArray[byte]): seq[seq[byte]] =
## Get all witness given a keyPath prefix. ## Get all witness given a keyPath prefix.
## Include ## Include
## ##
@ -165,5 +164,5 @@ proc getWitness*(db: DB; nodeHash: BytesContainer | KeccakHash; key: BytesContai
## 2. witness in the subtrie of the last node in keyPath ## 2. witness in the subtrie of the last node in keyPath
checkValidHashZ(nodeHash) checkValidHashZ(nodeHash)
result = @[] result = @[]
var keyBits = bits MutByteRange(key.toRange) var keyBits = bits key
getWitnessImpl(db, toRange(nodeHash), keyBits, result) getWitnessImpl(db, nodeHash, keyBits, result)

View File

@ -1,30 +1,22 @@
import import
tables, hashes, sets, tables, hashes, sets,
nimcrypto/[hash, keccak], eth/rlp, nimcrypto/[hash, keccak],
trie_defs, db_tracing trie_defs, db_tracing
type type
MemDBRec = object MemDBRec = object
refCount: int refCount: int
value: Bytes value: seq[byte]
MemoryLayer* = ref object of RootObj MemoryLayer* = ref object of RootObj
records: Table[Bytes, MemDBRec] records: Table[seq[byte], MemDBRec]
deleted: HashSet[Bytes] deleted: HashSet[seq[byte]]
TrieDatabaseConcept* = concept DB
mixin put, del, get
put(var DB, KeccakHash, BytesRange)
del(var DB, KeccakHash)
get(DB, KeccakHash) is Bytes
contains(DB, KeccakHash) is bool
# XXX: poor's man vtref types # XXX: poor's man vtref types
PutProc = proc (db: RootRef, key, val: openarray[byte]) {. PutProc = proc (db: RootRef, key, val: openarray[byte]) {.
gcsafe, raises: [Defect, CatchableError] .} gcsafe, raises: [Defect, CatchableError] .}
GetProc = proc (db: RootRef, key: openarray[byte]): Bytes {. GetProc = proc (db: RootRef, key: openarray[byte]): seq[byte] {.
gcsafe, raises: [Defect, CatchableError] .} gcsafe, raises: [Defect, CatchableError] .}
## The result will be empty seq if not found ## The result will be empty seq if not found
@ -56,14 +48,14 @@ type
TransactionID* = distinct DbTransaction TransactionID* = distinct DbTransaction
proc put*(db: TrieDatabaseRef, key, val: openarray[byte]) {.gcsafe.} proc put*(db: TrieDatabaseRef, key, val: openarray[byte]) {.gcsafe.}
proc get*(db: TrieDatabaseRef, key: openarray[byte]): Bytes {.gcsafe.} proc get*(db: TrieDatabaseRef, key: openarray[byte]): seq[byte] {.gcsafe.}
proc del*(db: TrieDatabaseRef, key: openarray[byte]) {.gcsafe.} proc del*(db: TrieDatabaseRef, key: openarray[byte]) {.gcsafe.}
proc beginTransaction*(db: TrieDatabaseRef): DbTransaction {.gcsafe.} proc beginTransaction*(db: TrieDatabaseRef): DbTransaction {.gcsafe.}
proc keccak*(r: BytesRange): KeccakHash = proc keccak*(r: openArray[byte]): KeccakHash =
keccak256.digest r.toOpenArray keccak256.digest r
proc get*(db: MemoryLayer, key: openarray[byte]): Bytes = proc get*(db: MemoryLayer, key: openarray[byte]): seq[byte] =
result = db.records.getOrDefault(@key).value result = db.records.getOrDefault(@key).value
traceGet key, result traceGet key, result
@ -107,8 +99,8 @@ proc put*(db: MemoryLayer, key, val: openarray[byte]) =
proc newMemoryLayer: MemoryLayer = proc newMemoryLayer: MemoryLayer =
result.new result.new
result.records = initTable[Bytes, MemDBRec]() result.records = initTable[seq[byte], MemDBRec]()
result.deleted = initHashSet[Bytes]() result.deleted = initHashSet[seq[byte]]()
proc commit(memDb: MemoryLayer, db: TrieDatabaseRef, applyDeletes: bool = true) = proc commit(memDb: MemoryLayer, db: TrieDatabaseRef, applyDeletes: bool = true) =
if applyDeletes: if applyDeletes:
@ -136,7 +128,7 @@ proc totalRecordsInMemoryDB*(db: TrieDatabaseRef): int =
doAssert isMemoryDB(db) doAssert isMemoryDB(db)
return db.mostInnerTransaction.modifications.records.len return db.mostInnerTransaction.modifications.records.len
iterator pairsInMemoryDB*(db: TrieDatabaseRef): (Bytes, Bytes) = iterator pairsInMemoryDB*(db: TrieDatabaseRef): (seq[byte], seq[byte]) =
doAssert isMemoryDB(db) doAssert isMemoryDB(db)
for k, v in db.mostInnerTransaction.modifications.records: for k, v in db.mostInnerTransaction.modifications.records:
yield (k, v.value) yield (k, v.value)
@ -178,7 +170,7 @@ proc putImpl[T](db: RootRef, key, val: openarray[byte]) =
mixin put mixin put
put(T(db), key, val) put(T(db), key, val)
proc getImpl[T](db: RootRef, key: openarray[byte]): Bytes = proc getImpl[T](db: RootRef, key: openarray[byte]): seq[byte] =
mixin get mixin get
return get(T(db), key) return get(T(db), key)
@ -207,7 +199,7 @@ proc put*(db: TrieDatabaseRef, key, val: openarray[byte]) =
else: else:
db.putProc(db.obj, key, val) db.putProc(db.obj, key, val)
proc get*(db: TrieDatabaseRef, key: openarray[byte]): Bytes = proc get*(db: TrieDatabaseRef, key: openarray[byte]): seq[byte] =
# TODO: This is quite inefficient and it won't be necessary once # TODO: This is quite inefficient and it won't be necessary once
# https://github.com/nim-lang/Nim/issues/7457 is developed. # https://github.com/nim-lang/Nim/issues/7457 is developed.
let key = @key let key = @key

View File

@ -9,11 +9,11 @@ when db_tracing in ["on", "1"]:
template traceGet*(k, v) = template traceGet*(k, v) =
if dbTracingEnabled: if dbTracingEnabled:
echo "GET ", toHex(k), " = ", toHex(v) # rlpFromBytes(@v.toRange).inspect echo "GET ", toHex(k), " = ", toHex(v) # rlpFromBytes(v).inspect
template tracePut*(k, v) = template tracePut*(k, v) =
if dbTracingEnabled: if dbTracingEnabled:
echo "PUT ", toHex(k), " = ", toHex(v) # rlpFromBytes(@v.toRange).inspect echo "PUT ", toHex(k), " = ", toHex(v) # rlpFromBytes(v).inspect
template traceDel*(k) = template traceDel*(k) =
if dbTracingEnabled: if dbTracingEnabled:

View File

@ -1,7 +1,7 @@
import import
tables, tables,
nimcrypto/[keccak, hash, utils], stew/ranges/ptr_arith, eth/rlp, nimcrypto/[keccak, hash], eth/rlp,
trie_defs, nibbles, trie_utils as trieUtils, db trie_defs, nibbles, db
type type
TrieNodeKey = object TrieNodeKey = object
@ -17,28 +17,26 @@ type
SecureHexaryTrie* = distinct HexaryTrie SecureHexaryTrie* = distinct HexaryTrie
TrieNode = Rlp
template len(key: TrieNodeKey): int = template len(key: TrieNodeKey): int =
key.usedBytes.int key.usedBytes.int
proc keccak*(r: BytesRange): KeccakHash = proc keccak*(r: openArray[byte]): KeccakHash =
keccak256.digest r.toOpenArray keccak256.digest r
template asDbKey(k: TrieNodeKey): untyped = template asDbKey(k: TrieNodeKey): untyped =
doAssert k.usedBytes == 32 doAssert k.usedBytes == 32
k.hash.data k.hash.data
proc expectHash(r: Rlp): BytesRange = proc expectHash(r: Rlp): seq[byte] =
result = r.toBytes result = r.toBytes
if result.len != 32: if result.len != 32:
raise newException(RlpTypeMismatch, raise newException(RlpTypeMismatch,
"RLP expected to be a Keccak hash value, but has an incorrect length") "RLP expected to be a Keccak hash value, but has an incorrect length")
proc dbPut(db: DB, data: BytesRange): TrieNodeKey {.gcsafe.} proc dbPut(db: DB, data: openArray[byte]): TrieNodeKey {.gcsafe.}
template get(db: DB, key: Rlp): BytesRange = template get(db: DB, key: Rlp): seq[byte] =
db.get(key.expectHash.toOpenArray).toRange db.get(key.expectHash)
converter toTrieNodeKey(hash: KeccakHash): TrieNodeKey = converter toTrieNodeKey(hash: KeccakHash): TrieNodeKey =
result.hash = hash result.hash = hash
@ -54,7 +52,7 @@ template initSecureHexaryTrie*(db: DB, rootHash: KeccakHash, isPruning = true):
proc initHexaryTrie*(db: DB, isPruning = true): HexaryTrie = proc initHexaryTrie*(db: DB, isPruning = true): HexaryTrie =
result.db = db result.db = db
result.root = result.db.dbPut(emptyRlp.toRange) result.root = result.db.dbPut(emptyRlp)
result.isPruning = isPruning result.isPruning = isPruning
template initSecureHexaryTrie*(db: DB, isPruning = true): SecureHexaryTrie = template initSecureHexaryTrie*(db: DB, isPruning = true): SecureHexaryTrie =
@ -72,38 +70,32 @@ template prune(t: HexaryTrie, x: openArray[byte]) =
proc isPruning*(t: HexaryTrie): bool = proc isPruning*(t: HexaryTrie): bool =
t.isPruning t.isPruning
proc getLocalBytes(x: TrieNodeKey): BytesRange = proc getLocalBytes(x: TrieNodeKey): seq[byte] =
## This proc should be used on nodes using the optimization ## This proc should be used on nodes using the optimization
## of short values within the key. ## of short values within the key.
doAssert x.usedBytes < 32 doAssert x.usedBytes < 32
x.hash.data[0..<x.usedBytes]
when defined(rangesEnableUnsafeAPI): template keyToLocalBytes(db: DB, k: TrieNodeKey): seq[byte] =
result = unsafeRangeConstruction(x.data, x.usedBytes)
else:
var dataCopy = newSeq[byte](x.usedBytes)
copyMem(dataCopy.baseAddr, x.hash.data.baseAddr, x.usedBytes)
return dataCopy.toRange
template keyToLocalBytes(db: DB, k: TrieNodeKey): BytesRange =
if k.len < 32: k.getLocalBytes if k.len < 32: k.getLocalBytes
else: db.get(k.asDbKey).toRange else: db.get(k.asDbKey)
template extensionNodeKey(r: Rlp): auto = template extensionNodeKey(r: Rlp): auto =
hexPrefixDecode r.listElem(0).toBytes hexPrefixDecode r.listElem(0).toBytes
proc getAux(db: DB, nodeRlp: Rlp, path: NibblesRange): BytesRange {.gcsafe.} proc getAux(db: DB, nodeRlp: Rlp, path: NibblesSeq): seq[byte] {.gcsafe.}
proc getAuxByHash(db: DB, node: TrieNodeKey, path: NibblesRange): BytesRange = proc getAuxByHash(db: DB, node: TrieNodeKey, path: NibblesSeq): seq[byte] =
var nodeRlp = rlpFromBytes keyToLocalBytes(db, node) var nodeRlp = rlpFromBytes keyToLocalBytes(db, node)
return getAux(db, nodeRlp, path) return getAux(db, nodeRlp, path)
template getLookup(elem: untyped): untyped = template getLookup(elem: untyped): untyped =
if elem.isList: elem if elem.isList: elem
else: rlpFromBytes(get(db, toOpenArray(elem.expectHash)).toRange) else: rlpFromBytes(get(db, elem.expectHash))
proc getAux(db: DB, nodeRlp: Rlp, path: NibblesRange): BytesRange = proc getAux(db: DB, nodeRlp: Rlp, path: NibblesSeq): seq[byte] =
if not nodeRlp.hasData or nodeRlp.isEmpty: if not nodeRlp.hasData or nodeRlp.isEmpty:
return zeroBytesRange return
case nodeRlp.listLen case nodeRlp.listLen
of 2: of 2:
@ -118,13 +110,13 @@ proc getAux(db: DB, nodeRlp: Rlp, path: NibblesRange): BytesRange =
let nextLookup = value.getLookup let nextLookup = value.getLookup
return getAux(db, nextLookup, path.slice(sharedNibbles)) return getAux(db, nextLookup, path.slice(sharedNibbles))
return zeroBytesRange return
of 17: of 17:
if path.len == 0: if path.len == 0:
return nodeRlp.listElem(16).toBytes return nodeRlp.listElem(16).toBytes
var branch = nodeRlp.listElem(path[0].int) var branch = nodeRlp.listElem(path[0].int)
if branch.isEmpty: if branch.isEmpty:
return zeroBytesRange return
else: else:
let nextLookup = branch.getLookup let nextLookup = branch.getLookup
return getAux(db, nextLookup, path.slice(1)) return getAux(db, nextLookup, path.slice(1))
@ -132,10 +124,10 @@ proc getAux(db: DB, nodeRlp: Rlp, path: NibblesRange): BytesRange =
raise newException(CorruptedTrieDatabase, raise newException(CorruptedTrieDatabase,
"HexaryTrie node with an unexpected number of children") "HexaryTrie node with an unexpected number of children")
proc get*(self: HexaryTrie; key: BytesRange): BytesRange = proc get*(self: HexaryTrie; key: openArray[byte]): seq[byte] =
return getAuxByHash(self.db, self.root, initNibbleRange(key)) return getAuxByHash(self.db, self.root, initNibbleRange(key))
proc getKeysAux(db: DB, stack: var seq[tuple[nodeRlp: Rlp, path: NibblesRange]]): BytesRange = proc getKeysAux(db: DB, stack: var seq[tuple[nodeRlp: Rlp, path: NibblesSeq]]): seq[byte] =
while stack.len > 0: while stack.len > 0:
let (nodeRlp, path) = stack.pop() let (nodeRlp, path) = stack.pop()
if not nodeRlp.hasData or nodeRlp.isEmpty: if not nodeRlp.hasData or nodeRlp.isEmpty:
@ -172,15 +164,14 @@ proc getKeysAux(db: DB, stack: var seq[tuple[nodeRlp: Rlp, path: NibblesRange]])
raise newException(CorruptedTrieDatabase, raise newException(CorruptedTrieDatabase,
"HexaryTrie node with an unexpected number of children") "HexaryTrie node with an unexpected number of children")
iterator keys*(self: HexaryTrie): BytesRange = iterator keys*(self: HexaryTrie): seq[byte] =
var var
nodeRlp = rlpFromBytes keyToLocalBytes(self.db, self.root) nodeRlp = rlpFromBytes keyToLocalBytes(self.db, self.root)
path = newRange[byte](0) stack = @[(nodeRlp, initNibbleRange([]))]
stack = @[(nodeRlp, initNibbleRange(path))]
while stack.len > 0: while stack.len > 0:
yield getKeysAux(self.db, stack) yield getKeysAux(self.db, stack)
proc getValuesAux(db: DB, stack: var seq[Rlp]): BytesRange = proc getValuesAux(db: DB, stack: var seq[Rlp]): seq[byte] =
while stack.len > 0: while stack.len > 0:
let nodeRlp = stack.pop() let nodeRlp = stack.pop()
if not nodeRlp.hasData or nodeRlp.isEmpty: if not nodeRlp.hasData or nodeRlp.isEmpty:
@ -211,14 +202,14 @@ proc getValuesAux(db: DB, stack: var seq[Rlp]): BytesRange =
raise newException(CorruptedTrieDatabase, raise newException(CorruptedTrieDatabase,
"HexaryTrie node with an unexpected number of children") "HexaryTrie node with an unexpected number of children")
iterator values*(self: HexaryTrie): BytesRange = iterator values*(self: HexaryTrie): seq[byte] =
var var
nodeRlp = rlpFromBytes keyToLocalBytes(self.db, self.root) nodeRlp = rlpFromBytes keyToLocalBytes(self.db, self.root)
stack = @[nodeRlp] stack = @[nodeRlp]
while stack.len > 0: while stack.len > 0:
yield getValuesAux(self.db, stack) yield getValuesAux(self.db, stack)
proc getPairsAux(db: DB, stack: var seq[tuple[nodeRlp: Rlp, path: NibblesRange]]): (BytesRange, BytesRange) = proc getPairsAux(db: DB, stack: var seq[tuple[nodeRlp: Rlp, path: NibblesSeq]]): (seq[byte], seq[byte]) =
while stack.len > 0: while stack.len > 0:
let (nodeRlp, path) = stack.pop() let (nodeRlp, path) = stack.pop()
if not nodeRlp.hasData or nodeRlp.isEmpty: if not nodeRlp.hasData or nodeRlp.isEmpty:
@ -254,11 +245,10 @@ proc getPairsAux(db: DB, stack: var seq[tuple[nodeRlp: Rlp, path: NibblesRange]]
raise newException(CorruptedTrieDatabase, raise newException(CorruptedTrieDatabase,
"HexaryTrie node with an unexpected number of children") "HexaryTrie node with an unexpected number of children")
iterator pairs*(self: HexaryTrie): (BytesRange, BytesRange) = iterator pairs*(self: HexaryTrie): (seq[byte], seq[byte]) =
var var
nodeRlp = rlpFromBytes keyToLocalBytes(self.db, self.root) nodeRlp = rlpFromBytes keyToLocalBytes(self.db, self.root)
path = newRange[byte](0) stack = @[(nodeRlp, initNibbleRange([]))]
stack = @[(nodeRlp, initNibbleRange(path))]
while stack.len > 0: while stack.len > 0:
# perhaps a Nim bug #9778 # perhaps a Nim bug #9778
# cannot yield the helper proc directly # cannot yield the helper proc directly
@ -266,7 +256,7 @@ iterator pairs*(self: HexaryTrie): (BytesRange, BytesRange) =
let res = getPairsAux(self.db, stack) let res = getPairsAux(self.db, stack)
yield res yield res
iterator replicate*(self: HexaryTrie): (BytesRange, BytesRange) = iterator replicate*(self: HexaryTrie): (seq[byte], seq[byte]) =
# this iterator helps 'rebuild' the entire trie without # this iterator helps 'rebuild' the entire trie without
# going through a trie algorithm, but it will pull the entire # going through a trie algorithm, but it will pull the entire
# low level KV pairs. Thus the target db will only use put operations # low level KV pairs. Thus the target db will only use put operations
@ -274,19 +264,18 @@ iterator replicate*(self: HexaryTrie): (BytesRange, BytesRange) =
var var
localBytes = keyToLocalBytes(self.db, self.root) localBytes = keyToLocalBytes(self.db, self.root)
nodeRlp = rlpFromBytes localBytes nodeRlp = rlpFromBytes localBytes
path = newRange[byte](0) stack = @[(nodeRlp, initNibbleRange([]))]
stack = @[(nodeRlp, initNibbleRange(path))]
template pushOrYield(elem: untyped) = template pushOrYield(elem: untyped) =
if elem.isList: if elem.isList:
stack.add((elem, key)) stack.add((elem, key))
else: else:
let rlpBytes = get(self.db, toOpenArray(elem.expectHash)).toRange let rlpBytes = get(self.db, elem.expectHash)
let nextLookup = rlpFromBytes(rlpBytes) let nextLookup = rlpFromBytes(rlpBytes)
stack.add((nextLookup, key)) stack.add((nextLookup, key))
yield (elem.toBytes, rlpBytes) yield (elem.toBytes, rlpBytes)
yield (self.rootHash.toRange, localBytes) yield (@(self.rootHash.data), localBytes)
while stack.len > 0: while stack.len > 0:
let (nodeRlp, path) = stack.pop() let (nodeRlp, path) = stack.pop()
if not nodeRlp.hasData or nodeRlp.isEmpty: if not nodeRlp.hasData or nodeRlp.isEmpty:
@ -310,21 +299,21 @@ iterator replicate*(self: HexaryTrie): (BytesRange, BytesRange) =
raise newException(CorruptedTrieDatabase, raise newException(CorruptedTrieDatabase,
"HexaryTrie node with an unexpected number of children") "HexaryTrie node with an unexpected number of children")
proc getValues*(self: HexaryTrie): seq[BytesRange] = proc getValues*(self: HexaryTrie): seq[seq[byte]] =
result = @[] result = @[]
for v in self.values: for v in self.values:
result.add v result.add v
proc getKeys*(self: HexaryTrie): seq[BytesRange] = proc getKeys*(self: HexaryTrie): seq[seq[byte]] =
result = @[] result = @[]
for k in self.keys: for k in self.keys:
result.add k result.add k
template getNode(elem: untyped): untyped = template getNode(elem: untyped): untyped =
if elem.isList: elem.rawData if elem.isList: @(elem.rawData)
else: get(db, toOpenArray(elem.expectHash)).toRange else: get(db, elem.expectHash)
proc getBranchAux(db: DB, node: BytesRange, path: NibblesRange, output: var seq[BytesRange]) = proc getBranchAux(db: DB, node: openArray[byte], path: NibblesSeq, output: var seq[seq[byte]]) =
var nodeRlp = rlpFromBytes node var nodeRlp = rlpFromBytes node
if not nodeRlp.hasData or nodeRlp.isEmpty: return if not nodeRlp.hasData or nodeRlp.isEmpty: return
@ -349,21 +338,21 @@ proc getBranchAux(db: DB, node: BytesRange, path: NibblesRange, output: var seq[
raise newException(CorruptedTrieDatabase, raise newException(CorruptedTrieDatabase,
"HexaryTrie node with an unexpected number of children") "HexaryTrie node with an unexpected number of children")
proc getBranch*(self: HexaryTrie; key: BytesRange): seq[BytesRange] = proc getBranch*(self: HexaryTrie; key: openArray[byte]): seq[seq[byte]] =
result = @[] result = @[]
var node = keyToLocalBytes(self.db, self.root) var node = keyToLocalBytes(self.db, self.root)
result.add node result.add node
getBranchAux(self.db, node, initNibbleRange(key), result) getBranchAux(self.db, node, initNibbleRange(key), result)
proc dbDel(t: var HexaryTrie, data: BytesRange) = proc dbDel(t: var HexaryTrie, data: openArray[byte]) =
if data.len >= 32: t.prune(data.keccak.data) if data.len >= 32: t.prune(data.keccak.data)
proc dbPut(db: DB, data: BytesRange): TrieNodeKey = proc dbPut(db: DB, data: openArray[byte]): TrieNodeKey =
result.hash = data.keccak result.hash = data.keccak
result.usedBytes = 32 result.usedBytes = 32
put(db, result.asDbKey, data.toOpenArray) put(db, result.asDbKey, data)
proc appendAndSave(rlpWriter: var RlpWriter, data: BytesRange, db: DB) = proc appendAndSave(rlpWriter: var RlpWriter, data: openArray[byte], db: DB) =
if data.len >= 32: if data.len >= 32:
var nodeKey = dbPut(db, data) var nodeKey = dbPut(db, data)
rlpWriter.append(nodeKey.hash) rlpWriter.append(nodeKey.hash)
@ -373,7 +362,7 @@ proc appendAndSave(rlpWriter: var RlpWriter, data: BytesRange, db: DB) =
proc isTrieBranch(rlp: Rlp): bool = proc isTrieBranch(rlp: Rlp): bool =
rlp.isList and (var len = rlp.listLen; len == 2 or len == 17) rlp.isList and (var len = rlp.listLen; len == 2 or len == 17)
proc replaceValue(data: Rlp, key: NibblesRange, value: BytesRange): Bytes = proc replaceValue(data: Rlp, key: NibblesSeq, value: openArray[byte]): seq[byte] =
if data.isEmpty: if data.isEmpty:
let prefix = hexPrefixEncode(key, true) let prefix = hexPrefixEncode(key, true)
return encodeList(prefix, value) return encodeList(prefix, value)
@ -403,11 +392,6 @@ proc isTwoItemNode(self: HexaryTrie; r: Rlp): bool =
else: else:
return r.isList and r.listLen == 2 return r.isList and r.listLen == 2
proc isLeaf(r: Rlp): bool =
doAssert r.isList and r.listLen == 2
let b = r.listElem(0).toBytes()
return (b[0] and 0x20) != 0
proc findSingleChild(r: Rlp; childPos: var byte): Rlp = proc findSingleChild(r: Rlp; childPos: var byte): Rlp =
result = zeroBytesRlp result = zeroBytesRlp
var i: byte = 0 var i: byte = 0
@ -421,10 +405,10 @@ proc findSingleChild(r: Rlp; childPos: var byte): Rlp =
return zeroBytesRlp return zeroBytesRlp
inc i inc i
proc deleteAt(self: var HexaryTrie; origRlp: Rlp, key: NibblesRange): BytesRange {.gcsafe.} proc deleteAt(self: var HexaryTrie; origRlp: Rlp, key: NibblesSeq): seq[byte] {.gcsafe.}
proc deleteAux(self: var HexaryTrie; rlpWriter: var RlpWriter; proc deleteAux(self: var HexaryTrie; rlpWriter: var RlpWriter;
origRlp: Rlp; path: NibblesRange): bool = origRlp: Rlp; path: NibblesSeq): bool =
if origRlp.isEmpty: if origRlp.isEmpty:
return false return false
@ -439,16 +423,15 @@ proc deleteAux(self: var HexaryTrie; rlpWriter: var RlpWriter;
rlpWriter.appendAndSave(b, self.db) rlpWriter.appendAndSave(b, self.db)
return true return true
proc graft(self: var HexaryTrie; r: Rlp): Bytes = proc graft(self: var HexaryTrie; r: Rlp): seq[byte] =
doAssert r.isList and r.listLen == 2 doAssert r.isList and r.listLen == 2
var (origIsLeaf, origPath) = r.extensionNodeKey var (_, origPath) = r.extensionNodeKey
var value = r.listElem(1) var value = r.listElem(1)
var n: Rlp
if not value.isList: if not value.isList:
let nodeKey = value.expectHash let nodeKey = value.expectHash
var resolvedData = self.db.get(nodeKey.toOpenArray).toRange var resolvedData = self.db.get(nodeKey)
self.prune(nodeKey.toOpenArray) self.prune(nodeKey)
value = rlpFromBytes resolvedData value = rlpFromBytes resolvedData
doAssert value.listLen == 2 doAssert value.listLen == 2
@ -460,10 +443,10 @@ proc graft(self: var HexaryTrie; r: Rlp): Bytes =
return rlpWriter.finish return rlpWriter.finish
proc mergeAndGraft(self: var HexaryTrie; proc mergeAndGraft(self: var HexaryTrie;
soleChild: Rlp, childPos: byte): Bytes = soleChild: Rlp, childPos: byte): seq[byte] =
var output = initRlpList(2) var output = initRlpList(2)
if childPos == 16: if childPos == 16:
output.append hexPrefixEncode(zeroNibblesRange, true) output.append hexPrefixEncode(NibblesSeq(), true)
else: else:
doAssert(not soleChild.isEmpty) doAssert(not soleChild.isEmpty)
output.append int(hexPrefixEncodeByte(childPos)) output.append int(hexPrefixEncodeByte(childPos))
@ -471,20 +454,20 @@ proc mergeAndGraft(self: var HexaryTrie;
result = output.finish() result = output.finish()
if self.isTwoItemNode(soleChild): if self.isTwoItemNode(soleChild):
result = self.graft(rlpFromBytes(result.toRange)) result = self.graft(rlpFromBytes(result))
proc deleteAt(self: var HexaryTrie; proc deleteAt(self: var HexaryTrie;
origRlp: Rlp, key: NibblesRange): BytesRange = origRlp: Rlp, key: NibblesSeq): seq[byte] =
if origRlp.isEmpty: if origRlp.isEmpty:
return zeroBytesRange return
doAssert origRlp.isTrieBranch doAssert origRlp.isTrieBranch
let origBytes = origRlp.rawData let origBytes = @(origRlp.rawData)
if origRlp.listLen == 2: if origRlp.listLen == 2:
let (isLeaf, k) = origRlp.extensionNodeKey let (isLeaf, k) = origRlp.extensionNodeKey
if k == key and isLeaf: if k == key and isLeaf:
self.dbDel origBytes self.dbDel origBytes
return emptyRlp.toRange return emptyRlp
if key.startsWith(k): if key.startsWith(k):
var var
@ -493,22 +476,22 @@ proc deleteAt(self: var HexaryTrie;
value = origRlp.listElem(1) value = origRlp.listElem(1)
rlpWriter.append(path) rlpWriter.append(path)
if not self.deleteAux(rlpWriter, value, key.slice(k.len)): if not self.deleteAux(rlpWriter, value, key.slice(k.len)):
return zeroBytesRange return
self.dbDel origBytes self.dbDel origBytes
var finalBytes = rlpWriter.finish.toRange var finalBytes = rlpWriter.finish
var rlp = rlpFromBytes(finalBytes) var rlp = rlpFromBytes(finalBytes)
if self.isTwoItemNode(rlp.listElem(1)): if self.isTwoItemNode(rlp.listElem(1)):
return self.graft(rlp).toRange return self.graft(rlp)
return finalBytes return finalBytes
else: else:
return zeroBytesRange return
else: else:
if key.len == 0 and origRlp.listElem(16).isEmpty: if key.len == 0 and origRlp.listElem(16).isEmpty:
self.dbDel origBytes self.dbDel origBytes
var foundChildPos: byte var foundChildPos: byte
let singleChild = origRlp.findSingleChild(foundChildPos) let singleChild = origRlp.findSingleChild(foundChildPos)
if singleChild.hasData and foundChildPos != 16: if singleChild.hasData and foundChildPos != 16:
result = self.mergeAndGraft(singleChild, foundChildPos).toRange result = self.mergeAndGraft(singleChild, foundChildPos)
else: else:
var rlpRes = initRlpList(17) var rlpRes = initRlpList(17)
var iter = origRlp var iter = origRlp
@ -518,7 +501,7 @@ proc deleteAt(self: var HexaryTrie;
rlpRes.append iter rlpRes.append iter
iter.skipElem iter.skipElem
rlpRes.append "" rlpRes.append ""
return rlpRes.finish.toRange return rlpRes.finish
else: else:
var rlpWriter = initRlpList(17) var rlpWriter = initRlpList(17)
let keyHead = int(key[0]) let keyHead = int(key[0])
@ -527,20 +510,20 @@ proc deleteAt(self: var HexaryTrie;
for elem in items(origCopy): for elem in items(origCopy):
if i == keyHead: if i == keyHead:
if not self.deleteAux(rlpWriter, elem, key.slice(1)): if not self.deleteAux(rlpWriter, elem, key.slice(1)):
return zeroBytesRange return
else: else:
rlpWriter.append(elem) rlpWriter.append(elem)
inc i inc i
self.dbDel origBytes self.dbDel origBytes
result = rlpWriter.finish.toRange result = rlpWriter.finish
var resultRlp = rlpFromBytes(result) var resultRlp = rlpFromBytes(result)
var foundChildPos: byte var foundChildPos: byte
let singleChild = resultRlp.findSingleChild(foundChildPos) let singleChild = resultRlp.findSingleChild(foundChildPos)
if singleChild.hasData: if singleChild.hasData:
result = self.mergeAndGraft(singleChild, foundChildPos).toRange result = self.mergeAndGraft(singleChild, foundChildPos)
proc del*(self: var HexaryTrie; key: BytesRange) = proc del*(self: var HexaryTrie; key: openArray[byte]) =
var var
rootBytes = keyToLocalBytes(self.db, self.root) rootBytes = keyToLocalBytes(self.db, self.root)
rootRlp = rlpFromBytes rootBytes rootRlp = rlpFromBytes rootBytes
@ -552,16 +535,16 @@ proc del*(self: var HexaryTrie; key: BytesRange) =
self.root = self.db.dbPut(newRootBytes) self.root = self.db.dbPut(newRootBytes)
proc mergeAt(self: var HexaryTrie, orig: Rlp, origHash: KeccakHash, proc mergeAt(self: var HexaryTrie, orig: Rlp, origHash: KeccakHash,
key: NibblesRange, value: BytesRange, key: NibblesSeq, value: openArray[byte],
isInline = false): BytesRange {.gcsafe.} isInline = false): seq[byte] {.gcsafe.}
proc mergeAt(self: var HexaryTrie, rlp: Rlp, proc mergeAt(self: var HexaryTrie, rlp: Rlp,
key: NibblesRange, value: BytesRange, key: NibblesSeq, value: openArray[byte],
isInline = false): BytesRange = isInline = false): seq[byte] =
self.mergeAt(rlp, rlp.rawData.keccak, key, value, isInline) self.mergeAt(rlp, rlp.rawData.keccak, key, value, isInline)
proc mergeAtAux(self: var HexaryTrie, output: var RlpWriter, orig: Rlp, proc mergeAtAux(self: var HexaryTrie, output: var RlpWriter, orig: Rlp,
key: NibblesRange, value: BytesRange) = key: NibblesSeq, value: openArray[byte]) =
var resolved = orig var resolved = orig
var isRemovable = false var isRemovable = false
if not (orig.isList or orig.isEmpty): if not (orig.isList or orig.isEmpty):
@ -572,11 +555,11 @@ proc mergeAtAux(self: var HexaryTrie, output: var RlpWriter, orig: Rlp,
output.appendAndSave(b, self.db) output.appendAndSave(b, self.db)
proc mergeAt(self: var HexaryTrie, orig: Rlp, origHash: KeccakHash, proc mergeAt(self: var HexaryTrie, orig: Rlp, origHash: KeccakHash,
key: NibblesRange, value: BytesRange, key: NibblesSeq, value: openArray[byte],
isInline = false): BytesRange = isInline = false): seq[byte] =
template origWithNewValue: auto = template origWithNewValue: auto =
self.prune(origHash.data) self.prune(origHash.data)
replaceValue(orig, key, value).toRange replaceValue(orig, key, value)
if orig.isEmpty: if orig.isEmpty:
return origWithNewValue() return origWithNewValue()
@ -595,7 +578,7 @@ proc mergeAt(self: var HexaryTrie, orig: Rlp, origHash: KeccakHash,
var r = initRlpList(2) var r = initRlpList(2)
r.append orig.listElem(0) r.append orig.listElem(0)
self.mergeAtAux(r, origValue, key.slice(k.len), value) self.mergeAtAux(r, origValue, key.slice(k.len), value)
return r.finish.toRange return r.finish
if orig.rawData.len >= 32: if orig.rawData.len >= 32:
self.prune(origHash.data) self.prune(origHash.data)
@ -608,9 +591,9 @@ proc mergeAt(self: var HexaryTrie, orig: Rlp, origHash: KeccakHash,
var top = initRlpList(2) var top = initRlpList(2)
top.append hexPrefixEncode(k.slice(0, sharedNibbles), false) top.append hexPrefixEncode(k.slice(0, sharedNibbles), false)
top.appendAndSave(bottom.finish.toRange, self.db) top.appendAndSave(bottom.finish, self.db)
return self.mergeAt(rlpFromBytes(top.finish.toRange), key, value, true) return self.mergeAt(rlpFromBytes(top.finish), key, value, true)
else: else:
# Create a branch node # Create a branch node
var branches = initRlpList(17) var branches = initRlpList(17)
@ -626,7 +609,7 @@ proc mergeAt(self: var HexaryTrie, orig: Rlp, origHash: KeccakHash,
if byte(i) == n: if byte(i) == n:
if isLeaf or k.len > 1: if isLeaf or k.len > 1:
let childNode = encodeList(hexPrefixEncode(k.slice(1), isLeaf), let childNode = encodeList(hexPrefixEncode(k.slice(1), isLeaf),
origValue).toRange origValue)
branches.appendAndSave(childNode, self.db) branches.appendAndSave(childNode, self.db)
else: else:
branches.append origValue branches.append origValue
@ -634,7 +617,7 @@ proc mergeAt(self: var HexaryTrie, orig: Rlp, origHash: KeccakHash,
branches.append "" branches.append ""
branches.append "" branches.append ""
return self.mergeAt(rlpFromBytes(branches.finish.toRange), key, value, true) return self.mergeAt(rlpFromBytes(branches.finish), key, value, true)
else: else:
if key.len == 0: if key.len == 0:
return origWithNewValue() return origWithNewValue()
@ -654,12 +637,12 @@ proc mergeAt(self: var HexaryTrie, orig: Rlp, origHash: KeccakHash,
r.append(elem) r.append(elem)
inc i inc i
return r.finish.toRange return r.finish
proc put*(self: var HexaryTrie; key, value: BytesRange) = proc put*(self: var HexaryTrie; key, value: openArray[byte]) =
let root = self.root.hash let root = self.root.hash
var rootBytes = self.db.get(root.data).toRange var rootBytes = self.db.get(root.data)
doAssert rootBytes.len > 0 doAssert rootBytes.len > 0
let newRootBytes = self.mergeAt(rlpFromBytes(rootBytes), root, let newRootBytes = self.mergeAt(rlpFromBytes(rootBytes), root,
@ -669,23 +652,19 @@ proc put*(self: var HexaryTrie; key, value: BytesRange) =
self.root = self.db.dbPut(newRootBytes) self.root = self.db.dbPut(newRootBytes)
proc put*(self: var SecureHexaryTrie; key, value: BytesRange) = proc put*(self: var SecureHexaryTrie; key, value: openArray[byte]) =
let keyHash = @(key.keccak.data) put(HexaryTrie(self), key.keccak.data, value)
put(HexaryTrie(self), keyHash.toRange, value)
proc get*(self: SecureHexaryTrie; key: BytesRange): BytesRange = proc get*(self: SecureHexaryTrie; key: openArray[byte]): seq[byte] =
let keyHash = @(key.keccak.data) return get(HexaryTrie(self), key.keccak.data)
return get(HexaryTrie(self), keyHash.toRange)
proc del*(self: var SecureHexaryTrie; key: BytesRange) = proc del*(self: var SecureHexaryTrie; key: openArray[byte]) =
let keyHash = @(key.keccak.data) del(HexaryTrie(self), key.keccak.data)
del(HexaryTrie(self), keyHash.toRange)
proc rootHash*(self: SecureHexaryTrie): KeccakHash {.borrow.} proc rootHash*(self: SecureHexaryTrie): KeccakHash {.borrow.}
proc rootHashHex*(self: SecureHexaryTrie): string {.borrow.} proc rootHashHex*(self: SecureHexaryTrie): string {.borrow.}
proc isPruning*(self: SecureHexaryTrie): bool {.borrow.} proc isPruning*(self: SecureHexaryTrie): bool {.borrow.}
template contains*(self: HexaryTrie | SecureHexaryTrie; template contains*(self: HexaryTrie | SecureHexaryTrie;
key: BytesRange): bool = key: openArray[byte]): bool =
self.get(key).len > 0 self.get(key).len > 0

View File

@ -1,33 +1,26 @@
import
trie_defs
type type
NibblesRange* = object NibblesSeq* = object
bytes: ByteRange bytes: seq[byte]
ibegin, iend: int ibegin, iend: int
proc initNibbleRange*(bytes: ByteRange): NibblesRange = proc initNibbleRange*(bytes: openArray[byte]): NibblesSeq =
result.bytes = bytes result.bytes = @bytes
result.ibegin = 0 result.ibegin = 0
result.iend = bytes.len * 2 result.iend = bytes.len * 2
# can't be a const: https://github.com/status-im/nim-eth/issues/6 proc `{}`(r: NibblesSeq, pos: int): byte {.inline.} =
# we can't initialise it here, but since it's already zeroed memory, we don't need to
var zeroNibblesRange* {.threadvar.}: NibblesRange
proc `{}`(r: NibblesRange, pos: int): byte {.inline.} =
## This is a helper for a more raw access to the nibbles. ## This is a helper for a more raw access to the nibbles.
## It works with absolute positions. ## It works with absolute positions.
if pos > r.iend: raise newException(RangeError, "index out of range") if pos > r.iend: raise newException(RangeError, "index out of range")
return if (pos and 1) != 0: (r.bytes[pos div 2] and 0xf) return if (pos and 1) != 0: (r.bytes[pos div 2] and 0xf)
else: (r.bytes[pos div 2] shr 4) else: (r.bytes[pos div 2] shr 4)
template `[]`*(r: NibblesRange, i: int): byte = r{r.ibegin + i} template `[]`*(r: NibblesSeq, i: int): byte = r{r.ibegin + i}
proc len*(r: NibblesRange): int = proc len*(r: NibblesSeq): int =
r.iend - r.ibegin r.iend - r.ibegin
proc `==`*(lhs, rhs: NibblesRange): bool = proc `==`*(lhs, rhs: NibblesSeq): bool =
if lhs.len == rhs.len: if lhs.len == rhs.len:
for i in 0 ..< lhs.len: for i in 0 ..< lhs.len:
if lhs[i] != rhs[i]: if lhs[i] != rhs[i]:
@ -36,7 +29,7 @@ proc `==`*(lhs, rhs: NibblesRange): bool =
else: else:
return false return false
proc `$`*(r: NibblesRange): string = proc `$`*(r: NibblesSeq): string =
result = newStringOfCap(100) result = newStringOfCap(100)
for i in r.ibegin ..< r.iend: for i in r.ibegin ..< r.iend:
let n = int r{i} let n = int r{i}
@ -44,7 +37,7 @@ proc `$`*(r: NibblesRange): string =
else: char(ord('0') + n) else: char(ord('0') + n)
result.add c result.add c
proc slice*(r: NibblesRange, ibegin: int, iend = -1): NibblesRange = proc slice*(r: NibblesSeq, ibegin: int, iend = -1): NibblesSeq =
result.bytes = r.bytes result.bytes = r.bytes
result.ibegin = r.ibegin + ibegin result.ibegin = r.ibegin + ibegin
let e = if iend < 0: r.iend + iend + 1 let e = if iend < 0: r.iend + iend + 1
@ -69,11 +62,11 @@ template writeNibbles(r) {.dirty.} =
result[writeHead] = nextNibble shl 4 result[writeHead] = nextNibble shl 4
oddnessFlag = not oddnessFlag oddnessFlag = not oddnessFlag
proc hexPrefixEncode*(r: NibblesRange, isLeaf = false): Bytes = proc hexPrefixEncode*(r: NibblesSeq, isLeaf = false): seq[byte] =
writeFirstByte(r.len) writeFirstByte(r.len)
writeNibbles(r) writeNibbles(r)
proc hexPrefixEncode*(r1, r2: NibblesRange, isLeaf = false): Bytes = proc hexPrefixEncode*(r1, r2: NibblesSeq, isLeaf = false): seq[byte] =
writeFirstByte(r1.len + r2.len) writeFirstByte(r1.len + r2.len)
writeNibbles(r1) writeNibbles(r1)
writeNibbles(r2) writeNibbles(r2)
@ -82,16 +75,16 @@ proc hexPrefixEncodeByte*(val: byte, isLeaf = false): byte =
doAssert val < 16 doAssert val < 16
result = (((byte(isLeaf) * 2) + 1) shl 4) or val result = (((byte(isLeaf) * 2) + 1) shl 4) or val
proc sharedPrefixLen*(lhs, rhs: NibblesRange): int = proc sharedPrefixLen*(lhs, rhs: NibblesSeq): int =
result = 0 result = 0
while result < lhs.len and result < rhs.len: while result < lhs.len and result < rhs.len:
if lhs[result] != rhs[result]: break if lhs[result] != rhs[result]: break
inc result inc result
proc startsWith*(lhs, rhs: NibblesRange): bool = proc startsWith*(lhs, rhs: NibblesSeq): bool =
sharedPrefixLen(lhs, rhs) == rhs.len sharedPrefixLen(lhs, rhs) == rhs.len
proc hexPrefixDecode*(r: ByteRange): tuple[isLeaf: bool, nibbles: NibblesRange] = proc hexPrefixDecode*(r: openArray[byte]): tuple[isLeaf: bool, nibbles: NibblesSeq] =
result.nibbles = initNibbleRange(r) result.nibbles = initNibbleRange(r)
if r.len > 0: if r.len > 0:
result.isLeaf = (r[0] and 0x20) != 0 result.isLeaf = (r[0] and 0x20) != 0
@ -115,7 +108,7 @@ template putNibbles(bytes, src: untyped) =
template calcNeededBytes(len: int): int = template calcNeededBytes(len: int): int =
(len shr 1) + (len and 1) (len shr 1) + (len and 1)
proc `&`*(a, b: NibblesRange): NibblesRange = proc `&`*(a, b: NibblesSeq): NibblesSeq =
let let
len = a.len + b.len len = a.len + b.len
bytesNeeded = calcNeededBytes(len) bytesNeeded = calcNeededBytes(len)
@ -128,10 +121,10 @@ proc `&`*(a, b: NibblesRange): NibblesRange =
bytes.putNibbles(a) bytes.putNibbles(a)
bytes.putNibbles(b) bytes.putNibbles(b)
result = initNibbleRange(bytes.toRange) result = initNibbleRange(bytes)
result.iend = len result.iend = len
proc cloneAndReserveNibble*(a: NibblesRange): NibblesRange = proc cloneAndReserveNibble*(a: NibblesSeq): NibblesSeq =
let let
len = a.len + 1 len = a.len + 1
bytesNeeded = calcNeededBytes(len) bytesNeeded = calcNeededBytes(len)
@ -142,24 +135,15 @@ proc cloneAndReserveNibble*(a: NibblesRange): NibblesRange =
pos = 0 pos = 0
bytes.putNibbles(a) bytes.putNibbles(a)
result = initNibbleRange(bytes.toRange) result = initNibbleRange(bytes)
result.iend = len result.iend = len
proc replaceLastNibble*(a: var NibblesRange, b: byte) = proc replaceLastNibble*(a: var NibblesSeq, b: byte) =
var var
odd = (a.len and 1) == 0 odd = (a.len and 1) == 0
pos = (a.len shr 1) - odd.int pos = (a.len shr 1) - odd.int
putNibble(MutRange[byte](a.bytes), b) putNibble(a.bytes, b)
proc getBytes*(a: NibblesRange): ByteRange = proc getBytes*(a: NibblesSeq): seq[byte] =
a.bytes a.bytes
when false:
proc keyOf(r: ByteRange): NibblesRange =
let firstIdx = if r.len == 0: 0
elif (r[0] and 0x10) != 0: 1
else: 2
return initNibbleRange(s).slice(firstIdx)

View File

@ -1,9 +1,9 @@
import import
stew/ranges/[typedranges, bitranges], eth/rlp/types, ./trie_bitseq,
trie_defs, trie_utils, db, sparse_proofs ./trie_defs, ./trie_utils, ./db, ./sparse_proofs
export export
types, trie_utils, bitranges, trie_utils, trie_bitseq,
sparse_proofs.verifyProof sparse_proofs.verifyProof
type type
@ -11,11 +11,7 @@ type
SparseBinaryTrie* = object SparseBinaryTrie* = object
db: DB db: DB
rootHash: ByteRange rootHash: seq[byte]
proc `==`(a: ByteRange, b: KeccakHash): bool =
if a.len != b.data.len: return false
equalMem(a.baseAddr, b.data[0].unsafeAddr, a.len)
type type
# 256 * 2 div 8 # 256 * 2 div 8
@ -24,83 +20,83 @@ type
proc initDoubleHash(a, b: openArray[byte]): DoubleHash = proc initDoubleHash(a, b: openArray[byte]): DoubleHash =
doAssert(a.len == 32, $a.len) doAssert(a.len == 32, $a.len)
doAssert(b.len == 32, $b.len) doAssert(b.len == 32, $b.len)
copyMem(result[ 0].addr, a[0].unsafeAddr, 32) result[0..31] = a
copyMem(result[32].addr, b[0].unsafeAddr, 32) result[32..^1] = b
proc initDoubleHash(x: ByteRange): DoubleHash = proc initDoubleHash(x: openArray[byte]): DoubleHash =
initDoubleHash(x.toOpenArray, x.toOpenArray) initDoubleHash(x, x)
proc init*(x: typedesc[SparseBinaryTrie], db: DB): SparseBinaryTrie = proc init*(x: typedesc[SparseBinaryTrie], db: DB): SparseBinaryTrie =
result.db = db result.db = db
# Initialize an empty tree with one branch # Initialize an empty tree with one branch
var value = initDoubleHash(emptyNodeHashes[0]) var value = initDoubleHash(emptyNodeHashes[0].data)
result.rootHash = keccakHash(value) result.rootHash = @(keccakHash(value).data)
result.db.put(result.rootHash.toOpenArray, value) result.db.put(result.rootHash, value)
for i in 0..<treeHeight - 1: for i in 0..<treeHeight - 1:
value = initDoubleHash(emptyNodeHashes[i+1]) value = initDoubleHash(emptyNodeHashes[i+1].data)
result.db.put(emptyNodeHashes[i].toOpenArray, value) result.db.put(emptyNodeHashes[i].data, value)
result.db.put(emptyLeafNodeHash.data, zeroBytesRange.toOpenArray) result.db.put(emptyLeafNodeHash.data, [])
proc initSparseBinaryTrie*(db: DB): SparseBinaryTrie = proc initSparseBinaryTrie*(db: DB): SparseBinaryTrie =
init(SparseBinaryTrie, db) init(SparseBinaryTrie, db)
proc init*(x: typedesc[SparseBinaryTrie], db: DB, proc init*(x: typedesc[SparseBinaryTrie], db: DB,
rootHash: BytesContainer | KeccakHash): SparseBinaryTrie = rootHash: openArray[byte]): SparseBinaryTrie =
checkValidHashZ(rootHash) checkValidHashZ(rootHash)
result.db = db result.db = db
result.rootHash = rootHash result.rootHash = @rootHash
proc initSparseBinaryTrie*(db: DB, rootHash: BytesContainer | KeccakHash): SparseBinaryTrie = proc initSparseBinaryTrie*(db: DB, rootHash: openArray[byte]): SparseBinaryTrie =
init(SparseBinaryTrie, db, rootHash) init(SparseBinaryTrie, db, rootHash)
proc getDB*(t: SparseBinaryTrie): auto = t.db proc getDB*(t: SparseBinaryTrie): auto = t.db
proc getRootHash*(self: SparseBinaryTrie): ByteRange {.inline.} = proc getRootHash*(self: SparseBinaryTrie): seq[byte] {.inline.} =
self.rootHash self.rootHash
proc getAux(self: SparseBinaryTrie, path: BitRange, rootHash: ByteRange): ByteRange = proc getAux(self: SparseBinaryTrie, path: TrieBitSeq, rootHash: openArray[byte]): seq[byte] =
var nodeHash = rootHash var nodeHash = @rootHash
for targetBit in path: for targetBit in path:
let value = self.db.get(nodeHash.toOpenArray).toRange let value = self.db.get(nodeHash)
if value.len == 0: return zeroBytesRange if value.len == 0: return
if targetBit: nodeHash = value[32..^1] if targetBit: nodeHash = value[32..^1]
else: nodeHash = value[0..31] else: nodeHash = value[0..31]
if nodeHash.toOpenArray == emptyLeafNodeHash.data: if nodeHash == emptyLeafNodeHash.data:
result = zeroBytesRange result = @[]
else: else:
result = self.db.get(nodeHash.toOpenArray).toRange result = self.db.get(nodeHash)
proc get*(self: SparseBinaryTrie, key: BytesContainer): ByteRange = proc get*(self: SparseBinaryTrie, key: openArray[byte]): seq[byte] =
## gets a key from the tree. ## gets a key from the tree.
doAssert(key.len == pathByteLen) doAssert(key.len == pathByteLen)
let path = MutByteRange(key.toRange).bits let path = bits key
self.getAux(path, self.rootHash) self.getAux(path, self.rootHash)
proc get*(self: SparseBinaryTrie, key, rootHash: distinct BytesContainer): ByteRange = proc get*(self: SparseBinaryTrie, key, rootHash: openArray[byte]): seq[byte] =
## gets a key from the tree at a specific root. ## gets a key from the tree at a specific root.
doAssert(key.len == pathByteLen) doAssert(key.len == pathByteLen)
let path = MutByteRange(key.toRange).bits let path = bits key
self.getAux(path, rootHash.toRange) self.getAux(path, rootHash)
proc hashAndSave*(self: SparseBinaryTrie, node: ByteRange): ByteRange = proc hashAndSave*(self: SparseBinaryTrie, node: openArray[byte]): seq[byte] =
result = keccakHash(node) result = @(keccakHash(node).data)
self.db.put(result.toOpenArray, node.toOpenArray) self.db.put(result, node)
proc hashAndSave*(self: SparseBinaryTrie, a, b: ByteRange): ByteRange = proc hashAndSave*(self: SparseBinaryTrie, a, b: openArray[byte]): seq[byte] =
let value = initDoubleHash(a.toOpenArray, b.toOpenArray) let value = initDoubleHash(a, b)
result = keccakHash(value) result = @(keccakHash(value).data)
self.db.put(result.toOpenArray, value) self.db.put(result, value)
proc setAux(self: var SparseBinaryTrie, value: ByteRange, proc setAux(self: var SparseBinaryTrie, value: openArray[byte],
path: BitRange, depth: int, nodeHash: ByteRange): ByteRange = path: TrieBitSeq, depth: int, nodeHash: openArray[byte]): seq[byte] =
if depth == treeHeight: if depth == treeHeight:
result = self.hashAndSave(value) result = self.hashAndSave(value)
else: else:
let let
node = self.db.get(nodeHash.toOpenArray).toRange node = self.db.get(nodeHash)
leftNode = node[0..31] leftNode = node[0..31]
rightNode = node[32..^1] rightNode = node[32..^1]
if path[depth]: if path[depth]:
@ -108,75 +104,75 @@ proc setAux(self: var SparseBinaryTrie, value: ByteRange,
else: else:
result = self.hashAndSave(self.setAux(value, path, depth+1, leftNode), rightNode) result = self.hashAndSave(self.setAux(value, path, depth+1, leftNode), rightNode)
proc set*(self: var SparseBinaryTrie, key, value: distinct BytesContainer) = proc set*(self: var SparseBinaryTrie, key, value: openArray[byte]) =
## sets a new value for a key in the tree, returns the new root, ## sets a new value for a key in the tree, returns the new root,
## and sets the new current root of the tree. ## and sets the new current root of the tree.
doAssert(key.len == pathByteLen) doAssert(key.len == pathByteLen)
let path = MutByteRange(key.toRange).bits let path = bits key
self.rootHash = self.setAux(value.toRange, path, 0, self.rootHash) self.rootHash = self.setAux(value, path, 0, self.rootHash)
proc set*(self: var SparseBinaryTrie, key, value, rootHash: distinct BytesContainer): ByteRange = proc set*(self: var SparseBinaryTrie, key, value, rootHash: openArray[byte]): seq[byte] =
## sets a new value for a key in the tree at a specific root, ## sets a new value for a key in the tree at a specific root,
## and returns the new root. ## and returns the new root.
doAssert(key.len == pathByteLen) doAssert(key.len == pathByteLen)
let path = MutByteRange(key.toRange).bits let path = bits key
self.setAux(value.toRange, path, 0, rootHash.toRange) self.setAux(value, path, 0, rootHash)
template exists*(self: SparseBinaryTrie, key: BytesContainer): bool = template exists*(self: SparseBinaryTrie, key: openArray[byte]): bool =
self.get(toRange(key)) != zeroBytesRange self.get(key) != []
proc del*(self: var SparseBinaryTrie, key: BytesContainer) = proc del*(self: var SparseBinaryTrie, key: openArray[byte]) =
## Equals to setting the value to zeroBytesRange ## Equals to setting the value to zeroBytesRange
doAssert(key.len == pathByteLen) doAssert(key.len == pathByteLen)
self.set(key, zeroBytesRange) self.set(key, [])
# Dictionary API # Dictionary API
template `[]`*(self: SparseBinaryTrie, key: BytesContainer): ByteRange = template `[]`*(self: SparseBinaryTrie, key: openArray[byte]): seq[byte] =
self.get(key) self.get(key)
template `[]=`*(self: var SparseBinaryTrie, key, value: distinct BytesContainer) = template `[]=`*(self: var SparseBinaryTrie, key, value: openArray[byte]) =
self.set(key, value) self.set(key, value)
template contains*(self: SparseBinaryTrie, key: BytesContainer): bool = template contains*(self: SparseBinaryTrie, key: openArray[byte]): bool =
self.exists(key) self.exists(key)
proc proveAux(self: SparseBinaryTrie, key, rootHash: ByteRange, output: var seq[ByteRange]): bool = proc proveAux(self: SparseBinaryTrie, key, rootHash: openArray[byte], output: var seq[seq[byte]]): bool =
doAssert(key.len == pathByteLen) doAssert(key.len == pathByteLen)
var currVal = self.db.get(rootHash.toOpenArray).toRange var currVal = self.db.get(rootHash)
if currVal.len == 0: return false if currVal.len == 0: return false
let path = MutByteRange(key).bits let path = bits key
for i, bit in path: for i, bit in path:
if bit: if bit:
# right side # right side
output[i] = currVal[0..31] output[i] = currVal[0..31]
currVal = self.db.get(currVal[32..^1].toOpenArray).toRange currVal = self.db.get(currVal[32..^1])
if currVal.len == 0: return false if currVal.len == 0: return false
else: else:
output[i] = currVal[32..^1] output[i] = currVal[32..^1]
currVal = self.db.get(currVal[0..31].toOpenArray).toRange currVal = self.db.get(currVal[0..31])
if currVal.len == 0: return false if currVal.len == 0: return false
result = true result = true
# prove generates a Merkle proof for a key. # prove generates a Merkle proof for a key.
proc prove*(self: SparseBinaryTrie, key: BytesContainer): seq[ByteRange] = proc prove*(self: SparseBinaryTrie, key: openArray[byte]): seq[seq[byte]] =
result = newSeq[ByteRange](treeHeight) result = newSeq[seq[byte]](treeHeight)
if not self.proveAux(key.toRange, self.rootHash, result): if not self.proveAux(key, self.rootHash, result):
result = @[] result = @[]
# prove generates a Merkle proof for a key, at a specific root. # prove generates a Merkle proof for a key, at a specific root.
proc prove*(self: SparseBinaryTrie, key, rootHash: distinct BytesContainer): seq[ByteRange] = proc prove*(self: SparseBinaryTrie, key, rootHash: openArray[byte]): seq[seq[byte]] =
result = newSeq[ByteRange](treeHeight) result = newSeq[seq[byte]](treeHeight)
if not self.proveAux(key.toRange, rootHash.toRange, result): if not self.proveAux(key, rootHash, result):
result = @[] result = @[]
# proveCompact generates a compacted Merkle proof for a key. # proveCompact generates a compacted Merkle proof for a key.
proc proveCompact*(self: SparseBinaryTrie, key: BytesContainer): seq[ByteRange] = proc proveCompact*(self: SparseBinaryTrie, key: openArray[byte]): seq[seq[byte]] =
var temp = self.prove(key) var temp = self.prove(key)
temp.compactProof temp.compactProof
# proveCompact generates a compacted Merkle proof for a key, at a specific root. # proveCompact generates a compacted Merkle proof for a key, at a specific root.
proc proveCompact*(self: SparseBinaryTrie, key, rootHash: distinct BytesContainer): seq[ByteRange] = proc proveCompact*(self: SparseBinaryTrie, key, rootHash: openArray[byte]): seq[seq[byte]] =
var temp = self.prove(key, rootHash) var temp = self.prove(key, rootHash)
temp.compactProof temp.compactProof

View File

@ -1,26 +1,25 @@
import import
stew/ranges/[typedranges, bitranges], ./trie_bitseq, ./trie_defs, /trie_utils
trie_defs, trie_utils
const const
treeHeight* = 160 treeHeight* = 160
pathByteLen* = treeHeight div 8 pathByteLen* = treeHeight div 8
emptyLeafNodeHash* = blankStringHash emptyLeafNodeHash* = blankStringHash
proc makeInitialEmptyTreeHash(H: static[int]): array[H, ByteRange] = proc makeInitialEmptyTreeHash(H: static[int]): array[H, KeccakHash] =
result[^1] = @(emptyLeafNodeHash.data).toRange result[^1] = emptyLeafNodeHash
for i in countdown(H-1, 1): for i in countdown(H-1, 1):
result[i - 1] = keccakHash(result[i], result[i]) result[i - 1] = keccakHash(result[i].data, result[i].data)
# cannot yet turn this into compile time constant # cannot yet turn this into compile time constant
let emptyNodeHashes* = makeInitialEmptyTreeHash(treeHeight) let emptyNodeHashes* = makeInitialEmptyTreeHash(treeHeight)
# VerifyProof verifies a Merkle proof. # VerifyProof verifies a Merkle proof.
proc verifyProofAux*(proof: seq[ByteRange], root, key, value: ByteRange): bool = proc verifyProofAux*(proof: seq[seq[byte]], root, key, value: openArray[byte]): bool =
doAssert(root.len == 32) doAssert(root.len == 32)
doAssert(key.len == pathByteLen) doAssert(key.len == pathByteLen)
var var
path = MutByteRange(key).bits path = bits key
curHash = keccakHash(value) curHash = keccakHash(value)
if proof.len != treeHeight: return false if proof.len != treeHeight: return false
@ -30,57 +29,58 @@ proc verifyProofAux*(proof: seq[ByteRange], root, key, value: ByteRange): bool =
if node.len != 32: return false if node.len != 32: return false
if path[i]: # right if path[i]: # right
# reuse curHash without more alloc # reuse curHash without more alloc
curHash.keccakHash(node, curHash) curHash.data.keccakHash(node, curHash.data)
else: else:
curHash.keccakHash(curHash, node) curHash.data.keccakHash(curHash.data, node)
result = curHash == root result = curHash.data == root
template verifyProof*(proof: seq[ByteRange], root, key, value: distinct BytesContainer): bool = template verifyProof*(proof: seq[seq[byte]], root, key, value: openArray[byte]): bool =
verifyProofAux(proof, root.toRange, key.toRange, value.toRange) verifyProofAux(proof, root, key, value)
proc count(b: BitRange, val: bool): int = proc count(b: TrieBitSeq, val: bool): int =
for c in b: for c in b:
if c == val: inc result if c == val: inc result
# CompactProof compacts a proof, to reduce its size. # CompactProof compacts a proof, to reduce its size.
proc compactProof*(proof: seq[ByteRange]): seq[ByteRange] = proc compactProof*(proof: seq[seq[byte]]): seq[seq[byte]] =
if proof.len != treeHeight: return if proof.len != treeHeight: return
var var
data = newRange[byte](pathByteLen) data = newSeq[byte](pathByteLen)
bits = MutByteRange(data).bits bits = bits data
result = @[] result = @[]
result.add data result.add @[]
for i in 0 ..< treeHeight: for i in 0 ..< treeHeight:
var node = proof[i] var node = proof[i]
if node == emptyNodeHashes[i]: if node == emptyNodeHashes[i].data:
bits[i] = true bits[i] = true
else: else:
result.add node result.add node
result[0] = bits.toBytes
# decompactProof decompacts a proof, so that it can be used for VerifyProof. # decompactProof decompacts a proof, so that it can be used for VerifyProof.
proc decompactProof*(proof: seq[ByteRange]): seq[ByteRange] = proc decompactProof*(proof: seq[seq[byte]]): seq[seq[byte]] =
if proof.len == 0: return if proof.len == 0: return
if proof[0].len != pathByteLen: return if proof[0].len != pathByteLen: return
var bits = MutByteRange(proof[0]).bits let bits = bits proof[0]
if proof.len != bits.count(false) + 1: return if proof.len != bits.count(false) + 1: return
result = newSeq[ByteRange](treeHeight) result = newSeq[seq[byte]](treeHeight)
var pos = 1 # skip bits var pos = 1 # skip bits
for i in 0 ..< treeHeight: for i in 0 ..< treeHeight:
if bits[i]: if bits[i]:
result[i] = emptyNodeHashes[i] result[i] = @(emptyNodeHashes[i].data)
else: else:
result[i] = proof[pos] result[i] = proof[pos]
inc pos inc pos
# verifyCompactProof verifies a compacted Merkle proof. # verifyCompactProof verifies a compacted Merkle proof.
proc verifyCompactProofAux*(proof: seq[ByteRange], root, key, value: ByteRange): bool = proc verifyCompactProofAux*(proof: seq[seq[byte]], root, key, value: openArray[byte]): bool =
var decompactedProof = decompactProof(proof) var decompactedProof = decompactProof(proof)
if decompactedProof.len == 0: return false if decompactedProof.len == 0: return false
verifyProofAux(decompactedProof, root, key, value) verifyProofAux(decompactedProof, root, key, value)
template verifyCompactProof*(proof: seq[ByteRange], root, key, value: distinct BytesContainer): bool = template verifyCompactProof*(proof: seq[seq[byte]], root, key, value: openArray[byte]): bool =
verifyCompactProofAux(proof, root.toRange, key.toRange, value.toRange) verifyCompactProofAux(proof, root, key, value)

129
eth/trie/trie_bitseq.nim Normal file
View File

@ -0,0 +1,129 @@
import
stew/bitops2
type
TrieBitSeq* = object
## Bit sequence as used in ethereum tries
data: seq[byte]
start: int
mLen: int ## Length in bits
template `@`(s, idx: untyped): untyped =
(when idx is BackwardsIndex: s.len - int(idx) else: int(idx))
proc bits*(a: seq[byte], start, len: int): TrieBitSeq =
doAssert start <= len
doAssert len <= 8 * a.len
TrieBitSeq(data: a, start: start, mLen: len)
template bits*(a: seq[byte]): TrieBitSeq =
bits(a, 0, a.len * 8)
template bits*(a: seq[byte], len: int): TrieBitSeq =
bits(a, 0, len)
template bits*(a: openArray[byte], start, len: int): TrieBitSeq =
bits(@a, start, len)
template bits*(a: openArray[byte]): TrieBitSeq =
bits(@a, 0, a.len * 8)
template bits*(a: openArray[byte], len: int): TrieBitSeq =
bits(@a, 0, len)
template bits*(x: TrieBitSeq): TrieBitSeq = x
proc len*(r: TrieBitSeq): int = r.mLen
iterator enumerateBits(x: TrieBitSeq): (int, bool) =
var p = x.start
var i = 0
let e = x.len
while i != e:
yield (i, getBitBE(x.data, p))
inc p
inc i
iterator items*(x: TrieBitSeq): bool =
for _, v in enumerateBits(x): yield v
iterator pairs*(x: TrieBitSeq): (int, bool) =
for i, v in enumerateBits(x): yield (i, v)
proc `[]`*(x: TrieBitSeq, idx: int): bool =
doAssert idx < x.len
let p = x.start + idx
result = getBitBE(x.data, p)
proc sliceNormalized(x: TrieBitSeq, ibegin, iend: int): TrieBitSeq =
doAssert ibegin >= 0 and
ibegin < x.len and
iend < x.len and
iend + 1 >= ibegin # the +1 here allows the result to be
# an empty range
result.data = x.data
result.start = x.start + ibegin
result.mLen = iend - ibegin + 1
proc `[]`*(r: TrieBitSeq, s: HSlice): TrieBitSeq =
sliceNormalized(r, r @ s.a, r @ s.b)
proc `==`*(a, b: TrieBitSeq): bool =
if a.len != b.len: return false
for i in 0 ..< a.len:
if a[i] != b[i]: return false
true
proc `[]=`*(r: var TrieBitSeq, idx: Natural, val: bool) =
doAssert idx < r.len
let absIdx = r.start + idx
changeBitBE(r.data, absIdx, val)
proc pushFront*(x: var TrieBitSeq, val: bool) =
doAssert x.start > 0
dec x.start
x[0] = val
inc x.mLen
template neededBytes(nBits: int): int =
(nBits shr 3) + ord((nBits and 0b111) != 0)
static:
doAssert neededBytes(2) == 1
doAssert neededBytes(8) == 1
doAssert neededBytes(9) == 2
proc `&`*(a, b: TrieBitSeq): TrieBitSeq =
let totalLen = a.len + b.len
var bytes = newSeq[byte](totalLen.neededBytes)
result = bits(bytes, 0, totalLen)
for i in 0 ..< a.len: result.data.changeBitBE(i, a[i])
for i in 0 ..< b.len: result.data.changeBitBE(i + a.len, b[i])
proc `$`*(r: TrieBitSeq): string =
result = newStringOfCap(r.len)
for bit in r:
result.add(if bit: '1' else: '0')
proc fromBits*(T: type, r: TrieBitSeq, offset, num: Natural): T =
doAssert(num <= sizeof(T) * 8)
# XXX: Nim has a bug that a typedesc parameter cannot be used
# in a type coercion, so we must define an alias here:
type TT = T
for i in 0 ..< num:
result = (result shl 1) or TT(r[offset + i])
proc parse*(T: type TrieBitSeq, s: string): TrieBitSeq =
var bytes = newSeq[byte](s.len.neededBytes)
for i, c in s:
case c
of '0': discard
of '1': setBitBE(bytes, i)
else: doAssert false
result = bits(bytes, 0, s.len)
proc toBytes*(r: TrieBitSeq): seq[byte] =
r.data[(r.start div 8)..<((r.mLen - r.start + 7) div 8)]

View File

@ -1,12 +1,8 @@
import import
eth/rlp, stew/ranges/typedranges, nimcrypto/hash eth/rlp, nimcrypto/hash
export
typedranges, Bytes
type type
KeccakHash* = MDigest[256] KeccakHash* = MDigest[256]
BytesContainer* = ByteRange | Bytes | string
TrieError* = object of CatchableError TrieError* = object of CatchableError
# A common base type of all Trie errors. # A common base type of all Trie errors.
@ -20,10 +16,6 @@ type
# operate if its database has been tampered with. A swift crash # operate if its database has been tampered with. A swift crash
# will be a more appropriate response. # will be a more appropriate response.
# can't be a const: https://github.com/status-im/nim-eth/issues/6
# we can't initialise it here, but since it's already zeroed memory, we don't need to
var zeroBytesRange* {.threadvar.}: ByteRange
const const
blankStringHash* = "c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470".toDigest blankStringHash* = "c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470".toDigest
emptyRlp* = @[128.byte] emptyRlp* = @[128.byte]
@ -34,10 +26,3 @@ proc read*(rlp: var Rlp, T: typedesc[MDigest]): T {.inline.} =
proc append*(rlpWriter: var RlpWriter, a: MDigest) {.inline.} = proc append*(rlpWriter: var RlpWriter, a: MDigest) {.inline.} =
rlpWriter.append(a.data) rlpWriter.append(a.data)
proc unnecessary_OpenArrayToRange*(key: openarray[byte]): ByteRange =
## XXX: The name of this proc is intentionally long, because it
## performs a memory allocation and data copying that may be eliminated
## in the future. Avoid renaming it to something similar as `toRange`, so
## it can remain searchable in the code.
toRange(@key)

View File

@ -1,42 +1,22 @@
import import
stew/byteutils, stew/byteutils,
stew/ranges/[typedranges, ptr_arith], nimcrypto/[hash, keccak], nimcrypto/[hash, keccak],
trie_defs, binaries trie_defs
proc toTrieNodeKey*(hash: KeccakHash): TrieNodeKey =
result = newRange[byte](32)
copyMem(result.baseAddr, hash.data.baseAddr, 32)
template checkValidHashZ*(x: untyped) = template checkValidHashZ*(x: untyped) =
when x.type isnot KeccakHash: when x.type isnot KeccakHash:
doAssert(x.len == 32 or x.len == 0) doAssert(x.len == 32 or x.len == 0)
template isZeroHash*(x: ByteRange): bool = template isZeroHash*(x: openArray[byte]): bool =
x.len == 0 x.len == 0
template toRange*(hash: KeccakHash): ByteRange =
toTrieNodeKey(hash)
proc toRange*(str: string): ByteRange =
var s = newSeq[byte](str.len)
if str.len > 0:
copyMem(s[0].addr, str[0].unsafeAddr, str.len)
result = toRange(s)
proc hashFromHex*(bits: static[int], input: string): MDigest[bits] = proc hashFromHex*(bits: static[int], input: string): MDigest[bits] =
MDigest(data: hexToByteArray[bits div 8](input)) MDigest(data: hexToByteArray[bits div 8](input))
template hashFromHex*(s: static[string]): untyped = hashFromHex(s.len * 4, s) template hashFromHex*(s: static[string]): untyped = hashFromHex(s.len * 4, s)
proc keccakHash*(input: openArray[byte]): ByteRange = proc keccakHash*(input: openArray[byte]): KeccakHash =
var s = newSeq[byte](32) keccak256.digest(input)
var ctx: keccak256
ctx.init()
if input.len > 0:
ctx.update(input[0].unsafeAddr, uint(input.len))
ctx.finish s
ctx.clear()
result = toRange(s)
proc keccakHash*(dest: var openArray[byte], a, b: openArray[byte]) = proc keccakHash*(dest: var openArray[byte], a, b: openArray[byte]) =
var ctx: keccak256 var ctx: keccak256
@ -48,16 +28,7 @@ proc keccakHash*(dest: var openArray[byte], a, b: openArray[byte]) =
ctx.finish dest ctx.finish dest
ctx.clear() ctx.clear()
proc keccakHash*(a, b: openArray[byte]): ByteRange = proc keccakHash*(a, b: openArray[byte]): KeccakHash =
var s = newSeq[byte](32) var s: array[32, byte]
keccakHash(s, a, b) keccakHash(s, a, b)
result = toRange(s) KeccakHash(data: s)
template keccakHash*(input: ByteRange): ByteRange =
keccakHash(input.toOpenArray)
template keccakHash*(a, b: ByteRange): ByteRange =
keccakHash(a.toOpenArray, b.toOpenArray)
template keccakHash*(dest: var ByteRange, a, b: ByteRange) =
keccakHash(dest.toOpenArray, a.toOpenArray, b.toOpenArray)

View File

@ -24,8 +24,8 @@ proc generate() =
# valid data for a Ping packet # valid data for a Ping packet
block: block:
let payload = rlp.encode((4, fromAddr, toAddr, expiration())).toRange let payload = rlp.encode((4, fromAddr, toAddr, expiration()))
let encodedData = @[1.byte] & payload.toSeq() let encodedData = @[1.byte] & payload
debug "Ping", data=byteutils.toHex(encodedData) debug "Ping", data=byteutils.toHex(encodedData)
encodedData.toFile(inputsDir & "ping") encodedData.toFile(inputsDir & "ping")
@ -33,8 +33,8 @@ proc generate() =
# valid data for a Pong packet # valid data for a Pong packet
block: block:
let token = keccak256.digest(@[0]) let token = keccak256.digest(@[0])
let payload = rlp.encode((toAddr, token , expiration())).toRange let payload = rlp.encode((toAddr, token , expiration()))
let encodedData = @[2.byte] & payload.toSeq() let encodedData = @[2.byte] & payload
debug "Pong", data=byteutils.toHex(encodedData) debug "Pong", data=byteutils.toHex(encodedData)
encodedData.toFile(inputsDir & "pong") encodedData.toFile(inputsDir & "pong")
@ -43,7 +43,7 @@ proc generate() =
block: block:
var data: array[64, byte] var data: array[64, byte]
data[32 .. ^1] = peerKey.toPublicKey().tryGet().toNodeId().toByteArrayBE() data[32 .. ^1] = peerKey.toPublicKey().tryGet().toNodeId().toByteArrayBE()
let payload = rlp.encode((data, expiration())).toRange let payload = rlp.encode((data, expiration()))
let encodedData = @[3.byte] & payload.toSeq() let encodedData = @[3.byte] & payload.toSeq()
debug "FindNode", data=byteutils.toHex(encodedData) debug "FindNode", data=byteutils.toHex(encodedData)
@ -65,7 +65,7 @@ proc generate() =
nodes.add((n1Addr.ip, n1Addr.udpPort, n1Addr.tcpPort, n1Key.toPublicKey().tryGet())) nodes.add((n1Addr.ip, n1Addr.udpPort, n1Addr.tcpPort, n1Key.toPublicKey().tryGet()))
nodes.add((n2Addr.ip, n2Addr.udpPort, n2Addr.tcpPort, n2Key.toPublicKey().tryGet())) nodes.add((n2Addr.ip, n2Addr.udpPort, n2Addr.tcpPort, n2Key.toPublicKey().tryGet()))
let payload = rlp.encode((nodes, expiration())).toRange let payload = rlp.encode((nodes, expiration()))
let encodedData = @[4.byte] & payload.toSeq() let encodedData = @[4.byte] & payload.toSeq()
debug "Neighbours", data=byteutils.toHex(encodedData) debug "Neighbours", data=byteutils.toHex(encodedData)

View File

@ -19,7 +19,7 @@ These are the mandatory `test` block and the optional `init` block.
Example usage: Example usage:
```nim ```nim
test: test:
var rlp = rlpFromBytes(@payload.toRange) var rlp = rlpFromBytes(payload)
discard rlp.inspect() discard rlp.inspect()
``` ```
@ -30,7 +30,7 @@ E.g.:
```nim ```nim
test: test:
try: try:
var rlp = rlpFromBytes(@payload.toRange) var rlp = rlpFromBytes(payload)
discard rlp.inspect() discard rlp.inspect()
except RlpError as e: except RlpError as e:
debug "Inspect failed", err = e.msg debug "Inspect failed", err = e.msg

View File

@ -2,7 +2,7 @@ import chronicles, eth/rlp, ../fuzztest
test: test:
try: try:
var rlp = rlpFromBytes(@payload.toRange) var rlp = rlpFromBytes(payload)
discard rlp.inspect() discard rlp.inspect()
except RlpError as e: except RlpError as e:
debug "Inspect failed", err = e.msg debug "Inspect failed", err = e.msg

View File

@ -52,7 +52,7 @@ proc packData*(payload: openArray[byte], pk: PrivateKey): seq[byte] =
template sourceDir*: string = currentSourcePath.rsplit(DirSep, 1)[0] template sourceDir*: string = currentSourcePath.rsplit(DirSep, 1)[0]
proc recvMsgMock*(msg: openArray[byte]): tuple[msgId: int, msgData: Rlp] = proc recvMsgMock*(msg: openArray[byte]): tuple[msgId: int, msgData: Rlp] =
var rlp = rlpFromBytes(@msg.toRange) var rlp = rlpFromBytes(msg)
let msgId = rlp.read(int32) let msgId = rlp.read(int32)
return (msgId.int, rlp) return (msgId.int, rlp)

View File

@ -1,7 +1,7 @@
import import
net, unittest, options, net, unittest, options,
nimcrypto/utils, nimcrypto/utils,
eth/p2p/enode, eth/p2p/discoveryv5/enr, eth/keys, eth/rlp eth/p2p/enode, eth/p2p/discoveryv5/enr, eth/keys
suite "ENR": suite "ENR":
test "Serialization": test "Serialization":

View File

@ -49,7 +49,7 @@ procSuite "Waku Mail Client":
let decoded = decode(response.envelope.data, symKey = some(symKey)) let decoded = decode(response.envelope.data, symKey = some(symKey))
require decoded.isSome() require decoded.isSome()
var rlp = rlpFromBytes(decoded.get().payload.toRange) var rlp = rlpFromBytes(decoded.get().payload)
let output = rlp.read(MailRequest) let output = rlp.read(MailRequest)
check: check:
output.lower == lower output.lower == lower
@ -92,7 +92,7 @@ procSuite "Waku Mail Client":
var envelopes: seq[Envelope] var envelopes: seq[Envelope]
traceAsyncErrors peer.p2pMessage(envelopes) traceAsyncErrors peer.p2pMessage(envelopes)
var cursor: Bytes var cursor: seq[byte]
count = count - 1 count = count - 1
if count == 0: if count == 0:
cursor = @[] cursor = @[]

View File

@ -1,13 +1,15 @@
{.used.}
import import
math, unittest, strutils, math, unittest, strutils, stew/byteutils,
eth/rlp, util/json_testing eth/rlp
proc q(s: string): string = "\"" & s & "\"" proc q(s: string): string = "\"" & s & "\""
proc i(s: string): string = s.replace(" ").replace("\n") proc i(s: string): string = s.replace(" ").replace("\n")
proc inspectMatch(r: Rlp, s: string): bool = r.inspect.i == s.i proc inspectMatch(r: Rlp, s: string): bool = r.inspect.i == s.i
test "empty bytes are not a proper RLP": test "empty bytes are not a proper RLP":
var rlp = rlpFromBytes Bytes(@[]).toRange var rlp = rlpFromBytes seq[byte](@[])
check: check:
not rlp.hasData not rlp.hasData
@ -52,7 +54,7 @@ test "encode/decode object":
var writer = initRlpWriter() var writer = initRlpWriter()
writer.append(input) writer.append(input)
let bytes = writer.finish() let bytes = writer.finish()
var rlp = rlpFromBytes(bytes.toRange) var rlp = rlpFromBytes(bytes)
var output = rlp.read(MyObj) var output = rlp.read(MyObj)
check: check:
@ -66,10 +68,10 @@ test "encode and decode lists":
var var
bytes = writer.finish bytes = writer.finish
rlp = rlpFromBytes bytes.toRange rlp = rlpFromBytes bytes
check: check:
bytes.hexRepr == "d183666f6fc8836261728362617ac31e2832" bytes.toHex == "d183666f6fc8836261728362617ac31e2832"
rlp.inspectMatch """ rlp.inspectMatch """
{ {
"foo" "foo"
@ -89,7 +91,7 @@ test "encode and decode lists":
"Lorem ipsum dolor sit amet", "Lorem ipsum dolor sit amet",
"Donec ligula tortor, egestas eu est vitae") "Donec ligula tortor, egestas eu est vitae")
rlp = rlpFromBytes bytes.toRange rlp = rlpFromBytes bytes
check: check:
rlp.listLen == 3 rlp.listLen == 3
rlp.listElem(0).toInt(int) == 6000 rlp.listElem(0).toInt(int) == 6000
@ -97,7 +99,7 @@ test "encode and decode lists":
rlp.listElem(2).toString == "Donec ligula tortor, egestas eu est vitae" rlp.listElem(2).toString == "Donec ligula tortor, egestas eu est vitae"
# test creating RLPs from other RLPs # test creating RLPs from other RLPs
var list = rlpFromBytes encodeList(rlp.listELem(1), rlp.listELem(0)).toRange var list = rlpFromBytes encodeList(rlp.listELem(1), rlp.listELem(0))
# test that iteration with enterList/skipElem works as expected # test that iteration with enterList/skipElem works as expected
doAssert list.enterList # We already know that we are working with a list doAssert list.enterList # We already know that we are working with a list
@ -117,11 +119,11 @@ test "toBytes":
let tok = rlp.listElem(1).toBytes() let tok = rlp.listElem(1).toBytes()
check: check:
tok.len == 32 tok.len == 32
tok.hexRepr == "40ef02798f211da2e8173d37f255be908871ae65060dbb2f77fb29c0421447f4" tok.toHex == "40ef02798f211da2e8173d37f255be908871ae65060dbb2f77fb29c0421447f4"
test "nested lists": test "nested lists":
let listBytes = encode([[1, 2, 3], [5, 6, 7]]) let listBytes = encode([[1, 2, 3], [5, 6, 7]])
let listRlp = rlpFromBytes listBytes.toRange let listRlp = rlpFromBytes listBytes
let sublistRlp0 = listRlp.listElem(0) let sublistRlp0 = listRlp.listElem(0)
let sublistRlp1 = listRlp.listElem(1) let sublistRlp1 = listRlp.listElem(1)
check sublistRlp0.listElem(0).toInt(int) == 1 check sublistRlp0.listElem(0).toInt(int) == 1
@ -133,12 +135,12 @@ test "nested lists":
test "encoding length": test "encoding length":
let listBytes = encode([1,2,3,4,5]) let listBytes = encode([1,2,3,4,5])
let listRlp = rlpFromBytes listBytes.toRange let listRlp = rlpFromBytes listBytes
check listRlp.listLen == 5 check listRlp.listLen == 5
let emptyListBytes = encode "" let emptyListBytes = encode ""
check emptyListBytes.len == 1 check emptyListBytes.len == 1
let emptyListRlp = rlpFromBytes emptyListBytes.toRange let emptyListRlp = rlpFromBytes emptyListBytes
check emptyListRlp.blobLen == 0 check emptyListRlp.blobLen == 0
test "basic decoding": test "basic decoding":
@ -159,21 +161,21 @@ test "encode byte arrays":
var b2 = [byte(6), 8, 12, 123] var b2 = [byte(6), 8, 12, 123]
var b3 = @[byte(122), 56, 65, 12] var b3 = @[byte(122), 56, 65, 12]
let rlp = rlpFromBytes(encode((b1, b2, b3)).toRange) let rlp = rlpFromBytes(encode((b1, b2, b3)))
check: check:
rlp.listLen == 3 rlp.listLen == 3
rlp.listElem(0).toBytes().toSeq() == @b1 rlp.listElem(0).toBytes() == b1
rlp.listElem(1).toBytes().toSeq() == @b2 rlp.listElem(1).toBytes() == b2
rlp.listElem(2).toBytes().toSeq() == @b3 rlp.listElem(2).toBytes() == b3
# The first byte here is the length of the datum (132 - 128 => 4) # The first byte here is the length of the datum (132 - 128 => 4)
$(rlp.listElem(1).rawData) == "R[132, 6, 8, 12, 123]" $(rlp.listElem(1).rawData) == "[132, 6, 8, 12, 123]"
test "empty byte arrays": test "empty byte arrays":
var var
rlp = rlpFromBytes rlp.encode("").toRange rlp = rlpFromBytes rlp.encode("")
b = rlp.toBytes b = rlp.toBytes
check $b == "R[]" check $b == "@[]"
test "encode/decode floats": test "encode/decode floats":
for f in [high(float64), low(float64), 0.1, 122.23, for f in [high(float64), low(float64), 0.1, 122.23,
@ -202,7 +204,7 @@ test "invalid enum":
writer.append(2) writer.append(2)
writer.append(-1) writer.append(-1)
let bytes = writer.finish() let bytes = writer.finish()
var rlp = rlpFromBytes(bytes.toRange) var rlp = rlpFromBytes(bytes)
expect RlpTypeMismatch: expect RlpTypeMismatch:
discard rlp.read(MyEnum) discard rlp.read(MyEnum)
rlp.skipElem() rlp.skipElem()

View File

@ -1,7 +1,7 @@
{.used.} {.used.}
import import
unittest, times, eth/rlp, util/json_testing unittest, times, eth/rlp, stew/byteutils
type type
Transaction = object Transaction = object
@ -46,7 +46,7 @@ test "encoding and decoding an object":
f: Foo(x: 5'u64, y: "hocus pocus", z: @[100, 200, 300])) f: Foo(x: 5'u64, y: "hocus pocus", z: @[100, 200, 300]))
var bytes = encode(originalBar) var bytes = encode(originalBar)
var r = rlpFromBytes(bytes.toRange) var r = rlpFromBytes(bytes)
var restoredBar = r.read(Bar) var restoredBar = r.read(Bar)
check: check:
@ -57,7 +57,7 @@ test "encoding and decoding an object":
var t2 = bytes.decode(Transaction) var t2 = bytes.decode(Transaction)
check: check:
bytes.hexRepr == "cd85416c69636583426f628203e8" # verifies that Alice comes first bytes.toHex == "cd85416c69636583426f628203e8" # verifies that Alice comes first
t2.time == default(Time) t2.time == default(Time)
t2.sender == "Alice" t2.sender == "Alice"
t2.receiver == "Bob" t2.receiver == "Bob"
@ -66,7 +66,7 @@ test "encoding and decoding an object":
test "custom field serialization": test "custom field serialization":
var origVal = CustomSerialized(customFoo: Foo(x: 10'u64, y: "y", z: @[]), ignored: 5) var origVal = CustomSerialized(customFoo: Foo(x: 10'u64, y: "y", z: @[]), ignored: 5)
var bytes = encode(origVal) var bytes = encode(origVal)
var r = rlpFromBytes(bytes.toRange) var r = rlpFromBytes(bytes)
var restored = r.read(CustomSerialized) var restored = r.read(CustomSerialized)
check: check:
@ -79,4 +79,3 @@ test "RLP fields count":
Bar.rlpFieldsCount == 2 Bar.rlpFieldsCount == 2
Foo.rlpFieldsCount == 3 Foo.rlpFieldsCount == 3
Transaction.rlpFieldsCount == 3 Transaction.rlpFieldsCount == 3

View File

@ -1,5 +1,5 @@
import import
json, strutils, unittest, eth/rlp json, stew/byteutils, unittest, eth/rlp
proc append(output: var RlpWriter, js: JsonNode) = proc append(output: var RlpWriter, js: JsonNode) =
case js.kind case js.kind
@ -14,11 +14,6 @@ proc append(output: var RlpWriter, js: JsonNode) =
of JArray: of JArray:
output.append js.elems output.append js.elems
proc hexRepr*(bytes: BytesRange|Bytes): string =
result = newStringOfCap(bytes.len * 2)
for byte in bytes:
result.add(toHex(int(byte), 2).toLowerAscii)
proc `==`(lhs: JsonNode, rhs: string): bool = proc `==`(lhs: JsonNode, rhs: string): bool =
lhs.kind == JString and lhs.str == rhs lhs.kind == JString and lhs.str == rhs
@ -58,7 +53,7 @@ proc runTests*(filename: string) =
var outRlp = initRlpWriter() var outRlp = initRlpWriter()
outRlp.append input outRlp.append input
let let
actual = outRlp.finish.hexRepr actual = outRlp.finish.toHex
expected = output.str expected = output.str
check actual == expected check actual == expected

View File

@ -6,7 +6,7 @@ import
test_examples, test_examples,
test_hexary_trie, test_hexary_trie,
test_json_suite, test_json_suite,
test_nibbles,
test_sparse_binary_trie, test_sparse_binary_trie,
test_storage_backends, test_storage_backends,
test_transaction_db test_transaction_db,
test_trie_bitseq

View File

@ -2,8 +2,8 @@
import import
unittest, random, unittest, random,
eth/trie/[trie_defs, db, binary], eth/trie/[db, binary],
./testutils ./testutils, stew/byteutils
suite "binary trie": suite "binary trie":
@ -19,7 +19,7 @@ suite "binary trie":
for i, c in kv_pairs: for i, c in kv_pairs:
trie.set(c.key, c.value) trie.set(c.key, c.value)
let x = trie.get(c.key) let x = trie.get(c.key)
let y = toRange(c.value) let y = c.value
check y == x check y == x
check result == zeroHash or trie.getRootHash() == result check result == zeroHash or trie.getRootHash() == result
@ -54,27 +54,27 @@ suite "binary trie":
let will_raise_error = data[4] let will_raise_error = data[4]
# First test case, delete subtrie of a kv node # First test case, delete subtrie of a kv node
trie.set(kv1[0], kv1[1]) trie.set(kv1[0].toBytes, kv1[1].toBytes)
trie.set(kv2[0], kv2[1]) trie.set(kv2[0].toBytes, kv2[1].toBytes)
check trie.get(kv1[0]) == toRange(kv1[1]) check trie.get(kv1[0].toBytes) == kv1[1].toBytes
check trie.get(kv2[0]) == toRange(kv2[1]) check trie.get(kv2[0].toBytes) == kv2[1].toBytes
if will_delete: if will_delete:
trie.deleteSubtrie(key_to_be_deleted) trie.deleteSubtrie(key_to_be_deleted.toBytes)
check trie.get(kv1[0]) == zeroBytesRange check trie.get(kv1[0].toBytes) == []
check trie.get(kv2[0]) == zeroBytesRange check trie.get(kv2[0].toBytes) == []
check trie.getRootHash() == zeroHash check trie.getRootHash() == zeroHash
else: else:
if will_raise_error: if will_raise_error:
try: try:
trie.deleteSubtrie(key_to_be_deleted) trie.deleteSubtrie(key_to_be_deleted.toBytes)
except NodeOverrideError: except NodeOverrideError:
discard discard
else: else:
let root_hash_before_delete = trie.getRootHash() let root_hash_before_delete = trie.getRootHash()
trie.deleteSubtrie(key_to_be_deleted) trie.deleteSubtrie(key_to_be_deleted.toBytes)
check trie.get(kv1[0]) == toRange(kv1[1]) check trie.get(kv1[0].toBytes) == toBytes(kv1[1])
check trie.get(kv2[0]) == toRange(kv2[1]) check trie.get(kv2[0].toBytes) == toBytes(kv2[1])
check trie.getRootHash() == root_hash_before_delete check trie.getRootHash() == root_hash_before_delete
const invalidKeyData = [ const invalidKeyData = [
@ -90,22 +90,22 @@ suite "binary trie":
var db = newMemoryDB() var db = newMemoryDB()
var trie = initBinaryTrie(db) var trie = initBinaryTrie(db)
trie.set("\x12\x34\x56\x78", "78") trie.set("\x12\x34\x56\x78".toBytes, "78".toBytes)
trie.set("\x12\x34\x56\x79", "79") trie.set("\x12\x34\x56\x79".toBytes, "79".toBytes)
let invalidKey = data[0] let invalidKey = data[0]
let if_error = data[1] let if_error = data[1]
check trie.get(invalidKey) == zeroBytesRange check trie.get(invalidKey.toBytes) == []
if if_error: if if_error:
try: try:
trie.delete(invalidKey) trie.delete(invalidKey.toBytes)
except NodeOverrideError: except NodeOverrideError:
discard discard
else: else:
let previous_root_hash = trie.getRootHash() let previous_root_hash = trie.getRootHash()
trie.delete(invalidKey) trie.delete(invalidKey.toBytes)
check previous_root_hash == trie.getRootHash() check previous_root_hash == trie.getRootHash()
test "update value": test "update value":
@ -114,13 +114,13 @@ suite "binary trie":
var db = newMemoryDB() var db = newMemoryDB()
var trie = initBinaryTrie(db) var trie = initBinaryTrie(db)
for key in keys: for key in keys:
trie.set(key, "old") trie.set(key.toBytes, "old".toBytes)
var current_root = trie.getRootHash() var current_root = trie.getRootHash()
for i in vals: for i in vals:
trie.set(keys[i], "old") trie.set(keys[i].toBytes, "old".toBytes)
check current_root == trie.getRootHash() check current_root == trie.getRootHash()
trie.set(keys[i], "new") trie.set(keys[i].toBytes, "new".toBytes)
check current_root != trie.getRootHash() check current_root != trie.getRootHash()
check trie.get(keys[i]) == toRange("new") check trie.get(keys[i].toBytes) == toBytes("new")
current_root = trie.getRootHash() current_root = trie.getRootHash()

View File

@ -2,11 +2,11 @@
import import
unittest, strutils, unittest, strutils,
stew/ranges/bitranges, eth/rlp/types, nimcrypto/[keccak, hash], nimcrypto/[keccak, hash],
eth/trie/[binaries, trie_utils], eth/trie/[binaries, trie_bitseq],
./testutils ./testutils, stew/byteutils
proc parseBitVector(x: string): BitRange = proc parseBitVector(x: string): TrieBitSeq =
result = genBitVec(x.len) result = genBitVec(x.len)
for i, c in x: for i, c in x:
result[i] = (c == '1') result[i] = (c == '1')
@ -53,7 +53,7 @@ suite "binaries utils":
test "node parsing": test "node parsing":
for c in parseNodeData: for c in parseNodeData:
let input = toRange(c[0]) let input = toBytes(c[0])
let node = c[1] let node = c[1]
let kind = TrieNodeKind(node[0]) let kind = TrieNodeKind(node[0])
let raiseError = node[3] let raiseError = node[3]
@ -69,12 +69,12 @@ suite "binaries utils":
case res.kind case res.kind
of KV_TYPE: of KV_TYPE:
check(res.keyPath == parseBitVector(node[1])) check(res.keyPath == parseBitVector(node[1]))
check(res.child == toRange(node[2])) check(res.child == toBytes(node[2]))
of BRANCH_TYPE: of BRANCH_TYPE:
check(res.leftChild == toRange(node[2])) check(res.leftChild == toBytes(node[2]))
check(res.rightChild == toRange(node[2])) check(res.rightChild == toBytes(node[2]))
of LEAF_TYPE: of LEAF_TYPE:
check(res.value == toRange(node[2])) check(res.value == toBytes(node[2]))
const const
kvData = [ kvData = [
@ -88,7 +88,7 @@ suite "binaries utils":
test "kv node encoding": test "kv node encoding":
for c in kvData: for c in kvData:
let keyPath = parseBitVector(c[0]) let keyPath = parseBitVector(c[0])
let node = toRange(c[1]) let node = toBytes(c[1])
let output = toBytes(c[2]) let output = toBytes(c[2])
let raiseError = c[3] let raiseError = c[3]
@ -110,8 +110,8 @@ suite "binaries utils":
test "branch node encode": test "branch node encode":
for c in branchData: for c in branchData:
let left = toRange(c[0]) let left = toBytes(c[0])
let right = toRange(c[1]) let right = toBytes(c[1])
let output = toBytes(c[2]) let output = toBytes(c[2])
let raiseError = c[3] let raiseError = c[3]
@ -132,34 +132,34 @@ suite "binaries utils":
let raiseError = c[2] let raiseError = c[2]
if raiseError: if raiseError:
expect(ValidationError): expect(ValidationError):
check toBytes(c[1]) == encodeLeafNode(toRange(c[0])) check toBytes(c[1]) == encodeLeafNode(toBytes(c[0]))
else: else:
check toBytes(c[1]) == encodeLeafNode(toRange(c[0])) check toBytes(c[1]) == encodeLeafNode(toBytes(c[0]))
test "random kv encoding": test "random kv encoding":
let lengths = randList(int, randGen(1, 999), randGen(100, 100), unique = false) let lengths = randList(int, randGen(1, 999), randGen(100, 100), unique = false)
for len in lengths: for len in lengths:
var k = len var k = len
var bitvec = genBitVec(len) var bitvec = genBitVec(len)
var nodeHash = keccak256.digest(cast[ptr byte](k.addr), uint(sizeof(int))).toRange var nodeHash = keccak256.digest(cast[ptr byte](k.addr), uint(sizeof(int)))
var kvnode = encodeKVNode(bitvec, nodeHash).toRange var kvnode = encodeKVNode(bitvec, @(nodeHash.data))
# first byte if KV_TYPE # first byte if KV_TYPE
# in the middle are 1..n bits of binary-encoded-keypath # in the middle are 1..n bits of binary-encoded-keypath
# last 32 bytes are hash # last 32 bytes are hash
var keyPath = decodeToBinKeypath(kvnode[1..^33]) var keyPath = decodeToBinKeypath(kvnode[1..^33])
check kvnode[0].ord == KV_TYPE.ord check kvnode[0].ord == KV_TYPE.ord
check keyPath == bitvec check keyPath == bitvec
check kvnode[^32..^1] == nodeHash check kvnode[^32..^1] == nodeHash.data
test "optimized single bit keypath kvnode encoding": test "optimized single bit keypath kvnode encoding":
var k = 1 var k = 1
var nodeHash = keccak256.digest(cast[ptr byte](k.addr), uint(sizeof(int))).toRange var nodeHash = keccak256.digest(cast[ptr byte](k.addr), uint(sizeof(int)))
var bitvec = genBitVec(1) var bitvec = genBitVec(1)
bitvec[0] = false bitvec[0] = false
var kvnode = encodeKVNode(bitvec, nodeHash).toRange var kvnode = encodeKVNode(bitvec, @(nodeHash.data))
var kp = decodeToBinKeypath(kvnode[1..^33]) var kp = decodeToBinKeypath(kvnode[1..^33])
var okv = encodeKVNode(false, nodeHash).toRange var okv = encodeKVNode(false, @(nodeHash.data))
check okv == kvnode check okv == kvnode
var okp = decodeToBinKeypath(kvnode[1..^33]) var okp = decodeToBinKeypath(kvnode[1..^33])
check okp == kp check okp == kp
@ -167,10 +167,10 @@ suite "binaries utils":
check okp == bitvec check okp == bitvec
bitvec[0] = true bitvec[0] = true
kvnode = encodeKVNode(bitvec, nodeHash).toRange kvnode = encodeKVNode(bitvec, @(nodeHash.data))
kp = decodeToBinKeypath(kvnode[1..^33]) kp = decodeToBinKeypath(kvnode[1..^33])
okv = encodeKVNode(true, nodeHash).toRange okv = encodeKVNode(true, @(nodeHash.data))
check okv == kvnode check okv == kvnode
okp = decodeToBinKeypath(kvnode[1..^33]) okp = decodeToBinKeypath(kvnode[1..^33])
check okp == kp check okp == kp

View File

@ -1,7 +1,7 @@
{.used.} {.used.}
import import
sets, unittest, strutils, sets, sets, unittest, strutils, stew/byteutils,
eth/trie/[db, binary, branches] eth/trie/[db, binary, branches]
suite "branches utils": suite "branches utils":
@ -10,9 +10,9 @@ suite "branches utils":
var db = newMemoryDB() var db = newMemoryDB()
var trie = initBinaryTrie(db) var trie = initBinaryTrie(db)
trie.set("\x12\x34\x56\x78\x9a", "9a") trie.set("\x12\x34\x56\x78\x9a".toBytes, "9a".toBytes)
trie.set("\x12\x34\x56\x78\x9b", "9b") trie.set("\x12\x34\x56\x78\x9b".toBytes, "9b".toBytes)
trie.set("\x12\x34\x56\xff", "ff") trie.set("\x12\x34\x56\xff".toBytes, "ff".toBytes)
trie trie
const branchExistData = [ const branchExistData = [
@ -28,7 +28,7 @@ suite "branches utils":
var trie = testTrie() var trie = testTrie()
var db = trie.getDB() var db = trie.getDB()
for c in branchExistData: for c in branchExistData:
let keyPrefix = c[0].toRange let keyPrefix = c[0].toBytes
let if_exist = c[1] let if_exist = c[1]
check checkIfBranchExist(db, trie.getRootHash(), keyPrefix) == if_exist check checkIfBranchExist(db, trie.getRootHash(), keyPrefix) == if_exist
@ -45,7 +45,7 @@ suite "branches utils":
var trie = testTrie() var trie = testTrie()
var db = trie.getDB() var db = trie.getDB()
for c in branchData: for c in branchData:
let key = c[0].toRange let key = c[0].toBytes
let keyValid = c[1] let keyValid = c[1]
if keyValid: if keyValid:
@ -83,15 +83,15 @@ suite "branches utils":
(repeat('0', 32), @[]) (repeat('0', 32), @[])
] ]
proc toRanges(x: seq[string]): seq[BytesRange] = proc toRanges(x: seq[string]): seq[seq[byte]] =
result = newSeq[BytesRange](x.len) result = newSeq[seq[byte]](x.len)
for i, c in x: result[i] = toRange(c) for i, c in x: result[i] = toBytes(c)
test "get trie nodes": test "get trie nodes":
var trie = testTrie() var trie = testTrie()
var db = trie.getDB() var db = trie.getDB()
for c in trieNodesData: for c in trieNodesData:
let root = c[0].toRange() let root = c[0].toBytes()
let nodes = toRanges(c[1]) let nodes = toRanges(c[1])
check toHashSet(nodes) == toHashSet(getTrieNodes(db, root)) check toHashSet(nodes) == toHashSet(getTrieNodes(db, root))
@ -135,7 +135,7 @@ suite "branches utils":
var trie = testTrie() var trie = testTrie()
var db = trie.getDB() var db = trie.getDB()
for c in witnessData: for c in witnessData:
let key = c[0].toRange let key = c[0].toBytes
let nodes = toRanges(c[1]) let nodes = toRanges(c[1])
if nodes.len != 0: if nodes.len != 0:

View File

@ -1,9 +1,9 @@
{.used.} {.used.}
import import
unittest, unittest, stew/byteutils,
nimcrypto/[keccak, hash], nimcrypto/[keccak, hash],
eth/trie/[trie_defs, db, binary, binaries, trie_utils, branches] eth/trie/[db, binary, binaries, trie_utils, branches]
suite "examples": suite "examples":
@ -11,84 +11,84 @@ suite "examples":
var trie = initBinaryTrie(db) var trie = initBinaryTrie(db)
test "basic set/get": test "basic set/get":
trie.set("key1", "value1") trie.set("key1".toBytes(), "value1".toBytes())
trie.set("key2", "value2") trie.set("key2".toBytes(), "value2".toBytes())
check trie.get("key1") == "value1".toRange check trie.get("key1".toBytes) == "value1".toBytes
check trie.get("key2") == "value2".toRange check trie.get("key2".toBytes) == "value2".toBytes
test "check branch exists": test "check branch exists":
check checkIfBranchExist(db, trie.getRootHash(), "key") == true check checkIfBranchExist(db, trie.getRootHash(), "key".toBytes) == true
check checkIfBranchExist(db, trie.getRootHash(), "key1") == true check checkIfBranchExist(db, trie.getRootHash(), "key1".toBytes) == true
check checkIfBranchExist(db, trie.getRootHash(), "ken") == false check checkIfBranchExist(db, trie.getRootHash(), "ken".toBytes) == false
check checkIfBranchExist(db, trie.getRootHash(), "key123") == false check checkIfBranchExist(db, trie.getRootHash(), "key123".toBytes) == false
test "branches utils": test "branches utils":
var branchA = getBranch(db, trie.getRootHash(), "key1") var branchA = getBranch(db, trie.getRootHash(), "key1".toBytes)
# ==> [A, B, C1, D1] # ==> [A, B, C1, D1]
check branchA.len == 4 check branchA.len == 4
var branchB = getBranch(db, trie.getRootHash(), "key2") var branchB = getBranch(db, trie.getRootHash(), "key2".toBytes)
# ==> [A, B, C2, D2] # ==> [A, B, C2, D2]
check branchB.len == 4 check branchB.len == 4
check isValidBranch(branchA, trie.getRootHash(), "key1", "value1") == true check isValidBranch(branchA, trie.getRootHash(), "key1".toBytes, "value1".toBytes) == true
check isValidBranch(branchA, trie.getRootHash(), "key5", "") == true check isValidBranch(branchA, trie.getRootHash(), "key5".toBytes, "".toBytes) == true
expect InvalidNode: expect InvalidNode:
check isValidBranch(branchB, trie.getRootHash(), "key1", "value1") check isValidBranch(branchB, trie.getRootHash(), "key1".toBytes, "value1".toBytes)
var x = getBranch(db, trie.getRootHash(), "key") var x = getBranch(db, trie.getRootHash(), "key".toBytes)
# ==> [A] # ==> [A]
check x.len == 1 check x.len == 1
expect InvalidKeyError: expect InvalidKeyError:
x = getBranch(db, trie.getRootHash(), "key123") # InvalidKeyError x = getBranch(db, trie.getRootHash(), "key123".toBytes) # InvalidKeyError
x = getBranch(db, trie.getRootHash(), "key5") # there is still branch for non-exist key x = getBranch(db, trie.getRootHash(), "key5".toBytes) # there is still branch for non-exist key
# ==> [A] # ==> [A]
check x.len == 1 check x.len == 1
test "getWitness": test "getWitness":
var branch = getWitness(db, trie.getRootHash(), "key1") var branch = getWitness(db, trie.getRootHash(), "key1".toBytes)
# equivalent to `getBranch(db, trie.getRootHash(), "key1")` # equivalent to `getBranch(db, trie.getRootHash(), "key1")`
# ==> [A, B, C1, D1] # ==> [A, B, C1, D1]
check branch.len == 4 check branch.len == 4
branch = getWitness(db, trie.getRootHash(), "key") branch = getWitness(db, trie.getRootHash(), "key".toBytes)
# this will include additional nodes of "key2" # this will include additional nodes of "key2"
# ==> [A, B, C1, D1, C2, D2] # ==> [A, B, C1, D1, C2, D2]
check branch.len == 6 check branch.len == 6
branch = getWitness(db, trie.getRootHash(), "") branch = getWitness(db, trie.getRootHash(), "".toBytes)
# this will return the whole trie # this will return the whole trie
# ==> [A, B, C1, D1, C2, D2] # ==> [A, B, C1, D1, C2, D2]
check branch.len == 6 check branch.len == 6
let beforeDeleteLen = db.totalRecordsInMemoryDB let beforeDeleteLen = db.totalRecordsInMemoryDB
test "verify intermediate entries existence": test "verify intermediate entries existence":
var branchs = getWitness(db, trie.getRootHash, zeroBytesRange) var branchs = getWitness(db, trie.getRootHash, [])
# set operation create new intermediate entries # set operation create new intermediate entries
check branchs.len < beforeDeleteLen check branchs.len < beforeDeleteLen
var node = branchs[1] var node = branchs[1]
let nodeHash = keccak256.digest(node.baseAddr, uint(node.len)) let nodeHash = keccak256.digest(node)
var nodes = getTrieNodes(db, nodeHash) var nodes = getTrieNodes(db, @(nodeHash.data))
check nodes.len == branchs.len - 1 check nodes.len == branchs.len - 1
test "delete sub trie": test "delete sub trie":
# delete all subtrie with key prefixes "key" # delete all subtrie with key prefixes "key"
trie.deleteSubtrie("key") trie.deleteSubtrie("key".toBytes)
check trie.get("key1") == zeroBytesRange check trie.get("key1".toBytes) == []
check trie.get("key2") == zeroBytesRange check trie.get("key2".toBytes) == []
test "prove the lie": test "prove the lie":
# `delete` and `deleteSubtrie` not actually delete the nodes # `delete` and `deleteSubtrie` not actually delete the nodes
check db.totalRecordsInMemoryDB == beforeDeleteLen check db.totalRecordsInMemoryDB == beforeDeleteLen
var branchs = getWitness(db, trie.getRootHash, zeroBytesRange) var branchs = getWitness(db, trie.getRootHash, [])
check branchs.len == 0 check branchs.len == 0
test "dictionary syntax API": test "dictionary syntax API":
# dictionary syntax API # dictionary syntax API
trie["moon"] = "sun" trie["moon".toBytes] = "sun".toBytes
check "moon" in trie check "moon".toBytes in trie
check trie["moon"] == "sun".toRange check trie["moon".toBytes] == "sun".toBytes

View File

@ -1,26 +1,17 @@
{.used.} {.used.}
import import
unittest, sequtils, os, unittest, sequtils, os, stew/byteutils,
stew/ranges/typedranges, eth/trie/[hexary, db, trie_defs], nimcrypto/utils, eth/trie/[hexary, db, trie_defs], nimcrypto/utils,
./testutils, algorithm, eth/rlp/types as rlpTypes, random ./testutils, algorithm, random
from strutils import split from strutils import split
template put(t: HexaryTrie|SecureHexaryTrie, key, val: string) =
t.put(key.toBytesRange, val.toBytesRange)
template del(t: HexaryTrie|SecureHexaryTrie, key) =
t.del(key.toBytesRange)
template get(t: HexaryTrie|SecureHexaryTrie, key): auto =
t.get(key.toBytesRange)
suite "hexary trie": suite "hexary trie":
setup: setup:
var var
db = newMemoryDB() db = newMemoryDB()
tr = initHexaryTrie(db) tr {.used.} = initHexaryTrie(db)
test "ref-counted keys crash": test "ref-counted keys crash":
proc addKey(intKey: int) = proc addKey(intKey: int) =
@ -28,12 +19,12 @@ suite "hexary trie":
key[19] = byte(intKey) key[19] = byte(intKey)
var data = newSeqWith(29, 1.byte) var data = newSeqWith(29, 1.byte)
var k = key.toRange var k = key
let v = tr.get(k) let v = tr.get(k)
doAssert(v.len == 0) doAssert(v.len == 0)
tr.put(k, toRange(data)) tr.put(k, data)
addKey(166) addKey(166)
addKey(193) addKey(193)
@ -78,12 +69,12 @@ suite "hexary trie":
key = fromHex(parts[0]) key = fromHex(parts[0])
val = fromHex(parts[1]) val = fromHex(parts[1])
SecureHexaryTrie(tr).put(key.toRange, val.toRange) SecureHexaryTrie(tr).put(key, val)
check tr.rootHashHex == "D7F8974FB5AC78D9AC099B9AD5018BEDC2CE0A72DAD1827A1709DA30580F0544" check tr.rootHashHex == "D7F8974FB5AC78D9AC099B9AD5018BEDC2CE0A72DAD1827A1709DA30580F0544"
# lexicographic comparison # lexicographic comparison
proc lexComp(a, b: BytesRange): bool = proc lexComp(a, b: seq[byte]): bool =
var var
x = 0 x = 0
y = 0 y = 0
@ -98,7 +89,7 @@ suite "hexary trie":
result = y != ylen result = y != ylen
proc cmp(a, b: BytesRange): int = proc cmp(a, b: seq[byte]): int =
if a == b: return 0 if a == b: return 0
if a.lexComp(b): return 1 if a.lexComp(b): return 1
return -1 return -1
@ -109,17 +100,17 @@ suite "hexary trie":
memdb = newMemoryDB() memdb = newMemoryDB()
trie = initHexaryTrie(memdb) trie = initHexaryTrie(memdb)
keys = [ keys = [
"key".toBytesRange, "key".toBytes,
"abc".toBytesRange, "abc".toBytes,
"hola".toBytesRange, "hola".toBytes,
"bubble".toBytesRange "bubble".toBytes
] ]
vals = [ vals = [
"hello".toBytesRange, "hello".toBytes,
"world".toBytesRange, "world".toBytes,
"block".toBytesRange, "block".toBytes,
"chain".toBytesRange "chain".toBytes
] ]
for i in 0 ..< keys.len: for i in 0 ..< keys.len:
@ -157,11 +148,11 @@ suite "hexary trie":
var var
memdb = newMemoryDB() memdb = newMemoryDB()
trie = initHexaryTrie(memdb) trie = initHexaryTrie(memdb)
keys = randList(BytesRange, randGen(5, 32), randGen(10)) keys = randList(seq[byte], randGen(5, 32), randGen(10))
vals = randList(BytesRange, randGen(5, 7), randGen(10)) vals = randList(seq[byte], randGen(5, 7), randGen(10))
keys2 = randList(BytesRange, randGen(5, 30), randGen(15)) keys2 = randList(seq[byte], randGen(5, 30), randGen(15))
vals2 = randList(BytesRange, randGen(5, 7), randGen(15)) vals2 = randList(seq[byte], randGen(5, 7), randGen(15))
for i in 0 ..< keys.len: for i in 0 ..< keys.len:
trie.put(keys[i], vals[i]) trie.put(keys[i], vals[i])
@ -226,11 +217,11 @@ suite "hexary trie":
var var
memdb = newMemoryDB() memdb = newMemoryDB()
nonPruningTrie = initHexaryTrie(memdb, false) nonPruningTrie = initHexaryTrie(memdb, false)
keys = randList(BytesRange, randGen(5, 77), randGen(30)) keys = randList(seq[byte], randGen(5, 77), randGen(30))
vals = randList(BytesRange, randGen(1, 57), randGen(30)) vals = randList(seq[byte], randGen(1, 57), randGen(30))
moreKeys = randList(BytesRange, randGen(5, 33), randGen(45)) moreKeys = randList(seq[byte], randGen(5, 33), randGen(45))
moreVals = randList(BytesRange, randGen(1, 47), randGen(45)) moreVals = randList(seq[byte], randGen(1, 47), randGen(45))
for i in 0 ..< keys.len: for i in 0 ..< keys.len:
nonPruningTrie.put(keys[i], vals[i]) nonPruningTrie.put(keys[i], vals[i])
@ -265,8 +256,8 @@ suite "hexary trie":
test "elaborate non-pruning test": test "elaborate non-pruning test":
type type
History = object History = object
keys: seq[BytesRange] keys: seq[seq[byte]]
values: seq[BytesRange] values: seq[seq[byte]]
rootHash: KeccakHash rootHash: KeccakHash
const const
@ -277,14 +268,14 @@ suite "hexary trie":
var var
memdb = newMemoryDB() memdb = newMemoryDB()
nonPruningTrie = initHexaryTrie(memdb, false) nonPruningTrie = initHexaryTrie(memdb, false)
keys = randList(BytesRange, randGen(3, 33), randGen(listLength)) keys = randList(seq[byte], randGen(3, 33), randGen(listLength))
values = randList(BytesRange, randGen(5, 77), randGen(listLength)) values = randList(seq[byte], randGen(5, 77), randGen(listLength))
historyList = newSeq[History](listLength) historyList = newSeq[History](listLength)
ok = true ok = true
for i, k in keys: for i, k in keys:
historyList[i].keys = newSeq[BytesRange](i + 1) historyList[i].keys = newSeq[seq[byte]](i + 1)
historyList[i].values = newSeq[BytesRange](i + 1) historyList[i].values = newSeq[seq[byte]](i + 1)
for x in 0 ..< i + 1: for x in 0 ..< i + 1:
historyList[i].keys[x] = keys[x] historyList[i].keys[x] = keys[x]
historyList[i].values[x] = values[x] historyList[i].values[x] = values[x]
@ -296,7 +287,7 @@ suite "hexary trie":
for h in historyList: for h in historyList:
var var
trie = initHexaryTrie(memdb, h.rootHash) trie = initHexaryTrie(memdb, h.rootHash)
pKeys: seq[BytesRange] = @[] pKeys: seq[seq[byte]] = @[]
pValues = trie.getValues() pValues = trie.getValues()
for k in trie.keys: for k in trie.keys:
@ -318,7 +309,7 @@ suite "hexary trie":
echo "ITERATION: ", iteration echo "ITERATION: ", iteration
break break
proc isValidBranch(branch: seq[BytesRange], rootHash: KeccakHash, key, value: BytesRange): bool = proc isValidBranch(branch: seq[seq[byte]], rootHash: KeccakHash, key, value: seq[byte]): bool =
# branch must not be empty # branch must not be empty
doAssert(branch.len != 0) doAssert(branch.len != 0)
@ -326,17 +317,17 @@ suite "hexary trie":
for node in branch: for node in branch:
doAssert(node.len != 0) doAssert(node.len != 0)
let nodeHash = hexary.keccak(node) let nodeHash = hexary.keccak(node)
db.put(nodeHash.data, node.toOpenArray) db.put(nodeHash.data, node)
var trie = initHexaryTrie(db, rootHash) var trie = initHexaryTrie(db, rootHash)
result = trie.get(key) == toRange(value) result = trie.get(key) == value
test "get branch with pruning trie": test "get branch with pruning trie":
var var
memdb = newMemoryDB() memdb = newMemoryDB()
trie = initHexaryTrie(memdb) trie = initHexaryTrie(memdb)
keys = randList(BytesRange, randGen(5, 77), randGen(30)) keys = randList(seq[byte], randGen(5, 77), randGen(30))
vals = randList(BytesRange, randGen(1, 57), randGen(30)) vals = randList(seq[byte], randGen(1, 57), randGen(30))
for i in 0 ..< keys.len: for i in 0 ..< keys.len:
trie.put(keys[i], vals[i]) trie.put(keys[i], vals[i])
@ -352,8 +343,8 @@ suite "hexary trie":
var var
memdb = newMemoryDB() memdb = newMemoryDB()
nonPruningTrie = initHexaryTrie(memdb, false) nonPruningTrie = initHexaryTrie(memdb, false)
keys = randList(BytesRange, randGen(5, 77), randGen(numKeyVal)) keys = randList(seq[byte], randGen(5, 77), randGen(numKeyVal))
vals = randList(BytesRange, randGen(1, 57), randGen(numKeyVal)) vals = randList(seq[byte], randGen(1, 57), randGen(numKeyVal))
roots = newSeq[KeccakHash](numKeyVal) roots = newSeq[KeccakHash](numKeyVal)
for i in 0 ..< keys.len: for i in 0 ..< keys.len:
@ -388,9 +379,9 @@ suite "hexary trie":
pruningTrie = initHexaryTrie(memdb, isPruning = true) pruningTrie = initHexaryTrie(memdb, isPruning = true)
let let
keys = randList(BytesRange, randGen(5, 77), randGen(numKeyVal)) keys = randList(seq[byte], randGen(5, 77), randGen(numKeyVal))
vals = randList(BytesRange, randGen(1, 57), randGen(numKeyVal)) vals = randList(seq[byte], randGen(1, 57), randGen(numKeyVal))
newVals = randList(BytesRange, randGen(1, 63), randGen(numKeyVal)) newVals = randList(seq[byte], randGen(1, 63), randGen(numKeyVal))
var tx1 = memdb.beginTransaction() var tx1 = memdb.beginTransaction()
for i in 0 ..< numKeyVal: for i in 0 ..< numKeyVal:
@ -426,15 +417,15 @@ suite "hexary trie":
pruningTrie = initHexaryTrie(memdb, isPruning = true) pruningTrie = initHexaryTrie(memdb, isPruning = true)
let let
keys = randList(BytesRange, randGen(5, 77), randGen(numKeyVal)) keys = randList(seq[byte], randGen(5, 77), randGen(numKeyVal))
vals = randList(BytesRange, randGen(1, 57), randGen(numKeyVal)) vals = randList(seq[byte], randGen(1, 57), randGen(numKeyVal))
for i in 0 ..< numKeyVal: for i in 0 ..< numKeyVal:
pruningTrie.put(keys[i], vals[i]) pruningTrie.put(keys[i], vals[i])
let rootHash = pruningTrie.rootHash let rootHash = pruningTrie.rootHash
for k, v in pruningTrie.replicate: for k, v in pruningTrie.replicate:
repdb.put(k.toOpenArray, v.toOpenArray) repdb.put(k, v)
var trie = initHexaryTrie(repdb, rootHash, isPruning = true) var trie = initHexaryTrie(repdb, rootHash, isPruning = true)
var numPairs = 0 var numPairs = 0

View File

@ -2,15 +2,14 @@
import import
os, json, tables, strutils, algorithm, os, json, tables, strutils, algorithm,
eth/rlp/types, eth/trie/[db, hexary],
eth/trie/[trie_defs, db, hexary], stew/byteutils
./testutils
type type
TestOp = object TestOp = object
idx: int idx: int
key: BytesRange key: seq[byte]
value: BytesRange value: seq[byte]
proc cmp(lhs, rhs: TestOp): int = cmp(lhs.idx, rhs.idx) proc cmp(lhs, rhs: TestOp): int = cmp(lhs.idx, rhs.idx)
proc `<=`(lhs, rhs: TestOp): bool = lhs.idx <= rhs.idx proc `<=`(lhs, rhs: TestOp): bool = lhs.idx <= rhs.idx
@ -76,12 +75,12 @@ proc runTests*(filename: string) =
case v.kind case v.kind
of JString: of JString:
inputs.add(TestOp(idx: inputs.len, inputs.add(TestOp(idx: inputs.len,
key: k.str.toBytesRange, key: k.str.toBytes,
value: v.str.toBytesRange)) value: v.str.toBytes))
of JNull: of JNull:
inputs.add(TestOp(idx: inputs.len, inputs.add(TestOp(idx: inputs.len,
key: k.str.toBytesRange, key: k.str.toBytes,
value: zeroBytesRange)) value: @[]))
else: invalidTest() else: invalidTest()
else: invalidTest() else: invalidTest()
@ -91,12 +90,12 @@ proc runTests*(filename: string) =
case v.kind case v.kind
of JString: of JString:
inputs.add(TestOp(idx: inputs.len, inputs.add(TestOp(idx: inputs.len,
key: k.toBytesRange, key: k.toBytes,
value: v.str.toBytesRange)) value: v.str.toBytes))
of JNull: of JNull:
inputs.add(TestOp(idx: inputs.len, inputs.add(TestOp(idx: inputs.len,
key: k.toBytesRange, key: k.toBytes,
value: zeroBytesRange)) value: @[]))
else: invalidTest() else: invalidTest()
else: invalidTest() else: invalidTest()

View File

@ -1,11 +0,0 @@
{.used.}
import
unittest,
eth/trie/nibbles
suite "nibbles":
test "zeroNibblesRange":
# https://github.com/status-im/nim-eth/issues/6
check zeroNibblesRange.len == 0

View File

@ -1,8 +1,8 @@
{.used.} {.used.}
import import
unittest, random, unittest, random, stew/byteutils,
eth/trie/[trie_defs, db, sparse_binary, sparse_proofs], eth/trie/[db, sparse_binary, sparse_proofs],
./testutils ./testutils
suite "sparse binary trie": suite "sparse binary trie":
@ -21,14 +21,14 @@ suite "sparse binary trie":
test "basic get": test "basic get":
for c in kv_pairs: for c in kv_pairs:
let x = trie.get(c.key) let x = trie.get(c.key)
let y = toRange(c.value) let y = c.value
check x == y check x == y
trie.del(c.key) trie.del(c.key)
for c in kv_pairs: for c in kv_pairs:
check trie.exists(c.key) == false check trie.exists(c.key) == false
check trie.getRootHash() == keccakHash(emptyNodeHashes[0].toOpenArray, emptyNodeHashes[0].toOpenArray).toRange check trie.getRootHash() == keccakHash(emptyNodeHashes[0].data, emptyNodeHashes[0].data).data
test "single update set": test "single update set":
random.shuffle(kv_pairs) random.shuffle(kv_pairs)
@ -42,11 +42,11 @@ suite "sparse binary trie":
test "single update get": test "single update get":
for i in numbers: for i in numbers:
# If new value is the same as current value, skip the update # If new value is the same as current value, skip the update
if toRange($i) == trie.get(kv_pairs[i].key): if toBytes($i) == trie.get(kv_pairs[i].key):
continue continue
# Update # Update
trie.set(kv_pairs[i].key, $i) trie.set(kv_pairs[i].key, toBytes($i))
check trie.get(kv_pairs[i].key) == toRange($i) check trie.get(kv_pairs[i].key) == toBytes($i)
check trie.getRootHash() != prior_to_update_root check trie.getRootHash() != prior_to_update_root
# Un-update # Un-update
@ -56,7 +56,7 @@ suite "sparse binary trie":
test "batch update with different update order": test "batch update with different update order":
# First batch update # First batch update
for i in numbers: for i in numbers:
trie.set(kv_pairs[i].key, $i) trie.set(kv_pairs[i].key, toBytes($i))
let batch_updated_root = trie.getRootHash() let batch_updated_root = trie.getRootHash()
@ -70,14 +70,14 @@ suite "sparse binary trie":
# Second batch update # Second batch update
random.shuffle(numbers) random.shuffle(numbers)
for i in numbers: for i in numbers:
trie.set(kv_pairs[i].key, $i) trie.set(kv_pairs[i].key, toBytes($i))
check trie.getRootHash() == batch_updated_root check trie.getRootHash() == batch_updated_root
test "dictionary API": test "dictionary API":
trie[kv_pairs[0].key] = kv_pairs[0].value trie[kv_pairs[0].key] = kv_pairs[0].value
let x = trie[kv_pairs[0].key] let x = trie[kv_pairs[0].key]
let y = toRange(kv_pairs[0].value) let y = kv_pairs[0].value
check x == y check x == y
check kv_pairs[0].key in trie check kv_pairs[0].key in trie
@ -85,10 +85,10 @@ suite "sparse binary trie":
db = newMemoryDB() db = newMemoryDB()
trie = initSparseBinaryTrie(db) trie = initSparseBinaryTrie(db)
let let
testKey = toRange(kv_pairs[0].key) testKey = kv_pairs[0].key
testValue = toRange(kv_pairs[0].value) testValue = kv_pairs[0].value
testKey2 = toRange(kv_pairs[1].key) testKey2 = kv_pairs[1].key
testValue2 = toRange(kv_pairs[1].value) testValue2 = kv_pairs[1].value
trie.set(testKey, testValue) trie.set(testKey, testValue)
var root = trie.getRootHash() var root = trie.getRootHash()
@ -102,11 +102,11 @@ suite "sparse binary trie":
value = trie.get(testKey, root) value = trie.get(testKey, root)
check value == testValue check value == testValue
proc makeBadProof(size: int, width = 32): seq[BytesRange] = proc makeBadProof(size: int, width = 32): seq[seq[byte]] =
let badProofStr = randList(string, randGen(width, width), randGen(size, size)) let badProofStr = randList(seq[byte], randGen(width, width), randGen(size, size))
result = newSeq[BytesRange](size) result = newSeq[seq[byte]](size)
for i in 0 ..< result.len: for i in 0 ..< result.len:
result[i] = toRange(badProofStr[i]) result[i] = badProofStr[i]
test "proofs": test "proofs":
const const
@ -115,9 +115,9 @@ suite "sparse binary trie":
let let
testKey = kv_pairs[0].key testKey = kv_pairs[0].key
badKey = kv_pairs[1].key badKey = kv_pairs[1].key
testValue = "testValue" testValue = "testValue".toBytes
testValue2 = "testValue2" testValue2 = "testValue2".toBytes
badValue = "badValue" badValue = "badValue".toBytes
badProof = makeBadProof(MaxBadProof) badProof = makeBadProof(MaxBadProof)
trie[testKey] = testValue trie[testKey] = testValue
@ -131,7 +131,7 @@ suite "sparse binary trie":
let let
testKey2 = kv_pairs[2].key testKey2 = kv_pairs[2].key
testKey3 = kv_pairs[3].key testKey3 = kv_pairs[3].key
defaultValue = zeroBytesRange defaultValue = default(seq[byte])
trie.set(testKey2, testValue) trie.set(testKey2, testValue)
proof = trie.prove(testKey) proof = trie.prove(testKey)
@ -169,7 +169,7 @@ suite "sparse binary trie":
check compactProof(badProof2).len == 0 check compactProof(badProof2).len == 0
check compactProof(badProof3).len == 0 check compactProof(badProof3).len == 0
check decompactProof(badProof3).len == 0 check decompactProof(badProof3).len == 0
var zeroProof: seq[BytesRange] var zeroProof: seq[seq[byte]]
check decompactProof(zeroProof).len == 0 check decompactProof(zeroProof).len == 0
proof = trie.proveCompact(testKey2) proof = trie.proveCompact(testKey2)
@ -202,23 +202,23 @@ suite "sparse binary trie":
test "examples": test "examples":
let let
key1 = "01234567890123456789" key1 = "01234567890123456789".toBytes
key2 = "abcdefghijklmnopqrst" key2 = "abcdefghijklmnopqrst".toBytes
trie.set(key1, "value1") trie.set(key1, "value1".toBytes)
trie.set(key2, "value2") trie.set(key2, "value2".toBytes)
check trie.get(key1) == "value1".toRange check trie.get(key1) == "value1".toBytes
check trie.get(key2) == "value2".toRange check trie.get(key2) == "value2".toBytes
trie.del(key1) trie.del(key1)
check trie.get(key1) == zeroBytesRange check trie.get(key1) == []
trie.del(key2) trie.del(key2)
check trie[key2] == zeroBytesRange check trie[key2] == []
let let
value1 = "hello world" value1 = "hello world".toBytes
badValue = "bad value" badValue = "bad value".toBytes
trie[key1] = value1 trie[key1] = value1
var proof = trie.prove(key1) var proof = trie.prove(key1)

View File

@ -2,8 +2,7 @@
import import
unittest, unittest,
eth/trie/[db, trie_defs], ./testutils, eth/trie/[db], ./testutils
eth/rlp/types as rlpTypes
suite "transaction db": suite "transaction db":
setup: setup:
@ -11,10 +10,10 @@ suite "transaction db":
listLength = 30 listLength = 30
var var
keysA = randList(Bytes, randGen(3, 33), randGen(listLength)) keysA = randList(seq[byte], randGen(3, 33), randGen(listLength))
valuesA = randList(Bytes, randGen(5, 77), randGen(listLength)) valuesA = randList(seq[byte], randGen(5, 77), randGen(listLength))
keysB = randList(Bytes, randGen(3, 33), randGen(listLength)) keysB = randList(seq[byte], randGen(3, 33), randGen(listLength))
valuesB = randList(Bytes, randGen(5, 77), randGen(listLength)) valuesB = randList(seq[byte], randGen(5, 77), randGen(listLength))
proc populateA(db: TrieDatabaseRef) = proc populateA(db: TrieDatabaseRef) =
for i in 0 ..< listLength: for i in 0 ..< listLength:

View File

@ -0,0 +1,84 @@
{.used.}
import
random, unittest,
eth/trie/trie_bitseq
proc randomBytes(n: int): seq[byte] =
result = newSeq[byte](n)
for i in 0 ..< result.len:
result[i] = byte(rand(256))
suite "trie bitseq":
test "basic":
var a = @[byte 0b10101010, 0b11110000, 0b00001111, 0b01010101]
var bSeq = @[byte 0b10101010, 0b00000000, 0b00000000, 0b11111111]
var b = bits(bSeq, 8)
var cSeq = @[byte 0b11110000, 0b00001111, 0b00000000, 0b00000000]
var c = bits(cSeq, 16)
var dSeq = @[byte 0b00001111, 0b00000000, 0b00000000, 0b00000000]
var d = bits(dSeq, 8)
var eSeq = @[byte 0b01010101, 0b00000000, 0b00000000, 0b00000000]
var e = bits(eSeq, 8)
var m = a.bits
var n = m[0..7]
check n == b
check n.len == 8
check b.len == 8
check c == m[8..23]
check $(d) == "00001111"
check $(e) == "01010101"
var f = int.fromBits(e, 0, 4)
check f == 0b0101
let k = n & d
check(k.len == n.len + d.len)
check($k == $n & $d)
var asciiSeq = @[byte('A'),byte('S'),byte('C'),byte('I'),byte('I')]
let asciiBits = bits(asciiSeq)
check $asciiBits == "0100000101010011010000110100100101001001"
test "concat operator":
randomize(5000)
for i in 0..<256:
var xSeq = randomBytes(rand(i))
var ySeq = randomBytes(rand(i))
let x = xSeq.bits
let y = ySeq.bits
var z = x & y
check z.len == x.len + y.len
check($z == $x & $y)
test "get set bits":
randomize(1000)
for i in 0..<256:
# produce random vector
var xSeq = randomBytes(i)
var ySeq = randomBytes(i)
var x = xSeq.bits
var y = ySeq.bits
for idx, bit in x:
y[idx] = bit
check x == y
test "constructor with start":
var a = @[byte 0b10101010, 0b11110000, 0b00001111, 0b01010101]
var b = a.bits(1, 8)
check b.len == 8
check b[0] == false
check $b == "01010101"
b[0] = true
check $b == "11010101"
check b[0] == true
b.pushFront(false)
check b[0] == false
check $b == "011010101"

View File

@ -1,6 +1,6 @@
import import
random, sets, eth/trie/trie_utils as ethUtils, random, sets,
eth/rlp/types as rlpTypes, stew/ranges/bitranges, eth/trie/trie_bitseq,
nimcrypto/[utils, sysrand] nimcrypto/[utils, sysrand]
type type
@ -8,8 +8,8 @@ type
minVal, maxVal: T minVal, maxVal: T
KVPair* = ref object KVPair* = ref object
key*: string key*: seq[byte]
value*: string value*: seq[byte]
proc randGen*[T](minVal, maxVal: T): RandGen[T] = proc randGen*[T](minVal, maxVal: T): RandGen[T] =
doAssert(minVal <= maxVal) doAssert(minVal <= maxVal)
@ -28,11 +28,11 @@ proc randString*(len: int): string =
for i in 0..<len: for i in 0..<len:
result[i] = rand(255).char result[i] = rand(255).char
proc randBytes*(len: int): Bytes = proc randBytes*(len: int): seq[byte] =
result = newSeq[byte](len) result = newSeq[byte](len)
discard randomBytes(result[0].addr, len) discard randomBytes(result[0].addr, len)
proc toBytesRange*(str: string): BytesRange = proc toBytesRange*(str: string): seq[byte] =
var s: seq[byte] var s: seq[byte]
if str[0] == '0' and str[1] == 'x': if str[0] == '0' and str[1] == 'x':
s = fromHex(str.substr(2)) s = fromHex(str.substr(2))
@ -40,16 +40,16 @@ proc toBytesRange*(str: string): BytesRange =
s = newSeq[byte](str.len) s = newSeq[byte](str.len)
for i in 0 ..< str.len: for i in 0 ..< str.len:
s[i] = byte(str[i]) s[i] = byte(str[i])
result = s.toRange result = s
proc randPrimitives*[T](val: int): T = proc randPrimitives*[T](val: int): T =
when T is string: when T is string:
randString(val) randString(val)
elif T is int: elif T is int:
result = val result = val
elif T is BytesRange: elif T is string:
result = randString(val).toRange result = randString(val)
elif T is Bytes: elif T is seq[byte]:
result = randBytes(val) result = randBytes(val)
proc randList*(T: typedesc, strGen, listGen: RandGen, unique: bool = true): seq[T] = proc randList*(T: typedesc, strGen, listGen: RandGen, unique: bool = true): seq[T] =
@ -71,19 +71,14 @@ proc randList*(T: typedesc, strGen, listGen: RandGen, unique: bool = true): seq[
proc randKVPair*(keySize = 32): seq[KVPair] = proc randKVPair*(keySize = 32): seq[KVPair] =
const listLen = 100 const listLen = 100
let keys = randList(string, randGen(keySize, keySize), randGen(listLen, listLen)) let keys = randList(seq[byte], randGen(keySize, keySize), randGen(listLen, listLen))
let vals = randList(string, randGen(1, 100), randGen(listLen, listLen)) let vals = randList(seq[byte], randGen(1, 100), randGen(listLen, listLen))
result = newSeq[KVPair](listLen) result = newSeq[KVPair](listLen)
for i in 0..<listLen: for i in 0..<listLen:
result[i] = KVPair(key: keys[i], value: vals[i]) result[i] = KVPair(key: keys[i], value: vals[i])
proc toBytes*(str: string): Bytes = proc genBitVec*(len: int): TrieBitSeq =
result = newSeq[byte](str.len)
for i in 0..<str.len:
result[i] = byte(str[i])
proc genBitVec*(len: int): BitRange =
let k = ((len + 7) and (not 7)) shr 3 let k = ((len + 7) and (not 7)) shr 3
var s = newSeq[byte](k) var s = newSeq[byte](k)
result = bits(s, len) result = bits(s, len)