diff --git a/evm/Cargo.toml b/evm/Cargo.toml index bcbc5cd9..7c179318 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -14,7 +14,7 @@ ethereum-types = "0.14.0" hex = { version = "0.4.3", optional = true } hex-literal = "0.3.4" itertools = "0.10.3" -keccak-hash = "0.9.0" +keccak-hash = "0.10.0" log = "0.4.14" num = "0.4.0" maybe_rayon = { path = "../maybe_rayon" } diff --git a/evm/spec/tries.tex b/evm/spec/mpts.tex similarity index 53% rename from evm/spec/tries.tex rename to evm/spec/mpts.tex index 7ec0fcce..49d1d328 100644 --- a/evm/spec/tries.tex +++ b/evm/spec/mpts.tex @@ -6,21 +6,21 @@ Withour our zkEVM's kernel memory, \begin{enumerate} \item An empty node is encoded as $(\texttt{MPT\_NODE\_EMPTY})$. - \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, c_1, \dots, c_{16}, \abs{v}, v)$, where each $c_i$ is a pointer to a child node, and $v$ is a value of length $\abs{v}$.\footnote{If a branch node has no associated value, then $\abs{v} = 0$ and $v = ()$.} + \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, c_1, \dots, c_{16}, v)$, where each $c_i$ is a pointer to a child node, and $v$ is a pointer to a value. If a branch node has no associated value, then $v = 0$, i.e. the null pointer. \item An extension node is encoded as $(\texttt{MPT\_NODE\_EXTENSION}, k, c)$, $k$ represents the part of the key associated with this extension, and is encoded as a 2-tuple $(\texttt{packed\_nibbles}, \texttt{num\_nibbles})$. $c$ is a pointer to a child node. - \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, \abs{v}, v)$, where $k$ is a 2-tuple as above, and $v$ is a leaf payload. - \item A digest node is encoded as $(\texttt{MPT\_NODE\_DIGEST}, d)$, where $d$ is a Keccak256 digest. + \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, v)$, where $k$ is a 2-tuple as above, and $v$ is a pointer to a value. + \item A digest node is encoded as $(\texttt{MPT\_NODE\_HASH}, d)$, where $d$ is a Keccak256 digest. \end{enumerate} \subsection{Prover input format} -The initial state of each trie is given by the prover as a nondeterministic input tape. This tape has a similar format: +The initial state of each trie is given by the prover as a nondeterministic input tape. This tape has a slightly different format: \begin{enumerate} \item An empty node is encoded as $(\texttt{MPT\_NODE\_EMPTY})$. - \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, \abs{v}, v, c_1, \dots, c_{16})$, where $\abs{v}$ is the length of the value, and $v$ is the value itself. Each $c_i$ is the encoding of a child node. + \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, v_?, c_1, \dots, c_{16})$. Here $v_?$ consists of a flag indicating whether a value is present,\todo{In the current implementation, we use a length prefix rather than a is-present prefix, but we plan to change that.} followed by the actual value payload if one is present. Each $c_i$ is the encoding of a child node. \item An extension node is encoded as $(\texttt{MPT\_NODE\_EXTENSION}, k, c)$, $k$ represents the part of the key associated with this extension, and is encoded as a 2-tuple $(\texttt{packed\_nibbles}, \texttt{num\_nibbles})$. $c$ is a pointer to a child node. - \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, \abs{v}, v)$, where $k$ is a 2-tuple as above, and $v$ is a leaf payload. - \item A digest node is encoded as $(\texttt{MPT\_NODE\_DIGEST}, d)$, where $d$ is a Keccak256 digest. + \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, v)$, where $k$ is a 2-tuple as above, and $v$ is a value payload. + \item A digest node is encoded as $(\texttt{MPT\_NODE\_HASH}, d)$, where $d$ is a Keccak256 digest. \end{enumerate} -Nodes are thus given in depth-first order, leading to natural recursive methods for encoding and decoding this format. +Nodes are thus given in depth-first order, enabling natural recursive methods for encoding and decoding this format. diff --git a/evm/spec/zkevm.pdf b/evm/spec/zkevm.pdf index 184ba36b..f181eba6 100644 Binary files a/evm/spec/zkevm.pdf and b/evm/spec/zkevm.pdf differ diff --git a/evm/spec/zkevm.tex b/evm/spec/zkevm.tex index 65766986..2927e7a5 100644 --- a/evm/spec/zkevm.tex +++ b/evm/spec/zkevm.tex @@ -51,7 +51,7 @@ \input{introduction} \input{framework} \input{tables} -\input{tries} +\input{mpts} \input{instructions} \bibliography{bibliography}{} diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 15b11cd6..843db031 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -42,15 +42,16 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/memory/metadata.asm"), include_str!("asm/memory/packing.asm"), include_str!("asm/memory/txn_fields.asm"), - include_str!("asm/mpt/delete.asm"), - include_str!("asm/mpt/hash.asm"), - include_str!("asm/mpt/hash_trie_specific.asm"), + include_str!("asm/mpt/delete/delete.asm"), + include_str!("asm/mpt/hash/hash.asm"), + include_str!("asm/mpt/hash/hash_trie_specific.asm"), include_str!("asm/mpt/hex_prefix.asm"), - include_str!("asm/mpt/insert.asm"), - include_str!("asm/mpt/insert_extension.asm"), - include_str!("asm/mpt/insert_leaf.asm"), - include_str!("asm/mpt/insert_trie_specific.asm"), - include_str!("asm/mpt/load.asm"), + include_str!("asm/mpt/insert/insert.asm"), + include_str!("asm/mpt/insert/insert_extension.asm"), + include_str!("asm/mpt/insert/insert_leaf.asm"), + include_str!("asm/mpt/insert/insert_trie_specific.asm"), + include_str!("asm/mpt/load/load.asm"), + include_str!("asm/mpt/load/load_trie_specific.asm"), include_str!("asm/mpt/read.asm"), include_str!("asm/mpt/storage_read.asm"), include_str!("asm/mpt/storage_write.asm"), diff --git a/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm b/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm index 96e177ff..a1c2ff3c 100644 --- a/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm +++ b/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm @@ -132,7 +132,7 @@ pubkey_to_addr: // stack: PKx, PKy, retdest PUSH 0 // stack: 0, PKx, PKy, retdest - MSTORE // TODO: switch to kernel memory (like `%mstore_current(@SEGMENT_KERNEL_GENERAL)`). + MSTORE // TODO: switch to kernel memory (like `%mstore_kernel(@SEGMENT_KERNEL_GENERAL)`). // stack: PKy, retdest PUSH 0x20 // stack: 0x20, PKy, retdest diff --git a/evm/src/cpu/kernel/asm/mpt/delete.asm b/evm/src/cpu/kernel/asm/mpt/delete/delete.asm similarity index 100% rename from evm/src/cpu/kernel/asm/mpt/delete.asm rename to evm/src/cpu/kernel/asm/mpt/delete/delete.asm diff --git a/evm/src/cpu/kernel/asm/mpt/hash.asm b/evm/src/cpu/kernel/asm/mpt/hash/hash.asm similarity index 91% rename from evm/src/cpu/kernel/asm/mpt/hash.asm rename to evm/src/cpu/kernel/asm/mpt/hash/hash.asm index 8342d650..9fe0edef 100644 --- a/evm/src/cpu/kernel/asm/mpt/hash.asm +++ b/evm/src/cpu/kernel/asm/mpt/hash/hash.asm @@ -47,7 +47,29 @@ mpt_hash_hash_rlp_after_unpacking: // Pre stack: node_ptr, encode_value, retdest // Post stack: result, result_len global encode_or_hash_node: - %stack (node_ptr, encode_value) -> (node_ptr, encode_value, maybe_hash_node) + // stack: node_ptr, encode_value, retdest + DUP1 %mload_trie_data + + // Check if we're dealing with a concrete node, i.e. not a hash node. + // stack: node_type, node_ptr, encode_value, retdest + DUP1 + PUSH @MPT_NODE_HASH + SUB + %jumpi(encode_or_hash_concrete_node) + + // If we got here, node_type == @MPT_NODE_HASH. + // Load the hash and return (hash, 32). + // stack: node_type, node_ptr, encode_value, retdest + POP + // stack: node_ptr, encode_value, retdest + %increment // Skip over node type prefix + // stack: hash_ptr, encode_value, retdest + %mload_trie_data + // stack: hash, encode_value, retdest + %stack (hash, encode_value, retdest) -> (retdest, hash, 32) + JUMP +encode_or_hash_concrete_node: + %stack (node_type, node_ptr, encode_value) -> (node_type, node_ptr, encode_value, maybe_hash_node) %jump(encode_node) maybe_hash_node: // stack: result_ptr, result_len, retdest @@ -75,22 +97,22 @@ after_packed_small_rlp: // RLP encode the given trie node, and return an (pointer, length) pair // indicating where the data lives within @SEGMENT_RLP_RAW. // -// Pre stack: node_ptr, encode_value, retdest +// Pre stack: node_type, node_ptr, encode_value, retdest // Post stack: result_ptr, result_len -global encode_node: - // stack: node_ptr, encode_value, retdest - DUP1 %mload_trie_data +encode_node: // stack: node_type, node_ptr, encode_value, retdest // Increment node_ptr, so it points to the node payload instead of its type. SWAP1 %increment SWAP1 // stack: node_type, node_payload_ptr, encode_value, retdest DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(encode_node_empty) - DUP1 %eq_const(@MPT_NODE_HASH) %jumpi(encode_node_hash) DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(encode_node_branch) DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(encode_node_extension) DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(encode_node_leaf) - PANIC // Invalid node type? Shouldn't get here. + + // If we got here, node_type is either @MPT_NODE_HASH, which should have + // been handled earlier in encode_or_hash_node, or something invalid. + PANIC global encode_node_empty: // stack: node_type, node_payload_ptr, encode_value, retdest @@ -105,14 +127,6 @@ global encode_node_empty: %stack (retdest) -> (retdest, 0, 1) JUMP -global encode_node_hash: - // stack: node_type, node_payload_ptr, encode_value, retdest - POP - // stack: node_payload_ptr, encode_value, retdest - %mload_trie_data - %stack (hash, encode_value, retdest) -> (retdest, hash, 32) - JUMP - encode_node_branch: // stack: node_type, node_payload_ptr, encode_value, retdest POP @@ -152,21 +166,17 @@ encode_node_branch: %add_const(16) // stack: value_ptr_ptr, rlp_pos', encode_value, retdest %mload_trie_data - // stack: value_len_ptr, rlp_pos', encode_value, retdest - DUP1 %mload_trie_data - // stack: value_len, value_len_ptr, rlp_pos', encode_value, retdest - %jumpi(encode_node_branch_with_value) + // stack: value_ptr, rlp_pos', encode_value, retdest + DUP1 %jumpi(encode_node_branch_with_value) // No value; append the empty string (0x80). - // stack: value_len_ptr, rlp_pos', encode_value, retdest - %stack (value_len_ptr, rlp_pos, encode_value) -> (rlp_pos, 0x80, rlp_pos) + // stack: value_ptr, rlp_pos', encode_value, retdest + %stack (value_ptr, rlp_pos, encode_value) -> (rlp_pos, 0x80, rlp_pos) %mstore_rlp // stack: rlp_pos', retdest %increment // stack: rlp_pos'', retdest %jump(encode_node_branch_prepend_prefix) encode_node_branch_with_value: - // stack: value_len_ptr, rlp_pos', encode_value, retdest - %increment // stack: value_ptr, rlp_pos', encode_value, retdest %stack (value_ptr, rlp_pos, encode_value) -> (encode_value, rlp_pos, value_ptr, encode_node_branch_prepend_prefix) @@ -276,7 +286,6 @@ encode_node_leaf_after_hex_prefix: %add_const(2) // The value pointer starts at index 3, after num_nibbles and packed_nibbles. // stack: value_ptr_ptr, rlp_pos, encode_value, retdest %mload_trie_data - %increment // skip over length prefix // stack: value_ptr, rlp_pos, encode_value, retdest %stack (value_ptr, rlp_pos, encode_value, retdest) -> (encode_value, rlp_pos, value_ptr, encode_node_leaf_after_encode_value, retdest) diff --git a/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/hash/hash_trie_specific.asm similarity index 86% rename from evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm rename to evm/src/cpu/kernel/asm/mpt/hash/hash_trie_specific.asm index bf2c46f0..4f9b58b4 100644 --- a/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm +++ b/evm/src/cpu/kernel/asm/mpt/hash/hash_trie_specific.asm @@ -72,8 +72,13 @@ global encode_account: // stack: balance, rlp_pos_4, value_ptr, retdest SWAP1 %encode_rlp_scalar // stack: rlp_pos_5, value_ptr, retdest - DUP2 %add_const(2) %mload_trie_data // storage_root = value[2] - // stack: storage_root, rlp_pos_5, value_ptr, retdest + PUSH encode_account_after_hash_storage_trie + PUSH encode_storage_value + DUP4 %add_const(2) %mload_trie_data // storage_root_ptr = value[2] + // stack: storage_root_ptr, encode_storage_value, encode_account_after_hash_storage_trie, rlp_pos_5, value_ptr, retdest + %jump(mpt_hash) +encode_account_after_hash_storage_trie: + // stack: storage_root_digest, rlp_pos_5, value_ptr, retdest SWAP1 %encode_rlp_256 // stack: rlp_pos_6, value_ptr, retdest SWAP1 %add_const(3) %mload_trie_data // code_hash = value[3] @@ -88,3 +93,6 @@ encode_txn: encode_receipt: PANIC // TODO + +encode_storage_value: + PANIC // TODO diff --git a/evm/src/cpu/kernel/asm/mpt/insert.asm b/evm/src/cpu/kernel/asm/mpt/insert/insert.asm similarity index 100% rename from evm/src/cpu/kernel/asm/mpt/insert.asm rename to evm/src/cpu/kernel/asm/mpt/insert/insert.asm diff --git a/evm/src/cpu/kernel/asm/mpt/insert_extension.asm b/evm/src/cpu/kernel/asm/mpt/insert/insert_extension.asm similarity index 100% rename from evm/src/cpu/kernel/asm/mpt/insert_extension.asm rename to evm/src/cpu/kernel/asm/mpt/insert/insert_extension.asm diff --git a/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm b/evm/src/cpu/kernel/asm/mpt/insert/insert_leaf.asm similarity index 100% rename from evm/src/cpu/kernel/asm/mpt/insert_leaf.asm rename to evm/src/cpu/kernel/asm/mpt/insert/insert_leaf.asm diff --git a/evm/src/cpu/kernel/asm/mpt/insert_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/insert/insert_trie_specific.asm similarity index 100% rename from evm/src/cpu/kernel/asm/mpt/insert_trie_specific.asm rename to evm/src/cpu/kernel/asm/mpt/insert/insert_trie_specific.asm diff --git a/evm/src/cpu/kernel/asm/mpt/load.asm b/evm/src/cpu/kernel/asm/mpt/load.asm deleted file mode 100644 index 73f58b95..00000000 --- a/evm/src/cpu/kernel/asm/mpt/load.asm +++ /dev/null @@ -1,206 +0,0 @@ -// TODO: Receipt trie leaves are variable-length, so we need to be careful not -// to permit buffer over-reads. - -// Load all partial trie data from prover inputs. -global load_all_mpts: - // stack: retdest - // First set @GLOBAL_METADATA_TRIE_DATA_SIZE = 1. - // We don't want it to start at 0, as we use 0 as a null pointer. - PUSH 1 - %set_trie_data_size - - %load_mpt %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) - %load_mpt %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) - %load_mpt %mstore_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_ROOT) - - PROVER_INPUT(mpt) - // stack: num_storage_tries, retdest - DUP1 %mstore_global_metadata(@GLOBAL_METADATA_NUM_STORAGE_TRIES) - // stack: num_storage_tries, retdest - PUSH 0 // i = 0 - // stack: i, num_storage_tries, retdest -storage_trie_loop: - DUP2 DUP2 EQ - // stack: i == num_storage_tries, i, num_storage_tries, retdest - %jumpi(storage_trie_loop_end) - // stack: i, num_storage_tries, retdest - PROVER_INPUT(mpt) - // stack: storage_trie_addr, i, num_storage_tries, retdest - DUP2 - // stack: i, storage_trie_addr, i, num_storage_tries, retdest - %mstore_kernel(@SEGMENT_STORAGE_TRIE_ADDRS) - // stack: i, num_storage_tries, retdest - %load_mpt - // stack: root_ptr, i, num_storage_tries, retdest - DUP2 - // stack: i, root_ptr, i, num_storage_tries, retdest - %mstore_kernel(@SEGMENT_STORAGE_TRIE_PTRS) - // stack: i, num_storage_tries, retdest - %jump(storage_trie_loop) -storage_trie_loop_end: - // stack: i, num_storage_tries, retdest - %pop2 - // stack: retdest - JUMP - -// Load an MPT from prover inputs. -// Pre stack: retdest -// Post stack: node_ptr -load_mpt: - // stack: retdest - PROVER_INPUT(mpt) - // stack: node_type, retdest - - DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(load_mpt_empty) - DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(load_mpt_branch) - DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(load_mpt_extension) - DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(load_mpt_leaf) - DUP1 %eq_const(@MPT_NODE_HASH) %jumpi(load_mpt_digest) - PANIC // Invalid node type - -load_mpt_empty: - // TRIE_DATA[0] = 0, and an empty node has type 0, so we can simply return the null pointer. - %stack (node_type, retdest) -> (retdest, 0) - JUMP - -load_mpt_branch: - // stack: node_type, retdest - %get_trie_data_size - // stack: node_ptr, node_type, retdest - SWAP1 %append_to_trie_data - // stack: node_ptr, retdest - // Save the offset of our 16 child pointers so we can write them later. - // Then advance out current trie pointer beyond them, so we can load the - // value and have it placed after our child pointers. - %get_trie_data_size - // stack: children_ptr, node_ptr, retdest - DUP1 %add_const(17) // Skip over 16 children plus the value pointer - // stack: value_ptr, children_ptr, node_ptr, retdest - %set_trie_data_size - // stack: children_ptr, node_ptr, retdest - %load_value - // stack: children_ptr, value_ptr, node_ptr, retdest - SWAP1 - - // Load the 16 children. - %rep 16 - %load_mpt - // stack: child_ptr, next_child_ptr_ptr, value_ptr, node_ptr, retdest - DUP2 - // stack: next_child_ptr_ptr, child_ptr, next_child_ptr_ptr, value_ptr, node_ptr, retdest - %mstore_trie_data - // stack: next_child_ptr_ptr, value_ptr, node_ptr, retdest - %increment - // stack: next_child_ptr_ptr, value_ptr, node_ptr, retdest - %endrep - - // stack: value_ptr_ptr, value_ptr, node_ptr, retdest - %mstore_trie_data - // stack: node_ptr, retdest - SWAP1 - JUMP - -load_mpt_extension: - // stack: node_type, retdest - %get_trie_data_size - // stack: node_ptr, node_type, retdest - SWAP1 %append_to_trie_data - // stack: node_ptr, retdest - PROVER_INPUT(mpt) // read num_nibbles - %append_to_trie_data - PROVER_INPUT(mpt) // read packed_nibbles - %append_to_trie_data - // stack: node_ptr, retdest - - %get_trie_data_size - // stack: child_ptr_ptr, node_ptr, retdest - // Increment trie_data_size, to leave room for child_ptr_ptr, before we load our child. - DUP1 %increment %set_trie_data_size - // stack: child_ptr_ptr, node_ptr, retdest - - %load_mpt - // stack: child_ptr, child_ptr_ptr, node_ptr, retdest - SWAP1 - %mstore_trie_data - // stack: node_ptr, retdest - SWAP1 - JUMP - -load_mpt_leaf: - // stack: node_type, retdest - %get_trie_data_size - // stack: node_ptr, node_type, retdest - SWAP1 %append_to_trie_data - // stack: node_ptr, retdest - PROVER_INPUT(mpt) // read num_nibbles - %append_to_trie_data - PROVER_INPUT(mpt) // read packed_nibbles - %append_to_trie_data - // stack: node_ptr, retdest - // We save value_ptr_ptr = get_trie_data_size, then increment trie_data_size - // to skip over the slot for value_ptr. We will write value_ptr after the - // load_value call. - %get_trie_data_size - // stack: value_ptr_ptr, node_ptr, retdest - DUP1 %increment %set_trie_data_size - // stack: value_ptr_ptr, node_ptr, retdest - %load_value - // stack: value_ptr, value_ptr_ptr, node_ptr, retdest - SWAP1 %mstore_trie_data - // stack: node_ptr, retdest - SWAP1 - JUMP - -load_mpt_digest: - // stack: node_type, retdest - %get_trie_data_size - // stack: node_ptr, node_type, retdest - SWAP1 %append_to_trie_data - // stack: node_ptr, retdest - PROVER_INPUT(mpt) // read digest - %append_to_trie_data - // stack: node_ptr, retdest - SWAP1 - JUMP - -// Convenience macro to call load_mpt and return where we left off. -%macro load_mpt - PUSH %%after - %jump(load_mpt) -%%after: -%endmacro - -// Load a leaf from prover input, append it to trie data, and return a pointer to it. -%macro load_value - // stack: (empty) - PROVER_INPUT(mpt) - // stack: value_len - DUP1 ISZERO - %jumpi(%%return_null) - // stack: value_len - %get_trie_data_size - SWAP1 - // stack: value_len, value_ptr - DUP1 %append_to_trie_data - // stack: value_len, value_ptr -%%loop: - DUP1 ISZERO - // stack: value_len == 0, value_len, value_ptr - %jumpi(%%finish_loop) - // stack: value_len, value_ptr - PROVER_INPUT(mpt) - // stack: leaf_part, value_len, value_ptr - %append_to_trie_data - // stack: value_len, value_ptr - %decrement - // stack: value_len', value_ptr - %jump(%%loop) -%%finish_loop: - // stack: value_len, value_ptr - POP - // stack: value_ptr - %jump(%%end) -%%return_null: - %stack (value_len) -> (0) -%%end: -%endmacro diff --git a/evm/src/cpu/kernel/asm/mpt/load/load.asm b/evm/src/cpu/kernel/asm/mpt/load/load.asm new file mode 100644 index 00000000..d787074b --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/load/load.asm @@ -0,0 +1,173 @@ +// Load all partial trie data from prover inputs. +global load_all_mpts: + // stack: retdest + // First set @GLOBAL_METADATA_TRIE_DATA_SIZE = 1. + // We don't want it to start at 0, as we use 0 as a null pointer. + PUSH 1 + %set_trie_data_size + + %load_mpt(mpt_load_state_trie_value) %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + %load_mpt(mpt_load_txn_trie_value) %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) + %load_mpt(mpt_load_receipt_trie_value) %mstore_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_ROOT) + + // stack: retdest + JUMP + +// Load an MPT from prover inputs. +// Pre stack: load_value, retdest +// Post stack: node_ptr +global load_mpt: + // stack: load_value, retdest + PROVER_INPUT(mpt) + // stack: node_type, load_value, retdest + + DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(load_mpt_empty) + DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(load_mpt_branch) + DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(load_mpt_extension) + DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(load_mpt_leaf) + DUP1 %eq_const(@MPT_NODE_HASH) %jumpi(load_mpt_digest) + PANIC // Invalid node type + +load_mpt_empty: + // TRIE_DATA[0] = 0, and an empty node has type 0, so we can simply return the null pointer. + %stack (node_type, load_value, retdest) -> (retdest, 0) + JUMP + +load_mpt_branch: + // stack: node_type, load_value, retdest + %get_trie_data_size + // stack: node_ptr, node_type, load_value, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, load_value, retdest + // Save the offset of our 16 child pointers so we can write them later. + // Then advance our current trie pointer beyond them, so we can load the + // value and have it placed after our child pointers. + %get_trie_data_size + // stack: children_ptr, node_ptr, load_value, retdest + DUP1 %add_const(17) // Skip over 16 children plus the value pointer + // stack: end_of_branch_ptr, children_ptr, node_ptr, load_value, retdest + DUP1 %set_trie_data_size + // Now the top of the stack points to where the branch node will end and the + // value will begin, if there is a value. But we need to ask the prover if a + // value is present, and point to null if not. + // stack: end_of_branch_ptr, children_ptr, node_ptr, load_value, retdest + PROVER_INPUT(mpt) + // stack: is_value_present, end_of_branch_ptr, children_ptr, node_ptr, load_value, retdest + %jumpi(load_mpt_branch_value_present) + // There is no value present, so value_ptr = null. + %stack (end_of_branch_ptr) -> (0) + // stack: value_ptr, children_ptr, node_ptr, load_value, retdest + %jump(load_mpt_branch_after_load_value) +load_mpt_branch_value_present: + // stack: value_ptr, children_ptr, node_ptr, load_value, retdest + PUSH load_mpt_branch_after_load_value + DUP5 // load_value + JUMP +load_mpt_branch_after_load_value: + // stack: value_ptr, children_ptr, node_ptr, load_value, retdest + SWAP1 + // stack: children_ptr, value_ptr, node_ptr, load_value, retdest + + // Load the 16 children. + %rep 16 + DUP4 // load_value + %load_mpt + // stack: child_ptr, next_child_ptr_ptr, value_ptr, node_ptr, load_value, retdest + DUP2 + // stack: next_child_ptr_ptr, child_ptr, next_child_ptr_ptr, value_ptr, node_ptr, load_value, retdest + %mstore_trie_data + // stack: next_child_ptr_ptr, value_ptr, node_ptr, load_value, retdest + %increment + // stack: next_child_ptr_ptr, value_ptr, node_ptr, load_value, retdest + %endrep + + // stack: value_ptr_ptr, value_ptr, node_ptr, load_value, retdest + %mstore_trie_data + %stack (node_ptr, load_value, retdest) -> (retdest, node_ptr) + JUMP + +load_mpt_extension: + // stack: node_type, load_value, retdest + %get_trie_data_size + // stack: node_ptr, node_type, load_value, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, load_value, retdest + PROVER_INPUT(mpt) // read num_nibbles + %append_to_trie_data + PROVER_INPUT(mpt) // read packed_nibbles + %append_to_trie_data + // stack: node_ptr, load_value, retdest + + %get_trie_data_size + // stack: child_ptr_ptr, node_ptr, load_value, retdest + // Increment trie_data_size, to leave room for child_ptr_ptr, before we load our child. + DUP1 %increment %set_trie_data_size + %stack (child_ptr_ptr, node_ptr, load_value, retdest) + -> (load_value, load_mpt_extension_after_load_mpt, + child_ptr_ptr, retdest, node_ptr) + %jump(load_mpt) +load_mpt_extension_after_load_mpt: + // stack: child_ptr, child_ptr_ptr, retdest, node_ptr + SWAP1 %mstore_trie_data + // stack: retdest, node_ptr + JUMP + +load_mpt_leaf: + // stack: node_type, load_value, retdest + %get_trie_data_size + // stack: node_ptr, node_type, load_value, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, load_value, retdest + PROVER_INPUT(mpt) // read num_nibbles + %append_to_trie_data + PROVER_INPUT(mpt) // read packed_nibbles + %append_to_trie_data + // stack: node_ptr, load_value, retdest + // We save value_ptr_ptr = get_trie_data_size, then increment trie_data_size + // to skip over the slot for value_ptr_ptr. We will write to value_ptr_ptr + // after the load_value call. + %get_trie_data_size + // stack: value_ptr_ptr, node_ptr, load_value, retdest + DUP1 %increment + // stack: value_ptr, value_ptr_ptr, node_ptr, load_value, retdest + DUP1 %set_trie_data_size + // stack: value_ptr, value_ptr_ptr, node_ptr, load_value, retdest + %stack (value_ptr, value_ptr_ptr, node_ptr, load_value, retdest) + -> (load_value, load_mpt_leaf_after_load_value, + value_ptr_ptr, value_ptr, retdest, node_ptr) + JUMP +load_mpt_leaf_after_load_value: + // stack: value_ptr_ptr, value_ptr, retdest, node_ptr + %mstore_trie_data + // stack: retdest, node_ptr + JUMP + +load_mpt_digest: + // stack: node_type, load_value, retdest + %get_trie_data_size + // stack: node_ptr, node_type, load_value, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, load_value, retdest + PROVER_INPUT(mpt) // read digest + %append_to_trie_data + %stack (node_ptr, load_value, retdest) -> (retdest, node_ptr) + JUMP + +// Convenience macro to call load_mpt and return where we left off. +// Pre stack: load_value +// Post stack: node_ptr +%macro load_mpt + %stack (load_value) -> (load_value, %%after) + %jump(load_mpt) +%%after: +%endmacro + +// Convenience macro to call load_mpt and return where we left off. +// Pre stack: (empty) +// Post stack: node_ptr +%macro load_mpt(load_value) + PUSH %%after + PUSH $load_value + %jump(load_mpt) +%%after: +%endmacro diff --git a/evm/src/cpu/kernel/asm/mpt/load/load_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/load/load_trie_specific.asm new file mode 100644 index 00000000..b93b36e4 --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/load/load_trie_specific.asm @@ -0,0 +1,40 @@ +global mpt_load_state_trie_value: + // stack: retdest + + // Load and append the nonce and balance. + PROVER_INPUT(mpt) %append_to_trie_data + PROVER_INPUT(mpt) %append_to_trie_data + + // Now increment the trie data size by 2, to leave room for our storage trie + // pointer and code hash fields, before calling load_mpt which will append + // our storage trie data. + %get_trie_data_size + // stack: storage_trie_ptr_ptr, retdest + DUP1 %add_const(2) + // stack: storage_trie_ptr, storage_trie_ptr_ptr, retdest + %set_trie_data_size + // stack: storage_trie_ptr_ptr, retdest + + %load_mpt(mpt_load_storage_trie_value) + // stack: storage_trie_ptr, storage_trie_ptr_ptr, retdest + DUP2 %mstore_trie_data + // stack: storage_trie_ptr_ptr, retdest + %increment + // stack: code_hash_ptr, retdest + PROVER_INPUT(mpt) + // stack: code_hash, code_hash_ptr, retdest + SWAP1 %mstore_trie_data + // stack: retdest + JUMP + +global mpt_load_txn_trie_value: + // stack: retdest + PANIC // TODO + +global mpt_load_receipt_trie_value: + // stack: retdest + PANIC // TODO + +global mpt_load_storage_trie_value: + // stack: retdest + PANIC // TODO diff --git a/evm/src/cpu/kernel/asm/mpt/read.asm b/evm/src/cpu/kernel/asm/mpt/read.asm index dae97336..d375bedc 100644 --- a/evm/src/cpu/kernel/asm/mpt/read.asm +++ b/evm/src/cpu/kernel/asm/mpt/read.asm @@ -1,6 +1,6 @@ -// Given an address, return a pointer to the associated (length-prefixed) -// account data, which consists of four words (nonce, balance, storage_root, -// code_hash), in the state trie. Returns 0 if the address is not found. +// Given an address, return a pointer to the associated account data, which +// consists of four words (nonce, balance, storage_root, code_hash), in the +// state trie. Returns null if the address is not found. global mpt_read_state_trie: // stack: addr, retdest // The key is the hash of the address. Since KECCAK_GENERAL takes input from @@ -24,7 +24,7 @@ mpt_read_state_trie_after_mstore: // - the key, as a U256 // - the number of nibbles in the key (should start at 64) // -// This function returns a pointer to the length-prefixed leaf, or 0 if the key is not found. +// This function returns a pointer to the value, or 0 if the key is not found. global mpt_read: // stack: node_ptr, num_nibbles, key, retdest DUP1 @@ -77,15 +77,6 @@ mpt_read_branch_end_of_key: %add_const(16) // skip over the 16 child nodes // stack: value_ptr_ptr, retdest %mload_trie_data - // stack: value_len_ptr, retdest - DUP1 %mload_trie_data - // stack: value_len, value_len_ptr, retdest - %jumpi(mpt_read_branch_found_value) - // This branch node contains no value, so return null. - %stack (value_len_ptr, retdest) -> (retdest, 0) -mpt_read_branch_found_value: - // stack: value_len_ptr, retdest - %increment // stack: value_ptr, retdest SWAP1 JUMP diff --git a/evm/src/cpu/kernel/asm/rlp/decode.asm b/evm/src/cpu/kernel/asm/rlp/decode.asm index 182354c4..9842bfbd 100644 --- a/evm/src/cpu/kernel/asm/rlp/decode.asm +++ b/evm/src/cpu/kernel/asm/rlp/decode.asm @@ -14,7 +14,7 @@ global decode_rlp_string_len: // stack: pos, retdest DUP1 - %mload_current(@SEGMENT_RLP_RAW) + %mload_kernel(@SEGMENT_RLP_RAW) // stack: first_byte, pos, retdest DUP1 %gt_const(0xb7) @@ -89,7 +89,7 @@ global decode_rlp_scalar: global decode_rlp_list_len: // stack: pos, retdest DUP1 - %mload_current(@SEGMENT_RLP_RAW) + %mload_kernel(@SEGMENT_RLP_RAW) // stack: first_byte, pos, retdest SWAP1 %increment // increment pos @@ -151,7 +151,7 @@ decode_int_given_len_loop: // stack: acc << 8, pos, end_pos, retdest DUP2 // stack: pos, acc << 8, pos, end_pos, retdest - %mload_current(@SEGMENT_RLP_RAW) + %mload_kernel(@SEGMENT_RLP_RAW) // stack: byte, acc << 8, pos, end_pos, retdest ADD // stack: acc', pos, end_pos, retdest diff --git a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm index 5d8cbd17..2d71e65a 100644 --- a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm +++ b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm @@ -23,7 +23,7 @@ read_rlp_to_memory_loop: // stack: byte, pos, len, retdest DUP2 // stack: pos, byte, pos, len, retdest - %mstore_current(@SEGMENT_RLP_RAW) + %mstore_kernel(@SEGMENT_RLP_RAW) // stack: pos, len, retdest %increment // stack: pos', len, retdest diff --git a/evm/src/cpu/kernel/asm/transactions/router.asm b/evm/src/cpu/kernel/asm/transactions/router.asm index 974fed99..3f4ebe37 100644 --- a/evm/src/cpu/kernel/asm/transactions/router.asm +++ b/evm/src/cpu/kernel/asm/transactions/router.asm @@ -18,14 +18,14 @@ read_txn_from_memory: // first byte >= 0xc0, so there is no overlap. PUSH 0 - %mload_current(@SEGMENT_RLP_RAW) + %mload_kernel(@SEGMENT_RLP_RAW) %eq_const(1) // stack: first_byte == 1, retdest %jumpi(process_type_1_txn) // stack: retdest PUSH 0 - %mload_current(@SEGMENT_RLP_RAW) + %mload_kernel(@SEGMENT_RLP_RAW) %eq_const(2) // stack: first_byte == 2, retdest %jumpi(process_type_2_txn) diff --git a/evm/src/cpu/kernel/constants/context_metadata.rs b/evm/src/cpu/kernel/constants/context_metadata.rs index 17945d98..a2c460fc 100644 --- a/evm/src/cpu/kernel/constants/context_metadata.rs +++ b/evm/src/cpu/kernel/constants/context_metadata.rs @@ -21,7 +21,7 @@ pub(crate) enum ContextMetadata { /// prohibited. Static = 8, /// Pointer to the initial version of the state trie, at the creation of this context. Used when - /// we need to revert a context. See also `StorageTrieCheckpointPointers`. + /// we need to revert a context. StateTrieCheckpointPointer = 9, } diff --git a/evm/src/cpu/kernel/constants/global_metadata.rs b/evm/src/cpu/kernel/constants/global_metadata.rs index 295cdfd5..1fa62efe 100644 --- a/evm/src/cpu/kernel/constants/global_metadata.rs +++ b/evm/src/cpu/kernel/constants/global_metadata.rs @@ -18,9 +18,6 @@ pub(crate) enum GlobalMetadata { TransactionTrieRoot = 5, /// A pointer to the root of the receipt trie within the `TrieData` buffer. ReceiptTrieRoot = 6, - /// The number of storage tries involved in these transactions. I.e. the number of values in - /// `StorageTrieAddresses`, `StorageTriePointers` and `StorageTrieCheckpointPointers`. - NumStorageTries = 7, // The root digests of each Merkle trie before these transactions. StateTrieRootDigestBefore = 8, @@ -38,7 +35,7 @@ pub(crate) enum GlobalMetadata { } impl GlobalMetadata { - pub(crate) const COUNT: usize = 15; + pub(crate) const COUNT: usize = 14; pub(crate) fn all() -> [Self; Self::COUNT] { [ @@ -49,7 +46,6 @@ impl GlobalMetadata { Self::StateTrieRoot, Self::TransactionTrieRoot, Self::ReceiptTrieRoot, - Self::NumStorageTries, Self::StateTrieRootDigestBefore, Self::TransactionTrieRootDigestBefore, Self::ReceiptTrieRootDigestBefore, @@ -70,7 +66,6 @@ impl GlobalMetadata { GlobalMetadata::StateTrieRoot => "GLOBAL_METADATA_STATE_TRIE_ROOT", GlobalMetadata::TransactionTrieRoot => "GLOBAL_METADATA_TXN_TRIE_ROOT", GlobalMetadata::ReceiptTrieRoot => "GLOBAL_METADATA_RECEIPT_TRIE_ROOT", - GlobalMetadata::NumStorageTries => "GLOBAL_METADATA_NUM_STORAGE_TRIES", GlobalMetadata::StateTrieRootDigestBefore => "GLOBAL_METADATA_STATE_TRIE_DIGEST_BEFORE", GlobalMetadata::TransactionTrieRootDigestBefore => { "GLOBAL_METADATA_TXN_TRIE_DIGEST_BEFORE" diff --git a/evm/src/cpu/kernel/tests/mpt/hash.rs b/evm/src/cpu/kernel/tests/mpt/hash.rs index de519797..19c38e91 100644 --- a/evm/src/cpu/kernel/tests/mpt/hash.rs +++ b/evm/src/cpu/kernel/tests/mpt/hash.rs @@ -1,12 +1,12 @@ use anyhow::Result; use eth_trie_utils::partial_trie::PartialTrie; -use ethereum_types::{BigEndianHash, H256, U256}; +use ethereum_types::{BigEndianHash, H256}; use super::nibbles; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::extension_to_leaf; -use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; +use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1_rlp, test_account_2_rlp}; +use crate::generation::mpt::all_mpt_prover_inputs_reversed; use crate::generation::TrieInputs; // TODO: Test with short leaf. Might need to be a storage trie. @@ -24,73 +24,69 @@ fn mpt_hash_empty() -> Result<()> { } #[test] -fn mpt_hash_leaf() -> Result<()> { - let account = AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: H256::from_uint(&U256::from(3333)), - code_hash: H256::from_uint(&U256::from(4444)), +fn mpt_hash_empty_branch() -> Result<()> { + let children = std::array::from_fn(|_| PartialTrie::Empty.into()); + let state_trie = PartialTrie::Branch { + children, + value: vec![], }; - let account_rlp = rlp::encode(&account); - - let state_trie = PartialTrie::Leaf { - nibbles: nibbles(0xABC), - value: account_rlp.to_vec(), - }; - let trie_inputs = TrieInputs { state_trie, transactions_trie: Default::default(), receipts_trie: Default::default(), storage_tries: vec![], }; + test_state_trie(trie_inputs) +} +#[test] +fn mpt_hash_hash() -> Result<()> { + let hash = H256::random(); + let trie_inputs = TrieInputs { + state_trie: PartialTrie::Hash(hash), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + test_state_trie(trie_inputs) +} + +#[test] +fn mpt_hash_leaf() -> Result<()> { + let state_trie = PartialTrie::Leaf { + nibbles: nibbles(0xABC), + value: test_account_1_rlp(), + }; + let trie_inputs = TrieInputs { + state_trie, + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; test_state_trie(trie_inputs) } #[test] fn mpt_hash_extension_to_leaf() -> Result<()> { - let account = AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: H256::from_uint(&U256::from(3333)), - code_hash: H256::from_uint(&U256::from(4444)), - }; - let account_rlp = rlp::encode(&account); - - let state_trie = extension_to_leaf(account_rlp.to_vec()); - + let state_trie = extension_to_leaf(test_account_1_rlp()); let trie_inputs = TrieInputs { state_trie, transactions_trie: Default::default(), receipts_trie: Default::default(), storage_tries: vec![], }; - test_state_trie(trie_inputs) } #[test] fn mpt_hash_branch_to_leaf() -> Result<()> { - let account = AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: H256::from_uint(&U256::from(3333)), - code_hash: H256::from_uint(&U256::from(4444)), - }; - let account_rlp = rlp::encode(&account); - let leaf = PartialTrie::Leaf { nibbles: nibbles(0xABC), - value: account_rlp.to_vec(), + value: test_account_2_rlp(), } .into(); let mut children = std::array::from_fn(|_| PartialTrie::Empty.into()); - children[5] = PartialTrie::Branch { - children: children.clone(), - value: vec![], - } - .into(); children[3] = leaf; let state_trie = PartialTrie::Branch { children, diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index 469ad1e4..3a52948d 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -6,13 +6,13 @@ use super::nibbles; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::{test_account_1_rlp, test_account_2_rlp}; +use crate::cpu::kernel::tests::mpt::{test_account_1_rlp, test_account_2}; use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; use crate::generation::TrieInputs; #[test] fn mpt_insert_empty() -> Result<()> { - test_state_trie(Default::default(), nibbles(0xABC), test_account_2_rlp()) + test_state_trie(Default::default(), nibbles(0xABC), test_account_2()) } #[test] @@ -22,7 +22,7 @@ fn mpt_insert_leaf_identical_keys() -> Result<()> { nibbles: key, value: test_account_1_rlp(), }; - test_state_trie(state_trie, key, test_account_2_rlp()) + test_state_trie(state_trie, key, test_account_2()) } #[test] @@ -31,7 +31,7 @@ fn mpt_insert_leaf_nonoverlapping_keys() -> Result<()> { nibbles: nibbles(0xABC), value: test_account_1_rlp(), }; - test_state_trie(state_trie, nibbles(0x123), test_account_2_rlp()) + test_state_trie(state_trie, nibbles(0x123), test_account_2()) } #[test] @@ -40,7 +40,7 @@ fn mpt_insert_leaf_overlapping_keys() -> Result<()> { nibbles: nibbles(0xABC), value: test_account_1_rlp(), }; - test_state_trie(state_trie, nibbles(0xADE), test_account_2_rlp()) + test_state_trie(state_trie, nibbles(0xADE), test_account_2()) } #[test] @@ -49,7 +49,7 @@ fn mpt_insert_leaf_insert_key_extends_leaf_key() -> Result<()> { nibbles: nibbles(0xABC), value: test_account_1_rlp(), }; - test_state_trie(state_trie, nibbles(0xABCDE), test_account_2_rlp()) + test_state_trie(state_trie, nibbles(0xABCDE), test_account_2()) } #[test] @@ -58,7 +58,7 @@ fn mpt_insert_leaf_leaf_key_extends_insert_key() -> Result<()> { nibbles: nibbles(0xABCDE), value: test_account_1_rlp(), }; - test_state_trie(state_trie, nibbles(0xABC), test_account_2_rlp()) + test_state_trie(state_trie, nibbles(0xABC), test_account_2()) } #[test] @@ -69,10 +69,13 @@ fn mpt_insert_branch_replacing_empty_child() -> Result<()> { value: vec![], }; - test_state_trie(state_trie, nibbles(0xABC), test_account_2_rlp()) + test_state_trie(state_trie, nibbles(0xABC), test_account_2()) } #[test] +// TODO: Not a valid test because branches state trie cannot have branch values. +// We should change it to use a different trie. +#[ignore] fn mpt_insert_extension_nonoverlapping_keys() -> Result<()> { // Existing keys are 0xABC, 0xABCDEF; inserted key is 0x12345. let mut children = std::array::from_fn(|_| PartialTrie::Empty.into()); @@ -89,10 +92,13 @@ fn mpt_insert_extension_nonoverlapping_keys() -> Result<()> { } .into(), }; - test_state_trie(state_trie, nibbles(0x12345), test_account_2_rlp()) + test_state_trie(state_trie, nibbles(0x12345), test_account_2()) } #[test] +// TODO: Not a valid test because branches state trie cannot have branch values. +// We should change it to use a different trie. +#[ignore] fn mpt_insert_extension_insert_key_extends_node_key() -> Result<()> { // Existing keys are 0xA, 0xABCD; inserted key is 0xABCDEF. let mut children = std::array::from_fn(|_| PartialTrie::Empty.into()); @@ -109,7 +115,7 @@ fn mpt_insert_extension_insert_key_extends_node_key() -> Result<()> { } .into(), }; - test_state_trie(state_trie, nibbles(0xABCDEF), test_account_2_rlp()) + test_state_trie(state_trie, nibbles(0xABCDEF), test_account_2()) } #[test] @@ -126,10 +132,14 @@ fn mpt_insert_branch_to_leaf_same_key() -> Result<()> { value: vec![], }; - test_state_trie(state_trie, nibbles(0xABCD), test_account_2_rlp()) + test_state_trie(state_trie, nibbles(0xABCD), test_account_2()) } -fn test_state_trie(state_trie: PartialTrie, k: Nibbles, v: Vec) -> Result<()> { +/// Note: The account's storage_root is ignored, as we can't insert a new storage_root without the +/// accompanying trie data. An empty trie's storage_root is used instead. +fn test_state_trie(state_trie: PartialTrie, k: Nibbles, mut account: AccountRlp) -> Result<()> { + account.storage_root = PartialTrie::Empty.calc_hash(); + let trie_inputs = TrieInputs { state_trie: state_trie.clone(), transactions_trie: Default::default(), @@ -155,10 +165,13 @@ fn test_state_trie(state_trie: PartialTrie, k: Nibbles, v: Vec) -> Result<() trie_data.push(0.into()); } let value_ptr = trie_data.len(); - let account: AccountRlp = rlp::decode(&v).expect("Decoding failed"); - let account_data = account.to_vec(); - trie_data.push(account_data.len().into()); - trie_data.extend(account_data); + trie_data.push(account.nonce); + trie_data.push(account.balance); + // In memory, storage_root gets interpreted as a pointer to a storage trie, + // so we have to ensure the pointer is valid. It's easiest to set it to 0, + // which works as an empty node, since trie_data[0] = 0 = MPT_TYPE_EMPTY. + trie_data.push(H256::zero().into_uint()); + trie_data.push(account.code_hash.into_uint()); let trie_data_len = trie_data.len().into(); interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, trie_data_len); interpreter.push(0xDEADBEEFu32.into()); @@ -187,7 +200,7 @@ fn test_state_trie(state_trie: PartialTrie, k: Nibbles, v: Vec) -> Result<() ); let hash = H256::from_uint(&interpreter.stack()[0]); - let updated_trie = state_trie.insert(k, v); + let updated_trie = state_trie.insert(k, rlp::encode(&account).to_vec()); let expected_state_trie_hash = updated_trie.calc_hash(); assert_eq!(hash, expected_state_trie_hash); diff --git a/evm/src/cpu/kernel/tests/mpt/load.rs b/evm/src/cpu/kernel/tests/mpt/load.rs index 0572458d..78129a1c 100644 --- a/evm/src/cpu/kernel/tests/mpt/load.rs +++ b/evm/src/cpu/kernel/tests/mpt/load.rs @@ -1,6 +1,6 @@ use anyhow::Result; use eth_trie_utils::partial_trie::PartialTrie; -use ethereum_types::{BigEndianHash, U256}; +use ethereum_types::{BigEndianHash, H256, U256}; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::constants::trie_type::PartialTrieType; @@ -42,11 +42,6 @@ fn load_all_mpts_empty() -> Result<()> { 0.into() ); - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::NumStorageTries), - trie_inputs.storage_tries.len().into() - ); - Ok(()) } @@ -79,11 +74,13 @@ fn load_all_mpts_leaf() -> Result<()> { 3.into(), 0xABC.into(), 5.into(), // value ptr - 4.into(), // value length test_account_1().nonce, test_account_1().balance, - test_account_1().storage_root.into_uint(), + 9.into(), // pointer to storage trie root test_account_1().code_hash.into_uint(), + // These last two elements encode the storage trie, which is a hash node. + (PartialTrieType::Hash as u32).into(), + test_account_1().storage_root.into_uint(), ] ); @@ -96,9 +93,40 @@ fn load_all_mpts_leaf() -> Result<()> { 0.into() ); + Ok(()) +} + +#[test] +fn load_all_mpts_hash() -> Result<()> { + let hash = H256::random(); + let trie_inputs = TrieInputs { + state_trie: PartialTrie::Hash(hash), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + let type_hash = U256::from(PartialTrieType::Hash as u32); assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::NumStorageTries), - trie_inputs.storage_tries.len().into() + interpreter.get_trie_data(), + vec![0.into(), type_hash, hash.into_uint(),] + ); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), + 0.into() + ); + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), + 0.into() ); Ok(()) @@ -161,11 +189,6 @@ fn load_all_mpts_empty_branch() -> Result<()> { 0.into() ); - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::NumStorageTries), - trie_inputs.storage_tries.len().into() - ); - Ok(()) } @@ -200,18 +223,15 @@ fn load_all_mpts_ext_to_leaf() -> Result<()> { 3.into(), // 3 nibbles 0xDEF.into(), // key part 9.into(), // value pointer - 4.into(), // value length test_account_1().nonce, test_account_1().balance, - test_account_1().storage_root.into_uint(), + 13.into(), // pointer to storage trie root test_account_1().code_hash.into_uint(), + // These last two elements encode the storage trie, which is a hash node. + (PartialTrieType::Hash as u32).into(), + test_account_1().storage_root.into_uint(), ] ); - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::NumStorageTries), - trie_inputs.storage_tries.len().into() - ); - Ok(()) } diff --git a/evm/src/cpu/kernel/tests/mpt/read.rs b/evm/src/cpu/kernel/tests/mpt/read.rs index 06d89ff6..d8808e24 100644 --- a/evm/src/cpu/kernel/tests/mpt/read.rs +++ b/evm/src/cpu/kernel/tests/mpt/read.rs @@ -1,25 +1,17 @@ use anyhow::Result; -use ethereum_types::{BigEndianHash, H256, U256}; +use ethereum_types::BigEndianHash; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::extension_to_leaf; -use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; +use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1, test_account_1_rlp}; +use crate::generation::mpt::all_mpt_prover_inputs_reversed; use crate::generation::TrieInputs; #[test] fn mpt_read() -> Result<()> { - let account = AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: H256::from_uint(&U256::from(3333)), - code_hash: H256::from_uint(&U256::from(4444)), - }; - let account_rlp = rlp::encode(&account); - let trie_inputs = TrieInputs { - state_trie: extension_to_leaf(account_rlp.to_vec()), + state_trie: extension_to_leaf(test_account_1_rlp()), transactions_trie: Default::default(), receipts_trie: Default::default(), storage_tries: vec![], @@ -44,12 +36,12 @@ fn mpt_read() -> Result<()> { assert_eq!(interpreter.stack().len(), 1); let result_ptr = interpreter.stack()[0].as_usize(); - let result = &interpreter.get_trie_data()[result_ptr..][..5]; - assert_eq!(result[0], 4.into()); - assert_eq!(result[1], account.nonce); - assert_eq!(result[2], account.balance); - assert_eq!(result[3], account.storage_root.into_uint()); - assert_eq!(result[4], account.code_hash.into_uint()); + let result = &interpreter.get_trie_data()[result_ptr..][..4]; + assert_eq!(result[0], test_account_1().nonce); + assert_eq!(result[1], test_account_1().balance); + // result[2] is the storage root pointer. We won't check that it matches a + // particular address, since that seems like over-specifying. + assert_eq!(result[3], test_account_1().code_hash.into_uint()); Ok(()) } diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index e35364c6..8ceb195a 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -1,5 +1,8 @@ -use eth_trie_utils::partial_trie::PartialTrie; +use std::collections::HashMap; + +use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; use ethereum_types::{BigEndianHash, H256, U256}; +use keccak_hash::keccak; use rlp_derive::{RlpDecodable, RlpEncodable}; use crate::cpu::kernel::constants::trie_type::PartialTrieType; @@ -13,17 +16,6 @@ pub(crate) struct AccountRlp { pub(crate) code_hash: H256, } -impl AccountRlp { - pub(crate) fn to_vec(&self) -> Vec { - vec![ - self.nonce, - self.balance, - self.storage_root.into_uint(), - self.code_hash.into_uint(), - ] - } -} - pub(crate) fn all_mpt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec { let mut inputs = all_mpt_prover_inputs(trie_inputs); inputs.reverse(); @@ -34,10 +26,18 @@ pub(crate) fn all_mpt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec Vec { let mut prover_inputs = vec![]; - mpt_prover_inputs(&trie_inputs.state_trie, &mut prover_inputs, &|rlp| { - let account: AccountRlp = rlp::decode(rlp).expect("Decoding failed"); - account.to_vec() - }); + let storage_tries_by_state_key = trie_inputs + .storage_tries + .iter() + .map(|(address, storage_trie)| (Nibbles::from(keccak(address)), storage_trie)) + .collect(); + + mpt_prover_inputs_state_trie( + &trie_inputs.state_trie, + empty_nibbles(), + &mut prover_inputs, + &storage_tries_by_state_key, + ); mpt_prover_inputs(&trie_inputs.transactions_trie, &mut prover_inputs, &|rlp| { rlp::decode_list(rlp) @@ -48,14 +48,6 @@ pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { vec![] }); - prover_inputs.push(trie_inputs.storage_tries.len().into()); - for (addr, storage_trie) in &trie_inputs.storage_tries { - prover_inputs.push(addr.0.as_ref().into()); - mpt_prover_inputs(storage_trie, &mut prover_inputs, &|leaf_be| { - vec![U256::from_big_endian(leaf_be)] - }); - } - prover_inputs } @@ -66,7 +58,7 @@ pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { pub(crate) fn mpt_prover_inputs( trie: &PartialTrie, prover_inputs: &mut Vec, - parse_leaf: &F, + parse_value: &F, ) where F: Fn(&[u8]) -> Vec, { @@ -76,28 +68,108 @@ pub(crate) fn mpt_prover_inputs( PartialTrie::Hash(h) => prover_inputs.push(U256::from_big_endian(h.as_bytes())), PartialTrie::Branch { children, value } => { if value.is_empty() { - // There's no value, so length=0. + // There's no value, so value_len = 0. prover_inputs.push(U256::zero()); } else { - let leaf = parse_leaf(value); - prover_inputs.push(leaf.len().into()); - prover_inputs.extend(leaf); + let parsed_value = parse_value(value); + prover_inputs.push(parsed_value.len().into()); + prover_inputs.extend(parsed_value); } for child in children { - mpt_prover_inputs(child, prover_inputs, parse_leaf); + mpt_prover_inputs(child, prover_inputs, parse_value); } } PartialTrie::Extension { nibbles, child } => { prover_inputs.push(nibbles.count.into()); prover_inputs.push(nibbles.packed); - mpt_prover_inputs(child, prover_inputs, parse_leaf); + mpt_prover_inputs(child, prover_inputs, parse_value); } PartialTrie::Leaf { nibbles, value } => { prover_inputs.push(nibbles.count.into()); prover_inputs.push(nibbles.packed); - let leaf = parse_leaf(value); - prover_inputs.push(leaf.len().into()); + let leaf = parse_value(value); prover_inputs.extend(leaf); } } } + +/// Like `mpt_prover_inputs`, but for the state trie, which is a bit unique since each value +/// leads to a storage trie which we recursively traverse. +pub(crate) fn mpt_prover_inputs_state_trie( + trie: &PartialTrie, + key: Nibbles, + prover_inputs: &mut Vec, + storage_tries_by_state_key: &HashMap, +) { + prover_inputs.push((PartialTrieType::of(trie) as u32).into()); + match trie { + PartialTrie::Empty => {} + PartialTrie::Hash(h) => prover_inputs.push(U256::from_big_endian(h.as_bytes())), + PartialTrie::Branch { children, value } => { + assert!(value.is_empty(), "State trie should not have branch values"); + // There's no value, so value_len = 0. + prover_inputs.push(U256::zero()); + + for (i, child) in children.iter().enumerate() { + let extended_key = key.merge(&Nibbles { + count: 1, + packed: i.into(), + }); + mpt_prover_inputs_state_trie( + child, + extended_key, + prover_inputs, + storage_tries_by_state_key, + ); + } + } + PartialTrie::Extension { nibbles, child } => { + prover_inputs.push(nibbles.count.into()); + prover_inputs.push(nibbles.packed); + let extended_key = key.merge(nibbles); + mpt_prover_inputs_state_trie( + child, + extended_key, + prover_inputs, + storage_tries_by_state_key, + ); + } + PartialTrie::Leaf { nibbles, value } => { + let account: AccountRlp = rlp::decode(value).expect("Decoding failed"); + let AccountRlp { + nonce, + balance, + storage_root, + code_hash, + } = account; + + let storage_hash_only = PartialTrie::Hash(storage_root); + let storage_trie: &PartialTrie = storage_tries_by_state_key + .get(&key) + .copied() + .unwrap_or(&storage_hash_only); + + assert_eq!(storage_trie.calc_hash(), storage_root, + "In TrieInputs, an account's storage_root didn't match the associated storage trie hash"); + + prover_inputs.push(nibbles.count.into()); + prover_inputs.push(nibbles.packed); + prover_inputs.push(nonce); + prover_inputs.push(balance); + mpt_prover_inputs(storage_trie, prover_inputs, &parse_storage_value); + prover_inputs.push(code_hash.into_uint()); + } + } +} + +fn parse_storage_value(value_rlp: &[u8]) -> Vec { + let value: U256 = rlp::decode(value_rlp).expect("Decoding failed"); + vec![value] +} + +fn empty_nibbles() -> Nibbles { + Nibbles { + count: 0, + packed: U256::zero(), + } +} diff --git a/evm/src/memory/segments.rs b/evm/src/memory/segments.rs index b6254900..b8ba904f 100644 --- a/evm/src/memory/segments.rs +++ b/evm/src/memory/segments.rs @@ -29,23 +29,14 @@ pub(crate) enum Segment { /// Contains all trie data. Tries are stored as immutable, copy-on-write trees, so this is an /// append-only buffer. It is owned by the kernel, so it only lives on context 0. TrieData = 12, - /// The account address associated with the `i`th storage trie. Only lives on context 0. - StorageTrieAddresses = 13, - /// A pointer to the `i`th storage trie within the `TrieData` buffer. Only lives on context 0. - StorageTriePointers = 14, - /// Like `StorageTriePointers`, except that these pointers correspond to the version of each - /// trie at the creation of a given context. This lets us easily revert a context by replacing - /// `StorageTriePointers` with `StorageTrieCheckpointPointers`. - /// See also `StateTrieCheckpointPointer`. - StorageTrieCheckpointPointers = 15, /// A buffer used to store the encodings of a branch node's children. - TrieEncodedChild = 16, + TrieEncodedChild = 13, /// A buffer used to store the lengths of the encodings of a branch node's children. - TrieEncodedChildLen = 17, + TrieEncodedChildLen = 14, } impl Segment { - pub(crate) const COUNT: usize = 18; + pub(crate) const COUNT: usize = 15; pub(crate) fn all() -> [Self; Self::COUNT] { [ @@ -62,9 +53,6 @@ impl Segment { Self::TxnData, Self::RlpRaw, Self::TrieData, - Self::StorageTrieAddresses, - Self::StorageTriePointers, - Self::StorageTrieCheckpointPointers, Self::TrieEncodedChild, Self::TrieEncodedChildLen, ] @@ -86,9 +74,6 @@ impl Segment { Segment::TxnData => "SEGMENT_TXN_DATA", Segment::RlpRaw => "SEGMENT_RLP_RAW", Segment::TrieData => "SEGMENT_TRIE_DATA", - Segment::StorageTrieAddresses => "SEGMENT_STORAGE_TRIE_ADDRS", - Segment::StorageTriePointers => "SEGMENT_STORAGE_TRIE_PTRS", - Segment::StorageTrieCheckpointPointers => "SEGMENT_STORAGE_TRIE_CHECKPOINT_PTRS", Segment::TrieEncodedChild => "SEGMENT_TRIE_ENCODED_CHILD", Segment::TrieEncodedChildLen => "SEGMENT_TRIE_ENCODED_CHILD_LEN", } @@ -110,9 +95,6 @@ impl Segment { Segment::TxnData => 256, Segment::RlpRaw => 8, Segment::TrieData => 256, - Segment::StorageTrieAddresses => 160, - Segment::StorageTriePointers => 32, - Segment::StorageTrieCheckpointPointers => 32, Segment::TrieEncodedChild => 256, Segment::TrieEncodedChildLen => 6, } diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index a16063c4..bc64bb57 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -146,7 +146,7 @@ impl, C: GenericConfig, const D: usize> } // Verify the CTL checks. - let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits); + let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits()); verify_cross_table_lookups::( cross_table_lookups, pis.map(|p| p.ctl_zs_last), @@ -221,7 +221,7 @@ impl, C: GenericConfig, const D: usize> } // Verify the CTL checks. - let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits); + let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits()); verify_cross_table_lookups_circuit::( builder, cross_table_lookups, @@ -586,6 +586,7 @@ where VerifierCircuitTarget { constants_sigmas_cap: builder .constant_merkle_cap(&verifier_data.verifier_only.constants_sigmas_cap), + circuit_digest: builder.add_virtual_hash(), } }); RecursiveAllProofTargetWithData { diff --git a/plonky2/examples/bench_recursion.rs b/plonky2/examples/bench_recursion.rs index 8073c9dc..f4379e7a 100644 --- a/plonky2/examples/bench_recursion.rs +++ b/plonky2/examples/bench_recursion.rs @@ -117,11 +117,9 @@ where let inner_data = VerifierCircuitTarget { constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height), + circuit_digest: builder.add_virtual_hash(), }; - pw.set_cap_target( - &inner_data.constants_sigmas_cap, - &inner_vd.constants_sigmas_cap, - ); + pw.set_verifier_data_target(&inner_data, inner_vd); builder.verify_proof(pt, &inner_data, inner_cd); builder.print_gate_counts(0); @@ -151,6 +149,7 @@ where /// Test serialization and print some size info. fn test_serialization, C: GenericConfig, const D: usize>( proof: &ProofWithPublicInputs, + vd: &VerifierOnlyCircuitData, cd: &CommonCircuitData, ) -> Result<()> where @@ -162,8 +161,10 @@ where assert_eq!(proof, &proof_from_bytes); let now = std::time::Instant::now(); - let compressed_proof = proof.clone().compress(cd)?; - let decompressed_compressed_proof = compressed_proof.clone().decompress(cd)?; + let compressed_proof = proof.clone().compress(&vd.circuit_digest, cd)?; + let decompressed_compressed_proof = compressed_proof + .clone() + .decompress(&vd.circuit_digest, cd)?; info!("{:.4}s to compress proof", now.elapsed().as_secs_f64()); assert_eq!(proof, &decompressed_compressed_proof); @@ -190,7 +191,7 @@ fn benchmark(config: &CircuitConfig, log2_inner_size: usize) -> Result<()> { info!( "Initial proof degree {} = 2^{}", cd.degree(), - cd.degree_bits + cd.degree_bits() ); // Recursively verify the proof @@ -199,19 +200,19 @@ fn benchmark(config: &CircuitConfig, log2_inner_size: usize) -> Result<()> { info!( "Single recursion proof degree {} = 2^{}", cd.degree(), - cd.degree_bits + cd.degree_bits() ); // Add a second layer of recursion to shrink the proof size further let outer = recursive_proof::(&middle, config, None)?; - let (proof, _, cd) = &outer; + let (proof, vd, cd) = &outer; info!( "Double recursion proof degree {} = 2^{}", cd.degree(), - cd.degree_bits + cd.degree_bits() ); - test_serialization(proof, cd)?; + test_serialization(proof, vd, cd)?; Ok(()) } diff --git a/plonky2/src/fri/mod.rs b/plonky2/src/fri/mod.rs index 87c4c2aa..9c44b53b 100644 --- a/plonky2/src/fri/mod.rs +++ b/plonky2/src/fri/mod.rs @@ -54,7 +54,7 @@ impl FriConfig { /// FRI parameters, including generated parameters which are specific to an instance size, in /// contrast to `FriConfig` which is user-specified and independent of instance size. -#[derive(Debug)] +#[derive(Debug, Eq, PartialEq)] pub struct FriParams { /// User-specified FRI configuration. pub config: FriConfig, diff --git a/plonky2/src/gates/selectors.rs b/plonky2/src/gates/selectors.rs index fff5d967..6cea86a7 100644 --- a/plonky2/src/gates/selectors.rs +++ b/plonky2/src/gates/selectors.rs @@ -9,7 +9,7 @@ use crate::hash::hash_types::RichField; /// Placeholder value to indicate that a gate doesn't use a selector polynomial. pub(crate) const UNUSED_SELECTOR: usize = u32::MAX as usize; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub(crate) struct SelectorsInfo { pub(crate) selector_indices: Vec, pub(crate) groups: Vec>, diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index e7f21241..9a3cb662 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -13,6 +13,7 @@ use crate::hash::merkle_tree::MerkleCap; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; +use crate::plonk::circuit_data::{VerifierCircuitTarget, VerifierOnlyCircuitData}; use crate::plonk::config::{AlgebraicHasher, GenericConfig}; use crate::plonk::proof::{Proof, ProofTarget, ProofWithPublicInputs, ProofWithPublicInputsTarget}; @@ -197,6 +198,18 @@ pub trait Witness { } } + fn set_verifier_data_target, const D: usize>( + &mut self, + vdt: &VerifierCircuitTarget, + vd: &VerifierOnlyCircuitData, + ) where + F: RichField + Extendable, + C::Hasher: AlgebraicHasher, + { + self.set_cap_target(&vdt.constants_sigmas_cap, &vd.constants_sigmas_cap); + self.set_hash_target(vdt.circuit_digest, vd.circuit_digest); + } + fn set_wire(&mut self, wire: Wire, value: F) { self.set_target(Target::Wire(wire), value) } diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index bfa012da..83587f2e 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -271,6 +271,10 @@ impl, const D: usize> CircuitBuilder { ); } + pub fn add_gate_to_gate_set(&mut self, gate: GateRef) { + self.gates.insert(gate); + } + pub fn connect_extension(&mut self, src: ExtensionTarget, dst: ExtensionTarget) { for i in 0..D { self.connect(src.0[i], dst.0[i]); @@ -751,11 +755,6 @@ impl, const D: usize> CircuitBuilder { Some(&fft_root_table), ); - let constants_sigmas_cap = constants_sigmas_commitment.merkle_tree.cap.clone(); - let verifier_only = VerifierOnlyCircuitData { - constants_sigmas_cap: constants_sigmas_cap.clone(), - }; - // Map between gates where not all generators are used and the gate's number of used generators. let incomplete_gates = self .current_slots @@ -796,17 +795,6 @@ impl, const D: usize> CircuitBuilder { indices.shrink_to_fit(); } - let prover_only = ProverOnlyCircuitData { - generators: self.generators, - generator_indices_by_watches, - constants_sigmas_commitment, - sigmas: transpose_poly_values(sigma_vecs), - subgroup, - public_inputs: self.public_inputs, - representative_map: forest.parents, - fft_root_table: Some(fft_root_table), - }; - let num_gate_constraints = gates .iter() .map(|gate| gate.0.num_constraints()) @@ -816,6 +804,7 @@ impl, const D: usize> CircuitBuilder { let num_partial_products = num_partial_products(self.config.num_routed_wires, quotient_degree_factor); + let constants_sigmas_cap = constants_sigmas_commitment.merkle_tree.cap.clone(); // TODO: This should also include an encoding of gate constraints. let circuit_digest_parts = [ constants_sigmas_cap.flatten(), @@ -829,7 +818,6 @@ impl, const D: usize> CircuitBuilder { let common = CommonCircuitData { config: self.config, fri_params, - degree_bits, gates, selectors_info, quotient_degree_factor, @@ -838,6 +826,22 @@ impl, const D: usize> CircuitBuilder { num_public_inputs, k_is, num_partial_products, + }; + + let prover_only = ProverOnlyCircuitData { + generators: self.generators, + generator_indices_by_watches, + constants_sigmas_commitment, + sigmas: transpose_poly_values(sigma_vecs), + subgroup, + public_inputs: self.public_inputs, + representative_map: forest.parents, + fft_root_table: Some(fft_root_table), + circuit_digest, + }; + + let verifier_only = VerifierOnlyCircuitData { + constants_sigmas_cap, circuit_digest, }; diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 7e69ef31..5143e730 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -15,7 +15,7 @@ use crate::fri::structure::{ use crate::fri::{FriConfig, FriParams}; use crate::gates::gate::GateRef; use crate::gates::selectors::SelectorsInfo; -use crate::hash::hash_types::{MerkleCapTarget, RichField}; +use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::ext_target::ExtensionTarget; use crate::iop::generator::WitnessGenerator; @@ -29,7 +29,7 @@ use crate::plonk::prover::prove; use crate::plonk::verifier::verify; use crate::util::timing::TimingTree; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct CircuitConfig { pub num_wires: usize, pub num_routed_wires: usize, @@ -141,6 +141,23 @@ impl, C: GenericConfig, const D: usize> compressed_proof_with_pis.verify(&self.verifier_only, &self.common) } + pub fn compress( + &self, + proof: ProofWithPublicInputs, + ) -> Result> { + proof.compress(&self.verifier_only.circuit_digest, &self.common) + } + + pub fn decompress( + &self, + proof: CompressedProofWithPublicInputs, + ) -> Result> + where + [(); C::Hasher::HASH_SIZE]:, + { + proof.decompress(&self.verifier_only.circuit_digest, &self.common) + } + pub fn verifier_data(self) -> VerifierCircuitData { let CircuitData { verifier_only, @@ -253,6 +270,9 @@ pub struct ProverOnlyCircuitData< pub representative_map: Vec, /// Pre-computed roots for faster FFT. pub fft_root_table: Option>, + /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to + /// seed Fiat-Shamir. + pub circuit_digest: <>::Hasher as Hasher>::Hash, } /// Circuit data required by the verifier, but not the prover. @@ -260,10 +280,13 @@ pub struct ProverOnlyCircuitData< pub struct VerifierOnlyCircuitData, const D: usize> { /// A commitment to each constant polynomial and each permutation polynomial. pub constants_sigmas_cap: MerkleCap, + /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to + /// seed Fiat-Shamir. + pub circuit_digest: <>::Hasher as Hasher>::Hash, } /// Circuit data required by both the prover and the verifier. -#[derive(Debug)] +#[derive(Debug, Eq, PartialEq)] pub struct CommonCircuitData< F: RichField + Extendable, C: GenericConfig, @@ -273,10 +296,8 @@ pub struct CommonCircuitData< pub(crate) fri_params: FriParams, - pub degree_bits: usize, - /// The types of gates used in this circuit, along with their prefixes. - pub(crate) gates: Vec>, + pub(crate) gates: Vec>, /// Information on the circuit's selector polynomials. pub(crate) selectors_info: SelectorsInfo, @@ -293,29 +314,29 @@ pub struct CommonCircuitData< pub(crate) num_public_inputs: usize, /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. - pub(crate) k_is: Vec, + pub(crate) k_is: Vec, /// The number of partial products needed to compute the `Z` polynomials. pub(crate) num_partial_products: usize, - - /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to - /// seed Fiat-Shamir. - pub(crate) circuit_digest: <>::Hasher as Hasher>::Hash, } impl, C: GenericConfig, const D: usize> CommonCircuitData { + pub const fn degree_bits(&self) -> usize { + self.fri_params.degree_bits + } + pub fn degree(&self) -> usize { - 1 << self.degree_bits + 1 << self.degree_bits() } pub fn lde_size(&self) -> usize { - 1 << (self.degree_bits + self.config.fri_config.rate_bits) + self.fri_params.lde_size() } pub fn lde_generator(&self) -> F { - F::primitive_root_of_unity(self.degree_bits + self.config.fri_config.rate_bits) + F::primitive_root_of_unity(self.degree_bits() + self.config.fri_config.rate_bits) } pub fn constraint_degree(&self) -> usize { @@ -358,7 +379,7 @@ impl, C: GenericConfig, const D: usize> }; // The Z polynomials are also opened at g * zeta. - let g = F::Extension::primitive_root_of_unity(self.degree_bits); + let g = F::Extension::primitive_root_of_unity(self.degree_bits()); let zeta_next = g * zeta; let zeta_next_batch = FriBatchInfo { point: zeta_next, @@ -384,7 +405,7 @@ impl, C: GenericConfig, const D: usize> }; // The Z polynomials are also opened at g * zeta. - let g = F::primitive_root_of_unity(self.degree_bits); + let g = F::primitive_root_of_unity(self.degree_bits()); let zeta_next = builder.mul_const_extension(g, zeta); let zeta_next_batch = FriBatchInfoTarget { point: zeta_next, @@ -476,4 +497,7 @@ impl, C: GenericConfig, const D: usize> pub struct VerifierCircuitTarget { /// A commitment to each constant polynomial and each permutation polynomial. pub constants_sigmas_cap: MerkleCapTarget, + /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to + /// seed Fiat-Shamir. + pub circuit_digest: HashOutTarget, } diff --git a/plonky2/src/plonk/conditional_recursive_verifier.rs b/plonky2/src/plonk/conditional_recursive_verifier.rs index 8d5b2e88..2c406904 100644 --- a/plonky2/src/plonk/conditional_recursive_verifier.rs +++ b/plonky2/src/plonk/conditional_recursive_verifier.rs @@ -1,29 +1,84 @@ +use anyhow::{ensure, Result}; use itertools::Itertools; use plonky2_field::extension::Extendable; +use plonky2_util::ceil_div_usize; use crate::fri::proof::{ FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget, FriQueryStepTarget, }; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; +use crate::gates::noop::NoopGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; use crate::hash::merkle_proofs::MerkleProofTarget; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::{BoolTarget, Target}; +use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::circuit_data::{CommonCircuitData, VerifierCircuitTarget}; -use crate::plonk::config::{AlgebraicHasher, GenericConfig}; -use crate::plonk::proof::{OpeningSetTarget, ProofTarget, ProofWithPublicInputsTarget}; +use crate::plonk::circuit_data::{ + CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, +}; +use crate::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; +use crate::plonk::proof::{ + OpeningSetTarget, ProofTarget, ProofWithPublicInputs, ProofWithPublicInputsTarget, +}; use crate::with_context; +/// Generate a proof having a given `CommonCircuitData`. +#[allow(unused)] // TODO: should be used soon. +pub(crate) fn dummy_proof< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + common_data: &CommonCircuitData, +) -> Result<( + ProofWithPublicInputs, + VerifierOnlyCircuitData, +)> +where + [(); C::Hasher::HASH_SIZE]:, +{ + let config = common_data.config.clone(); + + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + ensure!( + !common_data.config.zero_knowledge, + "Degree calculation can be off if zero-knowledge is on." + ); + let degree = common_data.degree(); + // Number of `NoopGate`s to add to get a circuit of size `degree` in the end. + // Need to account for public input hashing, a `PublicInputGate` and a `ConstantGate`. + let num_noop_gate = degree - ceil_div_usize(common_data.num_public_inputs, 8) - 2; + + for _ in 0..num_noop_gate { + builder.add_gate(NoopGate, vec![]); + } + for gate in &common_data.gates { + builder.add_gate_to_gate_set(gate.clone()); + } + for _ in 0..common_data.num_public_inputs { + let t = builder.add_virtual_public_input(); + pw.set_target(t, F::ZERO); + } + + let data = builder.build::(); + assert_eq!(&data.common, common_data); + let proof = data.prove(pw)?; + + Ok((proof, data.verifier_only)) +} + impl, const D: usize> CircuitBuilder { /// Verify `proof0` if `condition` else verify `proof1`. /// `proof0` and `proof1` are assumed to use the same `CommonCircuitData`. pub fn conditionally_verify_proof>( &mut self, condition: BoolTarget, - proof_with_pis0: ProofWithPublicInputsTarget, + proof_with_pis0: &ProofWithPublicInputsTarget, inner_verifier_data0: &VerifierCircuitTarget, - proof_with_pis1: ProofWithPublicInputsTarget, + proof_with_pis1: &ProofWithPublicInputsTarget, inner_verifier_data1: &VerifierCircuitTarget, inner_common_data: &CommonCircuitData, ) where @@ -79,18 +134,52 @@ impl, const D: usize> CircuitBuilder { let selected_verifier_data = VerifierCircuitTarget { constants_sigmas_cap: self.select_cap( condition, - inner_verifier_data0.constants_sigmas_cap.clone(), - inner_verifier_data1.constants_sigmas_cap.clone(), + &inner_verifier_data0.constants_sigmas_cap, + &inner_verifier_data1.constants_sigmas_cap, + ), + circuit_digest: self.select_hash( + condition, + inner_verifier_data0.circuit_digest, + inner_verifier_data1.circuit_digest, ), }; self.verify_proof(selected_proof, &selected_verifier_data, inner_common_data); } - fn select_vec(&mut self, b: BoolTarget, v0: Vec, v1: Vec) -> Vec { - v0.into_iter() + /// Conditionally verify a proof with a new generated dummy proof. + pub fn conditionally_verify_proof_or_dummy>( + &mut self, + condition: BoolTarget, + proof_with_pis: &ProofWithPublicInputsTarget, + inner_verifier_data: &VerifierCircuitTarget, + inner_common_data: &CommonCircuitData, + ) -> (ProofWithPublicInputsTarget, VerifierCircuitTarget) + where + C::Hasher: AlgebraicHasher, + { + let dummy_proof = self.add_virtual_proof_with_pis(inner_common_data); + let dummy_verifier_data = VerifierCircuitTarget { + constants_sigmas_cap: self + .add_virtual_cap(inner_common_data.config.fri_config.cap_height), + circuit_digest: self.add_virtual_hash(), + }; + self.conditionally_verify_proof( + condition, + proof_with_pis, + inner_verifier_data, + &dummy_proof, + &dummy_verifier_data, + inner_common_data, + ); + + (dummy_proof, dummy_verifier_data) + } + + fn select_vec(&mut self, b: BoolTarget, v0: &[Target], v1: &[Target]) -> Vec { + v0.iter() .zip_eq(v1) - .map(|(t0, t1)| self.select(b, t0, t1)) + .map(|(t0, t1)| self.select(b, *t0, *t1)) .collect() } @@ -108,15 +197,15 @@ impl, const D: usize> CircuitBuilder { fn select_cap( &mut self, b: BoolTarget, - cap0: MerkleCapTarget, - cap1: MerkleCapTarget, + cap0: &MerkleCapTarget, + cap1: &MerkleCapTarget, ) -> MerkleCapTarget { assert_eq!(cap0.0.len(), cap1.0.len()); MerkleCapTarget( cap0.0 - .into_iter() - .zip_eq(cap1.0) - .map(|(h0, h1)| self.select_hash(b, h0, h1)) + .iter() + .zip_eq(&cap1.0) + .map(|(h0, h1)| self.select_hash(b, *h0, *h1)) .collect(), ) } @@ -124,10 +213,10 @@ impl, const D: usize> CircuitBuilder { fn select_vec_cap( &mut self, b: BoolTarget, - v0: Vec, - v1: Vec, + v0: &[MerkleCapTarget], + v1: &[MerkleCapTarget], ) -> Vec { - v0.into_iter() + v0.iter() .zip_eq(v1) .map(|(c0, c1)| self.select_cap(b, c0, c1)) .collect() @@ -136,53 +225,53 @@ impl, const D: usize> CircuitBuilder { fn select_opening_set( &mut self, b: BoolTarget, - os0: OpeningSetTarget, - os1: OpeningSetTarget, + os0: &OpeningSetTarget, + os1: &OpeningSetTarget, ) -> OpeningSetTarget { OpeningSetTarget { - constants: self.select_vec_ext(b, os0.constants, os1.constants), - plonk_sigmas: self.select_vec_ext(b, os0.plonk_sigmas, os1.plonk_sigmas), - wires: self.select_vec_ext(b, os0.wires, os1.wires), - plonk_zs: self.select_vec_ext(b, os0.plonk_zs, os1.plonk_zs), - plonk_zs_next: self.select_vec_ext(b, os0.plonk_zs_next, os1.plonk_zs_next), - partial_products: self.select_vec_ext(b, os0.partial_products, os1.partial_products), - quotient_polys: self.select_vec_ext(b, os0.quotient_polys, os1.quotient_polys), + constants: self.select_vec_ext(b, &os0.constants, &os1.constants), + plonk_sigmas: self.select_vec_ext(b, &os0.plonk_sigmas, &os1.plonk_sigmas), + wires: self.select_vec_ext(b, &os0.wires, &os1.wires), + plonk_zs: self.select_vec_ext(b, &os0.plonk_zs, &os1.plonk_zs), + plonk_zs_next: self.select_vec_ext(b, &os0.plonk_zs_next, &os1.plonk_zs_next), + partial_products: self.select_vec_ext(b, &os0.partial_products, &os1.partial_products), + quotient_polys: self.select_vec_ext(b, &os0.quotient_polys, &os1.quotient_polys), } } fn select_vec_ext( &mut self, b: BoolTarget, - v0: Vec>, - v1: Vec>, + v0: &[ExtensionTarget], + v1: &[ExtensionTarget], ) -> Vec> { - v0.into_iter() + v0.iter() .zip_eq(v1) - .map(|(e0, e1)| self.select_ext(b, e0, e1)) + .map(|(e0, e1)| self.select_ext(b, *e0, *e1)) .collect() } fn select_opening_proof( &mut self, b: BoolTarget, - proof0: FriProofTarget, - proof1: FriProofTarget, + proof0: &FriProofTarget, + proof1: &FriProofTarget, ) -> FriProofTarget { FriProofTarget { commit_phase_merkle_caps: self.select_vec_cap( b, - proof0.commit_phase_merkle_caps, - proof1.commit_phase_merkle_caps, + &proof0.commit_phase_merkle_caps, + &proof1.commit_phase_merkle_caps, ), query_round_proofs: self.select_vec_query_round( b, - proof0.query_round_proofs, - proof1.query_round_proofs, + &proof0.query_round_proofs, + &proof1.query_round_proofs, ), final_poly: PolynomialCoeffsExtTarget(self.select_vec_ext( b, - proof0.final_poly.0, - proof1.final_poly.0, + &proof0.final_poly.0, + &proof1.final_poly.0, )), pow_witness: self.select(b, proof0.pow_witness, proof1.pow_witness), } @@ -191,26 +280,26 @@ impl, const D: usize> CircuitBuilder { fn select_query_round( &mut self, b: BoolTarget, - qr0: FriQueryRoundTarget, - qr1: FriQueryRoundTarget, + qr0: &FriQueryRoundTarget, + qr1: &FriQueryRoundTarget, ) -> FriQueryRoundTarget { FriQueryRoundTarget { initial_trees_proof: self.select_initial_tree_proof( b, - qr0.initial_trees_proof, - qr1.initial_trees_proof, + &qr0.initial_trees_proof, + &qr1.initial_trees_proof, ), - steps: self.select_vec_query_step(b, qr0.steps, qr1.steps), + steps: self.select_vec_query_step(b, &qr0.steps, &qr1.steps), } } fn select_vec_query_round( &mut self, b: BoolTarget, - v0: Vec>, - v1: Vec>, + v0: &[FriQueryRoundTarget], + v1: &[FriQueryRoundTarget], ) -> Vec> { - v0.into_iter() + v0.iter() .zip_eq(v1) .map(|(qr0, qr1)| self.select_query_round(b, qr0, qr1)) .collect() @@ -219,14 +308,14 @@ impl, const D: usize> CircuitBuilder { fn select_initial_tree_proof( &mut self, b: BoolTarget, - proof0: FriInitialTreeProofTarget, - proof1: FriInitialTreeProofTarget, + proof0: &FriInitialTreeProofTarget, + proof1: &FriInitialTreeProofTarget, ) -> FriInitialTreeProofTarget { FriInitialTreeProofTarget { evals_proofs: proof0 .evals_proofs - .into_iter() - .zip_eq(proof1.evals_proofs) + .iter() + .zip_eq(&proof1.evals_proofs) .map(|((v0, p0), (v1, p1))| { ( self.select_vec(b, v0, v1), @@ -240,15 +329,15 @@ impl, const D: usize> CircuitBuilder { fn select_merkle_proof( &mut self, b: BoolTarget, - proof0: MerkleProofTarget, - proof1: MerkleProofTarget, + proof0: &MerkleProofTarget, + proof1: &MerkleProofTarget, ) -> MerkleProofTarget { MerkleProofTarget { siblings: proof0 .siblings - .into_iter() - .zip_eq(proof1.siblings) - .map(|(h0, h1)| self.select_hash(b, h0, h1)) + .iter() + .zip_eq(&proof1.siblings) + .map(|(h0, h1)| self.select_hash(b, *h0, *h1)) .collect(), } } @@ -256,22 +345,22 @@ impl, const D: usize> CircuitBuilder { fn select_query_step( &mut self, b: BoolTarget, - qs0: FriQueryStepTarget, - qs1: FriQueryStepTarget, + qs0: &FriQueryStepTarget, + qs1: &FriQueryStepTarget, ) -> FriQueryStepTarget { FriQueryStepTarget { - evals: self.select_vec_ext(b, qs0.evals, qs1.evals), - merkle_proof: self.select_merkle_proof(b, qs0.merkle_proof, qs1.merkle_proof), + evals: self.select_vec_ext(b, &qs0.evals, &qs1.evals), + merkle_proof: self.select_merkle_proof(b, &qs0.merkle_proof, &qs1.merkle_proof), } } fn select_vec_query_step( &mut self, b: BoolTarget, - v0: Vec>, - v1: Vec>, + v0: &[FriQueryStepTarget], + v1: &[FriQueryStepTarget], ) -> Vec> { - v0.into_iter() + v0.iter() .zip_eq(v1) .map(|(qs0, qs1)| self.select_query_step(b, qs0, qs1)) .collect() @@ -297,6 +386,7 @@ mod tests { type F = >::F; let config = CircuitConfig::standard_recursion_config(); + // Generate proof. let mut builder = CircuitBuilder::::new(config.clone()); let mut pw = PartialWitness::new(); let t = builder.add_virtual_target(); @@ -310,33 +400,32 @@ mod tests { let proof = data.prove(pw)?; data.verify(proof.clone())?; + // Generate dummy proof with the same `CommonCircuitData`. + let (dummy_proof, dummy_data) = dummy_proof(&data.common)?; + + // Conditionally verify the two proofs. let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::new(); let pt = builder.add_virtual_proof_with_pis(&data.common); pw.set_proof_with_pis_target(&pt, &proof); let dummy_pt = builder.add_virtual_proof_with_pis(&data.common); - pw.set_proof_with_pis_target(&dummy_pt, &proof); - + pw.set_proof_with_pis_target(&dummy_pt, &dummy_proof); let inner_data = VerifierCircuitTarget { constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), + circuit_digest: builder.add_virtual_hash(), }; - pw.set_cap_target( - &inner_data.constants_sigmas_cap, - &data.verifier_only.constants_sigmas_cap, - ); + pw.set_verifier_data_target(&inner_data, &data.verifier_only); let dummy_inner_data = VerifierCircuitTarget { constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), + circuit_digest: builder.add_virtual_hash(), }; - pw.set_cap_target( - &dummy_inner_data.constants_sigmas_cap, - &data.verifier_only.constants_sigmas_cap, - ); + pw.set_verifier_data_target(&dummy_inner_data, &dummy_data); let b = builder.constant_bool(F::rand().0 % 2 == 0); builder.conditionally_verify_proof( b, - pt, + &pt, &inner_data, - dummy_pt, + &dummy_pt, &dummy_inner_data, &data.common, ); diff --git a/plonky2/src/plonk/get_challenges.rs b/plonky2/src/plonk/get_challenges.rs index a8ca52e5..f497380f 100644 --- a/plonky2/src/plonk/get_challenges.rs +++ b/plonky2/src/plonk/get_challenges.rs @@ -29,6 +29,7 @@ fn get_challenges, C: GenericConfig, cons commit_phase_merkle_caps: &[MerkleCap], final_poly: &PolynomialCoeffs, pow_witness: F, + circuit_digest: &<>::Hasher as Hasher>::Hash, common_data: &CommonCircuitData, ) -> anyhow::Result> { let config = &common_data.config; @@ -37,7 +38,7 @@ fn get_challenges, C: GenericConfig, cons let mut challenger = Challenger::::new(); // Observe the instance. - challenger.observe_hash::(common_data.circuit_digest); + challenger.observe_hash::(*circuit_digest); challenger.observe_hash::(public_inputs_hash); challenger.observe_cap(wires_cap); @@ -61,7 +62,7 @@ fn get_challenges, C: GenericConfig, cons commit_phase_merkle_caps, final_poly, pow_witness, - common_data.degree_bits, + common_data.degree_bits(), &config.fri_config, ), }) @@ -72,10 +73,11 @@ impl, C: GenericConfig, const D: usize> { pub(crate) fn fri_query_indices( &self, + circuit_digest: &<>::Hasher as Hasher>::Hash, common_data: &CommonCircuitData, ) -> anyhow::Result> { Ok(self - .get_challenges(self.get_public_inputs_hash(), common_data)? + .get_challenges(self.get_public_inputs_hash(), circuit_digest, common_data)? .fri_challenges .fri_query_indices) } @@ -84,6 +86,7 @@ impl, C: GenericConfig, const D: usize> pub(crate) fn get_challenges( &self, public_inputs_hash: <>::InnerHasher as Hasher>::Hash, + circuit_digest: &<>::Hasher as Hasher>::Hash, common_data: &CommonCircuitData, ) -> anyhow::Result> { let Proof { @@ -109,6 +112,7 @@ impl, C: GenericConfig, const D: usize> commit_phase_merkle_caps, final_poly, *pow_witness, + circuit_digest, common_data, ) } @@ -121,6 +125,7 @@ impl, C: GenericConfig, const D: usize> pub(crate) fn get_challenges( &self, public_inputs_hash: <>::InnerHasher as Hasher>::Hash, + circuit_digest: &<>::Hasher as Hasher>::Hash, common_data: &CommonCircuitData, ) -> anyhow::Result> { let CompressedProof { @@ -146,6 +151,7 @@ impl, C: GenericConfig, const D: usize> commit_phase_merkle_caps, final_poly, *pow_witness, + circuit_digest, common_data, ) } @@ -175,7 +181,7 @@ impl, C: GenericConfig, const D: usize> &self.proof.openings.to_fri_openings(), *fri_alpha, ); - let log_n = common_data.degree_bits + common_data.config.fri_config.rate_bits; + let log_n = common_data.degree_bits() + common_data.config.fri_config.rate_bits; // Simulate the proof verification and collect the inferred elements. // The content of the loop is basically the same as the `fri_verifier_query_round` function. for &(mut x_index) in fri_query_indices { @@ -237,6 +243,7 @@ impl, const D: usize> CircuitBuilder { commit_phase_merkle_caps: &[MerkleCapTarget], final_poly: &PolynomialCoeffsExtTarget, pow_witness: Target, + inner_circuit_digest: HashOutTarget, inner_common_data: &CommonCircuitData, ) -> ProofChallengesTarget where @@ -248,9 +255,7 @@ impl, const D: usize> CircuitBuilder { let mut challenger = RecursiveChallenger::::new(self); // Observe the instance. - let digest = - HashOutTarget::from_vec(self.constants(&inner_common_data.circuit_digest.elements)); - challenger.observe_hash(&digest); + challenger.observe_hash(&inner_circuit_digest); challenger.observe_hash(&public_inputs_hash); challenger.observe_cap(wires_cap); @@ -286,6 +291,7 @@ impl ProofWithPublicInputsTarget { &self, builder: &mut CircuitBuilder, public_inputs_hash: HashOutTarget, + inner_circuit_digest: HashOutTarget, inner_common_data: &CommonCircuitData, ) -> ProofChallengesTarget where @@ -314,6 +320,7 @@ impl ProofWithPublicInputsTarget { commit_phase_merkle_caps, final_poly, *pow_witness, + inner_circuit_digest, inner_common_data, ) } diff --git a/plonky2/src/plonk/proof.rs b/plonky2/src/plonk/proof.rs index 922a24bb..2ec26c75 100644 --- a/plonky2/src/plonk/proof.rs +++ b/plonky2/src/plonk/proof.rs @@ -81,9 +81,10 @@ impl, C: GenericConfig, const D: usize> { pub fn compress( self, + circuit_digest: &<>::Hasher as Hasher>::Hash, common_data: &CommonCircuitData, ) -> anyhow::Result> { - let indices = self.fri_query_indices(common_data)?; + let indices = self.fri_query_indices(circuit_digest, common_data)?; let compressed_proof = self.proof.compress(&indices, &common_data.fri_params); Ok(CompressedProofWithPublicInputs { public_inputs: self.public_inputs, @@ -176,12 +177,14 @@ impl, C: GenericConfig, const D: usize> { pub fn decompress( self, + circuit_digest: &<>::Hasher as Hasher>::Hash, common_data: &CommonCircuitData, ) -> anyhow::Result> where [(); C::Hasher::HASH_SIZE]:, { - let challenges = self.get_challenges(self.get_public_inputs_hash(), common_data)?; + let challenges = + self.get_challenges(self.get_public_inputs_hash(), circuit_digest, common_data)?; let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); let decompressed_proof = self.proof @@ -205,7 +208,11 @@ impl, C: GenericConfig, const D: usize> "Number of public inputs doesn't match circuit data." ); let public_inputs_hash = self.get_public_inputs_hash(); - let challenges = self.get_challenges(public_inputs_hash, common_data)?; + let challenges = self.get_challenges( + public_inputs_hash, + &verifier_data.circuit_digest, + common_data, + )?; let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); let decompressed_proof = self.proof @@ -418,8 +425,8 @@ mod tests { verify(proof.clone(), &data.verifier_only, &data.common)?; // Verify that `decompress ∘ compress = identity`. - let compressed_proof = proof.clone().compress(&data.common)?; - let decompressed_compressed_proof = compressed_proof.clone().decompress(&data.common)?; + let compressed_proof = data.compress(proof.clone())?; + let decompressed_compressed_proof = data.decompress(compressed_proof.clone())?; assert_eq!(proof, decompressed_compressed_proof); verify(proof, &data.verifier_only, &data.common)?; diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index 3e81942b..8476a2d9 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -81,7 +81,7 @@ where let mut challenger = Challenger::::new(); // Observe the instance. - challenger.observe_hash::(common_data.circuit_digest); + challenger.observe_hash::(prover_data.circuit_digest); challenger.observe_hash::(public_inputs_hash); challenger.observe_cap(&wires_commitment.merkle_tree.cap); @@ -172,9 +172,9 @@ where // To avoid leaking witness data, we want to ensure that our opening locations, `zeta` and // `g * zeta`, are not in our subgroup `H`. It suffices to check `zeta` only, since // `(g * zeta)^n = zeta^n`, where `n` is the order of `g`. - let g = F::Extension::primitive_root_of_unity(common_data.degree_bits); + let g = F::Extension::primitive_root_of_unity(common_data.degree_bits()); ensure!( - zeta.exp_power_of_2(common_data.degree_bits) != F::Extension::ONE, + zeta.exp_power_of_2(common_data.degree_bits()) != F::Extension::ONE, "Opening point is in the subgroup." ); @@ -342,10 +342,10 @@ fn compute_quotient_polys< // steps away since we work on an LDE of degree `max_filtered_constraint_degree`. let next_step = 1 << quotient_degree_bits; - let points = F::two_adic_subgroup(common_data.degree_bits + quotient_degree_bits); + let points = F::two_adic_subgroup(common_data.degree_bits() + quotient_degree_bits); let lde_size = points.len(); - let z_h_on_coset = ZeroPolyOnCoset::new(common_data.degree_bits, quotient_degree_bits); + let z_h_on_coset = ZeroPolyOnCoset::new(common_data.degree_bits(), quotient_degree_bits); let points_batches = points.par_chunks(BATCH_SIZE); let num_batches = ceil_div_usize(points.len(), BATCH_SIZE); diff --git a/plonky2/src/plonk/recursive_verifier.rs b/plonky2/src/plonk/recursive_verifier.rs index ecce34e1..bb9076be 100644 --- a/plonky2/src/plonk/recursive_verifier.rs +++ b/plonky2/src/plonk/recursive_verifier.rs @@ -29,7 +29,12 @@ impl, const D: usize> CircuitBuilder { ); let public_inputs_hash = self.hash_n_to_hash_no_pad::(proof_with_pis.public_inputs.clone()); - let challenges = proof_with_pis.get_challenges(self, public_inputs_hash, inner_common_data); + let challenges = proof_with_pis.get_challenges( + self, + public_inputs_hash, + inner_verifier_data.circuit_digest, + inner_common_data, + ); self.verify_proof_with_challenges( proof_with_pis.proof, @@ -66,7 +71,7 @@ impl, const D: usize> CircuitBuilder { let partial_products = &proof.openings.partial_products; let zeta_pow_deg = - self.exp_power_of_2_extension(challenges.plonk_zeta, inner_common_data.degree_bits); + self.exp_power_of_2_extension(challenges.plonk_zeta, inner_common_data.degree_bits()); let vanishing_polys_zeta = with_context!( self, "evaluate the vanishing polynomial at our challenge point, zeta.", @@ -205,9 +210,9 @@ mod tests { let config = CircuitConfig::standard_recursion_zk_config(); let (proof, vd, cd) = dummy_proof::(&config, 4_000)?; - let (proof, _vd, cd) = + let (proof, vd, cd) = recursive_proof::(proof, vd, cd, &config, None, true, true)?; - test_serialization(&proof, &cd)?; + test_serialization(&proof, &vd, &cd)?; Ok(()) } @@ -223,19 +228,19 @@ mod tests { // Start with a degree 2^14 proof let (proof, vd, cd) = dummy_proof::(&config, 16_000)?; - assert_eq!(cd.degree_bits, 14); + assert_eq!(cd.degree_bits(), 14); // Shrink it to 2^13. let (proof, vd, cd) = recursive_proof::(proof, vd, cd, &config, Some(13), false, false)?; - assert_eq!(cd.degree_bits, 13); + assert_eq!(cd.degree_bits(), 13); // Shrink it to 2^12. - let (proof, _vd, cd) = + let (proof, vd, cd) = recursive_proof::(proof, vd, cd, &config, None, true, true)?; - assert_eq!(cd.degree_bits, 12); + assert_eq!(cd.degree_bits(), 12); - test_serialization(&proof, &cd)?; + test_serialization(&proof, &vd, &cd)?; Ok(()) } @@ -255,11 +260,11 @@ mod tests { // An initial dummy proof. let (proof, vd, cd) = dummy_proof::(&standard_config, 4_000)?; - assert_eq!(cd.degree_bits, 12); + assert_eq!(cd.degree_bits(), 12); // A standard recursive proof. let (proof, vd, cd) = recursive_proof(proof, vd, cd, &standard_config, None, false, false)?; - assert_eq!(cd.degree_bits, 12); + assert_eq!(cd.degree_bits(), 12); // A high-rate recursive proof, designed to be verifiable with fewer routed wires. let high_rate_config = CircuitConfig { @@ -273,7 +278,7 @@ mod tests { }; let (proof, vd, cd) = recursive_proof::(proof, vd, cd, &high_rate_config, None, true, true)?; - assert_eq!(cd.degree_bits, 12); + assert_eq!(cd.degree_bits(), 12); // A final proof, optimized for size. let final_config = CircuitConfig { @@ -287,11 +292,11 @@ mod tests { }, ..high_rate_config }; - let (proof, _vd, cd) = + let (proof, vd, cd) = recursive_proof::(proof, vd, cd, &final_config, None, true, true)?; - assert_eq!(cd.degree_bits, 12, "final proof too large"); + assert_eq!(cd.degree_bits(), 12, "final proof too large"); - test_serialization(&proof, &cd)?; + test_serialization(&proof, &vd, &cd)?; Ok(()) } @@ -309,11 +314,11 @@ mod tests { let (proof, vd, cd) = recursive_proof::(proof, vd, cd, &config, None, false, false)?; - test_serialization(&proof, &cd)?; + test_serialization(&proof, &vd, &cd)?; - let (proof, _vd, cd) = + let (proof, vd, cd) = recursive_proof::(proof, vd, cd, &config, None, false, false)?; - test_serialization(&proof, &cd)?; + test_serialization(&proof, &vd, &cd)?; Ok(()) } @@ -372,11 +377,13 @@ mod tests { let inner_data = VerifierCircuitTarget { constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height), + circuit_digest: builder.add_virtual_hash(), }; pw.set_cap_target( &inner_data.constants_sigmas_cap, &inner_vd.constants_sigmas_cap, ); + pw.set_hash_target(inner_data.circuit_digest, inner_vd.circuit_digest); builder.verify_proof(pt, &inner_data, &inner_cd); @@ -414,6 +421,7 @@ mod tests { const D: usize, >( proof: &ProofWithPublicInputs, + vd: &VerifierOnlyCircuitData, cd: &CommonCircuitData, ) -> Result<()> where @@ -425,8 +433,10 @@ mod tests { assert_eq!(proof, &proof_from_bytes); let now = std::time::Instant::now(); - let compressed_proof = proof.clone().compress(cd)?; - let decompressed_compressed_proof = compressed_proof.clone().decompress(cd)?; + let compressed_proof = proof.clone().compress(&vd.circuit_digest, cd)?; + let decompressed_compressed_proof = compressed_proof + .clone() + .decompress(&vd.circuit_digest, cd)?; info!("{:.4}s to compress proof", now.elapsed().as_secs_f64()); assert_eq!(proof, &decompressed_compressed_proof); diff --git a/plonky2/src/plonk/verifier.rs b/plonky2/src/plonk/verifier.rs index 6a4f3790..37ddfffa 100644 --- a/plonky2/src/plonk/verifier.rs +++ b/plonky2/src/plonk/verifier.rs @@ -23,7 +23,11 @@ where validate_proof_with_pis_shape(&proof_with_pis, common_data)?; let public_inputs_hash = proof_with_pis.get_public_inputs_hash(); - let challenges = proof_with_pis.get_challenges(public_inputs_hash, common_data)?; + let challenges = proof_with_pis.get_challenges( + public_inputs_hash, + &verifier_data.circuit_digest, + common_data, + )?; verify_with_challenges( proof_with_pis.proof, @@ -78,7 +82,7 @@ where let quotient_polys_zeta = &proof.openings.quotient_polys; let zeta_pow_deg = challenges .plonk_zeta - .exp_power_of_2(common_data.degree_bits); + .exp_power_of_2(common_data.degree_bits()); let z_h_zeta = zeta_pow_deg - F::Extension::ONE; // `quotient_polys_zeta` holds `num_challenges * quotient_degree_factor` evaluations. // Each chunk of `quotient_degree_factor` holds the evaluations of `t_0(zeta),...,t_{quotient_degree_factor-1}(zeta)`