diff --git a/evm/Cargo.toml b/evm/Cargo.toml index 09a47ee2..848dff15 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -16,6 +16,7 @@ hex-literal = "0.3.4" itertools = "0.10.3" keccak-hash = "0.9.0" log = "0.4.14" +num = "0.4.0" maybe_rayon = { path = "../maybe_rayon" } once_cell = "1.13.0" pest = "2.1.3" diff --git a/evm/spec/tries.tex b/evm/spec/tries.tex index fed78f40..7ec0fcce 100644 --- a/evm/spec/tries.tex +++ b/evm/spec/tries.tex @@ -6,11 +6,21 @@ Withour our zkEVM's kernel memory, \begin{enumerate} \item An empty node is encoded as $(\texttt{MPT\_NODE\_EMPTY})$. - \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, c_1, \dots, c_{16}, v)$, where each $c_i$ is a pointer to a child node, and $v$ is a leaf payload. - \item An extension node is encoded as $(\texttt{MPT\_NODE\_EXTENSION}, k, c)$, $k$ is a 2-tuple $(\texttt{packed\_nibbles}, \texttt{num\_nibbles})$, and $c$ is a pointer to a child node. - \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, v)$, where $k$ is a 2-tuple as above, and $v$ is a leaf payload. + \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, c_1, \dots, c_{16}, \abs{v}, v)$, where each $c_i$ is a pointer to a child node, and $v$ is a value of length $\abs{v}$.\footnote{If a branch node has no associated value, then $\abs{v} = 0$ and $v = ()$.} + \item An extension node is encoded as $(\texttt{MPT\_NODE\_EXTENSION}, k, c)$, $k$ represents the part of the key associated with this extension, and is encoded as a 2-tuple $(\texttt{packed\_nibbles}, \texttt{num\_nibbles})$. $c$ is a pointer to a child node. + \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, \abs{v}, v)$, where $k$ is a 2-tuple as above, and $v$ is a leaf payload. \item A digest node is encoded as $(\texttt{MPT\_NODE\_DIGEST}, d)$, where $d$ is a Keccak256 digest. \end{enumerate} \subsection{Prover input format} + +The initial state of each trie is given by the prover as a nondeterministic input tape. This tape has a similar format: +\begin{enumerate} + \item An empty node is encoded as $(\texttt{MPT\_NODE\_EMPTY})$. + \item A branch node is encoded as $(\texttt{MPT\_NODE\_BRANCH}, \abs{v}, v, c_1, \dots, c_{16})$, where $\abs{v}$ is the length of the value, and $v$ is the value itself. Each $c_i$ is the encoding of a child node. + \item An extension node is encoded as $(\texttt{MPT\_NODE\_EXTENSION}, k, c)$, $k$ represents the part of the key associated with this extension, and is encoded as a 2-tuple $(\texttt{packed\_nibbles}, \texttt{num\_nibbles})$. $c$ is a pointer to a child node. + \item A leaf node is encoded as $(\texttt{MPT\_NODE\_LEAF}, k, \abs{v}, v)$, where $k$ is a 2-tuple as above, and $v$ is a leaf payload. + \item A digest node is encoded as $(\texttt{MPT\_NODE\_DIGEST}, d)$, where $d$ is a Keccak256 digest. +\end{enumerate} +Nodes are thus given in depth-first order, leading to natural recursive methods for encoding and decoding this format. diff --git a/evm/spec/zkevm.pdf b/evm/spec/zkevm.pdf index aff46eda..184ba36b 100644 Binary files a/evm/spec/zkevm.pdf and b/evm/spec/zkevm.pdf differ diff --git a/evm/spec/zkevm.tex b/evm/spec/zkevm.tex index f87f02f3..65766986 100644 --- a/evm/spec/zkevm.tex +++ b/evm/spec/zkevm.tex @@ -28,7 +28,8 @@ \let\subsectionautorefname\sectionautorefname \let\subsubsectionautorefname\sectionautorefname -% \floor{...} and \ceil{...} +% \abs{...}, \floor{...} and \ceil{...} +\DeclarePairedDelimiter\abs{\lvert}{\rvert} \DeclarePairedDelimiter\ceil{\lceil}{\rceil} \DeclarePairedDelimiter\floor{\lfloor}{\rfloor} diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 0c2516e5..26840c5f 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -65,7 +65,7 @@ impl, const D: usize> AllStark { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Table { Cpu = 0, Keccak = 1, @@ -185,12 +185,12 @@ mod tests { use plonky2::field::types::{Field, PrimeField64}; use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; - use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::circuit_data::{CircuitConfig, VerifierCircuitData}; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::util::timing::TimingTree; use rand::{thread_rng, Rng}; - use crate::all_stark::AllStark; + use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::aggregator::KERNEL; @@ -203,8 +203,10 @@ mod tests { use crate::memory::NUM_CHANNELS; use crate::proof::{AllProof, PublicValues}; use crate::prover::prove_with_traces; + use crate::recursive_verifier::tests::recursively_verify_all_proof; use crate::recursive_verifier::{ - add_virtual_all_proof, set_all_proof_target, verify_proof_circuit, + add_virtual_recursive_all_proof, all_verifier_data_recursive_stark_proof, + set_recursive_all_proof_target, RecursiveAllProof, }; use crate::stark::Stark; use crate::util::{limb_from_bits_le, trace_rows_to_poly_values}; @@ -232,7 +234,7 @@ mod tests { ) -> Vec> { keccak_memory_stark.generate_trace( vec![], - 1 << config.fri_config.cap_height, + config.fri_config.num_cap_elements(), &mut TimingTree::default(), ) } @@ -359,6 +361,7 @@ mod tests { let row: &mut cpu::columns::CpuColumnsView = cpu_trace_rows[clock].borrow_mut(); row.clock = F::from_canonical_usize(clock); + dbg!(channel, row.mem_channels.len()); let channel = &mut row.mem_channels[channel]; channel.used = F::ONE; channel.is_read = memory_trace[memory::columns::IS_READ].values[i]; @@ -754,34 +757,42 @@ mod tests { let (all_stark, proof) = get_proof(&config)?; verify_proof(all_stark.clone(), proof.clone(), &config)?; - recursive_proof(all_stark, proof, &config, true) + recursive_proof(all_stark, proof, &config) } fn recursive_proof( inner_all_stark: AllStark, inner_proof: AllProof, inner_config: &StarkConfig, - print_gate_counts: bool, ) -> Result<()> { let circuit_config = CircuitConfig::standard_recursion_config(); + let recursive_all_proof = recursively_verify_all_proof( + &inner_all_stark, + &inner_proof, + inner_config, + &circuit_config, + )?; + + let verifier_data: [VerifierCircuitData; NUM_TABLES] = + all_verifier_data_recursive_stark_proof( + &inner_all_stark, + inner_proof.degree_bits(inner_config), + inner_config, + &circuit_config, + ); + let circuit_config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(circuit_config); let mut pw = PartialWitness::new(); - let degree_bits = inner_proof.degree_bits(inner_config); - let nums_ctl_zs = inner_proof.nums_ctl_zs(); - let pt = add_virtual_all_proof( + let recursive_all_proof_target = + add_virtual_recursive_all_proof(&mut builder, &verifier_data); + set_recursive_all_proof_target(&mut pw, &recursive_all_proof_target, &recursive_all_proof); + RecursiveAllProof::verify_circuit( &mut builder, - &inner_all_stark, + recursive_all_proof_target, + &verifier_data, + inner_all_stark.cross_table_lookups, inner_config, - °ree_bits, - &nums_ctl_zs, ); - set_all_proof_target(&mut pw, &pt, &inner_proof, builder.zero()); - - verify_proof_circuit::(&mut builder, inner_all_stark, pt, inner_config); - - if print_gate_counts { - builder.print_gate_counts(0); - } let data = builder.build::(); let proof = data.prove(pw)?; diff --git a/evm/src/arithmetic/add.rs b/evm/src/arithmetic/add.rs index e87566b6..d2520fb9 100644 --- a/evm/src/arithmetic/add.rs +++ b/evm/src/arithmetic/add.rs @@ -165,6 +165,8 @@ mod tests { use crate::arithmetic::columns::NUM_ARITH_COLUMNS; use crate::constraint_consumer::ConstraintConsumer; + const N_RND_TESTS: usize = 1000; + // TODO: Should be able to refactor this test to apply to all operations. #[test] fn generate_eval_consistency_not_add() { @@ -177,14 +179,14 @@ mod tests { // if all values are garbage. lv[IS_ADD] = F::ZERO; - let mut constrant_consumer = ConstraintConsumer::new( + let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], GoldilocksField::ONE, GoldilocksField::ONE, GoldilocksField::ONE, ); - eval_packed_generic(&lv, &mut constrant_consumer); - for &acc in &constrant_consumer.constraint_accs { + eval_packed_generic(&lv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { assert_eq!(acc, GoldilocksField::ZERO); } } @@ -198,23 +200,26 @@ mod tests { // set `IS_ADD == 1` and ensure all constraints are satisfied. lv[IS_ADD] = F::ONE; - // set inputs to random values - for (&ai, bi) in ADD_INPUT_0.iter().zip(ADD_INPUT_1) { - lv[ai] = F::from_canonical_u16(rng.gen()); - lv[bi] = F::from_canonical_u16(rng.gen()); - } - generate(&mut lv); + for _ in 0..N_RND_TESTS { + // set inputs to random values + for (&ai, bi) in ADD_INPUT_0.iter().zip(ADD_INPUT_1) { + lv[ai] = F::from_canonical_u16(rng.gen()); + lv[bi] = F::from_canonical_u16(rng.gen()); + } - let mut constrant_consumer = ConstraintConsumer::new( - vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], - GoldilocksField::ONE, - GoldilocksField::ONE, - GoldilocksField::ONE, - ); - eval_packed_generic(&lv, &mut constrant_consumer); - for &acc in &constrant_consumer.constraint_accs { - assert_eq!(acc, GoldilocksField::ZERO); + generate(&mut lv); + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ONE, + ); + eval_packed_generic(&lv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } } } } diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 58b8afff..fc168cee 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -9,6 +9,7 @@ use plonky2::hash::hash_types::RichField; use crate::arithmetic::add; use crate::arithmetic::columns; use crate::arithmetic::compare; +use crate::arithmetic::modular; use crate::arithmetic::mul; use crate::arithmetic::sub; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; @@ -50,6 +51,12 @@ impl ArithmeticStark { compare::generate(local_values, columns::IS_LT); } else if local_values[columns::IS_GT].is_one() { compare::generate(local_values, columns::IS_GT); + } else if local_values[columns::IS_ADDMOD].is_one() { + modular::generate(local_values, columns::IS_ADDMOD); + } else if local_values[columns::IS_MULMOD].is_one() { + modular::generate(local_values, columns::IS_MULMOD); + } else if local_values[columns::IS_MOD].is_one() { + modular::generate(local_values, columns::IS_MOD); } else { todo!("the requested operation has not yet been implemented"); } @@ -72,6 +79,7 @@ impl, const D: usize> Stark for ArithmeticSta sub::eval_packed_generic(lv, yield_constr); mul::eval_packed_generic(lv, yield_constr); compare::eval_packed_generic(lv, yield_constr); + modular::eval_packed_generic(lv, yield_constr); } fn eval_ext_circuit( @@ -85,6 +93,7 @@ impl, const D: usize> Stark for ArithmeticSta sub::eval_ext_circuit(builder, lv, yield_constr); mul::eval_ext_circuit(builder, lv, yield_constr); compare::eval_ext_circuit(builder, lv, yield_constr); + modular::eval_ext_circuit(builder, lv, yield_constr); } fn constraint_degree(&self) -> usize { diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index 7b44adc1..ca8ba549 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -44,7 +44,7 @@ pub(crate) const ALL_OPERATIONS: [usize; 16] = [ /// used by any arithmetic circuit, depending on which one is active /// this cycle. Can be increased as needed as other operations are /// implemented. -const NUM_SHARED_COLS: usize = 64; +const NUM_SHARED_COLS: usize = 144; // only need 64 for add, sub, and mul const fn shared_col(i: usize) -> usize { assert!(i < NUM_SHARED_COLS); @@ -64,7 +64,10 @@ const fn gen_input_cols(start: usize) -> [usize; N] { const GENERAL_INPUT_0: [usize; N_LIMBS] = gen_input_cols::(0); const GENERAL_INPUT_1: [usize; N_LIMBS] = gen_input_cols::(N_LIMBS); const GENERAL_INPUT_2: [usize; N_LIMBS] = gen_input_cols::(2 * N_LIMBS); -const AUX_INPUT_0: [usize; N_LIMBS] = gen_input_cols::(3 * N_LIMBS); +const GENERAL_INPUT_3: [usize; N_LIMBS] = gen_input_cols::(3 * N_LIMBS); +const AUX_INPUT_0: [usize; 2 * N_LIMBS] = gen_input_cols::<{ 2 * N_LIMBS }>(4 * N_LIMBS); +const AUX_INPUT_1: [usize; 2 * N_LIMBS] = gen_input_cols::<{ 2 * N_LIMBS }>(6 * N_LIMBS); +const AUX_INPUT_2: [usize; N_LIMBS] = gen_input_cols::(8 * N_LIMBS); pub(crate) const ADD_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; pub(crate) const ADD_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; @@ -77,11 +80,21 @@ pub(crate) const SUB_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_2; pub(crate) const MUL_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; pub(crate) const MUL_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; pub(crate) const MUL_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_2; -pub(crate) const MUL_AUX_INPUT: [usize; N_LIMBS] = AUX_INPUT_0; +pub(crate) const MUL_AUX_INPUT: [usize; N_LIMBS] = GENERAL_INPUT_3; pub(crate) const CMP_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; pub(crate) const CMP_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; pub(crate) const CMP_OUTPUT: usize = GENERAL_INPUT_2[0]; -pub(crate) const CMP_AUX_INPUT: [usize; N_LIMBS] = AUX_INPUT_0; +pub(crate) const CMP_AUX_INPUT: [usize; N_LIMBS] = GENERAL_INPUT_3; + +pub(crate) const MODULAR_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; +pub(crate) const MODULAR_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; +pub(crate) const MODULAR_MODULUS: [usize; N_LIMBS] = GENERAL_INPUT_2; +pub(crate) const MODULAR_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_3; +pub(crate) const MODULAR_QUO_INPUT: [usize; 2 * N_LIMBS] = AUX_INPUT_0; +// NB: Last value is not used in AUX, it is used in IS_ZERO +pub(crate) const MODULAR_AUX_INPUT: [usize; 2 * N_LIMBS] = AUX_INPUT_1; +pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_INPUT_1[2 * N_LIMBS - 1]; +pub(crate) const MODULAR_OUT_AUX_RED: [usize; N_LIMBS] = AUX_INPUT_2; pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS; diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index 69fbda09..a6f59446 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -1,5 +1,6 @@ mod add; mod compare; +mod modular; mod mul; mod sub; mod utils; diff --git a/evm/src/arithmetic/modular.rs b/evm/src/arithmetic/modular.rs new file mode 100644 index 00000000..1fd31bb1 --- /dev/null +++ b/evm/src/arithmetic/modular.rs @@ -0,0 +1,593 @@ +//! Support for the EVM modular instructions ADDMOD, MULMOD and MOD. +//! +//! This crate verifies an EVM modular instruction, which takes three +//! 256-bit inputs A, B and M, and produces a 256-bit output C satisfying +//! +//! C = operation(A, B) (mod M). +//! +//! where operation can be addition, multiplication, or just return +//! the first argument (for MOD). Inputs A, B and M, and output C, +//! are given as arrays of 16-bit limbs. For example, if the limbs of +//! A are a[0]...a[15], then +//! +//! A = \sum_{i=0}^15 a[i] β^i, +//! +//! where β = 2^16 = 2^LIMB_BITS. To verify that A, B, M and C satisfy +//! the equation we proceed as follows. Define +//! +//! a(x) = \sum_{i=0}^15 a[i] x^i +//! +//! (so A = a(β)) and similarly for b(x), m(x) and c(x). Then +//! operation(A,B) = C (mod M) if and only if the polynomial +//! +//! operation(a(x), b(x)) - c(x) - m(x) * q(x) +//! +//! is zero when evaluated at x = β, i.e. it is divisible by (x - β). +//! Thus exists a polynomial s such that +//! +//! operation(a(x), b(x)) - c(x) - m(x) * q(x) - (x - β) * s(x) == 0 +//! +//! if and only if operation(A,B) = C (mod M). In the code below, this +//! "constraint polynomial" is constructed in the variable +//! `constr_poly`. It must be identically zero for the modular +//! operation to be verified, or, equivalently, each of its +//! coefficients must be zero. The variable names of the constituent +//! polynomials are (writing N for N_LIMBS=16): +//! +//! a(x) = \sum_{i=0}^{N-1} input0[i] * β^i +//! b(x) = \sum_{i=0}^{N-1} input1[i] * β^i +//! c(x) = \sum_{i=0}^{N-1} output[i] * β^i +//! m(x) = \sum_{i=0}^{N-1} modulus[i] * β^i +//! q(x) = \sum_{i=0}^{2N-1} quot[i] * β^i +//! s(x) = \sum_i^{2N-2} aux[i] * β^i +//! +//! Because A, B, M and C are 256-bit numbers, the degrees of a, b, m +//! and c are (at most) N-1 = 15. If m = 1, then Q would be A*B which +//! can be up to 2^512 - ε, so deg(q) can be up to 2*N-1 = 31. Note +//! that, although for arbitrary m and q we might have deg(m*q) = 3*N-2, +//! because the magnitude of M*Q must match that of operation(A,B), we +//! always have deg(m*q) <= 2*N-1. Finally, in order for all the degrees +//! to match, we have deg(s) <= 2*N-2 = 30. +//! +//! -*- +//! +//! To verify that the output is reduced, that is, output < modulus, +//! the prover supplies the value `out_aux_red` which must satisfy +//! +//! output - modulus = out_aux_red + 2^256 +//! +//! and these values are passed to the "less than" operation. +//! +//! -*- +//! +//! The EVM defines division by zero as zero. We handle this as +//! follows: +//! +//! The prover supplies a binary value `mod_is_zero` which is one if +//! the modulus is zero and zero otherwise. This is verified, then +//! added to the modulus (this can't overflow, as modulus[0] was +//! range-checked and mod_is_zero is 0 or 1). The rest of the +//! calculation proceeds as if modulus was actually 1; this correctly +//! verifies that the output is zero, as required by the standard. +//! To summarise: +//! +//! - mod_is_zero is 0 or 1 +//! - if mod_is_zero is 1, then +//! - given modulus is 0 +//! - updated modulus is 1, which forces the correct output of 0 +//! - if mod_is_zero is 0, then +//! - given modulus can be 0 or non-zero +//! - updated modulus is same as given +//! - if modulus is non-zero, correct output is obtained +//! - if modulus is 0, then the test output < modulus, checking that +//! the output is reduced, will fail, because output is non-negative. + +use num::{BigUint, Zero}; +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use super::columns; +use crate::arithmetic::columns::*; +use crate::arithmetic::compare::{eval_ext_circuit_lt, eval_packed_generic_lt}; +use crate::arithmetic::utils::*; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::range_check_error; + +/// Convert the base-2^16 representation of a number into a BigUint. +/// +/// Given `N` unsigned 16-bit values in `limbs`, return the BigUint +/// +/// \sum_{i=0}^{N-1} limbs[i] * β^i. +/// +fn columns_to_biguint(limbs: &[i64; N]) -> BigUint { + const BASE: i64 = 1i64 << LIMB_BITS; + + // Although the input type is i64, the values must always be in + // [0, 2^16 + ε) because of the caller's range check on the inputs + // (the ε allows us to convert calculated output, which can be + // bigger than 2^16). + debug_assert!(limbs.iter().all(|&x| x >= 0)); + + let mut limbs_u32 = Vec::with_capacity(N / 2 + 1); + let mut cy = 0i64; // cy is necessary to handle ε > 0 + for i in 0..(N / 2) { + let t = cy + limbs[2 * i] + BASE * limbs[2 * i + 1]; + limbs_u32.push(t as u32); + cy = t >> 32; + } + if N & 1 != 0 { + // If N is odd we need to add the last limb on its own + let t = cy + limbs[N - 1]; + limbs_u32.push(t as u32); + cy = t >> 32; + } + limbs_u32.push(cy as u32); + + BigUint::from_slice(&limbs_u32) +} + +/// Convert a BigUint into a base-2^16 representation. +/// +/// Given a BigUint `num`, return an array of `N` unsigned 16-bit +/// values, say `limbs`, such that +/// +/// num = \sum_{i=0}^{N-1} limbs[i] * β^i. +/// +/// Note that `N` must be at least ceil(log2(num)/16) in order to be +/// big enough to hold `num`. +fn biguint_to_columns(num: &BigUint) -> [i64; N] { + assert!(num.bits() <= 16 * N as u64); + let mut output = [0i64; N]; + for (i, limb) in num.iter_u32_digits().enumerate() { + output[2 * i] = limb as u16 as i64; + output[2 * i + 1] = (limb >> LIMB_BITS) as i64; + } + output +} + +/// Generate the output and auxiliary values for given `operation`. +/// +/// NB: `operation` can set the higher order elements in its result to +/// zero if they are not used. +fn generate_modular_op( + lv: &mut [F; NUM_ARITH_COLUMNS], + operation: fn([i64; N_LIMBS], [i64; N_LIMBS]) -> [i64; 2 * N_LIMBS - 1], +) { + // Inputs are all range-checked in [0, 2^16), so the "as i64" + // conversion is safe. + let input0_limbs = MODULAR_INPUT_0.map(|c| F::to_canonical_u64(&lv[c]) as i64); + let input1_limbs = MODULAR_INPUT_1.map(|c| F::to_canonical_u64(&lv[c]) as i64); + let mut modulus_limbs = MODULAR_MODULUS.map(|c| F::to_canonical_u64(&lv[c]) as i64); + + // The use of BigUints is just to avoid having to implement + // modular reduction. + let mut modulus = columns_to_biguint(&modulus_limbs); + + // constr_poly is initialised to the calculated input, and is + // used as such for the BigUint reduction; later, other values are + // added/subtracted, which is where its meaning as the "constraint + // polynomial" comes in. + let mut constr_poly = [0i64; 2 * N_LIMBS]; + constr_poly[..2 * N_LIMBS - 1].copy_from_slice(&operation(input0_limbs, input1_limbs)); + + if modulus.is_zero() { + modulus += 1u32; + modulus_limbs[0] += 1i64; + lv[MODULAR_MOD_IS_ZERO] = F::ONE; + } else { + lv[MODULAR_MOD_IS_ZERO] = F::ZERO; + } + + let input = columns_to_biguint(&constr_poly); + + // modulus != 0 here, because, if the given modulus was zero, then + // we added 1 to it above. + let output = &input % &modulus; + let output_limbs = biguint_to_columns::(&output); + let quot = (&input - &output) / &modulus; // exact division + let quot_limbs = biguint_to_columns::<{ 2 * N_LIMBS }>("); + + // two_exp_256 == 2^256 + let mut two_exp_256 = BigUint::zero(); + two_exp_256.set_bit(256, true); + // output < modulus here, so the proof requires (output - modulus) % 2^256: + let out_aux_red = biguint_to_columns::(&(two_exp_256 + output - modulus)); + + // constr_poly is the array of coefficients of the polynomial + // + // operation(a(x), b(x)) - c(x) - s(x)*m(x). + // + pol_sub_assign(&mut constr_poly, &output_limbs); + let prod = pol_mul_wide2(quot_limbs, modulus_limbs); + pol_sub_assign(&mut constr_poly, &prod[0..2 * N_LIMBS]); + + // Higher order terms of the product must be zero for valid quot and modulus: + debug_assert!(&prod[2 * N_LIMBS..].iter().all(|&x| x == 0i64)); + + // constr_poly must be zero when evaluated at x = β := + // 2^LIMB_BITS, hence it's divisible by (x - β). `aux_limbs` is + // the result of removing that root. + let aux_limbs = pol_remove_root_2exp::(constr_poly); + + for deg in 0..N_LIMBS { + lv[MODULAR_OUTPUT[deg]] = F::from_canonical_i64(output_limbs[deg]); + lv[MODULAR_OUT_AUX_RED[deg]] = F::from_canonical_i64(out_aux_red[deg]); + lv[MODULAR_QUO_INPUT[deg]] = F::from_canonical_i64(quot_limbs[deg]); + lv[MODULAR_QUO_INPUT[deg + N_LIMBS]] = F::from_canonical_i64(quot_limbs[deg + N_LIMBS]); + lv[MODULAR_AUX_INPUT[deg]] = F::from_noncanonical_i64(aux_limbs[deg]); + // Don't overwrite MODULAR_MOD_IS_ZERO, which is at the last + // index of MODULAR_AUX_INPUT + if deg < N_LIMBS - 1 { + lv[MODULAR_AUX_INPUT[deg + N_LIMBS]] = + F::from_noncanonical_i64(aux_limbs[deg + N_LIMBS]); + } + } +} + +/// Generate the output and auxiliary values for modular operations. +/// +/// `filter` must be one of `columns::IS_{ADDMOD,MULMOD,MOD}`. +pub(crate) fn generate(lv: &mut [F; NUM_ARITH_COLUMNS], filter: usize) { + match filter { + columns::IS_ADDMOD => generate_modular_op(lv, pol_add), + columns::IS_MULMOD => generate_modular_op(lv, pol_mul_wide), + columns::IS_MOD => generate_modular_op(lv, |a, _| pol_extend(a)), + _ => panic!("generate modular operation called with unknown opcode"), + } +} + +/// Build the part of the constraint polynomial that's common to all +/// modular operations, and perform the common verifications. +/// +/// Specifically, with the notation above, build the polynomial +/// +/// c(x) + q(x) * m(x) + (x - β) * s(x) +/// +/// and check consistency when m = 0, and that c is reduced. +#[allow(clippy::needless_range_loop)] +fn modular_constr_poly( + lv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, + filter: P, +) -> [P; 2 * N_LIMBS] { + range_check_error!(MODULAR_INPUT_0, 16); + range_check_error!(MODULAR_INPUT_1, 16); + range_check_error!(MODULAR_MODULUS, 16); + range_check_error!(MODULAR_QUO_INPUT, 16); + range_check_error!(MODULAR_AUX_INPUT, 20, signed); + range_check_error!(MODULAR_OUTPUT, 16); + + let mut modulus = MODULAR_MODULUS.map(|c| lv[c]); + let mod_is_zero = lv[MODULAR_MOD_IS_ZERO]; + + // Check that mod_is_zero is zero or one + yield_constr.constraint(filter * (mod_is_zero * mod_is_zero - mod_is_zero)); + + // Check that mod_is_zero is zero if modulus is not zero (they + // could both be zero) + let limb_sum = modulus.into_iter().sum::

(); + yield_constr.constraint(filter * limb_sum * mod_is_zero); + + // See the file documentation for why this suffices to handle + // modulus = 0. + modulus[0] += mod_is_zero; + + let output = MODULAR_OUTPUT.map(|c| lv[c]); + + // Verify that the output is reduced, i.e. output < modulus. + let out_aux_red = MODULAR_OUT_AUX_RED.map(|c| lv[c]); + let is_less_than = P::ONES; + eval_packed_generic_lt( + yield_constr, + filter, + output, + modulus, + out_aux_red, + is_less_than, + ); + + // prod = q(x) * m(x) + let quot = MODULAR_QUO_INPUT.map(|c| lv[c]); + let prod = pol_mul_wide2(quot, modulus); + // higher order terms must be zero + for &x in prod[2 * N_LIMBS..].iter() { + yield_constr.constraint(filter * x); + } + + // constr_poly = c(x) + q(x) * m(x) + let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); + pol_add_assign(&mut constr_poly, &output); + + // constr_poly = c(x) + q(x) * m(x) + (x - β) * s(x) + let aux = MODULAR_AUX_INPUT.map(|c| lv[c]); + let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); + pol_add_assign(&mut constr_poly, &pol_adjoin_root(aux, base)); + + constr_poly +} + +/// Add constraints for modular operations. +pub(crate) fn eval_packed_generic( + lv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + // NB: The CTL code guarantees that filter is 0 or 1, i.e. that + // only one of the operations below is "live". + let filter = lv[columns::IS_ADDMOD] + lv[columns::IS_MULMOD] + lv[columns::IS_MOD]; + + // constr_poly has 2*N_LIMBS limbs + let constr_poly = modular_constr_poly(lv, yield_constr, filter); + + let input0 = MODULAR_INPUT_0.map(|c| lv[c]); + let input1 = MODULAR_INPUT_1.map(|c| lv[c]); + + let add_input = pol_add(input0, input1); + let mul_input = pol_mul_wide(input0, input1); + let mod_input = pol_extend(input0); + + for (input, &filter) in [ + (&add_input, &lv[columns::IS_ADDMOD]), + (&mul_input, &lv[columns::IS_MULMOD]), + (&mod_input, &lv[columns::IS_MOD]), + ] { + // Need constr_poly_copy to be the first argument to + // pol_sub_assign, since it is the longer of the two + // arguments. + let mut constr_poly_copy = constr_poly; + pol_sub_assign(&mut constr_poly_copy, input); + + // At this point constr_poly_copy holds the coefficients of + // the polynomial + // + // operation(a(x), b(x)) - c(x) - q(x) * m(x) - (x - β) * s(x) + // + // where operation is add, mul or |a,b|->a. The modular + // operation is valid if and only if all of those coefficients + // are zero. + for &c in constr_poly_copy.iter() { + yield_constr.constraint(filter * c); + } + } +} + +fn modular_constr_poly_ext_circuit, const D: usize>( + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + builder: &mut CircuitBuilder, + yield_constr: &mut RecursiveConstraintConsumer, + filter: ExtensionTarget, +) -> [ExtensionTarget; 2 * N_LIMBS] { + let mut modulus = MODULAR_MODULUS.map(|c| lv[c]); + let mod_is_zero = lv[MODULAR_MOD_IS_ZERO]; + + let t = builder.mul_sub_extension(mod_is_zero, mod_is_zero, mod_is_zero); + let t = builder.mul_extension(filter, t); + yield_constr.constraint(builder, t); + + let limb_sum = builder.add_many_extension(modulus); + let t = builder.mul_extension(limb_sum, mod_is_zero); + let t = builder.mul_extension(filter, t); + yield_constr.constraint(builder, t); + + modulus[0] = builder.add_extension(modulus[0], mod_is_zero); + + let output = MODULAR_OUTPUT.map(|c| lv[c]); + let out_aux_red = MODULAR_OUT_AUX_RED.map(|c| lv[c]); + let is_less_than = builder.one_extension(); + eval_ext_circuit_lt( + builder, + yield_constr, + filter, + output, + modulus, + out_aux_red, + is_less_than, + ); + + let quot = MODULAR_QUO_INPUT.map(|c| lv[c]); + let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus); + for &x in prod[2 * N_LIMBS..].iter() { + let t = builder.mul_extension(filter, x); + yield_constr.constraint(builder, t); + } + + let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); + pol_add_assign_ext_circuit(builder, &mut constr_poly, &output); + + let aux = MODULAR_AUX_INPUT.map(|c| lv[c]); + let base = builder.constant_extension(F::Extension::from_canonical_u64(1u64 << LIMB_BITS)); + let t = pol_adjoin_root_ext_circuit(builder, aux, base); + pol_add_assign_ext_circuit(builder, &mut constr_poly, &t); + + constr_poly +} + +pub(crate) fn eval_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let filter = builder.add_many_extension([ + lv[columns::IS_ADDMOD], + lv[columns::IS_MULMOD], + lv[columns::IS_MOD], + ]); + + let constr_poly = modular_constr_poly_ext_circuit(lv, builder, yield_constr, filter); + + let input0 = MODULAR_INPUT_0.map(|c| lv[c]); + let input1 = MODULAR_INPUT_1.map(|c| lv[c]); + + let add_input = pol_add_ext_circuit(builder, input0, input1); + let mul_input = pol_mul_wide_ext_circuit(builder, input0, input1); + let mod_input = pol_extend_ext_circuit(builder, input0); + + for (input, &filter) in [ + (&add_input, &lv[columns::IS_ADDMOD]), + (&mul_input, &lv[columns::IS_MULMOD]), + (&mod_input, &lv[columns::IS_MOD]), + ] { + let mut constr_poly_copy = constr_poly; + pol_sub_assign_ext_circuit(builder, &mut constr_poly_copy, input); + for &c in constr_poly_copy.iter() { + let t = builder.mul_extension(filter, c); + yield_constr.constraint(builder, t); + } + } +} + +#[cfg(test)] +mod tests { + use itertools::izip; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::Field; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha8Rng; + + use super::*; + use crate::arithmetic::columns::NUM_ARITH_COLUMNS; + use crate::constraint_consumer::ConstraintConsumer; + + const N_RND_TESTS: usize = 1000; + + // TODO: Should be able to refactor this test to apply to all operations. + #[test] + fn generate_eval_consistency_not_modular() { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::rand_from_rng(&mut rng)); + + // if `IS_ADDMOD == 0`, then the constraints should be met even + // if all values are garbage. + lv[IS_ADDMOD] = F::ZERO; + lv[IS_MULMOD] = F::ZERO; + lv[IS_MOD] = F::ZERO; + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ONE, + ); + eval_packed_generic(&lv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + + #[test] + fn generate_eval_consistency() { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::rand_from_rng(&mut rng)); + + for op_filter in [IS_ADDMOD, IS_MOD, IS_MULMOD] { + // Reset operation columns, then select one + lv[IS_ADDMOD] = F::ZERO; + lv[IS_MULMOD] = F::ZERO; + lv[IS_MOD] = F::ZERO; + lv[op_filter] = F::ONE; + + for i in 0..N_RND_TESTS { + // set inputs to random values + for (&ai, &bi, &mi) in izip!( + MODULAR_INPUT_0.iter(), + MODULAR_INPUT_1.iter(), + MODULAR_MODULUS.iter() + ) { + lv[ai] = F::from_canonical_u16(rng.gen()); + lv[bi] = F::from_canonical_u16(rng.gen()); + lv[mi] = F::from_canonical_u16(rng.gen()); + } + + // For the second half of the tests, set the top 16 - + // start digits of the modulus to zero so it is much + // smaller than the inputs. + if i > N_RND_TESTS / 2 { + // 1 <= start < N_LIMBS + let start = (rng.gen::() % (N_LIMBS - 1)) + 1; + for &mi in &MODULAR_MODULUS[start..N_LIMBS] { + lv[mi] = F::ZERO; + } + } + + generate(&mut lv, op_filter); + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ONE, + ); + eval_packed_generic(&lv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + } + } + + #[test] + fn zero_modulus() { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::rand_from_rng(&mut rng)); + + for op_filter in [IS_ADDMOD, IS_MOD, IS_MULMOD] { + // Reset operation columns, then select one + lv[IS_ADDMOD] = F::ZERO; + lv[IS_MULMOD] = F::ZERO; + lv[IS_MOD] = F::ZERO; + lv[op_filter] = F::ONE; + + for _i in 0..N_RND_TESTS { + // set inputs to random values and the modulus to zero; + // the output is defined to be zero when modulus is zero. + for (&ai, &bi, &mi) in izip!( + MODULAR_INPUT_0.iter(), + MODULAR_INPUT_1.iter(), + MODULAR_MODULUS.iter() + ) { + lv[ai] = F::from_canonical_u16(rng.gen()); + lv[bi] = F::from_canonical_u16(rng.gen()); + lv[mi] = F::ZERO; + } + + generate(&mut lv, op_filter); + + // check that the correct output was generated + assert!(MODULAR_OUTPUT.iter().all(|&oi| lv[oi] == F::ZERO)); + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ONE, + ); + eval_packed_generic(&lv, &mut constraint_consumer); + assert!(constraint_consumer + .constraint_accs + .iter() + .all(|&acc| acc == F::ZERO)); + + // Corrupt one output limb by setting it to a non-zero value + let random_oi = MODULAR_OUTPUT[rng.gen::() % N_LIMBS]; + lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); + + eval_packed_generic(&lv, &mut constraint_consumer); + + // Check that at least one of the constraints was non-zero + assert!(constraint_consumer + .constraint_accs + .iter() + .any(|&acc| acc != F::ZERO)); + } + } + } +} diff --git a/evm/src/arithmetic/mul.rs b/evm/src/arithmetic/mul.rs index 72270517..9d6638f1 100644 --- a/evm/src/arithmetic/mul.rs +++ b/evm/src/arithmetic/mul.rs @@ -15,7 +15,7 @@ //! and similarly for b(x) and c(x). Then A*B = C (mod 2^256) if and only //! if there exist polynomials q and m such that //! -//! a(x)*b(x) - c(x) - m(x)*x^16 - (x - β)*q(x) == 0. +//! a(x)*b(x) - c(x) - m(x)*x^16 - (β - x)*q(x) == 0. //! //! Because A, B and C are 256-bit numbers, the degrees of a, b and c //! are (at most) 15. Thus deg(a*b) <= 30, so deg(m) <= 14 and deg(q) @@ -24,7 +24,7 @@ //! them evaluating at β gives a factor of β^16 = 2^256 which is 0. //! //! Hence, to verify the equality, we don't need m(x) at all, and we -//! only need to know q(x) up to degree 14 (so that (x-β)*q(x) has +//! only need to know q(x) up to degree 14 (so that (β - x)*q(x) has //! degree 15). On the other hand, the coefficients of q(x) can be as //! large as 16*(β-2) or 20 bits. @@ -35,6 +35,7 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use crate::arithmetic::columns::*; +use crate::arithmetic::utils::{pol_mul_lo, pol_sub_assign}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::range_check_error; @@ -48,26 +49,17 @@ pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { let mut aux_in_limbs = [0u64; N_LIMBS]; let mut output_limbs = [0u64; N_LIMBS]; - let mut unreduced_prod = [0u64; N_LIMBS]; - // Column-wise pen-and-paper long multiplication on 16-bit limbs. - // We have heaps of space at the top of each limb, so by - // calculating column-wise (instead of the usual row-wise) we - // avoid a bunch of carry propagation handling (at the expense of - // slightly worse cache coherency), and it makes it easy to - // calculate the coefficients of a(x)*b(x) (in unreduced_prod). + // First calculate the coefficients of a(x)*b(x) (in unreduced_prod), + // then do carry propagation to obtain C = c(β) = a(β)*b(β). let mut cy = 0u64; + let mut unreduced_prod = pol_mul_lo(input0_limbs, input1_limbs); for col in 0..N_LIMBS { - for i in 0..=col { - // Invariant: i + j = col - let j = col - i; - let ai_x_bj = input0_limbs[i] * input1_limbs[j]; - unreduced_prod[col] += ai_x_bj; - } let t = unreduced_prod[col] + cy; cy = t >> LIMB_BITS; output_limbs[col] = t & MASK; } + // In principle, the last cy could be dropped because this is // multiplication modulo 2^256. However, we need it below for // aux_in_limbs to handle the fact that unreduced_prod will @@ -76,23 +68,22 @@ pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { for (&c, output_limb) in MUL_OUTPUT.iter().zip(output_limbs) { lv[c] = F::from_canonical_u64(output_limb); } - for deg in 0..N_LIMBS { - // deg'th element <- a*b - c - unreduced_prod[deg] -= output_limbs[deg]; - } + pol_sub_assign(&mut unreduced_prod, &output_limbs); // unreduced_prod is the coefficients of the polynomial a(x)*b(x) - c(x). - // This must be zero when evaluated at x = B = 2^LIMB_BITS, hence it's - // divisible by (B - x). If we write unreduced_prod as + // This must be zero when evaluated at x = β = 2^LIMB_BITS, hence it's + // divisible by (β - x). If we write unreduced_prod as // - // a(x)*b(x) - c(x) = \sum_{i=0}^n p_i x^i - // = (B - x) \sum_{i=0}^{n-1} q_i x^i + // a(x)*b(x) - c(x) = \sum_{i=0}^n p_i x^i + terms of degree > n + // = (β - x) \sum_{i=0}^{n-1} q_i x^i + terms of degree > n // // then by comparing coefficients it is easy to see that // - // q_0 = p_0 / B and q_i = (p_i + q_{i-1}) / B + // q_0 = p_0 / β and q_i = (p_i + q_{i-1}) / β // - // for 0 < i < n-1 (and the divisions are exact). + // for 0 < i < n-1 (and the divisions are exact). Because we're + // only calculating the result modulo 2^256, we can ignore the + // terms of degree > n = 15. aux_in_limbs[0] = unreduced_prod[0] >> LIMB_BITS; for deg in 1..N_LIMBS - 1 { aux_in_limbs[deg] = (unreduced_prod[deg] + aux_in_limbs[deg - 1]) >> LIMB_BITS; @@ -122,14 +113,10 @@ pub fn eval_packed_generic( // Constraint poly holds the coefficients of the polynomial that // must be identically zero for this multiplication to be - // verified. It is initialised to the /negative/ of the claimed - // output. - let mut constr_poly = [P::ZEROS; N_LIMBS]; - - assert_eq!(constr_poly.len(), N_LIMBS); - - // After this loop constr_poly holds the coefficients of the - // polynomial A(x)B(x) - C(x), where A, B and C are the polynomials + // verified. + // + // These two lines set constr_poly to the polynomial A(x)B(x) - C(x), + // where A, B and C are the polynomials // // A(x) = \sum_i input0_limbs[i] * 2^LIMB_BITS // B(x) = \sum_i input1_limbs[i] * 2^LIMB_BITS @@ -139,14 +126,8 @@ pub fn eval_packed_generic( // // Q(x) = \sum_i aux_limbs[i] * 2^LIMB_BITS // - for col in 0..N_LIMBS { - // Invariant: i + j = col - for i in 0..=col { - let j = col - i; - constr_poly[col] += input0_limbs[i] * input1_limbs[j]; - } - constr_poly[col] -= output_limbs[col]; - } + let mut constr_poly = pol_mul_lo(input0_limbs, input1_limbs); + pol_sub_assign(&mut constr_poly, &output_limbs); // This subtracts (2^LIMB_BITS - x) * Q(x) from constr_poly. let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); @@ -156,7 +137,7 @@ pub fn eval_packed_generic( } // At this point constr_poly holds the coefficients of the - // polynomial A(x)B(x) - C(x) - (x - 2^LIMB_BITS)*Q(x). The + // polynomial A(x)B(x) - C(x) - (2^LIMB_BITS - x)*Q(x). The // multiplication is valid if and only if all of those // coefficients are zero. for &c in &constr_poly { @@ -189,12 +170,20 @@ pub fn eval_ext_circuit, const D: usize>( } let base = F::from_canonical_u64(1 << LIMB_BITS); - let t = builder.mul_const_extension(base, aux_in_limbs[0]); - constr_poly[0] = builder.sub_extension(constr_poly[0], t); + let one = builder.one_extension(); + // constr_poly[0] = constr_poly[0] - base * aux_in_limbs[0] + constr_poly[0] = + builder.arithmetic_extension(F::ONE, -base, constr_poly[0], one, aux_in_limbs[0]); for deg in 1..N_LIMBS { - let t0 = builder.mul_const_extension(base, aux_in_limbs[deg]); - let t1 = builder.sub_extension(t0, aux_in_limbs[deg - 1]); - constr_poly[deg] = builder.sub_extension(constr_poly[deg], t1); + // constr_poly[deg] -= (base*aux_in_limbs[deg] - aux_in_limbs[deg-1]) + let t = builder.arithmetic_extension( + base, + F::NEG_ONE, + aux_in_limbs[deg], + one, + aux_in_limbs[deg - 1], + ); + constr_poly[deg] = builder.sub_extension(constr_poly[deg], t); } for &c in &constr_poly { @@ -214,6 +203,8 @@ mod tests { use crate::arithmetic::columns::NUM_ARITH_COLUMNS; use crate::constraint_consumer::ConstraintConsumer; + const N_RND_TESTS: usize = 1000; + // TODO: Should be able to refactor this test to apply to all operations. #[test] fn generate_eval_consistency_not_mul() { @@ -226,14 +217,14 @@ mod tests { // if all values are garbage. lv[IS_MUL] = F::ZERO; - let mut constrant_consumer = ConstraintConsumer::new( + let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], GoldilocksField::ONE, GoldilocksField::ONE, GoldilocksField::ONE, ); - eval_packed_generic(&lv, &mut constrant_consumer); - for &acc in &constrant_consumer.constraint_accs { + eval_packed_generic(&lv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { assert_eq!(acc, GoldilocksField::ZERO); } } @@ -247,23 +238,26 @@ mod tests { // set `IS_MUL == 1` and ensure all constraints are satisfied. lv[IS_MUL] = F::ONE; - // set inputs to random values - for (&ai, bi) in MUL_INPUT_0.iter().zip(MUL_INPUT_1) { - lv[ai] = F::from_canonical_u16(rng.gen()); - lv[bi] = F::from_canonical_u16(rng.gen()); - } - generate(&mut lv); + for _i in 0..N_RND_TESTS { + // set inputs to random values + for (&ai, bi) in MUL_INPUT_0.iter().zip(MUL_INPUT_1) { + lv[ai] = F::from_canonical_u16(rng.gen()); + lv[bi] = F::from_canonical_u16(rng.gen()); + } - let mut constrant_consumer = ConstraintConsumer::new( - vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], - GoldilocksField::ONE, - GoldilocksField::ONE, - GoldilocksField::ONE, - ); - eval_packed_generic(&lv, &mut constrant_consumer); - for &acc in &constrant_consumer.constraint_accs { - assert_eq!(acc, GoldilocksField::ZERO); + generate(&mut lv); + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ONE, + ); + eval_packed_generic(&lv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } } } } diff --git a/evm/src/arithmetic/sub.rs b/evm/src/arithmetic/sub.rs index c632eb94..25834406 100644 --- a/evm/src/arithmetic/sub.rs +++ b/evm/src/arithmetic/sub.rs @@ -96,6 +96,8 @@ mod tests { use crate::arithmetic::columns::NUM_ARITH_COLUMNS; use crate::constraint_consumer::ConstraintConsumer; + const N_RND_TESTS: usize = 1000; + // TODO: Should be able to refactor this test to apply to all operations. #[test] fn generate_eval_consistency_not_sub() { @@ -108,14 +110,14 @@ mod tests { // if all values are garbage. lv[IS_SUB] = F::ZERO; - let mut constrant_consumer = ConstraintConsumer::new( + let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], GoldilocksField::ONE, GoldilocksField::ONE, GoldilocksField::ONE, ); - eval_packed_generic(&lv, &mut constrant_consumer); - for &acc in &constrant_consumer.constraint_accs { + eval_packed_generic(&lv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { assert_eq!(acc, GoldilocksField::ZERO); } } @@ -129,23 +131,26 @@ mod tests { // set `IS_SUB == 1` and ensure all constraints are satisfied. lv[IS_SUB] = F::ONE; - // set inputs to random values - for (&ai, bi) in SUB_INPUT_0.iter().zip(SUB_INPUT_1) { - lv[ai] = F::from_canonical_u16(rng.gen()); - lv[bi] = F::from_canonical_u16(rng.gen()); - } - generate(&mut lv); + for _ in 0..N_RND_TESTS { + // set inputs to random values + for (&ai, bi) in SUB_INPUT_0.iter().zip(SUB_INPUT_1) { + lv[ai] = F::from_canonical_u16(rng.gen()); + lv[bi] = F::from_canonical_u16(rng.gen()); + } - let mut constrant_consumer = ConstraintConsumer::new( - vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], - GoldilocksField::ONE, - GoldilocksField::ONE, - GoldilocksField::ONE, - ); - eval_packed_generic(&lv, &mut constrant_consumer); - for &acc in &constrant_consumer.constraint_accs { - assert_eq!(acc, GoldilocksField::ZERO); + generate(&mut lv); + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ONE, + ); + eval_packed_generic(&lv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } } } } diff --git a/evm/src/arithmetic/utils.rs b/evm/src/arithmetic/utils.rs index c50481f3..b5356a78 100644 --- a/evm/src/arithmetic/utils.rs +++ b/evm/src/arithmetic/utils.rs @@ -1,14 +1,28 @@ +use std::ops::{Add, AddAssign, Mul, Neg, Shr, Sub, SubAssign}; + use log::error; +use plonky2::field::extension::Extendable; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use crate::arithmetic::columns::N_LIMBS; /// Emit an error message regarding unchecked range assumptions. /// Assumes the values in `cols` are `[cols[0], cols[0] + 1, ..., /// cols[0] + cols.len() - 1]`. -pub(crate) fn _range_check_error(file: &str, line: u32, cols: &[usize]) { +pub(crate) fn _range_check_error( + file: &str, + line: u32, + cols: &[usize], + signedness: &str, +) { error!( - "{}:{}: arithmetic unit skipped {}-bit range-checks on columns {}--{}: not yet implemented", + "{}:{}: arithmetic unit skipped {}-bit {} range-checks on columns {}--{}: not yet implemented", line, file, RC_BITS, + signedness, cols[0], cols[0] + cols.len() - 1 ); @@ -17,9 +31,297 @@ pub(crate) fn _range_check_error(file: &str, line: u32, cols #[macro_export] macro_rules! range_check_error { ($cols:ident, $rc_bits:expr) => { - $crate::arithmetic::utils::_range_check_error::<$rc_bits>(file!(), line!(), &$cols); + $crate::arithmetic::utils::_range_check_error::<$rc_bits>( + file!(), + line!(), + &$cols, + "unsigned", + ); + }; + ($cols:ident, $rc_bits:expr, signed) => { + $crate::arithmetic::utils::_range_check_error::<$rc_bits>( + file!(), + line!(), + &$cols, + "signed", + ); }; ([$cols:ident], $rc_bits:expr) => { - $crate::arithmetic::utils::_range_check_error::<$rc_bits>(file!(), line!(), &[$cols]); + $crate::arithmetic::utils::_range_check_error::<$rc_bits>( + file!(), + line!(), + &[$cols], + "unsigned", + ); }; } + +/// Return an array of `N` zeros of type T. +pub(crate) fn pol_zero() -> [T; N] +where + T: Copy + Default, +{ + // TODO: This should really be T::zero() from num::Zero, because + // default() doesn't guarantee to initialise to zero (though in + // our case it always does). However I couldn't work out how to do + // that without touching half of the entire crate because it + // involves replacing Field::is_zero() with num::Zero::is_zero() + // which is used everywhere. Hence Default::default() it is. + [T::default(); N] +} + +/// a(x) += b(x), but must have deg(a) >= deg(b). +pub(crate) fn pol_add_assign(a: &mut [T], b: &[T]) +where + T: AddAssign + Copy + Default, +{ + debug_assert!(a.len() >= b.len(), "expected {} >= {}", a.len(), b.len()); + for (a_item, b_item) in a.iter_mut().zip(b) { + *a_item += *b_item; + } +} + +pub(crate) fn pol_add_assign_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a: &mut [ExtensionTarget], + b: &[ExtensionTarget], +) { + debug_assert!(a.len() >= b.len(), "expected {} >= {}", a.len(), b.len()); + for (a_item, b_item) in a.iter_mut().zip(b) { + *a_item = builder.add_extension(*a_item, *b_item); + } +} + +/// Return a(x) + b(x); returned array is bigger than necessary to +/// make the interface consistent with `pol_mul_wide`. +pub(crate) fn pol_add(a: [T; N_LIMBS], b: [T; N_LIMBS]) -> [T; 2 * N_LIMBS - 1] +where + T: Add + Copy + Default, +{ + let mut sum = pol_zero(); + for i in 0..N_LIMBS { + sum[i] = a[i] + b[i]; + } + sum +} + +pub(crate) fn pol_add_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a: [ExtensionTarget; N_LIMBS], + b: [ExtensionTarget; N_LIMBS], +) -> [ExtensionTarget; 2 * N_LIMBS - 1] { + let zero = builder.zero_extension(); + let mut sum = [zero; 2 * N_LIMBS - 1]; + for i in 0..N_LIMBS { + sum[i] = builder.add_extension(a[i], b[i]); + } + sum +} + +/// a(x) -= b(x), but must have deg(a) >= deg(b). +pub(crate) fn pol_sub_assign(a: &mut [T], b: &[T]) +where + T: SubAssign + Copy, +{ + debug_assert!(a.len() >= b.len(), "expected {} >= {}", a.len(), b.len()); + for (a_item, b_item) in a.iter_mut().zip(b) { + *a_item -= *b_item; + } +} + +pub(crate) fn pol_sub_assign_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a: &mut [ExtensionTarget], + b: &[ExtensionTarget], +) { + debug_assert!(a.len() >= b.len(), "expected {} >= {}", a.len(), b.len()); + for (a_item, b_item) in a.iter_mut().zip(b) { + *a_item = builder.sub_extension(*a_item, *b_item); + } +} + +/// Given polynomials a(x) and b(x), return a(x)*b(x). +/// +/// NB: The caller is responsible for ensuring that no undesired +/// overflow occurs during the calculation of the coefficients of the +/// product. +pub(crate) fn pol_mul_wide(a: [T; N_LIMBS], b: [T; N_LIMBS]) -> [T; 2 * N_LIMBS - 1] +where + T: AddAssign + Copy + Mul + Default, +{ + let mut res = [T::default(); 2 * N_LIMBS - 1]; + for (i, &ai) in a.iter().enumerate() { + for (j, &bj) in b.iter().enumerate() { + res[i + j] += ai * bj; + } + } + res +} + +pub(crate) fn pol_mul_wide_ext_circuit< + F: RichField + Extendable, + const D: usize, + const M: usize, + const N: usize, + const P: usize, +>( + builder: &mut CircuitBuilder, + a: [ExtensionTarget; M], + b: [ExtensionTarget; N], +) -> [ExtensionTarget; P] { + let zero = builder.zero_extension(); + let mut res = [zero; P]; + for (i, &ai) in a.iter().enumerate() { + for (j, &bj) in b.iter().enumerate() { + res[i + j] = builder.mul_add_extension(ai, bj, res[i + j]); + } + } + res +} + +/// As for `pol_mul_wide` but the first argument has 2N elements and +/// hence the result has 3N-1. +pub(crate) fn pol_mul_wide2(a: [T; 2 * N_LIMBS], b: [T; N_LIMBS]) -> [T; 3 * N_LIMBS - 1] +where + T: AddAssign + Copy + Mul + Default, +{ + let mut res = [T::default(); 3 * N_LIMBS - 1]; + for (i, &ai) in a.iter().enumerate() { + for (j, &bj) in b.iter().enumerate() { + res[i + j] += ai * bj; + } + } + res +} + +pub(crate) fn pol_mul_wide2_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a: [ExtensionTarget; 2 * N_LIMBS], + b: [ExtensionTarget; N_LIMBS], +) -> [ExtensionTarget; 3 * N_LIMBS - 1] { + let zero = builder.zero_extension(); + let mut res = [zero; 3 * N_LIMBS - 1]; + for (i, &ai) in a.iter().enumerate() { + for (j, &bj) in b.iter().enumerate() { + res[i + j] = builder.mul_add_extension(ai, bj, res[i + j]); + } + } + res +} + +/// Given a(x) and b(x), return a(x)*b(x) mod 2^256. +pub(crate) fn pol_mul_lo(a: [T; N], b: [T; N]) -> [T; N] +where + T: AddAssign + Copy + Default + Mul, +{ + let mut res = pol_zero(); + for deg in 0..N { + // Invariant: i + j = deg + for i in 0..=deg { + let j = deg - i; + res[deg] += a[i] * b[j]; + } + } + res +} + +/// Adjoin M - N zeros to a, returning [a[0], a[1], ..., a[N-1], 0, 0, ..., 0]. +pub(crate) fn pol_extend(a: [T; N]) -> [T; M] +where + T: Copy + Default, +{ + assert_eq!(M, 2 * N - 1); + + let mut zero_extend = pol_zero(); + zero_extend[..N].copy_from_slice(&a); + zero_extend +} + +pub(crate) fn pol_extend_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a: [ExtensionTarget; N_LIMBS], +) -> [ExtensionTarget; 2 * N_LIMBS - 1] { + let zero = builder.zero_extension(); + let mut zero_extend = [zero; 2 * N_LIMBS - 1]; + + zero_extend[..N_LIMBS].copy_from_slice(&a); + zero_extend +} + +/// Given polynomial a(x) = \sum_{i=0}^{2N-2} a[i] x^i and an element +/// `root`, return b = (x - root) * a(x). +/// +/// NB: Ignores element a[2 * N_LIMBS - 1], treating it as if it's 0. +pub(crate) fn pol_adjoin_root(a: [T; 2 * N_LIMBS], root: U) -> [T; 2 * N_LIMBS] +where + T: Add + Copy + Default + Mul + Sub, + U: Copy + Mul + Neg, +{ + // \sum_i res[i] x^i = (x - root) \sum_i a[i] x^i. Comparing + // coefficients, res[0] = -root*a[0] and + // res[i] = a[i-1] - root * a[i] + + let mut res = [T::default(); 2 * N_LIMBS]; + res[0] = -root * a[0]; + for deg in 1..(2 * N_LIMBS - 1) { + res[deg] = a[deg - 1] - (root * a[deg]); + } + // NB: We assume that a[2 * N_LIMBS - 1] = 0, so the last + // iteration has no "* root" term. + res[2 * N_LIMBS - 1] = a[2 * N_LIMBS - 2]; + res +} + +pub(crate) fn pol_adjoin_root_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a: [ExtensionTarget; 2 * N_LIMBS], + root: ExtensionTarget, +) -> [ExtensionTarget; 2 * N_LIMBS] { + let zero = builder.zero_extension(); + let mut res = [zero; 2 * N_LIMBS]; + // res[deg] = NEG_ONE * root * a[0] + ZERO * zero + res[0] = builder.arithmetic_extension(F::NEG_ONE, F::ZERO, root, a[0], zero); + for deg in 1..(2 * N_LIMBS - 1) { + // res[deg] = NEG_ONE * root * a[deg] + ONE * a[deg - 1] + res[deg] = builder.arithmetic_extension(F::NEG_ONE, F::ONE, root, a[deg], a[deg - 1]); + } + // NB: We assumes that a[2 * N_LIMBS - 1] = 0, so the last + // iteration has no "* root" term. + res[2 * N_LIMBS - 1] = a[2 * N_LIMBS - 2]; + res +} + +/// Given polynomial a(x) = \sum_{i=0}^{2N-1} a[i] x^i and a root of `a` +/// of the form 2^EXP, return q(x) satisfying a(x) = (x - root) * q(x). +/// +/// NB: We do not verify that a(2^EXP) = 0; if this doesn't hold the +/// result is basically junk. +/// +/// NB: The result could be returned in 2*N-1 elements, but we return +/// 2*N and set the last element to zero since the calling code +/// happens to require a result zero-extended to 2*N elements. +pub(crate) fn pol_remove_root_2exp(a: [T; 2 * N_LIMBS]) -> [T; 2 * N_LIMBS] +where + T: Copy + Default + Neg + Shr + Sub, +{ + // By assumption β := 2^EXP is a root of `a`, i.e. (x - β) divides + // `a`; if we write + // + // a(x) = \sum_{i=0}^{2N-1} a[i] x^i + // = (x - β) \sum_{i=0}^{2N-2} q[i] x^i + // + // then by comparing coefficients it is easy to see that + // + // q[0] = -a[0] / β and q[i] = (q[i-1] - a[i]) / β + // + // for 0 < i <= 2N-1 (and the divisions are exact). + + let mut q = [T::default(); 2 * N_LIMBS]; + q[0] = -(a[0] >> EXP); + + // NB: Last element of q is deliberately left equal to zero. + for deg in 1..2 * N_LIMBS - 1 { + q[deg] = (q[deg - 1] - a[deg]) >> EXP; + } + q +} diff --git a/evm/src/cpu/kernel/asm/memory/core.asm b/evm/src/cpu/kernel/asm/memory/core.asm index 12665e55..c2c19811 100644 --- a/evm/src/cpu/kernel/asm/memory/core.asm +++ b/evm/src/cpu/kernel/asm/memory/core.asm @@ -363,7 +363,7 @@ // Load a single value from kernel general memory. %macro mload_kernel_general_2(offset) PUSH $offset - %mload_kernel(@SEGMENT_KERNEL_GENERAL) + %mload_kernel(@SEGMENT_KERNEL_GENERAL_2) // stack: value %endmacro diff --git a/evm/src/cpu/kernel/asm/memory/packing.asm b/evm/src/cpu/kernel/asm/memory/packing.asm index 3021c640..c8b4c468 100644 --- a/evm/src/cpu/kernel/asm/memory/packing.asm +++ b/evm/src/cpu/kernel/asm/memory/packing.asm @@ -1,10 +1,47 @@ // Methods for encoding integers as bytes in memory, as well as the reverse, // decoding bytes as integers. All big-endian. +// Given a pointer to some bytes in memory, pack them into a word. Assumes 0 < len <= 32. +// Pre stack: addr: 3, len, retdest +// Post stack: packed_value +// NOTE: addr: 3 denotes a (context, segment, virtual) tuple global mload_packing: - // stack: context, segment, offset, len, retdest - PANIC // TODO - // stack: value + // stack: addr: 3, len, retdest + DUP3 DUP3 DUP3 MLOAD_GENERAL DUP5 %eq_const(1) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(1) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(2) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(2) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(3) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(3) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(4) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(4) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(5) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(5) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(6) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(6) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(7) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(7) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(8) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(8) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(9) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(9) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(10) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(10) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(11) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(11) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(12) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(12) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(13) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(13) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(14) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(14) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(15) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(15) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(16) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(16) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(17) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(17) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(18) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(18) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(19) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(19) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(20) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(20) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(21) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(21) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(22) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(22) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(23) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(23) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(24) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(24) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(25) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(25) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(26) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(26) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(27) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(27) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(28) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(28) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(29) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(29) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(30) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(30) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(31) %jumpi(mload_packing_return) %shl_const(8) + DUP4 %add_const(31) DUP4 DUP4 MLOAD_GENERAL ADD +mload_packing_return: + %stack (packed_value, addr: 3, len, retdest) -> (retdest, packed_value) + JUMP // Pre stack: context, segment, offset, value, len, retdest // Post stack: offset' diff --git a/evm/src/cpu/kernel/asm/mpt/hash.asm b/evm/src/cpu/kernel/asm/mpt/hash.asm index 840fb429..abd436fe 100644 --- a/evm/src/cpu/kernel/asm/mpt/hash.asm +++ b/evm/src/cpu/kernel/asm/mpt/hash.asm @@ -64,7 +64,13 @@ maybe_hash_node: JUMP pack_small_rlp: // stack: result_ptr, result_len, retdest - PANIC // TODO: Return packed RLP + %stack (result_ptr, result_len) + -> (0, @SEGMENT_RLP_RAW, result_ptr, result_len, + after_packed_small_rlp, result_len) + %jump(mload_packing) +after_packed_small_rlp: + %stack (result, result_len, retdest) -> (retdest, result, result_len) + JUMP // RLP encode the given trie node, and return an (pointer, length) pair // indicating where the data lives within @SEGMENT_RLP_RAW. @@ -107,36 +113,6 @@ global encode_node_hash: %stack (hash, encode_value, retdest) -> (retdest, hash, 32) JUMP -// Part of the encode_node_branch function. Encodes the i'th child. -// Stores the result in SEGMENT_KERNEL_GENERAL[i], and its length in -// SEGMENT_KERNEL_GENERAL_2[i]. -%macro encode_child(i) - // stack: node_payload_ptr, encode_value, retdest - PUSH %%after_encode - DUP3 DUP3 - // stack: node_payload_ptr, encode_value, %%after_encode, node_payload_ptr, encode_value, retdest - %add_const($i) %mload_trie_data - // stack: child_i_ptr, encode_value, %%after_encode, node_payload_ptr, encode_value, retdest - %jump(encode_or_hash_node) -%%after_encode: - // stack: result, result_len, node_payload_ptr, encode_value, retdest - %mstore_kernel_general($i) - %mstore_kernel_general_2($i) - // stack: node_payload_ptr, encode_value, retdest -%endmacro - -// Part of the encode_node_branch function. Appends the i'th child's RLP. -%macro append_child(i) - // stack: rlp_pos, node_payload_ptr, encode_value, retdest - %mload_kernel_general($i) // load result_i - %mload_kernel_general_2($i) // load result_i_len - %stack (result, result_len, rlp_pos, node_payload_ptr, encode_value, retdest) - -> (rlp_pos, result, result_len, %%after_unpacking, node_payload_ptr, encode_value, retdest) - %jump(mstore_unpacking_rlp) -%%after_unpacking: - // stack: rlp_pos', node_payload_ptr, encode_value, retdest -%endmacro - encode_node_branch: // stack: node_type, node_payload_ptr, encode_value, retdest POP @@ -186,11 +162,83 @@ encode_node_branch_prepend_prefix: %stack (start_pos, rlp_len, retdest) -> (retdest, start_pos, rlp_len) JUMP +// Part of the encode_node_branch function. Encodes the i'th child. +// Stores the result in SEGMENT_KERNEL_GENERAL[i], and its length in +// SEGMENT_KERNEL_GENERAL_2[i]. +%macro encode_child(i) + // stack: node_payload_ptr, encode_value, retdest + PUSH %%after_encode + DUP3 DUP3 + // stack: node_payload_ptr, encode_value, %%after_encode, node_payload_ptr, encode_value, retdest + %add_const($i) %mload_trie_data + // stack: child_i_ptr, encode_value, %%after_encode, node_payload_ptr, encode_value, retdest + %jump(encode_or_hash_node) +%%after_encode: + // stack: result, result_len, node_payload_ptr, encode_value, retdest + %mstore_kernel_general($i) + %mstore_kernel_general_2($i) + // stack: node_payload_ptr, encode_value, retdest +%endmacro + +// Part of the encode_node_branch function. Appends the i'th child's RLP. +%macro append_child(i) + // stack: rlp_pos, node_payload_ptr, encode_value, retdest + %mload_kernel_general($i) // load result + %mload_kernel_general_2($i) // load result_len + // stack: result_len, result, rlp_pos, node_payload_ptr, encode_value, retdest + // If result_len != 32, result is raw RLP, with an appropriate RLP prefix already. + DUP1 %sub_const(32) %jumpi(%%unpack) + // Otherwise, result is a hash, and we need to add the prefix 0x80 + 32 = 160. + // stack: result_len, result, rlp_pos, node_payload_ptr, encode_value, retdest + PUSH 160 + DUP4 // rlp_pos + %mstore_rlp + SWAP2 %increment SWAP2 // rlp_pos += 1 +%%unpack: + %stack (result_len, result, rlp_pos, node_payload_ptr, encode_value, retdest) + -> (rlp_pos, result, result_len, %%after_unpacking, node_payload_ptr, encode_value, retdest) + %jump(mstore_unpacking_rlp) +%%after_unpacking: + // stack: rlp_pos', node_payload_ptr, encode_value, retdest +%endmacro + encode_node_extension: // stack: node_type, node_payload_ptr, encode_value, retdest - POP - // stack: node_payload_ptr, encode_value, retdest - PANIC // TODO + %stack (node_type, node_payload_ptr, encode_value) + -> (node_payload_ptr, encode_value, encode_node_extension_after_encode_child, node_payload_ptr) + %add_const(2) %mload_trie_data + // stack: child_ptr, encode_value, encode_node_extension_after_encode_child, node_payload_ptr, retdest + %jump(encode_or_hash_node) +encode_node_extension_after_encode_child: + // stack: result, result_len, node_payload_ptr, retdest + PUSH encode_node_extension_after_hex_prefix // retdest + PUSH 0 // terminated + // stack: terminated, encode_node_extension_after_hex_prefix, result, result_len, node_payload_ptr, retdest + DUP5 %add_const(1) %mload_trie_data // Load the packed_nibbles field, which is at index 1. + // stack: packed_nibbles, terminated, encode_node_extension_after_hex_prefix, result, result_len, node_payload_ptr, retdest + DUP6 %mload_trie_data // Load the num_nibbles field, which is at index 0. + // stack: num_nibbles, packed_nibbles, terminated, encode_node_extension_after_hex_prefix, result, result_len, node_payload_ptr, retdest + PUSH 9 // We start at 9 to leave room to prepend the largest possible RLP list header. + // stack: rlp_start, num_nibbles, packed_nibbles, terminated, encode_node_extension_after_hex_prefix, result, result_len, node_payload_ptr, retdest + %jump(hex_prefix_rlp) +encode_node_extension_after_hex_prefix: + // stack: rlp_pos, result, result_len, node_payload_ptr, retdest + // If result_len != 32, result is raw RLP, with an appropriate RLP prefix already. + DUP3 %sub_const(32) %jumpi(encode_node_extension_unpack) + // Otherwise, result is a hash, and we need to add the prefix 0x80 + 32 = 160. + PUSH 160 + DUP2 // rlp_pos + %mstore_rlp + %increment // rlp_pos += 1 +encode_node_extension_unpack: + %stack (rlp_pos, result, result_len, node_payload_ptr) + -> (rlp_pos, result, result_len, encode_node_extension_after_unpacking) + %jump(mstore_unpacking_rlp) +encode_node_extension_after_unpacking: + // stack: rlp_end_pos, retdest + %prepend_rlp_list_prefix + %stack (rlp_start_pos, rlp_len, retdest) -> (retdest, rlp_start_pos, rlp_len) + JUMP encode_node_leaf: // stack: node_type, node_payload_ptr, encode_value, retdest diff --git a/evm/src/cpu/kernel/asm/mpt/load.asm b/evm/src/cpu/kernel/asm/mpt/load.asm index 2e218ab4..f072f202 100644 --- a/evm/src/cpu/kernel/asm/mpt/load.asm +++ b/evm/src/cpu/kernel/asm/mpt/load.asm @@ -70,6 +70,9 @@ load_mpt_branch: // stack: node_type, retdest POP // stack: retdest + // Save the offset of our 16 child pointers so we can write them later. + // Then advance out current trie pointer beyond them, so we can load the + // value and have it placed after our child pointers. %get_trie_data_size // stack: ptr_children, retdest DUP1 %add_const(16) @@ -78,24 +81,20 @@ load_mpt_branch: // stack: ptr_children, retdest %load_leaf_value - // Save the current trie_data_size (which now points to the end of the leaf) - // for later, then have it point to the start of our 16 child pointers. - %get_trie_data_size - // stack: ptr_end_of_leaf, ptr_children, retdest - SWAP1 - %set_trie_data_size - // stack: ptr_end_of_leaf, retdest - // Load the 16 children. %rep 16 %load_mpt_and_return_root_ptr - // stack: child_ptr, ptr_end_of_leaf, retdest - %append_to_trie_data - // stack: ptr_end_of_leaf, retdest + // stack: child_ptr, ptr_next_child, retdest + DUP2 + // stack: ptr_next_child, child_ptr, ptr_next_child, retdest + %mstore_trie_data + // stack: ptr_next_child, retdest + %increment + // stack: ptr_next_child, retdest %endrep - %set_trie_data_size - // stack: retdest + // stack: ptr_next_child, retdest + POP JUMP load_mpt_extension: diff --git a/evm/src/cpu/kernel/tests/mpt/hash.rs b/evm/src/cpu/kernel/tests/mpt/hash.rs index 4a79a5b6..6b31a523 100644 --- a/evm/src/cpu/kernel/tests/mpt/hash.rs +++ b/evm/src/cpu/kernel/tests/mpt/hash.rs @@ -4,11 +4,26 @@ use ethereum_types::{BigEndianHash, H256, U256}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; +use crate::cpu::kernel::tests::mpt::extension_to_leaf; use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; use crate::generation::TrieInputs; +// TODO: Test with short leaf. Might need to be a storage trie. + #[test] -fn mpt_hash() -> Result<()> { +fn mpt_hash_empty() -> Result<()> { + let trie_inputs = TrieInputs { + state_trie: Default::default(), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + test_state_trie(trie_inputs) +} + +#[test] +fn mpt_hash_leaf() -> Result<()> { let account = AccountRlp { nonce: U256::from(1111), balance: U256::from(2222), @@ -17,8 +32,6 @@ fn mpt_hash() -> Result<()> { }; let account_rlp = rlp::encode(&account); - // TODO: Try this more "advanced" trie. - // let state_trie = state_trie_ext_to_account_leaf(account_rlp.to_vec()); let state_trie = PartialTrie::Leaf { nibbles: Nibbles { count: 3, @@ -26,7 +39,6 @@ fn mpt_hash() -> Result<()> { }, value: account_rlp.to_vec(), }; - let state_trie_hash = state_trie.calc_hash(); let trie_inputs = TrieInputs { state_trie, @@ -35,10 +47,70 @@ fn mpt_hash() -> Result<()> { storage_tries: vec![], }; + test_state_trie(trie_inputs) +} + +#[test] +fn mpt_hash_extension_to_leaf() -> Result<()> { + let account = AccountRlp { + nonce: U256::from(1111), + balance: U256::from(2222), + storage_root: H256::from_uint(&U256::from(3333)), + code_hash: H256::from_uint(&U256::from(4444)), + }; + let account_rlp = rlp::encode(&account); + + let state_trie = extension_to_leaf(account_rlp.to_vec()); + + let trie_inputs = TrieInputs { + state_trie, + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + test_state_trie(trie_inputs) +} + +#[test] +fn mpt_hash_branch_to_leaf() -> Result<()> { + let account = AccountRlp { + nonce: U256::from(1111), + balance: U256::from(2222), + storage_root: H256::from_uint(&U256::from(3333)), + code_hash: H256::from_uint(&U256::from(4444)), + }; + let account_rlp = rlp::encode(&account); + + let leaf = PartialTrie::Leaf { + nibbles: Nibbles { + count: 3, + packed: 0xABC.into(), + }, + value: account_rlp.to_vec(), + }; + let mut children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); + children[0] = Box::new(leaf); + let state_trie = PartialTrie::Branch { + children, + value: vec![], + }; + + let trie_inputs = TrieInputs { + state_trie, + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + test_state_trie(trie_inputs) +} + +fn test_state_trie(trie_inputs: TrieInputs) -> Result<()> { let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; - let initial_stack = vec![0xdeadbeefu32.into()]; + let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); interpreter.run()?; @@ -49,9 +121,15 @@ fn mpt_hash() -> Result<()> { interpreter.push(0xDEADBEEFu32.into()); interpreter.run()?; - assert_eq!(interpreter.stack().len(), 1); + assert_eq!( + interpreter.stack().len(), + 1, + "Expected 1 item on stack, found {:?}", + interpreter.stack() + ); let hash = H256::from_uint(&interpreter.stack()[0]); - assert_eq!(hash, state_trie_hash); + let expected_state_trie_hash = trie_inputs.state_trie.calc_hash(); + assert_eq!(hash, expected_state_trie_hash); Ok(()) } diff --git a/evm/src/cpu/kernel/tests/mpt/load.rs b/evm/src/cpu/kernel/tests/mpt/load.rs index 19640387..3af39e30 100644 --- a/evm/src/cpu/kernel/tests/mpt/load.rs +++ b/evm/src/cpu/kernel/tests/mpt/load.rs @@ -5,7 +5,7 @@ use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::state_trie_ext_to_account_leaf; +use crate::cpu::kernel::tests::mpt::extension_to_leaf; use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; use crate::generation::TrieInputs; @@ -20,7 +20,7 @@ fn load_all_mpts() -> Result<()> { let account_rlp = rlp::encode(&account); let trie_inputs = TrieInputs { - state_trie: state_trie_ext_to_account_leaf(account_rlp.to_vec()), + state_trie: extension_to_leaf(account_rlp.to_vec()), transactions_trie: Default::default(), receipts_trie: Default::default(), storage_tries: vec![], diff --git a/evm/src/cpu/kernel/tests/mpt/mod.rs b/evm/src/cpu/kernel/tests/mpt/mod.rs index 8308962a..55a56653 100644 --- a/evm/src/cpu/kernel/tests/mpt/mod.rs +++ b/evm/src/cpu/kernel/tests/mpt/mod.rs @@ -6,7 +6,7 @@ mod load; mod read; /// A `PartialTrie` where an extension node leads to a leaf node containing an account. -pub(crate) fn state_trie_ext_to_account_leaf(value: Vec) -> PartialTrie { +pub(crate) fn extension_to_leaf(value: Vec) -> PartialTrie { PartialTrie::Extension { nibbles: Nibbles { count: 3, diff --git a/evm/src/cpu/kernel/tests/mpt/read.rs b/evm/src/cpu/kernel/tests/mpt/read.rs index c539eef8..c45a6b60 100644 --- a/evm/src/cpu/kernel/tests/mpt/read.rs +++ b/evm/src/cpu/kernel/tests/mpt/read.rs @@ -4,7 +4,7 @@ use ethereum_types::{BigEndianHash, H256, U256}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::state_trie_ext_to_account_leaf; +use crate::cpu::kernel::tests::mpt::extension_to_leaf; use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; use crate::generation::TrieInputs; @@ -19,7 +19,7 @@ fn mpt_read() -> Result<()> { let account_rlp = rlp::encode(&account); let trie_inputs = TrieInputs { - state_trie: state_trie_ext_to_account_leaf(account_rlp.to_vec()), + state_trie: extension_to_leaf(account_rlp.to_vec()), transactions_trie: Default::default(), receipts_trie: Default::default(), storage_tries: vec![], diff --git a/evm/src/cpu/kernel/tests/packing.rs b/evm/src/cpu/kernel/tests/packing.rs index dcfdd69b..71f66e6d 100644 --- a/evm/src/cpu/kernel/tests/packing.rs +++ b/evm/src/cpu/kernel/tests/packing.rs @@ -1,9 +1,70 @@ use anyhow::Result; +use ethereum_types::U256; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; use crate::memory::segments::Segment; +#[test] +fn test_mload_packing_1_byte() -> Result<()> { + let mstore_unpacking = KERNEL.global_labels["mload_packing"]; + + let retdest = 0xDEADBEEFu32.into(); + let len = 1.into(); + let offset = 2.into(); + let segment = (Segment::RlpRaw as u32).into(); + let context = 0.into(); + let initial_stack = vec![retdest, len, offset, segment, context]; + + let mut interpreter = Interpreter::new_with_kernel(mstore_unpacking, initial_stack); + interpreter.set_rlp_memory(vec![0, 0, 0xAB]); + + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![0xAB.into()]); + + Ok(()) +} + +#[test] +fn test_mload_packing_3_bytes() -> Result<()> { + let mstore_unpacking = KERNEL.global_labels["mload_packing"]; + + let retdest = 0xDEADBEEFu32.into(); + let len = 3.into(); + let offset = 2.into(); + let segment = (Segment::RlpRaw as u32).into(); + let context = 0.into(); + let initial_stack = vec![retdest, len, offset, segment, context]; + + let mut interpreter = Interpreter::new_with_kernel(mstore_unpacking, initial_stack); + interpreter.set_rlp_memory(vec![0, 0, 0xAB, 0xCD, 0xEF]); + + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![0xABCDEF.into()]); + + Ok(()) +} + +#[test] +fn test_mload_packing_32_bytes() -> Result<()> { + let mstore_unpacking = KERNEL.global_labels["mload_packing"]; + + let retdest = 0xDEADBEEFu32.into(); + let len = 32.into(); + let offset = 0.into(); + let segment = (Segment::RlpRaw as u32).into(); + let context = 0.into(); + let initial_stack = vec![retdest, len, offset, segment, context]; + + let mut interpreter = Interpreter::new_with_kernel(mstore_unpacking, initial_stack); + interpreter.set_rlp_memory(vec![0xFF; 32]); + + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![U256::MAX]); + + Ok(()) +} + #[test] fn test_mstore_unpacking() -> Result<()> { let mstore_unpacking = KERNEL.global_labels["mstore_unpacking"]; diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 83f2083d..a1fd3ce7 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -191,6 +191,15 @@ impl CrossTableLookup { default, } } + + pub(crate) fn num_ctl_zs(ctls: &[Self], table: Table, num_challenges: usize) -> usize { + let mut num_ctls = 0; + for ctl in ctls { + let all_tables = std::iter::once(&ctl.looked_table).chain(&ctl.looking_tables); + num_ctls += all_tables.filter(|twc| twc.table == table).count(); + } + num_ctls * num_challenges + } } /// Cross-table lookup data for one table. @@ -450,24 +459,24 @@ pub struct CtlCheckVarsTarget<'a, F: Field, const D: usize> { } impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { - pub(crate) fn from_proofs( - proofs: &[StarkProofTarget; NUM_TABLES], + pub(crate) fn from_proof( + table: Table, + proof: &StarkProofTarget, cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, - num_permutation_zs: &[usize; NUM_TABLES], - ) -> [Vec; NUM_TABLES] { - let mut ctl_zs = proofs - .iter() - .zip(num_permutation_zs) - .map(|(p, &num_perms)| { - let openings = &p.openings; - let ctl_zs = openings.permutation_ctl_zs.iter().skip(num_perms); - let ctl_zs_next = openings.permutation_ctl_zs_next.iter().skip(num_perms); - ctl_zs.zip(ctl_zs_next) - }) - .collect::>(); + num_permutation_zs: usize, + ) -> Vec { + let mut ctl_zs = { + let openings = &proof.openings; + let ctl_zs = openings.permutation_ctl_zs.iter().skip(num_permutation_zs); + let ctl_zs_next = openings + .permutation_ctl_zs_next + .iter() + .skip(num_permutation_zs); + ctl_zs.zip(ctl_zs_next) + }; - let mut ctl_vars_per_table = [0; NUM_TABLES].map(|_| vec![]); + let mut ctl_vars = vec![]; for CrossTableLookup { looking_tables, looked_table, @@ -475,28 +484,33 @@ impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { } in cross_table_lookups { for &challenges in &ctl_challenges.challenges { - for table in looking_tables { - let (looking_z, looking_z_next) = ctl_zs[table.table as usize].next().unwrap(); - ctl_vars_per_table[table.table as usize].push(Self { - local_z: *looking_z, - next_z: *looking_z_next, - challenges, - columns: &table.columns, - filter_column: &table.filter_column, - }); + for looking_table in looking_tables { + if looking_table.table == table { + let (looking_z, looking_z_next) = ctl_zs.next().unwrap(); + ctl_vars.push(Self { + local_z: *looking_z, + next_z: *looking_z_next, + challenges, + columns: &looking_table.columns, + filter_column: &looking_table.filter_column, + }); + } } - let (looked_z, looked_z_next) = ctl_zs[looked_table.table as usize].next().unwrap(); - ctl_vars_per_table[looked_table.table as usize].push(Self { - local_z: *looked_z, - next_z: *looked_z_next, - challenges, - columns: &looked_table.columns, - filter_column: &looked_table.filter_column, - }); + if looked_table.table == table { + let (looked_z, looked_z_next) = ctl_zs.next().unwrap(); + ctl_vars.push(Self { + local_z: *looked_z, + next_z: *looked_z_next, + challenges, + columns: &looked_table.columns, + filter_column: &looked_table.filter_column, + }); + } } } - ctl_vars_per_table + assert!(ctl_zs.next().is_none()); + ctl_vars } } @@ -568,18 +582,12 @@ pub(crate) fn verify_cross_table_lookups< const D: usize, >( cross_table_lookups: Vec>, - proofs: &[StarkProof; NUM_TABLES], + ctl_zs_lasts: [Vec; NUM_TABLES], + degrees_bits: [usize; NUM_TABLES], challenges: GrandProductChallengeSet, config: &StarkConfig, ) -> Result<()> { - let degrees_bits = proofs - .iter() - .map(|p| p.recover_degree_bits(config)) - .collect::>(); - let mut ctl_zs_openings = proofs - .iter() - .map(|p| p.openings.ctl_zs_last.iter()) - .collect::>(); + let mut ctl_zs_openings = ctl_zs_lasts.iter().map(|v| v.iter()).collect::>(); for ( i, CrossTableLookup { @@ -626,18 +634,12 @@ pub(crate) fn verify_cross_table_lookups_circuit< >( builder: &mut CircuitBuilder, cross_table_lookups: Vec>, - proofs: &[StarkProofTarget; NUM_TABLES], + ctl_zs_lasts: [Vec; NUM_TABLES], + degrees_bits: [usize; NUM_TABLES], challenges: GrandProductChallengeSet, inner_config: &StarkConfig, ) { - let degrees_bits = proofs - .iter() - .map(|p| p.recover_degree_bits(inner_config)) - .collect::>(); - let mut ctl_zs_openings = proofs - .iter() - .map(|p| p.openings.ctl_zs_last.iter()) - .collect::>(); + let mut ctl_zs_openings = ctl_zs_lasts.iter().map(|v| v.iter()).collect::>(); for ( i, CrossTableLookup { diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index be932c5c..75f434d7 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -116,7 +116,7 @@ pub(crate) fn generate_traces, const D: usize>( let keccak_trace = all_stark.keccak_stark.generate_trace(keccak_inputs, timing); let keccak_memory_trace = all_stark.keccak_memory_stark.generate_trace( keccak_memory_inputs, - 1 << config.fri_config.cap_height, + config.fri_config.num_cap_elements(), timing, ); let logic_trace = all_stark.logic_stark.generate_trace(logic_ops, timing); diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index d02e2d59..f6bc630d 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -69,12 +69,17 @@ pub(crate) fn mpt_prover_inputs( PartialTrie::Empty => {} PartialTrie::Hash(h) => prover_inputs.push(U256::from_big_endian(h.as_bytes())), PartialTrie::Branch { children, value } => { + if value.is_empty() { + // There's no value, so length=0. + prover_inputs.push(U256::zero()); + } else { + let leaf = parse_leaf(value); + prover_inputs.push(leaf.len().into()); + prover_inputs.extend(leaf); + } for child in children { mpt_prover_inputs(child, prover_inputs, parse_leaf); } - let leaf = parse_leaf(value); - prover_inputs.push(leaf.len().into()); - prover_inputs.extend(leaf); } PartialTrie::Extension { nibbles, child } => { prover_inputs.push(nibbles.count.into()); diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index 6545a1af..ede7c466 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -5,11 +5,11 @@ use plonky2::iop::challenger::{Challenger, RecursiveChallenger}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; -use crate::all_stark::AllStark; +use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; use crate::permutation::{ - get_grand_product_challenge_set, get_grand_product_challenge_set_target, - get_n_grand_product_challenge_sets, get_n_grand_product_challenge_sets_target, + get_grand_product_challenge_set, get_n_grand_product_challenge_sets, + get_n_grand_product_challenge_sets_target, }; use crate::proof::*; @@ -36,6 +36,7 @@ impl, C: GenericConfig, const D: usize> A AllProofChallenges { stark_challenges: std::array::from_fn(|i| { + challenger.compact(); self.stark_proofs[i].get_challenges( &mut challenger, num_permutation_zs[i] > 0, @@ -46,40 +47,40 @@ impl, C: GenericConfig, const D: usize> A ctl_challenges, } } -} -impl AllProofTarget { - pub(crate) fn get_challenges, C: GenericConfig>( + #[allow(unused)] // TODO: should be used soon + pub(crate) fn get_challenger_states( &self, - builder: &mut CircuitBuilder, all_stark: &AllStark, config: &StarkConfig, - ) -> AllProofChallengesTarget - where - C::Hasher: AlgebraicHasher, - { - let mut challenger = RecursiveChallenger::::new(builder); + ) -> AllChallengerState { + let mut challenger = Challenger::::new(); for proof in &self.stark_proofs { challenger.observe_cap(&proof.trace_cap); } + // TODO: Observe public values. + let ctl_challenges = - get_grand_product_challenge_set_target(builder, &mut challenger, config.num_challenges); + get_grand_product_challenge_set(&mut challenger, config.num_challenges); let num_permutation_zs = all_stark.nums_permutation_zs(config); let num_permutation_batch_sizes = all_stark.permutation_batch_sizes(); - AllProofChallengesTarget { - stark_challenges: std::array::from_fn(|i| { - self.stark_proofs[i].get_challenges::( - builder, - &mut challenger, - num_permutation_zs[i] > 0, - num_permutation_batch_sizes[i], - config, - ) - }), + let mut challenger_states = vec![challenger.compact()]; + for i in 0..NUM_TABLES { + self.stark_proofs[i].get_challenges( + &mut challenger, + num_permutation_zs[i] > 0, + num_permutation_batch_sizes[i], + config, + ); + challenger_states.push(challenger.compact()); + } + + AllChallengerState { + states: challenger_states.try_into().unwrap(), ctl_challenges, } } diff --git a/evm/src/permutation.rs b/evm/src/permutation.rs index 0bb8ab1d..b081c309 100644 --- a/evm/src/permutation.rs +++ b/evm/src/permutation.rs @@ -1,5 +1,7 @@ //! Permutation arguments. +use std::fmt::Debug; + use itertools::Itertools; use maybe_rayon::*; use plonky2::field::batch_util::batch_multiply_inplace; @@ -42,14 +44,14 @@ impl PermutationPair { } /// A single instance of a permutation check protocol. -pub(crate) struct PermutationInstance<'a, T: Copy> { +pub(crate) struct PermutationInstance<'a, T: Copy + Eq + PartialEq + Debug> { pub(crate) pair: &'a PermutationPair, pub(crate) challenge: GrandProductChallenge, } /// Randomness for a single instance of a permutation check protocol. -#[derive(Copy, Clone)] -pub(crate) struct GrandProductChallenge { +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub(crate) struct GrandProductChallenge { /// Randomness used to combine multiple columns into one. pub(crate) beta: T, /// Random offset that's added to the beta-reduced column values. @@ -92,8 +94,8 @@ impl GrandProductChallenge { } /// Like `PermutationChallenge`, but with `num_challenges` copies to boost soundness. -#[derive(Clone)] -pub(crate) struct GrandProductChallengeSet { +#[derive(Clone, Eq, PartialEq, Debug)] +pub(crate) struct GrandProductChallengeSet { pub(crate) challenges: Vec>, } @@ -261,7 +263,7 @@ pub(crate) fn get_n_grand_product_challenge_sets_target< /// Before batching, each permutation pair leads to `num_challenges` permutation arguments, so we /// start with the cartesian product of `permutation_pairs` and `0..num_challenges`. Then we /// chunk these arguments based on our batch size. -pub(crate) fn get_permutation_batches<'a, T: Copy>( +pub(crate) fn get_permutation_batches<'a, T: Copy + Eq + PartialEq + Debug>( permutation_pairs: &'a [PermutationPair], permutation_challenge_sets: &[GrandProductChallengeSet], num_challenges: usize, diff --git a/evm/src/proof.rs b/evm/src/proof.rs index de00abfc..4cd03a65 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -8,6 +8,7 @@ use plonky2::fri::structure::{ FriOpeningBatch, FriOpeningBatchTarget, FriOpenings, FriOpeningsTarget, }; use plonky2::hash::hash_types::{MerkleCapTarget, RichField}; +use plonky2::hash::hashing::SPONGE_WIDTH; use plonky2::hash::merkle_tree::MerkleCap; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; @@ -28,10 +29,6 @@ impl, C: GenericConfig, const D: usize> A pub fn degree_bits(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { std::array::from_fn(|i| self.stark_proofs[i].recover_degree_bits(config)) } - - pub fn nums_ctl_zs(&self) -> [usize; NUM_TABLES] { - std::array::from_fn(|i| self.stark_proofs[i].openings.ctl_zs_last.len()) - } } pub(crate) struct AllProofChallenges, const D: usize> { @@ -39,6 +36,14 @@ pub(crate) struct AllProofChallenges, const D: usiz pub ctl_challenges: GrandProductChallengeSet, } +#[allow(unused)] // TODO: should be used soon +pub(crate) struct AllChallengerState, const D: usize> { + /// Sponge state of the challenger before starting each proof, + /// along with the final state after all proofs are done. This final state isn't strictly needed. + pub states: [[F; SPONGE_WIDTH]; NUM_TABLES + 1], + pub ctl_challenges: GrandProductChallengeSet, +} + pub struct AllProofTarget { pub stark_proofs: [StarkProofTarget; NUM_TABLES], pub public_values: PublicValuesTarget, @@ -94,11 +99,6 @@ pub struct BlockMetadataTarget { pub block_base_fee: Target, } -pub(crate) struct AllProofChallengesTarget { - pub stark_challenges: [StarkProofChallengesTarget; NUM_TABLES], - pub ctl_challenges: GrandProductChallengeSet, -} - #[derive(Debug, Clone)] pub struct StarkProof, C: GenericConfig, const D: usize> { /// Merkle cap of LDEs of trace values. @@ -123,6 +123,10 @@ impl, C: GenericConfig, const D: usize> S let lde_bits = config.fri_config.cap_height + initial_merkle_proof.siblings.len(); lde_bits - config.fri_config.rate_bits } + + pub fn num_ctl_zs(&self) -> usize { + self.openings.ctl_zs_last.len() + } } pub struct StarkProofTarget { diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 7fe57631..20e8c628 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -201,6 +201,8 @@ where "FRI total reduction arity is too large.", ); + challenger.compact(); + // Permutation arguments. let permutation_challenges = stark.uses_permutation_args().then(|| { get_n_grand_product_challenge_sets( diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 35041a48..a16063c4 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -1,29 +1,44 @@ +use std::fmt::Debug; + +use anyhow::{ensure, Result}; use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::types::Field; use plonky2::fri::witness_util::set_fri_proof_target; -use plonky2::hash::hash_types::RichField; +use plonky2::hash::hash_types::{HashOut, RichField}; +use plonky2::hash::hashing::SPONGE_WIDTH; +use plonky2::iop::challenger::{Challenger, RecursiveChallenger}; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::iop::witness::Witness; use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{CircuitConfig, VerifierCircuitData, VerifierCircuitTarget}; +use plonky2::plonk::config::Hasher; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; +use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use plonky2::util::reducing::ReducingFactorTarget; use plonky2::with_context; +use crate::all_stark::NUM_TABLES; use crate::config::StarkConfig; use crate::constraint_consumer::RecursiveConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; -use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CtlCheckVarsTarget}; +use crate::cross_table_lookup::{ + verify_cross_table_lookups, verify_cross_table_lookups_circuit, CrossTableLookup, + CtlCheckVarsTarget, +}; use crate::keccak::keccak_stark::KeccakStark; use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; -use crate::permutation::PermutationCheckDataTarget; +use crate::permutation::{ + get_grand_product_challenge_set, get_grand_product_challenge_set_target, GrandProductChallenge, + GrandProductChallengeSet, PermutationCheckDataTarget, +}; use crate::proof::{ - AllProof, AllProofChallengesTarget, AllProofTarget, BlockMetadata, BlockMetadataTarget, - PublicValues, PublicValuesTarget, StarkOpeningSetTarget, StarkProof, - StarkProofChallengesTarget, StarkProofTarget, TrieRoots, TrieRootsTarget, + AllProof, AllProofTarget, BlockMetadata, BlockMetadataTarget, PublicValues, PublicValuesTarget, + StarkOpeningSetTarget, StarkProof, StarkProofChallengesTarget, StarkProofTarget, TrieRoots, + TrieRootsTarget, }; use crate::stark::Stark; use crate::util::h160_limbs; @@ -34,118 +49,342 @@ use crate::{ util::h256_limbs, }; -pub fn verify_proof_circuit< +/// Table-wise recursive proofs of an `AllProof`. +pub struct RecursiveAllProof< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> { + pub recursive_proofs: [ProofWithPublicInputs; NUM_TABLES], +} + +pub struct RecursiveAllProofTargetWithData { + pub recursive_proofs: [ProofWithPublicInputsTarget; NUM_TABLES], + pub verifier_data: [VerifierCircuitTarget; NUM_TABLES], +} + +struct PublicInputs { + trace_cap: Vec>, + ctl_zs_last: Vec, + ctl_challenges: GrandProductChallengeSet, + challenger_state_before: [T; SPONGE_WIDTH], + challenger_state_after: [T; SPONGE_WIDTH], +} + +/// Similar to the unstable `Iterator::next_chunk`. Could be replaced with that when it's stable. +fn next_chunk(iter: &mut impl Iterator) -> [T; N] { + (0..N) + .flat_map(|_| iter.next()) + .collect_vec() + .try_into() + .expect("Not enough elements") +} + +impl PublicInputs { + fn from_vec(v: &[T], config: &StarkConfig) -> Self { + let mut iter = v.iter().copied(); + let trace_cap = (0..1 << config.fri_config.cap_height) + .map(|_| next_chunk::<_, 4>(&mut iter).to_vec()) + .collect(); + let ctl_challenges = GrandProductChallengeSet { + challenges: (0..config.num_challenges) + .map(|_| GrandProductChallenge { + beta: iter.next().unwrap(), + gamma: iter.next().unwrap(), + }) + .collect(), + }; + let challenger_state_before = next_chunk(&mut iter); + let challenger_state_after = next_chunk(&mut iter); + let ctl_zs_last = iter.collect(); + + Self { + trace_cap, + ctl_zs_last, + ctl_challenges, + challenger_state_before, + challenger_state_after, + } + } +} + +impl, C: GenericConfig, const D: usize> + RecursiveAllProof +{ + /// Verify every recursive proof. + pub fn verify( + self, + verifier_data: &[VerifierCircuitData; NUM_TABLES], + cross_table_lookups: Vec>, + inner_config: &StarkConfig, + ) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { + let pis: [_; NUM_TABLES] = std::array::from_fn(|i| { + PublicInputs::from_vec(&self.recursive_proofs[i].public_inputs, inner_config) + }); + + let mut challenger = Challenger::::new(); + for pi in &pis { + for h in &pi.trace_cap { + challenger.observe_elements(h); + } + } + let ctl_challenges = + get_grand_product_challenge_set(&mut challenger, inner_config.num_challenges); + // Check that the correct CTL challenges are used in every proof. + for pi in &pis { + ensure!(ctl_challenges == pi.ctl_challenges); + } + + let state = challenger.compact(); + ensure!(state == pis[0].challenger_state_before); + // Check that the challenger state is consistent between proofs. + for i in 1..NUM_TABLES { + ensure!(pis[i].challenger_state_before == pis[i - 1].challenger_state_after); + } + + // Verify the CTL checks. + let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits); + verify_cross_table_lookups::( + cross_table_lookups, + pis.map(|p| p.ctl_zs_last), + degrees_bits, + ctl_challenges, + inner_config, + )?; + + // Verify the proofs. + for (proof, verifier_data) in self.recursive_proofs.into_iter().zip(verifier_data) { + verifier_data.verify(proof)?; + } + Ok(()) + } + + /// Recursively verify every recursive proof. + pub fn verify_circuit( + builder: &mut CircuitBuilder, + recursive_all_proof_target: RecursiveAllProofTargetWithData, + verifier_data: &[VerifierCircuitData; NUM_TABLES], + cross_table_lookups: Vec>, + inner_config: &StarkConfig, + ) where + [(); C::Hasher::HASH_SIZE]:, + >::Hasher: AlgebraicHasher, + { + let RecursiveAllProofTargetWithData { + recursive_proofs, + verifier_data: verifier_data_target, + } = recursive_all_proof_target; + let pis: [_; NUM_TABLES] = std::array::from_fn(|i| { + PublicInputs::from_vec(&recursive_proofs[i].public_inputs, inner_config) + }); + + let mut challenger = RecursiveChallenger::::new(builder); + for pi in &pis { + for h in &pi.trace_cap { + challenger.observe_elements(h); + } + } + let ctl_challenges = get_grand_product_challenge_set_target( + builder, + &mut challenger, + inner_config.num_challenges, + ); + // Check that the correct CTL challenges are used in every proof. + for pi in &pis { + for i in 0..inner_config.num_challenges { + builder.connect( + ctl_challenges.challenges[i].beta, + pi.ctl_challenges.challenges[i].beta, + ); + builder.connect( + ctl_challenges.challenges[i].gamma, + pi.ctl_challenges.challenges[i].gamma, + ); + } + } + + let state = challenger.compact(builder); + for k in 0..SPONGE_WIDTH { + builder.connect(state[k], pis[0].challenger_state_before[k]); + } + // Check that the challenger state is consistent between proofs. + for i in 1..NUM_TABLES { + for k in 0..SPONGE_WIDTH { + builder.connect( + pis[i].challenger_state_before[k], + pis[i - 1].challenger_state_after[k], + ); + } + } + + // Verify the CTL checks. + let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits); + verify_cross_table_lookups_circuit::( + builder, + cross_table_lookups, + pis.map(|p| p.ctl_zs_last), + degrees_bits, + ctl_challenges, + inner_config, + ); + for (i, (recursive_proof, verifier_data_target)) in recursive_proofs + .into_iter() + .zip(verifier_data_target) + .enumerate() + { + builder.verify_proof( + recursive_proof, + &verifier_data_target, + &verifier_data[i].common, + ); + } + } +} + +/// Returns the verifier data for the recursive Stark circuit. +fn verifier_data_recursive_stark_proof< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + const D: usize, +>( + table: Table, + stark: S, + degree_bits: usize, + cross_table_lookups: &[CrossTableLookup], + inner_config: &StarkConfig, + circuit_config: &CircuitConfig, +) -> VerifierCircuitData +where + [(); S::COLUMNS]:, + [(); C::Hasher::HASH_SIZE]:, + C::Hasher: AlgebraicHasher, +{ + let mut builder = CircuitBuilder::::new(circuit_config.clone()); + + let num_permutation_zs = stark.num_permutation_batches(inner_config); + let num_permutation_batch_size = stark.permutation_batch_size(); + let num_ctl_zs = + CrossTableLookup::num_ctl_zs(cross_table_lookups, table, inner_config.num_challenges); + let proof_target = + add_virtual_stark_proof(&mut builder, &stark, inner_config, degree_bits, num_ctl_zs); + builder.register_public_inputs( + &proof_target + .trace_cap + .0 + .iter() + .flat_map(|h| h.elements) + .collect::>(), + ); + + let ctl_challenges_target = GrandProductChallengeSet { + challenges: (0..inner_config.num_challenges) + .map(|_| GrandProductChallenge { + beta: builder.add_virtual_public_input(), + gamma: builder.add_virtual_public_input(), + }) + .collect(), + }; + + let ctl_vars = CtlCheckVarsTarget::from_proof( + table, + &proof_target, + cross_table_lookups, + &ctl_challenges_target, + num_permutation_zs, + ); + + let challenger_state = std::array::from_fn(|_| builder.add_virtual_public_input()); + let mut challenger = RecursiveChallenger::::from_state(challenger_state); + let challenges = proof_target.get_challenges::( + &mut builder, + &mut challenger, + num_permutation_zs > 0, + num_permutation_batch_size, + inner_config, + ); + let challenger_state = challenger.compact(&mut builder); + builder.register_public_inputs(&challenger_state); + + builder.register_public_inputs(&proof_target.openings.ctl_zs_last); + + verify_stark_proof_with_challenges_circuit::( + &mut builder, + &stark, + &proof_target, + &challenges, + &ctl_vars, + inner_config, + ); + + builder.build_verifier::() +} + +/// Returns the recursive Stark circuit verifier data for every Stark in `AllStark`. +pub fn all_verifier_data_recursive_stark_proof< F: RichField + Extendable, C: GenericConfig, const D: usize, >( - builder: &mut CircuitBuilder, - all_stark: AllStark, - all_proof: AllProofTarget, + all_stark: &AllStark, + degree_bits: [usize; NUM_TABLES], inner_config: &StarkConfig, -) where + circuit_config: &CircuitConfig, +) -> [VerifierCircuitData; NUM_TABLES] +where [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakMemoryStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, [(); MemoryStark::::COLUMNS]:, + [(); C::Hasher::HASH_SIZE]:, C::Hasher: AlgebraicHasher, { - let AllProofChallengesTarget { - stark_challenges, - ctl_challenges, - } = all_proof.get_challenges::(builder, &all_stark, inner_config); - - let nums_permutation_zs = all_stark.nums_permutation_zs(inner_config); - - let AllStark { - cpu_stark, - keccak_stark, - keccak_memory_stark, - logic_stark, - memory_stark, - cross_table_lookups, - } = all_stark; - - let ctl_vars_per_table = CtlCheckVarsTarget::from_proofs( - &all_proof.stark_proofs, - &cross_table_lookups, - &ctl_challenges, - &nums_permutation_zs, - ); - - with_context!( - builder, - "verify CPU proof", - verify_stark_proof_with_challenges_circuit::( - builder, - cpu_stark, - &all_proof.stark_proofs[Table::Cpu as usize], - &stark_challenges[Table::Cpu as usize], - &ctl_vars_per_table[Table::Cpu as usize], + [ + verifier_data_recursive_stark_proof( + Table::Cpu, + all_stark.cpu_stark, + degree_bits[Table::Cpu as usize], + &all_stark.cross_table_lookups, inner_config, - ) - ); - with_context!( - builder, - "verify Keccak proof", - verify_stark_proof_with_challenges_circuit::( - builder, - keccak_stark, - &all_proof.stark_proofs[Table::Keccak as usize], - &stark_challenges[Table::Keccak as usize], - &ctl_vars_per_table[Table::Keccak as usize], + circuit_config, + ), + verifier_data_recursive_stark_proof( + Table::Keccak, + all_stark.keccak_stark, + degree_bits[Table::Keccak as usize], + &all_stark.cross_table_lookups, inner_config, - ) - ); - with_context!( - builder, - "verify Keccak memory proof", - verify_stark_proof_with_challenges_circuit::( - builder, - keccak_memory_stark, - &all_proof.stark_proofs[Table::KeccakMemory as usize], - &stark_challenges[Table::KeccakMemory as usize], - &ctl_vars_per_table[Table::KeccakMemory as usize], + circuit_config, + ), + verifier_data_recursive_stark_proof( + Table::KeccakMemory, + all_stark.keccak_memory_stark, + degree_bits[Table::KeccakMemory as usize], + &all_stark.cross_table_lookups, inner_config, - ) - ); - with_context!( - builder, - "verify logic proof", - verify_stark_proof_with_challenges_circuit::( - builder, - logic_stark, - &all_proof.stark_proofs[Table::Logic as usize], - &stark_challenges[Table::Logic as usize], - &ctl_vars_per_table[Table::Logic as usize], + circuit_config, + ), + verifier_data_recursive_stark_proof( + Table::Logic, + all_stark.logic_stark, + degree_bits[Table::Logic as usize], + &all_stark.cross_table_lookups, inner_config, - ) - ); - with_context!( - builder, - "verify memory proof", - verify_stark_proof_with_challenges_circuit::( - builder, - memory_stark, - &all_proof.stark_proofs[Table::Memory as usize], - &stark_challenges[Table::Memory as usize], - &ctl_vars_per_table[Table::Memory as usize], + circuit_config, + ), + verifier_data_recursive_stark_proof( + Table::Memory, + all_stark.memory_stark, + degree_bits[Table::Memory as usize], + &all_stark.cross_table_lookups, inner_config, - ) - ); - - with_context!( - builder, - "verify cross-table lookups", - verify_cross_table_lookups_circuit::( - builder, - cross_table_lookups, - &all_proof.stark_proofs, - ctl_challenges, - inner_config, - ) - ); + circuit_config, + ), + ] } /// Recursively verifies an inner proof. @@ -156,7 +395,7 @@ fn verify_stark_proof_with_challenges_circuit< const D: usize, >( builder: &mut CircuitBuilder, - stark: S, + stark: &S, proof: &StarkProofTarget, challenges: &StarkProofChallengesTarget, ctl_vars: &[CtlCheckVarsTarget], @@ -212,7 +451,7 @@ fn verify_stark_proof_with_challenges_circuit< "evaluate vanishing polynomial", eval_vanishing_poly_circuit::( builder, - &stark, + stark, inner_config, vars, permutation_data, @@ -286,35 +525,35 @@ pub fn add_virtual_all_proof, const D: usize>( let stark_proofs = [ add_virtual_stark_proof( builder, - all_stark.cpu_stark, + &all_stark.cpu_stark, config, degree_bits[Table::Cpu as usize], nums_ctl_zs[Table::Cpu as usize], ), add_virtual_stark_proof( builder, - all_stark.keccak_stark, + &all_stark.keccak_stark, config, degree_bits[Table::Keccak as usize], nums_ctl_zs[Table::Keccak as usize], ), add_virtual_stark_proof( builder, - all_stark.keccak_memory_stark, + &all_stark.keccak_memory_stark, config, degree_bits[Table::KeccakMemory as usize], nums_ctl_zs[Table::KeccakMemory as usize], ), add_virtual_stark_proof( builder, - all_stark.logic_stark, + &all_stark.logic_stark, config, degree_bits[Table::Logic as usize], nums_ctl_zs[Table::Logic as usize], ), add_virtual_stark_proof( builder, - all_stark.memory_stark, + &all_stark.memory_stark, config, degree_bits[Table::Memory as usize], nums_ctl_zs[Table::Memory as usize], @@ -328,6 +567,33 @@ pub fn add_virtual_all_proof, const D: usize>( } } +/// Returns `RecursiveAllProofTargetWithData` where the proofs targets are virtual and the +/// verifier data targets are constants. +pub fn add_virtual_recursive_all_proof, H, C, const D: usize>( + builder: &mut CircuitBuilder, + verifier_data: &[VerifierCircuitData; NUM_TABLES], +) -> RecursiveAllProofTargetWithData +where + H: Hasher>, + C: GenericConfig, +{ + let recursive_proofs = std::array::from_fn(|i| { + let verifier_data = &verifier_data[i]; + builder.add_virtual_proof_with_pis(&verifier_data.common) + }); + let verifier_data = std::array::from_fn(|i| { + let verifier_data = &verifier_data[i]; + VerifierCircuitTarget { + constants_sigmas_cap: builder + .constant_merkle_cap(&verifier_data.verifier_only.constants_sigmas_cap), + } + }); + RecursiveAllProofTargetWithData { + recursive_proofs, + verifier_data, + } +} + pub fn add_virtual_public_values, const D: usize>( builder: &mut CircuitBuilder, ) -> PublicValuesTarget { @@ -377,7 +643,7 @@ pub fn add_virtual_block_metadata, const D: usize>( pub fn add_virtual_stark_proof, S: Stark, const D: usize>( builder: &mut CircuitBuilder, - stark: S, + stark: &S, config: &StarkConfig, degree_bits: usize, num_ctl_zs: usize, @@ -397,14 +663,14 @@ pub fn add_virtual_stark_proof, S: Stark, con trace_cap: builder.add_virtual_cap(cap_height), permutation_ctl_zs_cap: permutation_zs_cap, quotient_polys_cap: builder.add_virtual_cap(cap_height), - openings: add_stark_opening_set::(builder, stark, num_ctl_zs, config), + openings: add_virtual_stark_opening_set::(builder, stark, num_ctl_zs, config), opening_proof: builder.add_virtual_fri_proof(&num_leaves_per_oracle, &fri_params), } } -fn add_stark_opening_set, S: Stark, const D: usize>( +fn add_virtual_stark_opening_set, S: Stark, const D: usize>( builder: &mut CircuitBuilder, - stark: S, + stark: &S, num_ctl_zs: usize, config: &StarkConfig, ) -> StarkOpeningSetTarget { @@ -422,6 +688,22 @@ fn add_stark_opening_set, S: Stark, const D: } } +pub fn set_recursive_all_proof_target, W, const D: usize>( + witness: &mut W, + recursive_all_proof_target: &RecursiveAllProofTargetWithData, + all_proof: &RecursiveAllProof, +) where + F: RichField + Extendable, + C::Hasher: AlgebraicHasher, + W: Witness, +{ + for i in 0..NUM_TABLES { + witness.set_proof_with_pis_target( + &recursive_all_proof_target.recursive_proofs[i], + &all_proof.recursive_proofs[i], + ); + } +} pub fn set_all_proof_target, W, const D: usize>( witness: &mut W, all_proof_target: &AllProofTarget, @@ -556,3 +838,219 @@ pub fn set_block_metadata_target( F::from_canonical_u64(block_metadata.block_base_fee.as_u64()), ); } + +#[cfg(test)] +pub(crate) mod tests { + use anyhow::Result; + use plonky2::field::extension::Extendable; + use plonky2::hash::hash_types::RichField; + use plonky2::hash::hashing::SPONGE_WIDTH; + use plonky2::iop::challenger::RecursiveChallenger; + use plonky2::iop::witness::{PartialWitness, Witness}; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::{CircuitConfig, VerifierCircuitData}; + use plonky2::plonk::config::Hasher; + use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; + use plonky2::plonk::proof::ProofWithPublicInputs; + + use crate::all_stark::{AllStark, Table}; + use crate::config::StarkConfig; + use crate::cpu::cpu_stark::CpuStark; + use crate::cross_table_lookup::{CrossTableLookup, CtlCheckVarsTarget}; + use crate::keccak::keccak_stark::KeccakStark; + use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; + use crate::logic::LogicStark; + use crate::memory::memory_stark::MemoryStark; + use crate::permutation::{GrandProductChallenge, GrandProductChallengeSet}; + use crate::proof::{AllChallengerState, AllProof, StarkProof}; + use crate::recursive_verifier::{ + add_virtual_stark_proof, set_stark_proof_target, + verify_stark_proof_with_challenges_circuit, RecursiveAllProof, + }; + use crate::stark::Stark; + + /// Recursively verify a Stark proof. + /// Outputs the recursive proof and the associated verifier data. + fn recursively_verify_stark_proof< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + const D: usize, + >( + table: Table, + stark: S, + proof: &StarkProof, + cross_table_lookups: &[CrossTableLookup], + ctl_challenges: &GrandProductChallengeSet, + challenger_state_before_vals: [F; SPONGE_WIDTH], + inner_config: &StarkConfig, + circuit_config: &CircuitConfig, + ) -> Result<(ProofWithPublicInputs, VerifierCircuitData)> + where + [(); S::COLUMNS]:, + [(); C::Hasher::HASH_SIZE]:, + C::Hasher: AlgebraicHasher, + { + let mut builder = CircuitBuilder::::new(circuit_config.clone()); + let mut pw = PartialWitness::new(); + + let num_permutation_zs = stark.num_permutation_batches(inner_config); + let num_permutation_batch_size = stark.permutation_batch_size(); + let proof_target = add_virtual_stark_proof( + &mut builder, + &stark, + inner_config, + proof.recover_degree_bits(inner_config), + proof.num_ctl_zs(), + ); + set_stark_proof_target(&mut pw, &proof_target, proof, builder.zero()); + builder.register_public_inputs( + &proof_target + .trace_cap + .0 + .iter() + .flat_map(|h| h.elements) + .collect::>(), + ); + + let ctl_challenges_target = GrandProductChallengeSet { + challenges: (0..inner_config.num_challenges) + .map(|_| GrandProductChallenge { + beta: builder.add_virtual_public_input(), + gamma: builder.add_virtual_public_input(), + }) + .collect(), + }; + for i in 0..inner_config.num_challenges { + pw.set_target( + ctl_challenges_target.challenges[i].beta, + ctl_challenges.challenges[i].beta, + ); + pw.set_target( + ctl_challenges_target.challenges[i].gamma, + ctl_challenges.challenges[i].gamma, + ); + } + + let ctl_vars = CtlCheckVarsTarget::from_proof( + table, + &proof_target, + cross_table_lookups, + &ctl_challenges_target, + num_permutation_zs, + ); + + let challenger_state_before = std::array::from_fn(|_| builder.add_virtual_public_input()); + pw.set_target_arr(challenger_state_before, challenger_state_before_vals); + let mut challenger = + RecursiveChallenger::::from_state(challenger_state_before); + let challenges = proof_target.get_challenges::( + &mut builder, + &mut challenger, + num_permutation_zs > 0, + num_permutation_batch_size, + inner_config, + ); + let challenger_state_after = challenger.compact(&mut builder); + builder.register_public_inputs(&challenger_state_after); + + builder.register_public_inputs(&proof_target.openings.ctl_zs_last); + + verify_stark_proof_with_challenges_circuit::( + &mut builder, + &stark, + &proof_target, + &challenges, + &ctl_vars, + inner_config, + ); + + let data = builder.build::(); + Ok((data.prove(pw)?, data.verifier_data())) + } + + /// Recursively verify every Stark proof in an `AllProof`. + pub fn recursively_verify_all_proof< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, + >( + all_stark: &AllStark, + all_proof: &AllProof, + inner_config: &StarkConfig, + circuit_config: &CircuitConfig, + ) -> Result> + where + [(); CpuStark::::COLUMNS]:, + [(); KeccakStark::::COLUMNS]:, + [(); KeccakMemoryStark::::COLUMNS]:, + [(); LogicStark::::COLUMNS]:, + [(); MemoryStark::::COLUMNS]:, + [(); C::Hasher::HASH_SIZE]:, + C::Hasher: AlgebraicHasher, + { + let AllChallengerState { + states, + ctl_challenges, + } = all_proof.get_challenger_states(all_stark, inner_config); + Ok(RecursiveAllProof { + recursive_proofs: [ + recursively_verify_stark_proof( + Table::Cpu, + all_stark.cpu_stark, + &all_proof.stark_proofs[Table::Cpu as usize], + &all_stark.cross_table_lookups, + &ctl_challenges, + states[0], + inner_config, + circuit_config, + )? + .0, + recursively_verify_stark_proof( + Table::Keccak, + all_stark.keccak_stark, + &all_proof.stark_proofs[Table::Keccak as usize], + &all_stark.cross_table_lookups, + &ctl_challenges, + states[1], + inner_config, + circuit_config, + )? + .0, + recursively_verify_stark_proof( + Table::KeccakMemory, + all_stark.keccak_memory_stark, + &all_proof.stark_proofs[Table::KeccakMemory as usize], + &all_stark.cross_table_lookups, + &ctl_challenges, + states[2], + inner_config, + circuit_config, + )? + .0, + recursively_verify_stark_proof( + Table::Logic, + all_stark.logic_stark, + &all_proof.stark_proofs[Table::Logic as usize], + &all_stark.cross_table_lookups, + &ctl_challenges, + states[3], + inner_config, + circuit_config, + )? + .0, + recursively_verify_stark_proof( + Table::Memory, + all_stark.memory_stark, + &all_proof.stark_proofs[Table::Memory as usize], + &all_stark.cross_table_lookups, + &ctl_challenges, + states[4], + inner_config, + circuit_config, + )? + .0, + ], + }) + } +} diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index 53ac3c7c..0bfbc3d4 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -95,9 +95,12 @@ where config, )?; - verify_cross_table_lookups( + let degrees_bits = + std::array::from_fn(|i| all_proof.stark_proofs[i].recover_degree_bits(config)); + verify_cross_table_lookups::( cross_table_lookups, - &all_proof.stark_proofs, + all_proof.stark_proofs.map(|p| p.openings.ctl_zs_last), + degrees_bits, ctl_challenges, config, ) diff --git a/field/src/goldilocks_field.rs b/field/src/goldilocks_field.rs index c1bb60b0..e036ab9b 100644 --- a/field/src/goldilocks_field.rs +++ b/field/src/goldilocks_field.rs @@ -130,6 +130,18 @@ impl Field64 for GoldilocksField { Self(n) } + #[inline] + fn from_noncanonical_i64(n: i64) -> Self { + Self::from_canonical_u64(if n < 0 { + // If n < 0, then this is guaranteed to overflow since + // both arguments have their high bit set, so the result + // is in the canonical range. + Self::ORDER.wrapping_add(n as u64) + } else { + n as u64 + }) + } + #[inline] unsafe fn add_canonical_u64(&self, rhs: u64) -> Self { let (res_wrapped, carry) = self.0.overflowing_add(rhs); diff --git a/field/src/types.rs b/field/src/types.rs index 7130b7f5..b112fde2 100644 --- a/field/src/types.rs +++ b/field/src/types.rs @@ -490,6 +490,18 @@ pub trait Field64: Field { // TODO: Move to `Field`. fn from_noncanonical_u64(n: u64) -> Self; + /// Returns `n` as an element of this field. + // TODO: Move to `Field`. + fn from_noncanonical_i64(n: i64) -> Self; + + /// Returns `n` as an element of this field. Assumes that `0 <= n < Self::ORDER`. + // TODO: Move to `Field`. + // TODO: Should probably be unsafe. + #[inline] + fn from_canonical_i64(n: i64) -> Self { + Self::from_canonical_u64(n as u64) + } + #[inline] // TODO: Move to `Field`. fn add_one(&self) -> Self { diff --git a/plonky2/src/fri/mod.rs b/plonky2/src/fri/mod.rs index 38286312..9c44b53b 100644 --- a/plonky2/src/fri/mod.rs +++ b/plonky2/src/fri/mod.rs @@ -46,6 +46,10 @@ impl FriConfig { reduction_arity_bits, } } + + pub fn num_cap_elements(&self) -> usize { + 1 << self.cap_height + } } /// FRI parameters, including generated parameters which are specific to an instance size, in diff --git a/plonky2/src/iop/challenger.rs b/plonky2/src/iop/challenger.rs index 97d21197..c601ae0f 100644 --- a/plonky2/src/iop/challenger.rs +++ b/plonky2/src/iop/challenger.rs @@ -146,6 +146,14 @@ impl> Challenger { self.output_buffer .extend_from_slice(&self.sponge_state[0..SPONGE_RATE]); } + + pub fn compact(&mut self) -> [F; SPONGE_WIDTH] { + if !self.input_buffer.is_empty() { + self.duplexing(); + } + self.output_buffer.clear(); + self.sponge_state + } } impl> Default for Challenger { @@ -176,6 +184,14 @@ impl, H: AlgebraicHasher, const D: usize> } } + pub fn from_state(sponge_state: [Target; SPONGE_WIDTH]) -> Self { + RecursiveChallenger { + sponge_state, + input_buffer: vec![], + output_buffer: vec![], + } + } + pub(crate) fn observe_element(&mut self, target: Target) { // Any buffered outputs are now invalid, since they wouldn't reflect this input. self.output_buffer.clear(); @@ -183,7 +199,7 @@ impl, H: AlgebraicHasher, const D: usize> self.input_buffer.push(target); } - pub(crate) fn observe_elements(&mut self, targets: &[Target]) { + pub fn observe_elements(&mut self, targets: &[Target]) { for &target in targets { self.observe_element(target); } @@ -272,6 +288,12 @@ impl, H: AlgebraicHasher, const D: usize> self.input_buffer.clear(); } + + pub fn compact(&mut self, builder: &mut CircuitBuilder) -> [Target; SPONGE_WIDTH] { + self.absorb_buffered_inputs(builder); + self.output_buffer.clear(); + self.sponge_state + } } #[cfg(test)] diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 1e4067df..521e35c5 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -23,8 +23,9 @@ use crate::gates::gate::{CurrentSlot, Gate, GateInstance, GateRef}; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; use crate::gates::selectors::selector_polynomials; -use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; +use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget, RichField}; use crate::hash::merkle_proofs::MerkleProofTarget; +use crate::hash::merkle_tree::MerkleCap; use crate::iop::ext_target::ExtensionTarget; use crate::iop::generator::{ ConstantGenerator, CopyGenerator, RandomValueGenerator, SimpleGenerator, WitnessGenerator, @@ -208,6 +209,13 @@ impl, const D: usize> CircuitBuilder { b } + /// Add a virtual target and register it as a public input. + pub fn add_virtual_public_input(&mut self) -> Target { + let t = self.add_virtual_target(); + self.register_public_input(t); + t + } + /// Adds a gate to the circuit, and returns its index. pub fn add_gate>(&mut self, gate_type: G, mut constants: Vec) -> usize { self.check_gate_compatibility(&gate_type); @@ -365,6 +373,19 @@ impl, const D: usize> CircuitBuilder { } } + pub fn constant_hash(&mut self, h: HashOut) -> HashOutTarget { + HashOutTarget { + elements: h.elements.map(|x| self.constant(x)), + } + } + + pub fn constant_merkle_cap>>( + &mut self, + cap: &MerkleCap, + ) -> MerkleCapTarget { + MerkleCapTarget(cap.0.iter().map(|h| self.constant_hash(*h)).collect()) + } + /// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns /// its constant value. Otherwise, returns `None`. pub fn target_as_constant(&self, target: Target) -> Option { @@ -839,15 +860,8 @@ impl, const D: usize> CircuitBuilder { [(); C::Hasher::HASH_SIZE]:, { // TODO: Can skip parts of this. - let CircuitData { - prover_only, - common, - .. - } = self.build(); - ProverCircuitData { - prover_only, - common, - } + let circuit_data = self.build(); + circuit_data.prover_data() } /// Builds a "verifier circuit", with data needed to verify proofs but not generate them. @@ -856,14 +870,7 @@ impl, const D: usize> CircuitBuilder { [(); C::Hasher::HASH_SIZE]:, { // TODO: Can skip parts of this. - let CircuitData { - verifier_only, - common, - .. - } = self.build(); - VerifierCircuitData { - verifier_only, - common, - } + let circuit_data = self.build(); + circuit_data.verifier_data() } } diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index c22eae3e..851485cd 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -140,6 +140,30 @@ impl, C: GenericConfig, const D: usize> { compressed_proof_with_pis.verify(&self.verifier_only, &self.common) } + + pub fn verifier_data(self) -> VerifierCircuitData { + let CircuitData { + verifier_only, + common, + .. + } = self; + VerifierCircuitData { + verifier_only, + common, + } + } + + pub fn prover_data(self) -> ProverCircuitData { + let CircuitData { + prover_only, + common, + .. + } = self; + ProverCircuitData { + prover_only, + common, + } + } } /// Circuit data required by the prover. This may be thought of as a proving key, although it