From 40866e775aa707b8572571917b4eb35e2b40481b Mon Sep 17 00:00:00 2001 From: Hamish Ivey-Law <426294+unzvfu@users.noreply.github.com> Date: Fri, 10 Feb 2023 23:07:57 +1100 Subject: [PATCH] Refactor arithmetic operation traits (#876) * Use U256s in `generate(...)` interfaces; fix reduction bug modular. * Refactor `Operation` trait. * Rename file. * Rename `add_cc` things to `addcy`. * Clippy. * Simplify generation of less-than and greater-than. * Add some comparison tests. * Use `PrimeField64` instead of `RichField` where possible. * Connect `SUBMOD` operation to witness generator. * Add clippy exception. * Add missing verification of range counter column. * Fix generation of RANGE_COUNTER column. * Address William's PR comments. --- evm/src/arithmetic/{addcc.rs => addcy.rs} | 93 +++++------- evm/src/arithmetic/arithmetic_stark.rs | 117 +++++++++------ evm/src/arithmetic/columns.rs | 2 + evm/src/arithmetic/mod.rs | 131 +++++++++++++---- evm/src/arithmetic/modular.rs | 141 ++++++++++-------- evm/src/arithmetic/mul.rs | 14 +- evm/src/arithmetic/operations.rs | 166 ---------------------- evm/src/arithmetic/utils.rs | 42 ++++-- evm/src/cpu/columns/ops.rs | 3 +- evm/src/cpu/stack.rs | 1 + evm/src/witness/transition.rs | 4 + 11 files changed, 346 insertions(+), 368 deletions(-) rename evm/src/arithmetic/{addcc.rs => addcy.rs} (82%) delete mode 100644 evm/src/arithmetic/operations.rs diff --git a/evm/src/arithmetic/addcc.rs b/evm/src/arithmetic/addcy.rs similarity index 82% rename from evm/src/arithmetic/addcc.rs rename to evm/src/arithmetic/addcy.rs index ed173ec7..32fa4a9e 100644 --- a/evm/src/arithmetic/addcc.rs +++ b/evm/src/arithmetic/addcy.rs @@ -14,51 +14,19 @@ //! GT: X > Z, inputs X, Z, output CY, auxiliary output Y //! LT: Z < X, inputs Z, X, output CY, auxiliary output Y -use itertools::{izip, Itertools}; +use ethereum_types::U256; +use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; -use plonky2::field::types::Field; +use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::arithmetic::columns::*; -use crate::arithmetic::utils::read_value_u64_limbs; +use crate::arithmetic::utils::u256_to_array; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -fn u256_add_cc(input0: [u64; N_LIMBS], input1: [u64; N_LIMBS]) -> ([u64; N_LIMBS], u64) { - // Input and output have 16-bit limbs - let mut output = [0u64; N_LIMBS]; - - const MASK: u64 = (1u64 << LIMB_BITS) - 1u64; - let mut cy = 0u64; - for (i, a, b) in izip!(0.., input0, input1) { - let s = a + b + cy; - cy = s >> LIMB_BITS; - assert!(cy <= 1u64, "input limbs were larger than 16 bits"); - output[i] = s & MASK; - } - (output, cy) -} - -fn u256_sub_br(input0: [u64; N_LIMBS], input1: [u64; N_LIMBS]) -> ([u64; N_LIMBS], u64) { - const LIMB_BOUNDARY: u64 = 1 << LIMB_BITS; - const MASK: u64 = LIMB_BOUNDARY - 1u64; - - let mut output = [0u64; N_LIMBS]; - let mut br = 0u64; - for (i, a, b) in izip!(0.., input0, input1) { - let d = LIMB_BOUNDARY + a - b - br; - // if a < b, then d < 2^16 so br = 1 - // if a >= b, then d >= 2^16 so br = 0 - br = 1u64 - (d >> LIMB_BITS); - assert!(br <= 1u64, "input limbs were larger than 16 bits"); - output[i] = d & MASK; - } - - (output, br) -} - /// Generate row for ADD, SUB, GT and LT operations. /// /// A row consists of four values, GENERAL_REGISTER_[012] and @@ -69,27 +37,35 @@ fn u256_sub_br(input0: [u64; N_LIMBS], input1: [u64; N_LIMBS]) -> ([u64; N_LIMBS /// SUB: REGISTER_2 - REGISTER_0, output in REGISTER_1, ignore REGISTER_BIT /// GT: REGISTER_0 > REGISTER_2, output in REGISTER_BIT, auxiliary output in REGISTER_1 /// LT: REGISTER_2 < REGISTER_0, output in REGISTER_BIT, auxiliary output in REGISTER_1 -pub(crate) fn generate(lv: &mut [F], filter: usize) { +pub(crate) fn generate( + lv: &mut [F], + filter: usize, + left_in: U256, + right_in: U256, +) { + // Swap left_in and right_in for LT + let (left_in, right_in) = if filter == IS_LT { + (right_in, left_in) + } else { + (left_in, right_in) + }; + match filter { IS_ADD => { - let x = read_value_u64_limbs(lv, GENERAL_REGISTER_0); - let y = read_value_u64_limbs(lv, GENERAL_REGISTER_1); - // x + y == z + cy*2^256 - let (z, cy) = u256_add_cc(x, y); - - lv[GENERAL_REGISTER_2].copy_from_slice(&z.map(F::from_canonical_u64)); - lv[GENERAL_REGISTER_BIT] = F::from_canonical_u64(cy); + let (result, cy) = left_in.overflowing_add(right_in); + u256_to_array(&mut lv[GENERAL_REGISTER_0], left_in); // x + u256_to_array(&mut lv[GENERAL_REGISTER_1], right_in); // y + u256_to_array(&mut lv[GENERAL_REGISTER_2], result); // z + lv[GENERAL_REGISTER_BIT] = F::from_bool(cy); } IS_SUB | IS_GT | IS_LT => { - let x = read_value_u64_limbs(lv, GENERAL_REGISTER_0); - let z = read_value_u64_limbs(lv, GENERAL_REGISTER_2); - // y == z - x + cy*2^256 - let (y, cy) = u256_sub_br(z, x); - - lv[GENERAL_REGISTER_1].copy_from_slice(&y.map(F::from_canonical_u64)); - lv[GENERAL_REGISTER_BIT] = F::from_canonical_u64(cy); + let (diff, cy) = right_in.overflowing_sub(left_in); + u256_to_array(&mut lv[GENERAL_REGISTER_0], left_in); // x + u256_to_array(&mut lv[GENERAL_REGISTER_2], right_in); // z + u256_to_array(&mut lv[GENERAL_REGISTER_1], diff); // y + lv[GENERAL_REGISTER_BIT] = F::from_bool(cy); } _ => panic!("unexpected operation filter"), }; @@ -144,7 +120,7 @@ const GOLDILOCKS_INVERSE_65536: u64 = 18446462594437939201; /// is true if `(x_n + y_n)*2^(16*n) == cy_{n-1}*2^(16*n) + /// z_n*2^(16*n) + cy_n*2^(16*n)` (again, this is `t` on line 127ff) /// with the last `cy_n` checked against the `given_cy` given as input. -pub(crate) fn eval_packed_generic_add_cc( +pub(crate) fn eval_packed_generic_addcy( yield_constr: &mut ConstraintConsumer

, filter: P, x: &[P], @@ -202,11 +178,11 @@ pub fn eval_packed_generic( eval_packed_generic_check_is_one_bit(yield_constr, op_filter, cy); // x + y = z + cy*2^256 - eval_packed_generic_add_cc(yield_constr, op_filter, x, y, z, cy, false); + eval_packed_generic_addcy(yield_constr, op_filter, x, y, z, cy, false); } #[allow(clippy::needless_collect)] -pub(crate) fn eval_ext_circuit_add_cc, const D: usize>( +pub(crate) fn eval_ext_circuit_addcy, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, yield_constr: &mut RecursiveConstraintConsumer, filter: ExtensionTarget, @@ -272,7 +248,7 @@ pub fn eval_ext_circuit, const D: usize>( let op_filter = builder.add_many_extension([is_add, is_sub, is_lt, is_gt]); eval_ext_circuit_check_is_one_bit(builder, yield_constr, op_filter, cy); - eval_ext_circuit_add_cc(builder, yield_constr, op_filter, x, y, z, cy, false); + eval_ext_circuit_addcy(builder, yield_constr, op_filter, x, y, z, cy, false); } #[cfg(test)] @@ -328,7 +304,7 @@ mod tests { .map(|_| F::from_canonical_u16(rng.gen::())); // set operation filter and ensure all constraints are - // satisfied. we have to explicitly set the other + // satisfied. We have to explicitly set the other // operation filters to zero since all are treated by // the call. lv[IS_ADD] = F::ZERO; @@ -337,7 +313,10 @@ mod tests { lv[IS_GT] = F::ZERO; lv[op_filter] = F::ONE; - generate(&mut lv, op_filter); + let left_in = U256::from(rng.gen::<[u8; 32]>()); + let right_in = U256::from(rng.gen::<[u8; 32]>()); + + generate(&mut lv, op_filter, left_in, right_in); let mut constrant_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 9af1a9c3..0a89f3c6 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -4,11 +4,12 @@ use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::util::transpose; -use crate::arithmetic::operations::Operation; -use crate::arithmetic::{addcc, columns, modular, mul}; +use crate::arithmetic::{addcy, columns, modular, mul, Operation}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::permutation::PermutationPair; @@ -33,6 +34,9 @@ impl ArithmeticStark { for i in 0..RANGE_MAX { cols[columns::RANGE_COUNTER][i] = F::from_canonical_usize(i); } + for i in RANGE_MAX..n_rows { + cols[columns::RANGE_COUNTER][i] = F::from_canonical_usize(RANGE_MAX - 1); + } // For each column c in cols, generate the range-check // permutations and put them in the corresponding range-check @@ -44,7 +48,8 @@ impl ArithmeticStark { } } - pub fn generate(&self, operations: Vec<&dyn Operation>) -> Vec> { + #[allow(unused)] + pub(crate) fn generate(&self, operations: Vec) -> Vec> { // The number of rows reserved is the smallest value that's // guaranteed to avoid a reallocation: The only ops that use // two rows are the modular operations and DIV, so the only @@ -96,14 +101,25 @@ impl, const D: usize> Stark for ArithmeticSta let lv = vars.local_values; let nv = vars.next_values; + // Check the range column: First value must be 0, last row + // must be 2^16-1, and intermediate rows must increment by 0 + // or 1. + let rc1 = lv[columns::RANGE_COUNTER]; + let rc2 = nv[columns::RANGE_COUNTER]; + yield_constr.constraint_first_row(rc1); + let incr = rc2 - rc1; + yield_constr.constraint_transition(incr * incr - incr); + let range_max = P::Scalar::from_canonical_u64((RANGE_MAX - 1) as u64); + yield_constr.constraint_last_row(rc1 - range_max); + mul::eval_packed_generic(lv, yield_constr); - addcc::eval_packed_generic(lv, yield_constr); + addcy::eval_packed_generic(lv, yield_constr); modular::eval_packed_generic(lv, nv, yield_constr); } fn eval_ext_circuit( &self, - builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + builder: &mut CircuitBuilder, vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ) { @@ -114,8 +130,20 @@ impl, const D: usize> Stark for ArithmeticSta let lv = vars.local_values; let nv = vars.next_values; + + let rc1 = lv[columns::RANGE_COUNTER]; + let rc2 = nv[columns::RANGE_COUNTER]; + yield_constr.constraint_first_row(builder, rc1); + let incr = builder.sub_extension(rc2, rc1); + let t = builder.mul_sub_extension(incr, incr, incr); + yield_constr.constraint_transition(builder, t); + let range_max = + builder.constant_extension(F::Extension::from_canonical_usize(RANGE_MAX - 1)); + let t = builder.sub_extension(rc1, range_max); + yield_constr.constraint_last_row(builder, t); + mul::eval_ext_circuit(builder, lv, yield_constr); - addcc::eval_ext_circuit(builder, lv, yield_constr); + addcy::eval_ext_circuit(builder, lv, yield_constr); modular::eval_ext_circuit(builder, lv, nv, yield_constr); } @@ -148,7 +176,7 @@ mod tests { use rand_chacha::ChaCha8Rng; use super::{columns, ArithmeticStark}; - use crate::arithmetic::operations::*; + use crate::arithmetic::*; use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; #[test] @@ -189,34 +217,37 @@ mod tests { }; // 123 + 456 == 579 - let add = SimpleBinaryOp::new(columns::IS_ADD, U256::from(123), U256::from(456)); + let add = Operation::binary(BinaryOperator::Add, U256::from(123), U256::from(456)); // (123 * 456) % 1007 == 703 - let mulmod = ModularBinaryOp::new( - columns::IS_MULMOD, + let mulmod = Operation::ternary( + TernaryOperator::MulMod, U256::from(123), U256::from(456), U256::from(1007), ); - // (123 - 456) % 1007 == 674 - let submod = ModularBinaryOp::new( - columns::IS_SUBMOD, - U256::from(123), - U256::from(456), + // (1234 + 567) % 1007 == 794 + let addmod = Operation::ternary( + TernaryOperator::AddMod, + U256::from(1234), + U256::from(567), U256::from(1007), ); // 123 * 456 == 56088 - let mul = SimpleBinaryOp::new(columns::IS_MUL, U256::from(123), U256::from(456)); - // 128 % 13 == 11 - let modop = ModOp { - input: U256::from(128), - modulus: U256::from(13), - }; + let mul = Operation::binary(BinaryOperator::Mul, U256::from(123), U256::from(456)); // 128 / 13 == 9 - let div = DivOp { - numerator: U256::from(128), - denominator: U256::from(13), - }; - let ops: Vec<&dyn Operation> = vec![&add, &mulmod, &submod, &mul, &div, &modop]; + let div = Operation::binary(BinaryOperator::Div, U256::from(128), U256::from(13)); + + // 128 < 13 == 0 + let lt1 = Operation::binary(BinaryOperator::Lt, U256::from(128), U256::from(13)); + // 13 < 128 == 1 + let lt2 = Operation::binary(BinaryOperator::Lt, U256::from(13), U256::from(128)); + // 128 < 128 == 0 + let lt3 = Operation::binary(BinaryOperator::Lt, U256::from(128), U256::from(128)); + + // 128 % 13 == 11 + let modop = Operation::binary(BinaryOperator::Mod, U256::from(128), U256::from(13)); + + let ops: Vec = vec![add, mulmod, addmod, mul, modop, lt1, lt2, lt3, div]; let pols = stark.generate(ops); @@ -228,15 +259,21 @@ mod tests { && pols.iter().all(|v| v.len() == super::RANGE_MAX) ); + // Wrap the single value GENERAL_REGISTER_BIT in a Range. + let cmp_range = columns::GENERAL_REGISTER_BIT..columns::GENERAL_REGISTER_BIT + 1; + // Each operation has a single word answer that we can check let expected_output = [ // Row (some ops take two rows), col, expected - (0, columns::GENERAL_REGISTER_2, 579), // ADD_OUTPUT - (1, columns::MODULAR_OUTPUT, 703), - (3, columns::MODULAR_OUTPUT, 674), - (5, columns::MUL_OUTPUT, 56088), - (6, columns::MODULAR_OUTPUT, 11), - (8, columns::DIV_OUTPUT, 9), + (0, &columns::GENERAL_REGISTER_2, 579), // ADD_OUTPUT + (1, &columns::MODULAR_OUTPUT, 703), + (3, &columns::MODULAR_OUTPUT, 794), + (5, &columns::MUL_OUTPUT, 56088), + (6, &columns::MODULAR_OUTPUT, 11), + (8, &cmp_range, 0), + (9, &cmp_range, 1), + (10, &cmp_range, 0), + (11, &columns::DIV_OUTPUT, 9), ]; for (row, col, expected) in expected_output { @@ -269,18 +306,14 @@ mod tests { let ops = (0..super::RANGE_MAX) .map(|_| { - SimpleBinaryOp::new( - columns::IS_MUL, + Operation::binary( + BinaryOperator::Mul, U256::from(rng.gen::<[u8; 32]>()), U256::from(rng.gen::<[u8; 32]>()), ) }) .collect::>(); - // TODO: This is clearly not the right way to build this - // vector; I can't work out how to do it using the map above - // though, with or without Boxes. - let ops = ops.iter().map(|o| o as &dyn Operation).collect(); let pols = stark.generate(ops); // Trace should always have NUM_ARITH_COLUMNS columns and @@ -293,8 +326,8 @@ mod tests { let ops = (0..super::RANGE_MAX) .map(|_| { - ModularBinaryOp::new( - columns::IS_MULMOD, + Operation::ternary( + TernaryOperator::MulMod, U256::from(rng.gen::<[u8; 32]>()), U256::from(rng.gen::<[u8; 32]>()), U256::from(rng.gen::<[u8; 32]>()), @@ -302,10 +335,6 @@ mod tests { }) .collect::>(); - // TODO: This is clearly not the right way to build this - // vector; I can't work out how to do it using the map above - // though, with or without Boxes. - let ops = ops.iter().map(|o| o as &dyn Operation).collect(); let pols = stark.generate(ops); // Trace should always have NUM_ARITH_COLUMNS columns and diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index e05a4070..952a8ed5 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -101,7 +101,9 @@ pub(crate) const MODULAR_AUX_INPUT_HI: Range = AUX_REGISTER_2; // Must be set to MOD_IS_ZERO for DIV operation i.e. MOD_IS_ZERO * lv[IS_DIV] pub(crate) const MODULAR_DIV_DENOM_IS_ZERO: usize = AUX_REGISTER_2.end; +#[allow(unused)] // TODO: Will be used when hooking into the CPU pub(crate) const DIV_NUMERATOR: Range = MODULAR_INPUT_0; +#[allow(unused)] // TODO: Will be used when hooking into the CPU pub(crate) const DIV_DENOMINATOR: Range = MODULAR_MODULUS; #[allow(unused)] // TODO: Will be used when hooking into the CPU pub(crate) const DIV_OUTPUT: Range = diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index 60c8a2f8..6ba8ed12 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -1,10 +1,9 @@ -use std::str::FromStr; - use ethereum_types::U256; +use plonky2::field::types::PrimeField64; use crate::util::{addmod, mulmod, submod}; -mod addcc; +mod addcy; mod modular; mod mul; mod utils; @@ -12,8 +11,6 @@ mod utils; pub mod arithmetic_stark; pub(crate) mod columns; -pub mod operations; - #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum BinaryOperator { Add, @@ -48,31 +45,36 @@ impl BinaryOperator { input0 % input1 } } - BinaryOperator::Lt => { - if input0 < input1 { - U256::one() - } else { - U256::zero() - } - } - BinaryOperator::Gt => { - if input0 > input1 { - U256::one() - } else { - U256::zero() - } - } - BinaryOperator::AddFp254 => addmod(input0, input1, bn_base_order()), - BinaryOperator::MulFp254 => mulmod(input0, input1, bn_base_order()), - BinaryOperator::SubFp254 => submod(input0, input1, bn_base_order()), + BinaryOperator::Lt => U256::from((input0 < input1) as u8), + BinaryOperator::Gt => U256::from((input0 > input1) as u8), + BinaryOperator::AddFp254 => addmod(input0, input1, BN_BASE_ORDER), + BinaryOperator::MulFp254 => mulmod(input0, input1, BN_BASE_ORDER), + BinaryOperator::SubFp254 => submod(input0, input1, BN_BASE_ORDER), + } + } + + pub(crate) fn row_filter(&self) -> usize { + match self { + BinaryOperator::Add => columns::IS_ADD, + BinaryOperator::Mul => columns::IS_MUL, + BinaryOperator::Sub => columns::IS_SUB, + BinaryOperator::Div => columns::IS_DIV, + BinaryOperator::Mod => columns::IS_MOD, + BinaryOperator::Lt => columns::IS_LT, + BinaryOperator::Gt => columns::IS_GT, + BinaryOperator::AddFp254 => columns::IS_ADDMOD, + BinaryOperator::MulFp254 => columns::IS_MULMOD, + BinaryOperator::SubFp254 => columns::IS_SUBMOD, } } } +#[allow(clippy::enum_variant_names)] #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum TernaryOperator { AddMod, MulMod, + SubMod, } impl TernaryOperator { @@ -80,6 +82,15 @@ impl TernaryOperator { match self { TernaryOperator::AddMod => addmod(input0, input1, input2), TernaryOperator::MulMod => mulmod(input0, input1, input2), + TernaryOperator::SubMod => submod(input0, input1, input2), + } + } + + pub(crate) fn row_filter(&self) -> usize { + match self { + TernaryOperator::AddMod => columns::IS_ADDMOD, + TernaryOperator::MulMod => columns::IS_MULMOD, + TernaryOperator::SubMod => columns::IS_SUBMOD, } } } @@ -135,8 +146,80 @@ impl Operation { Operation::TernaryOperation { result, .. } => *result, } } + + /// Convert operation into one or two rows of the trace. + /// + /// Morally these types should be [F; NUM_ARITH_COLUMNS], but we + /// use vectors because that's what utils::transpose (who consumes + /// the result of this function as part of the range check code) + /// expects. + fn to_rows(&self) -> (Vec, Option>) { + match *self { + Operation::BinaryOperation { + operator, + input0, + input1, + result, + } => binary_op_to_rows(operator, input0, input1, result), + Operation::TernaryOperation { + operator, + input0, + input1, + input2, + result, + } => ternary_op_to_rows(operator.row_filter(), input0, input1, input2, result), + } + } } -fn bn_base_order() -> U256 { - U256::from_str("0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47").unwrap() +fn ternary_op_to_rows( + row_filter: usize, + input0: U256, + input1: U256, + input2: U256, + _result: U256, +) -> (Vec, Option>) { + let mut row1 = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + let mut row2 = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + + row1[row_filter] = F::ONE; + + modular::generate(&mut row1, &mut row2, row_filter, input0, input1, input2); + + (row1, Some(row2)) } + +fn binary_op_to_rows( + op: BinaryOperator, + input0: U256, + input1: U256, + result: U256, +) -> (Vec, Option>) { + let mut row = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + row[op.row_filter()] = F::ONE; + + match op { + BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Lt | BinaryOperator::Gt => { + addcy::generate(&mut row, op.row_filter(), input0, input1); + (row, None) + } + BinaryOperator::Mul => { + mul::generate(&mut row, input0, input1); + (row, None) + } + BinaryOperator::Div | BinaryOperator::Mod => { + ternary_op_to_rows::(op.row_filter(), input0, U256::zero(), input1, result) + } + BinaryOperator::AddFp254 | BinaryOperator::MulFp254 | BinaryOperator::SubFp254 => { + ternary_op_to_rows::(op.row_filter(), input0, input1, BN_BASE_ORDER, result) + } + } +} + +/// Order of the BN254 base field. +const BN_BASE_ORDER: U256 = U256([ + 4332616871279656263, + 10917124144477883021, + 13281191951274694749, + 3486998266802970665, +]); diff --git a/evm/src/arithmetic/modular.rs b/evm/src/arithmetic/modular.rs index 799a4ad4..99eccacb 100644 --- a/evm/src/arithmetic/modular.rs +++ b/evm/src/arithmetic/modular.rs @@ -108,17 +108,18 @@ //! only require 96 columns, or 80 if the output doesn't need to be //! reduced. +use ethereum_types::U256; use num::bigint::Sign; use num::{BigInt, One, Zero}; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; -use plonky2::field::types::Field; +use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use super::columns; -use crate::arithmetic::addcc::{eval_ext_circuit_add_cc, eval_packed_generic_add_cc}; +use crate::arithmetic::addcy::{eval_ext_circuit_addcy, eval_packed_generic_addcy}; use crate::arithmetic::columns::*; use crate::arithmetic::utils::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; @@ -189,7 +190,7 @@ fn bigint_to_columns(num: &BigInt) -> [i64; N] { /// /// NB: `operation` can set the higher order elements in its result to /// zero if they are not used. -fn generate_modular_op( +fn generate_modular_op( lv: &mut [F], nv: &mut [F], filter: usize, @@ -213,6 +214,13 @@ fn generate_modular_op( let mut constr_poly = [0i64; 2 * N_LIMBS]; constr_poly[..2 * N_LIMBS - 1].copy_from_slice(&operation(input0_limbs, input1_limbs)); + // two_exp_256 == 2^256 + let two_exp_256 = { + let mut t = BigInt::zero(); + t.set_bit(256, true); + t + }; + let mut mod_is_zero = F::ZERO; if modulus.is_zero() { if filter == columns::IS_DIV { @@ -242,8 +250,8 @@ fn generate_modular_op( let quot = (&input - &output) / &modulus; // exact division; can be -ve let quot_limbs = bigint_to_columns::<{ 2 * N_LIMBS }>("); - // output < modulus here, so the proof requires (modulus - output). - let out_aux_red = bigint_to_columns::(&(modulus - output)); + // output < modulus here; the proof requires (output - modulus) % 2^256: + let out_aux_red = bigint_to_columns::(&(two_exp_256 - modulus + output)); // constr_poly is the array of coefficients of the polynomial // @@ -283,8 +291,20 @@ fn generate_modular_op( /// Generate the output and auxiliary values for modular operations. /// /// `filter` must be one of `columns::IS_{ADDMOD,MULMOD,MOD}`. -pub(crate) fn generate(lv: &mut [F], nv: &mut [F], filter: usize) { +pub(crate) fn generate( + lv: &mut [F], + nv: &mut [F], + filter: usize, + input0: U256, + input1: U256, + modulus: U256, +) { debug_assert!(lv.len() == NUM_ARITH_COLUMNS && nv.len() == NUM_ARITH_COLUMNS); + + u256_to_array(&mut lv[MODULAR_INPUT_0], input0); + u256_to_array(&mut lv[MODULAR_INPUT_1], input1); + u256_to_array(&mut lv[MODULAR_MODULUS], modulus); + match filter { columns::IS_ADDMOD => generate_modular_op(lv, nv, filter, pol_add), columns::IS_SUBMOD => generate_modular_op(lv, nv, filter, pol_sub), @@ -332,30 +352,30 @@ fn modular_constr_poly( yield_constr.constraint_transition(filter * (mod_is_zero * lv[IS_DIV] - div_denom_is_zero)); // Needed to compensate for adding mod_is_zero to modulus above, - // since the call eval_packed_generic_add_cc() below subtracts modulus + // since the call eval_packed_generic_addcy() below subtracts modulus // to verify in the case of a DIV. output[0] += div_denom_is_zero; // Verify that the output is reduced, i.e. output < modulus. let out_aux_red = &nv[MODULAR_OUT_AUX_RED]; - // This sets is_greater_than to 0 unless we get mod_is_zero when - // doing a DIV; in that case, we need is_greater_than=1, since - // eval_packed_generic_add_cc checks + // This sets is_less_than to 1 unless we get mod_is_zero when + // doing a DIV; in that case, we need is_less_than=0, since + // eval_packed_generic_addcy checks // - // output + out_aux_red == modulus + is_greater_than*2^256 + // modulus + out_aux_red == output + is_less_than*2^256 // - // and we were given output = out_aux_red - let is_greater_than = mod_is_zero * lv[IS_DIV]; + // and we are given output = out_aux_red when modulus is zero. + let is_less_than = P::ONES - mod_is_zero * lv[IS_DIV]; // NB: output and modulus in lv while out_aux_red and - // is_greater_than (via mod_is_zero) depend on nv, hence the + // is_less_than (via mod_is_zero) depend on nv, hence the // 'is_two_row_op' argument is set to 'true'. - eval_packed_generic_add_cc( + eval_packed_generic_addcy( yield_constr, filter, - &output, - out_aux_red, &modulus, - is_greater_than, + out_aux_red, + &output, + is_less_than, true, ); // restore output[0] @@ -483,16 +503,18 @@ fn modular_constr_poly_ext_circuit, const D: usize> output[0] = builder.add_extension(output[0], div_denom_is_zero); let out_aux_red = &nv[MODULAR_OUT_AUX_RED]; - let is_greater_than = builder.mul_extension(mod_is_zero, lv[IS_DIV]); + let one = builder.one_extension(); + let is_less_than = + builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one); - eval_ext_circuit_add_cc( + eval_ext_circuit_addcy( builder, yield_constr, filter, - &output, - out_aux_red, &modulus, - is_greater_than, + out_aux_red, + &output, + is_less_than, true, ); output[0] = builder.sub_extension(output[0], div_denom_is_zero); @@ -574,7 +596,6 @@ pub(crate) fn eval_ext_circuit, const D: usize>( #[cfg(test)] mod tests { - use itertools::izip; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::{Field, Sample}; use rand::{Rng, SeedableRng}; @@ -620,38 +641,40 @@ mod tests { type F = GoldilocksField; let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); - let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); - let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); for op_filter in [IS_ADDMOD, IS_DIV, IS_SUBMOD, IS_MOD, IS_MULMOD] { - // Reset operation columns, then select one - lv[IS_ADDMOD] = F::ZERO; - lv[IS_SUBMOD] = F::ZERO; - lv[IS_MULMOD] = F::ZERO; - lv[IS_MOD] = F::ZERO; - lv[IS_DIV] = F::ZERO; - lv[op_filter] = F::ONE; - for i in 0..N_RND_TESTS { // set inputs to random values - for (ai, bi, mi) in izip!(MODULAR_INPUT_0, MODULAR_INPUT_1, MODULAR_MODULUS) { - lv[ai] = F::from_canonical_u16(rng.gen()); - lv[bi] = F::from_canonical_u16(rng.gen()); - lv[mi] = F::from_canonical_u16(rng.gen()); - } + let mut lv = [F::default(); NUM_ARITH_COLUMNS] + .map(|_| F::from_canonical_u16(rng.gen::())); + let mut nv = [F::default(); NUM_ARITH_COLUMNS] + .map(|_| F::from_canonical_u16(rng.gen::())); + // Reset operation columns, then select one + lv[IS_ADDMOD] = F::ZERO; + lv[IS_SUBMOD] = F::ZERO; + lv[IS_MULMOD] = F::ZERO; + lv[IS_MOD] = F::ZERO; + lv[IS_DIV] = F::ZERO; + lv[op_filter] = F::ONE; + + let input0 = U256::from(rng.gen::<[u8; 32]>()); + let input1 = U256::from(rng.gen::<[u8; 32]>()); + + let mut modulus_limbs = [0u8; 32]; // For the second half of the tests, set the top // 16-start digits of the modulus to zero so it is // much smaller than the inputs. if i > N_RND_TESTS / 2 { // 1 <= start < N_LIMBS - let start = (rng.gen::() % (N_LIMBS - 1)) + 1; - for mi in MODULAR_MODULUS.skip(start) { - lv[mi] = F::ZERO; + let start = (rng.gen::() % (modulus_limbs.len() - 1)) + 1; + for mi in modulus_limbs.iter_mut().skip(start) { + *mi = 0u8; } } + let modulus = U256::from(modulus_limbs); - generate(&mut lv, &mut nv, op_filter); + generate(&mut lv, &mut nv, op_filter, input0, input1, modulus); let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], @@ -672,29 +695,29 @@ mod tests { type F = GoldilocksField; let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); - let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); - let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); for op_filter in [IS_ADDMOD, IS_SUBMOD, IS_DIV, IS_MOD, IS_MULMOD] { - // Reset operation columns, then select one - lv[IS_ADDMOD] = F::ZERO; - lv[IS_SUBMOD] = F::ZERO; - lv[IS_MULMOD] = F::ZERO; - lv[IS_MOD] = F::ZERO; - lv[IS_DIV] = F::ZERO; - lv[op_filter] = F::ONE; - for _i in 0..N_RND_TESTS { // set inputs to random values and the modulus to zero; // the output is defined to be zero when modulus is zero. + let mut lv = [F::default(); NUM_ARITH_COLUMNS] + .map(|_| F::from_canonical_u16(rng.gen::())); + let mut nv = [F::default(); NUM_ARITH_COLUMNS] + .map(|_| F::from_canonical_u16(rng.gen::())); - for (ai, bi, mi) in izip!(MODULAR_INPUT_0, MODULAR_INPUT_1, MODULAR_MODULUS) { - lv[ai] = F::from_canonical_u16(rng.gen()); - lv[bi] = F::from_canonical_u16(rng.gen()); - lv[mi] = F::ZERO; - } + // Reset operation columns, then select one + lv[IS_ADDMOD] = F::ZERO; + lv[IS_SUBMOD] = F::ZERO; + lv[IS_MULMOD] = F::ZERO; + lv[IS_MOD] = F::ZERO; + lv[IS_DIV] = F::ZERO; + lv[op_filter] = F::ONE; - generate(&mut lv, &mut nv, op_filter); + let input0 = U256::from(rng.gen::<[u8; 32]>()); + let input1 = U256::from(rng.gen::<[u8; 32]>()); + let modulus = U256::zero(); + + generate(&mut lv, &mut nv, op_filter, input0, input1, modulus); // check that the correct output was generated if op_filter == IS_DIV { diff --git a/evm/src/arithmetic/mul.rs b/evm/src/arithmetic/mul.rs index f12b407a..03acfa97 100644 --- a/evm/src/arithmetic/mul.rs +++ b/evm/src/arithmetic/mul.rs @@ -55,9 +55,10 @@ //! file `modular.rs`), we don't need to check that output is reduced, //! since any value of output is less than β^16 and is hence reduced. +use ethereum_types::U256; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; -use plonky2::field::types::Field; +use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; @@ -66,7 +67,12 @@ use crate::arithmetic::columns::*; use crate::arithmetic::utils::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -pub fn generate(lv: &mut [F]) { +pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { + // TODO: It would probably be clearer/cleaner to read the U256 + // into an [i64;N] and then copy that to the lv table. + u256_to_array(&mut lv[MUL_INPUT_0], left_in); + u256_to_array(&mut lv[MUL_INPUT_1], right_in); + let input0 = read_value_i64_limbs(lv, MUL_INPUT_0); let input1 = read_value_i64_limbs(lv, MUL_INPUT_1); @@ -252,7 +258,9 @@ mod tests { lv[bi] = F::from_canonical_u16(rng.gen()); } - generate(&mut lv); + let left_in = U256::from(rng.gen::<[u8; 32]>()); + let right_in = U256::from(rng.gen::<[u8; 32]>()); + generate(&mut lv, left_in, right_in); let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], diff --git a/evm/src/arithmetic/operations.rs b/evm/src/arithmetic/operations.rs deleted file mode 100644 index f8aed566..00000000 --- a/evm/src/arithmetic/operations.rs +++ /dev/null @@ -1,166 +0,0 @@ -use ethereum_types::U256; -use plonky2::hash::hash_types::RichField; -use static_assertions::const_assert; - -use crate::arithmetic::columns::*; -use crate::arithmetic::{addcc, modular, mul}; - -#[inline] -fn u64_to_array(out: &mut [F], x: u64) { - const_assert!(LIMB_BITS == 16); - debug_assert!(out.len() == 4); - - out[0] = F::from_canonical_u16(x as u16); - out[1] = F::from_canonical_u16((x >> 16) as u16); - out[2] = F::from_canonical_u16((x >> 32) as u16); - out[3] = F::from_canonical_u16((x >> 48) as u16); -} - -fn u256_to_array(out: &mut [F], x: U256) { - const_assert!(N_LIMBS == 16); - debug_assert!(out.len() == N_LIMBS); - - u64_to_array(&mut out[0..4], x.0[0]); - u64_to_array(&mut out[4..8], x.0[1]); - u64_to_array(&mut out[8..12], x.0[2]); - u64_to_array(&mut out[12..16], x.0[3]); -} - -pub trait Operation { - /// Convert operation into one or two rows of the trace. - /// - /// Morally these types should be [F; NUM_ARITH_COLUMNS], but we - /// use vectors because that's what utils::transpose expects. - fn to_rows(&self) -> (Vec, Option>); -} - -pub struct SimpleBinaryOp { - /// The operation is identified using the associated filter from - /// `columns::IS_ADD` etc., stored in `op_filter`. - op_filter: usize, - input0: U256, - input1: U256, -} - -impl SimpleBinaryOp { - pub fn new(op_filter: usize, input0: U256, input1: U256) -> Self { - assert!( - op_filter == IS_ADD - || op_filter == IS_SUB - || op_filter == IS_MUL - || op_filter == IS_LT - || op_filter == IS_GT - ); - Self { - op_filter, - input0, - input1, - } - } -} - -impl Operation for SimpleBinaryOp { - fn to_rows(&self) -> (Vec, Option>) { - let mut row = vec![F::ZERO; NUM_ARITH_COLUMNS]; - row[self.op_filter] = F::ONE; - - if self.op_filter == IS_SUB || self.op_filter == IS_GT { - u256_to_array(&mut row[GENERAL_REGISTER_2], self.input0); - u256_to_array(&mut row[GENERAL_REGISTER_0], self.input1); - } else if self.op_filter == IS_LT { - u256_to_array(&mut row[GENERAL_REGISTER_0], self.input0); - u256_to_array(&mut row[GENERAL_REGISTER_2], self.input1); - } else { - assert!( - self.op_filter == IS_ADD || self.op_filter == IS_MUL, - "unrecognised operation" - ); - u256_to_array(&mut row[GENERAL_REGISTER_0], self.input0); - u256_to_array(&mut row[GENERAL_REGISTER_1], self.input1); - } - - if self.op_filter == IS_MUL { - mul::generate(&mut row); - } else { - addcc::generate(&mut row, self.op_filter); - } - (row, None) - } -} - -pub struct ModularBinaryOp { - op_filter: usize, - input0: U256, - input1: U256, - modulus: U256, -} - -impl ModularBinaryOp { - pub fn new(op_filter: usize, input0: U256, input1: U256, modulus: U256) -> Self { - assert!(op_filter == IS_ADDMOD || op_filter == IS_SUBMOD || op_filter == IS_MULMOD); - Self { - op_filter, - input0, - input1, - modulus, - } - } -} - -fn modular_to_rows_helper( - op_filter: usize, - input0: U256, - input1: U256, - modulus: U256, -) -> (Vec, Option>) { - let mut row1 = vec![F::ZERO; NUM_ARITH_COLUMNS]; - let mut row2 = vec![F::ZERO; NUM_ARITH_COLUMNS]; - - row1[op_filter] = F::ONE; - - u256_to_array(&mut row1[MODULAR_INPUT_0], input0); - u256_to_array(&mut row1[MODULAR_INPUT_1], input1); - u256_to_array(&mut row1[MODULAR_MODULUS], modulus); - - modular::generate(&mut row1, &mut row2, op_filter); - - (row1, Some(row2)) -} - -impl Operation for ModularBinaryOp { - fn to_rows(&self) -> (Vec, Option>) { - modular_to_rows_helper(self.op_filter, self.input0, self.input1, self.modulus) - } -} - -pub struct ModOp { - pub input: U256, - pub modulus: U256, -} - -impl Operation for ModOp { - fn to_rows(&self) -> (Vec, Option>) { - modular_to_rows_helper(IS_MOD, self.input, U256::zero(), self.modulus) - } -} - -pub struct DivOp { - pub numerator: U256, - pub denominator: U256, -} - -impl Operation for DivOp { - fn to_rows(&self) -> (Vec, Option>) { - let mut row1 = vec![F::ZERO; NUM_ARITH_COLUMNS]; - let mut row2 = vec![F::ZERO; NUM_ARITH_COLUMNS]; - - row1[IS_DIV] = F::ONE; - - u256_to_array(&mut row1[DIV_NUMERATOR], self.numerator); - u256_to_array(&mut row1[DIV_DENOMINATOR], self.denominator); - - modular::generate(&mut row1, &mut row2, IS_DIV); - - (row1, Some(row2)) - } -} diff --git a/evm/src/arithmetic/utils.rs b/evm/src/arithmetic/utils.rs index 7eb33099..8b4f546a 100644 --- a/evm/src/arithmetic/utils.rs +++ b/evm/src/arithmetic/utils.rs @@ -1,11 +1,14 @@ use std::ops::{Add, AddAssign, Mul, Neg, Range, Shr, Sub, SubAssign}; +use ethereum_types::U256; use plonky2::field::extension::Extendable; +use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; +use static_assertions::const_assert; -use crate::arithmetic::columns::N_LIMBS; +use crate::arithmetic::columns::{LIMB_BITS, N_LIMBS}; /// Return an array of `N` zeros of type T. pub(crate) fn pol_zero() -> [T; N] @@ -315,24 +318,35 @@ pub(crate) fn read_value(lv: &[T], value_idxs: Range( - lv: &[F], - value_idxs: Range, -) -> [u64; N] { - let limbs: [_; N] = lv[value_idxs].try_into().unwrap(); - limbs.map(|c| F::to_canonical_u64(&c)) -} - /// Read the range `value_idxs` of values from `lv` into an array of /// length `N`, interpreting the values as `i64`s. Panics if the /// length of the range is not `N`. -pub(crate) fn read_value_i64_limbs( +pub(crate) fn read_value_i64_limbs( lv: &[F], value_idxs: Range, ) -> [i64; N] { let limbs: [_; N] = lv[value_idxs].try_into().unwrap(); - limbs.map(|c| F::to_canonical_u64(&c) as i64) + limbs.map(|c| c.to_canonical_u64() as i64) +} + +#[inline] +fn u64_to_array(out: &mut [F], x: u64) { + const_assert!(LIMB_BITS == 16); + debug_assert!(out.len() == 4); + + out[0] = F::from_canonical_u16(x as u16); + out[1] = F::from_canonical_u16((x >> 16) as u16); + out[2] = F::from_canonical_u16((x >> 32) as u16); + out[3] = F::from_canonical_u16((x >> 48) as u16); +} + +// TODO: Refactor/replace u256_limbs in evm/src/util.rs +pub(crate) fn u256_to_array(out: &mut [F], x: U256) { + const_assert!(N_LIMBS == 16); + debug_assert!(out.len() == N_LIMBS); + + u64_to_array(&mut out[0..4], x.0[0]); + u64_to_array(&mut out[4..8], x.0[1]); + u64_to_array(&mut out[8..12], x.0[2]); + u64_to_array(&mut out[12..16], x.0[3]); } diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index 63f6795d..1000d2fa 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -13,12 +13,13 @@ pub struct OpsColumnsView { pub sub: T, pub div: T, pub mod_: T, - // TODO: combine ADDMOD, MULMOD into one flag + // TODO: combine ADDMOD, MULMOD and SUBMOD into one flag pub addmod: T, pub mulmod: T, pub addfp254: T, pub mulfp254: T, pub subfp254: T, + pub submod: T, pub lt: T, pub gt: T, pub eq: T, // Note: This column must be 0 when is_cpu_cycle = 0. diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index 7f44c7b3..ee96a682 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -50,6 +50,7 @@ const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { addfp254: BASIC_BINARY_OP, mulfp254: BASIC_BINARY_OP, subfp254: BASIC_BINARY_OP, + submod: BASIC_TERNARY_OP, lt: BASIC_BINARY_OP, gt: BASIC_BINARY_OP, eq: BASIC_BINARY_OP, diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index b8e46d78..a60141c5 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -51,6 +51,9 @@ fn decode(registers: RegistersState, opcode: u8) -> Result Ok(Operation::BinaryArithmetic( arithmetic::BinaryOperator::SubFp254, )), + (0x0f, true) => Ok(Operation::TernaryArithmetic( + arithmetic::TernaryOperator::SubMod, + )), (0x10, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Lt)), (0x11, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Gt)), (0x12, _) => Ok(Operation::Syscall(opcode)), @@ -167,6 +170,7 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { Operation::BinaryArithmetic(arithmetic::BinaryOperator::SubFp254) => &mut flags.subfp254, Operation::TernaryArithmetic(arithmetic::TernaryOperator::AddMod) => &mut flags.addmod, Operation::TernaryArithmetic(arithmetic::TernaryOperator::MulMod) => &mut flags.mulmod, + Operation::TernaryArithmetic(arithmetic::TernaryOperator::SubMod) => &mut flags.submod, Operation::KeccakGeneral => &mut flags.keccak_general, Operation::ProverInput => &mut flags.prover_input, Operation::Pop => &mut flags.pop,