diff --git a/eth/trie/hexary_proof_verification.nim b/eth/trie/hexary_proof_verification.nim new file mode 100644 index 0000000..f94c8e2 --- /dev/null +++ b/eth/trie/hexary_proof_verification.nim @@ -0,0 +1,238 @@ +# proof verification +# Copyright (c) 2022 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.push raises: [Defect].} + +import + std/[tables, options, sequtils], + stew/results, + nimcrypto/[keccak, hash], + ".."/rlp, + "."/[trie_defs, nibbles, db] + +type + NextNodeKind = enum + EmptyValue, HashNode, ValueNode + + NextNodeResult = object + case kind: NextNodeKind + of EmptyValue: + discard + of HashNode: + nextNodeHash: seq[byte] + restOfTheKey: NibblesSeq + of ValueNode: + value: seq[byte] + + MptProofVerificationKind* = enum + ValidProof, InvalidProof, MissingKey + + MptProofVerificationResult* = object + case kind*: MptProofVerificationKind + of MissingKey: + discard + of InvalidProof: + errorMsg*: string + of ValidProof: + value*: seq[byte] + +func missingKey(): MptProofVerificationResult = + return MptProofVerificationResult(kind: MissingKey) + +func invalidProof(msg: string): MptProofVerificationResult = + return MptProofVerificationResult(kind: InvalidProof, errorMsg: msg) + +func validProof(value: seq[byte]): MptProofVerificationResult = + return MptProofVerificationResult(kind: ValidProof, value: value) + +func isValid*(res: MptProofVerificationResult): bool = + return res.kind == ValidProof + +func isMissing*(res: MptProofVerificationResult): bool = + return res.kind == MissingKey + +proc getListLen(rlp: Rlp): Result[int, string] = + try: + return ok(rlp.listLen) + except RlpError as e: + return err(e.msg) + +proc getListElem(rlp: Rlp, idx: int): Result[Rlp, string] = + if not rlp.isList: + return err("rlp element is not a list") + + try: + return ok(rlp.listElem(idx)) + except RlpError as e: + return err(e.msg) + +proc blobBytes(rlp: Rlp): Result[seq[byte], string] = + try: + return ok(rlp.toBytes) + except RlpError as e: + return err(e.msg) + +func rawBytesSeq(b: openArray[byte]): seq[byte] = + toSeq(b) + +proc getRawRlpBytes(rlp: Rlp): Result[seq[byte], string] = + try : + return ok(rawBytesSeq(rlp.rawData)) + except RlpError as e: + return err(e.msg) + +proc getNextNode(nodeRlp: Rlp, key: NibblesSeq): Result[NextNodeResult, string] = + var currNode = nodeRlp + var restKey = key + + template handleNextRef(nextRef: Rlp, keyLen: int) = + if not nextRef.hasData: + return err("invalid reference") + + if nextRef.isList: + let rawBytes = ? nextRef.getRawRlpBytes() + if len(rawBytes) > 32: + return err("Embedded node longer than 32 bytes") + else: + currNode = nextRef + restKey = restKey.slice(keyLen) + else: + let nodeBytes = ? nextRef.blobBytes() + if len(nodeBytes) == 32: + return ok( + NextNodeResult( + kind: HashNode, + nextNodeHash: nodeBytes, + restOfTheKey: restKey.slice(keyLen) + ) + ) + elif len(nodeBytes) == 0: + return ok(NextNodeResult(kind: EmptyValue)) + else: + return err("reference rlp blob should have 0 or 32 bytes") + + while true: + let listLen = ? currNode.getListLen() + + case listLen + of 2: + let + firstElem = ? currNode.getListElem(0) + blobBytes = ? firstElem.blobBytes() + + let (isLeaf, k) = hexPrefixDecode blobBytes + + # Paths have diverged, return empty result + if len(restKey) < len(k) or k != restKey.slice(0, len(k)): + return ok(NextNodeResult(kind: EmptyValue)) + + let nextRef = ? currNode.getListElem(1) + + if isLeaf: + let blobBytes = ? nextRef.blobBytes() + return ok(NextNodeResult(kind: ValueNode, value: blobBytes)) + + handleNextRef(nextRef, len(k)) + of 17: + if len(restKey) == 0: + let value = ? currNode.getListElem(16) + + if not value.hasData(): + return err("expected branch terminator") + + if value.isList(): + return err("branch value cannot be list") + + if value.isEmpty(): + return ok(NextNodeResult(kind: EmptyValue)) + else: + let bytes = ? value.blobBytes() + return ok(NextNodeResult(kind: ValueNode, value: bytes)) + else: + let nextRef = ? currNode.getListElem(restKey[0].int) + + handleNextRef(nextRef, 1) + else: + return err("Invalid list node ") + +proc verifyProof( + db: TrieDatabaseRef, + rootHash: seq[byte], + key: seq[byte]): Result[Option[seq[byte]], string] = + var currentKey = initNibbleRange(key) + + var currentHash = rootHash + + while true: + let node = db.get(currentHash) + + if len(node) == 0: + return err("missing expected node") + + let next = ? getNextNode(rlpFromBytes(node), currentKey) + + case next.kind + of EmptyValue: + return ok(none(seq[byte])) + of ValueNode: + return ok(some(next.value)) + of HashNode: + currentKey = next.restOfTheKey + currentHash = next.nextNodeHash + +proc verifyMptProof*( + branch: seq[seq[byte]], + rootHash: KeccakHash, + key: seq[byte], + value: seq[byte]): MptProofVerificationResult = + ## Verifies provided proof of inclusion (trie branch) against provided trie + ## root hash. + ## Distinguishes 3 possible results: + ## - proof is valid but key is not part of the trie + ## - proof is invalid + ## - proof is valid + ## In case of valid proof, value is extracted from the leaf node and compared + ## against provided value + ## + ## Main difference between this function and hexary.isValidBranch() is that + ## this function is meant for dealing with input from untrusted sources so: + ## - it should not have hidden assertion + ## - it should not have surprising exceptions + ## - it parses mpt nodes more strictly + ## + ## hexary.isValidBranch() is implemented via hexary trie `get` method which + ## may contain some checks important for integrity of the trie therefore is + ## is not really safe when receiving input from untrusted source. + + if len(branch) == 0: + return invalidProof("empty branch") + + var db = newMemoryDB() + for node in branch: + if len(node) == 0: + return invalidProof("empty mpt node in proof") + let nodeHash = keccakHash(node) + db.put(nodeHash.data, node) + + let + hashBytes: seq[byte] = toSeq(rootHash.data) + proofVerificationResult = verifyProof(db, hashBytes, key) + + if proofVerificationResult.isErr: + return invalidProof(proofVerificationResult.error) + + let maybeProofValue = proofVerificationResult.get() + + if maybeProofValue.isNone(): + return missingKey() + + let proofValue = maybeProofValue.unsafeGet() + + if proofValue == value: + return validProof(proofValue) + else: + return invalidProof("proof does not contain expected value") diff --git a/tests/trie/all_tests.nim b/tests/trie/all_tests.nim index e217fb6..f11cbde 100644 --- a/tests/trie/all_tests.nim +++ b/tests/trie/all_tests.nim @@ -7,4 +7,5 @@ import ./test_json_suite, ./test_sparse_binary_trie, ./test_transaction_db, - ./test_trie_bitseq + ./test_trie_bitseq, + ./test_hexary_proof diff --git a/tests/trie/test_hexary_proof.nim b/tests/trie/test_hexary_proof.nim new file mode 100644 index 0000000..091590e --- /dev/null +++ b/tests/trie/test_hexary_proof.nim @@ -0,0 +1,61 @@ +# proof verification +# Copyright (c) 2022 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.used.} + +{.push raises: [Defect].} + +import + unittest2, + stint, + std/sequtils, + nimcrypto/hash, + ../../eth/trie/[hexary, db, trie_defs, hexary_proof_verification] + +proc getKeyBytes(i: int): seq[byte] = + let hash = keccakHash(u256(i).toBytesBE()) + return toSeq(hash.data) + +suite "MPT trie proof verification": + test "Validate proof for existing value": + let numValues = 1000 + var db = newMemoryDB() + var trie = initHexaryTrie(db) + + for i in 1..numValues: + let bytes = getKeyBytes(i) + + trie.put(bytes, bytes) + + for i in 1..numValues: + let + kv = getKeyBytes(i) + proof = trie.getBranch(kv) + root = trie.rootHash() + res = verifyMptProof(proof, root, kv, kv) + + check: + res.isValid() + res.value == kv + + test "Validate proof for non-existing value": + let numValues = 1000 + var db = newMemoryDB() + var trie = initHexaryTrie(db) + + for i in 1..numValues: + let bytes = getKeyBytes(i) + trie.put(bytes, bytes) + + let + nonExistingKey = toSeq(keccakHash(toBytesBE(u256(numValues + 1))).data) + proof = trie.getBranch(nonExistingKey) + root = trie.rootHash() + res = verifyMptProof(proof, root, nonExistingKey, nonExistingKey) + + check: + res.isMissing()