diff --git a/evm/Cargo.toml b/evm/Cargo.toml index 848dff15..7c179318 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -7,14 +7,14 @@ edition = "2021" [dependencies] plonky2 = { path = "../plonky2", default-features = false, features = ["rand", "timing"] } plonky2_util = { path = "../util" } -eth_trie_utils = "0.1.0" +eth_trie_utils = "0.2.1" anyhow = "1.0.40" env_logger = "0.9.0" ethereum-types = "0.14.0" hex = { version = "0.4.3", optional = true } hex-literal = "0.3.4" itertools = "0.10.3" -keccak-hash = "0.9.0" +keccak-hash = "0.10.0" log = "0.4.14" num = "0.4.0" maybe_rayon = { path = "../maybe_rayon" } diff --git a/evm/spec/tries.tex b/evm/spec/mpts.tex similarity index 53% rename from evm/spec/tries.tex rename to evm/spec/mpts.tex index 7ec0fcce..49d1d328 100644 --- a/evm/spec/tries.tex +++ b/evm/spec/mpts.tex @@ -6,21 +6,21 @@ Withour our zkEVM's kernel memory, \begin{enumerate} \item An empty node is encoded as $(\texttt{MPT\_NODE\_EMPTY})$. - \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, c_1, \dots, c_{16}, \abs{v}, v)$, where each $c_i$ is a pointer to a child node, and $v$ is a value of length $\abs{v}$.\footnote{If a branch node has no associated value, then $\abs{v} = 0$ and $v = ()$.} + \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, c_1, \dots, c_{16}, v)$, where each $c_i$ is a pointer to a child node, and $v$ is a pointer to a value. If a branch node has no associated value, then $v = 0$, i.e. the null pointer. \item An extension node is encoded as $(\texttt{MPT\_NODE\_EXTENSION}, k, c)$, $k$ represents the part of the key associated with this extension, and is encoded as a 2-tuple $(\texttt{packed\_nibbles}, \texttt{num\_nibbles})$. $c$ is a pointer to a child node. - \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, \abs{v}, v)$, where $k$ is a 2-tuple as above, and $v$ is a leaf payload. - \item A digest node is encoded as $(\texttt{MPT\_NODE\_DIGEST}, d)$, where $d$ is a Keccak256 digest. + \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, v)$, where $k$ is a 2-tuple as above, and $v$ is a pointer to a value. + \item A digest node is encoded as $(\texttt{MPT\_NODE\_HASH}, d)$, where $d$ is a Keccak256 digest. \end{enumerate} \subsection{Prover input format} -The initial state of each trie is given by the prover as a nondeterministic input tape. This tape has a similar format: +The initial state of each trie is given by the prover as a nondeterministic input tape. This tape has a slightly different format: \begin{enumerate} \item An empty node is encoded as $(\texttt{MPT\_NODE\_EMPTY})$. - \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, \abs{v}, v, c_1, \dots, c_{16})$, where $\abs{v}$ is the length of the value, and $v$ is the value itself. Each $c_i$ is the encoding of a child node. + \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, v_?, c_1, \dots, c_{16})$. Here $v_?$ consists of a flag indicating whether a value is present,\todo{In the current implementation, we use a length prefix rather than a is-present prefix, but we plan to change that.} followed by the actual value payload if one is present. Each $c_i$ is the encoding of a child node. \item An extension node is encoded as $(\texttt{MPT\_NODE\_EXTENSION}, k, c)$, $k$ represents the part of the key associated with this extension, and is encoded as a 2-tuple $(\texttt{packed\_nibbles}, \texttt{num\_nibbles})$. $c$ is a pointer to a child node. - \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, \abs{v}, v)$, where $k$ is a 2-tuple as above, and $v$ is a leaf payload. - \item A digest node is encoded as $(\texttt{MPT\_NODE\_DIGEST}, d)$, where $d$ is a Keccak256 digest. + \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, v)$, where $k$ is a 2-tuple as above, and $v$ is a value payload. + \item A digest node is encoded as $(\texttt{MPT\_NODE\_HASH}, d)$, where $d$ is a Keccak256 digest. \end{enumerate} -Nodes are thus given in depth-first order, leading to natural recursive methods for encoding and decoding this format. +Nodes are thus given in depth-first order, enabling natural recursive methods for encoding and decoding this format. diff --git a/evm/spec/zkevm.pdf b/evm/spec/zkevm.pdf index 184ba36b..f181eba6 100644 Binary files a/evm/spec/zkevm.pdf and b/evm/spec/zkevm.pdf differ diff --git a/evm/spec/zkevm.tex b/evm/spec/zkevm.tex index 65766986..2927e7a5 100644 --- a/evm/spec/zkevm.tex +++ b/evm/spec/zkevm.tex @@ -51,7 +51,7 @@ \input{introduction} \input{framework} \input{tables} -\input{tries} +\input{mpts} \input{instructions} \bibliography{bibliography}{} diff --git a/evm/src/arithmetic/add.rs b/evm/src/arithmetic/add.rs index d2520fb9..1bf798cc 100644 --- a/evm/src/arithmetic/add.rs +++ b/evm/src/arithmetic/add.rs @@ -6,6 +6,7 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use crate::arithmetic::columns::*; +use crate::arithmetic::utils::read_value_u64_limbs; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::range_check_error; @@ -94,15 +95,12 @@ where } pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { - let input0_limbs = ADD_INPUT_0.map(|c| lv[c].to_canonical_u64()); - let input1_limbs = ADD_INPUT_1.map(|c| lv[c].to_canonical_u64()); + let input0 = read_value_u64_limbs(lv, ADD_INPUT_0); + let input1 = read_value_u64_limbs(lv, ADD_INPUT_1); // Input and output have 16-bit limbs - let (output_limbs, _) = u256_add_cc(input0_limbs, input1_limbs); - - for (&c, output_limb) in ADD_OUTPUT.iter().zip(output_limbs) { - lv[c] = F::from_canonical_u64(output_limb); - } + let (output_limbs, _) = u256_add_cc(input0, input1); + lv[ADD_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_u64(c))); } pub fn eval_packed_generic( @@ -114,15 +112,20 @@ pub fn eval_packed_generic( range_check_error!(ADD_OUTPUT, 16); let is_add = lv[IS_ADD]; - let input0_limbs = ADD_INPUT_0.iter().map(|&c| lv[c]); - let input1_limbs = ADD_INPUT_1.iter().map(|&c| lv[c]); - let output_limbs = ADD_OUTPUT.iter().map(|&c| lv[c]); + let input0_limbs = &lv[ADD_INPUT_0]; + let input1_limbs = &lv[ADD_INPUT_1]; + let output_limbs = &lv[ADD_OUTPUT]; // This computed output is not yet reduced; i.e. some limbs may be // more than 16 bits. - let output_computed = input0_limbs.zip(input1_limbs).map(|(a, b)| a + b); + let output_computed = input0_limbs.iter().zip(input1_limbs).map(|(&a, &b)| a + b); - eval_packed_generic_are_equal(yield_constr, is_add, output_computed, output_limbs); + eval_packed_generic_are_equal( + yield_constr, + is_add, + output_computed, + output_limbs.iter().copied(), + ); } #[allow(clippy::needless_collect)] @@ -132,17 +135,18 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr: &mut RecursiveConstraintConsumer, ) { let is_add = lv[IS_ADD]; - let input0_limbs = ADD_INPUT_0.iter().map(|&c| lv[c]); - let input1_limbs = ADD_INPUT_1.iter().map(|&c| lv[c]); - let output_limbs = ADD_OUTPUT.iter().map(|&c| lv[c]); + let input0_limbs = &lv[ADD_INPUT_0]; + let input1_limbs = &lv[ADD_INPUT_1]; + let output_limbs = &lv[ADD_OUTPUT]; // Since `map` is lazy and the closure passed to it borrows // `builder`, we can't then borrow builder again below in the call // to `eval_ext_circuit_are_equal`. The solution is to force // evaluation with `collect`. let output_computed = input0_limbs + .iter() .zip(input1_limbs) - .map(|(a, b)| builder.add_extension(a, b)) + .map(|(&a, &b)| builder.add_extension(a, b)) .collect::>>(); eval_ext_circuit_are_equal( @@ -150,7 +154,7 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr, is_add, output_computed.into_iter(), - output_limbs, + output_limbs.iter().copied(), ); } @@ -203,7 +207,7 @@ mod tests { for _ in 0..N_RND_TESTS { // set inputs to random values - for (&ai, bi) in ADD_INPUT_0.iter().zip(ADD_INPUT_1) { + for (ai, bi) in ADD_INPUT_0.zip(ADD_INPUT_1) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); } diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index ca8ba549..ee73f223 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -1,5 +1,7 @@ //! Arithmetic unit +use std::ops::Range; + pub const LIMB_BITS: usize = 16; const EVM_REGISTER_BITS: usize = 256; @@ -44,57 +46,42 @@ pub(crate) const ALL_OPERATIONS: [usize; 16] = [ /// used by any arithmetic circuit, depending on which one is active /// this cycle. Can be increased as needed as other operations are /// implemented. -const NUM_SHARED_COLS: usize = 144; // only need 64 for add, sub, and mul +const NUM_SHARED_COLS: usize = 9 * N_LIMBS; // only need 64 for add, sub, and mul -const fn shared_col(i: usize) -> usize { - assert!(i < NUM_SHARED_COLS); - START_SHARED_COLS + i -} +const GENERAL_INPUT_0: Range = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS; +const GENERAL_INPUT_1: Range = GENERAL_INPUT_0.end..GENERAL_INPUT_0.end + N_LIMBS; +const GENERAL_INPUT_2: Range = GENERAL_INPUT_1.end..GENERAL_INPUT_1.end + N_LIMBS; +const GENERAL_INPUT_3: Range = GENERAL_INPUT_2.end..GENERAL_INPUT_2.end + N_LIMBS; +const AUX_INPUT_0: Range = GENERAL_INPUT_3.end..GENERAL_INPUT_3.end + 2 * N_LIMBS; +const AUX_INPUT_1: Range = AUX_INPUT_0.end..AUX_INPUT_0.end + 2 * N_LIMBS; +const AUX_INPUT_2: Range = AUX_INPUT_1.end..AUX_INPUT_1.end + N_LIMBS; -const fn gen_input_cols(start: usize) -> [usize; N] { - let mut cols = [0usize; N]; - let mut i = 0; - while i < N { - cols[i] = shared_col(start + i); - i += 1; - } - cols -} +pub(crate) const ADD_INPUT_0: Range = GENERAL_INPUT_0; +pub(crate) const ADD_INPUT_1: Range = GENERAL_INPUT_1; +pub(crate) const ADD_OUTPUT: Range = GENERAL_INPUT_2; -const GENERAL_INPUT_0: [usize; N_LIMBS] = gen_input_cols::(0); -const GENERAL_INPUT_1: [usize; N_LIMBS] = gen_input_cols::(N_LIMBS); -const GENERAL_INPUT_2: [usize; N_LIMBS] = gen_input_cols::(2 * N_LIMBS); -const GENERAL_INPUT_3: [usize; N_LIMBS] = gen_input_cols::(3 * N_LIMBS); -const AUX_INPUT_0: [usize; 2 * N_LIMBS] = gen_input_cols::<{ 2 * N_LIMBS }>(4 * N_LIMBS); -const AUX_INPUT_1: [usize; 2 * N_LIMBS] = gen_input_cols::<{ 2 * N_LIMBS }>(6 * N_LIMBS); -const AUX_INPUT_2: [usize; N_LIMBS] = gen_input_cols::(8 * N_LIMBS); +pub(crate) const SUB_INPUT_0: Range = GENERAL_INPUT_0; +pub(crate) const SUB_INPUT_1: Range = GENERAL_INPUT_1; +pub(crate) const SUB_OUTPUT: Range = GENERAL_INPUT_2; -pub(crate) const ADD_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; -pub(crate) const ADD_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; -pub(crate) const ADD_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_2; +pub(crate) const MUL_INPUT_0: Range = GENERAL_INPUT_0; +pub(crate) const MUL_INPUT_1: Range = GENERAL_INPUT_1; +pub(crate) const MUL_OUTPUT: Range = GENERAL_INPUT_2; +pub(crate) const MUL_AUX_INPUT: Range = GENERAL_INPUT_3; -pub(crate) const SUB_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; -pub(crate) const SUB_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; -pub(crate) const SUB_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_2; +pub(crate) const CMP_INPUT_0: Range = GENERAL_INPUT_0; +pub(crate) const CMP_INPUT_1: Range = GENERAL_INPUT_1; +pub(crate) const CMP_OUTPUT: usize = GENERAL_INPUT_2.start; +pub(crate) const CMP_AUX_INPUT: Range = GENERAL_INPUT_3; -pub(crate) const MUL_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; -pub(crate) const MUL_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; -pub(crate) const MUL_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_2; -pub(crate) const MUL_AUX_INPUT: [usize; N_LIMBS] = GENERAL_INPUT_3; - -pub(crate) const CMP_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; -pub(crate) const CMP_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; -pub(crate) const CMP_OUTPUT: usize = GENERAL_INPUT_2[0]; -pub(crate) const CMP_AUX_INPUT: [usize; N_LIMBS] = GENERAL_INPUT_3; - -pub(crate) const MODULAR_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; -pub(crate) const MODULAR_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; -pub(crate) const MODULAR_MODULUS: [usize; N_LIMBS] = GENERAL_INPUT_2; -pub(crate) const MODULAR_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_3; -pub(crate) const MODULAR_QUO_INPUT: [usize; 2 * N_LIMBS] = AUX_INPUT_0; -// NB: Last value is not used in AUX, it is used in IS_ZERO -pub(crate) const MODULAR_AUX_INPUT: [usize; 2 * N_LIMBS] = AUX_INPUT_1; -pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_INPUT_1[2 * N_LIMBS - 1]; -pub(crate) const MODULAR_OUT_AUX_RED: [usize; N_LIMBS] = AUX_INPUT_2; +pub(crate) const MODULAR_INPUT_0: Range = GENERAL_INPUT_0; +pub(crate) const MODULAR_INPUT_1: Range = GENERAL_INPUT_1; +pub(crate) const MODULAR_MODULUS: Range = GENERAL_INPUT_2; +pub(crate) const MODULAR_OUTPUT: Range = GENERAL_INPUT_3; +pub(crate) const MODULAR_QUO_INPUT: Range = AUX_INPUT_0; +// NB: Last value is not used in AUX, it is used in MOD_IS_ZERO +pub(crate) const MODULAR_AUX_INPUT: Range = AUX_INPUT_1; +pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_INPUT_1.end - 1; +pub(crate) const MODULAR_OUT_AUX_RED: Range = AUX_INPUT_2; pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS; diff --git a/evm/src/arithmetic/compare.rs b/evm/src/arithmetic/compare.rs index a6566db5..55dc5764 100644 --- a/evm/src/arithmetic/compare.rs +++ b/evm/src/arithmetic/compare.rs @@ -22,12 +22,13 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::arithmetic::add::{eval_ext_circuit_are_equal, eval_packed_generic_are_equal}; use crate::arithmetic::columns::*; use crate::arithmetic::sub::u256_sub_br; +use crate::arithmetic::utils::read_value_u64_limbs; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::range_check_error; pub(crate) fn generate(lv: &mut [F; NUM_ARITH_COLUMNS], op: usize) { - let input0 = CMP_INPUT_0.map(|c| lv[c].to_canonical_u64()); - let input1 = CMP_INPUT_1.map(|c| lv[c].to_canonical_u64()); + let input0 = read_value_u64_limbs(lv, CMP_INPUT_0); + let input1 = read_value_u64_limbs(lv, CMP_INPUT_1); let (diff, br) = match op { // input0 - input1 == diff + br*2^256 @@ -39,9 +40,7 @@ pub(crate) fn generate(lv: &mut [F; NUM_ARITH_COLUMNS], op: usize) _ => panic!("op code not a comparison"), }; - for (&c, diff_limb) in CMP_AUX_INPUT.iter().zip(diff) { - lv[c] = F::from_canonical_u64(diff_limb); - } + lv[CMP_AUX_INPUT].copy_from_slice(&diff.map(|c| F::from_canonical_u64(c))); lv[CMP_OUTPUT] = F::from_canonical_u64(br); } @@ -56,15 +55,17 @@ fn eval_packed_generic_check_is_one_bit( pub(crate) fn eval_packed_generic_lt( yield_constr: &mut ConstraintConsumer

, is_op: P, - input0: [P; N_LIMBS], - input1: [P; N_LIMBS], - aux: [P; N_LIMBS], + input0: &[P], + input1: &[P], + aux: &[P], output: P, ) { + debug_assert!(input0.len() == N_LIMBS && input1.len() == N_LIMBS && aux.len() == N_LIMBS); + // Verify (input0 < input1) == output by providing aux such that // input0 - input1 == aux + output*2^256. - let lhs_limbs = input0.iter().zip(input1).map(|(&a, b)| a - b); - let cy = eval_packed_generic_are_equal(yield_constr, is_op, aux.into_iter(), lhs_limbs); + let lhs_limbs = input0.iter().zip(input1).map(|(&a, &b)| a - b); + let cy = eval_packed_generic_are_equal(yield_constr, is_op, aux.iter().copied(), lhs_limbs); // We don't need to check that cy is 0 or 1, since output has // already been checked to be 0 or 1. yield_constr.constraint(is_op * (cy - output)); @@ -81,9 +82,9 @@ pub fn eval_packed_generic( let is_lt = lv[IS_LT]; let is_gt = lv[IS_GT]; - let input0 = CMP_INPUT_0.map(|c| lv[c]); - let input1 = CMP_INPUT_1.map(|c| lv[c]); - let aux = CMP_AUX_INPUT.map(|c| lv[c]); + let input0 = &lv[CMP_INPUT_0]; + let input1 = &lv[CMP_INPUT_1]; + let aux = &lv[CMP_AUX_INPUT]; let output = lv[CMP_OUTPUT]; let is_cmp = is_lt + is_gt; @@ -109,11 +110,13 @@ pub(crate) fn eval_ext_circuit_lt, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, yield_constr: &mut RecursiveConstraintConsumer, is_op: ExtensionTarget, - input0: [ExtensionTarget; N_LIMBS], - input1: [ExtensionTarget; N_LIMBS], - aux: [ExtensionTarget; N_LIMBS], + input0: &[ExtensionTarget], + input1: &[ExtensionTarget], + aux: &[ExtensionTarget], output: ExtensionTarget, ) { + debug_assert!(input0.len() == N_LIMBS && input1.len() == N_LIMBS && aux.len() == N_LIMBS); + // Since `map` is lazy and the closure passed to it borrows // `builder`, we can't then borrow builder again below in the call // to `eval_ext_circuit_are_equal`. The solution is to force @@ -121,14 +124,14 @@ pub(crate) fn eval_ext_circuit_lt, const D: usize>( let lhs_limbs = input0 .iter() .zip(input1) - .map(|(&a, b)| builder.sub_extension(a, b)) + .map(|(&a, &b)| builder.sub_extension(a, b)) .collect::>>(); let cy = eval_ext_circuit_are_equal( builder, yield_constr, is_op, - aux.into_iter(), + aux.iter().copied(), lhs_limbs.into_iter(), ); let good_output = builder.sub_extension(cy, output); @@ -144,9 +147,9 @@ pub fn eval_ext_circuit, const D: usize>( let is_lt = lv[IS_LT]; let is_gt = lv[IS_GT]; - let input0 = CMP_INPUT_0.map(|c| lv[c]); - let input1 = CMP_INPUT_1.map(|c| lv[c]); - let aux = CMP_AUX_INPUT.map(|c| lv[c]); + let input0 = &lv[CMP_INPUT_0]; + let input1 = &lv[CMP_INPUT_1]; + let aux = &lv[CMP_AUX_INPUT]; let output = lv[CMP_OUTPUT]; let is_cmp = builder.add_extension(is_lt, is_gt); @@ -210,7 +213,7 @@ mod tests { lv[other_op] = F::ZERO; // set inputs to random values - for (&ai, bi) in CMP_INPUT_0.iter().zip(CMP_INPUT_1) { + for (ai, bi) in CMP_INPUT_0.zip(CMP_INPUT_1) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); } diff --git a/evm/src/arithmetic/modular.rs b/evm/src/arithmetic/modular.rs index 1fd31bb1..53051cda 100644 --- a/evm/src/arithmetic/modular.rs +++ b/evm/src/arithmetic/modular.rs @@ -18,12 +18,13 @@ //! a(x) = \sum_{i=0}^15 a[i] x^i //! //! (so A = a(β)) and similarly for b(x), m(x) and c(x). Then -//! operation(A,B) = C (mod M) if and only if the polynomial +//! operation(A,B) = C (mod M) if and only if there exists q such that +//! the polynomial //! //! operation(a(x), b(x)) - c(x) - m(x) * q(x) //! -//! is zero when evaluated at x = β, i.e. it is divisible by (x - β). -//! Thus exists a polynomial s such that +//! is zero when evaluated at x = β, i.e. it is divisible by (x - β); +//! equivalently, there exists a polynomial s such that //! //! operation(a(x), b(x)) - c(x) - m(x) * q(x) - (x - β) * s(x) == 0 //! @@ -34,12 +35,12 @@ //! coefficients must be zero. The variable names of the constituent //! polynomials are (writing N for N_LIMBS=16): //! -//! a(x) = \sum_{i=0}^{N-1} input0[i] * β^i -//! b(x) = \sum_{i=0}^{N-1} input1[i] * β^i -//! c(x) = \sum_{i=0}^{N-1} output[i] * β^i -//! m(x) = \sum_{i=0}^{N-1} modulus[i] * β^i -//! q(x) = \sum_{i=0}^{2N-1} quot[i] * β^i -//! s(x) = \sum_i^{2N-2} aux[i] * β^i +//! a(x) = \sum_{i=0}^{N-1} input0[i] * x^i +//! b(x) = \sum_{i=0}^{N-1} input1[i] * x^i +//! c(x) = \sum_{i=0}^{N-1} output[i] * x^i +//! m(x) = \sum_{i=0}^{N-1} modulus[i] * x^i +//! q(x) = \sum_{i=0}^{2N-1} quot[i] * x^i +//! s(x) = \sum_i^{2N-2} aux[i] * x^i //! //! Because A, B, M and C are 256-bit numbers, the degrees of a, b, m //! and c are (at most) N-1 = 15. If m = 1, then Q would be A*B which @@ -159,9 +160,9 @@ fn generate_modular_op( ) { // Inputs are all range-checked in [0, 2^16), so the "as i64" // conversion is safe. - let input0_limbs = MODULAR_INPUT_0.map(|c| F::to_canonical_u64(&lv[c]) as i64); - let input1_limbs = MODULAR_INPUT_1.map(|c| F::to_canonical_u64(&lv[c]) as i64); - let mut modulus_limbs = MODULAR_MODULUS.map(|c| F::to_canonical_u64(&lv[c]) as i64); + let input0_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_0); + let input1_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_1); + let mut modulus_limbs = read_value_i64_limbs(lv, MODULAR_MODULUS); // The use of BigUints is just to avoid having to implement // modular reduction. @@ -174,12 +175,11 @@ fn generate_modular_op( let mut constr_poly = [0i64; 2 * N_LIMBS]; constr_poly[..2 * N_LIMBS - 1].copy_from_slice(&operation(input0_limbs, input1_limbs)); + let mut mod_is_zero = F::ZERO; if modulus.is_zero() { modulus += 1u32; modulus_limbs[0] += 1i64; - lv[MODULAR_MOD_IS_ZERO] = F::ONE; - } else { - lv[MODULAR_MOD_IS_ZERO] = F::ZERO; + mod_is_zero = F::ONE; } let input = columns_to_biguint(&constr_poly); @@ -211,21 +211,13 @@ fn generate_modular_op( // constr_poly must be zero when evaluated at x = β := // 2^LIMB_BITS, hence it's divisible by (x - β). `aux_limbs` is // the result of removing that root. - let aux_limbs = pol_remove_root_2exp::(constr_poly); + let aux_limbs = pol_remove_root_2exp::(constr_poly); - for deg in 0..N_LIMBS { - lv[MODULAR_OUTPUT[deg]] = F::from_canonical_i64(output_limbs[deg]); - lv[MODULAR_OUT_AUX_RED[deg]] = F::from_canonical_i64(out_aux_red[deg]); - lv[MODULAR_QUO_INPUT[deg]] = F::from_canonical_i64(quot_limbs[deg]); - lv[MODULAR_QUO_INPUT[deg + N_LIMBS]] = F::from_canonical_i64(quot_limbs[deg + N_LIMBS]); - lv[MODULAR_AUX_INPUT[deg]] = F::from_noncanonical_i64(aux_limbs[deg]); - // Don't overwrite MODULAR_MOD_IS_ZERO, which is at the last - // index of MODULAR_AUX_INPUT - if deg < N_LIMBS - 1 { - lv[MODULAR_AUX_INPUT[deg + N_LIMBS]] = - F::from_noncanonical_i64(aux_limbs[deg + N_LIMBS]); - } - } + lv[MODULAR_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_i64(c))); + lv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(|c| F::from_canonical_i64(c))); + lv[MODULAR_QUO_INPUT].copy_from_slice("_limbs.map(|c| F::from_canonical_i64(c))); + lv[MODULAR_AUX_INPUT].copy_from_slice(&aux_limbs.map(|c| F::from_noncanonical_i64(c))); + lv[MODULAR_MOD_IS_ZERO] = mod_is_zero; } /// Generate the output and auxiliary values for modular operations. @@ -261,7 +253,7 @@ fn modular_constr_poly( range_check_error!(MODULAR_AUX_INPUT, 20, signed); range_check_error!(MODULAR_OUTPUT, 16); - let mut modulus = MODULAR_MODULUS.map(|c| lv[c]); + let mut modulus = read_value::(lv, MODULAR_MODULUS); let mod_is_zero = lv[MODULAR_MOD_IS_ZERO]; // Check that mod_is_zero is zero or one @@ -276,22 +268,22 @@ fn modular_constr_poly( // modulus = 0. modulus[0] += mod_is_zero; - let output = MODULAR_OUTPUT.map(|c| lv[c]); + let output = &lv[MODULAR_OUTPUT]; // Verify that the output is reduced, i.e. output < modulus. - let out_aux_red = MODULAR_OUT_AUX_RED.map(|c| lv[c]); + let out_aux_red = &lv[MODULAR_OUT_AUX_RED]; let is_less_than = P::ONES; eval_packed_generic_lt( yield_constr, filter, output, - modulus, + &modulus, out_aux_red, is_less_than, ); // prod = q(x) * m(x) - let quot = MODULAR_QUO_INPUT.map(|c| lv[c]); + let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); let prod = pol_mul_wide2(quot, modulus); // higher order terms must be zero for &x in prod[2 * N_LIMBS..].iter() { @@ -300,10 +292,11 @@ fn modular_constr_poly( // constr_poly = c(x) + q(x) * m(x) let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); - pol_add_assign(&mut constr_poly, &output); + pol_add_assign(&mut constr_poly, output); // constr_poly = c(x) + q(x) * m(x) + (x - β) * s(x) - let aux = MODULAR_AUX_INPUT.map(|c| lv[c]); + let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT); + aux[2 * N_LIMBS - 1] = P::ZEROS; // zero out the MOD_IS_ZERO flag let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); pol_add_assign(&mut constr_poly, &pol_adjoin_root(aux, base)); @@ -322,8 +315,8 @@ pub(crate) fn eval_packed_generic( // constr_poly has 2*N_LIMBS limbs let constr_poly = modular_constr_poly(lv, yield_constr, filter); - let input0 = MODULAR_INPUT_0.map(|c| lv[c]); - let input1 = MODULAR_INPUT_1.map(|c| lv[c]); + let input0 = read_value(lv, MODULAR_INPUT_0); + let input1 = read_value(lv, MODULAR_INPUT_1); let add_input = pol_add(input0, input1); let mul_input = pol_mul_wide(input0, input1); @@ -360,7 +353,7 @@ fn modular_constr_poly_ext_circuit, const D: usize> yield_constr: &mut RecursiveConstraintConsumer, filter: ExtensionTarget, ) -> [ExtensionTarget; 2 * N_LIMBS] { - let mut modulus = MODULAR_MODULUS.map(|c| lv[c]); + let mut modulus = read_value::(lv, MODULAR_MODULUS); let mod_is_zero = lv[MODULAR_MOD_IS_ZERO]; let t = builder.mul_sub_extension(mod_is_zero, mod_is_zero, mod_is_zero); @@ -374,20 +367,20 @@ fn modular_constr_poly_ext_circuit, const D: usize> modulus[0] = builder.add_extension(modulus[0], mod_is_zero); - let output = MODULAR_OUTPUT.map(|c| lv[c]); - let out_aux_red = MODULAR_OUT_AUX_RED.map(|c| lv[c]); + let output = &lv[MODULAR_OUTPUT]; + let out_aux_red = &lv[MODULAR_OUT_AUX_RED]; let is_less_than = builder.one_extension(); eval_ext_circuit_lt( builder, yield_constr, filter, output, - modulus, + &modulus, out_aux_red, is_less_than, ); - let quot = MODULAR_QUO_INPUT.map(|c| lv[c]); + let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus); for &x in prod[2 * N_LIMBS..].iter() { let t = builder.mul_extension(filter, x); @@ -395,9 +388,10 @@ fn modular_constr_poly_ext_circuit, const D: usize> } let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); - pol_add_assign_ext_circuit(builder, &mut constr_poly, &output); + pol_add_assign_ext_circuit(builder, &mut constr_poly, output); - let aux = MODULAR_AUX_INPUT.map(|c| lv[c]); + let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT); + aux[2 * N_LIMBS - 1] = builder.zero_extension(); let base = builder.constant_extension(F::Extension::from_canonical_u64(1u64 << LIMB_BITS)); let t = pol_adjoin_root_ext_circuit(builder, aux, base); pol_add_assign_ext_circuit(builder, &mut constr_poly, &t); @@ -418,8 +412,8 @@ pub(crate) fn eval_ext_circuit, const D: usize>( let constr_poly = modular_constr_poly_ext_circuit(lv, builder, yield_constr, filter); - let input0 = MODULAR_INPUT_0.map(|c| lv[c]); - let input1 = MODULAR_INPUT_1.map(|c| lv[c]); + let input0 = read_value(lv, MODULAR_INPUT_0); + let input1 = read_value(lv, MODULAR_INPUT_1); let add_input = pol_add_ext_circuit(builder, input0, input1); let mul_input = pol_mul_wide_ext_circuit(builder, input0, input1); @@ -495,11 +489,7 @@ mod tests { for i in 0..N_RND_TESTS { // set inputs to random values - for (&ai, &bi, &mi) in izip!( - MODULAR_INPUT_0.iter(), - MODULAR_INPUT_1.iter(), - MODULAR_MODULUS.iter() - ) { + for (ai, bi, mi) in izip!(MODULAR_INPUT_0, MODULAR_INPUT_1, MODULAR_MODULUS) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); lv[mi] = F::from_canonical_u16(rng.gen()); @@ -511,7 +501,7 @@ mod tests { if i > N_RND_TESTS / 2 { // 1 <= start < N_LIMBS let start = (rng.gen::() % (N_LIMBS - 1)) + 1; - for &mi in &MODULAR_MODULUS[start..N_LIMBS] { + for mi in MODULAR_MODULUS.skip(start) { lv[mi] = F::ZERO; } } @@ -549,11 +539,7 @@ mod tests { for _i in 0..N_RND_TESTS { // set inputs to random values and the modulus to zero; // the output is defined to be zero when modulus is zero. - for (&ai, &bi, &mi) in izip!( - MODULAR_INPUT_0.iter(), - MODULAR_INPUT_1.iter(), - MODULAR_MODULUS.iter() - ) { + for (ai, bi, mi) in izip!(MODULAR_INPUT_0, MODULAR_INPUT_1, MODULAR_MODULUS) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); lv[mi] = F::ZERO; @@ -562,7 +548,7 @@ mod tests { generate(&mut lv, op_filter); // check that the correct output was generated - assert!(MODULAR_OUTPUT.iter().all(|&oi| lv[oi] == F::ZERO)); + assert!(lv[MODULAR_OUTPUT].iter().all(|&c| c == F::ZERO)); let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], @@ -577,7 +563,7 @@ mod tests { .all(|&acc| acc == F::ZERO)); // Corrupt one output limb by setting it to a non-zero value - let random_oi = MODULAR_OUTPUT[rng.gen::() % N_LIMBS]; + let random_oi = MODULAR_OUTPUT.start + rng.gen::() % N_LIMBS; lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); eval_packed_generic(&lv, &mut constraint_consumer); diff --git a/evm/src/arithmetic/mul.rs b/evm/src/arithmetic/mul.rs index 9d6638f1..7dda18e2 100644 --- a/evm/src/arithmetic/mul.rs +++ b/evm/src/arithmetic/mul.rs @@ -3,30 +3,57 @@ //! This crate verifies an EVM MUL instruction, which takes two //! 256-bit inputs A and B, and produces a 256-bit output C satisfying //! -//! C = A*B (mod 2^256). +//! C = A*B (mod 2^256), //! -//! Inputs A and B, and output C, are given as arrays of 16-bit +//! i.e. C is the lower half of the usual long multiplication +//! A*B. Inputs A and B, and output C, are given as arrays of 16-bit //! limbs. For example, if the limbs of A are a[0]...a[15], then //! //! A = \sum_{i=0}^15 a[i] β^i, //! -//! where β = 2^16. To verify that A, B and C satisfy the equation we -//! proceed as follows. Define a(x) = \sum_{i=0}^15 a[i] x^i (so A = a(β)) -//! and similarly for b(x) and c(x). Then A*B = C (mod 2^256) if and only -//! if there exist polynomials q and m such that +//! where β = 2^16 = 2^LIMB_BITS. To verify that A, B and C satisfy +//! the equation we proceed as follows. Define //! -//! a(x)*b(x) - c(x) - m(x)*x^16 - (β - x)*q(x) == 0. +//! a(x) = \sum_{i=0}^15 a[i] x^i +//! +//! (so A = a(β)) and similarly for b(x) and c(x). Then A*B = C (mod +//! 2^256) if and only if there exists q such that the polynomial +//! +//! a(x) * b(x) - c(x) - x^16 * q(x) +//! +//! is zero when evaluated at x = β, i.e. it is divisible by (x - β); +//! equivalently, there exists a polynomial s (representing the +//! carries from the long multiplication) such that +//! +//! a(x) * b(x) - c(x) - x^16 * q(x) - (x - β) * s(x) == 0 +//! +//! As we only need the lower half of the product, we can omit q(x) +//! since it is multiplied by the modulus β^16 = 2^256. Thus we only +//! need to verify +//! +//! a(x) * b(x) - c(x) - (x - β) * s(x) == 0 +//! +//! In the code below, this "constraint polynomial" is constructed in +//! the variable `constr_poly`. It must be identically zero for the +//! multiplication operation to be verified, or, equivalently, each of +//! its coefficients must be zero. The variable names of the +//! constituent polynomials are (writing N for N_LIMBS=16): +//! +//! a(x) = \sum_{i=0}^{N-1} input0[i] * x^i +//! b(x) = \sum_{i=0}^{N-1} input1[i] * x^i +//! c(x) = \sum_{i=0}^{N-1} output[i] * x^i +//! s(x) = \sum_i^{2N-3} aux[i] * x^i //! //! Because A, B and C are 256-bit numbers, the degrees of a, b and c -//! are (at most) 15. Thus deg(a*b) <= 30, so deg(m) <= 14 and deg(q) -//! <= 29. However, the fact that we're verifying the equality modulo -//! 2^256 means that we can ignore terms of degree >= 16, since for -//! them evaluating at β gives a factor of β^16 = 2^256 which is 0. +//! are (at most) 15. Thus deg(a*b) <= 30 and deg(s) <= 29; however, +//! as we're only verifying the lower half of A*B, we only need to +//! know s(x) up to degree 14 (so that (x - β)*s(x) has degree 15). On +//! the other hand, the coefficients of s(x) can be as large as +//! 16*(β-2) or 20 bits. //! -//! Hence, to verify the equality, we don't need m(x) at all, and we -//! only need to know q(x) up to degree 14 (so that (β - x)*q(x) has -//! degree 15). On the other hand, the coefficients of q(x) can be as -//! large as 16*(β-2) or 20 bits. +//! Note that, unlike for the general modular multiplication (see the +//! file `modular.rs`), we don't need to check that output is reduced, +//! since any value of output is less than β^16 and is hence reduced. use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; @@ -35,65 +62,41 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use crate::arithmetic::columns::*; -use crate::arithmetic::utils::{pol_mul_lo, pol_sub_assign}; +use crate::arithmetic::utils::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::range_check_error; pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { - let input0_limbs = MUL_INPUT_0.map(|c| lv[c].to_canonical_u64()); - let input1_limbs = MUL_INPUT_1.map(|c| lv[c].to_canonical_u64()); + let input0 = read_value_i64_limbs(lv, MUL_INPUT_0); + let input1 = read_value_i64_limbs(lv, MUL_INPUT_1); - const MASK: u64 = (1u64 << LIMB_BITS) - 1u64; + const MASK: i64 = (1i64 << LIMB_BITS) - 1i64; // Input and output have 16-bit limbs - let mut aux_in_limbs = [0u64; N_LIMBS]; - let mut output_limbs = [0u64; N_LIMBS]; + let mut output_limbs = [0i64; N_LIMBS]; // Column-wise pen-and-paper long multiplication on 16-bit limbs. // First calculate the coefficients of a(x)*b(x) (in unreduced_prod), // then do carry propagation to obtain C = c(β) = a(β)*b(β). - let mut cy = 0u64; - let mut unreduced_prod = pol_mul_lo(input0_limbs, input1_limbs); + let mut cy = 0i64; + let mut unreduced_prod = pol_mul_lo(input0, input1); for col in 0..N_LIMBS { let t = unreduced_prod[col] + cy; cy = t >> LIMB_BITS; output_limbs[col] = t & MASK; } - // In principle, the last cy could be dropped because this is // multiplication modulo 2^256. However, we need it below for - // aux_in_limbs to handle the fact that unreduced_prod will - // inevitably contain a one digit's worth that is > 2^256. + // aux_limbs to handle the fact that unreduced_prod will + // inevitably contain one digit's worth that is > 2^256. - for (&c, output_limb) in MUL_OUTPUT.iter().zip(output_limbs) { - lv[c] = F::from_canonical_u64(output_limb); - } + lv[MUL_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_i64(c))); pol_sub_assign(&mut unreduced_prod, &output_limbs); - // unreduced_prod is the coefficients of the polynomial a(x)*b(x) - c(x). - // This must be zero when evaluated at x = β = 2^LIMB_BITS, hence it's - // divisible by (β - x). If we write unreduced_prod as - // - // a(x)*b(x) - c(x) = \sum_{i=0}^n p_i x^i + terms of degree > n - // = (β - x) \sum_{i=0}^{n-1} q_i x^i + terms of degree > n - // - // then by comparing coefficients it is easy to see that - // - // q_0 = p_0 / β and q_i = (p_i + q_{i-1}) / β - // - // for 0 < i < n-1 (and the divisions are exact). Because we're - // only calculating the result modulo 2^256, we can ignore the - // terms of degree > n = 15. - aux_in_limbs[0] = unreduced_prod[0] >> LIMB_BITS; - for deg in 1..N_LIMBS - 1 { - aux_in_limbs[deg] = (unreduced_prod[deg] + aux_in_limbs[deg - 1]) >> LIMB_BITS; - } - aux_in_limbs[N_LIMBS - 1] = cy; + let mut aux_limbs = pol_remove_root_2exp::(unreduced_prod); + aux_limbs[N_LIMBS - 1] = -cy; - for deg in 0..N_LIMBS { - let c = MUL_AUX_INPUT[deg]; - lv[c] = F::from_canonical_u64(aux_in_limbs[deg]); - } + lv[MUL_AUX_INPUT].copy_from_slice(&aux_limbs.map(|c| F::from_noncanonical_i64(c))); } pub fn eval_packed_generic( @@ -106,38 +109,35 @@ pub fn eval_packed_generic( range_check_error!(MUL_AUX_INPUT, 20); let is_mul = lv[IS_MUL]; - let input0_limbs = MUL_INPUT_0.map(|c| lv[c]); - let input1_limbs = MUL_INPUT_1.map(|c| lv[c]); - let output_limbs = MUL_OUTPUT.map(|c| lv[c]); - let aux_limbs = MUL_AUX_INPUT.map(|c| lv[c]); + let input0_limbs = read_value::(lv, MUL_INPUT_0); + let input1_limbs = read_value::(lv, MUL_INPUT_1); + let output_limbs = read_value::(lv, MUL_OUTPUT); + let aux_limbs = read_value::(lv, MUL_AUX_INPUT); // Constraint poly holds the coefficients of the polynomial that // must be identically zero for this multiplication to be // verified. // - // These two lines set constr_poly to the polynomial A(x)B(x) - C(x), - // where A, B and C are the polynomials + // These two lines set constr_poly to the polynomial a(x)b(x) - c(x), + // where a, b and c are the polynomials // - // A(x) = \sum_i input0_limbs[i] * 2^LIMB_BITS - // B(x) = \sum_i input1_limbs[i] * 2^LIMB_BITS - // C(x) = \sum_i output_limbs[i] * 2^LIMB_BITS + // a(x) = \sum_i input0_limbs[i] * x^i + // b(x) = \sum_i input1_limbs[i] * x^i + // c(x) = \sum_i output_limbs[i] * x^i // - // This polynomial should equal (2^LIMB_BITS - x) * Q(x) where Q is + // This polynomial should equal (x - β)*s(x) where s is // - // Q(x) = \sum_i aux_limbs[i] * 2^LIMB_BITS + // s(x) = \sum_i aux_limbs[i] * x^i // let mut constr_poly = pol_mul_lo(input0_limbs, input1_limbs); pol_sub_assign(&mut constr_poly, &output_limbs); - // This subtracts (2^LIMB_BITS - x) * Q(x) from constr_poly. + // This subtracts (x - β) * s(x) from constr_poly. let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); - constr_poly[0] -= base * aux_limbs[0]; - for deg in 1..N_LIMBS { - constr_poly[deg] -= (base * aux_limbs[deg]) - aux_limbs[deg - 1]; - } + pol_sub_assign(&mut constr_poly, &pol_adjoin_root(aux_limbs, base)); // At this point constr_poly holds the coefficients of the - // polynomial A(x)B(x) - C(x) - (2^LIMB_BITS - x)*Q(x). The + // polynomial a(x)b(x) - c(x) - (x - β)*s(x). The // multiplication is valid if and only if all of those // coefficients are zero. for &c in &constr_poly { @@ -151,40 +151,17 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr: &mut RecursiveConstraintConsumer, ) { let is_mul = lv[IS_MUL]; - let input0_limbs = MUL_INPUT_0.map(|c| lv[c]); - let input1_limbs = MUL_INPUT_1.map(|c| lv[c]); - let output_limbs = MUL_OUTPUT.map(|c| lv[c]); - let aux_in_limbs = MUL_AUX_INPUT.map(|c| lv[c]); + let input0_limbs = read_value::(lv, MUL_INPUT_0); + let input1_limbs = read_value::(lv, MUL_INPUT_1); + let output_limbs = read_value::(lv, MUL_OUTPUT); + let aux_limbs = read_value::(lv, MUL_AUX_INPUT); - let zero = builder.zero_extension(); - let mut constr_poly = [zero; N_LIMBS]; // pointless init + let mut constr_poly = pol_mul_lo_ext_circuit(builder, input0_limbs, input1_limbs); + pol_sub_assign_ext_circuit(builder, &mut constr_poly, &output_limbs); - // Invariant: i + j = deg - for col in 0..N_LIMBS { - let mut acc = zero; - for i in 0..=col { - let j = col - i; - acc = builder.mul_add_extension(input0_limbs[i], input1_limbs[j], acc); - } - constr_poly[col] = builder.sub_extension(acc, output_limbs[col]); - } - - let base = F::from_canonical_u64(1 << LIMB_BITS); - let one = builder.one_extension(); - // constr_poly[0] = constr_poly[0] - base * aux_in_limbs[0] - constr_poly[0] = - builder.arithmetic_extension(F::ONE, -base, constr_poly[0], one, aux_in_limbs[0]); - for deg in 1..N_LIMBS { - // constr_poly[deg] -= (base*aux_in_limbs[deg] - aux_in_limbs[deg-1]) - let t = builder.arithmetic_extension( - base, - F::NEG_ONE, - aux_in_limbs[deg], - one, - aux_in_limbs[deg - 1], - ); - constr_poly[deg] = builder.sub_extension(constr_poly[deg], t); - } + let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << LIMB_BITS)); + let rhs = pol_adjoin_root_ext_circuit(builder, aux_limbs, base); + pol_sub_assign_ext_circuit(builder, &mut constr_poly, &rhs); for &c in &constr_poly { let filter = builder.mul_extension(is_mul, c); @@ -241,7 +218,7 @@ mod tests { for _i in 0..N_RND_TESTS { // set inputs to random values - for (&ai, bi) in MUL_INPUT_0.iter().zip(MUL_INPUT_1) { + for (ai, bi) in MUL_INPUT_0.zip(MUL_INPUT_1) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); } diff --git a/evm/src/arithmetic/sub.rs b/evm/src/arithmetic/sub.rs index 25834406..f8377651 100644 --- a/evm/src/arithmetic/sub.rs +++ b/evm/src/arithmetic/sub.rs @@ -6,6 +6,7 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::arithmetic::add::{eval_ext_circuit_are_equal, eval_packed_generic_are_equal}; use crate::arithmetic::columns::*; +use crate::arithmetic::utils::read_value_u64_limbs; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::range_check_error; @@ -28,14 +29,12 @@ pub(crate) fn u256_sub_br(input0: [u64; N_LIMBS], input1: [u64; N_LIMBS]) -> ([u } pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { - let input0_limbs = SUB_INPUT_0.map(|c| lv[c].to_canonical_u64()); - let input1_limbs = SUB_INPUT_1.map(|c| lv[c].to_canonical_u64()); + let input0 = read_value_u64_limbs(lv, SUB_INPUT_0); + let input1 = read_value_u64_limbs(lv, SUB_INPUT_1); - let (output_limbs, _) = u256_sub_br(input0_limbs, input1_limbs); + let (output_limbs, _) = u256_sub_br(input0, input1); - for (&c, output_limb) in SUB_OUTPUT.iter().zip(output_limbs) { - lv[c] = F::from_canonical_u64(output_limb); - } + lv[SUB_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_u64(c))); } pub fn eval_packed_generic( @@ -47,13 +46,18 @@ pub fn eval_packed_generic( range_check_error!(SUB_OUTPUT, 16); let is_sub = lv[IS_SUB]; - let input0_limbs = SUB_INPUT_0.iter().map(|&c| lv[c]); - let input1_limbs = SUB_INPUT_1.iter().map(|&c| lv[c]); - let output_limbs = SUB_OUTPUT.iter().map(|&c| lv[c]); + let input0_limbs = &lv[SUB_INPUT_0]; + let input1_limbs = &lv[SUB_INPUT_1]; + let output_limbs = &lv[SUB_OUTPUT]; - let output_computed = input0_limbs.zip(input1_limbs).map(|(a, b)| a - b); + let output_computed = input0_limbs.iter().zip(input1_limbs).map(|(&a, &b)| a - b); - eval_packed_generic_are_equal(yield_constr, is_sub, output_limbs, output_computed); + eval_packed_generic_are_equal( + yield_constr, + is_sub, + output_limbs.iter().copied(), + output_computed, + ); } #[allow(clippy::needless_collect)] @@ -63,24 +67,25 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr: &mut RecursiveConstraintConsumer, ) { let is_sub = lv[IS_SUB]; - let input0_limbs = SUB_INPUT_0.iter().map(|&c| lv[c]); - let input1_limbs = SUB_INPUT_1.iter().map(|&c| lv[c]); - let output_limbs = SUB_OUTPUT.iter().map(|&c| lv[c]); + let input0_limbs = &lv[SUB_INPUT_0]; + let input1_limbs = &lv[SUB_INPUT_1]; + let output_limbs = &lv[SUB_OUTPUT]; // Since `map` is lazy and the closure passed to it borrows // `builder`, we can't then borrow builder again below in the call // to `eval_ext_circuit_are_equal`. The solution is to force // evaluation with `collect`. let output_computed = input0_limbs + .iter() .zip(input1_limbs) - .map(|(a, b)| builder.sub_extension(a, b)) + .map(|(&a, &b)| builder.sub_extension(a, b)) .collect::>>(); eval_ext_circuit_are_equal( builder, yield_constr, is_sub, - output_limbs, + output_limbs.iter().copied(), output_computed.into_iter(), ); } @@ -134,7 +139,7 @@ mod tests { for _ in 0..N_RND_TESTS { // set inputs to random values - for (&ai, bi) in SUB_INPUT_0.iter().zip(SUB_INPUT_1) { + for (ai, bi) in SUB_INPUT_0.zip(SUB_INPUT_1) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); } diff --git a/evm/src/arithmetic/utils.rs b/evm/src/arithmetic/utils.rs index b5356a78..871a9646 100644 --- a/evm/src/arithmetic/utils.rs +++ b/evm/src/arithmetic/utils.rs @@ -1,4 +1,4 @@ -use std::ops::{Add, AddAssign, Mul, Neg, Shr, Sub, SubAssign}; +use std::ops::{Add, AddAssign, Mul, Neg, Range, Shr, Sub, SubAssign}; use log::error; use plonky2::field::extension::Extendable; @@ -6,7 +6,7 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; -use crate::arithmetic::columns::N_LIMBS; +use crate::arithmetic::columns::{NUM_ARITH_COLUMNS, N_LIMBS}; /// Emit an error message regarding unchecked range assumptions. /// Assumes the values in `cols` are `[cols[0], cols[0] + 1, ..., @@ -14,7 +14,7 @@ use crate::arithmetic::columns::N_LIMBS; pub(crate) fn _range_check_error( file: &str, line: u32, - cols: &[usize], + cols: Range, signedness: &str, ) { error!( @@ -23,8 +23,8 @@ pub(crate) fn _range_check_error( file, RC_BITS, signedness, - cols[0], - cols[0] + cols.len() - 1 + cols.start, + cols.end - 1, ); } @@ -34,7 +34,7 @@ macro_rules! range_check_error { $crate::arithmetic::utils::_range_check_error::<$rc_bits>( file!(), line!(), - &$cols, + $cols, "unsigned", ); }; @@ -42,7 +42,7 @@ macro_rules! range_check_error { $crate::arithmetic::utils::_range_check_error::<$rc_bits>( file!(), line!(), - &$cols, + $cols, "signed", ); }; @@ -225,6 +225,22 @@ where res } +pub(crate) fn pol_mul_lo_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a: [ExtensionTarget; N_LIMBS], + b: [ExtensionTarget; N_LIMBS], +) -> [ExtensionTarget; N_LIMBS] { + let zero = builder.zero_extension(); + let mut res = [zero; N_LIMBS]; + for deg in 0..N_LIMBS { + for i in 0..=deg { + let j = deg - i; + res[deg] = builder.mul_add_extension(a[i], b[j], res[deg]); + } + } + res +} + /// Adjoin M - N zeros to a, returning [a[0], a[1], ..., a[N-1], 0, 0, ..., 0]. pub(crate) fn pol_extend(a: [T; N]) -> [T; M] where @@ -248,11 +264,9 @@ pub(crate) fn pol_extend_ext_circuit, const D: usiz zero_extend } -/// Given polynomial a(x) = \sum_{i=0}^{2N-2} a[i] x^i and an element +/// Given polynomial a(x) = \sum_{i=0}^{N-2} a[i] x^i and an element /// `root`, return b = (x - root) * a(x). -/// -/// NB: Ignores element a[2 * N_LIMBS - 1], treating it as if it's 0. -pub(crate) fn pol_adjoin_root(a: [T; 2 * N_LIMBS], root: U) -> [T; 2 * N_LIMBS] +pub(crate) fn pol_adjoin_root(a: [T; N], root: U) -> [T; N] where T: Add + Copy + Default + Mul + Sub, U: Copy + Mul + Neg, @@ -261,67 +275,96 @@ where // coefficients, res[0] = -root*a[0] and // res[i] = a[i-1] - root * a[i] - let mut res = [T::default(); 2 * N_LIMBS]; + let mut res = [T::default(); N]; res[0] = -root * a[0]; - for deg in 1..(2 * N_LIMBS - 1) { + for deg in 1..N { res[deg] = a[deg - 1] - (root * a[deg]); } - // NB: We assume that a[2 * N_LIMBS - 1] = 0, so the last - // iteration has no "* root" term. - res[2 * N_LIMBS - 1] = a[2 * N_LIMBS - 2]; res } -pub(crate) fn pol_adjoin_root_ext_circuit, const D: usize>( +pub(crate) fn pol_adjoin_root_ext_circuit< + F: RichField + Extendable, + const D: usize, + const N: usize, +>( builder: &mut CircuitBuilder, - a: [ExtensionTarget; 2 * N_LIMBS], + a: [ExtensionTarget; N], root: ExtensionTarget, -) -> [ExtensionTarget; 2 * N_LIMBS] { +) -> [ExtensionTarget; N] { let zero = builder.zero_extension(); - let mut res = [zero; 2 * N_LIMBS]; + let mut res = [zero; N]; // res[deg] = NEG_ONE * root * a[0] + ZERO * zero res[0] = builder.arithmetic_extension(F::NEG_ONE, F::ZERO, root, a[0], zero); - for deg in 1..(2 * N_LIMBS - 1) { + for deg in 1..N { // res[deg] = NEG_ONE * root * a[deg] + ONE * a[deg - 1] res[deg] = builder.arithmetic_extension(F::NEG_ONE, F::ONE, root, a[deg], a[deg - 1]); } - // NB: We assumes that a[2 * N_LIMBS - 1] = 0, so the last - // iteration has no "* root" term. - res[2 * N_LIMBS - 1] = a[2 * N_LIMBS - 2]; res } -/// Given polynomial a(x) = \sum_{i=0}^{2N-1} a[i] x^i and a root of `a` +/// Given polynomial a(x) = \sum_{i=0}^{N-1} a[i] x^i and a root of `a` /// of the form 2^EXP, return q(x) satisfying a(x) = (x - root) * q(x). /// /// NB: We do not verify that a(2^EXP) = 0; if this doesn't hold the /// result is basically junk. /// -/// NB: The result could be returned in 2*N-1 elements, but we return -/// 2*N and set the last element to zero since the calling code -/// happens to require a result zero-extended to 2*N elements. -pub(crate) fn pol_remove_root_2exp(a: [T; 2 * N_LIMBS]) -> [T; 2 * N_LIMBS] +/// NB: The result could be returned in N-1 elements, but we return +/// N and set the last element to zero since the calling code +/// happens to require a result zero-extended to N elements. +pub(crate) fn pol_remove_root_2exp(a: [T; N]) -> [T; N] where T: Copy + Default + Neg + Shr + Sub, { // By assumption β := 2^EXP is a root of `a`, i.e. (x - β) divides // `a`; if we write // - // a(x) = \sum_{i=0}^{2N-1} a[i] x^i - // = (x - β) \sum_{i=0}^{2N-2} q[i] x^i + // a(x) = \sum_{i=0}^{N-1} a[i] x^i + // = (x - β) \sum_{i=0}^{N-2} q[i] x^i // // then by comparing coefficients it is easy to see that // // q[0] = -a[0] / β and q[i] = (q[i-1] - a[i]) / β // - // for 0 < i <= 2N-1 (and the divisions are exact). + // for 0 < i <= N-1 (and the divisions are exact). - let mut q = [T::default(); 2 * N_LIMBS]; + let mut q = [T::default(); N]; q[0] = -(a[0] >> EXP); // NB: Last element of q is deliberately left equal to zero. - for deg in 1..2 * N_LIMBS - 1 { + for deg in 1..N - 1 { q[deg] = (q[deg - 1] - a[deg]) >> EXP; } q } + +/// Read the range `value_idxs` of values from `lv` into an array of +/// length `N`. Panics if the length of the range is not `N`. +pub(crate) fn read_value( + lv: &[T; NUM_ARITH_COLUMNS], + value_idxs: Range, +) -> [T; N] { + lv[value_idxs].try_into().unwrap() +} + +/// Read the range `value_idxs` of values from `lv` into an array of +/// length `N`, interpreting the values as `u64`s. Panics if the +/// length of the range is not `N`. +pub(crate) fn read_value_u64_limbs( + lv: &[F; NUM_ARITH_COLUMNS], + value_idxs: Range, +) -> [u64; N] { + let limbs: [_; N] = lv[value_idxs].try_into().unwrap(); + limbs.map(|c| F::to_canonical_u64(&c)) +} + +/// Read the range `value_idxs` of values from `lv` into an array of +/// length `N`, interpreting the values as `i64`s. Panics if the +/// length of the range is not `N`. +pub(crate) fn read_value_i64_limbs( + lv: &[F; NUM_ARITH_COLUMNS], + value_idxs: Range, +) -> [i64; N] { + let limbs: [_; N] = lv[value_idxs].try_into().unwrap(); + limbs.map(|c| F::to_canonical_u64(&c) as i64) +} diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index e0cb2952..04d4d0f2 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -19,6 +19,9 @@ pub struct OpsColumnsView { pub mulmod: T, pub exp: T, pub signextend: T, + pub addfp254: T, + pub mulfp254: T, + pub subfp254: T, pub lt: T, pub gt: T, pub slt: T, diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index 3856726c..c7b7c6bb 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -9,7 +9,7 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; // TODO: This list is incomplete. -const NATIVE_INSTRUCTIONS: [usize; 25] = [ +const NATIVE_INSTRUCTIONS: [usize; 28] = [ COL_MAP.op.add, COL_MAP.op.mul, COL_MAP.op.sub, @@ -20,6 +20,9 @@ const NATIVE_INSTRUCTIONS: [usize; 25] = [ COL_MAP.op.addmod, COL_MAP.op.mulmod, COL_MAP.op.signextend, + COL_MAP.op.addfp254, + COL_MAP.op.mulfp254, + COL_MAP.op.subfp254, COL_MAP.op.lt, COL_MAP.op.gt, COL_MAP.op.slt, diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index b11ff9f5..7b34cc4f 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -11,7 +11,7 @@ use plonky2::hash::hash_types::RichField; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; use crate::cpu::{ - bootstrap_kernel, control_flow, decode, dup_swap, jumps, membus, simple_logic, stack, + bootstrap_kernel, control_flow, decode, dup_swap, jumps, membus, modfp254, simple_logic, stack, stack_bounds, syscalls, }; use crate::cross_table_lookup::Column; @@ -150,6 +150,7 @@ impl, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark Kernel { include_str!("asm/memory/metadata.asm"), include_str!("asm/memory/packing.asm"), include_str!("asm/memory/txn_fields.asm"), + include_str!("asm/mpt/delete.asm"), include_str!("asm/mpt/hash.asm"), include_str!("asm/mpt/hash_trie_specific.asm"), include_str!("asm/mpt/hex_prefix.asm"), + include_str!("asm/mpt/insert.asm"), + include_str!("asm/mpt/insert_extension.asm"), + include_str!("asm/mpt/insert_leaf.asm"), + include_str!("asm/mpt/insert_trie_specific.asm"), include_str!("asm/mpt/load.asm"), + include_str!("asm/mpt/load_trie_specific.asm"), include_str!("asm/mpt/read.asm"), include_str!("asm/mpt/storage_read.asm"), include_str!("asm/mpt/storage_write.asm"), include_str!("asm/mpt/util.asm"), - include_str!("asm/mpt/write.asm"), include_str!("asm/ripemd/box.asm"), include_str!("asm/ripemd/compression.asm"), include_str!("asm/ripemd/constants.asm"), diff --git a/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm b/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm index 931a6a7b..5891807c 100644 --- a/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm +++ b/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm @@ -22,7 +22,7 @@ count_zeros_loop: // stack: zeros', i, retdest SWAP1 // stack: i, zeros', retdest - %add_const(1) + %increment // stack: i', zeros', retdest %jump(count_zeros_loop) diff --git a/evm/src/cpu/kernel/asm/core/util.asm b/evm/src/cpu/kernel/asm/core/util.asm index 4ceaec3b..dfacf1a2 100644 --- a/evm/src/cpu/kernel/asm/core/util.asm +++ b/evm/src/cpu/kernel/asm/core/util.asm @@ -14,7 +14,7 @@ %macro next_context_id // stack: (empty) %mload_global_metadata(@GLOBAL_METADATA_LARGEST_CONTEXT) - %add_const(1) + %increment // stack: new_ctx DUP1 %mstore_global_metadata(@GLOBAL_METADATA_LARGEST_CONTEXT) diff --git a/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm b/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm index 96e177ff..a1c2ff3c 100644 --- a/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm +++ b/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm @@ -132,7 +132,7 @@ pubkey_to_addr: // stack: PKx, PKy, retdest PUSH 0 // stack: 0, PKx, PKy, retdest - MSTORE // TODO: switch to kernel memory (like `%mstore_current(@SEGMENT_KERNEL_GENERAL)`). + MSTORE // TODO: switch to kernel memory (like `%mstore_kernel(@SEGMENT_KERNEL_GENERAL)`). // stack: PKy, retdest PUSH 0x20 // stack: 0x20, PKy, retdest diff --git a/evm/src/cpu/kernel/asm/memory/core.asm b/evm/src/cpu/kernel/asm/memory/core.asm index c2c19811..2b4d2b68 100644 --- a/evm/src/cpu/kernel/asm/memory/core.asm +++ b/evm/src/cpu/kernel/asm/memory/core.asm @@ -55,6 +55,19 @@ // stack: (empty) %endmacro +// Store a single value from the given segment of kernel (context 0) memory. +%macro mstore_kernel(segment, offset) + // stack: value + PUSH $offset + // stack: offset, value + PUSH $segment + // stack: segment, offset, value + PUSH 0 // kernel has context 0 + // stack: context, segment, offset, value + MSTORE_GENERAL + // stack: (empty) +%endmacro + // Load from the kernel a big-endian u32, consisting of 4 bytes (c_3, c_2, c_1, c_0) %macro mload_kernel_u32(segment) // stack: offset @@ -64,7 +77,7 @@ %shl_const(8) // stack: c_3 << 8, offset DUP2 - %add_const(1) + %increment %mload_kernel($segment) OR // stack: (c_3 << 8) | c_2, offset @@ -91,7 +104,7 @@ %mload_kernel($segment) // stack: c0 , offset DUP2 - %add_const(1) + %increment %mload_kernel($segment) %shl_const(8) OR @@ -208,7 +221,7 @@ // stack: c_2, c_1, c_0, offset DUP4 // stack: offset, c_2, c_1, c_0, offset - %add_const(1) + %increment %mstore_kernel($segment) // stack: c_1, c_0, offset DUP3 diff --git a/evm/src/cpu/kernel/asm/memory/memcpy.asm b/evm/src/cpu/kernel/asm/memory/memcpy.asm index 3feca35d..dd0569e7 100644 --- a/evm/src/cpu/kernel/asm/memory/memcpy.asm +++ b/evm/src/cpu/kernel/asm/memory/memcpy.asm @@ -28,15 +28,15 @@ global memcpy: // Increment dst_addr. SWAP2 - %add_const(1) + %increment SWAP2 // Increment src_addr. SWAP5 - %add_const(1) + %increment SWAP5 // Decrement count. SWAP6 - %sub_const(1) + %decrement SWAP6 // Continue the loop. diff --git a/evm/src/cpu/kernel/asm/memory/packing.asm b/evm/src/cpu/kernel/asm/memory/packing.asm index c8b4c468..f12c7b17 100644 --- a/evm/src/cpu/kernel/asm/memory/packing.asm +++ b/evm/src/cpu/kernel/asm/memory/packing.asm @@ -71,9 +71,9 @@ mstore_unpacking_loop: // stack: i, context, segment, offset, value, len, retdest // Increment offset. - SWAP3 %add_const(1) SWAP3 + SWAP3 %increment SWAP3 // Increment i. - %add_const(1) + %increment %jump(mstore_unpacking_loop) diff --git a/evm/src/cpu/kernel/asm/mpt/delete.asm b/evm/src/cpu/kernel/asm/mpt/delete.asm new file mode 100644 index 00000000..3e0b8afe --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/delete.asm @@ -0,0 +1,6 @@ +// Return a copy of the given node with the given key deleted. +// +// Pre stack: node_ptr, num_nibbles, key, retdest +// Post stack: updated_node_ptr +global mpt_delete: + PANIC // TODO diff --git a/evm/src/cpu/kernel/asm/mpt/hash.asm b/evm/src/cpu/kernel/asm/mpt/hash.asm index abd436fe..9fe0edef 100644 --- a/evm/src/cpu/kernel/asm/mpt/hash.asm +++ b/evm/src/cpu/kernel/asm/mpt/hash.asm @@ -47,7 +47,29 @@ mpt_hash_hash_rlp_after_unpacking: // Pre stack: node_ptr, encode_value, retdest // Post stack: result, result_len global encode_or_hash_node: - %stack (node_ptr, encode_value) -> (node_ptr, encode_value, maybe_hash_node) + // stack: node_ptr, encode_value, retdest + DUP1 %mload_trie_data + + // Check if we're dealing with a concrete node, i.e. not a hash node. + // stack: node_type, node_ptr, encode_value, retdest + DUP1 + PUSH @MPT_NODE_HASH + SUB + %jumpi(encode_or_hash_concrete_node) + + // If we got here, node_type == @MPT_NODE_HASH. + // Load the hash and return (hash, 32). + // stack: node_type, node_ptr, encode_value, retdest + POP + // stack: node_ptr, encode_value, retdest + %increment // Skip over node type prefix + // stack: hash_ptr, encode_value, retdest + %mload_trie_data + // stack: hash, encode_value, retdest + %stack (hash, encode_value, retdest) -> (retdest, hash, 32) + JUMP +encode_or_hash_concrete_node: + %stack (node_type, node_ptr, encode_value) -> (node_type, node_ptr, encode_value, maybe_hash_node) %jump(encode_node) maybe_hash_node: // stack: result_ptr, result_len, retdest @@ -75,22 +97,22 @@ after_packed_small_rlp: // RLP encode the given trie node, and return an (pointer, length) pair // indicating where the data lives within @SEGMENT_RLP_RAW. // -// Pre stack: node_ptr, encode_value, retdest +// Pre stack: node_type, node_ptr, encode_value, retdest // Post stack: result_ptr, result_len -global encode_node: - // stack: node_ptr, encode_value, retdest - DUP1 %mload_trie_data +encode_node: // stack: node_type, node_ptr, encode_value, retdest // Increment node_ptr, so it points to the node payload instead of its type. - SWAP1 %add_const(1) SWAP1 + SWAP1 %increment SWAP1 // stack: node_type, node_payload_ptr, encode_value, retdest DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(encode_node_empty) - DUP1 %eq_const(@MPT_NODE_HASH) %jumpi(encode_node_hash) DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(encode_node_branch) DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(encode_node_extension) DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(encode_node_leaf) - PANIC // Invalid node type? Shouldn't get here. + + // If we got here, node_type is either @MPT_NODE_HASH, which should have + // been handled earlier in encode_or_hash_node, or something invalid. + PANIC global encode_node_empty: // stack: node_type, node_payload_ptr, encode_value, retdest @@ -105,27 +127,27 @@ global encode_node_empty: %stack (retdest) -> (retdest, 0, 1) JUMP -global encode_node_hash: - // stack: node_type, node_payload_ptr, encode_value, retdest - POP - // stack: node_payload_ptr, encode_value, retdest - %mload_trie_data - %stack (hash, encode_value, retdest) -> (retdest, hash, 32) - JUMP - encode_node_branch: // stack: node_type, node_payload_ptr, encode_value, retdest POP // stack: node_payload_ptr, encode_value, retdest + // Get the next unused offset within the encoded child buffers. + // Then immediately increment the next unused offset by 16, so any + // recursive calls will use nonoverlapping offsets. + %mload_global_metadata(@TRIE_ENCODED_CHILD_SIZE) + DUP1 %add_const(16) + %mstore_global_metadata(@TRIE_ENCODED_CHILD_SIZE) + // stack: base_offset, node_payload_ptr, encode_value, retdest + // We will call encode_or_hash_node on each child. For the i'th child, we - // will store the result in SEGMENT_KERNEL_GENERAL[i], and its length in - // SEGMENT_KERNEL_GENERAL_2[i]. + // will store the result in SEGMENT_TRIE_ENCODED_CHILD[base + i], and its length in + // SEGMENT_TRIE_ENCODED_CHILD_LEN[base + i]. %encode_child(0) %encode_child(1) %encode_child(2) %encode_child(3) %encode_child(4) %encode_child(5) %encode_child(6) %encode_child(7) %encode_child(8) %encode_child(9) %encode_child(10) %encode_child(11) %encode_child(12) %encode_child(13) %encode_child(14) %encode_child(15) - // stack: node_payload_ptr, encode_value, retdest + // stack: base_offset, node_payload_ptr, encode_value, retdest // Now, append each child to our RLP tape. PUSH 9 // rlp_pos; we start at 9 to leave room to prepend a list prefix @@ -133,25 +155,28 @@ encode_node_branch: %append_child(4) %append_child(5) %append_child(6) %append_child(7) %append_child(8) %append_child(9) %append_child(10) %append_child(11) %append_child(12) %append_child(13) %append_child(14) %append_child(15) + // stack: rlp_pos', base_offset, node_payload_ptr, encode_value, retdest + + // We no longer need base_offset. + SWAP1 + POP // stack: rlp_pos', node_payload_ptr, encode_value, retdest SWAP1 %add_const(16) - // stack: value_len_ptr, rlp_pos', encode_value, retdest - DUP1 %mload_trie_data - // stack: value_len, value_len_ptr, rlp_pos', encode_value, retdest - %jumpi(encode_node_branch_with_value) + // stack: value_ptr_ptr, rlp_pos', encode_value, retdest + %mload_trie_data + // stack: value_ptr, rlp_pos', encode_value, retdest + DUP1 %jumpi(encode_node_branch_with_value) // No value; append the empty string (0x80). - // stack: value_len_ptr, rlp_pos', encode_value, retdest - %stack (value_len_ptr, rlp_pos, encode_value) -> (rlp_pos, 0x80, rlp_pos) + // stack: value_ptr, rlp_pos', encode_value, retdest + %stack (value_ptr, rlp_pos, encode_value) -> (rlp_pos, 0x80, rlp_pos) %mstore_rlp // stack: rlp_pos', retdest %increment // stack: rlp_pos'', retdest %jump(encode_node_branch_prepend_prefix) encode_node_branch_with_value: - // stack: value_len_ptr, rlp_pos', encode_value, retdest - %increment // stack: value_ptr, rlp_pos', encode_value, retdest %stack (value_ptr, rlp_pos, encode_value) -> (encode_value, rlp_pos, value_ptr, encode_node_branch_prepend_prefix) @@ -163,43 +188,44 @@ encode_node_branch_prepend_prefix: JUMP // Part of the encode_node_branch function. Encodes the i'th child. -// Stores the result in SEGMENT_KERNEL_GENERAL[i], and its length in -// SEGMENT_KERNEL_GENERAL_2[i]. +// Stores the result in SEGMENT_TRIE_ENCODED_CHILD[base + i], and its length in +// SEGMENT_TRIE_ENCODED_CHILD_LEN[base + i]. %macro encode_child(i) - // stack: node_payload_ptr, encode_value, retdest + // stack: base_offset, node_payload_ptr, encode_value, retdest PUSH %%after_encode - DUP3 DUP3 - // stack: node_payload_ptr, encode_value, %%after_encode, node_payload_ptr, encode_value, retdest + DUP4 DUP4 + // stack: node_payload_ptr, encode_value, %%after_encode, base_offset, node_payload_ptr, encode_value, retdest %add_const($i) %mload_trie_data - // stack: child_i_ptr, encode_value, %%after_encode, node_payload_ptr, encode_value, retdest + // stack: child_i_ptr, encode_value, %%after_encode, base_offset, node_payload_ptr, encode_value, retdest %jump(encode_or_hash_node) %%after_encode: - // stack: result, result_len, node_payload_ptr, encode_value, retdest - %mstore_kernel_general($i) - %mstore_kernel_general_2($i) - // stack: node_payload_ptr, encode_value, retdest + // stack: result, result_len, base_offset, node_payload_ptr, encode_value, retdest + DUP3 %add_const($i) %mstore_kernel(@SEGMENT_TRIE_ENCODED_CHILD) + // stack: result_len, base_offset, node_payload_ptr, encode_value, retdest + DUP2 %add_const($i) %mstore_kernel(@SEGMENT_TRIE_ENCODED_CHILD_LEN) + // stack: base_offset, node_payload_ptr, encode_value, retdest %endmacro // Part of the encode_node_branch function. Appends the i'th child's RLP. %macro append_child(i) - // stack: rlp_pos, node_payload_ptr, encode_value, retdest - %mload_kernel_general($i) // load result - %mload_kernel_general_2($i) // load result_len - // stack: result_len, result, rlp_pos, node_payload_ptr, encode_value, retdest + // stack: rlp_pos, base_offset, node_payload_ptr, encode_value, retdest + DUP2 %add_const($i) %mload_kernel(@SEGMENT_TRIE_ENCODED_CHILD) // load result + DUP3 %add_const($i) %mload_kernel(@SEGMENT_TRIE_ENCODED_CHILD_LEN) // load result_len + // stack: result_len, result, rlp_pos, base_offset, node_payload_ptr, encode_value, retdest // If result_len != 32, result is raw RLP, with an appropriate RLP prefix already. DUP1 %sub_const(32) %jumpi(%%unpack) // Otherwise, result is a hash, and we need to add the prefix 0x80 + 32 = 160. - // stack: result_len, result, rlp_pos, node_payload_ptr, encode_value, retdest + // stack: result_len, result, rlp_pos, base_offset, node_payload_ptr, encode_value, retdest PUSH 160 DUP4 // rlp_pos %mstore_rlp SWAP2 %increment SWAP2 // rlp_pos += 1 %%unpack: - %stack (result_len, result, rlp_pos, node_payload_ptr, encode_value, retdest) - -> (rlp_pos, result, result_len, %%after_unpacking, node_payload_ptr, encode_value, retdest) + %stack (result_len, result, rlp_pos, base_offset, node_payload_ptr, encode_value, retdest) + -> (rlp_pos, result, result_len, %%after_unpacking, base_offset, node_payload_ptr, encode_value, retdest) %jump(mstore_unpacking_rlp) %%after_unpacking: - // stack: rlp_pos', node_payload_ptr, encode_value, retdest + // stack: rlp_pos', base_offset, node_payload_ptr, encode_value, retdest %endmacro encode_node_extension: @@ -214,7 +240,7 @@ encode_node_extension_after_encode_child: PUSH encode_node_extension_after_hex_prefix // retdest PUSH 0 // terminated // stack: terminated, encode_node_extension_after_hex_prefix, result, result_len, node_payload_ptr, retdest - DUP5 %add_const(1) %mload_trie_data // Load the packed_nibbles field, which is at index 1. + DUP5 %increment %mload_trie_data // Load the packed_nibbles field, which is at index 1. // stack: packed_nibbles, terminated, encode_node_extension_after_hex_prefix, result, result_len, node_payload_ptr, retdest DUP6 %mload_trie_data // Load the num_nibbles field, which is at index 0. // stack: num_nibbles, packed_nibbles, terminated, encode_node_extension_after_hex_prefix, result, result_len, node_payload_ptr, retdest @@ -247,7 +273,7 @@ encode_node_leaf: PUSH encode_node_leaf_after_hex_prefix // retdest PUSH 1 // terminated // stack: terminated, encode_node_leaf_after_hex_prefix, node_payload_ptr, encode_value, retdest - DUP3 %add_const(1) %mload_trie_data // Load the packed_nibbles field, which is at index 1. + DUP3 %increment %mload_trie_data // Load the packed_nibbles field, which is at index 1. // stack: packed_nibbles, terminated, encode_node_leaf_after_hex_prefix, node_payload_ptr, encode_value, retdest DUP4 %mload_trie_data // Load the num_nibbles field, which is at index 0. // stack: num_nibbles, packed_nibbles, terminated, encode_node_leaf_after_hex_prefix, node_payload_ptr, encode_value, retdest @@ -257,7 +283,9 @@ encode_node_leaf: encode_node_leaf_after_hex_prefix: // stack: rlp_pos, node_payload_ptr, encode_value, retdest SWAP1 - %add_const(3) // The value starts at index 3, after num_nibbles, packed_nibbles, and value_len. + %add_const(2) // The value pointer starts at index 3, after num_nibbles and packed_nibbles. + // stack: value_ptr_ptr, rlp_pos, encode_value, retdest + %mload_trie_data // stack: value_ptr, rlp_pos, encode_value, retdest %stack (value_ptr, rlp_pos, encode_value, retdest) -> (encode_value, rlp_pos, value_ptr, encode_node_leaf_after_encode_value, retdest) diff --git a/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm index 80763deb..4f9b58b4 100644 --- a/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm +++ b/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm @@ -39,7 +39,7 @@ global mpt_hash_receipt_trie: %%after: %endmacro -encode_account: +global encode_account: // stack: rlp_pos, value_ptr, retdest // First, we compute the length of the RLP data we're about to write. // The nonce and balance fields are variable-length, so we need to load them @@ -48,7 +48,7 @@ encode_account: DUP2 %mload_trie_data // nonce = value[0] %rlp_scalar_len // stack: nonce_rlp_len, rlp_pos, value_ptr, retdest - DUP3 %add_const(1) %mload_trie_data // balance = value[1] + DUP3 %increment %mload_trie_data // balance = value[1] %rlp_scalar_len // stack: balance_rlp_len, nonce_rlp_len, rlp_pos, value_ptr, retdest PUSH 66 // storage_root and code_hash fields each take 1 + 32 bytes @@ -68,12 +68,17 @@ encode_account: // stack: nonce, rlp_pos_3, value_ptr, retdest SWAP1 %encode_rlp_scalar // stack: rlp_pos_4, value_ptr, retdest - DUP2 %add_const(1) %mload_trie_data // balance = value[1] + DUP2 %increment %mload_trie_data // balance = value[1] // stack: balance, rlp_pos_4, value_ptr, retdest SWAP1 %encode_rlp_scalar // stack: rlp_pos_5, value_ptr, retdest - DUP2 %add_const(2) %mload_trie_data // storage_root = value[2] - // stack: storage_root, rlp_pos_5, value_ptr, retdest + PUSH encode_account_after_hash_storage_trie + PUSH encode_storage_value + DUP4 %add_const(2) %mload_trie_data // storage_root_ptr = value[2] + // stack: storage_root_ptr, encode_storage_value, encode_account_after_hash_storage_trie, rlp_pos_5, value_ptr, retdest + %jump(mpt_hash) +encode_account_after_hash_storage_trie: + // stack: storage_root_digest, rlp_pos_5, value_ptr, retdest SWAP1 %encode_rlp_256 // stack: rlp_pos_6, value_ptr, retdest SWAP1 %add_const(3) %mload_trie_data // code_hash = value[3] @@ -88,3 +93,6 @@ encode_txn: encode_receipt: PANIC // TODO + +encode_storage_value: + PANIC // TODO diff --git a/evm/src/cpu/kernel/asm/mpt/hex_prefix.asm b/evm/src/cpu/kernel/asm/mpt/hex_prefix.asm index 72ac18cc..b7a3073b 100644 --- a/evm/src/cpu/kernel/asm/mpt/hex_prefix.asm +++ b/evm/src/cpu/kernel/asm/mpt/hex_prefix.asm @@ -15,7 +15,7 @@ global hex_prefix_rlp: // Compute the length of the hex-prefix string, in bytes: // hp_len = num_nibbles / 2 + 1 = i + 1 - DUP1 %add_const(1) + DUP1 %increment // stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest // Write the RLP header. @@ -35,7 +35,7 @@ rlp_header_medium: %mstore_rlp // rlp_pos += 1 - SWAP2 %add_const(1) SWAP2 + SWAP2 %increment SWAP2 %jump(start_loop) @@ -49,7 +49,7 @@ rlp_header_large: %mstore_rlp DUP1 // value = hp_len - DUP4 %add_const(1) // offset = rlp_pos + 1 + DUP4 %increment // offset = rlp_pos + 1 %mstore_rlp // rlp_pos += 2 @@ -74,7 +74,7 @@ loop: %mstore_rlp // stack: i, hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest - %sub_const(1) + %decrement SWAP4 %shr_const(8) SWAP4 // packed_nibbles >>= 8 %jump(loop) diff --git a/evm/src/cpu/kernel/asm/mpt/insert.asm b/evm/src/cpu/kernel/asm/mpt/insert.asm new file mode 100644 index 00000000..2830d376 --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/insert.asm @@ -0,0 +1,107 @@ +// Return a copy of the given node, with the given key set to the given value. +// +// Pre stack: node_ptr, num_nibbles, key, value_ptr, retdest +// Post stack: updated_node_ptr +global mpt_insert: + // stack: node_ptr, num_nibbles, key, value_ptr, retdest + DUP1 %mload_trie_data + // stack: node_type, node_ptr, num_nibbles, key, value_ptr, retdest + // Increment node_ptr, so it points to the node payload instead of its type. + SWAP1 %increment SWAP1 + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + + DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(mpt_insert_empty) + DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(mpt_insert_branch) + DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(mpt_insert_extension) + DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(mpt_insert_leaf) + + // There's still the MPT_NODE_HASH case, but if we hit a hash node, + // it means the prover failed to provide necessary Merkle data, so panic. + PANIC + +mpt_insert_empty: + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + %pop2 + // stack: num_nibbles, key, value_ptr, retdest + // We will append a new leaf node to our MPT tape and return a pointer to it. + %get_trie_data_size + // stack: leaf_ptr, num_nibbles, key, value_ptr, retdest + PUSH @MPT_NODE_LEAF %append_to_trie_data + // stack: leaf_ptr, num_nibbles, key, value_ptr, retdest + SWAP1 %append_to_trie_data + // stack: leaf_ptr, key, value_ptr, retdest + SWAP1 %append_to_trie_data + // stack: leaf_ptr, value_ptr, retdest + SWAP1 %append_to_trie_data + // stack: leaf_ptr, retdest + SWAP1 + JUMP + +mpt_insert_branch: + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + %get_trie_data_size + // stack: updated_branch_ptr, node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + SWAP1 + %append_to_trie_data + // stack: updated_branch_ptr, node_payload_ptr, num_nibbles, key, value_ptr, retdest + SWAP1 + // stack: node_payload_ptr, updated_branch_ptr, num_nibbles, key, value_ptr, retdest + + // Copy the original node's data to our updated node. + DUP1 %mload_trie_data %append_to_trie_data // Copy child[0] + DUP1 %add_const(1) %mload_trie_data %append_to_trie_data // ... + DUP1 %add_const(2) %mload_trie_data %append_to_trie_data + DUP1 %add_const(3) %mload_trie_data %append_to_trie_data + DUP1 %add_const(4) %mload_trie_data %append_to_trie_data + DUP1 %add_const(5) %mload_trie_data %append_to_trie_data + DUP1 %add_const(6) %mload_trie_data %append_to_trie_data + DUP1 %add_const(7) %mload_trie_data %append_to_trie_data + DUP1 %add_const(8) %mload_trie_data %append_to_trie_data + DUP1 %add_const(9) %mload_trie_data %append_to_trie_data + DUP1 %add_const(10) %mload_trie_data %append_to_trie_data + DUP1 %add_const(11) %mload_trie_data %append_to_trie_data + DUP1 %add_const(12) %mload_trie_data %append_to_trie_data + DUP1 %add_const(13) %mload_trie_data %append_to_trie_data + DUP1 %add_const(14) %mload_trie_data %append_to_trie_data + DUP1 %add_const(15) %mload_trie_data %append_to_trie_data // Copy child[15] + %add_const(16) %mload_trie_data %append_to_trie_data // Copy value_ptr + + // At this point, we branch based on whether the key terminates with this branch node. + // stack: updated_branch_ptr, num_nibbles, key, value_ptr, retdest + DUP2 %jumpi(mpt_insert_branch_nonterminal) + + // The key terminates here, so the value will be placed right in our (updated) branch node. + // stack: updated_branch_ptr, num_nibbles, key, value_ptr, retdest + SWAP3 + // stack: value_ptr, num_nibbles, key, updated_branch_ptr, retdest + DUP4 %add_const(17) + // stack: updated_branch_value_ptr_ptr, value_ptr, num_nibbles, key, updated_branch_ptr, retdest + %mstore_trie_data + // stack: num_nibbles, key, updated_branch_ptr, retdest + %pop2 + // stack: updated_branch_ptr, retdest + SWAP1 + JUMP + +mpt_insert_branch_nonterminal: + // The key continues, so we split off the first (most significant) nibble, + // and recursively insert into the child associated with that nibble. + // stack: updated_branch_ptr, num_nibbles, key, value_ptr, retdest + %stack (updated_branch_ptr, num_nibbles, key) -> (num_nibbles, key, updated_branch_ptr) + %split_first_nibble + // stack: first_nibble, num_nibbles, key, updated_branch_ptr, value_ptr, retdest + DUP4 %increment ADD + // stack: child_ptr_ptr, num_nibbles, key, updated_branch_ptr, value_ptr, retdest + %stack (child_ptr_ptr, num_nibbles, key, updated_branch_ptr, value_ptr) + -> (child_ptr_ptr, num_nibbles, key, value_ptr, + mpt_insert_branch_nonterminal_after_recursion, + child_ptr_ptr, updated_branch_ptr) + %mload_trie_data // Deref child_ptr_ptr, giving child_ptr + %jump(mpt_insert) + +mpt_insert_branch_nonterminal_after_recursion: + // stack: updated_child_ptr, child_ptr_ptr, updated_branch_ptr, retdest + SWAP1 %mstore_trie_data // Store the pointer to the updated child. + // stack: updated_branch_ptr, retdest + SWAP1 + JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/insert_extension.asm b/evm/src/cpu/kernel/asm/mpt/insert_extension.asm new file mode 100644 index 00000000..3ead805b --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/insert_extension.asm @@ -0,0 +1,201 @@ +/* +Insert into an extension node. +The high-level logic can be expressed with the following pseudocode: + +common_len, common_key, node_len, node_key, insert_len, insert_key = + split_common_prefix(node_len, node_key, insert_len, insert_key) + +if node_len == 0: + new_node = insert(node_child, insert_len, insert_key, insert_value) +else: + new_node = [MPT_TYPE_BRANCH] + [0] * 17 + + // Process the node's child. + if node_len > 1: + // The node key continues with multiple nibbles left, so we can't place + // node_child directly in the branch, but need an extension for it. + node_key_first, node_len, node_key = split_first_nibble(node_len, node_key) + new_node[node_key_first + 1] = [MPT_TYPE_EXTENSION, node_len, node_key, node_child] + else: + // The remaining node_key is a single nibble, so we can place node_child directly in the branch. + new_node[node_key + 1] = node_child + + // Process the inserted entry. + if insert_len > 0: + // The insert key continues. Add a leaf node for it. + insert_key_first, insert_len, insert_key = split_first_nibble(insert_len, insert_key) + new_node[insert_key_first + 1] = [MPT_TYPE_LEAF, insert_len, insert_key, insert_value] + else: + new_node[17] = insert_value + +if common_len > 0: + return [MPT_TYPE_EXTENSION, common_len, common_key, new_node] +else: + return new_node +*/ + +global mpt_insert_extension: + // stack: node_type, node_payload_ptr, insert_len, insert_key, insert_value_ptr, retdest + POP + // stack: node_payload_ptr, insert_len, insert_key, insert_value_ptr, retdest + + // We start by loading the extension node's three fields: node_len, node_key, node_child_ptr + DUP1 %add_const(2) %mload_trie_data + // stack: node_child_ptr, node_payload_ptr, insert_len, insert_key, insert_value_ptr, retdest + %stack (node_child_ptr, node_payload_ptr, insert_len, insert_key) + -> (node_payload_ptr, insert_len, insert_key, node_child_ptr) + // stack: node_payload_ptr, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + DUP1 %increment %mload_trie_data + // stack: node_key, node_payload_ptr, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + SWAP1 %mload_trie_data + // stack: node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + + // Next, we split off any key prefix which is common to the node's key and the inserted key. + %split_common_prefix + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + + // Now we branch based on whether the node key continues beyond the common prefix. + DUP3 %jumpi(node_key_continues) + + // The node key does not continue. In this case we recurse. Pseudocode: + // new_node = insert(node_child, insert_len, insert_key, insert_value) + // and then proceed to maybe_add_extension_for_common_key. + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + PUSH maybe_add_extension_for_common_key + DUP9 // insert_value_ptr + DUP8 // insert_key + DUP8 // insert_len + DUP11 // node_child_ptr + %jump(mpt_insert) + +node_key_continues: + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // Allocate new_node, a branch node which is initially empty + // Pseudocode: new_node = [MPT_TYPE_BRANCH] + [0] * 17 + %get_trie_data_size // pointer to the branch node we're about to create + PUSH @MPT_NODE_BRANCH %append_to_trie_data + %rep 17 + PUSH 0 %append_to_trie_data + %endrep + +process_node_child: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // We want to check if node_len > 1. We already know node_len > 0 since we're in node_key_continues, + // so it suffices to check 1 - node_len != 0 + DUP4 // node_len + PUSH 1 SUB + %jumpi(node_key_continues_multiple_nibbles) + + // If we got here, node_len = 1. + // Pseudocode: new_node[node_key + 1] = node_child + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + DUP8 // node_child_ptr + DUP2 // new_node_ptr + %increment + DUP7 // node_key + ADD + %mstore_trie_data + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + %jump(process_inserted_entry) + +node_key_continues_multiple_nibbles: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // Pseudocode: node_key_first, node_len, node_key = split_first_nibble(node_len, node_key) + // To minimize stack manipulation, we won't actually mutate the node_len, node_key variables in our stack. + // Instead we will duplicate them, and leave the old ones alone; they won't be used. + DUP5 DUP5 + // stack: node_len, node_key, new_node_ptr, ... + %split_first_nibble + // stack: node_key_first, node_len, node_key, new_node_ptr, ... + + // Pseudocode: new_node[node_key_first + 1] = [MPT_TYPE_EXTENSION, node_len, node_key, node_child] + %get_trie_data_size // pointer to the extension node we're about to create + // stack: ext_node_ptr, node_key_first, node_len, node_key, new_node_ptr, ... + PUSH @MPT_NODE_EXTENSION %append_to_trie_data + // stack: ext_node_ptr, node_key_first, node_len, node_key, new_node_ptr, ... + SWAP2 %append_to_trie_data // Append node_len + // stack: node_key_first, ext_node_ptr, node_key, new_node_ptr, ... + SWAP2 %append_to_trie_data // Append node_key + // stack: ext_node_ptr, node_key_first, new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + DUP10 %append_to_trie_data // Append node_child_ptr + + SWAP1 + // stack: node_key_first, ext_node_ptr, new_node_ptr, ... + DUP3 // new_node_ptr + ADD + %increment + // stack: new_node_ptr + node_key_first + 1, ext_node_ptr, new_node_ptr, ... + %mstore_trie_data + %jump(process_inserted_entry) + +process_inserted_entry: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + DUP6 // insert_len + %jumpi(insert_key_continues) + + // If we got here, insert_len = 0, so we store the inserted value directly in our new branch node. + // Pseudocode: new_node[17] = insert_value + DUP9 // insert_value_ptr + DUP2 // new_node_ptr + %add_const(17) + %mstore_trie_data + %jump(maybe_add_extension_for_common_key) + +insert_key_continues: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // Pseudocode: insert_key_first, insert_len, insert_key = split_first_nibble(insert_len, insert_key) + // To minimize stack manipulation, we won't actually mutate the node_len, node_key variables in our stack. + // Instead we will duplicate them, and leave the old ones alone; they won't be used. + DUP7 DUP7 + // stack: insert_len, insert_key, new_node_ptr, ... + %split_first_nibble + // stack: insert_key_first, insert_len, insert_key, new_node_ptr, ... + + // Pseudocode: new_node[insert_key_first + 1] = [MPT_TYPE_LEAF, insert_len, insert_key, insert_value] + %get_trie_data_size // pointer to the leaf node we're about to create + // stack: leaf_node_ptr, insert_key_first, insert_len, insert_key, new_node_ptr, ... + PUSH @MPT_NODE_LEAF %append_to_trie_data + // stack: leaf_node_ptr, insert_key_first, insert_len, insert_key, new_node_ptr, ... + SWAP2 %append_to_trie_data // Append insert_len + // stack: insert_key_first, leaf_node_ptr, insert_key, new_node_ptr, ... + SWAP2 %append_to_trie_data // Append insert_key + // stack: leaf_node_ptr, insert_key_first, new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + DUP11 %append_to_trie_data // Append insert_value_ptr + + SWAP1 + // stack: insert_key_first, leaf_node_ptr, new_node_ptr, ... + DUP3 // new_node_ptr + ADD + %increment + // stack: new_node_ptr + insert_key_first + 1, leaf_node_ptr, new_node_ptr, ... + %mstore_trie_data + %jump(maybe_add_extension_for_common_key) + +maybe_add_extension_for_common_key: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // If common_len > 0, we need to add an extension node. + DUP2 %jumpi(add_extension_for_common_key) + // Otherwise, we simply return new_node_ptr. + SWAP8 + %pop8 + // stack: new_node_ptr, retdest + SWAP1 + JUMP + +add_extension_for_common_key: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // Pseudocode: return [MPT_TYPE_EXTENSION, common_len, common_key, new_node] + %get_trie_data_size // pointer to the extension node we're about to create + // stack: extension_ptr, new_node_ptr, common_len, common_key, ... + PUSH @MPT_NODE_EXTENSION %append_to_trie_data + SWAP2 %append_to_trie_data // Append common_len to our node + // stack: new_node_ptr, extension_ptr, common_key, ... + SWAP2 %append_to_trie_data // Append common_key to our node + // stack: extension_ptr, new_node_ptr, ... + SWAP1 %append_to_trie_data // Append new_node_ptr to our node + // stack: extension_ptr, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + SWAP6 + %pop6 + // stack: extension_ptr, retdest + SWAP1 + JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm b/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm new file mode 100644 index 00000000..6afe2f14 --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm @@ -0,0 +1,193 @@ +/* +Insert into a leaf node. +The high-level logic can be expressed with the following pseudocode: + +if node_len == insert_len && node_key == insert_key: + return Leaf[node_key, insert_value] + +common_len, common_key, node_len, node_key, insert_len, insert_key = + split_common_prefix(node_len, node_key, insert_len, insert_key) + +branch = [MPT_TYPE_BRANCH] + [0] * 17 + +// Process the node's entry. +if node_len > 0: + node_key_first, node_len, node_key = split_first_nibble(node_len, node_key) + branch[node_key_first + 1] = [MPT_TYPE_LEAF, node_len, node_key, node_value] +else: + branch[17] = node_value + +// Process the inserted entry. +if insert_len > 0: + insert_key_first, insert_len, insert_key = split_first_nibble(insert_len, insert_key) + branch[insert_key_first + 1] = [MPT_TYPE_LEAF, insert_len, insert_key, insert_value] +else: + branch[17] = insert_value + +// Add an extension node if there is a common prefix. +if common_len > 0: + return [MPT_TYPE_EXTENSION, common_len, common_key, branch] +else: + return branch +*/ + +global mpt_insert_leaf: + // stack: node_type, node_payload_ptr, insert_len, insert_key, insert_value_ptr, retdest + POP + // stack: node_payload_ptr, insert_len, insert_key, insert_value_ptr, retdest + %stack (node_payload_ptr, insert_len, insert_key) -> (insert_len, insert_key, node_payload_ptr) + // stack: insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + DUP3 %increment %mload_trie_data + // stack: node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + DUP4 %mload_trie_data + // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + + // If the keys match, i.e. node_len == insert_len && node_key == insert_key, + // then we're simply replacing the leaf node's value. Since this is a common + // case, it's best to detect it early. Calling %split_common_prefix could be + // expensive as leaf keys tend to be long. + DUP1 DUP4 EQ // node_len == insert_len + DUP3 DUP6 EQ // node_key == insert_key + MUL // Cheaper than AND + // stack: keys_match, node_len, node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + %jumpi(keys_match) + + // Replace node_payload_ptr with node_value, which is node_payload[2]. + // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + SWAP4 + %add_const(2) + %mload_trie_data + SWAP4 + // stack: node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + + // Split off any common prefix between the node key and the inserted key. + %split_common_prefix + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + + // For the remaining cases, we will need a new branch node since the two keys diverge. + // We may also need an extension node above it (if common_len > 0); we will handle that later. + // For now, we allocate the branch node, initially with no children or value. + %get_trie_data_size // pointer to the branch node we're about to create + PUSH @MPT_NODE_BRANCH %append_to_trie_data + %rep 17 + PUSH 0 %append_to_trie_data + %endrep + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + + // Now, we branch based on whether each key continues beyond the common + // prefix, starting with the node key. + +process_node_entry: + DUP4 // node_len + %jumpi(node_key_continues) + + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + // branch[17] = node_value_ptr + DUP8 // node_value_ptr + DUP2 // branch_ptr + %add_const(17) + %mstore_trie_data + +process_inserted_entry: + DUP6 // insert_len + %jumpi(insert_key_continues) + + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + // branch[17] = insert_value_ptr + DUP9 // insert_value_ptr + DUP2 // branch_ptr + %add_const(17) + %mstore_trie_data + +maybe_add_extension_for_common_key: + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + // If common_len > 0, we need to add an extension node. + DUP2 %jumpi(add_extension_for_common_key) + // Otherwise, we simply return branch_ptr. + SWAP8 + %pop8 + // stack: branch_ptr, retdest + SWAP1 + JUMP + +add_extension_for_common_key: + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + // Pseudocode: return [MPT_TYPE_EXTENSION, common_len, common_key, branch] + %get_trie_data_size // pointer to the extension node we're about to create + // stack: extension_ptr, branch_ptr, common_len, common_key, ... + PUSH @MPT_NODE_EXTENSION %append_to_trie_data + SWAP2 %append_to_trie_data // Append common_len to our node + // stack: branch_ptr, extension_ptr, common_key, ... + SWAP2 %append_to_trie_data // Append common_key to our node + // stack: extension_ptr, branch_ptr, ... + SWAP1 %append_to_trie_data // Append branch_ptr to our node + // stack: extension_ptr, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + SWAP6 + %pop6 + // stack: extension_ptr, retdest + SWAP1 + JUMP + +node_key_continues: + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + // branch[node_key_first + 1] = Leaf[node_len, node_key, node_value] + // To minimize stack manipulation, we won't actually mutate the node_len, node_key variables in our stack. + // Instead we will duplicate them, and leave the old ones alone; they won't be used. + DUP5 DUP5 + // stack: node_len, node_key, branch_ptr, ... + %split_first_nibble + // stack: node_key_first, node_len, node_key, branch_ptr, ... + %get_trie_data_size // pointer to the leaf node we're about to create + // stack: leaf_ptr, node_key_first, node_len, node_key, branch_ptr, ... + SWAP1 + DUP5 // branch_ptr + %increment // Skip over node type field + ADD // Add node_key_first + %mstore_trie_data + // stack: node_len, node_key, branch_ptr, ... + PUSH @MPT_NODE_LEAF %append_to_trie_data + %append_to_trie_data // Append node_len to our leaf node + %append_to_trie_data // Append node_key to our leaf node + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + DUP8 %append_to_trie_data // Append node_value_ptr to our leaf node + %jump(process_inserted_entry) + +insert_key_continues: + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + // branch[insert_key_first + 1] = Leaf[insert_len, insert_key, insert_value] + // To minimize stack manipulation, we won't actually mutate the insert_len, insert_key variables in our stack. + // Instead we will duplicate them, and leave the old ones alone; they won't be used. + DUP7 DUP7 + // stack: insert_len, insert_key, branch_ptr, ... + %split_first_nibble + // stack: insert_key_first, insert_len, insert_key, branch_ptr, ... + %get_trie_data_size // pointer to the leaf node we're about to create + // stack: leaf_ptr, insert_key_first, insert_len, insert_key, branch_ptr, ... + SWAP1 + DUP5 // branch_ptr + %increment // Skip over node type field + ADD // Add insert_key_first + %mstore_trie_data + // stack: insert_len, insert_key, branch_ptr, ... + PUSH @MPT_NODE_LEAF %append_to_trie_data + %append_to_trie_data // Append insert_len to our leaf node + %append_to_trie_data // Append insert_key to our leaf node + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + DUP9 %append_to_trie_data // Append insert_value_ptr to our leaf node + %jump(maybe_add_extension_for_common_key) + +keys_match: + // The keys match exactly, so we simply create a new leaf node with the new value.xs + // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + %stack (node_len, node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr) + -> (node_len, node_key, insert_value_ptr) + // stack: common_len, common_key, insert_value_ptr, retdest + %get_trie_data_size // pointer to the leaf node we're about to create + // stack: updated_leaf_ptr, common_len, common_key, insert_value_ptr, retdest + PUSH @MPT_NODE_LEAF %append_to_trie_data + SWAP1 %append_to_trie_data // Append common_len to our leaf node + SWAP1 %append_to_trie_data // Append common_key to our leaf node + SWAP1 %append_to_trie_data // Append insert_value_ptr to our leaf node + // stack: updated_leaf_ptr, retdestx + SWAP1 + JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/insert_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/insert_trie_specific.asm new file mode 100644 index 00000000..4c03d96c --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/insert_trie_specific.asm @@ -0,0 +1,14 @@ +// Insertion logic specific to a particular trie. + +// Mutate the state trie, inserting the given key-value pair. +global mpt_insert_state_trie: + // stack: num_nibbles, key, value_ptr, retdest + %stack (num_nibbles, key, value_ptr) + -> (num_nibbles, key, value_ptr, mpt_insert_state_trie_save) + %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + // stack: state_root_ptr, num_nibbles, key, value_ptr, mpt_insert_state_trie_save, retdest + %jump(mpt_insert) +mpt_insert_state_trie_save: + // stack: updated_node_ptr, retdest + %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/load.asm b/evm/src/cpu/kernel/asm/mpt/load.asm index f072f202..d787074b 100644 --- a/evm/src/cpu/kernel/asm/mpt/load.asm +++ b/evm/src/cpu/kernel/asm/mpt/load.asm @@ -1,6 +1,3 @@ -// TODO: Receipt trie leaves are variable-length, so we need to be careful not -// to permit buffer over-reads. - // Load all partial trie data from prover inputs. global load_all_mpts: // stack: retdest @@ -9,49 +6,20 @@ global load_all_mpts: PUSH 1 %set_trie_data_size - %load_mpt_and_return_root_ptr %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) - %load_mpt_and_return_root_ptr %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) - %load_mpt_and_return_root_ptr %mstore_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_ROOT) + %load_mpt(mpt_load_state_trie_value) %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + %load_mpt(mpt_load_txn_trie_value) %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) + %load_mpt(mpt_load_receipt_trie_value) %mstore_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_ROOT) - PROVER_INPUT(mpt) - // stack: num_storage_tries, retdest - DUP1 %mstore_global_metadata(@GLOBAL_METADATA_NUM_STORAGE_TRIES) - // stack: num_storage_tries, retdest - PUSH 0 // i = 0 - // stack: i, num_storage_tries, retdest -storage_trie_loop: - DUP2 DUP2 EQ - // stack: i == num_storage_tries, i, num_storage_tries, retdest - %jumpi(storage_trie_loop_end) - // stack: i, num_storage_tries, retdest - PROVER_INPUT(mpt) - // stack: storage_trie_addr, i, num_storage_tries, retdest - DUP2 - // stack: i, storage_trie_addr, i, num_storage_tries, retdest - %mstore_kernel(@SEGMENT_STORAGE_TRIE_ADDRS) - // stack: i, num_storage_tries, retdest - %load_mpt_and_return_root_ptr - // stack: root_ptr, i, num_storage_tries, retdest - DUP2 - // stack: i, root_ptr, i, num_storage_tries, retdest - %mstore_kernel(@SEGMENT_STORAGE_TRIE_PTRS) - // stack: i, num_storage_tries, retdest - %jump(storage_trie_loop) -storage_trie_loop_end: - // stack: i, num_storage_tries, retdest - %pop2 // stack: retdest JUMP // Load an MPT from prover inputs. -// Pre stack: retdest -// Post stack: (empty) -load_mpt: - // stack: retdest +// Pre stack: load_value, retdest +// Post stack: node_ptr +global load_mpt: + // stack: load_value, retdest PROVER_INPUT(mpt) - // stack: node_type, retdest - DUP1 %append_to_trie_data - // stack: node_type, retdest + // stack: node_type, load_value, retdest DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(load_mpt_empty) DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(load_mpt_branch) @@ -61,121 +29,145 @@ load_mpt: PANIC // Invalid node type load_mpt_empty: - // stack: node_type, retdest - POP - // stack: retdest + // TRIE_DATA[0] = 0, and an empty node has type 0, so we can simply return the null pointer. + %stack (node_type, load_value, retdest) -> (retdest, 0) JUMP load_mpt_branch: - // stack: node_type, retdest - POP - // stack: retdest + // stack: node_type, load_value, retdest + %get_trie_data_size + // stack: node_ptr, node_type, load_value, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, load_value, retdest // Save the offset of our 16 child pointers so we can write them later. - // Then advance out current trie pointer beyond them, so we can load the + // Then advance our current trie pointer beyond them, so we can load the // value and have it placed after our child pointers. %get_trie_data_size - // stack: ptr_children, retdest - DUP1 %add_const(16) - // stack: ptr_leaf, ptr_children, retdest - %set_trie_data_size - // stack: ptr_children, retdest - %load_leaf_value + // stack: children_ptr, node_ptr, load_value, retdest + DUP1 %add_const(17) // Skip over 16 children plus the value pointer + // stack: end_of_branch_ptr, children_ptr, node_ptr, load_value, retdest + DUP1 %set_trie_data_size + // Now the top of the stack points to where the branch node will end and the + // value will begin, if there is a value. But we need to ask the prover if a + // value is present, and point to null if not. + // stack: end_of_branch_ptr, children_ptr, node_ptr, load_value, retdest + PROVER_INPUT(mpt) + // stack: is_value_present, end_of_branch_ptr, children_ptr, node_ptr, load_value, retdest + %jumpi(load_mpt_branch_value_present) + // There is no value present, so value_ptr = null. + %stack (end_of_branch_ptr) -> (0) + // stack: value_ptr, children_ptr, node_ptr, load_value, retdest + %jump(load_mpt_branch_after_load_value) +load_mpt_branch_value_present: + // stack: value_ptr, children_ptr, node_ptr, load_value, retdest + PUSH load_mpt_branch_after_load_value + DUP5 // load_value + JUMP +load_mpt_branch_after_load_value: + // stack: value_ptr, children_ptr, node_ptr, load_value, retdest + SWAP1 + // stack: children_ptr, value_ptr, node_ptr, load_value, retdest // Load the 16 children. %rep 16 - %load_mpt_and_return_root_ptr - // stack: child_ptr, ptr_next_child, retdest + DUP4 // load_value + %load_mpt + // stack: child_ptr, next_child_ptr_ptr, value_ptr, node_ptr, load_value, retdest DUP2 - // stack: ptr_next_child, child_ptr, ptr_next_child, retdest + // stack: next_child_ptr_ptr, child_ptr, next_child_ptr_ptr, value_ptr, node_ptr, load_value, retdest %mstore_trie_data - // stack: ptr_next_child, retdest + // stack: next_child_ptr_ptr, value_ptr, node_ptr, load_value, retdest %increment - // stack: ptr_next_child, retdest + // stack: next_child_ptr_ptr, value_ptr, node_ptr, load_value, retdest %endrep - // stack: ptr_next_child, retdest - POP + // stack: value_ptr_ptr, value_ptr, node_ptr, load_value, retdest + %mstore_trie_data + %stack (node_ptr, load_value, retdest) -> (retdest, node_ptr) JUMP load_mpt_extension: - // stack: node_type, retdest - POP - // stack: retdest + // stack: node_type, load_value, retdest + %get_trie_data_size + // stack: node_ptr, node_type, load_value, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, load_value, retdest PROVER_INPUT(mpt) // read num_nibbles %append_to_trie_data PROVER_INPUT(mpt) // read packed_nibbles %append_to_trie_data - // stack: retdest + // stack: node_ptr, load_value, retdest - // Let i be the current trie data size. We still need to expand this node by - // one element, appending our child pointer. Thus our child node will start - // at i + 1. So we will set our child pointer to i + 1. %get_trie_data_size - %add_const(1) - %append_to_trie_data - // stack: retdest - - %load_mpt - // stack: retdest + // stack: child_ptr_ptr, node_ptr, load_value, retdest + // Increment trie_data_size, to leave room for child_ptr_ptr, before we load our child. + DUP1 %increment %set_trie_data_size + %stack (child_ptr_ptr, node_ptr, load_value, retdest) + -> (load_value, load_mpt_extension_after_load_mpt, + child_ptr_ptr, retdest, node_ptr) + %jump(load_mpt) +load_mpt_extension_after_load_mpt: + // stack: child_ptr, child_ptr_ptr, retdest, node_ptr + SWAP1 %mstore_trie_data + // stack: retdest, node_ptr JUMP load_mpt_leaf: - // stack: node_type, retdest - POP - // stack: retdest + // stack: node_type, load_value, retdest + %get_trie_data_size + // stack: node_ptr, node_type, load_value, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, load_value, retdest PROVER_INPUT(mpt) // read num_nibbles %append_to_trie_data PROVER_INPUT(mpt) // read packed_nibbles %append_to_trie_data - // stack: retdest - %load_leaf_value - // stack: retdest + // stack: node_ptr, load_value, retdest + // We save value_ptr_ptr = get_trie_data_size, then increment trie_data_size + // to skip over the slot for value_ptr_ptr. We will write to value_ptr_ptr + // after the load_value call. + %get_trie_data_size + // stack: value_ptr_ptr, node_ptr, load_value, retdest + DUP1 %increment + // stack: value_ptr, value_ptr_ptr, node_ptr, load_value, retdest + DUP1 %set_trie_data_size + // stack: value_ptr, value_ptr_ptr, node_ptr, load_value, retdest + %stack (value_ptr, value_ptr_ptr, node_ptr, load_value, retdest) + -> (load_value, load_mpt_leaf_after_load_value, + value_ptr_ptr, value_ptr, retdest, node_ptr) + JUMP +load_mpt_leaf_after_load_value: + // stack: value_ptr_ptr, value_ptr, retdest, node_ptr + %mstore_trie_data + // stack: retdest, node_ptr JUMP load_mpt_digest: - // stack: node_type, retdest - POP - // stack: retdest + // stack: node_type, load_value, retdest + %get_trie_data_size + // stack: node_ptr, node_type, load_value, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, load_value, retdest PROVER_INPUT(mpt) // read digest %append_to_trie_data - // stack: retdest + %stack (node_ptr, load_value, retdest) -> (retdest, node_ptr) JUMP // Convenience macro to call load_mpt and return where we left off. +// Pre stack: load_value +// Post stack: node_ptr %macro load_mpt - PUSH %%after + %stack (load_value) -> (load_value, %%after) %jump(load_mpt) %%after: %endmacro -%macro load_mpt_and_return_root_ptr - // stack: (empty) - %get_trie_data_size - // stack: ptr - %load_mpt - // stack: ptr -%endmacro - -// Load a leaf from prover input, and append it to trie data. -%macro load_leaf_value - // stack: (empty) - PROVER_INPUT(mpt) - // stack: leaf_len - DUP1 %append_to_trie_data - // stack: leaf_len -%%loop: - DUP1 ISZERO - // stack: leaf_len == 0, leaf_len - %jumpi(%%finish) - // stack: leaf_len - PROVER_INPUT(mpt) - // stack: leaf_part, leaf_len - %append_to_trie_data - // stack: leaf_len - %sub_const(1) - // stack: leaf_len' - %jump(%%loop) -%%finish: - POP - // stack: (empty) +// Convenience macro to call load_mpt and return where we left off. +// Pre stack: (empty) +// Post stack: node_ptr +%macro load_mpt(load_value) + PUSH %%after + PUSH $load_value + %jump(load_mpt) +%%after: %endmacro diff --git a/evm/src/cpu/kernel/asm/mpt/load_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/load_trie_specific.asm new file mode 100644 index 00000000..b93b36e4 --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/load_trie_specific.asm @@ -0,0 +1,40 @@ +global mpt_load_state_trie_value: + // stack: retdest + + // Load and append the nonce and balance. + PROVER_INPUT(mpt) %append_to_trie_data + PROVER_INPUT(mpt) %append_to_trie_data + + // Now increment the trie data size by 2, to leave room for our storage trie + // pointer and code hash fields, before calling load_mpt which will append + // our storage trie data. + %get_trie_data_size + // stack: storage_trie_ptr_ptr, retdest + DUP1 %add_const(2) + // stack: storage_trie_ptr, storage_trie_ptr_ptr, retdest + %set_trie_data_size + // stack: storage_trie_ptr_ptr, retdest + + %load_mpt(mpt_load_storage_trie_value) + // stack: storage_trie_ptr, storage_trie_ptr_ptr, retdest + DUP2 %mstore_trie_data + // stack: storage_trie_ptr_ptr, retdest + %increment + // stack: code_hash_ptr, retdest + PROVER_INPUT(mpt) + // stack: code_hash, code_hash_ptr, retdest + SWAP1 %mstore_trie_data + // stack: retdest + JUMP + +global mpt_load_txn_trie_value: + // stack: retdest + PANIC // TODO + +global mpt_load_receipt_trie_value: + // stack: retdest + PANIC // TODO + +global mpt_load_storage_trie_value: + // stack: retdest + PANIC // TODO diff --git a/evm/src/cpu/kernel/asm/mpt/read.asm b/evm/src/cpu/kernel/asm/mpt/read.asm index f952f49a..d375bedc 100644 --- a/evm/src/cpu/kernel/asm/mpt/read.asm +++ b/evm/src/cpu/kernel/asm/mpt/read.asm @@ -1,6 +1,6 @@ // Given an address, return a pointer to the associated account data, which // consists of four words (nonce, balance, storage_root, code_hash), in the -// state trie. Returns 0 if the address is not found. +// state trie. Returns null if the address is not found. global mpt_read_state_trie: // stack: addr, retdest // The key is the hash of the address. Since KECCAK_GENERAL takes input from @@ -24,14 +24,14 @@ mpt_read_state_trie_after_mstore: // - the key, as a U256 // - the number of nibbles in the key (should start at 64) // -// This function returns a pointer to the leaf, or 0 if the key is not found. +// This function returns a pointer to the value, or 0 if the key is not found. global mpt_read: // stack: node_ptr, num_nibbles, key, retdest DUP1 %mload_trie_data // stack: node_type, node_ptr, num_nibbles, key, retdest // Increment node_ptr, so it points to the node payload instead of its type. - SWAP1 %add_const(1) SWAP1 + SWAP1 %increment SWAP1 // stack: node_type, node_payload_ptr, num_nibbles, key, retdest DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(mpt_read_empty) @@ -39,7 +39,7 @@ global mpt_read: DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(mpt_read_extension) DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(mpt_read_leaf) - // There's still the MPT_NODE_HASH case, but if we hit a digest node, + // There's still the MPT_NODE_HASH case, but if we hit a hash node, // it means the prover failed to provide necessary Merkle data, so panic. PANIC @@ -75,15 +75,8 @@ mpt_read_branch_end_of_key: %stack (node_payload_ptr, num_nibbles, key, retdest) -> (node_payload_ptr, retdest) // stack: node_payload_ptr, retdest %add_const(16) // skip over the 16 child nodes - // stack: value_len_ptr, retdest - DUP1 %mload_trie_data - // stack: value_len, value_len_ptr, retdest - %jumpi(mpt_read_branch_found_value) - // This branch node contains no value, so return null. - %stack (value_len_ptr, retdest) -> (retdest, 0) -mpt_read_branch_found_value: - // stack: value_len_ptr, retdest - %increment + // stack: value_ptr_ptr, retdest + %mload_trie_data // stack: value_ptr, retdest SWAP1 JUMP @@ -103,7 +96,7 @@ mpt_read_extension: %mul_const(4) SHR // key_part = key >> (future_nibbles * 4) DUP1 // stack: key_part, key_part, future_nibbles, key, node_payload_ptr, retdest - DUP5 %add_const(1) %mload_trie_data + DUP5 %increment %mload_trie_data // stack: node_key, key_part, key_part, future_nibbles, key, node_payload_ptr, retdest EQ // does the first part of our key match the node's key? %jumpi(mpt_read_extension_found) @@ -131,7 +124,7 @@ mpt_read_leaf: // stack: node_payload_ptr, num_nibbles, key, retdest DUP1 %mload_trie_data // stack: node_num_nibbles, node_payload_ptr, num_nibbles, key, retdest - DUP2 %add_const(1) %mload_trie_data + DUP2 %increment %mload_trie_data // stack: node_key, node_num_nibbles, node_payload_ptr, num_nibbles, key, retdest SWAP3 // stack: num_nibbles, node_num_nibbles, node_payload_ptr, node_key, key, retdest @@ -147,7 +140,9 @@ mpt_read_leaf: JUMP mpt_read_leaf_found: // stack: node_payload_ptr, retdest - %add_const(3) // The value is located after num_nibbles, the key, and the value length. + %add_const(2) // The value pointer is located after num_nibbles and the key. + // stack: value_ptr_ptr, retdest + %mload_trie_data // stack: value_ptr, retdest SWAP1 JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/util.asm b/evm/src/cpu/kernel/asm/mpt/util.asm index 0e0006d3..0faa72f4 100644 --- a/evm/src/cpu/kernel/asm/mpt/util.asm +++ b/evm/src/cpu/kernel/asm/mpt/util.asm @@ -28,7 +28,7 @@ %get_trie_data_size // stack: trie_data_size, value DUP1 - %add_const(1) + %increment // stack: trie_data_size', trie_data_size, value %set_trie_data_size // stack: trie_data_size, value @@ -45,7 +45,7 @@ // return (first_nibble, num_nibbles, key) %macro split_first_nibble // stack: num_nibbles, key - %sub_const(1) // num_nibbles -= 1 + %decrement // num_nibbles -= 1 // stack: num_nibbles, key DUP2 // stack: key, num_nibbles, key @@ -72,3 +72,96 @@ POP // stack: first_nibble, num_nibbles, key %endmacro + +// Split off the common prefix among two key parts. +// +// Pre stack: len_1, key_1, len_2, key_2 +// Post stack: len_common, key_common, len_1, key_1, len_2, key_2 +// +// Roughly equivalent to +// def split_common_prefix(len_1, key_1, len_2, key_2): +// bits_1 = len_1 * 4 +// bits_2 = len_2 * 4 +// len_common = 0 +// key_common = 0 +// while True: +// if bits_1 * bits_2 == 0: +// break +// first_nib_1 = (key_1 >> (bits_1 - 4)) & 0xF +// first_nib_2 = (key_2 >> (bits_2 - 4)) & 0xF +// if first_nib_1 != first_nib_2: +// break +// len_common += 1 +// key_common = key_common * 16 + first_nib_1 +// bits_1 -= 4 +// bits_2 -= 4 +// key_1 -= (first_nib_1 << bits_1) +// key_2 -= (first_nib_2 << bits_2) +// len_1 = bits_1 // 4 +// len_2 = bits_2 // 4 +// return (len_common, key_common, len_1, key_1, len_2, key_2) +%macro split_common_prefix + // stack: len_1, key_1, len_2, key_2 + %mul_const(4) + SWAP2 %mul_const(4) SWAP2 + // stack: bits_1, key_1, bits_2, key_2 + PUSH 0 + PUSH 0 + +%%loop: + // stack: len_common, key_common, bits_1, key_1, bits_2, key_2 + + // if bits_1 * bits_2 == 0: break + DUP3 DUP6 MUL ISZERO %jumpi(%%return) + + // first_nib_2 = (key_2 >> (bits_2 - 4)) & 0xF + DUP6 DUP6 %sub_const(4) SHR %and_const(0xF) + // first_nib_1 = (key_1 >> (bits_1 - 4)) & 0xF + DUP5 DUP5 %sub_const(4) SHR %and_const(0xF) + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // if first_nib_1 != first_nib_2: break + DUP2 DUP2 SUB %jumpi(%%return_with_first_nibs) + + // len_common += 1 + SWAP2 %increment SWAP2 + + // key_common = key_common * 16 + first_nib_1 + SWAP3 + %mul_const(16) + DUP4 ADD + SWAP3 + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // bits_1 -= 4 + SWAP4 %sub_const(4) SWAP4 + // bits_2 -= 4 + SWAP6 %sub_const(4) SWAP6 + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // key_1 -= (first_nib_1 << bits_1) + DUP5 SHL + // stack: first_nib_1 << bits_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + DUP6 SUB + // stack: key_1, first_nib_2, len_common, key_common, bits_1, key_1_old, bits_2, key_2 + SWAP5 POP + // stack: first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // key_2 -= (first_nib_2 << bits_2) + DUP6 SHL + // stack: first_nib_2 << bits_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + DUP7 SUB + // stack: key_2, len_common, key_common, bits_1, key_1, bits_2, key_2_old + SWAP6 POP + // stack: len_common, key_common, bits_1, key_1, bits_2, key_2 + + %jump(%%loop) +%%return_with_first_nibs: + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + %pop2 +%%return: + // stack: len_common, key_common, bits_1, key_1, bits_2, key_2 + SWAP2 %div_const(4) SWAP2 // bits_1 -> len_1 (in nibbles) + SWAP4 %div_const(4) SWAP4 // bits_2 -> len_2 (in nibbles) + // stack: len_common, key_common, len_1, key_1, len_2, key_2 +%endmacro diff --git a/evm/src/cpu/kernel/asm/mpt/write.asm b/evm/src/cpu/kernel/asm/mpt/write.asm deleted file mode 100644 index 5b59d016..00000000 --- a/evm/src/cpu/kernel/asm/mpt/write.asm +++ /dev/null @@ -1,3 +0,0 @@ -global mpt_write: - // stack: node_ptr, num_nibbles, key, retdest - // TODO diff --git a/evm/src/cpu/kernel/asm/ripemd/memory.asm b/evm/src/cpu/kernel/asm/ripemd/memory.asm index 5d0266bd..e3b7cbe6 100644 --- a/evm/src/cpu/kernel/asm/ripemd/memory.asm +++ b/evm/src/cpu/kernel/asm/ripemd/memory.asm @@ -44,7 +44,7 @@ store_input_stack: // stack: offset, byte, rem, length, REM_INP %mstore_kernel_general // stack: rem, length, REM_INP - %sub_const(1) + %decrement DUP1 // stack: rem - 1, rem - 1, length, REM_INP %jumpi(store_input_stack) @@ -66,10 +66,10 @@ store_input: // stack: offset, byte, rem , ADDR , length %mstore_kernel_general // stack: rem , ADDR , length - %sub_const(1) + %decrement // stack: rem-1, ADDR , length SWAP3 - %add_const(1) + %increment SWAP3 // stack: rem-1, ADDR+1, length DUP1 @@ -90,12 +90,12 @@ global buffer_update: // stack: get, set, get , set , times , retdest %mupdate_kernel_general // stack: get , set , times , retdest - %add_const(1) + %increment SWAP1 - %add_const(1) + %increment SWAP1 SWAP2 - %sub_const(1) + %decrement SWAP2 // stack: get+1, set+1, times-1, retdest DUP3 @@ -112,7 +112,7 @@ global buffer_update: // stack: offset = N-i, 0, i %mstore_kernel_general // stack: i - %sub_const(1) + %decrement DUP1 // stack: i-1, i-1 %jumpi($label) diff --git a/evm/src/cpu/kernel/asm/rlp/decode.asm b/evm/src/cpu/kernel/asm/rlp/decode.asm index 5749aee7..9842bfbd 100644 --- a/evm/src/cpu/kernel/asm/rlp/decode.asm +++ b/evm/src/cpu/kernel/asm/rlp/decode.asm @@ -14,7 +14,7 @@ global decode_rlp_string_len: // stack: pos, retdest DUP1 - %mload_current(@SEGMENT_RLP_RAW) + %mload_kernel(@SEGMENT_RLP_RAW) // stack: first_byte, pos, retdest DUP1 %gt_const(0xb7) @@ -36,7 +36,7 @@ decode_rlp_string_len_medium: %sub_const(0x80) // stack: len, pos, retdest SWAP1 - %add_const(1) + %increment // stack: pos', len, retdest %stack (pos, len, retdest) -> (retdest, pos, len) JUMP @@ -47,7 +47,7 @@ decode_rlp_string_len_large: %sub_const(0xb7) // stack: len_of_len, pos, retdest SWAP1 - %add_const(1) + %increment // stack: pos', len_of_len, retdest %jump(decode_int_given_len) @@ -89,10 +89,10 @@ global decode_rlp_scalar: global decode_rlp_list_len: // stack: pos, retdest DUP1 - %mload_current(@SEGMENT_RLP_RAW) + %mload_kernel(@SEGMENT_RLP_RAW) // stack: first_byte, pos, retdest SWAP1 - %add_const(1) // increment pos + %increment // increment pos SWAP1 // stack: first_byte, pos', retdest // If first_byte is >= 0xf8, it's a > 55 byte list, and @@ -151,13 +151,13 @@ decode_int_given_len_loop: // stack: acc << 8, pos, end_pos, retdest DUP2 // stack: pos, acc << 8, pos, end_pos, retdest - %mload_current(@SEGMENT_RLP_RAW) + %mload_kernel(@SEGMENT_RLP_RAW) // stack: byte, acc << 8, pos, end_pos, retdest ADD // stack: acc', pos, end_pos, retdest // Increment pos. SWAP1 - %add_const(1) + %increment SWAP1 // stack: acc', pos', end_pos, retdest %jump(decode_int_given_len_loop) diff --git a/evm/src/cpu/kernel/asm/rlp/encode.asm b/evm/src/cpu/kernel/asm/rlp/encode.asm index 851ad3cf..dada98b0 100644 --- a/evm/src/cpu/kernel/asm/rlp/encode.asm +++ b/evm/src/cpu/kernel/asm/rlp/encode.asm @@ -14,7 +14,7 @@ global encode_rlp_scalar: // stack: pos, scalar, pos, retdest %mstore_rlp // stack: pos, retdest - %add_const(1) + %increment // stack: pos', retdest SWAP1 JUMP @@ -76,7 +76,7 @@ encode_rlp_fixed: %mstore_rlp // stack: len, pos, string, retdest SWAP1 - %add_const(1) // increment pos + %increment // increment pos // stack: pos, len, string, retdest %stack (pos, len, string) -> (pos, string, len, encode_rlp_fixed_finish) // stack: context, segment, pos, string, len, encode_rlp_fixed_finish, retdest @@ -159,7 +159,7 @@ global encode_rlp_list_prefix: // stack: pos, prefix, pos, retdest %mstore_rlp // stack: pos, retdest - %add_const(1) + %increment SWAP1 JUMP encode_rlp_list_prefix_large: @@ -172,7 +172,7 @@ encode_rlp_list_prefix_large: DUP3 // pos %mstore_rlp // stack: len_of_len, pos, payload_len, retdest - SWAP1 %add_const(1) + SWAP1 %increment // stack: pos', len_of_len, payload_len, retdest %stack (pos, len_of_len, payload_len) -> (pos, payload_len, len_of_len, @@ -231,7 +231,7 @@ prepend_rlp_list_prefix_big: SUB // stack: start_pos, len_of_len, payload_len, end_pos, retdest DUP2 %add_const(0xf7) DUP2 %mstore_rlp // rlp[start_pos] = 0xf7 + len_of_len - DUP1 %add_const(1) // start_len_pos = start_pos + 1 + DUP1 %increment // start_len_pos = start_pos + 1 %stack (start_len_pos, start_pos, len_of_len, payload_len, end_pos, retdest) -> (start_len_pos, payload_len, len_of_len, prepend_rlp_list_prefix_big_done_writing_len, @@ -269,7 +269,7 @@ prepend_rlp_list_prefix_big_done_writing_len: // stack: scalar %num_bytes // stack: scalar_bytes - %add_const(1) // Account for the length prefix. + %increment // Account for the length prefix. // stack: rlp_len %%finish: %endmacro diff --git a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm index 189edd1d..2d71e65a 100644 --- a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm +++ b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm @@ -23,9 +23,9 @@ read_rlp_to_memory_loop: // stack: byte, pos, len, retdest DUP2 // stack: pos, byte, pos, len, retdest - %mstore_current(@SEGMENT_RLP_RAW) + %mstore_kernel(@SEGMENT_RLP_RAW) // stack: pos, len, retdest - %add_const(1) + %increment // stack: pos', len, retdest %jump(read_rlp_to_memory_loop) diff --git a/evm/src/cpu/kernel/asm/transactions/router.asm b/evm/src/cpu/kernel/asm/transactions/router.asm index 974fed99..3f4ebe37 100644 --- a/evm/src/cpu/kernel/asm/transactions/router.asm +++ b/evm/src/cpu/kernel/asm/transactions/router.asm @@ -18,14 +18,14 @@ read_txn_from_memory: // first byte >= 0xc0, so there is no overlap. PUSH 0 - %mload_current(@SEGMENT_RLP_RAW) + %mload_kernel(@SEGMENT_RLP_RAW) %eq_const(1) // stack: first_byte == 1, retdest %jumpi(process_type_1_txn) // stack: retdest PUSH 0 - %mload_current(@SEGMENT_RLP_RAW) + %mload_kernel(@SEGMENT_RLP_RAW) %eq_const(2) // stack: first_byte == 2, retdest %jumpi(process_type_2_txn) diff --git a/evm/src/cpu/kernel/asm/util/basic_macros.asm b/evm/src/cpu/kernel/asm/util/basic_macros.asm index bba2a2c1..02a2c807 100644 --- a/evm/src/cpu/kernel/asm/util/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/util/basic_macros.asm @@ -44,6 +44,12 @@ %endrep %endmacro +%macro pop8 + %rep 8 + POP + %endrep +%endmacro + %macro and_const(c) // stack: input, ... PUSH $c diff --git a/evm/src/cpu/kernel/constants/context_metadata.rs b/evm/src/cpu/kernel/constants/context_metadata.rs index 17945d98..a2c460fc 100644 --- a/evm/src/cpu/kernel/constants/context_metadata.rs +++ b/evm/src/cpu/kernel/constants/context_metadata.rs @@ -21,7 +21,7 @@ pub(crate) enum ContextMetadata { /// prohibited. Static = 8, /// Pointer to the initial version of the state trie, at the creation of this context. Used when - /// we need to revert a context. See also `StorageTrieCheckpointPointers`. + /// we need to revert a context. StateTrieCheckpointPointer = 9, } diff --git a/evm/src/cpu/kernel/constants/global_metadata.rs b/evm/src/cpu/kernel/constants/global_metadata.rs index f3f34e7a..1fa62efe 100644 --- a/evm/src/cpu/kernel/constants/global_metadata.rs +++ b/evm/src/cpu/kernel/constants/global_metadata.rs @@ -18,9 +18,6 @@ pub(crate) enum GlobalMetadata { TransactionTrieRoot = 5, /// A pointer to the root of the receipt trie within the `TrieData` buffer. ReceiptTrieRoot = 6, - /// The number of storage tries involved in these transactions. I.e. the number of values in - /// `StorageTrieAddresses`, `StorageTriePointers` and `StorageTrieCheckpointPointers`. - NumStorageTries = 7, // The root digests of each Merkle trie before these transactions. StateTrieRootDigestBefore = 8, @@ -31,6 +28,10 @@ pub(crate) enum GlobalMetadata { StateTrieRootDigestAfter = 11, TransactionTrieRootDigestAfter = 12, ReceiptTrieRootDigestAfter = 13, + + /// The sizes of the `TrieEncodedChild` and `TrieEncodedChildLen` buffers. In other words, the + /// next available offset in these buffers. + TrieEncodedChildSize = 14, } impl GlobalMetadata { @@ -45,13 +46,13 @@ impl GlobalMetadata { Self::StateTrieRoot, Self::TransactionTrieRoot, Self::ReceiptTrieRoot, - Self::NumStorageTries, Self::StateTrieRootDigestBefore, Self::TransactionTrieRootDigestBefore, Self::ReceiptTrieRootDigestBefore, Self::StateTrieRootDigestAfter, Self::TransactionTrieRootDigestAfter, Self::ReceiptTrieRootDigestAfter, + Self::TrieEncodedChildSize, ] } @@ -65,7 +66,6 @@ impl GlobalMetadata { GlobalMetadata::StateTrieRoot => "GLOBAL_METADATA_STATE_TRIE_ROOT", GlobalMetadata::TransactionTrieRoot => "GLOBAL_METADATA_TXN_TRIE_ROOT", GlobalMetadata::ReceiptTrieRoot => "GLOBAL_METADATA_RECEIPT_TRIE_ROOT", - GlobalMetadata::NumStorageTries => "GLOBAL_METADATA_NUM_STORAGE_TRIES", GlobalMetadata::StateTrieRootDigestBefore => "GLOBAL_METADATA_STATE_TRIE_DIGEST_BEFORE", GlobalMetadata::TransactionTrieRootDigestBefore => { "GLOBAL_METADATA_TXN_TRIE_DIGEST_BEFORE" @@ -80,6 +80,7 @@ impl GlobalMetadata { GlobalMetadata::ReceiptTrieRootDigestAfter => { "GLOBAL_METADATA_RECEIPT_TRIE_DIGEST_AFTER" } + GlobalMetadata::TrieEncodedChildSize => "TRIE_ENCODED_CHILD_SIZE", } } } diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 589ba6b3..bca6d095 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -1,3 +1,5 @@ +//! An EVM interpreter for testing and debugging purposes. + use std::collections::HashMap; use anyhow::{anyhow, bail, ensure}; @@ -75,6 +77,7 @@ pub struct Interpreter<'a> { pub(crate) generation_state: GenerationState, prover_inputs_map: &'a HashMap, pub(crate) halt_offsets: Vec, + pub(crate) debug_offsets: Vec, running: bool, } @@ -128,6 +131,7 @@ impl<'a> Interpreter<'a> { prover_inputs_map: prover_inputs, context: 0, halt_offsets: vec![DEFAULT_HALT_OFFSET], + debug_offsets: vec![], running: false, } } @@ -168,10 +172,19 @@ impl<'a> Interpreter<'a> { self.memory.context_memory[0].segments[Segment::GlobalMetadata as usize].get(field as usize) } + pub(crate) fn set_global_metadata_field(&mut self, field: GlobalMetadata, value: U256) { + self.memory.context_memory[0].segments[Segment::GlobalMetadata as usize] + .set(field as usize, value) + } + pub(crate) fn get_trie_data(&self) -> &[U256] { &self.memory.context_memory[0].segments[Segment::TrieData as usize].content } + pub(crate) fn get_trie_data_mut(&mut self) -> &mut Vec { + &mut self.memory.context_memory[0].segments[Segment::TrieData as usize].content + } + pub(crate) fn get_rlp_memory(&self) -> Vec { self.memory.context_memory[0].segments[Segment::RlpRaw as usize] .content @@ -205,7 +218,7 @@ impl<'a> Interpreter<'a> { self.push(if x { U256::one() } else { U256::zero() }); } - fn pop(&mut self) -> U256 { + pub(crate) fn pop(&mut self) -> U256 { self.stack_mut().pop().expect("Pop on empty stack.") } @@ -225,6 +238,9 @@ impl<'a> Interpreter<'a> { 0x09 => self.run_mulmod(), // "MULMOD", 0x0a => self.run_exp(), // "EXP", 0x0b => todo!(), // "SIGNEXTEND", + 0x0c => todo!(), // "ADDFP254", + 0x0d => todo!(), // "MULFP254", + 0x0e => todo!(), // "SUBFP254", 0x10 => self.run_lt(), // "LT", 0x11 => self.run_gt(), // "GT", 0x12 => todo!(), // "SLT", @@ -274,7 +290,7 @@ impl<'a> Interpreter<'a> { 0x55 => todo!(), // "SSTORE", 0x56 => self.run_jump(), // "JUMP", 0x57 => self.run_jumpi(), // "JUMPI", - 0x58 => todo!(), // "GETPC", + 0x58 => self.run_pc(), // "PC", 0x59 => self.run_msize(), // "MSIZE", 0x5a => todo!(), // "GAS", 0x5b => self.run_jumpdest(), // "JUMPDEST", @@ -309,9 +325,24 @@ impl<'a> Interpreter<'a> { 0xff => todo!(), // "SELFDESTRUCT", _ => bail!("Unrecognized opcode {}.", opcode), }; + + if self.debug_offsets.contains(&self.offset) { + println!("At {}, stack={:?}", self.offset_name(), self.stack()); + } + Ok(()) } + /// Get a string representation of the current offset for debugging purposes. + fn offset_name(&self) -> String { + // TODO: Not sure we should use KERNEL? Interpreter is more general in other places. + let label = KERNEL + .global_labels + .iter() + .find_map(|(k, v)| (*v == self.offset).then(|| k.clone())); + label.unwrap_or_else(|| self.offset.to_string()) + } + fn run_stop(&mut self) { self.running = false; } @@ -467,6 +498,7 @@ impl<'a> Interpreter<'a> { let bytes = (offset..offset + size) .map(|i| self.memory.mload_general(context, segment, i).byte(0)) .collect::>(); + println!("Hashing {:?}", &bytes); let hash = keccak(bytes); self.push(U256::from_big_endian(hash.as_bytes())); } @@ -535,6 +567,10 @@ impl<'a> Interpreter<'a> { } } + fn run_pc(&mut self) { + self.push((self.offset - 1).into()); + } + fn run_msize(&mut self) { let num_bytes = self.memory.context_memory[self.context].segments [Segment::MainMemory as usize] @@ -600,7 +636,13 @@ impl<'a> Interpreter<'a> { let segment = Segment::all()[self.pop().as_usize()]; let offset = self.pop().as_usize(); let value = self.pop(); - assert!(value.bits() <= segment.bit_range()); + assert!( + value.bits() <= segment.bit_range(), + "Value {} exceeds {:?} range of {} bits", + value, + segment, + segment.bit_range() + ); self.memory.mstore_general(context, segment, offset, value); } } diff --git a/evm/src/cpu/kernel/opcodes.rs b/evm/src/cpu/kernel/opcodes.rs index 2325c53a..20601267 100644 --- a/evm/src/cpu/kernel/opcodes.rs +++ b/evm/src/cpu/kernel/opcodes.rs @@ -2,7 +2,7 @@ pub(crate) fn get_push_opcode(n: u8) -> u8 { assert!(n > 0); assert!(n <= 32); - 0x60 + (n as u8 - 1) + 0x60 + n - 1 } /// The opcode of a standard instruction (not a `PUSH`). @@ -20,6 +20,9 @@ pub(crate) fn get_opcode(mnemonic: &str) -> u8 { "MULMOD" => 0x09, "EXP" => 0x0a, "SIGNEXTEND" => 0x0b, + "ADDFP254" => 0x0c, + "MULFP254" => 0x0d, + "SUBFP254" => 0x0e, "LT" => 0x10, "GT" => 0x11, "SLT" => 0x12, diff --git a/evm/src/cpu/kernel/tests/mpt/hash.rs b/evm/src/cpu/kernel/tests/mpt/hash.rs index 6b31a523..19c38e91 100644 --- a/evm/src/cpu/kernel/tests/mpt/hash.rs +++ b/evm/src/cpu/kernel/tests/mpt/hash.rs @@ -1,11 +1,12 @@ use anyhow::Result; -use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; -use ethereum_types::{BigEndianHash, H256, U256}; +use eth_trie_utils::partial_trie::PartialTrie; +use ethereum_types::{BigEndianHash, H256}; +use super::nibbles; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::extension_to_leaf; -use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; +use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1_rlp, test_account_2_rlp}; +use crate::generation::mpt::all_mpt_prover_inputs_reversed; use crate::generation::TrieInputs; // TODO: Test with short leaf. Might need to be a storage trie. @@ -23,74 +24,70 @@ fn mpt_hash_empty() -> Result<()> { } #[test] -fn mpt_hash_leaf() -> Result<()> { - let account = AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: H256::from_uint(&U256::from(3333)), - code_hash: H256::from_uint(&U256::from(4444)), +fn mpt_hash_empty_branch() -> Result<()> { + let children = std::array::from_fn(|_| PartialTrie::Empty.into()); + let state_trie = PartialTrie::Branch { + children, + value: vec![], }; - let account_rlp = rlp::encode(&account); - - let state_trie = PartialTrie::Leaf { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, - value: account_rlp.to_vec(), - }; - let trie_inputs = TrieInputs { state_trie, transactions_trie: Default::default(), receipts_trie: Default::default(), storage_tries: vec![], }; + test_state_trie(trie_inputs) +} +#[test] +fn mpt_hash_hash() -> Result<()> { + let hash = H256::random(); + let trie_inputs = TrieInputs { + state_trie: PartialTrie::Hash(hash), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + test_state_trie(trie_inputs) +} + +#[test] +fn mpt_hash_leaf() -> Result<()> { + let state_trie = PartialTrie::Leaf { + nibbles: nibbles(0xABC), + value: test_account_1_rlp(), + }; + let trie_inputs = TrieInputs { + state_trie, + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; test_state_trie(trie_inputs) } #[test] fn mpt_hash_extension_to_leaf() -> Result<()> { - let account = AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: H256::from_uint(&U256::from(3333)), - code_hash: H256::from_uint(&U256::from(4444)), - }; - let account_rlp = rlp::encode(&account); - - let state_trie = extension_to_leaf(account_rlp.to_vec()); - + let state_trie = extension_to_leaf(test_account_1_rlp()); let trie_inputs = TrieInputs { state_trie, transactions_trie: Default::default(), receipts_trie: Default::default(), storage_tries: vec![], }; - test_state_trie(trie_inputs) } #[test] fn mpt_hash_branch_to_leaf() -> Result<()> { - let account = AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: H256::from_uint(&U256::from(3333)), - code_hash: H256::from_uint(&U256::from(4444)), - }; - let account_rlp = rlp::encode(&account); - let leaf = PartialTrie::Leaf { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, - value: account_rlp.to_vec(), - }; - let mut children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); - children[0] = Box::new(leaf); + nibbles: nibbles(0xABC), + value: test_account_2_rlp(), + } + .into(); + let mut children = std::array::from_fn(|_| PartialTrie::Empty.into()); + children[3] = leaf; let state_trie = PartialTrie::Branch { children, value: vec![], diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs new file mode 100644 index 00000000..3a52948d --- /dev/null +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -0,0 +1,208 @@ +use anyhow::Result; +use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; +use ethereum_types::{BigEndianHash, H256}; + +use super::nibbles; +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::interpreter::Interpreter; +use crate::cpu::kernel::tests::mpt::{test_account_1_rlp, test_account_2}; +use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; +use crate::generation::TrieInputs; + +#[test] +fn mpt_insert_empty() -> Result<()> { + test_state_trie(Default::default(), nibbles(0xABC), test_account_2()) +} + +#[test] +fn mpt_insert_leaf_identical_keys() -> Result<()> { + let key = nibbles(0xABC); + let state_trie = PartialTrie::Leaf { + nibbles: key, + value: test_account_1_rlp(), + }; + test_state_trie(state_trie, key, test_account_2()) +} + +#[test] +fn mpt_insert_leaf_nonoverlapping_keys() -> Result<()> { + let state_trie = PartialTrie::Leaf { + nibbles: nibbles(0xABC), + value: test_account_1_rlp(), + }; + test_state_trie(state_trie, nibbles(0x123), test_account_2()) +} + +#[test] +fn mpt_insert_leaf_overlapping_keys() -> Result<()> { + let state_trie = PartialTrie::Leaf { + nibbles: nibbles(0xABC), + value: test_account_1_rlp(), + }; + test_state_trie(state_trie, nibbles(0xADE), test_account_2()) +} + +#[test] +fn mpt_insert_leaf_insert_key_extends_leaf_key() -> Result<()> { + let state_trie = PartialTrie::Leaf { + nibbles: nibbles(0xABC), + value: test_account_1_rlp(), + }; + test_state_trie(state_trie, nibbles(0xABCDE), test_account_2()) +} + +#[test] +fn mpt_insert_leaf_leaf_key_extends_insert_key() -> Result<()> { + let state_trie = PartialTrie::Leaf { + nibbles: nibbles(0xABCDE), + value: test_account_1_rlp(), + }; + test_state_trie(state_trie, nibbles(0xABC), test_account_2()) +} + +#[test] +fn mpt_insert_branch_replacing_empty_child() -> Result<()> { + let children = std::array::from_fn(|_| PartialTrie::Empty.into()); + let state_trie = PartialTrie::Branch { + children, + value: vec![], + }; + + test_state_trie(state_trie, nibbles(0xABC), test_account_2()) +} + +#[test] +// TODO: Not a valid test because branches state trie cannot have branch values. +// We should change it to use a different trie. +#[ignore] +fn mpt_insert_extension_nonoverlapping_keys() -> Result<()> { + // Existing keys are 0xABC, 0xABCDEF; inserted key is 0x12345. + let mut children = std::array::from_fn(|_| PartialTrie::Empty.into()); + children[0xD] = PartialTrie::Leaf { + nibbles: nibbles(0xEF), + value: test_account_1_rlp(), + } + .into(); + let state_trie = PartialTrie::Extension { + nibbles: nibbles(0xABC), + child: PartialTrie::Branch { + children, + value: test_account_1_rlp(), + } + .into(), + }; + test_state_trie(state_trie, nibbles(0x12345), test_account_2()) +} + +#[test] +// TODO: Not a valid test because branches state trie cannot have branch values. +// We should change it to use a different trie. +#[ignore] +fn mpt_insert_extension_insert_key_extends_node_key() -> Result<()> { + // Existing keys are 0xA, 0xABCD; inserted key is 0xABCDEF. + let mut children = std::array::from_fn(|_| PartialTrie::Empty.into()); + children[0xB] = PartialTrie::Leaf { + nibbles: nibbles(0xCD), + value: test_account_1_rlp(), + } + .into(); + let state_trie = PartialTrie::Extension { + nibbles: nibbles(0xA), + child: PartialTrie::Branch { + children, + value: test_account_1_rlp(), + } + .into(), + }; + test_state_trie(state_trie, nibbles(0xABCDEF), test_account_2()) +} + +#[test] +fn mpt_insert_branch_to_leaf_same_key() -> Result<()> { + let leaf = PartialTrie::Leaf { + nibbles: nibbles(0xBCD), + value: test_account_1_rlp(), + } + .into(); + let mut children = std::array::from_fn(|_| PartialTrie::Empty.into()); + children[0xA] = leaf; + let state_trie = PartialTrie::Branch { + children, + value: vec![], + }; + + test_state_trie(state_trie, nibbles(0xABCD), test_account_2()) +} + +/// Note: The account's storage_root is ignored, as we can't insert a new storage_root without the +/// accompanying trie data. An empty trie's storage_root is used instead. +fn test_state_trie(state_trie: PartialTrie, k: Nibbles, mut account: AccountRlp) -> Result<()> { + account.storage_root = PartialTrie::Empty.calc_hash(); + + let trie_inputs = TrieInputs { + state_trie: state_trie.clone(), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + let mpt_insert_state_trie = KERNEL.global_labels["mpt_insert_state_trie"]; + let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + // Next, execute mpt_insert_state_trie. + interpreter.offset = mpt_insert_state_trie; + let trie_data = interpreter.get_trie_data_mut(); + if trie_data.is_empty() { + // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. + // Since we don't explicitly set it to 0, we need to do so here. + trie_data.push(0.into()); + } + let value_ptr = trie_data.len(); + trie_data.push(account.nonce); + trie_data.push(account.balance); + // In memory, storage_root gets interpreted as a pointer to a storage trie, + // so we have to ensure the pointer is valid. It's easiest to set it to 0, + // which works as an empty node, since trie_data[0] = 0 = MPT_TYPE_EMPTY. + trie_data.push(H256::zero().into_uint()); + trie_data.push(account.code_hash.into_uint()); + let trie_data_len = trie_data.len().into(); + interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, trie_data_len); + interpreter.push(0xDEADBEEFu32.into()); + interpreter.push(value_ptr.into()); // value_ptr + interpreter.push(k.packed); // key + interpreter.push(k.count.into()); // num_nibbles + + interpreter.run()?; + assert_eq!( + interpreter.stack().len(), + 0, + "Expected empty stack after insert, found {:?}", + interpreter.stack() + ); + + // Now, execute mpt_hash_state_trie. + interpreter.offset = mpt_hash_state_trie; + interpreter.push(0xDEADBEEFu32.into()); + interpreter.run()?; + + assert_eq!( + interpreter.stack().len(), + 1, + "Expected 1 item on stack after hashing, found {:?}", + interpreter.stack() + ); + let hash = H256::from_uint(&interpreter.stack()[0]); + + let updated_trie = state_trie.insert(k, rlp::encode(&account).to_vec()); + let expected_state_trie_hash = updated_trie.calc_hash(); + assert_eq!(hash, expected_state_trie_hash); + + Ok(()) +} diff --git a/evm/src/cpu/kernel/tests/mpt/load.rs b/evm/src/cpu/kernel/tests/mpt/load.rs index 3af39e30..78129a1c 100644 --- a/evm/src/cpu/kernel/tests/mpt/load.rs +++ b/evm/src/cpu/kernel/tests/mpt/load.rs @@ -1,26 +1,19 @@ use anyhow::Result; +use eth_trie_utils::partial_trie::PartialTrie; use ethereum_types::{BigEndianHash, H256, U256}; -use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::extension_to_leaf; -use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; +use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1, test_account_1_rlp}; +use crate::cpu::kernel::{aggregator::KERNEL, tests::mpt::nibbles}; +use crate::generation::mpt::all_mpt_prover_inputs_reversed; use crate::generation::TrieInputs; #[test] -fn load_all_mpts() -> Result<()> { - let account = AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: H256::from_uint(&U256::from(3333)), - code_hash: H256::from_uint(&U256::from(4444)), - }; - let account_rlp = rlp::encode(&account); - +fn load_all_mpts_empty() -> Result<()> { let trie_inputs = TrieInputs { - state_trie: extension_to_leaf(account_rlp.to_vec()), + state_trie: Default::default(), transactions_trie: Default::default(), receipts_trie: Default::default(), storage_tries: vec![], @@ -28,13 +21,194 @@ fn load_all_mpts() -> Result<()> { let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - let initial_stack = vec![0xdeadbeefu32.into()]; + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + assert_eq!(interpreter.get_trie_data(), vec![]); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot), + 0.into() + ); + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), + 0.into() + ); + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), + 0.into() + ); + + Ok(()) +} + +#[test] +fn load_all_mpts_leaf() -> Result<()> { + let trie_inputs = TrieInputs { + state_trie: PartialTrie::Leaf { + nibbles: nibbles(0xABC), + value: test_account_1_rlp(), + }, + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + let type_leaf = U256::from(PartialTrieType::Leaf as u32); + assert_eq!( + interpreter.get_trie_data(), + vec![ + 0.into(), + type_leaf, + 3.into(), + 0xABC.into(), + 5.into(), // value ptr + test_account_1().nonce, + test_account_1().balance, + 9.into(), // pointer to storage trie root + test_account_1().code_hash.into_uint(), + // These last two elements encode the storage trie, which is a hash node. + (PartialTrieType::Hash as u32).into(), + test_account_1().storage_root.into_uint(), + ] + ); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), + 0.into() + ); + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), + 0.into() + ); + + Ok(()) +} + +#[test] +fn load_all_mpts_hash() -> Result<()> { + let hash = H256::random(); + let trie_inputs = TrieInputs { + state_trie: PartialTrie::Hash(hash), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + let type_hash = U256::from(PartialTrieType::Hash as u32); + assert_eq!( + interpreter.get_trie_data(), + vec![0.into(), type_hash, hash.into_uint(),] + ); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), + 0.into() + ); + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), + 0.into() + ); + + Ok(()) +} + +#[test] +fn load_all_mpts_empty_branch() -> Result<()> { + let children = std::array::from_fn(|_| PartialTrie::Empty.into()); + let state_trie = PartialTrie::Branch { + children, + value: vec![], + }; + let trie_inputs = TrieInputs { + state_trie, + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + let type_branch = U256::from(PartialTrieType::Branch as u32); + assert_eq!( + interpreter.get_trie_data(), + vec![ + 0.into(), // First address is unused, so that 0 can be treated as a null pointer. + type_branch, + 0.into(), // child 0 + 0.into(), // ... + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), // child 16 + 0.into(), // value_ptr + ] + ); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), + 0.into() + ); + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), + 0.into() + ); + + Ok(()) +} + +#[test] +fn load_all_mpts_ext_to_leaf() -> Result<()> { + let trie_inputs = TrieInputs { + state_trie: extension_to_leaf(test_account_1_rlp()), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); - let type_empty = U256::from(PartialTrieType::Empty as u32); let type_extension = U256::from(PartialTrieType::Extension as u32); let type_leaf = U256::from(PartialTrieType::Leaf as u32); assert_eq!( @@ -48,20 +222,16 @@ fn load_all_mpts() -> Result<()> { type_leaf, 3.into(), // 3 nibbles 0xDEF.into(), // key part - 4.into(), // value length - account.nonce, - account.balance, - account.storage_root.into_uint(), - account.code_hash.into_uint(), - type_empty, // txn trie - type_empty, // receipt trie + 9.into(), // value pointer + test_account_1().nonce, + test_account_1().balance, + 13.into(), // pointer to storage trie root + test_account_1().code_hash.into_uint(), + // These last two elements encode the storage trie, which is a hash node. + (PartialTrieType::Hash as u32).into(), + test_account_1().storage_root.into_uint(), ] ); - assert_eq!( - interpreter.get_global_metadata_field(GlobalMetadata::NumStorageTries), - trie_inputs.storage_tries.len().into() - ); - Ok(()) } diff --git a/evm/src/cpu/kernel/tests/mpt/mod.rs b/evm/src/cpu/kernel/tests/mpt/mod.rs index 55a56653..2c7999df 100644 --- a/evm/src/cpu/kernel/tests/mpt/mod.rs +++ b/evm/src/cpu/kernel/tests/mpt/mod.rs @@ -1,23 +1,62 @@ use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; +use ethereum_types::{BigEndianHash, H256, U256}; + +use crate::generation::mpt::AccountRlp; mod hash; mod hex_prefix; +mod insert; mod load; mod read; +/// Helper function to reduce code duplication. +/// Note that this preserves all nibbles (eg. `0x123` is not interpreted as `0x0123`). +pub(crate) fn nibbles>(v: T) -> Nibbles { + let packed = v.into(); + + Nibbles { + count: Nibbles::get_num_nibbles_in_key(&packed), + packed, + } +} + +pub(crate) fn test_account_1() -> AccountRlp { + AccountRlp { + nonce: U256::from(1111), + balance: U256::from(2222), + storage_root: H256::from_uint(&U256::from(3333)), + code_hash: H256::from_uint(&U256::from(4444)), + } +} + +pub(crate) fn test_account_1_rlp() -> Vec { + rlp::encode(&test_account_1()).to_vec() +} + +pub(crate) fn test_account_2() -> AccountRlp { + AccountRlp { + nonce: U256::from(5555), + balance: U256::from(6666), + storage_root: H256::from_uint(&U256::from(7777)), + code_hash: H256::from_uint(&U256::from(8888)), + } +} + +pub(crate) fn test_account_2_rlp() -> Vec { + rlp::encode(&test_account_2()).to_vec() +} + /// A `PartialTrie` where an extension node leads to a leaf node containing an account. pub(crate) fn extension_to_leaf(value: Vec) -> PartialTrie { PartialTrie::Extension { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, - child: Box::new(PartialTrie::Leaf { + nibbles: nibbles(0xABC), + child: PartialTrie::Leaf { nibbles: Nibbles { count: 3, packed: 0xDEF.into(), }, value, - }), + } + .into(), } } diff --git a/evm/src/cpu/kernel/tests/mpt/read.rs b/evm/src/cpu/kernel/tests/mpt/read.rs index c45a6b60..d8808e24 100644 --- a/evm/src/cpu/kernel/tests/mpt/read.rs +++ b/evm/src/cpu/kernel/tests/mpt/read.rs @@ -1,25 +1,17 @@ use anyhow::Result; -use ethereum_types::{BigEndianHash, H256, U256}; +use ethereum_types::BigEndianHash; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::extension_to_leaf; -use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; +use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1, test_account_1_rlp}; +use crate::generation::mpt::all_mpt_prover_inputs_reversed; use crate::generation::TrieInputs; #[test] fn mpt_read() -> Result<()> { - let account = AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: H256::from_uint(&U256::from(3333)), - code_hash: H256::from_uint(&U256::from(4444)), - }; - let account_rlp = rlp::encode(&account); - let trie_inputs = TrieInputs { - state_trie: extension_to_leaf(account_rlp.to_vec()), + state_trie: extension_to_leaf(test_account_1_rlp()), transactions_trie: Default::default(), receipts_trie: Default::default(), storage_tries: vec![], @@ -45,10 +37,11 @@ fn mpt_read() -> Result<()> { assert_eq!(interpreter.stack().len(), 1); let result_ptr = interpreter.stack()[0].as_usize(); let result = &interpreter.get_trie_data()[result_ptr..][..4]; - assert_eq!(result[0], account.nonce); - assert_eq!(result[1], account.balance); - assert_eq!(result[2], account.storage_root.into_uint()); - assert_eq!(result[3], account.code_hash.into_uint()); + assert_eq!(result[0], test_account_1().nonce); + assert_eq!(result[1], test_account_1().balance); + // result[2] is the storage root pointer. We won't check that it matches a + // particular address, since that seems like over-specifying. + assert_eq!(result[3], test_account_1().code_hash.into_uint()); Ok(()) } diff --git a/evm/src/cpu/kernel/tests/ripemd.rs b/evm/src/cpu/kernel/tests/ripemd.rs index 6123c336..305548ec 100644 --- a/evm/src/cpu/kernel/tests/ripemd.rs +++ b/evm/src/cpu/kernel/tests/ripemd.rs @@ -46,7 +46,7 @@ fn test_ripemd_reference() -> Result<()> { let kernel = combined_kernel(); let initial_offset = kernel.global_labels["ripemd_stack"]; - let initial_stack: Vec = input.iter().map(|&x| U256::from(x as u32)).rev().collect(); + let initial_stack: Vec = input.iter().map(|&x| U256::from(x)).rev().collect(); let final_stack: Vec = run_with_kernel(&kernel, initial_offset, initial_stack)? .stack() .to_vec(); diff --git a/evm/src/cpu/mod.rs b/evm/src/cpu/mod.rs index bde06585..fda5db80 100644 --- a/evm/src/cpu/mod.rs +++ b/evm/src/cpu/mod.rs @@ -7,6 +7,7 @@ mod dup_swap; mod jumps; pub mod kernel; pub(crate) mod membus; +mod modfp254; mod simple_logic; mod stack; mod stack_bounds; diff --git a/evm/src/cpu/modfp254.rs b/evm/src/cpu/modfp254.rs new file mode 100644 index 00000000..defbf862 --- /dev/null +++ b/evm/src/cpu/modfp254.rs @@ -0,0 +1,53 @@ +use itertools::izip; +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cpu::columns::CpuColumnsView; + +// Python: +// >>> P = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +// >>> "[" + ", ".join(hex((P >> n) % 2**32) for n in range(0, 256, 32)) + "]" +const P_LIMBS: [u32; 8] = [ + 0xd87cfd47, 0x3c208c16, 0x6871ca8d, 0x97816a91, 0x8181585d, 0xb85045b6, 0xe131a029, 0x30644e72, +]; + +pub fn eval_packed( + lv: &CpuColumnsView

, + yield_constr: &mut ConstraintConsumer

, +) { + let filter = lv.is_cpu_cycle * (lv.op.addfp254 + lv.op.mulfp254 + lv.op.subfp254); + + // We want to use all the same logic as the usual mod operations, but without needing to read + // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's + // where the modulus goes in the generalized operations). + let channel_val = lv.mem_channels[2].value; + for (channel_limb, p_limb) in izip!(channel_val, P_LIMBS) { + let p_limb = P::Scalar::from_canonical_u32(p_limb); + yield_constr.constraint(filter * (channel_limb - p_limb)); + } +} + +pub fn eval_ext_circuit, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let filter = { + let flag_sum = builder.add_many_extension([lv.op.addfp254, lv.op.mulfp254, lv.op.subfp254]); + builder.mul_extension(lv.is_cpu_cycle, flag_sum) + }; + + // We want to use all the same logic as the usual mod operations, but without needing to read + // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's + // where the modulus goes in the generalized operations). + let channel_val = lv.mem_channels[2].value; + for (channel_limb, p_limb) in izip!(channel_val, P_LIMBS) { + let p_limb = F::from_canonical_u32(p_limb); + let constr = builder.arithmetic_extension(F::ONE, -p_limb, filter, channel_limb, filter); + yield_constr.constraint(builder, constr); + } +} diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index 9bc08091..c72688ed 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -52,6 +52,9 @@ const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { mulmod: BASIC_TERNARY_OP, exp: None, // TODO signextend: BASIC_BINARY_OP, + addfp254: BASIC_BINARY_OP, + mulfp254: BASIC_BINARY_OP, + subfp254: BASIC_BINARY_OP, lt: BASIC_BINARY_OP, gt: BASIC_BINARY_OP, slt: BASIC_BINARY_OP, diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index f6bc630d..8ceb195a 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -1,5 +1,8 @@ -use eth_trie_utils::partial_trie::PartialTrie; +use std::collections::HashMap; + +use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; use ethereum_types::{BigEndianHash, H256, U256}; +use keccak_hash::keccak; use rlp_derive::{RlpDecodable, RlpEncodable}; use crate::cpu::kernel::constants::trie_type::PartialTrieType; @@ -23,15 +26,18 @@ pub(crate) fn all_mpt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec Vec { let mut prover_inputs = vec![]; - mpt_prover_inputs(&trie_inputs.state_trie, &mut prover_inputs, &|rlp| { - let account: AccountRlp = rlp::decode(rlp).expect("Decoding failed"); - vec![ - account.nonce, - account.balance, - account.storage_root.into_uint(), - account.code_hash.into_uint(), - ] - }); + let storage_tries_by_state_key = trie_inputs + .storage_tries + .iter() + .map(|(address, storage_trie)| (Nibbles::from(keccak(address)), storage_trie)) + .collect(); + + mpt_prover_inputs_state_trie( + &trie_inputs.state_trie, + empty_nibbles(), + &mut prover_inputs, + &storage_tries_by_state_key, + ); mpt_prover_inputs(&trie_inputs.transactions_trie, &mut prover_inputs, &|rlp| { rlp::decode_list(rlp) @@ -42,14 +48,6 @@ pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { vec![] }); - prover_inputs.push(trie_inputs.storage_tries.len().into()); - for (addr, storage_trie) in &trie_inputs.storage_tries { - prover_inputs.push(addr.0.as_ref().into()); - mpt_prover_inputs(storage_trie, &mut prover_inputs, &|leaf_be| { - vec![U256::from_big_endian(leaf_be)] - }); - } - prover_inputs } @@ -60,7 +58,7 @@ pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { pub(crate) fn mpt_prover_inputs( trie: &PartialTrie, prover_inputs: &mut Vec, - parse_leaf: &F, + parse_value: &F, ) where F: Fn(&[u8]) -> Vec, { @@ -70,28 +68,108 @@ pub(crate) fn mpt_prover_inputs( PartialTrie::Hash(h) => prover_inputs.push(U256::from_big_endian(h.as_bytes())), PartialTrie::Branch { children, value } => { if value.is_empty() { - // There's no value, so length=0. + // There's no value, so value_len = 0. prover_inputs.push(U256::zero()); } else { - let leaf = parse_leaf(value); - prover_inputs.push(leaf.len().into()); - prover_inputs.extend(leaf); + let parsed_value = parse_value(value); + prover_inputs.push(parsed_value.len().into()); + prover_inputs.extend(parsed_value); } for child in children { - mpt_prover_inputs(child, prover_inputs, parse_leaf); + mpt_prover_inputs(child, prover_inputs, parse_value); } } PartialTrie::Extension { nibbles, child } => { prover_inputs.push(nibbles.count.into()); prover_inputs.push(nibbles.packed); - mpt_prover_inputs(child, prover_inputs, parse_leaf); + mpt_prover_inputs(child, prover_inputs, parse_value); } PartialTrie::Leaf { nibbles, value } => { prover_inputs.push(nibbles.count.into()); prover_inputs.push(nibbles.packed); - let leaf = parse_leaf(value); - prover_inputs.push(leaf.len().into()); + let leaf = parse_value(value); prover_inputs.extend(leaf); } } } + +/// Like `mpt_prover_inputs`, but for the state trie, which is a bit unique since each value +/// leads to a storage trie which we recursively traverse. +pub(crate) fn mpt_prover_inputs_state_trie( + trie: &PartialTrie, + key: Nibbles, + prover_inputs: &mut Vec, + storage_tries_by_state_key: &HashMap, +) { + prover_inputs.push((PartialTrieType::of(trie) as u32).into()); + match trie { + PartialTrie::Empty => {} + PartialTrie::Hash(h) => prover_inputs.push(U256::from_big_endian(h.as_bytes())), + PartialTrie::Branch { children, value } => { + assert!(value.is_empty(), "State trie should not have branch values"); + // There's no value, so value_len = 0. + prover_inputs.push(U256::zero()); + + for (i, child) in children.iter().enumerate() { + let extended_key = key.merge(&Nibbles { + count: 1, + packed: i.into(), + }); + mpt_prover_inputs_state_trie( + child, + extended_key, + prover_inputs, + storage_tries_by_state_key, + ); + } + } + PartialTrie::Extension { nibbles, child } => { + prover_inputs.push(nibbles.count.into()); + prover_inputs.push(nibbles.packed); + let extended_key = key.merge(nibbles); + mpt_prover_inputs_state_trie( + child, + extended_key, + prover_inputs, + storage_tries_by_state_key, + ); + } + PartialTrie::Leaf { nibbles, value } => { + let account: AccountRlp = rlp::decode(value).expect("Decoding failed"); + let AccountRlp { + nonce, + balance, + storage_root, + code_hash, + } = account; + + let storage_hash_only = PartialTrie::Hash(storage_root); + let storage_trie: &PartialTrie = storage_tries_by_state_key + .get(&key) + .copied() + .unwrap_or(&storage_hash_only); + + assert_eq!(storage_trie.calc_hash(), storage_root, + "In TrieInputs, an account's storage_root didn't match the associated storage trie hash"); + + prover_inputs.push(nibbles.count.into()); + prover_inputs.push(nibbles.packed); + prover_inputs.push(nonce); + prover_inputs.push(balance); + mpt_prover_inputs(storage_trie, prover_inputs, &parse_storage_value); + prover_inputs.push(code_hash.into_uint()); + } + } +} + +fn parse_storage_value(value_rlp: &[u8]) -> Vec { + let value: U256 = rlp::decode(value_rlp).expect("Decoding failed"); + vec![value] +} + +fn empty_nibbles() -> Nibbles { + Nibbles { + count: 0, + packed: U256::zero(), + } +} diff --git a/evm/src/memory/segments.rs b/evm/src/memory/segments.rs index 44390a9b..b8ba904f 100644 --- a/evm/src/memory/segments.rs +++ b/evm/src/memory/segments.rs @@ -29,19 +29,14 @@ pub(crate) enum Segment { /// Contains all trie data. Tries are stored as immutable, copy-on-write trees, so this is an /// append-only buffer. It is owned by the kernel, so it only lives on context 0. TrieData = 12, - /// The account address associated with the `i`th storage trie. Only lives on context 0. - StorageTrieAddresses = 13, - /// A pointer to the `i`th storage trie within the `TrieData` buffer. Only lives on context 0. - StorageTriePointers = 14, - /// Like `StorageTriePointers`, except that these pointers correspond to the version of each - /// trie at the creation of a given context. This lets us easily revert a context by replacing - /// `StorageTriePointers` with `StorageTrieCheckpointPointers`. - /// See also `StateTrieCheckpointPointer`. - StorageTrieCheckpointPointers = 15, + /// A buffer used to store the encodings of a branch node's children. + TrieEncodedChild = 13, + /// A buffer used to store the lengths of the encodings of a branch node's children. + TrieEncodedChildLen = 14, } impl Segment { - pub(crate) const COUNT: usize = 16; + pub(crate) const COUNT: usize = 15; pub(crate) fn all() -> [Self; Self::COUNT] { [ @@ -58,9 +53,8 @@ impl Segment { Self::TxnData, Self::RlpRaw, Self::TrieData, - Self::StorageTrieAddresses, - Self::StorageTriePointers, - Self::StorageTrieCheckpointPointers, + Self::TrieEncodedChild, + Self::TrieEncodedChildLen, ] } @@ -80,9 +74,8 @@ impl Segment { Segment::TxnData => "SEGMENT_TXN_DATA", Segment::RlpRaw => "SEGMENT_RLP_RAW", Segment::TrieData => "SEGMENT_TRIE_DATA", - Segment::StorageTrieAddresses => "SEGMENT_STORAGE_TRIE_ADDRS", - Segment::StorageTriePointers => "SEGMENT_STORAGE_TRIE_PTRS", - Segment::StorageTrieCheckpointPointers => "SEGMENT_STORAGE_TRIE_CHECKPOINT_PTRS", + Segment::TrieEncodedChild => "SEGMENT_TRIE_ENCODED_CHILD", + Segment::TrieEncodedChildLen => "SEGMENT_TRIE_ENCODED_CHILD_LEN", } } @@ -102,9 +95,8 @@ impl Segment { Segment::TxnData => 256, Segment::RlpRaw => 8, Segment::TrieData => 256, - Segment::StorageTrieAddresses => 160, - Segment::StorageTriePointers => 32, - Segment::StorageTrieCheckpointPointers => 32, + Segment::TrieEncodedChild => 256, + Segment::TrieEncodedChildLen => 6, } } } diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 4f08b513..bc64bb57 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -146,7 +146,7 @@ impl, C: GenericConfig, const D: usize> } // Verify the CTL checks. - let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits); + let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits()); verify_cross_table_lookups::( cross_table_lookups, pis.map(|p| p.ctl_zs_last), @@ -221,7 +221,7 @@ impl, C: GenericConfig, const D: usize> } // Verify the CTL checks. - let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits); + let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits()); verify_cross_table_lookups_circuit::( builder, cross_table_lookups, diff --git a/field/src/types.rs b/field/src/types.rs index b112fde2..545f90c5 100644 --- a/field/src/types.rs +++ b/field/src/types.rs @@ -455,7 +455,7 @@ pub trait PrimeField: Field { let mut x = w * *self; let mut b = x * w; - let mut v = Self::TWO_ADICITY as usize; + let mut v = Self::TWO_ADICITY; while !b.is_one() { let mut k = 0usize; diff --git a/plonky2/examples/bench_recursion.rs b/plonky2/examples/bench_recursion.rs index 2c38faa7..f4379e7a 100644 --- a/plonky2/examples/bench_recursion.rs +++ b/plonky2/examples/bench_recursion.rs @@ -191,7 +191,7 @@ fn benchmark(config: &CircuitConfig, log2_inner_size: usize) -> Result<()> { info!( "Initial proof degree {} = 2^{}", cd.degree(), - cd.degree_bits + cd.degree_bits() ); // Recursively verify the proof @@ -200,7 +200,7 @@ fn benchmark(config: &CircuitConfig, log2_inner_size: usize) -> Result<()> { info!( "Single recursion proof degree {} = 2^{}", cd.degree(), - cd.degree_bits + cd.degree_bits() ); // Add a second layer of recursion to shrink the proof size further @@ -209,7 +209,7 @@ fn benchmark(config: &CircuitConfig, log2_inner_size: usize) -> Result<()> { info!( "Double recursion proof degree {} = 2^{}", cd.degree(), - cd.degree_bits + cd.degree_bits() ); test_serialization(proof, vd, cd)?; diff --git a/plonky2/src/gadgets/arithmetic_extension.rs b/plonky2/src/gadgets/arithmetic_extension.rs index 23caeac1..23c401b8 100644 --- a/plonky2/src/gadgets/arithmetic_extension.rs +++ b/plonky2/src/gadgets/arithmetic_extension.rs @@ -443,7 +443,7 @@ impl, const D: usize> CircuitBuilder { let mut current = base; let mut product = self.one_extension(); - for j in 0..bits_u64(exponent as u64) { + for j in 0..bits_u64(exponent) { if j != 0 { current = self.square_extension(current); } diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 757dbdb1..83587f2e 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -818,7 +818,6 @@ impl, const D: usize> CircuitBuilder { let common = CommonCircuitData { config: self.config, fri_params, - degree_bits, gates, selectors_info, quotient_degree_factor, diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 0eb5c5e5..5143e730 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -296,8 +296,6 @@ pub struct CommonCircuitData< pub(crate) fri_params: FriParams, - pub degree_bits: usize, - /// The types of gates used in this circuit, along with their prefixes. pub(crate) gates: Vec>, @@ -325,16 +323,20 @@ pub struct CommonCircuitData< impl, C: GenericConfig, const D: usize> CommonCircuitData { + pub const fn degree_bits(&self) -> usize { + self.fri_params.degree_bits + } + pub fn degree(&self) -> usize { - 1 << self.degree_bits + 1 << self.degree_bits() } pub fn lde_size(&self) -> usize { - 1 << (self.degree_bits + self.config.fri_config.rate_bits) + self.fri_params.lde_size() } pub fn lde_generator(&self) -> F { - F::primitive_root_of_unity(self.degree_bits + self.config.fri_config.rate_bits) + F::primitive_root_of_unity(self.degree_bits() + self.config.fri_config.rate_bits) } pub fn constraint_degree(&self) -> usize { @@ -377,7 +379,7 @@ impl, C: GenericConfig, const D: usize> }; // The Z polynomials are also opened at g * zeta. - let g = F::Extension::primitive_root_of_unity(self.degree_bits); + let g = F::Extension::primitive_root_of_unity(self.degree_bits()); let zeta_next = g * zeta; let zeta_next_batch = FriBatchInfo { point: zeta_next, @@ -403,7 +405,7 @@ impl, C: GenericConfig, const D: usize> }; // The Z polynomials are also opened at g * zeta. - let g = F::primitive_root_of_unity(self.degree_bits); + let g = F::primitive_root_of_unity(self.degree_bits()); let zeta_next = builder.mul_const_extension(g, zeta); let zeta_next_batch = FriBatchInfoTarget { point: zeta_next, diff --git a/plonky2/src/plonk/get_challenges.rs b/plonky2/src/plonk/get_challenges.rs index d716c251..f497380f 100644 --- a/plonky2/src/plonk/get_challenges.rs +++ b/plonky2/src/plonk/get_challenges.rs @@ -62,7 +62,7 @@ fn get_challenges, C: GenericConfig, cons commit_phase_merkle_caps, final_poly, pow_witness, - common_data.degree_bits, + common_data.degree_bits(), &config.fri_config, ), }) @@ -181,7 +181,7 @@ impl, C: GenericConfig, const D: usize> &self.proof.openings.to_fri_openings(), *fri_alpha, ); - let log_n = common_data.degree_bits + common_data.config.fri_config.rate_bits; + let log_n = common_data.degree_bits() + common_data.config.fri_config.rate_bits; // Simulate the proof verification and collect the inferred elements. // The content of the loop is basically the same as the `fri_verifier_query_round` function. for &(mut x_index) in fri_query_indices { diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index 826e5ab4..8476a2d9 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -172,9 +172,9 @@ where // To avoid leaking witness data, we want to ensure that our opening locations, `zeta` and // `g * zeta`, are not in our subgroup `H`. It suffices to check `zeta` only, since // `(g * zeta)^n = zeta^n`, where `n` is the order of `g`. - let g = F::Extension::primitive_root_of_unity(common_data.degree_bits); + let g = F::Extension::primitive_root_of_unity(common_data.degree_bits()); ensure!( - zeta.exp_power_of_2(common_data.degree_bits) != F::Extension::ONE, + zeta.exp_power_of_2(common_data.degree_bits()) != F::Extension::ONE, "Opening point is in the subgroup." ); @@ -342,10 +342,10 @@ fn compute_quotient_polys< // steps away since we work on an LDE of degree `max_filtered_constraint_degree`. let next_step = 1 << quotient_degree_bits; - let points = F::two_adic_subgroup(common_data.degree_bits + quotient_degree_bits); + let points = F::two_adic_subgroup(common_data.degree_bits() + quotient_degree_bits); let lde_size = points.len(); - let z_h_on_coset = ZeroPolyOnCoset::new(common_data.degree_bits, quotient_degree_bits); + let z_h_on_coset = ZeroPolyOnCoset::new(common_data.degree_bits(), quotient_degree_bits); let points_batches = points.par_chunks(BATCH_SIZE); let num_batches = ceil_div_usize(points.len(), BATCH_SIZE); diff --git a/plonky2/src/plonk/recursive_verifier.rs b/plonky2/src/plonk/recursive_verifier.rs index 7d901236..bb9076be 100644 --- a/plonky2/src/plonk/recursive_verifier.rs +++ b/plonky2/src/plonk/recursive_verifier.rs @@ -71,7 +71,7 @@ impl, const D: usize> CircuitBuilder { let partial_products = &proof.openings.partial_products; let zeta_pow_deg = - self.exp_power_of_2_extension(challenges.plonk_zeta, inner_common_data.degree_bits); + self.exp_power_of_2_extension(challenges.plonk_zeta, inner_common_data.degree_bits()); let vanishing_polys_zeta = with_context!( self, "evaluate the vanishing polynomial at our challenge point, zeta.", @@ -228,17 +228,17 @@ mod tests { // Start with a degree 2^14 proof let (proof, vd, cd) = dummy_proof::(&config, 16_000)?; - assert_eq!(cd.degree_bits, 14); + assert_eq!(cd.degree_bits(), 14); // Shrink it to 2^13. let (proof, vd, cd) = recursive_proof::(proof, vd, cd, &config, Some(13), false, false)?; - assert_eq!(cd.degree_bits, 13); + assert_eq!(cd.degree_bits(), 13); // Shrink it to 2^12. let (proof, vd, cd) = recursive_proof::(proof, vd, cd, &config, None, true, true)?; - assert_eq!(cd.degree_bits, 12); + assert_eq!(cd.degree_bits(), 12); test_serialization(&proof, &vd, &cd)?; @@ -260,11 +260,11 @@ mod tests { // An initial dummy proof. let (proof, vd, cd) = dummy_proof::(&standard_config, 4_000)?; - assert_eq!(cd.degree_bits, 12); + assert_eq!(cd.degree_bits(), 12); // A standard recursive proof. let (proof, vd, cd) = recursive_proof(proof, vd, cd, &standard_config, None, false, false)?; - assert_eq!(cd.degree_bits, 12); + assert_eq!(cd.degree_bits(), 12); // A high-rate recursive proof, designed to be verifiable with fewer routed wires. let high_rate_config = CircuitConfig { @@ -278,7 +278,7 @@ mod tests { }; let (proof, vd, cd) = recursive_proof::(proof, vd, cd, &high_rate_config, None, true, true)?; - assert_eq!(cd.degree_bits, 12); + assert_eq!(cd.degree_bits(), 12); // A final proof, optimized for size. let final_config = CircuitConfig { @@ -294,7 +294,7 @@ mod tests { }; let (proof, vd, cd) = recursive_proof::(proof, vd, cd, &final_config, None, true, true)?; - assert_eq!(cd.degree_bits, 12, "final proof too large"); + assert_eq!(cd.degree_bits(), 12, "final proof too large"); test_serialization(&proof, &vd, &cd)?; diff --git a/plonky2/src/plonk/verifier.rs b/plonky2/src/plonk/verifier.rs index 78db1b2f..37ddfffa 100644 --- a/plonky2/src/plonk/verifier.rs +++ b/plonky2/src/plonk/verifier.rs @@ -82,7 +82,7 @@ where let quotient_polys_zeta = &proof.openings.quotient_polys; let zeta_pow_deg = challenges .plonk_zeta - .exp_power_of_2(common_data.degree_bits); + .exp_power_of_2(common_data.degree_bits()); let z_h_zeta = zeta_pow_deg - F::Extension::ONE; // `quotient_polys_zeta` holds `num_challenges * quotient_degree_factor` evaluations. // Each chunk of `quotient_degree_factor` holds the evaluations of `t_0(zeta),...,t_{quotient_degree_factor-1}(zeta)`