mirror of https://github.com/status-im/nim-eth.git
parent
5885f638e4
commit
059d319c16
|
@ -0,0 +1,238 @@
|
||||||
|
# proof verification
|
||||||
|
# Copyright (c) 2022 Status Research & Development GmbH
|
||||||
|
# Licensed and distributed under either of
|
||||||
|
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
|
||||||
|
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
|
||||||
|
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||||
|
|
||||||
|
{.push raises: [Defect].}
|
||||||
|
|
||||||
|
import
|
||||||
|
std/[tables, options, sequtils],
|
||||||
|
stew/results,
|
||||||
|
nimcrypto/[keccak, hash],
|
||||||
|
".."/rlp,
|
||||||
|
"."/[trie_defs, nibbles, db]
|
||||||
|
|
||||||
|
type
|
||||||
|
NextNodeKind = enum
|
||||||
|
EmptyValue, HashNode, ValueNode
|
||||||
|
|
||||||
|
NextNodeResult = object
|
||||||
|
case kind: NextNodeKind
|
||||||
|
of EmptyValue:
|
||||||
|
discard
|
||||||
|
of HashNode:
|
||||||
|
nextNodeHash: seq[byte]
|
||||||
|
restOfTheKey: NibblesSeq
|
||||||
|
of ValueNode:
|
||||||
|
value: seq[byte]
|
||||||
|
|
||||||
|
MptProofVerificationKind* = enum
|
||||||
|
ValidProof, InvalidProof, MissingKey
|
||||||
|
|
||||||
|
MptProofVerificationResult* = object
|
||||||
|
case kind*: MptProofVerificationKind
|
||||||
|
of MissingKey:
|
||||||
|
discard
|
||||||
|
of InvalidProof:
|
||||||
|
errorMsg*: string
|
||||||
|
of ValidProof:
|
||||||
|
value*: seq[byte]
|
||||||
|
|
||||||
|
func missingKey(): MptProofVerificationResult =
|
||||||
|
return MptProofVerificationResult(kind: MissingKey)
|
||||||
|
|
||||||
|
func invalidProof(msg: string): MptProofVerificationResult =
|
||||||
|
return MptProofVerificationResult(kind: InvalidProof, errorMsg: msg)
|
||||||
|
|
||||||
|
func validProof(value: seq[byte]): MptProofVerificationResult =
|
||||||
|
return MptProofVerificationResult(kind: ValidProof, value: value)
|
||||||
|
|
||||||
|
func isValid*(res: MptProofVerificationResult): bool =
|
||||||
|
return res.kind == ValidProof
|
||||||
|
|
||||||
|
func isMissing*(res: MptProofVerificationResult): bool =
|
||||||
|
return res.kind == MissingKey
|
||||||
|
|
||||||
|
proc getListLen(rlp: Rlp): Result[int, string] =
|
||||||
|
try:
|
||||||
|
return ok(rlp.listLen)
|
||||||
|
except RlpError as e:
|
||||||
|
return err(e.msg)
|
||||||
|
|
||||||
|
proc getListElem(rlp: Rlp, idx: int): Result[Rlp, string] =
|
||||||
|
if not rlp.isList:
|
||||||
|
return err("rlp element is not a list")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ok(rlp.listElem(idx))
|
||||||
|
except RlpError as e:
|
||||||
|
return err(e.msg)
|
||||||
|
|
||||||
|
proc blobBytes(rlp: Rlp): Result[seq[byte], string] =
|
||||||
|
try:
|
||||||
|
return ok(rlp.toBytes)
|
||||||
|
except RlpError as e:
|
||||||
|
return err(e.msg)
|
||||||
|
|
||||||
|
func rawBytesSeq(b: openArray[byte]): seq[byte] =
|
||||||
|
toSeq(b)
|
||||||
|
|
||||||
|
proc getRawRlpBytes(rlp: Rlp): Result[seq[byte], string] =
|
||||||
|
try :
|
||||||
|
return ok(rawBytesSeq(rlp.rawData))
|
||||||
|
except RlpError as e:
|
||||||
|
return err(e.msg)
|
||||||
|
|
||||||
|
proc getNextNode(nodeRlp: Rlp, key: NibblesSeq): Result[NextNodeResult, string] =
|
||||||
|
var currNode = nodeRlp
|
||||||
|
var restKey = key
|
||||||
|
|
||||||
|
template handleNextRef(nextRef: Rlp, keyLen: int) =
|
||||||
|
if not nextRef.hasData:
|
||||||
|
return err("invalid reference")
|
||||||
|
|
||||||
|
if nextRef.isList:
|
||||||
|
let rawBytes = ? nextRef.getRawRlpBytes()
|
||||||
|
if len(rawBytes) > 32:
|
||||||
|
return err("Embedded node longer than 32 bytes")
|
||||||
|
else:
|
||||||
|
currNode = nextRef
|
||||||
|
restKey = restKey.slice(keyLen)
|
||||||
|
else:
|
||||||
|
let nodeBytes = ? nextRef.blobBytes()
|
||||||
|
if len(nodeBytes) == 32:
|
||||||
|
return ok(
|
||||||
|
NextNodeResult(
|
||||||
|
kind: HashNode,
|
||||||
|
nextNodeHash: nodeBytes,
|
||||||
|
restOfTheKey: restKey.slice(keyLen)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif len(nodeBytes) == 0:
|
||||||
|
return ok(NextNodeResult(kind: EmptyValue))
|
||||||
|
else:
|
||||||
|
return err("reference rlp blob should have 0 or 32 bytes")
|
||||||
|
|
||||||
|
while true:
|
||||||
|
let listLen = ? currNode.getListLen()
|
||||||
|
|
||||||
|
case listLen
|
||||||
|
of 2:
|
||||||
|
let
|
||||||
|
firstElem = ? currNode.getListElem(0)
|
||||||
|
blobBytes = ? firstElem.blobBytes()
|
||||||
|
|
||||||
|
let (isLeaf, k) = hexPrefixDecode blobBytes
|
||||||
|
|
||||||
|
# Paths have diverged, return empty result
|
||||||
|
if len(restKey) < len(k) or k != restKey.slice(0, len(k)):
|
||||||
|
return ok(NextNodeResult(kind: EmptyValue))
|
||||||
|
|
||||||
|
let nextRef = ? currNode.getListElem(1)
|
||||||
|
|
||||||
|
if isLeaf:
|
||||||
|
let blobBytes = ? nextRef.blobBytes()
|
||||||
|
return ok(NextNodeResult(kind: ValueNode, value: blobBytes))
|
||||||
|
|
||||||
|
handleNextRef(nextRef, len(k))
|
||||||
|
of 17:
|
||||||
|
if len(restKey) == 0:
|
||||||
|
let value = ? currNode.getListElem(16)
|
||||||
|
|
||||||
|
if not value.hasData():
|
||||||
|
return err("expected branch terminator")
|
||||||
|
|
||||||
|
if value.isList():
|
||||||
|
return err("branch value cannot be list")
|
||||||
|
|
||||||
|
if value.isEmpty():
|
||||||
|
return ok(NextNodeResult(kind: EmptyValue))
|
||||||
|
else:
|
||||||
|
let bytes = ? value.blobBytes()
|
||||||
|
return ok(NextNodeResult(kind: ValueNode, value: bytes))
|
||||||
|
else:
|
||||||
|
let nextRef = ? currNode.getListElem(restKey[0].int)
|
||||||
|
|
||||||
|
handleNextRef(nextRef, 1)
|
||||||
|
else:
|
||||||
|
return err("Invalid list node ")
|
||||||
|
|
||||||
|
proc verifyProof(
|
||||||
|
db: TrieDatabaseRef,
|
||||||
|
rootHash: seq[byte],
|
||||||
|
key: seq[byte]): Result[Option[seq[byte]], string] =
|
||||||
|
var currentKey = initNibbleRange(key)
|
||||||
|
|
||||||
|
var currentHash = rootHash
|
||||||
|
|
||||||
|
while true:
|
||||||
|
let node = db.get(currentHash)
|
||||||
|
|
||||||
|
if len(node) == 0:
|
||||||
|
return err("missing expected node")
|
||||||
|
|
||||||
|
let next = ? getNextNode(rlpFromBytes(node), currentKey)
|
||||||
|
|
||||||
|
case next.kind
|
||||||
|
of EmptyValue:
|
||||||
|
return ok(none(seq[byte]))
|
||||||
|
of ValueNode:
|
||||||
|
return ok(some(next.value))
|
||||||
|
of HashNode:
|
||||||
|
currentKey = next.restOfTheKey
|
||||||
|
currentHash = next.nextNodeHash
|
||||||
|
|
||||||
|
proc verifyMptProof*(
|
||||||
|
branch: seq[seq[byte]],
|
||||||
|
rootHash: KeccakHash,
|
||||||
|
key: seq[byte],
|
||||||
|
value: seq[byte]): MptProofVerificationResult =
|
||||||
|
## Verifies provided proof of inclusion (trie branch) against provided trie
|
||||||
|
## root hash.
|
||||||
|
## Distinguishes 3 possible results:
|
||||||
|
## - proof is valid but key is not part of the trie
|
||||||
|
## - proof is invalid
|
||||||
|
## - proof is valid
|
||||||
|
## In case of valid proof, value is extracted from the leaf node and compared
|
||||||
|
## against provided value
|
||||||
|
##
|
||||||
|
## Main difference between this function and hexary.isValidBranch() is that
|
||||||
|
## this function is meant for dealing with input from untrusted sources so:
|
||||||
|
## - it should not have hidden assertion
|
||||||
|
## - it should not have surprising exceptions
|
||||||
|
## - it parses mpt nodes more strictly
|
||||||
|
##
|
||||||
|
## hexary.isValidBranch() is implemented via hexary trie `get` method which
|
||||||
|
## may contain some checks important for integrity of the trie therefore is
|
||||||
|
## is not really safe when receiving input from untrusted source.
|
||||||
|
|
||||||
|
if len(branch) == 0:
|
||||||
|
return invalidProof("empty branch")
|
||||||
|
|
||||||
|
var db = newMemoryDB()
|
||||||
|
for node in branch:
|
||||||
|
if len(node) == 0:
|
||||||
|
return invalidProof("empty mpt node in proof")
|
||||||
|
let nodeHash = keccakHash(node)
|
||||||
|
db.put(nodeHash.data, node)
|
||||||
|
|
||||||
|
let
|
||||||
|
hashBytes: seq[byte] = toSeq(rootHash.data)
|
||||||
|
proofVerificationResult = verifyProof(db, hashBytes, key)
|
||||||
|
|
||||||
|
if proofVerificationResult.isErr:
|
||||||
|
return invalidProof(proofVerificationResult.error)
|
||||||
|
|
||||||
|
let maybeProofValue = proofVerificationResult.get()
|
||||||
|
|
||||||
|
if maybeProofValue.isNone():
|
||||||
|
return missingKey()
|
||||||
|
|
||||||
|
let proofValue = maybeProofValue.unsafeGet()
|
||||||
|
|
||||||
|
if proofValue == value:
|
||||||
|
return validProof(proofValue)
|
||||||
|
else:
|
||||||
|
return invalidProof("proof does not contain expected value")
|
|
@ -7,4 +7,5 @@ import
|
||||||
./test_json_suite,
|
./test_json_suite,
|
||||||
./test_sparse_binary_trie,
|
./test_sparse_binary_trie,
|
||||||
./test_transaction_db,
|
./test_transaction_db,
|
||||||
./test_trie_bitseq
|
./test_trie_bitseq,
|
||||||
|
./test_hexary_proof
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
# proof verification
|
||||||
|
# Copyright (c) 2022 Status Research & Development GmbH
|
||||||
|
# Licensed and distributed under either of
|
||||||
|
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
|
||||||
|
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
|
||||||
|
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||||
|
|
||||||
|
{.used.}
|
||||||
|
|
||||||
|
{.push raises: [Defect].}
|
||||||
|
|
||||||
|
import
|
||||||
|
unittest2,
|
||||||
|
stint,
|
||||||
|
std/sequtils,
|
||||||
|
nimcrypto/hash,
|
||||||
|
../../eth/trie/[hexary, db, trie_defs, hexary_proof_verification]
|
||||||
|
|
||||||
|
proc getKeyBytes(i: int): seq[byte] =
|
||||||
|
let hash = keccakHash(u256(i).toBytesBE())
|
||||||
|
return toSeq(hash.data)
|
||||||
|
|
||||||
|
suite "MPT trie proof verification":
|
||||||
|
test "Validate proof for existing value":
|
||||||
|
let numValues = 1000
|
||||||
|
var db = newMemoryDB()
|
||||||
|
var trie = initHexaryTrie(db)
|
||||||
|
|
||||||
|
for i in 1..numValues:
|
||||||
|
let bytes = getKeyBytes(i)
|
||||||
|
|
||||||
|
trie.put(bytes, bytes)
|
||||||
|
|
||||||
|
for i in 1..numValues:
|
||||||
|
let
|
||||||
|
kv = getKeyBytes(i)
|
||||||
|
proof = trie.getBranch(kv)
|
||||||
|
root = trie.rootHash()
|
||||||
|
res = verifyMptProof(proof, root, kv, kv)
|
||||||
|
|
||||||
|
check:
|
||||||
|
res.isValid()
|
||||||
|
res.value == kv
|
||||||
|
|
||||||
|
test "Validate proof for non-existing value":
|
||||||
|
let numValues = 1000
|
||||||
|
var db = newMemoryDB()
|
||||||
|
var trie = initHexaryTrie(db)
|
||||||
|
|
||||||
|
for i in 1..numValues:
|
||||||
|
let bytes = getKeyBytes(i)
|
||||||
|
trie.put(bytes, bytes)
|
||||||
|
|
||||||
|
let
|
||||||
|
nonExistingKey = toSeq(keccakHash(toBytesBE(u256(numValues + 1))).data)
|
||||||
|
proof = trie.getBranch(nonExistingKey)
|
||||||
|
root = trie.rootHash()
|
||||||
|
res = verifyMptProof(proof, root, nonExistingKey, nonExistingKey)
|
||||||
|
|
||||||
|
check:
|
||||||
|
res.isMissing()
|
Loading…
Reference in New Issue