diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index b7168f85..e5f631e8 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -104,10 +104,7 @@ pub(crate) fn all_cross_table_lookups() -> Vec> { fn ctl_arithmetic() -> CrossTableLookup { CrossTableLookup::new( - vec![ - cpu_stark::ctl_arithmetic_base_rows(), - cpu_stark::ctl_arithmetic_shift_rows(), - ], + vec![cpu_stark::ctl_arithmetic_base_rows()], arithmetic_stark::ctl_arithmetic_rows(), ) } diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index f38aab9d..3d281c86 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -12,6 +12,7 @@ use plonky2::util::transpose; use static_assertions::const_assert; use super::columns::NUM_ARITH_COLUMNS; +use super::shift; use crate::all_stark::Table; use crate::arithmetic::columns::{RANGE_COUNTER, RC_FREQUENCIES, SHARED_COLS}; use crate::arithmetic::{addcy, byte, columns, divmod, modular, mul, Operation}; @@ -208,6 +209,7 @@ impl, const D: usize> Stark for ArithmeticSta divmod::eval_packed(lv, nv, yield_constr); modular::eval_packed(lv, nv, yield_constr); byte::eval_packed(lv, yield_constr); + shift::eval_packed_generic(lv, nv, yield_constr); } fn eval_ext_circuit( @@ -237,6 +239,7 @@ impl, const D: usize> Stark for ArithmeticSta divmod::eval_ext_circuit(builder, lv, nv, yield_constr); modular::eval_ext_circuit(builder, lv, nv, yield_constr); byte::eval_ext_circuit(builder, lv, yield_constr); + shift::eval_ext_circuit(builder, lv, nv, yield_constr); } fn constraint_degree(&self) -> usize { diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index 36eb983e..df2d1247 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -101,7 +101,7 @@ pub(crate) const MODULAR_OUT_AUX_RED: Range = AUX_REGISTER_0; pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_REGISTER_1.start; pub(crate) const MODULAR_AUX_INPUT_LO: Range = AUX_REGISTER_1.start + 1..AUX_REGISTER_1.end; pub(crate) const MODULAR_AUX_INPUT_HI: Range = AUX_REGISTER_2; -// Must be set to MOD_IS_ZERO for DIV operation i.e. MOD_IS_ZERO * lv[IS_DIV] +// Must be set to MOD_IS_ZERO for DIV and SHR operations i.e. MOD_IS_ZERO * (lv[IS_DIV] + lv[IS_SHR]). pub(crate) const MODULAR_DIV_DENOM_IS_ZERO: usize = AUX_REGISTER_2.end; /// The counter column (used for the range check) starts from 0 and increments. diff --git a/evm/src/arithmetic/divmod.rs b/evm/src/arithmetic/divmod.rs index 258c131f..e143ded6 100644 --- a/evm/src/arithmetic/divmod.rs +++ b/evm/src/arithmetic/divmod.rs @@ -15,24 +15,19 @@ use crate::arithmetic::modular::{ use crate::arithmetic::utils::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -/// Generate the output and auxiliary values for modular operations. -pub(crate) fn generate( +/// Generates the output and auxiliary values for modular operations, +/// assuming the input, modular and output limbs are already set. +pub(crate) fn generate_divmod( lv: &mut [F], nv: &mut [F], filter: usize, - input0: U256, - input1: U256, - result: U256, + input_limbs_range: Range, + modulus_range: Range, ) { - debug_assert!(lv.len() == NUM_ARITH_COLUMNS); - - u256_to_array(&mut lv[INPUT_REGISTER_0], input0); - u256_to_array(&mut lv[INPUT_REGISTER_1], input1); - u256_to_array(&mut lv[OUTPUT_REGISTER], result); - - let input_limbs = read_value_i64_limbs::(lv, INPUT_REGISTER_0); + let input_limbs = read_value_i64_limbs::(lv, input_limbs_range); let pol_input = pol_extend(input_limbs); - let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, INPUT_REGISTER_1); + let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, modulus_range); + debug_assert!( &quo_input[N_LIMBS..].iter().all(|&x| x == F::ZERO), "expected top half of quo_input to be zero" @@ -62,16 +57,35 @@ pub(crate) fn generate( ); lv[AUX_INPUT_REGISTER_0].copy_from_slice(&quo_input[..N_LIMBS]); } - _ => panic!("expected filter to be IS_DIV or IS_MOD but it was {filter}"), + _ => panic!("expected filter to be IS_DIV, IS_SHR or IS_MOD but it was {filter}"), }; } +/// Generate the output and auxiliary values for modular operations. +pub(crate) fn generate( + lv: &mut [F], + nv: &mut [F], + filter: usize, + input0: U256, + input1: U256, + result: U256, +) { + debug_assert!(lv.len() == NUM_ARITH_COLUMNS); + + u256_to_array(&mut lv[INPUT_REGISTER_0], input0); + u256_to_array(&mut lv[INPUT_REGISTER_1], input1); + u256_to_array(&mut lv[OUTPUT_REGISTER], result); + + generate_divmod(lv, nv, filter, INPUT_REGISTER_0, INPUT_REGISTER_1); +} /// Verify that num = quo * den + rem and 0 <= rem < den. -fn eval_packed_divmod_helper( +pub(crate) fn eval_packed_divmod_helper( lv: &[P; NUM_ARITH_COLUMNS], nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, filter: P, + num_range: Range, + den_range: Range, quo_range: Range, rem_range: Range, ) { @@ -80,8 +94,8 @@ fn eval_packed_divmod_helper( yield_constr.constraint_last_row(filter); - let num = &lv[INPUT_REGISTER_0]; - let den = read_value(lv, INPUT_REGISTER_1); + let num = &lv[num_range]; + let den = read_value(lv, den_range); let quo = { let mut quo = [P::ZEROS; 2 * N_LIMBS]; quo[..N_LIMBS].copy_from_slice(&lv[quo_range]); @@ -104,14 +118,13 @@ pub(crate) fn eval_packed( nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, ) { - // Constrain IS_SHR independently, so that it doesn't impact the - // constraints when combining the flag with IS_DIV. - yield_constr.constraint_last_row(lv[IS_SHR]); eval_packed_divmod_helper( lv, nv, yield_constr, - lv[IS_DIV] + lv[IS_SHR], + lv[IS_DIV], + INPUT_REGISTER_0, + INPUT_REGISTER_1, OUTPUT_REGISTER, AUX_INPUT_REGISTER_0, ); @@ -120,24 +133,28 @@ pub(crate) fn eval_packed( nv, yield_constr, lv[IS_MOD], + INPUT_REGISTER_0, + INPUT_REGISTER_1, AUX_INPUT_REGISTER_0, OUTPUT_REGISTER, ); } -fn eval_ext_circuit_divmod_helper, const D: usize>( +pub(crate) fn eval_ext_circuit_divmod_helper, const D: usize>( builder: &mut CircuitBuilder, lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, filter: ExtensionTarget, + num_range: Range, + den_range: Range, quo_range: Range, rem_range: Range, ) { yield_constr.constraint_last_row(builder, filter); - let num = &lv[INPUT_REGISTER_0]; - let den = read_value(lv, INPUT_REGISTER_1); + let num = &lv[num_range]; + let den = read_value(lv, den_range); let quo = { let zero = builder.zero_extension(); let mut quo = [zero; 2 * N_LIMBS]; @@ -164,14 +181,14 @@ pub(crate) fn eval_ext_circuit, const D: usize>( nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, ) { - yield_constr.constraint_last_row(builder, lv[IS_SHR]); - let div_shr_flag = builder.add_extension(lv[IS_DIV], lv[IS_SHR]); eval_ext_circuit_divmod_helper( builder, lv, nv, yield_constr, - div_shr_flag, + lv[IS_DIV], + INPUT_REGISTER_0, + INPUT_REGISTER_1, OUTPUT_REGISTER, AUX_INPUT_REGISTER_0, ); @@ -181,6 +198,8 @@ pub(crate) fn eval_ext_circuit, const D: usize>( nv, yield_constr, lv[IS_MOD], + INPUT_REGISTER_0, + INPUT_REGISTER_1, AUX_INPUT_REGISTER_0, OUTPUT_REGISTER, ); @@ -214,7 +233,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } - // Deactivate the SHR flag so that a DIV operation is not triggered. + // Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here. lv[IS_SHR] = F::ZERO; let mut constraint_consumer = ConstraintConsumer::new( @@ -247,6 +266,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + // Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here. lv[IS_SHR] = F::ZERO; lv[op_filter] = F::ONE; @@ -308,6 +328,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + // Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here. lv[IS_SHR] = F::ZERO; lv[op_filter] = F::ONE; diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index bd6d56e8..7763e98a 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -9,6 +9,7 @@ mod byte; mod divmod; mod modular; mod mul; +mod shift; mod utils; pub mod arithmetic_stark; @@ -35,15 +36,29 @@ impl BinaryOperator { pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 { match self { BinaryOperator::Add => input0.overflowing_add(input1).0, - BinaryOperator::Mul | BinaryOperator::Shl => input0.overflowing_mul(input1).0, + BinaryOperator::Mul => input0.overflowing_mul(input1).0, + BinaryOperator::Shl => { + if input0 < U256::from(256usize) { + input1 << input0 + } else { + U256::zero() + } + } BinaryOperator::Sub => input0.overflowing_sub(input1).0, - BinaryOperator::Div | BinaryOperator::Shr => { + BinaryOperator::Div => { if input1.is_zero() { U256::zero() } else { input0 / input1 } } + BinaryOperator::Shr => { + if input0 < U256::from(256usize) { + input1 >> input0 + } else { + U256::zero() + } + } BinaryOperator::Mod => { if input1.is_zero() { U256::zero() @@ -238,15 +253,25 @@ fn binary_op_to_rows( addcy::generate(&mut row, op.row_filter(), input0, input1); (row, None) } - BinaryOperator::Mul | BinaryOperator::Shl => { + BinaryOperator::Mul => { mul::generate(&mut row, input0, input1); (row, None) } - BinaryOperator::Div | BinaryOperator::Mod | BinaryOperator::Shr => { + BinaryOperator::Shl => { + let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + shift::generate(&mut row, &mut nv, true, input0, input1, result); + (row, None) + } + BinaryOperator::Div | BinaryOperator::Mod => { let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; divmod::generate(&mut row, &mut nv, op.row_filter(), input0, input1, result); (row, Some(nv)) } + BinaryOperator::Shr => { + let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + shift::generate(&mut row, &mut nv, false, input0, input1, result); + (row, Some(nv)) + } BinaryOperator::AddFp254 | BinaryOperator::MulFp254 | BinaryOperator::SubFp254 => { ternary_op_to_rows::(op.row_filter(), input0, input1, BN_BASE, result) } diff --git a/evm/src/arithmetic/modular.rs b/evm/src/arithmetic/modular.rs index 4e540cb6..4e6e21a6 100644 --- a/evm/src/arithmetic/modular.rs +++ b/evm/src/arithmetic/modular.rs @@ -239,7 +239,7 @@ pub(crate) fn generate_modular_op( let mut mod_is_zero = F::ZERO; if modulus.is_zero() { - if filter == columns::IS_DIV { + if filter == columns::IS_DIV || filter == columns::IS_SHR { // set modulus = 2^256; the condition above means we know // it's zero at this point, so we can just set bit 256. modulus.set_bit(256, true); @@ -330,7 +330,7 @@ pub(crate) fn generate_modular_op( nv[MODULAR_MOD_IS_ZERO] = mod_is_zero; nv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(F::from_canonical_i64)); - nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * lv[IS_DIV]; + nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]); ( output_limbs.map(F::from_canonical_i64), @@ -392,14 +392,14 @@ pub(crate) fn check_reduced( // Verify that the output is reduced, i.e. output < modulus. let out_aux_red = &nv[MODULAR_OUT_AUX_RED]; // This sets is_less_than to 1 unless we get mod_is_zero when - // doing a DIV; in that case, we need is_less_than=0, since + // doing a DIV or SHR; in that case, we need is_less_than=0, since // eval_packed_generic_addcy checks // // modulus + out_aux_red == output + is_less_than*2^256 // // and we are given output = out_aux_red when modulus is zero. let mut is_less_than = [P::ZEROS; N_LIMBS]; - is_less_than[0] = P::ONES - mod_is_zero * lv[IS_DIV]; + is_less_than[0] = P::ONES - mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]); // NB: output and modulus in lv while out_aux_red and // is_less_than (via mod_is_zero) depend on nv, hence the // 'is_two_row_op' argument is set to 'true'. @@ -448,13 +448,15 @@ pub(crate) fn modular_constr_poly( // modulus = 0. modulus[0] += mod_is_zero; - // Is 1 iff the operation is DIV and the denominator is zero. + // Is 1 iff the operation is DIV or SHR and the denominator is zero. let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO]; - yield_constr.constraint_transition(filter * (mod_is_zero * lv[IS_DIV] - div_denom_is_zero)); + yield_constr.constraint_transition( + filter * (mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]) - div_denom_is_zero), + ); // Needed to compensate for adding mod_is_zero to modulus above, // since the call eval_packed_generic_addcy() below subtracts modulus - // to verify in the case of a DIV. + // to verify in the case of a DIV or SHR. output[0] += div_denom_is_zero; check_reduced(lv, nv, yield_constr, filter, output, modulus, mod_is_zero); @@ -635,7 +637,8 @@ pub(crate) fn modular_constr_poly_ext_circuit, cons modulus[0] = builder.add_extension(modulus[0], mod_is_zero); let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO]; - let t = builder.mul_sub_extension(mod_is_zero, lv[IS_DIV], div_denom_is_zero); + let div_shr_filter = builder.add_extension(lv[IS_DIV], lv[IS_SHR]); + let t = builder.mul_sub_extension(mod_is_zero, div_shr_filter, div_denom_is_zero); let t = builder.mul_extension(filter, t); yield_constr.constraint_transition(builder, t); output[0] = builder.add_extension(output[0], div_denom_is_zero); @@ -645,7 +648,7 @@ pub(crate) fn modular_constr_poly_ext_circuit, cons let zero = builder.zero_extension(); let mut is_less_than = [zero; N_LIMBS]; is_less_than[0] = - builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one); + builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, div_shr_filter, one); eval_ext_circuit_addcy( builder, @@ -834,6 +837,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + lv[IS_SHR] = F::ZERO; lv[IS_DIV] = F::ZERO; lv[IS_MOD] = F::ZERO; @@ -867,6 +871,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + lv[IS_SHR] = F::ZERO; lv[IS_DIV] = F::ZERO; lv[IS_MOD] = F::ZERO; lv[op_filter] = F::ONE; @@ -926,6 +931,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + lv[IS_SHR] = F::ZERO; lv[IS_DIV] = F::ZERO; lv[IS_MOD] = F::ZERO; lv[op_filter] = F::ONE; diff --git a/evm/src/arithmetic/mul.rs b/evm/src/arithmetic/mul.rs index efb4d822..c09c39d8 100644 --- a/evm/src/arithmetic/mul.rs +++ b/evm/src/arithmetic/mul.rs @@ -67,16 +67,8 @@ use crate::arithmetic::columns::*; use crate::arithmetic::utils::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { - // TODO: It would probably be clearer/cleaner to read the U256 - // into an [i64;N] and then copy that to the lv table. - u256_to_array(&mut lv[INPUT_REGISTER_0], left_in); - u256_to_array(&mut lv[INPUT_REGISTER_1], right_in); - u256_to_array(&mut lv[INPUT_REGISTER_2], U256::zero()); - - let input0 = read_value_i64_limbs(lv, INPUT_REGISTER_0); - let input1 = read_value_i64_limbs(lv, INPUT_REGISTER_1); - +/// Given the two limbs of `left_in` and `right_in`, computes `left_in * right_in`. +pub(crate) fn generate_mul(lv: &mut [F], left_in: [i64; 16], right_in: [i64; 16]) { const MASK: i64 = (1i64 << LIMB_BITS) - 1i64; // Input and output have 16-bit limbs @@ -86,7 +78,7 @@ pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { // First calculate the coefficients of a(x)*b(x) (in unreduced_prod), // then do carry propagation to obtain C = c(β) = a(β)*b(β). let mut cy = 0i64; - let mut unreduced_prod = pol_mul_lo(input0, input1); + let mut unreduced_prod = pol_mul_lo(left_in, right_in); for col in 0..N_LIMBS { let t = unreduced_prod[col] + cy; cy = t >> LIMB_BITS; @@ -115,17 +107,30 @@ pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { .copy_from_slice(&aux_limbs.map(|c| F::from_canonical_u16((c >> 16) as u16))); } -pub fn eval_packed_generic( +pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { + // TODO: It would probably be clearer/cleaner to read the U256 + // into an [i64;N] and then copy that to the lv table. + u256_to_array(&mut lv[INPUT_REGISTER_0], left_in); + u256_to_array(&mut lv[INPUT_REGISTER_1], right_in); + u256_to_array(&mut lv[INPUT_REGISTER_2], U256::zero()); + + let input0 = read_value_i64_limbs(lv, INPUT_REGISTER_0); + let input1 = read_value_i64_limbs(lv, INPUT_REGISTER_1); + + generate_mul(lv, input0, input1); +} + +pub(crate) fn eval_packed_generic_mul( lv: &[P; NUM_ARITH_COLUMNS], + filter: P, + left_in_limbs: [P; 16], + right_in_limbs: [P; 16], yield_constr: &mut ConstraintConsumer

, ) { - let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); - - let is_mul = lv[IS_MUL] + lv[IS_SHL]; - let input0_limbs = read_value::(lv, INPUT_REGISTER_0); - let input1_limbs = read_value::(lv, INPUT_REGISTER_1); let output_limbs = read_value::(lv, OUTPUT_REGISTER); + let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); + let aux_limbs = { // MUL_AUX_INPUT was offset by 2^20 in generation, so we undo // that here @@ -153,7 +158,7 @@ pub fn eval_packed_generic( // // s(x) = \sum_i aux_limbs[i] * x^i // - let mut constr_poly = pol_mul_lo(input0_limbs, input1_limbs); + let mut constr_poly = pol_mul_lo(left_in_limbs, right_in_limbs); pol_sub_assign(&mut constr_poly, &output_limbs); // This subtracts (x - β) * s(x) from constr_poly. @@ -164,18 +169,29 @@ pub fn eval_packed_generic( // multiplication is valid if and only if all of those // coefficients are zero. for &c in &constr_poly { - yield_constr.constraint(is_mul * c); + yield_constr.constraint(filter * c); } } -pub fn eval_ext_circuit, const D: usize>( - builder: &mut CircuitBuilder, - lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], - yield_constr: &mut RecursiveConstraintConsumer, +pub fn eval_packed_generic( + lv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, ) { - let is_mul = builder.add_extension(lv[IS_MUL], lv[IS_SHL]); + let is_mul = lv[IS_MUL]; let input0_limbs = read_value::(lv, INPUT_REGISTER_0); let input1_limbs = read_value::(lv, INPUT_REGISTER_1); + + eval_packed_generic_mul(lv, is_mul, input0_limbs, input1_limbs, yield_constr); +} + +pub(crate) fn eval_ext_mul_circuit, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + filter: ExtensionTarget, + left_in_limbs: [ExtensionTarget; 16], + right_in_limbs: [ExtensionTarget; 16], + yield_constr: &mut RecursiveConstraintConsumer, +) { let output_limbs = read_value::(lv, OUTPUT_REGISTER); let aux_limbs = { @@ -192,7 +208,7 @@ pub fn eval_ext_circuit, const D: usize>( aux_limbs }; - let mut constr_poly = pol_mul_lo_ext_circuit(builder, input0_limbs, input1_limbs); + let mut constr_poly = pol_mul_lo_ext_circuit(builder, left_in_limbs, right_in_limbs); pol_sub_assign_ext_circuit(builder, &mut constr_poly, &output_limbs); let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << LIMB_BITS)); @@ -200,11 +216,30 @@ pub fn eval_ext_circuit, const D: usize>( pol_sub_assign_ext_circuit(builder, &mut constr_poly, &rhs); for &c in &constr_poly { - let filter = builder.mul_extension(is_mul, c); + let filter = builder.mul_extension(filter, c); yield_constr.constraint(builder, filter); } } +pub fn eval_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_mul = lv[IS_MUL]; + let input0_limbs = read_value::(lv, INPUT_REGISTER_0); + let input1_limbs = read_value::(lv, INPUT_REGISTER_1); + + eval_ext_mul_circuit( + builder, + lv, + is_mul, + input0_limbs, + input1_limbs, + yield_constr, + ); +} + #[cfg(test)] mod tests { use plonky2::field::goldilocks_field::GoldilocksField; @@ -229,8 +264,6 @@ mod tests { // if `IS_MUL == 0`, then the constraints should be met even // if all values are garbage. lv[IS_MUL] = F::ZERO; - // Deactivate the SHL flag so that a MUL operation is not triggered. - lv[IS_SHL] = F::ZERO; let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], diff --git a/evm/src/arithmetic/shift.rs b/evm/src/arithmetic/shift.rs new file mode 100644 index 00000000..6600c01e --- /dev/null +++ b/evm/src/arithmetic/shift.rs @@ -0,0 +1,338 @@ +//! Support for the EVM SHL and SHR instructions. +//! +//! This crate verifies an EVM shift instruction, which takes two +//! 256-bit inputs S and A, and produces a 256-bit output C satisfying +//! +//! C = A << S (mod 2^256) for SHL or +//! C = A >> S (mod 2^256) for SHR. +//! +//! The way this computation is carried is by providing a third input +//! B = 1 << S (mod 2^256) +//! and then computing: +//! C = A * B (mod 2^256) for SHL or +//! C = A / B (mod 2^256) for SHR +//! +//! Inputs A, S, and B, and output C, are given as arrays of 16-bit +//! limbs. For example, if the limbs of A are a[0]...a[15], then +//! +//! A = \sum_{i=0}^15 a[i] β^i, +//! +//! where β = 2^16 = 2^LIMB_BITS. To verify that A, S, B and C satisfy +//! the equations, we proceed similarly to MUL for SHL and to DIV for SHR. + +use ethereum_types::U256; +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::PrimeField64; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use super::{divmod, mul}; +use crate::arithmetic::columns::*; +use crate::arithmetic::utils::*; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +/// Generates a shift operation (either SHL or SHR). +/// The inputs are stored in the form `(shift, input, 1 << shift)`. +/// NB: if `shift >= 256`, then the third register holds 0. +/// We leverage the functions in mul.rs and divmod.rs to carry out +/// the computation. +pub fn generate( + lv: &mut [F], + nv: &mut [F], + is_shl: bool, + shift: U256, + input: U256, + result: U256, +) { + // We use the multiplication logic to generate SHL + // TODO: It would probably be clearer/cleaner to read the U256 + // into an [i64;N] and then copy that to the lv table. + // The first input is the shift we need to apply. + u256_to_array(&mut lv[INPUT_REGISTER_0], shift); + // The second register holds the input which needs shifting. + u256_to_array(&mut lv[INPUT_REGISTER_1], input); + u256_to_array(&mut lv[OUTPUT_REGISTER], result); + // If `shift >= 256`, the shifted displacement is set to 0. + // Compute 1 << shift and store it in the third input register. + let shifted_displacement = if shift > U256::from(255u64) { + U256::zero() + } else { + U256::one() << shift + }; + + u256_to_array(&mut lv[INPUT_REGISTER_2], shifted_displacement); + + let input0 = read_value_i64_limbs(lv, INPUT_REGISTER_1); // input + let input1 = read_value_i64_limbs(lv, INPUT_REGISTER_2); // 1 << shift + + if is_shl { + // We generate the multiplication input0 * input1 using mul.rs. + mul::generate_mul(lv, input0, input1); + } else { + // If the operation is SHR, we compute: `input / shifted_displacement` if `shifted_displacement == 0` + // otherwise, the output is 0. We use the logic in divmod.rs to achieve that. + divmod::generate_divmod(lv, nv, IS_SHR, INPUT_REGISTER_1, INPUT_REGISTER_2); + } +} + +/// Evaluates the constraints for an SHL opcode. +/// The logic is the same as the one for MUL. The only difference is that +/// the inputs are in `INPUT_REGISTER_1` and `INPUT_REGISTER_2` instead of +/// `INPUT_REGISTER_0` and `INPUT_REGISTER_1`. +fn eval_packed_shl( + lv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_shl = lv[IS_SHL]; + let input0_limbs = read_value::(lv, INPUT_REGISTER_1); + let shifted_limbs = read_value::(lv, INPUT_REGISTER_2); + + mul::eval_packed_generic_mul(lv, is_shl, input0_limbs, shifted_limbs, yield_constr); +} + +/// Evaluates the constraints for an SHR opcode. +/// The logic is tha same as the one for DIV. The only difference is that +/// the inputs are in `INPUT_REGISTER_1` and `INPUT_REGISTER_2` instead of +/// `INPUT_REGISTER_0` and `INPUT_REGISTER_1`. +fn eval_packed_shr( + lv: &[P; NUM_ARITH_COLUMNS], + nv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let quo_range = OUTPUT_REGISTER; + let rem_range = AUX_INPUT_REGISTER_0; + let filter = lv[IS_SHR]; + + divmod::eval_packed_divmod_helper( + lv, + nv, + yield_constr, + filter, + INPUT_REGISTER_1, + INPUT_REGISTER_2, + quo_range, + rem_range, + ); +} + +pub fn eval_packed_generic( + lv: &[P; NUM_ARITH_COLUMNS], + nv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + eval_packed_shl(lv, yield_constr); + eval_packed_shr(lv, nv, yield_constr); +} + +fn eval_ext_circuit_shl, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_shl = lv[IS_SHL]; + let input0_limbs = read_value::(lv, INPUT_REGISTER_1); + let shifted_limbs = read_value::(lv, INPUT_REGISTER_2); + + mul::eval_ext_mul_circuit( + builder, + lv, + is_shl, + input0_limbs, + shifted_limbs, + yield_constr, + ); +} + +fn eval_ext_circuit_shr, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let filter = lv[IS_SHR]; + let quo_range = OUTPUT_REGISTER; + let rem_range = AUX_INPUT_REGISTER_0; + + divmod::eval_ext_circuit_divmod_helper( + builder, + lv, + nv, + yield_constr, + filter, + INPUT_REGISTER_1, + INPUT_REGISTER_2, + quo_range, + rem_range, + ); +} + +pub fn eval_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + eval_ext_circuit_shl(builder, lv, yield_constr); + eval_ext_circuit_shr(builder, lv, nv, yield_constr); +} + +#[cfg(test)] +mod tests { + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::{Field, Sample}; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha8Rng; + + use super::*; + use crate::arithmetic::columns::NUM_ARITH_COLUMNS; + use crate::constraint_consumer::ConstraintConsumer; + + const N_RND_TESTS: usize = 1000; + + // TODO: Should be able to refactor this test to apply to all operations. + #[test] + fn generate_eval_consistency_not_shift() { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + + // if `IS_SHL == 0` and `IS_SHR == 0`, then the constraints should be met even + // if all values are garbage. + lv[IS_SHL] = F::ZERO; + lv[IS_SHR] = F::ZERO; + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ONE, + ); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + + fn generate_eval_consistency_shift(is_shl: bool) { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + + // set `IS_SHL == 1` or `IS_SHR == 1` and ensure all constraints are satisfied. + if is_shl { + lv[IS_SHL] = F::ONE; + lv[IS_SHR] = F::ZERO; + } else { + // Set `IS_DIV` to 0 in this case, since we're using the logic of DIV for SHR. + lv[IS_DIV] = F::ZERO; + lv[IS_SHL] = F::ZERO; + lv[IS_SHR] = F::ONE; + } + + for _i in 0..N_RND_TESTS { + let shift = U256::from(rng.gen::()); + + let mut full_input = U256::from(0); + // set inputs to random values + for ai in INPUT_REGISTER_1 { + lv[ai] = F::from_canonical_u16(rng.gen()); + full_input = + U256::from(lv[ai].to_canonical_u64()) + full_input * U256::from(1 << 16); + } + + let output = if is_shl { + full_input << shift + } else { + full_input >> shift + }; + + generate(&mut lv, &mut nv, is_shl, shift, full_input, output); + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ZERO, + ); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + } + + #[test] + fn generate_eval_consistency_shl() { + generate_eval_consistency_shift(true); + } + + #[test] + fn generate_eval_consistency_shr() { + generate_eval_consistency_shift(false); + } + + fn generate_eval_consistency_shift_over_256(is_shl: bool) { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + + // set `IS_SHL == 1` or `IS_SHR == 1` and ensure all constraints are satisfied. + if is_shl { + lv[IS_SHL] = F::ONE; + lv[IS_SHR] = F::ZERO; + } else { + // Set `IS_DIV` to 0 in this case, since we're using the logic of DIV for SHR. + lv[IS_DIV] = F::ZERO; + lv[IS_SHL] = F::ZERO; + lv[IS_SHR] = F::ONE; + } + + for _i in 0..N_RND_TESTS { + let mut shift = U256::from(rng.gen::()); + while shift > U256::MAX - 256 { + shift = U256::from(rng.gen::()); + } + shift += U256::from(256); + + let mut full_input = U256::from(0); + // set inputs to random values + for ai in INPUT_REGISTER_1 { + lv[ai] = F::from_canonical_u16(rng.gen()); + full_input = + U256::from(lv[ai].to_canonical_u64()) + full_input * U256::from(1 << 16); + } + + let output = 0.into(); + generate(&mut lv, &mut nv, is_shl, shift, full_input, output); + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ZERO, + ); + eval_packed_generic(&lv, &nv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + } + + #[test] + fn generate_eval_consistency_shl_over_256() { + generate_eval_consistency_shift_over_256(true); + } + + #[test] + fn generate_eval_consistency_shr_over_256() { + generate_eval_consistency_shift_over_256(false); + } +} diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index f23ff308..82ca5452 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -63,19 +63,10 @@ fn ctl_data_binops() -> Vec> { /// one output of a ternary operation. By default, ternary operations use /// the first three memory channels, and the last one for the result (binary /// operations do not use the third inputs). -/// -/// Shift operations are different, as they are simulated with `MUL` or `DIV` -/// on the arithmetic side. We first convert the shift into the multiplicand -/// (in case of `SHL`) or the divisor (in case of `SHR`), making the first memory -/// channel not directly usable. We overcome this by adding an offset of 1 in -/// case of shift operations, which will skip the first memory channel and use the -/// next three as ternary inputs. Because both `MUL` and `DIV` are binary operations, -/// the last memory channel used for the inputs will be safely ignored. -fn ctl_data_ternops(is_shift: bool) -> Vec> { - let offset = is_shift as usize; - let mut res = Column::singles(COL_MAP.mem_channels[offset].value).collect_vec(); - res.extend(Column::singles(COL_MAP.mem_channels[offset + 1].value)); - res.extend(Column::singles(COL_MAP.mem_channels[offset + 2].value)); +fn ctl_data_ternops() -> Vec> { + let mut res = Column::singles(COL_MAP.mem_channels[0].value).collect_vec(); + res.extend(Column::singles(COL_MAP.mem_channels[1].value)); + res.extend(Column::singles(COL_MAP.mem_channels[2].value)); res.extend(Column::singles( COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value, )); @@ -96,7 +87,7 @@ pub fn ctl_filter_logic() -> Column { pub fn ctl_arithmetic_base_rows() -> TableWithColumns { // Instead of taking single columns, we reconstruct the entire opcode value directly. let mut columns = vec![Column::le_bits(COL_MAP.opcode_bits)]; - columns.extend(ctl_data_ternops(false)); + columns.extend(ctl_data_ternops()); // Create the CPU Table whose columns are those with the three // inputs and one output of the ternary operations listed in `ops` // (also `ops` is used as the operation filter). The list of @@ -109,22 +100,11 @@ pub fn ctl_arithmetic_base_rows() -> TableWithColumns { COL_MAP.op.binary_op, COL_MAP.op.fp254_op, COL_MAP.op.ternary_op, + COL_MAP.op.shift, ])), ) } -pub fn ctl_arithmetic_shift_rows() -> TableWithColumns { - // Instead of taking single columns, we reconstruct the entire opcode value directly. - let mut columns = vec![Column::le_bits(COL_MAP.opcode_bits)]; - columns.extend(ctl_data_ternops(true)); - // Create the CPU Table whose columns are those with the three - // inputs and one output of the ternary operations listed in `ops` - // (also `ops` is used as the operation filter). The list of - // operations includes binary operations which will simply ignore - // the third input. - TableWithColumns::new(Table::Cpu, columns, Some(Column::single(COL_MAP.op.shift))) -} - pub fn ctl_data_byte_packing() -> Vec> { ctl_data_keccak_sponge() } diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index 0620069f..568fe4b1 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -499,18 +499,12 @@ fn append_shift( channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt); } - // Convert the shift, and log the corresponding arithmetic operation. - let input0 = if input0 > U256::from(255u64) { - U256::zero() - } else { - U256::one() << input0 - }; let operator = if is_shl { BinaryOperator::Shl } else { BinaryOperator::Shr }; - let operation = arithmetic::Operation::binary(operator, input1, input0); + let operation = arithmetic::Operation::binary(operator, input0, input1); state.traces.push_arithmetic(operation); state.traces.push_memory(log_in0);