diff --git a/evm/src/arithmetic/add.rs b/evm/src/arithmetic/add.rs index b09307b0..4e9de4b3 100644 --- a/evm/src/arithmetic/add.rs +++ b/evm/src/arithmetic/add.rs @@ -35,6 +35,7 @@ pub(crate) fn eval_packed_generic_are_equal( is_op: P, larger: I, smaller: J, + is_two_row_op: bool, ) -> P where P: PackedField, @@ -47,7 +48,11 @@ where for (a, b) in larger.zip(smaller) { // t should be either 0 or 2^LIMB_BITS let t = cy + a - b; - yield_constr.constraint(is_op * t * (overflow - t)); + if is_two_row_op { + yield_constr.constraint_transition(is_op * t * (overflow - t)); + } else { + yield_constr.constraint(is_op * t * (overflow - t)); + } // cy <-- 0 or 1 // NB: this is multiplication by a constant, so doesn't // increase the degree of the constraint. @@ -62,6 +67,7 @@ pub(crate) fn eval_ext_circuit_are_equal( is_op: ExtensionTarget, larger: I, smaller: J, + is_two_row_op: bool, ) -> ExtensionTarget where F: RichField + Extendable, @@ -87,7 +93,11 @@ where let t2 = builder.mul_extension(t, t1); let filtered_limb_constraint = builder.mul_extension(is_op, t2); - yield_constr.constraint(builder, filtered_limb_constraint); + if is_two_row_op { + yield_constr.constraint_transition(builder, filtered_limb_constraint); + } else { + yield_constr.constraint(builder, filtered_limb_constraint); + } cy = builder.mul_const_extension(overflow_inv, t); } @@ -125,6 +135,7 @@ pub fn eval_packed_generic( is_add, output_computed, output_limbs.iter().copied(), + false, ); } @@ -155,6 +166,7 @@ pub fn eval_ext_circuit, const D: usize>( is_add, output_computed.into_iter(), output_limbs.iter().copied(), + false, ); } diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 5d835e77..5790ae66 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -17,7 +17,11 @@ pub struct ArithmeticStark { } impl ArithmeticStark { - pub fn generate(&self, local_values: &mut [F; columns::NUM_ARITH_COLUMNS]) { + pub fn generate( + &self, + local_values: &mut [F; columns::NUM_ARITH_COLUMNS], + next_values: &mut [F; columns::NUM_ARITH_COLUMNS], + ) { // Check that at most one operation column is "one" and that the // rest are "zero". assert_eq!( @@ -47,17 +51,17 @@ impl ArithmeticStark { } else if local_values[columns::IS_GT].is_one() { compare::generate(local_values, columns::IS_GT); } else if local_values[columns::IS_ADDMOD].is_one() { - modular::generate(local_values, columns::IS_ADDMOD); + modular::generate(local_values, next_values, columns::IS_ADDMOD); } else if local_values[columns::IS_SUBMOD].is_one() { - modular::generate(local_values, columns::IS_SUBMOD); + modular::generate(local_values, next_values, columns::IS_SUBMOD); } else if local_values[columns::IS_MULMOD].is_one() { - modular::generate(local_values, columns::IS_MULMOD); + modular::generate(local_values, next_values, columns::IS_MULMOD); } else if local_values[columns::IS_MOD].is_one() { - modular::generate(local_values, columns::IS_MOD); + modular::generate(local_values, next_values, columns::IS_MOD); } else if local_values[columns::IS_DIV].is_one() { - modular::generate(local_values, columns::IS_DIV); + modular::generate(local_values, next_values, columns::IS_DIV); } else { - todo!("the requested operation has not yet been implemented"); + panic!("the requested operation should not be handled by the arithmetic table"); } } } @@ -74,11 +78,12 @@ impl, const D: usize> Stark for ArithmeticSta P: PackedField, { let lv = vars.local_values; + let nv = vars.next_values; add::eval_packed_generic(lv, yield_constr); sub::eval_packed_generic(lv, yield_constr); mul::eval_packed_generic(lv, yield_constr); compare::eval_packed_generic(lv, yield_constr); - modular::eval_packed_generic(lv, yield_constr); + modular::eval_packed_generic(lv, nv, yield_constr); } fn eval_ext_circuit( @@ -88,11 +93,12 @@ impl, const D: usize> Stark for ArithmeticSta yield_constr: &mut RecursiveConstraintConsumer, ) { let lv = vars.local_values; + let nv = vars.next_values; add::eval_ext_circuit(builder, lv, yield_constr); sub::eval_ext_circuit(builder, lv, yield_constr); mul::eval_ext_circuit(builder, lv, yield_constr); compare::eval_ext_circuit(builder, lv, yield_constr); - modular::eval_ext_circuit(builder, lv, yield_constr); + modular::eval_ext_circuit(builder, lv, nv, yield_constr); } fn constraint_degree(&self) -> usize { diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index 923fbc73..779be2ee 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -12,7 +12,11 @@ const fn n_limbs() -> usize { if EVM_REGISTER_BITS % LIMB_BITS != 0 { panic!("limb size must divide EVM register size"); } - EVM_REGISTER_BITS / LIMB_BITS + let n = EVM_REGISTER_BITS / LIMB_BITS; + if n % 2 == 1 { + panic!("number of limbs must be even"); + } + n } /// Number of LIMB_BITS limbs that are in on EVM register-sized number. @@ -40,43 +44,66 @@ pub(crate) const ALL_OPERATIONS: [usize; 12] = [ /// Within the Arithmetic Unit, there are shared columns which can be /// used by any arithmetic circuit, depending on which one is active -/// this cycle. Can be increased as needed as other operations are -/// implemented. -const NUM_SHARED_COLS: usize = 9 * N_LIMBS; // only need 64 for add, sub, and mul +/// this cycle. +/// +/// Modular arithmetic takes 9 * N_LIMBS columns which is split across +/// two rows, the first with 5 * N_LIMBS columns and the second with +/// 4 * N_LIMBS columns. (There are hence N_LIMBS "wasted columns" in +/// the second row.) +const NUM_SHARED_COLS: usize = 5 * N_LIMBS; const GENERAL_INPUT_0: Range = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS; const GENERAL_INPUT_1: Range = GENERAL_INPUT_0.end..GENERAL_INPUT_0.end + N_LIMBS; const GENERAL_INPUT_2: Range = GENERAL_INPUT_1.end..GENERAL_INPUT_1.end + N_LIMBS; const GENERAL_INPUT_3: Range = GENERAL_INPUT_2.end..GENERAL_INPUT_2.end + N_LIMBS; -const AUX_INPUT_0: Range = GENERAL_INPUT_3.end..GENERAL_INPUT_3.end + 2 * N_LIMBS; -const AUX_INPUT_1: Range = AUX_INPUT_0.end..AUX_INPUT_0.end + 2 * N_LIMBS; +const AUX_INPUT_0_LO: Range = GENERAL_INPUT_3.end..GENERAL_INPUT_3.end + N_LIMBS; + +// The auxiliary input columns overlap the general input columns +// because they correspond to the values in the second row for modular +// operations. +const AUX_INPUT_0_HI: Range = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS; +const AUX_INPUT_1: Range = AUX_INPUT_0_HI.end..AUX_INPUT_0_HI.end + 2 * N_LIMBS; +// These auxiliary input columns are awkwardly split across two rows, +// with the first half after the general input columns and the second +// half after the auxiliary input columns. const AUX_INPUT_2: Range = AUX_INPUT_1.end..AUX_INPUT_1.end + N_LIMBS; +// ADD takes 3 * N_LIMBS = 48 columns pub(crate) const ADD_INPUT_0: Range = GENERAL_INPUT_0; pub(crate) const ADD_INPUT_1: Range = GENERAL_INPUT_1; pub(crate) const ADD_OUTPUT: Range = GENERAL_INPUT_2; +// SUB takes 3 * N_LIMBS = 48 columns pub(crate) const SUB_INPUT_0: Range = GENERAL_INPUT_0; pub(crate) const SUB_INPUT_1: Range = GENERAL_INPUT_1; pub(crate) const SUB_OUTPUT: Range = GENERAL_INPUT_2; +// MUL takes 4 * N_LIMBS = 64 columns pub(crate) const MUL_INPUT_0: Range = GENERAL_INPUT_0; pub(crate) const MUL_INPUT_1: Range = GENERAL_INPUT_1; pub(crate) const MUL_OUTPUT: Range = GENERAL_INPUT_2; pub(crate) const MUL_AUX_INPUT: Range = GENERAL_INPUT_3; +// LT and GT take 4 * N_LIMBS = 64 columns pub(crate) const CMP_INPUT_0: Range = GENERAL_INPUT_0; pub(crate) const CMP_INPUT_1: Range = GENERAL_INPUT_1; pub(crate) const CMP_OUTPUT: usize = GENERAL_INPUT_2.start; pub(crate) const CMP_AUX_INPUT: Range = GENERAL_INPUT_3; +// MULMOD takes 4 * N_LIMBS + 2 * 2*N_LIMBS + N_LIMBS = 144 columns +// but split over two rows of 80 columns and 64 columns. +// +// ADDMOD, SUBMOD, MOD and DIV are currently implemented in terms of +// the general modular code, so they also take 144 columns (also split +// over two rows). pub(crate) const MODULAR_INPUT_0: Range = GENERAL_INPUT_0; pub(crate) const MODULAR_INPUT_1: Range = GENERAL_INPUT_1; pub(crate) const MODULAR_MODULUS: Range = GENERAL_INPUT_2; pub(crate) const MODULAR_OUTPUT: Range = GENERAL_INPUT_3; -pub(crate) const MODULAR_QUO_INPUT: Range = AUX_INPUT_0; +pub(crate) const MODULAR_QUO_INPUT_LO: Range = AUX_INPUT_0_LO; // NB: Last value is not used in AUX, it is used in MOD_IS_ZERO -pub(crate) const MODULAR_AUX_INPUT: Range = AUX_INPUT_1; +pub(crate) const MODULAR_QUO_INPUT_HI: Range = AUX_INPUT_0_HI; +pub(crate) const MODULAR_AUX_INPUT: Range = AUX_INPUT_1.start..AUX_INPUT_1.end - 1; pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_INPUT_1.end - 1; pub(crate) const MODULAR_OUT_AUX_RED: Range = AUX_INPUT_2; @@ -85,6 +112,6 @@ pub(crate) const DIV_NUMERATOR: Range = MODULAR_INPUT_0; #[allow(unused)] // TODO: Will be used when hooking into the CPU pub(crate) const DIV_DENOMINATOR: Range = MODULAR_MODULUS; #[allow(unused)] // TODO: Will be used when hooking into the CPU -pub(crate) const DIV_OUTPUT: Range = MODULAR_QUO_INPUT.start..MODULAR_QUO_INPUT.start + 16; +pub(crate) const DIV_OUTPUT: Range = MODULAR_QUO_INPUT_LO; pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS; diff --git a/evm/src/arithmetic/compare.rs b/evm/src/arithmetic/compare.rs index 7a360430..780053ce 100644 --- a/evm/src/arithmetic/compare.rs +++ b/evm/src/arithmetic/compare.rs @@ -57,16 +57,27 @@ pub(crate) fn eval_packed_generic_lt( input1: &[P], aux: &[P], output: P, + is_two_row_op: bool, ) { debug_assert!(input0.len() == N_LIMBS && input1.len() == N_LIMBS && aux.len() == N_LIMBS); // Verify (input0 < input1) == output by providing aux such that // input0 - input1 == aux + output*2^256. let lhs_limbs = input0.iter().zip(input1).map(|(&a, &b)| a - b); - let cy = eval_packed_generic_are_equal(yield_constr, is_op, aux.iter().copied(), lhs_limbs); + let cy = eval_packed_generic_are_equal( + yield_constr, + is_op, + aux.iter().copied(), + lhs_limbs, + is_two_row_op, + ); // We don't need to check that cy is 0 or 1, since output has // already been checked to be 0 or 1. - yield_constr.constraint(is_op * (cy - output)); + if is_two_row_op { + yield_constr.constraint_transition(is_op * (cy - output)); + } else { + yield_constr.constraint(is_op * (cy - output)); + } } pub fn eval_packed_generic( @@ -88,8 +99,8 @@ pub fn eval_packed_generic( let is_cmp = is_lt + is_gt; eval_packed_generic_check_is_one_bit(yield_constr, is_cmp, output); - eval_packed_generic_lt(yield_constr, is_lt, input0, input1, aux, output); - eval_packed_generic_lt(yield_constr, is_gt, input1, input0, aux, output); + eval_packed_generic_lt(yield_constr, is_lt, input0, input1, aux, output, false); + eval_packed_generic_lt(yield_constr, is_gt, input1, input0, aux, output, false); } fn eval_ext_circuit_check_is_one_bit, const D: usize>( @@ -112,6 +123,7 @@ pub(crate) fn eval_ext_circuit_lt, const D: usize>( input1: &[ExtensionTarget], aux: &[ExtensionTarget], output: ExtensionTarget, + is_two_row_op: bool, ) { debug_assert!(input0.len() == N_LIMBS && input1.len() == N_LIMBS && aux.len() == N_LIMBS); @@ -131,10 +143,11 @@ pub(crate) fn eval_ext_circuit_lt, const D: usize>( is_op, aux.iter().copied(), lhs_limbs.into_iter(), + is_two_row_op, ); let good_output = builder.sub_extension(cy, output); let filter = builder.mul_extension(is_op, good_output); - yield_constr.constraint(builder, filter); + yield_constr.constraint_transition(builder, filter); } pub fn eval_ext_circuit, const D: usize>( @@ -153,8 +166,26 @@ pub fn eval_ext_circuit, const D: usize>( let is_cmp = builder.add_extension(is_lt, is_gt); eval_ext_circuit_check_is_one_bit(builder, yield_constr, is_cmp, output); - eval_ext_circuit_lt(builder, yield_constr, is_lt, input0, input1, aux, output); - eval_ext_circuit_lt(builder, yield_constr, is_gt, input1, input0, aux, output); + eval_ext_circuit_lt( + builder, + yield_constr, + is_lt, + input0, + input1, + aux, + output, + false, + ); + eval_ext_circuit_lt( + builder, + yield_constr, + is_gt, + input1, + input0, + aux, + output, + false, + ); } #[cfg(test)] diff --git a/evm/src/arithmetic/modular.rs b/evm/src/arithmetic/modular.rs index 09c3996e..46f6c0fa 100644 --- a/evm/src/arithmetic/modular.rs +++ b/evm/src/arithmetic/modular.rs @@ -86,6 +86,27 @@ //! //! In the case of DIV, we do something similar, except that we "replace" //! the modulus with "2^256" to force the quotient to be zero. +//! +//! -*- +//! +//! NB: The implementation uses 9 * N_LIMBS = 144 columns because of +//! the requirements of the general purpose MULMOD; since ADDMOD, +//! SUBMOD, MOD and DIV are currently implemented in terms of the +//! general modular code, they also take 144 columns. Possible +//! improvements: +//! +//! - We could reduce the number of columns to 112 for ADDMOD, SUBMOD, +//! etc. if they were implemented separately, so they don't pay the +//! full cost of the general MULMOD. +//! +//! - All these operations could have alternative forms where the +//! output was not guaranteed to be reduced, which is often sufficient +//! in practice, and which would save a further 16 columns. +//! +//! - If the modulus is known in advance (such as for elliptic curve +//! arithmetic), specialised handling of MULMOD in that case would +//! only require 96 columns, or 80 if the output doesn't need to be +//! reduced. use num::bigint::Sign; use num::{BigInt, One, Zero}; @@ -171,11 +192,13 @@ fn bigint_to_columns(num: &BigInt) -> [i64; N] { /// zero if they are not used. fn generate_modular_op( lv: &mut [F; NUM_ARITH_COLUMNS], + nv: &mut [F; NUM_ARITH_COLUMNS], filter: usize, operation: fn([i64; N_LIMBS], [i64; N_LIMBS]) -> [i64; 2 * N_LIMBS - 1], ) { // Inputs are all range-checked in [0, 2^16), so the "as i64" // conversion is safe. + let input0_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_0); let input1_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_1); let mut modulus_limbs = read_value_i64_limbs(lv, MODULAR_MODULUS); @@ -246,21 +269,38 @@ fn generate_modular_op( let aux_limbs = pol_remove_root_2exp::(constr_poly); lv[MODULAR_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_i64(c))); - lv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(|c| F::from_canonical_i64(c))); - lv[MODULAR_QUO_INPUT].copy_from_slice("_limbs.map(|c| F::from_noncanonical_i64(c))); - lv[MODULAR_AUX_INPUT].copy_from_slice(&aux_limbs.map(|c| F::from_noncanonical_i64(c))); - lv[MODULAR_MOD_IS_ZERO] = mod_is_zero; + + // Copy lo and hi halves of quot_limbs into their respective registers + for (i, &lo) in MODULAR_QUO_INPUT_LO.zip("_limbs[..N_LIMBS]) { + lv[i] = F::from_noncanonical_i64(lo); + } + for (i, &hi) in MODULAR_QUO_INPUT_HI.zip("_limbs[N_LIMBS..]) { + nv[i] = F::from_noncanonical_i64(hi); + } + + for (i, &c) in MODULAR_AUX_INPUT.zip(&aux_limbs[..2 * N_LIMBS - 1]) { + nv[i] = F::from_noncanonical_i64(c); + } + + nv[MODULAR_MOD_IS_ZERO] = mod_is_zero; + nv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(|c| F::from_canonical_i64(c))); } /// Generate the output and auxiliary values for modular operations. /// /// `filter` must be one of `columns::IS_{ADDMOD,MULMOD,MOD}`. -pub(crate) fn generate(lv: &mut [F; NUM_ARITH_COLUMNS], filter: usize) { +pub(crate) fn generate( + lv: &mut [F; NUM_ARITH_COLUMNS], + nv: &mut [F; NUM_ARITH_COLUMNS], + filter: usize, +) { match filter { - columns::IS_ADDMOD => generate_modular_op(lv, filter, pol_add), - columns::IS_SUBMOD => generate_modular_op(lv, filter, pol_sub), - columns::IS_MULMOD => generate_modular_op(lv, filter, pol_mul_wide), - columns::IS_MOD | columns::IS_DIV => generate_modular_op(lv, filter, |a, _| pol_extend(a)), + columns::IS_ADDMOD => generate_modular_op(lv, nv, filter, pol_add), + columns::IS_SUBMOD => generate_modular_op(lv, nv, filter, pol_sub), + columns::IS_MULMOD => generate_modular_op(lv, nv, filter, pol_mul_wide), + columns::IS_MOD | columns::IS_DIV => { + generate_modular_op(lv, nv, filter, |a, _| pol_extend(a)) + } _ => panic!("generate modular operation called with unknown opcode"), } } @@ -275,26 +315,28 @@ pub(crate) fn generate(lv: &mut [F; NUM_ARITH_COLUMNS], filter: us /// and check consistency when m = 0, and that c is reduced. fn modular_constr_poly( lv: &[P; NUM_ARITH_COLUMNS], + nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, filter: P, ) -> [P; 2 * N_LIMBS] { range_check_error!(MODULAR_INPUT_0, 16); range_check_error!(MODULAR_INPUT_1, 16); range_check_error!(MODULAR_MODULUS, 16); - range_check_error!(MODULAR_QUO_INPUT, 16); + range_check_error!(MODULAR_QUO_INPUT_LO, 16); + range_check_error!(MODULAR_QUO_INPUT_HI, 16); range_check_error!(MODULAR_AUX_INPUT, 20, signed); range_check_error!(MODULAR_OUTPUT, 16); let mut modulus = read_value::(lv, MODULAR_MODULUS); - let mod_is_zero = lv[MODULAR_MOD_IS_ZERO]; + let mod_is_zero = nv[MODULAR_MOD_IS_ZERO]; // Check that mod_is_zero is zero or one - yield_constr.constraint(filter * (mod_is_zero * mod_is_zero - mod_is_zero)); + yield_constr.constraint_transition(filter * (mod_is_zero * mod_is_zero - mod_is_zero)); // Check that mod_is_zero is zero if modulus is not zero (they // could both be zero) let limb_sum = modulus.into_iter().sum::

(); - yield_constr.constraint(filter * limb_sum * mod_is_zero); + yield_constr.constraint_transition(filter * limb_sum * mod_is_zero); // See the file documentation for why this suffices to handle // modulus = 0. @@ -308,8 +350,8 @@ fn modular_constr_poly( output[0] += mod_is_zero * lv[IS_DIV]; // Verify that the output is reduced, i.e. output < modulus. - let out_aux_red = &lv[MODULAR_OUT_AUX_RED]; - // this sets is_less_than to 1 unless we get mod_is_zero when + let out_aux_red = &nv[MODULAR_OUT_AUX_RED]; + // This sets is_less_than to 1 unless we get mod_is_zero when // doing a DIV; in that case, we need is_less_than=0, since the // function checks // @@ -317,6 +359,8 @@ fn modular_constr_poly( // // and we were given output = out_aux_red let is_less_than = P::ONES - mod_is_zero * lv[IS_DIV]; + // NB: output and modulus in lv while out_aux_red and is_less_than + // (via mod_is_zero) depend on nv. eval_packed_generic_lt( yield_constr, filter, @@ -324,16 +368,23 @@ fn modular_constr_poly( &modulus, out_aux_red, is_less_than, + true, ); // restore output[0] output[0] -= mod_is_zero * lv[IS_DIV]; // prod = q(x) * m(x) - let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); + let quot = { + let mut quot = [P::default(); 2 * N_LIMBS]; + quot[..N_LIMBS].copy_from_slice(&lv[MODULAR_QUO_INPUT_LO]); + quot[N_LIMBS..].copy_from_slice(&nv[MODULAR_QUO_INPUT_HI]); + quot + }; + let prod = pol_mul_wide2(quot, modulus); // higher order terms must be zero for &x in prod[2 * N_LIMBS..].iter() { - yield_constr.constraint(filter * x); + yield_constr.constraint_transition(filter * x); } // constr_poly = c(x) + q(x) * m(x) @@ -341,8 +392,11 @@ fn modular_constr_poly( pol_add_assign(&mut constr_poly, &output); // constr_poly = c(x) + q(x) * m(x) + (x - β) * s(x) - let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT); - aux[2 * N_LIMBS - 1] = P::ZEROS; // zero out the MOD_IS_ZERO flag + let mut aux = [P::ZEROS; 2 * N_LIMBS]; + for (i, j) in MODULAR_AUX_INPUT.enumerate() { + aux[i] = nv[j]; + } + let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); pol_add_assign(&mut constr_poly, &pol_adjoin_root(aux, base)); @@ -352,6 +406,7 @@ fn modular_constr_poly( /// Add constraints for modular operations. pub(crate) fn eval_packed_generic( lv: &[P; NUM_ARITH_COLUMNS], + nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, ) { // NB: The CTL code guarantees that filter is 0 or 1, i.e. that @@ -362,8 +417,12 @@ pub(crate) fn eval_packed_generic( + lv[columns::IS_SUBMOD] + lv[columns::IS_DIV]; + // Ensure that this operation is not the last row of the table; + // needed because we access the next row of the table in nv. + yield_constr.constraint_last_row(filter); + // constr_poly has 2*N_LIMBS limbs - let constr_poly = modular_constr_poly(lv, yield_constr, filter); + let constr_poly = modular_constr_poly(lv, nv, yield_constr, filter); let input0 = read_value(lv, MODULAR_INPUT_0); let input1 = read_value(lv, MODULAR_INPUT_1); @@ -394,35 +453,36 @@ pub(crate) fn eval_packed_generic( // operation is valid if and only if all of those coefficients // are zero. for &c in constr_poly_copy.iter() { - yield_constr.constraint(filter * c); + yield_constr.constraint_transition(filter * c); } } } fn modular_constr_poly_ext_circuit, const D: usize>( lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], builder: &mut CircuitBuilder, yield_constr: &mut RecursiveConstraintConsumer, filter: ExtensionTarget, ) -> [ExtensionTarget; 2 * N_LIMBS] { let mut modulus = read_value::(lv, MODULAR_MODULUS); - let mod_is_zero = lv[MODULAR_MOD_IS_ZERO]; + let mod_is_zero = nv[MODULAR_MOD_IS_ZERO]; let t = builder.mul_sub_extension(mod_is_zero, mod_is_zero, mod_is_zero); let t = builder.mul_extension(filter, t); - yield_constr.constraint(builder, t); + yield_constr.constraint_transition(builder, t); let limb_sum = builder.add_many_extension(modulus); let t = builder.mul_extension(limb_sum, mod_is_zero); let t = builder.mul_extension(filter, t); - yield_constr.constraint(builder, t); + yield_constr.constraint_transition(builder, t); modulus[0] = builder.add_extension(modulus[0], mod_is_zero); let mut output = read_value::(lv, MODULAR_OUTPUT); output[0] = builder.mul_add_extension(mod_is_zero, lv[IS_DIV], output[0]); - let out_aux_red = &lv[MODULAR_OUT_AUX_RED]; + let out_aux_red = &nv[MODULAR_OUT_AUX_RED]; let one = builder.one_extension(); let is_less_than = builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one); @@ -435,22 +495,33 @@ fn modular_constr_poly_ext_circuit, const D: usize> &modulus, out_aux_red, is_less_than, + true, ); output[0] = builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], output[0]); + let quot = { + let zero = builder.zero_extension(); + let mut quot = [zero; 2 * N_LIMBS]; + quot[..N_LIMBS].copy_from_slice(&lv[MODULAR_QUO_INPUT_LO]); + quot[N_LIMBS..].copy_from_slice(&nv[MODULAR_QUO_INPUT_HI]); + quot + }; - let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus); for &x in prod[2 * N_LIMBS..].iter() { let t = builder.mul_extension(filter, x); - yield_constr.constraint(builder, t); + yield_constr.constraint_transition(builder, t); } let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); pol_add_assign_ext_circuit(builder, &mut constr_poly, &output); - let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT); - aux[2 * N_LIMBS - 1] = builder.zero_extension(); + let zero = builder.zero_extension(); + let mut aux = [zero; 2 * N_LIMBS]; + for (i, j) in MODULAR_AUX_INPUT.enumerate() { + aux[i] = nv[j]; + } + let base = builder.constant_extension(F::Extension::from_canonical_u64(1u64 << LIMB_BITS)); let t = pol_adjoin_root_ext_circuit(builder, aux, base); pol_add_assign_ext_circuit(builder, &mut constr_poly, &t); @@ -461,6 +532,7 @@ fn modular_constr_poly_ext_circuit, const D: usize> pub(crate) fn eval_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, ) { let filter = builder.add_many_extension([ @@ -471,8 +543,9 @@ pub(crate) fn eval_ext_circuit, const D: usize>( lv[columns::IS_DIV], ]); - let constr_poly = modular_constr_poly_ext_circuit(lv, builder, yield_constr, filter); + yield_constr.constraint_last_row(builder, filter); + let constr_poly = modular_constr_poly_ext_circuit(lv, nv, builder, yield_constr, filter); let input0 = read_value(lv, MODULAR_INPUT_0); let input1 = read_value(lv, MODULAR_INPUT_1); @@ -492,7 +565,7 @@ pub(crate) fn eval_ext_circuit, const D: usize>( pol_sub_assign_ext_circuit(builder, &mut constr_poly_copy, input); for &c in constr_poly_copy.iter() { let t = builder.mul_extension(filter, c); - yield_constr.constraint(builder, t); + yield_constr.constraint_transition(builder, t); } } } @@ -518,6 +591,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); // if `IS_ADDMOD == 0`, then the constraints should be met even // if all values are garbage. @@ -533,7 +607,7 @@ mod tests { GoldilocksField::ONE, GoldilocksField::ONE, ); - eval_packed_generic(&lv, &mut constraint_consumer); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); for &acc in &constraint_consumer.constraint_accs { assert_eq!(acc, GoldilocksField::ZERO); } @@ -545,6 +619,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); for op_filter in [IS_ADDMOD, IS_DIV, IS_SUBMOD, IS_MOD, IS_MULMOD] { // Reset operation columns, then select one @@ -563,9 +638,9 @@ mod tests { lv[mi] = F::from_canonical_u16(rng.gen()); } - // For the second half of the tests, set the top 16 - - // start digits of the modulus to zero so it is much - // smaller than the inputs. + // For the second half of the tests, set the top + // 16-start digits of the modulus to zero so it is + // much smaller than the inputs. if i > N_RND_TESTS / 2 { // 1 <= start < N_LIMBS let start = (rng.gen::() % (N_LIMBS - 1)) + 1; @@ -574,15 +649,15 @@ mod tests { } } - generate(&mut lv, op_filter); + generate(&mut lv, &mut nv, op_filter); let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], GoldilocksField::ONE, - GoldilocksField::ONE, - GoldilocksField::ONE, + GoldilocksField::ZERO, + GoldilocksField::ZERO, ); - eval_packed_generic(&lv, &mut constraint_consumer); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); for &acc in &constraint_consumer.constraint_accs { assert_eq!(acc, GoldilocksField::ZERO); } @@ -596,6 +671,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); for op_filter in [IS_ADDMOD, IS_SUBMOD, IS_DIV, IS_MOD, IS_MULMOD] { // Reset operation columns, then select one @@ -609,13 +685,14 @@ mod tests { for _i in 0..N_RND_TESTS { // set inputs to random values and the modulus to zero; // the output is defined to be zero when modulus is zero. + for (ai, bi, mi) in izip!(MODULAR_INPUT_0, MODULAR_INPUT_1, MODULAR_MODULUS) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); lv[mi] = F::ZERO; } - generate(&mut lv, op_filter); + generate(&mut lv, &mut nv, op_filter); // check that the correct output was generated if op_filter == IS_DIV { @@ -627,24 +704,25 @@ mod tests { let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], GoldilocksField::ONE, - GoldilocksField::ONE, - GoldilocksField::ONE, + GoldilocksField::ZERO, + GoldilocksField::ZERO, ); - eval_packed_generic(&lv, &mut constraint_consumer); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); assert!(constraint_consumer .constraint_accs .iter() .all(|&acc| acc == F::ZERO)); // Corrupt one output limb by setting it to a non-zero value - let random_oi = if op_filter == IS_DIV { - DIV_OUTPUT.start + rng.gen::() % N_LIMBS + if op_filter == IS_DIV { + let random_oi = DIV_OUTPUT.start + rng.gen::() % N_LIMBS; + lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); } else { - MODULAR_OUTPUT.start + rng.gen::() % N_LIMBS + let random_oi = MODULAR_OUTPUT.start + rng.gen::() % N_LIMBS; + lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); }; - lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); - eval_packed_generic(&lv, &mut constraint_consumer); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); // Check that at least one of the constraints was non-zero assert!(constraint_consumer diff --git a/evm/src/arithmetic/sub.rs b/evm/src/arithmetic/sub.rs index d589f323..13f6e8d5 100644 --- a/evm/src/arithmetic/sub.rs +++ b/evm/src/arithmetic/sub.rs @@ -57,6 +57,7 @@ pub fn eval_packed_generic( is_sub, output_limbs.iter().copied(), output_computed, + false, ); } @@ -87,6 +88,7 @@ pub fn eval_ext_circuit, const D: usize>( is_sub, output_limbs.iter().copied(), output_computed.into_iter(), + false, ); }