From c134b59763d28c76f28204c1032a546e8a0d4f95 Mon Sep 17 00:00:00 2001 From: Hamish Ivey-Law <426294+unzvfu@users.noreply.github.com> Date: Thu, 11 May 2023 03:29:06 +1000 Subject: [PATCH] Cross-table lookup for arithmetic stark (#905) * First draft of linking arithmetic Stark into the CTL mechanism. * Handle {ADD,SUB,MUL}FP254 operations explicitly in `modular.rs`. * Adjust argument order; add tests. * Add CTLs for ADD, MUL, SUB, LT and GT. * Add CTLs for {ADD,MUL,SUB}MOD, DIV and MOD. * Add CTLs for {ADD,MUL,SUB}FP254 operations. * Refactor the CPU/arithmetic CTL mapping; add some documentation. * Minor comment fixes. * Combine addcy CTLs at the expense of repeated constraint evaluation. * Combine addcy CTLs at the expense of repeated constraint evaluation. * Merge `*FP254` CTL into main CTL; rename some registers. * Connect extra argument from CPU in binary ops to facilitate combining with ternary ops. * Merge modular ops CTL into main CTL. * Refactor DIV and MOD code into its own module. * Merge DIV and MOD into arithmetic CTL. * Clippy. * Fixes related to merge. * Simplify register naming. * Generate u16 BN254 modulus limbs at compile time. * Clippy. * Add degree bits ranges for Arithmetic table. --- evm/src/all_stark.rs | 35 ++- evm/src/arithmetic/addcy.rs | 170 +++++++------ evm/src/arithmetic/arithmetic_stark.rs | 115 +++++++-- evm/src/arithmetic/columns.rs | 78 +++--- evm/src/arithmetic/divmod.rs | 339 +++++++++++++++++++++++++ evm/src/arithmetic/mod.rs | 11 +- evm/src/arithmetic/modular.rs | 292 +++++++++++++-------- evm/src/arithmetic/mul.rs | 25 +- evm/src/arithmetic/utils.rs | 11 - evm/src/cpu/cpu_stark.rs | 50 +++- evm/src/fixed_recursive_verifier.rs | 21 +- evm/src/prover.rs | 19 ++ evm/src/verifier.rs | 10 + evm/src/witness/traces.rs | 21 +- evm/tests/empty_txn_list.rs | 2 +- 15 files changed, 894 insertions(+), 305 deletions(-) create mode 100644 evm/src/arithmetic/divmod.rs diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 6ae6ad3e..b7cb52e5 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -4,6 +4,8 @@ use plonky2::field::extension::Extendable; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use crate::arithmetic::arithmetic_stark; +use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::config::StarkConfig; use crate::cpu::cpu_stark; use crate::cpu::cpu_stark::CpuStark; @@ -22,6 +24,7 @@ use crate::stark::Stark; #[derive(Clone)] pub struct AllStark, const D: usize> { + pub arithmetic_stark: ArithmeticStark, pub cpu_stark: CpuStark, pub keccak_stark: KeccakStark, pub keccak_sponge_stark: KeccakSpongeStark, @@ -33,6 +36,7 @@ pub struct AllStark, const D: usize> { impl, const D: usize> Default for AllStark { fn default() -> Self { Self { + arithmetic_stark: ArithmeticStark::default(), cpu_stark: CpuStark::default(), keccak_stark: KeccakStark::default(), keccak_sponge_stark: KeccakSpongeStark::default(), @@ -46,6 +50,7 @@ impl, const D: usize> Default for AllStark { impl, const D: usize> AllStark { pub(crate) fn nums_permutation_zs(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { [ + self.arithmetic_stark.num_permutation_batches(config), self.cpu_stark.num_permutation_batches(config), self.keccak_stark.num_permutation_batches(config), self.keccak_sponge_stark.num_permutation_batches(config), @@ -56,6 +61,7 @@ impl, const D: usize> AllStark { pub(crate) fn permutation_batch_sizes(&self) -> [usize; NUM_TABLES] { [ + self.arithmetic_stark.permutation_batch_size(), self.cpu_stark.permutation_batch_size(), self.keccak_stark.permutation_batch_size(), self.keccak_sponge_stark.permutation_batch_size(), @@ -67,11 +73,12 @@ impl, const D: usize> AllStark { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Table { - Cpu = 0, - Keccak = 1, - KeccakSponge = 2, - Logic = 3, - Memory = 4, + Arithmetic = 0, + Cpu = 1, + Keccak = 2, + KeccakSponge = 3, + Logic = 4, + Memory = 5, } pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; @@ -79,6 +86,7 @@ pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; impl Table { pub(crate) fn all() -> [Self; NUM_TABLES] { [ + Self::Arithmetic, Self::Cpu, Self::Keccak, Self::KeccakSponge, @@ -89,9 +97,15 @@ impl Table { } pub(crate) fn all_cross_table_lookups() -> Vec> { - let mut ctls = vec![ctl_keccak_sponge(), ctl_keccak(), ctl_logic(), ctl_memory()]; + let mut ctls = vec![ + ctl_arithmetic(), + ctl_keccak_sponge(), + ctl_keccak(), + ctl_logic(), + ctl_memory(), + ]; // TODO: Some CTLs temporarily disabled while we get them working. - disable_ctl(&mut ctls[3]); + disable_ctl(&mut ctls[4]); ctls } @@ -102,6 +116,13 @@ fn disable_ctl(ctl: &mut CrossTableLookup) { ctl.looked_table.filter_column = Some(Column::zero()); } +fn ctl_arithmetic() -> CrossTableLookup { + CrossTableLookup::new( + vec![cpu_stark::ctl_arithmetic_rows()], + arithmetic_stark::ctl_arithmetic_rows(), + ) +} + fn ctl_keccak() -> CrossTableLookup { let keccak_sponge_looking = TableWithColumns::new( Table::KeccakSponge, diff --git a/evm/src/arithmetic/addcy.rs b/evm/src/arithmetic/addcy.rs index 32fa4a9e..40b7e093 100644 --- a/evm/src/arithmetic/addcy.rs +++ b/evm/src/arithmetic/addcy.rs @@ -28,68 +28,41 @@ use crate::arithmetic::utils::u256_to_array; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; /// Generate row for ADD, SUB, GT and LT operations. -/// -/// A row consists of four values, GENERAL_REGISTER_[012] and -/// GENERAL_REGISTER_BIT. The interpretation of these values for each -/// operation is as follows: -/// -/// ADD: REGISTER_0 + REGISTER_1, output in REGISTER_2, ignore REGISTER_BIT -/// SUB: REGISTER_2 - REGISTER_0, output in REGISTER_1, ignore REGISTER_BIT -/// GT: REGISTER_0 > REGISTER_2, output in REGISTER_BIT, auxiliary output in REGISTER_1 -/// LT: REGISTER_2 < REGISTER_0, output in REGISTER_BIT, auxiliary output in REGISTER_1 pub(crate) fn generate( lv: &mut [F], filter: usize, left_in: U256, right_in: U256, ) { - // Swap left_in and right_in for LT - let (left_in, right_in) = if filter == IS_LT { - (right_in, left_in) - } else { - (left_in, right_in) - }; + u256_to_array(&mut lv[INPUT_REGISTER_0], left_in); + u256_to_array(&mut lv[INPUT_REGISTER_1], right_in); + u256_to_array(&mut lv[INPUT_REGISTER_2], U256::zero()); match filter { IS_ADD => { - // x + y == z + cy*2^256 let (result, cy) = left_in.overflowing_add(right_in); - u256_to_array(&mut lv[GENERAL_REGISTER_0], left_in); // x - u256_to_array(&mut lv[GENERAL_REGISTER_1], right_in); // y - u256_to_array(&mut lv[GENERAL_REGISTER_2], result); // z - lv[GENERAL_REGISTER_BIT] = F::from_bool(cy); + u256_to_array(&mut lv[AUX_INPUT_REGISTER_0], U256::from(cy as u32)); + u256_to_array(&mut lv[OUTPUT_REGISTER], result); } - IS_SUB | IS_GT | IS_LT => { - // y == z - x + cy*2^256 + IS_SUB => { + let (diff, cy) = left_in.overflowing_sub(right_in); + u256_to_array(&mut lv[AUX_INPUT_REGISTER_0], U256::from(cy as u32)); + u256_to_array(&mut lv[OUTPUT_REGISTER], diff); + } + IS_LT => { + let (diff, cy) = left_in.overflowing_sub(right_in); + u256_to_array(&mut lv[AUX_INPUT_REGISTER_0], diff); + u256_to_array(&mut lv[OUTPUT_REGISTER], U256::from(cy as u32)); + } + IS_GT => { let (diff, cy) = right_in.overflowing_sub(left_in); - u256_to_array(&mut lv[GENERAL_REGISTER_0], left_in); // x - u256_to_array(&mut lv[GENERAL_REGISTER_2], right_in); // z - u256_to_array(&mut lv[GENERAL_REGISTER_1], diff); // y - lv[GENERAL_REGISTER_BIT] = F::from_bool(cy); + u256_to_array(&mut lv[AUX_INPUT_REGISTER_0], diff); + u256_to_array(&mut lv[OUTPUT_REGISTER], U256::from(cy as u32)); } _ => panic!("unexpected operation filter"), }; } -fn eval_packed_generic_check_is_one_bit( - yield_constr: &mut ConstraintConsumer

, - filter: P, - x: P, -) { - yield_constr.constraint(filter * x * (x - P::ONES)); -} - -fn eval_ext_circuit_check_is_one_bit, const D: usize>( - builder: &mut CircuitBuilder, - yield_constr: &mut RecursiveConstraintConsumer, - filter: ExtensionTarget, - x: ExtensionTarget, -) { - let constr = builder.mul_sub_extension(x, x, x); - let filtered_constr = builder.mul_extension(filter, constr); - yield_constr.constraint(builder, filtered_constr); -} - /// 2^-16 mod (2^64 - 2^32 + 1) const GOLDILOCKS_INVERSE_65536: u64 = 18446462594437939201; @@ -126,10 +99,12 @@ pub(crate) fn eval_packed_generic_addcy( x: &[P], y: &[P], z: &[P], - given_cy: P, + given_cy: &[P], is_two_row_op: bool, ) { - debug_assert!(x.len() == N_LIMBS && y.len() == N_LIMBS && z.len() == N_LIMBS); + debug_assert!( + x.len() == N_LIMBS && y.len() == N_LIMBS && z.len() == N_LIMBS && given_cy.len() == N_LIMBS + ); let overflow = P::Scalar::from_canonical_u64(1u64 << LIMB_BITS); let overflow_inv = P::Scalar::from_canonical_u64(GOLDILOCKS_INVERSE_65536); @@ -154,9 +129,22 @@ pub(crate) fn eval_packed_generic_addcy( } if is_two_row_op { - yield_constr.constraint_transition(filter * (cy - given_cy)); + // NB: Mild hack: We don't check that given_cy[0] is 0 or 1 + // when is_two_row_op is true because that's only the case + // when this function is called from + // modular::modular_constr_poly(), in which case (1) this + // condition has already been checked and (2) it exceeds the + // degree budget because given_cy[0] is already degree 2. + yield_constr.constraint_transition(filter * (cy - given_cy[0])); + for i in 1..N_LIMBS { + yield_constr.constraint_transition(filter * given_cy[i]); + } } else { - yield_constr.constraint(filter * (cy - given_cy)); + yield_constr.constraint(filter * given_cy[0] * (given_cy[0] - P::ONES)); + yield_constr.constraint(filter * (cy - given_cy[0])); + for i in 1..N_LIMBS { + yield_constr.constraint(filter * given_cy[i]); + } } } @@ -169,30 +157,32 @@ pub fn eval_packed_generic( let is_lt = lv[IS_LT]; let is_gt = lv[IS_GT]; - let x = &lv[GENERAL_REGISTER_0]; - let y = &lv[GENERAL_REGISTER_1]; - let z = &lv[GENERAL_REGISTER_2]; - let cy = lv[GENERAL_REGISTER_BIT]; + let in0 = &lv[INPUT_REGISTER_0]; + let in1 = &lv[INPUT_REGISTER_1]; + let out = &lv[OUTPUT_REGISTER]; + let aux = &lv[AUX_INPUT_REGISTER_0]; - let op_filter = is_add + is_sub + is_lt + is_gt; - eval_packed_generic_check_is_one_bit(yield_constr, op_filter, cy); - - // x + y = z + cy*2^256 - eval_packed_generic_addcy(yield_constr, op_filter, x, y, z, cy, false); + // x + y = z + w*2^256 + eval_packed_generic_addcy(yield_constr, is_add, in0, in1, out, aux, false); + eval_packed_generic_addcy(yield_constr, is_sub, in1, out, in0, aux, false); + eval_packed_generic_addcy(yield_constr, is_lt, in1, aux, in0, out, false); + eval_packed_generic_addcy(yield_constr, is_gt, in0, aux, in1, out, false); } #[allow(clippy::needless_collect)] pub(crate) fn eval_ext_circuit_addcy, const D: usize>( - builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + builder: &mut CircuitBuilder, yield_constr: &mut RecursiveConstraintConsumer, filter: ExtensionTarget, x: &[ExtensionTarget], y: &[ExtensionTarget], z: &[ExtensionTarget], - given_cy: ExtensionTarget, + given_cy: &[ExtensionTarget], is_two_row_op: bool, ) { - debug_assert!(x.len() == N_LIMBS && y.len() == N_LIMBS && z.len() == N_LIMBS); + debug_assert!( + x.len() == N_LIMBS && y.len() == N_LIMBS && z.len() == N_LIMBS && given_cy.len() == N_LIMBS + ); // 2^LIMB_BITS in the base field let overflow_base = F::from_canonical_u64(1 << LIMB_BITS); @@ -222,17 +212,31 @@ pub(crate) fn eval_ext_circuit_addcy, const D: usiz cy = builder.mul_const_extension(overflow_inv, t); } - let good_cy = builder.sub_extension(cy, given_cy); - let filter = builder.mul_extension(filter, good_cy); + let good_cy = builder.sub_extension(cy, given_cy[0]); + let cy_filter = builder.mul_extension(filter, good_cy); + + // Check given carry is one bit + let bit_constr = builder.mul_sub_extension(given_cy[0], given_cy[0], given_cy[0]); + let bit_filter = builder.mul_extension(filter, bit_constr); + if is_two_row_op { - yield_constr.constraint_transition(builder, filter); + yield_constr.constraint_transition(builder, cy_filter); + for i in 1..N_LIMBS { + let t = builder.mul_extension(filter, given_cy[i]); + yield_constr.constraint_transition(builder, t); + } } else { - yield_constr.constraint(builder, filter); + yield_constr.constraint(builder, bit_filter); + yield_constr.constraint(builder, cy_filter); + for i in 1..N_LIMBS { + let t = builder.mul_extension(filter, given_cy[i]); + yield_constr.constraint(builder, t); + } } } pub fn eval_ext_circuit, const D: usize>( - builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + builder: &mut CircuitBuilder, lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, ) { @@ -241,14 +245,15 @@ pub fn eval_ext_circuit, const D: usize>( let is_lt = lv[IS_LT]; let is_gt = lv[IS_GT]; - let x = &lv[GENERAL_REGISTER_0]; - let y = &lv[GENERAL_REGISTER_1]; - let z = &lv[GENERAL_REGISTER_2]; - let cy = lv[GENERAL_REGISTER_BIT]; + let in0 = &lv[INPUT_REGISTER_0]; + let in1 = &lv[INPUT_REGISTER_1]; + let out = &lv[OUTPUT_REGISTER]; + let aux = &lv[AUX_INPUT_REGISTER_0]; - let op_filter = builder.add_many_extension([is_add, is_sub, is_lt, is_gt]); - eval_ext_circuit_check_is_one_bit(builder, yield_constr, op_filter, cy); - eval_ext_circuit_addcy(builder, yield_constr, op_filter, x, y, z, cy, false); + eval_ext_circuit_addcy(builder, yield_constr, is_add, in0, in1, out, aux, false); + eval_ext_circuit_addcy(builder, yield_constr, is_sub, in1, out, in0, aux, false); + eval_ext_circuit_addcy(builder, yield_constr, is_lt, in1, aux, in0, out, false); + eval_ext_circuit_addcy(builder, yield_constr, is_gt, in0, aux, in1, out, false); } #[cfg(test)] @@ -264,7 +269,7 @@ mod tests { // TODO: Should be able to refactor this test to apply to all operations. #[test] - fn generate_eval_consistency_not_addcc() { + fn generate_eval_consistency_not_addcy() { type F = GoldilocksField; let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); @@ -291,7 +296,7 @@ mod tests { } #[test] - fn generate_eval_consistency_addcc() { + fn generate_eval_consistency_addcy() { type F = GoldilocksField; let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); @@ -328,6 +333,21 @@ mod tests { for &acc in &constrant_consumer.constraint_accs { assert_eq!(acc, F::ZERO); } + + let expected = match op_filter { + IS_ADD => left_in.overflowing_add(right_in).0, + IS_SUB => left_in.overflowing_sub(right_in).0, + IS_LT => U256::from((left_in < right_in) as u8), + IS_GT => U256::from((left_in > right_in) as u8), + _ => panic!("unrecognised operation"), + }; + + let mut expected_limbs = [F::ZERO; N_LIMBS]; + u256_to_array(&mut expected_limbs, expected); + assert!(expected_limbs + .iter() + .zip(&lv[OUTPUT_REGISTER]) + .all(|(x, y)| x == y)); } } } diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 0a89f3c6..342bb8c2 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -1,4 +1,5 @@ use std::marker::PhantomData; +use std::ops::Range; use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; @@ -8,15 +9,82 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::util::transpose; +use static_assertions::const_assert; -use crate::arithmetic::{addcy, columns, modular, mul, Operation}; +use crate::all_stark::Table; +use crate::arithmetic::{addcy, columns, divmod, modular, mul, Operation}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cross_table_lookup::{Column, TableWithColumns}; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::permutation::PermutationPair; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; -#[derive(Copy, Clone)] +/// Link the 16-bit columns of the arithmetic table, split into groups +/// of N_LIMBS at a time in `regs`, with the corresponding 32-bit +/// columns of the CPU table. Does this for all ops in `ops`. +/// +/// This is done by taking pairs of columns (x, y) of the arithmetic +/// table and combining them as x + y*2^16 to ensure they equal the +/// corresponding 32-bit number in the CPU table. +fn cpu_arith_data_link(ops: &[usize], regs: &[Range]) -> Vec> { + let limb_base = F::from_canonical_u64(1 << columns::LIMB_BITS); + + let mut res = Column::singles(ops).collect_vec(); + + // The inner for loop below assumes N_LIMBS is even. + const_assert!(columns::N_LIMBS % 2 == 0); + + for reg_cols in regs { + // Loop below assumes we're operating on a "register" of N_LIMBS columns. + debug_assert_eq!(reg_cols.len(), columns::N_LIMBS); + + for i in 0..(columns::N_LIMBS / 2) { + let c0 = reg_cols.start + 2 * i; + let c1 = reg_cols.start + 2 * i + 1; + res.push(Column::linear_combination([(c0, F::ONE), (c1, limb_base)])); + } + } + res +} + +pub fn ctl_arithmetic_rows() -> TableWithColumns { + const ARITH_OPS: [usize; 13] = [ + columns::IS_ADD, + columns::IS_SUB, + columns::IS_MUL, + columns::IS_LT, + columns::IS_GT, + columns::IS_ADDFP254, + columns::IS_MULFP254, + columns::IS_SUBFP254, + columns::IS_ADDMOD, + columns::IS_MULMOD, + columns::IS_SUBMOD, + columns::IS_DIV, + columns::IS_MOD, + ]; + + const REGISTER_MAP: [Range; 4] = [ + columns::INPUT_REGISTER_0, + columns::INPUT_REGISTER_1, + columns::INPUT_REGISTER_2, + columns::OUTPUT_REGISTER, + ]; + + // Create the Arithmetic Table whose columns are those of the + // operations listed in `ops` whose inputs and outputs are given + // by `regs`, where each element of `regs` is a range of columns + // corresponding to a 256-bit input or output register (also `ops` + // is used as the operation filter). + TableWithColumns::new( + Table::Arithmetic, + cpu_arith_data_link(&ARITH_OPS, ®ISTER_MAP), + Some(Column::sum(ARITH_OPS)), + ) +} + +#[derive(Copy, Clone, Default)] pub struct ArithmeticStark { pub f: PhantomData, } @@ -48,8 +116,7 @@ impl ArithmeticStark { } } - #[allow(unused)] - pub(crate) fn generate(&self, operations: Vec) -> Vec> { + pub(crate) fn generate_trace(&self, operations: Vec) -> Vec> { // The number of rows reserved is the smallest value that's // guaranteed to avoid a reallocation: The only ops that use // two rows are the modular operations and DIV, so the only @@ -114,7 +181,8 @@ impl, const D: usize> Stark for ArithmeticSta mul::eval_packed_generic(lv, yield_constr); addcy::eval_packed_generic(lv, yield_constr); - modular::eval_packed_generic(lv, nv, yield_constr); + divmod::eval_packed(lv, nv, yield_constr); + modular::eval_packed(lv, nv, yield_constr); } fn eval_ext_circuit( @@ -144,6 +212,7 @@ impl, const D: usize> Stark for ArithmeticSta mul::eval_ext_circuit(builder, lv, yield_constr); addcy::eval_ext_circuit(builder, lv, yield_constr); + divmod::eval_ext_circuit(builder, lv, nv, yield_constr); modular::eval_ext_circuit(builder, lv, nv, yield_constr); } @@ -176,6 +245,7 @@ mod tests { use rand_chacha::ChaCha8Rng; use super::{columns, ArithmeticStark}; + use crate::arithmetic::columns::OUTPUT_REGISTER; use crate::arithmetic::*; use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; @@ -249,7 +319,7 @@ mod tests { let ops: Vec = vec![add, mulmod, addmod, mul, modop, lt1, lt2, lt3, div]; - let pols = stark.generate(ops); + let pols = stark.generate_trace(ops); // Trace should always have NUM_ARITH_COLUMNS columns and // min(RANGE_MAX, operations.len()) rows. In this case there @@ -259,26 +329,23 @@ mod tests { && pols.iter().all(|v| v.len() == super::RANGE_MAX) ); - // Wrap the single value GENERAL_REGISTER_BIT in a Range. - let cmp_range = columns::GENERAL_REGISTER_BIT..columns::GENERAL_REGISTER_BIT + 1; - // Each operation has a single word answer that we can check let expected_output = [ - // Row (some ops take two rows), col, expected - (0, &columns::GENERAL_REGISTER_2, 579), // ADD_OUTPUT - (1, &columns::MODULAR_OUTPUT, 703), - (3, &columns::MODULAR_OUTPUT, 794), - (5, &columns::MUL_OUTPUT, 56088), - (6, &columns::MODULAR_OUTPUT, 11), - (8, &cmp_range, 0), - (9, &cmp_range, 1), - (10, &cmp_range, 0), - (11, &columns::DIV_OUTPUT, 9), + // Row (some ops take two rows), expected + (0, 579), // ADD_OUTPUT + (1, 703), + (3, 794), + (5, 56088), + (6, 11), + (8, 0), + (9, 1), + (10, 0), + (11, 9), ]; - for (row, col, expected) in expected_output { + for (row, expected) in expected_output { // First register should match expected value... - let first = col.start; + let first = OUTPUT_REGISTER.start; let out = pols[first].values[row].to_canonical_u64(); assert_eq!( out, expected, @@ -286,7 +353,7 @@ mod tests { first, row, expected, out, ); // ...other registers should be zero - let rest = col.start + 1..col.end; + let rest = OUTPUT_REGISTER.start + 1..OUTPUT_REGISTER.end; assert!(pols[rest].iter().all(|v| v.values[row] == F::ZERO)); } } @@ -314,7 +381,7 @@ mod tests { }) .collect::>(); - let pols = stark.generate(ops); + let pols = stark.generate_trace(ops); // Trace should always have NUM_ARITH_COLUMNS columns and // min(RANGE_MAX, operations.len()) rows. In this case there @@ -335,7 +402,7 @@ mod tests { }) .collect::>(); - let pols = stark.generate(ops); + let pols = stark.generate_trace(ops); // Trace should always have NUM_ARITH_COLUMNS columns and // min(RANGE_MAX, operations.len()) rows. In this case there diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index 952a8ed5..98481f64 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -22,16 +22,19 @@ const fn n_limbs() -> usize { /// Number of LIMB_BITS limbs that are in on EVM register-sized number. pub const N_LIMBS: usize = n_limbs(); -pub const IS_ADD: usize = 0; -pub const IS_MUL: usize = IS_ADD + 1; -pub const IS_SUB: usize = IS_MUL + 1; -pub const IS_DIV: usize = IS_SUB + 1; -pub const IS_MOD: usize = IS_DIV + 1; -pub const IS_ADDMOD: usize = IS_MOD + 1; -pub const IS_SUBMOD: usize = IS_ADDMOD + 1; -pub const IS_MULMOD: usize = IS_SUBMOD + 1; -pub const IS_LT: usize = IS_MULMOD + 1; -pub const IS_GT: usize = IS_LT + 1; +pub(crate) const IS_ADD: usize = 0; +pub(crate) const IS_MUL: usize = IS_ADD + 1; +pub(crate) const IS_SUB: usize = IS_MUL + 1; +pub(crate) const IS_DIV: usize = IS_SUB + 1; +pub(crate) const IS_MOD: usize = IS_DIV + 1; +pub(crate) const IS_ADDMOD: usize = IS_MOD + 1; +pub(crate) const IS_MULMOD: usize = IS_ADDMOD + 1; +pub(crate) const IS_ADDFP254: usize = IS_MULMOD + 1; +pub(crate) const IS_MULFP254: usize = IS_ADDFP254 + 1; +pub(crate) const IS_SUBFP254: usize = IS_MULFP254 + 1; +pub(crate) const IS_SUBMOD: usize = IS_SUBFP254 + 1; +pub(crate) const IS_LT: usize = IS_SUBMOD + 1; +pub(crate) const IS_GT: usize = IS_LT + 1; pub(crate) const START_SHARED_COLS: usize = IS_GT + 1; @@ -46,28 +49,28 @@ pub(crate) const START_SHARED_COLS: usize = IS_GT + 1; pub(crate) const NUM_SHARED_COLS: usize = 6 * N_LIMBS; pub(crate) const SHARED_COLS: Range = START_SHARED_COLS..START_SHARED_COLS + NUM_SHARED_COLS; -pub(crate) const GENERAL_REGISTER_0: Range = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS; -pub(crate) const GENERAL_REGISTER_1: Range = - GENERAL_REGISTER_0.end..GENERAL_REGISTER_0.end + N_LIMBS; -pub(crate) const GENERAL_REGISTER_2: Range = - GENERAL_REGISTER_1.end..GENERAL_REGISTER_1.end + N_LIMBS; -const GENERAL_REGISTER_3: Range = GENERAL_REGISTER_2.end..GENERAL_REGISTER_2.end + N_LIMBS; -// NB: Uses first slot of the GENERAL_REGISTER_3 register. -pub(crate) const GENERAL_REGISTER_BIT: usize = GENERAL_REGISTER_3.start; +pub(crate) const INPUT_REGISTER_0: Range = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS; +pub(crate) const INPUT_REGISTER_1: Range = + INPUT_REGISTER_0.end..INPUT_REGISTER_0.end + N_LIMBS; +pub(crate) const INPUT_REGISTER_2: Range = + INPUT_REGISTER_1.end..INPUT_REGISTER_1.end + N_LIMBS; +pub(crate) const OUTPUT_REGISTER: Range = + INPUT_REGISTER_2.end..INPUT_REGISTER_2.end + N_LIMBS; -// NB: Only one of these two sets of columns will be used for a given operation -const GENERAL_REGISTER_4: Range = GENERAL_REGISTER_3.end..GENERAL_REGISTER_3.end + N_LIMBS; -const GENERAL_REGISTER_4_DBL: Range = - GENERAL_REGISTER_3.end..GENERAL_REGISTER_3.end + 2 * N_LIMBS; +// NB: Only one of AUX_INPUT_REGISTER_[01] or AUX_INPUT_REGISTER_DBL +// will be used for a given operation since they overlap +pub(crate) const AUX_INPUT_REGISTER_0: Range = + OUTPUT_REGISTER.end..OUTPUT_REGISTER.end + N_LIMBS; +pub(crate) const AUX_INPUT_REGISTER_1: Range = + AUX_INPUT_REGISTER_0.end..AUX_INPUT_REGISTER_0.end + N_LIMBS; +pub(crate) const AUX_INPUT_REGISTER_DBL: Range = + OUTPUT_REGISTER.end..OUTPUT_REGISTER.end + 2 * N_LIMBS; // The auxiliary input columns overlap the general input columns // because they correspond to the values in the second row for modular // operations. const AUX_REGISTER_0: Range = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS; const AUX_REGISTER_1: Range = AUX_REGISTER_0.end..AUX_REGISTER_0.end + 2 * N_LIMBS; -// These auxiliary input columns are awkwardly split across two rows, -// with the first half after the general input columns and the second -// half after the auxiliary input columns. const AUX_REGISTER_2: Range = AUX_REGISTER_1.end..AUX_REGISTER_1.end + 2 * N_LIMBS - 1; // Each element c of {MUL,MODULAR}_AUX_REGISTER is -2^20 <= c <= 2^20; @@ -76,11 +79,8 @@ const AUX_REGISTER_2: Range = AUX_REGISTER_1.end..AUX_REGISTER_1.end + 2 pub(crate) const AUX_COEFF_ABS_MAX: i64 = 1 << 20; // MUL takes 5 * N_LIMBS = 80 columns -pub(crate) const MUL_INPUT_0: Range = GENERAL_REGISTER_0; -pub(crate) const MUL_INPUT_1: Range = GENERAL_REGISTER_1; -pub(crate) const MUL_OUTPUT: Range = GENERAL_REGISTER_2; -pub(crate) const MUL_AUX_INPUT_LO: Range = GENERAL_REGISTER_3; -pub(crate) const MUL_AUX_INPUT_HI: Range = GENERAL_REGISTER_4; +pub(crate) const MUL_AUX_INPUT_LO: Range = AUX_INPUT_REGISTER_0; +pub(crate) const MUL_AUX_INPUT_HI: Range = AUX_INPUT_REGISTER_1; // MULMOD takes 4 * N_LIMBS + 3 * 2*N_LIMBS + N_LIMBS = 176 columns // but split over two rows of 96 columns and 80 columns. @@ -88,11 +88,11 @@ pub(crate) const MUL_AUX_INPUT_HI: Range = GENERAL_REGISTER_4; // ADDMOD, SUBMOD, MOD and DIV are currently implemented in terms of // the general modular code, so they also take 144 columns (also split // over two rows). -pub(crate) const MODULAR_INPUT_0: Range = GENERAL_REGISTER_0; -pub(crate) const MODULAR_INPUT_1: Range = GENERAL_REGISTER_1; -pub(crate) const MODULAR_MODULUS: Range = GENERAL_REGISTER_2; -pub(crate) const MODULAR_OUTPUT: Range = GENERAL_REGISTER_3; -pub(crate) const MODULAR_QUO_INPUT: Range = GENERAL_REGISTER_4_DBL; +pub(crate) const MODULAR_INPUT_0: Range = INPUT_REGISTER_0; +pub(crate) const MODULAR_INPUT_1: Range = INPUT_REGISTER_1; +pub(crate) const MODULAR_MODULUS: Range = INPUT_REGISTER_2; +pub(crate) const MODULAR_OUTPUT: Range = OUTPUT_REGISTER; +pub(crate) const MODULAR_QUO_INPUT: Range = AUX_INPUT_REGISTER_DBL; pub(crate) const MODULAR_OUT_AUX_RED: Range = AUX_REGISTER_0; // NB: Last value is not used in AUX, it is used in MOD_IS_ZERO pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_REGISTER_1.start; @@ -101,14 +101,6 @@ pub(crate) const MODULAR_AUX_INPUT_HI: Range = AUX_REGISTER_2; // Must be set to MOD_IS_ZERO for DIV operation i.e. MOD_IS_ZERO * lv[IS_DIV] pub(crate) const MODULAR_DIV_DENOM_IS_ZERO: usize = AUX_REGISTER_2.end; -#[allow(unused)] // TODO: Will be used when hooking into the CPU -pub(crate) const DIV_NUMERATOR: Range = MODULAR_INPUT_0; -#[allow(unused)] // TODO: Will be used when hooking into the CPU -pub(crate) const DIV_DENOMINATOR: Range = MODULAR_MODULUS; -#[allow(unused)] // TODO: Will be used when hooking into the CPU -pub(crate) const DIV_OUTPUT: Range = - MODULAR_QUO_INPUT.start..MODULAR_QUO_INPUT.start + N_LIMBS; - // Need one column for the table, then two columns for every value // that needs to be range checked in the trace, namely the permutation // of the column and the permutation of the range. The two diff --git a/evm/src/arithmetic/divmod.rs b/evm/src/arithmetic/divmod.rs new file mode 100644 index 00000000..4f2dd748 --- /dev/null +++ b/evm/src/arithmetic/divmod.rs @@ -0,0 +1,339 @@ +use std::ops::Range; + +use ethereum_types::U256; +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::PrimeField64; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use crate::arithmetic::columns::*; +use crate::arithmetic::modular::{ + generate_modular_op, modular_constr_poly, modular_constr_poly_ext_circuit, +}; +use crate::arithmetic::utils::*; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +/// Generate the output and auxiliary values for modular operations. +pub(crate) fn generate( + lv: &mut [F], + nv: &mut [F], + filter: usize, + input0: U256, + input1: U256, + result: U256, +) { + debug_assert!(lv.len() == NUM_ARITH_COLUMNS); + + u256_to_array(&mut lv[INPUT_REGISTER_0], input0); + u256_to_array(&mut lv[INPUT_REGISTER_1], input1); + u256_to_array(&mut lv[OUTPUT_REGISTER], result); + + let input_limbs = read_value_i64_limbs::(lv, INPUT_REGISTER_0); + let pol_input = pol_extend(input_limbs); + let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, INPUT_REGISTER_1); + debug_assert!( + &quo_input[N_LIMBS..].iter().all(|&x| x == F::ZERO), + "expected top half of quo_input to be zero" + ); + + // Initialise whole (double) register to zero; the low half will + // be overwritten via lv[AUX_INPUT_REGISTER] below. + for i in MODULAR_QUO_INPUT { + lv[i] = F::ZERO; + } + + match filter { + IS_DIV => { + debug_assert!( + lv[OUTPUT_REGISTER] + .iter() + .zip(&quo_input[..N_LIMBS]) + .all(|(x, y)| x == y), + "computed output doesn't match expected" + ); + lv[AUX_INPUT_REGISTER_0].copy_from_slice(&out); + } + IS_MOD => { + debug_assert!( + lv[OUTPUT_REGISTER].iter().zip(&out).all(|(x, y)| x == y), + "computed output doesn't match expected" + ); + lv[AUX_INPUT_REGISTER_0].copy_from_slice(&quo_input[..N_LIMBS]); + } + _ => panic!("expected filter to be IS_DIV or IS_MOD but it was {filter}"), + }; +} + +/// Verify that num = quo * den + rem and 0 <= rem < den. +fn eval_packed_divmod_helper( + lv: &[P; NUM_ARITH_COLUMNS], + nv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, + filter: P, + quo_range: Range, + rem_range: Range, +) { + debug_assert!(quo_range.len() == N_LIMBS); + debug_assert!(rem_range.len() == N_LIMBS); + + yield_constr.constraint_last_row(filter); + + let num = &lv[INPUT_REGISTER_0]; + let den = read_value(lv, INPUT_REGISTER_1); + let quo = { + let mut quo = [P::ZEROS; 2 * N_LIMBS]; + quo[..N_LIMBS].copy_from_slice(&lv[quo_range]); + quo + }; + let rem = read_value(lv, rem_range); + + let mut constr_poly = modular_constr_poly(lv, nv, yield_constr, filter, rem, den, quo); + + let input = num; + pol_sub_assign(&mut constr_poly, input); + + for &c in constr_poly.iter() { + yield_constr.constraint_transition(filter * c); + } +} + +pub(crate) fn eval_packed( + lv: &[P; NUM_ARITH_COLUMNS], + nv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + eval_packed_divmod_helper( + lv, + nv, + yield_constr, + lv[IS_DIV], + OUTPUT_REGISTER, + AUX_INPUT_REGISTER_0, + ); + eval_packed_divmod_helper( + lv, + nv, + yield_constr, + lv[IS_MOD], + AUX_INPUT_REGISTER_0, + OUTPUT_REGISTER, + ); +} + +fn eval_ext_circuit_divmod_helper, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, + filter: ExtensionTarget, + quo_range: Range, + rem_range: Range, +) { + yield_constr.constraint_last_row(builder, filter); + + let num = &lv[INPUT_REGISTER_0]; + let den = read_value(lv, INPUT_REGISTER_1); + let quo = { + let zero = builder.zero_extension(); + let mut quo = [zero; 2 * N_LIMBS]; + quo[..N_LIMBS].copy_from_slice(&lv[quo_range]); + quo + }; + let rem = read_value(lv, rem_range); + + let mut constr_poly = + modular_constr_poly_ext_circuit(lv, nv, builder, yield_constr, filter, rem, den, quo); + + let input = num; + pol_sub_assign_ext_circuit(builder, &mut constr_poly, input); + + for &c in constr_poly.iter() { + let t = builder.mul_extension(filter, c); + yield_constr.constraint_transition(builder, t); + } +} + +pub(crate) fn eval_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + eval_ext_circuit_divmod_helper( + builder, + lv, + nv, + yield_constr, + lv[IS_DIV], + OUTPUT_REGISTER, + AUX_INPUT_REGISTER_0, + ); + eval_ext_circuit_divmod_helper( + builder, + lv, + nv, + yield_constr, + lv[IS_MOD], + AUX_INPUT_REGISTER_0, + OUTPUT_REGISTER, + ); +} + +#[cfg(test)] +mod tests { + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::{Field, Sample}; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha8Rng; + + use super::*; + use crate::arithmetic::columns::NUM_ARITH_COLUMNS; + use crate::constraint_consumer::ConstraintConsumer; + + const N_RND_TESTS: usize = 1000; + const MODULAR_OPS: [usize; 2] = [IS_MOD, IS_DIV]; + + // TODO: Should be able to refactor this test to apply to all operations. + #[test] + fn generate_eval_consistency_not_modular() { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + let nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); + + // if `IS_MOD == 0`, then the constraints should be met even + // if all values are garbage (and similarly for the other operations). + for op in MODULAR_OPS { + lv[op] = F::ZERO; + } + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ONE, + ); + eval_packed(&lv, &nv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + + #[test] + fn generate_eval_consistency() { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + + for op_filter in MODULAR_OPS { + for i in 0..N_RND_TESTS { + // set inputs to random values + let mut lv = [F::default(); NUM_ARITH_COLUMNS] + .map(|_| F::from_canonical_u16(rng.gen::())); + let mut nv = [F::default(); NUM_ARITH_COLUMNS] + .map(|_| F::from_canonical_u16(rng.gen::())); + + // Reset operation columns, then select one + for op in MODULAR_OPS { + lv[op] = F::ZERO; + } + lv[op_filter] = F::ONE; + + let input0 = U256::from(rng.gen::<[u8; 32]>()); + let input1 = { + let mut modulus_limbs = [0u8; 32]; + // For the second half of the tests, set the top + // 16-start digits of the "modulus" to zero so it is + // much smaller than the inputs. + if i > N_RND_TESTS / 2 { + // 1 <= start < N_LIMBS + let start = (rng.gen::() % (modulus_limbs.len() - 1)) + 1; + for mi in modulus_limbs.iter_mut().skip(start) { + *mi = 0u8; + } + } + U256::from(modulus_limbs) + }; + + let result = if input1 == U256::zero() { + U256::zero() + } else if op_filter == IS_DIV { + input0 / input1 + } else { + input0 % input1 + }; + generate(&mut lv, &mut nv, op_filter, input0, input1, result); + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ZERO, + GoldilocksField::ZERO, + ); + eval_packed(&lv, &nv, &mut constraint_consumer); + for &acc in &constraint_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + } + } + + #[test] + fn zero_modulus() { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + + for op_filter in MODULAR_OPS { + for _i in 0..N_RND_TESTS { + // set inputs to random values and the modulus to zero; + // the output is defined to be zero when modulus is zero. + let mut lv = [F::default(); NUM_ARITH_COLUMNS] + .map(|_| F::from_canonical_u16(rng.gen::())); + let mut nv = [F::default(); NUM_ARITH_COLUMNS] + .map(|_| F::from_canonical_u16(rng.gen::())); + + // Reset operation columns, then select one + for op in MODULAR_OPS { + lv[op] = F::ZERO; + } + lv[op_filter] = F::ONE; + + let input0 = U256::from(rng.gen::<[u8; 32]>()); + let input1 = U256::zero(); + + generate(&mut lv, &mut nv, op_filter, input0, input1, U256::zero()); + + // check that the correct output was generated + assert!(lv[OUTPUT_REGISTER].iter().all(|&c| c == F::ZERO)); + + let mut constraint_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ZERO, + GoldilocksField::ZERO, + ); + eval_packed(&lv, &nv, &mut constraint_consumer); + assert!(constraint_consumer + .constraint_accs + .iter() + .all(|&acc| acc == F::ZERO)); + + // Corrupt one output limb by setting it to a non-zero value + let random_oi = OUTPUT_REGISTER.start + rng.gen::() % N_LIMBS; + lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); + + eval_packed(&lv, &nv, &mut constraint_consumer); + + // Check that at least one of the constraints was non-zero + assert!(constraint_consumer + .constraint_accs + .iter() + .any(|&acc| acc != F::ZERO)); + } + } + } +} diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index c6987ed7..74f08947 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -5,6 +5,7 @@ use crate::extension_tower::BN_BASE; use crate::util::{addmod, mulmod, submod}; mod addcy; +mod divmod; mod modular; mod mul; mod utils; @@ -63,9 +64,9 @@ impl BinaryOperator { BinaryOperator::Mod => columns::IS_MOD, BinaryOperator::Lt => columns::IS_LT, BinaryOperator::Gt => columns::IS_GT, - BinaryOperator::AddFp254 => columns::IS_ADDMOD, - BinaryOperator::MulFp254 => columns::IS_MULMOD, - BinaryOperator::SubFp254 => columns::IS_SUBMOD, + BinaryOperator::AddFp254 => columns::IS_ADDFP254, + BinaryOperator::MulFp254 => columns::IS_MULFP254, + BinaryOperator::SubFp254 => columns::IS_SUBFP254, } } } @@ -209,7 +210,9 @@ fn binary_op_to_rows( (row, None) } BinaryOperator::Div | BinaryOperator::Mod => { - ternary_op_to_rows::(op.row_filter(), input0, U256::zero(), input1, result) + let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + divmod::generate(&mut row, &mut nv, op.row_filter(), input0, input1, result); + (row, Some(nv)) } BinaryOperator::AddFp254 | BinaryOperator::MulFp254 | BinaryOperator::SubFp254 => { ternary_op_to_rows::(op.row_filter(), input0, input1, BN_BASE, result) diff --git a/evm/src/arithmetic/modular.rs b/evm/src/arithmetic/modular.rs index 99eccacb..a2e08b55 100644 --- a/evm/src/arithmetic/modular.rs +++ b/evm/src/arithmetic/modular.rs @@ -108,6 +108,8 @@ //! only require 96 columns, or 80 if the output doesn't need to be //! reduced. +use std::ops::Range; + use ethereum_types::U256; use num::bigint::Sign; use num::{BigInt, One, Zero}; @@ -117,12 +119,29 @@ use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; +use static_assertions::const_assert; use super::columns; use crate::arithmetic::addcy::{eval_ext_circuit_addcy, eval_packed_generic_addcy}; use crate::arithmetic::columns::*; use crate::arithmetic::utils::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::extension_tower::BN_BASE; + +const fn bn254_modulus_limbs() -> [u16; N_LIMBS] { + const_assert!(N_LIMBS == 16); // Assumed below + let mut limbs = [0u16; N_LIMBS]; + let mut i = 0; + while i < N_LIMBS / 4 { + let x = BN_BASE.0[i]; + limbs[4 * i] = x as u16; + limbs[4 * i + 1] = (x >> 16) as u16; + limbs[4 * i + 2] = (x >> 32) as u16; + limbs[4 * i + 3] = (x >> 48) as u16; + i += 1; + } + limbs +} /// Convert the base-2^16 representation of a number into a BigInt. /// @@ -190,29 +209,26 @@ fn bigint_to_columns(num: &BigInt) -> [i64; N] { /// /// NB: `operation` can set the higher order elements in its result to /// zero if they are not used. -fn generate_modular_op( +pub(crate) fn generate_modular_op( lv: &mut [F], nv: &mut [F], filter: usize, - operation: fn([i64; N_LIMBS], [i64; N_LIMBS]) -> [i64; 2 * N_LIMBS - 1], -) { - // Inputs are all range-checked in [0, 2^16), so the "as i64" - // conversion is safe. - - let input0_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_0); - let input1_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_1); - let mut modulus_limbs = read_value_i64_limbs(lv, MODULAR_MODULUS); + pol_input: [i64; 2 * N_LIMBS - 1], + modulus_range: Range, +) -> ([F; N_LIMBS], [F; 2 * N_LIMBS]) { + assert!(modulus_range.len() == N_LIMBS); + let mut modulus_limbs = read_value_i64_limbs(lv, modulus_range); // BigInts are just used to avoid having to implement modular // reduction. let mut modulus = columns_to_bigint(&modulus_limbs); - // constr_poly is initialised to the calculated input, and is - // used as such for the BigInt reduction; later, other values are - // added/subtracted, which is where its meaning as the "constraint - // polynomial" comes in. + // constr_poly is initialised to the input calculation as + // polynomials, and is used as such for the BigInt reduction; + // later, other values are added/subtracted, which is where its + // meaning as the "constraint polynomial" comes in. let mut constr_poly = [0i64; 2 * N_LIMBS]; - constr_poly[..2 * N_LIMBS - 1].copy_from_slice(&operation(input0_limbs, input1_limbs)); + constr_poly[..2 * N_LIMBS - 1].copy_from_slice(&pol_input); // two_exp_256 == 2^256 let two_exp_256 = { @@ -264,8 +280,6 @@ fn generate_modular_op( // Higher order terms of the product must be zero for valid quot and modulus: debug_assert!(&prod[2 * N_LIMBS..].iter().all(|&x| x == 0i64)); - lv[MODULAR_OUTPUT].copy_from_slice(&output_limbs.map(F::from_canonical_i64)); - lv[MODULAR_QUO_INPUT].copy_from_slice("_limbs.map(F::from_noncanonical_i64)); // constr_poly must be zero when evaluated at x = β := // 2^LIMB_BITS, hence it's divisible by (x - β). `aux_limbs` is // the result of removing that root. @@ -286,11 +300,16 @@ fn generate_modular_op( nv[MODULAR_MOD_IS_ZERO] = mod_is_zero; nv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(F::from_canonical_i64)); nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * lv[IS_DIV]; + + ( + output_limbs.map(F::from_canonical_i64), + quot_limbs.map(F::from_noncanonical_i64), + ) } /// Generate the output and auxiliary values for modular operations. /// -/// `filter` must be one of `columns::IS_{ADDMOD,MULMOD,MOD}`. +/// `filter` must be one of `columns::IS_{ADD,MUL,SUB}{MOD,FP254}`. pub(crate) fn generate( lv: &mut [F], nv: &mut [F], @@ -305,15 +324,29 @@ pub(crate) fn generate( u256_to_array(&mut lv[MODULAR_INPUT_1], input1); u256_to_array(&mut lv[MODULAR_MODULUS], modulus); - match filter { - columns::IS_ADDMOD => generate_modular_op(lv, nv, filter, pol_add), - columns::IS_SUBMOD => generate_modular_op(lv, nv, filter, pol_sub), - columns::IS_MULMOD => generate_modular_op(lv, nv, filter, pol_mul_wide), - columns::IS_MOD | columns::IS_DIV => { - generate_modular_op(lv, nv, filter, |a, _| pol_extend(a)) - } - _ => panic!("generate modular operation called with unknown opcode"), + if [ + columns::IS_ADDFP254, + columns::IS_SUBFP254, + columns::IS_MULFP254, + ] + .contains(&filter) + { + debug_assert!(modulus == BN_BASE); } + + // Inputs are all in [0, 2^16), so the "as i64" conversion is safe. + let input0_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_0); + let input1_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_1); + + let pol_input = match filter { + columns::IS_ADDMOD | columns::IS_ADDFP254 => pol_add(input0_limbs, input1_limbs), + columns::IS_SUBMOD | columns::IS_SUBFP254 => pol_sub(input0_limbs, input1_limbs), + columns::IS_MULMOD | columns::IS_MULFP254 => pol_mul_wide(input0_limbs, input1_limbs), + _ => panic!("generate modular operation called with unknown opcode"), + }; + let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, MODULAR_MODULUS); + lv[MODULAR_OUTPUT].copy_from_slice(&out); + lv[MODULAR_QUO_INPUT].copy_from_slice(&quo_input); } /// Build the part of the constraint polynomial that's common to all @@ -324,13 +357,15 @@ pub(crate) fn generate( /// c(x) + q(x) * m(x) + (x - β) * s(x) /// /// and check consistency when m = 0, and that c is reduced. -fn modular_constr_poly( +pub(crate) fn modular_constr_poly( lv: &[P; NUM_ARITH_COLUMNS], nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, filter: P, + mut output: [P; N_LIMBS], + mut modulus: [P; N_LIMBS], + quot: [P; 2 * N_LIMBS], ) -> [P; 2 * N_LIMBS] { - let mut modulus = read_value::(lv, MODULAR_MODULUS); let mod_is_zero = nv[MODULAR_MOD_IS_ZERO]; // Check that mod_is_zero is zero or one @@ -345,8 +380,6 @@ fn modular_constr_poly( // modulus = 0. modulus[0] += mod_is_zero; - let mut output = read_value::(lv, MODULAR_OUTPUT); - // Is 1 iff the operation is DIV and the denominator is zero. let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO]; yield_constr.constraint_transition(filter * (mod_is_zero * lv[IS_DIV] - div_denom_is_zero)); @@ -365,7 +398,8 @@ fn modular_constr_poly( // modulus + out_aux_red == output + is_less_than*2^256 // // and we are given output = out_aux_red when modulus is zero. - let is_less_than = P::ONES - mod_is_zero * lv[IS_DIV]; + let mut is_less_than = [P::ZEROS; N_LIMBS]; + is_less_than[0] = P::ONES - mod_is_zero * lv[IS_DIV]; // NB: output and modulus in lv while out_aux_red and // is_less_than (via mod_is_zero) depend on nv, hence the // 'is_two_row_op' argument is set to 'true'. @@ -375,19 +409,13 @@ fn modular_constr_poly( &modulus, out_aux_red, &output, - is_less_than, + &is_less_than, true, ); // restore output[0] output[0] -= div_denom_is_zero; // prod = q(x) * m(x) - let quot = { - let mut quot = [P::default(); 2 * N_LIMBS]; - quot.copy_from_slice(&lv[MODULAR_QUO_INPUT]); - quot - }; - let prod = pol_mul_wide2(quot, modulus); // higher order terms must be zero for &x in prod[2 * N_LIMBS..].iter() { @@ -419,25 +447,34 @@ fn modular_constr_poly( } /// Add constraints for modular operations. -pub(crate) fn eval_packed_generic( +pub(crate) fn eval_packed( lv: &[P; NUM_ARITH_COLUMNS], nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, ) { // NB: The CTL code guarantees that filter is 0 or 1, i.e. that // only one of the operations below is "live". - let filter = lv[columns::IS_ADDMOD] - + lv[columns::IS_SUBMOD] - + lv[columns::IS_MULMOD] - + lv[columns::IS_MOD] - + lv[columns::IS_DIV]; + let bn254_filter = + lv[columns::IS_ADDFP254] + lv[columns::IS_MULFP254] + lv[columns::IS_SUBFP254]; + let filter = + lv[columns::IS_ADDMOD] + lv[columns::IS_SUBMOD] + lv[columns::IS_MULMOD] + bn254_filter; // Ensure that this operation is not the last row of the table; // needed because we access the next row of the table in nv. yield_constr.constraint_last_row(filter); + // Verify that the modulus is the BN254 modulus for the + // {ADD,MUL,SUB}FP254 operations. + let modulus = read_value::(lv, MODULAR_MODULUS); + for (&mi, bi) in modulus.iter().zip(bn254_modulus_limbs()) { + yield_constr.constraint_transition(bn254_filter * (mi - P::Scalar::from_canonical_u16(bi))); + } + + let output = read_value::(lv, MODULAR_OUTPUT); + let quo_input = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); + // constr_poly has 2*N_LIMBS limbs - let constr_poly = modular_constr_poly(lv, nv, yield_constr, filter); + let constr_poly = modular_constr_poly(lv, nv, yield_constr, filter, output, modulus, quo_input); let input0 = read_value(lv, MODULAR_INPUT_0); let input1 = read_value(lv, MODULAR_INPUT_1); @@ -445,13 +482,15 @@ pub(crate) fn eval_packed_generic( let add_input = pol_add(input0, input1); let sub_input = pol_sub(input0, input1); let mul_input = pol_mul_wide(input0, input1); - let mod_input = pol_extend(input0); + + let add_filter = lv[columns::IS_ADDMOD] + lv[columns::IS_ADDFP254]; + let sub_filter = lv[columns::IS_SUBMOD] + lv[columns::IS_SUBFP254]; + let mul_filter = lv[columns::IS_MULMOD] + lv[columns::IS_MULFP254]; for (input, &filter) in [ - (&add_input, &lv[columns::IS_ADDMOD]), - (&sub_input, &lv[columns::IS_SUBMOD]), - (&mul_input, &lv[columns::IS_MULMOD]), - (&mod_input, &(lv[columns::IS_MOD] + lv[columns::IS_DIV])), + (&add_input, &add_filter), + (&sub_input, &sub_filter), + (&mul_input, &mul_filter), ] { // Need constr_poly_copy to be the first argument to // pol_sub_assign, since it is the longer of the two @@ -473,14 +512,16 @@ pub(crate) fn eval_packed_generic( } } -fn modular_constr_poly_ext_circuit, const D: usize>( +pub(crate) fn modular_constr_poly_ext_circuit, const D: usize>( lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], builder: &mut CircuitBuilder, yield_constr: &mut RecursiveConstraintConsumer, filter: ExtensionTarget, + mut output: [ExtensionTarget; N_LIMBS], + mut modulus: [ExtensionTarget; N_LIMBS], + quot: [ExtensionTarget; 2 * N_LIMBS], ) -> [ExtensionTarget; 2 * N_LIMBS] { - let mut modulus = read_value::(lv, MODULAR_MODULUS); let mod_is_zero = nv[MODULAR_MOD_IS_ZERO]; let t = builder.mul_sub_extension(mod_is_zero, mod_is_zero, mod_is_zero); @@ -494,8 +535,6 @@ fn modular_constr_poly_ext_circuit, const D: usize> modulus[0] = builder.add_extension(modulus[0], mod_is_zero); - let mut output = read_value::(lv, MODULAR_OUTPUT); - let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO]; let t = builder.mul_sub_extension(mod_is_zero, lv[IS_DIV], div_denom_is_zero); let t = builder.mul_extension(filter, t); @@ -504,7 +543,9 @@ fn modular_constr_poly_ext_circuit, const D: usize> let out_aux_red = &nv[MODULAR_OUT_AUX_RED]; let one = builder.one_extension(); - let is_less_than = + let zero = builder.zero_extension(); + let mut is_less_than = [zero; N_LIMBS]; + is_less_than[0] = builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one); eval_ext_circuit_addcy( @@ -514,16 +555,10 @@ fn modular_constr_poly_ext_circuit, const D: usize> &modulus, out_aux_red, &output, - is_less_than, + &is_less_than, true, ); output[0] = builder.sub_extension(output[0], div_denom_is_zero); - let quot = { - let zero = builder.zero_extension(); - let mut quot = [zero; 2 * N_LIMBS]; - quot.copy_from_slice(&lv[MODULAR_QUO_INPUT]); - quot - }; let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus); for &x in prod[2 * N_LIMBS..].iter() { @@ -559,31 +594,60 @@ pub(crate) fn eval_ext_circuit, const D: usize>( nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, ) { + let bn254_filter = builder.add_many_extension([ + lv[columns::IS_ADDFP254], + lv[columns::IS_MULFP254], + lv[columns::IS_SUBFP254], + ]); let filter = builder.add_many_extension([ lv[columns::IS_ADDMOD], lv[columns::IS_SUBMOD], lv[columns::IS_MULMOD], - lv[columns::IS_MOD], - lv[columns::IS_DIV], + bn254_filter, ]); yield_constr.constraint_last_row(builder, filter); - let constr_poly = modular_constr_poly_ext_circuit(lv, nv, builder, yield_constr, filter); + let modulus = read_value::(lv, MODULAR_MODULUS); + for (&mi, bi) in modulus.iter().zip(bn254_modulus_limbs()) { + // bn254_filter * (mi - bi) + let t = builder.arithmetic_extension( + F::ONE, + -F::from_canonical_u16(bi), + mi, + bn254_filter, + bn254_filter, + ); + yield_constr.constraint_transition(builder, t); + } + + let output = read_value::(lv, MODULAR_OUTPUT); + let quo_input = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); + + let constr_poly = modular_constr_poly_ext_circuit( + lv, + nv, + builder, + yield_constr, + filter, + output, + modulus, + quo_input, + ); let input0 = read_value(lv, MODULAR_INPUT_0); let input1 = read_value(lv, MODULAR_INPUT_1); let add_input = pol_add_ext_circuit(builder, input0, input1); let sub_input = pol_sub_ext_circuit(builder, input0, input1); let mul_input = pol_mul_wide_ext_circuit(builder, input0, input1); - let mod_input = pol_extend_ext_circuit(builder, input0); - let mod_div_filter = builder.add_extension(lv[columns::IS_MOD], lv[columns::IS_DIV]); + let add_filter = builder.add_extension(lv[columns::IS_ADDMOD], lv[columns::IS_ADDFP254]); + let sub_filter = builder.add_extension(lv[columns::IS_SUBMOD], lv[columns::IS_SUBFP254]); + let mul_filter = builder.add_extension(lv[columns::IS_MULMOD], lv[columns::IS_MULFP254]); for (input, &filter) in [ - (&add_input, &lv[columns::IS_ADDMOD]), - (&sub_input, &lv[columns::IS_SUBMOD]), - (&mul_input, &lv[columns::IS_MULMOD]), - (&mod_input, &mod_div_filter), + (&add_input, &add_filter), + (&sub_input, &sub_filter), + (&mul_input, &mul_filter), ] { let mut constr_poly_copy = constr_poly; pol_sub_assign_ext_circuit(builder, &mut constr_poly_copy, input); @@ -604,8 +668,17 @@ mod tests { use super::*; use crate::arithmetic::columns::NUM_ARITH_COLUMNS; use crate::constraint_consumer::ConstraintConsumer; + use crate::extension_tower::BN_BASE; const N_RND_TESTS: usize = 1000; + const MODULAR_OPS: [usize; 6] = [ + IS_ADDMOD, + IS_SUBMOD, + IS_MULMOD, + IS_ADDFP254, + IS_SUBFP254, + IS_MULFP254, + ]; // TODO: Should be able to refactor this test to apply to all operations. #[test] @@ -617,12 +690,12 @@ mod tests { let nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); // if `IS_ADDMOD == 0`, then the constraints should be met even - // if all values are garbage. - lv[IS_ADDMOD] = F::ZERO; - lv[IS_SUBMOD] = F::ZERO; - lv[IS_MULMOD] = F::ZERO; - lv[IS_MOD] = F::ZERO; + // if all values are garbage (and similarly for the other operations). + for op in MODULAR_OPS { + lv[op] = F::ZERO; + } lv[IS_DIV] = F::ZERO; + lv[IS_MOD] = F::ZERO; let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], @@ -630,7 +703,7 @@ mod tests { GoldilocksField::ONE, GoldilocksField::ONE, ); - eval_packed_generic(&lv, &nv, &mut constraint_consumer); + eval_packed(&lv, &nv, &mut constraint_consumer); for &acc in &constraint_consumer.constraint_accs { assert_eq!(acc, GoldilocksField::ZERO); } @@ -642,7 +715,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); - for op_filter in [IS_ADDMOD, IS_DIV, IS_SUBMOD, IS_MOD, IS_MULMOD] { + for op_filter in MODULAR_OPS { for i in 0..N_RND_TESTS { // set inputs to random values let mut lv = [F::default(); NUM_ARITH_COLUMNS] @@ -651,28 +724,32 @@ mod tests { .map(|_| F::from_canonical_u16(rng.gen::())); // Reset operation columns, then select one - lv[IS_ADDMOD] = F::ZERO; - lv[IS_SUBMOD] = F::ZERO; - lv[IS_MULMOD] = F::ZERO; - lv[IS_MOD] = F::ZERO; + for op in MODULAR_OPS { + lv[op] = F::ZERO; + } lv[IS_DIV] = F::ZERO; + lv[IS_MOD] = F::ZERO; lv[op_filter] = F::ONE; let input0 = U256::from(rng.gen::<[u8; 32]>()); let input1 = U256::from(rng.gen::<[u8; 32]>()); - let mut modulus_limbs = [0u8; 32]; - // For the second half of the tests, set the top - // 16-start digits of the modulus to zero so it is - // much smaller than the inputs. - if i > N_RND_TESTS / 2 { - // 1 <= start < N_LIMBS - let start = (rng.gen::() % (modulus_limbs.len() - 1)) + 1; - for mi in modulus_limbs.iter_mut().skip(start) { - *mi = 0u8; + let modulus = if [IS_ADDFP254, IS_MULFP254, IS_SUBFP254].contains(&op_filter) { + BN_BASE + } else { + let mut modulus_limbs = [0u8; 32]; + // For the second half of the tests, set the top + // 16-start digits of the modulus to zero so it is + // much smaller than the inputs. + if i > N_RND_TESTS / 2 { + // 1 <= start < N_LIMBS + let start = (rng.gen::() % (modulus_limbs.len() - 1)) + 1; + for mi in modulus_limbs.iter_mut().skip(start) { + *mi = 0u8; + } } - } - let modulus = U256::from(modulus_limbs); + U256::from(modulus_limbs) + }; generate(&mut lv, &mut nv, op_filter, input0, input1, modulus); @@ -682,7 +759,7 @@ mod tests { GoldilocksField::ZERO, GoldilocksField::ZERO, ); - eval_packed_generic(&lv, &nv, &mut constraint_consumer); + eval_packed(&lv, &nv, &mut constraint_consumer); for &acc in &constraint_consumer.constraint_accs { assert_eq!(acc, GoldilocksField::ZERO); } @@ -696,7 +773,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); - for op_filter in [IS_ADDMOD, IS_SUBMOD, IS_DIV, IS_MOD, IS_MULMOD] { + for op_filter in [IS_ADDMOD, IS_SUBMOD, IS_MULMOD] { for _i in 0..N_RND_TESTS { // set inputs to random values and the modulus to zero; // the output is defined to be zero when modulus is zero. @@ -706,11 +783,11 @@ mod tests { .map(|_| F::from_canonical_u16(rng.gen::())); // Reset operation columns, then select one - lv[IS_ADDMOD] = F::ZERO; - lv[IS_SUBMOD] = F::ZERO; - lv[IS_MULMOD] = F::ZERO; - lv[IS_MOD] = F::ZERO; + for op in MODULAR_OPS { + lv[op] = F::ZERO; + } lv[IS_DIV] = F::ZERO; + lv[IS_MOD] = F::ZERO; lv[op_filter] = F::ONE; let input0 = U256::from(rng.gen::<[u8; 32]>()); @@ -720,11 +797,7 @@ mod tests { generate(&mut lv, &mut nv, op_filter, input0, input1, modulus); // check that the correct output was generated - if op_filter == IS_DIV { - assert!(lv[DIV_OUTPUT].iter().all(|&c| c == F::ZERO)); - } else { - assert!(lv[MODULAR_OUTPUT].iter().all(|&c| c == F::ZERO)); - } + assert!(lv[MODULAR_OUTPUT].iter().all(|&c| c == F::ZERO)); let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], @@ -732,22 +805,17 @@ mod tests { GoldilocksField::ZERO, GoldilocksField::ZERO, ); - eval_packed_generic(&lv, &nv, &mut constraint_consumer); + eval_packed(&lv, &nv, &mut constraint_consumer); assert!(constraint_consumer .constraint_accs .iter() .all(|&acc| acc == F::ZERO)); // Corrupt one output limb by setting it to a non-zero value - if op_filter == IS_DIV { - let random_oi = DIV_OUTPUT.start + rng.gen::() % N_LIMBS; - lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); - } else { - let random_oi = MODULAR_OUTPUT.start + rng.gen::() % N_LIMBS; - lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); - }; + let random_oi = MODULAR_OUTPUT.start + rng.gen::() % N_LIMBS; + lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); - eval_packed_generic(&lv, &nv, &mut constraint_consumer); + eval_packed(&lv, &nv, &mut constraint_consumer); // Check that at least one of the constraints was non-zero assert!(constraint_consumer diff --git a/evm/src/arithmetic/mul.rs b/evm/src/arithmetic/mul.rs index 03acfa97..597d4051 100644 --- a/evm/src/arithmetic/mul.rs +++ b/evm/src/arithmetic/mul.rs @@ -70,11 +70,12 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { // TODO: It would probably be clearer/cleaner to read the U256 // into an [i64;N] and then copy that to the lv table. - u256_to_array(&mut lv[MUL_INPUT_0], left_in); - u256_to_array(&mut lv[MUL_INPUT_1], right_in); + u256_to_array(&mut lv[INPUT_REGISTER_0], left_in); + u256_to_array(&mut lv[INPUT_REGISTER_1], right_in); + u256_to_array(&mut lv[INPUT_REGISTER_2], U256::zero()); - let input0 = read_value_i64_limbs(lv, MUL_INPUT_0); - let input1 = read_value_i64_limbs(lv, MUL_INPUT_1); + let input0 = read_value_i64_limbs(lv, INPUT_REGISTER_0); + let input1 = read_value_i64_limbs(lv, INPUT_REGISTER_1); const MASK: i64 = (1i64 << LIMB_BITS) - 1i64; @@ -96,7 +97,7 @@ pub fn generate(lv: &mut [F], left_in: U256, right_in: U256) { // aux_limbs to handle the fact that unreduced_prod will // inevitably contain one digit's worth that is > 2^256. - lv[MUL_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_i64(c))); + lv[OUTPUT_REGISTER].copy_from_slice(&output_limbs.map(|c| F::from_canonical_i64(c))); pol_sub_assign(&mut unreduced_prod, &output_limbs); let mut aux_limbs = pol_remove_root_2exp::(unreduced_prod); @@ -121,9 +122,9 @@ pub fn eval_packed_generic( let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); let is_mul = lv[IS_MUL]; - let input0_limbs = read_value::(lv, MUL_INPUT_0); - let input1_limbs = read_value::(lv, MUL_INPUT_1); - let output_limbs = read_value::(lv, MUL_OUTPUT); + let input0_limbs = read_value::(lv, INPUT_REGISTER_0); + let input1_limbs = read_value::(lv, INPUT_REGISTER_1); + let output_limbs = read_value::(lv, OUTPUT_REGISTER); let aux_limbs = { // MUL_AUX_INPUT was offset by 2^20 in generation, so we undo @@ -173,9 +174,9 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr: &mut RecursiveConstraintConsumer, ) { let is_mul = lv[IS_MUL]; - let input0_limbs = read_value::(lv, MUL_INPUT_0); - let input1_limbs = read_value::(lv, MUL_INPUT_1); - let output_limbs = read_value::(lv, MUL_OUTPUT); + let input0_limbs = read_value::(lv, INPUT_REGISTER_0); + let input1_limbs = read_value::(lv, INPUT_REGISTER_1); + let output_limbs = read_value::(lv, OUTPUT_REGISTER); let aux_limbs = { let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << LIMB_BITS)); @@ -253,7 +254,7 @@ mod tests { for _i in 0..N_RND_TESTS { // set inputs to random values - for (ai, bi) in MUL_INPUT_0.zip(MUL_INPUT_1) { + for (ai, bi) in INPUT_REGISTER_0.zip(INPUT_REGISTER_1) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); } diff --git a/evm/src/arithmetic/utils.rs b/evm/src/arithmetic/utils.rs index 8b4f546a..6ea375fe 100644 --- a/evm/src/arithmetic/utils.rs +++ b/evm/src/arithmetic/utils.rs @@ -227,17 +227,6 @@ where zero_extend } -pub(crate) fn pol_extend_ext_circuit, const D: usize>( - builder: &mut CircuitBuilder, - a: [ExtensionTarget; N_LIMBS], -) -> [ExtensionTarget; 2 * N_LIMBS - 1] { - let zero = builder.zero_extension(); - let mut zero_extend = [zero; 2 * N_LIMBS - 1]; - - zero_extend[..N_LIMBS].copy_from_slice(&a); - zero_extend -} - /// Given polynomial a(x) = \sum_{i=0}^{N-2} a[i] x^i and an element /// `root`, return b = (x - root) * a(x). pub(crate) fn pol_adjoin_root(a: [T; N], root: U) -> [T; N] diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index a8c83e4f..069a1609 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -8,6 +8,7 @@ use plonky2::field::packed::PackedField; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use crate::all_stark::Table; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; use crate::cpu::membus::NUM_GP_CHANNELS; @@ -15,7 +16,7 @@ use crate::cpu::{ bootstrap_kernel, contextops, control_flow, decode, dup_swap, gas, jumps, membus, memio, modfp254, pc, shift, simple_logic, stack, stack_bounds, syscalls, }; -use crate::cross_table_lookup::Column; +use crate::cross_table_lookup::{Column, TableWithColumns}; use crate::memory::segments::Segment; use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::stark::Stark; @@ -45,8 +46,10 @@ pub fn ctl_filter_keccak_sponge() -> Column { Column::single(COL_MAP.is_keccak_sponge) } -pub fn ctl_data_logic() -> Vec> { - let mut res = Column::singles([COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor]).collect_vec(); +/// Create the vector of Columns corresponding to the two inputs and +/// one output of a binary operation. +fn ctl_data_binops(ops: &[usize]) -> Vec> { + let mut res = Column::singles(ops).collect_vec(); res.extend(Column::singles(COL_MAP.mem_channels[0].value)); res.extend(Column::singles(COL_MAP.mem_channels[1].value)); res.extend(Column::singles( @@ -55,10 +58,51 @@ pub fn ctl_data_logic() -> Vec> { res } +/// Create the vector of Columns corresponding to the three inputs and +/// one output of a ternary operation. +fn ctl_data_ternops(ops: &[usize]) -> Vec> { + let mut res = Column::singles(ops).collect_vec(); + res.extend(Column::singles(COL_MAP.mem_channels[0].value)); + res.extend(Column::singles(COL_MAP.mem_channels[1].value)); + res.extend(Column::singles(COL_MAP.mem_channels[2].value)); + res.extend(Column::singles( + COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value, + )); + res +} + +pub fn ctl_data_logic() -> Vec> { + ctl_data_binops(&[COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor]) +} + pub fn ctl_filter_logic() -> Column { Column::sum([COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor]) } +pub fn ctl_arithmetic_rows() -> TableWithColumns { + const OPS: [usize; 13] = [ + COL_MAP.op.add, + COL_MAP.op.sub, + COL_MAP.op.mul, + COL_MAP.op.lt, + COL_MAP.op.gt, + COL_MAP.op.addfp254, + COL_MAP.op.mulfp254, + COL_MAP.op.subfp254, + COL_MAP.op.addmod, + COL_MAP.op.mulmod, + COL_MAP.op.submod, + COL_MAP.op.div, + COL_MAP.op.mod_, + ]; + // Create the CPU Table whose columns are those with the three + // inputs and one output of the ternary operations listed in `ops` + // (also `ops` is used as the operation filter). The list of + // operations includes binary operations which will simply ignore + // the third input. + TableWithColumns::new(Table::Cpu, ctl_data_ternops(&OPS), Some(Column::sum(OPS))) +} + pub const MEM_CODE_CHANNEL_IDX: usize = 0; pub const MEM_GP_CHANNELS_IDX_START: usize = MEM_CODE_CHANNEL_IDX + 1; diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 483aaf54..1f9bf820 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -26,6 +26,7 @@ use plonky2::util::timing::TimingTree; use plonky2_util::log2_ceil; use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; +use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::config::StarkConfig; use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CrossTableLookup}; @@ -265,6 +266,7 @@ where F: RichField + Extendable, C: GenericConfig + 'static, C::Hasher: AlgebraicHasher, + [(); ArithmeticStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -338,43 +340,50 @@ where degree_bits_ranges: &[Range; NUM_TABLES], stark_config: &StarkConfig, ) -> Self { + let arithmetic = RecursiveCircuitsForTable::new( + Table::Arithmetic, + &all_stark.arithmetic_stark, + degree_bits_ranges[0].clone(), + &all_stark.cross_table_lookups, + stark_config, + ); let cpu = RecursiveCircuitsForTable::new( Table::Cpu, &all_stark.cpu_stark, - degree_bits_ranges[0].clone(), + degree_bits_ranges[1].clone(), &all_stark.cross_table_lookups, stark_config, ); let keccak = RecursiveCircuitsForTable::new( Table::Keccak, &all_stark.keccak_stark, - degree_bits_ranges[1].clone(), + degree_bits_ranges[2].clone(), &all_stark.cross_table_lookups, stark_config, ); let keccak_sponge = RecursiveCircuitsForTable::new( Table::KeccakSponge, &all_stark.keccak_sponge_stark, - degree_bits_ranges[2].clone(), + degree_bits_ranges[3].clone(), &all_stark.cross_table_lookups, stark_config, ); let logic = RecursiveCircuitsForTable::new( Table::Logic, &all_stark.logic_stark, - degree_bits_ranges[3].clone(), + degree_bits_ranges[4].clone(), &all_stark.cross_table_lookups, stark_config, ); let memory = RecursiveCircuitsForTable::new( Table::Memory, &all_stark.memory_stark, - degree_bits_ranges[4].clone(), + degree_bits_ranges[5].clone(), &all_stark.cross_table_lookups, stark_config, ); - let by_table = [cpu, keccak, keccak_sponge, logic, memory]; + let by_table = [arithmetic, cpu, keccak, keccak_sponge, logic, memory]; let root = Self::create_root_circuit(&by_table, stark_config); let aggregation = Self::create_aggregation_circuit(&root); let block = Self::create_block_circuit(&aggregation); diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 9f087c69..414b8d50 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -20,6 +20,7 @@ use plonky2_maybe_rayon::*; use plonky2_util::{log2_ceil, log2_strict}; use crate::all_stark::{AllStark, Table, NUM_TABLES}; +use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; @@ -50,6 +51,7 @@ pub fn prove( where F: RichField + Extendable, C: GenericConfig, + [(); ArithmeticStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -71,6 +73,7 @@ pub fn prove_with_outputs( where F: RichField + Extendable, C: GenericConfig, + [(); ArithmeticStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -98,6 +101,7 @@ pub(crate) fn prove_with_traces( where F: RichField + Extendable, C: GenericConfig, + [(); ArithmeticStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -185,12 +189,26 @@ fn prove_with_commitments( where F: RichField + Extendable, C: GenericConfig, + [(); ArithmeticStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, [(); MemoryStark::::COLUMNS]:, { + let arithmetic_proof = timed!( + timing, + "prove Arithmetic STARK", + prove_single_table( + &all_stark.arithmetic_stark, + config, + &trace_poly_values[Table::Arithmetic as usize], + &trace_commitments[Table::Arithmetic as usize], + &ctl_data_per_table[Table::Arithmetic as usize], + challenger, + timing, + )? + ); let cpu_proof = timed!( timing, "prove CPU STARK", @@ -257,6 +275,7 @@ where )? ); Ok([ + arithmetic_proof, cpu_proof, keccak_proof, keccak_sponge_proof, diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index 807ca203..5e68350b 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -9,6 +9,7 @@ use plonky2::plonk::config::GenericConfig; use plonky2::plonk::plonk_common::reduce_with_powers; use crate::all_stark::{AllStark, Table}; +use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; @@ -31,6 +32,7 @@ pub fn verify_proof, C: GenericConfig, co config: &StarkConfig, ) -> Result<()> where + [(); ArithmeticStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -45,6 +47,7 @@ where let nums_permutation_zs = all_stark.nums_permutation_zs(config); let AllStark { + arithmetic_stark, cpu_stark, keccak_stark, keccak_sponge_stark, @@ -60,6 +63,13 @@ where &nums_permutation_zs, ); + verify_stark_proof_with_challenges( + arithmetic_stark, + &all_proof.stark_proofs[Table::Arithmetic as usize].proof, + &stark_challenges[Table::Arithmetic as usize], + &ctl_vars_per_table[Table::Arithmetic as usize], + config, + )?; verify_stark_proof_with_challenges( cpu_stark, &all_proof.stark_proofs[Table::Cpu as usize].proof, diff --git a/evm/src/witness/traces.rs b/evm/src/witness/traces.rs index c904d2e5..4a1c8d85 100644 --- a/evm/src/witness/traces.rs +++ b/evm/src/witness/traces.rs @@ -18,19 +18,19 @@ use crate::{arithmetic, keccak, logic}; #[derive(Clone, Copy, Debug)] pub struct TraceCheckpoint { + pub(self) arithmetic_len: usize, pub(self) cpu_len: usize, pub(self) keccak_len: usize, pub(self) keccak_sponge_len: usize, pub(self) logic_len: usize, - pub(self) arithmetic_len: usize, pub(self) memory_len: usize, } #[derive(Debug)] pub(crate) struct Traces { + pub(crate) arithmetic_ops: Vec, pub(crate) cpu: Vec>, pub(crate) logic_ops: Vec, - pub(crate) arithmetic: Vec, pub(crate) memory_ops: Vec, pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>, pub(crate) keccak_sponge_ops: Vec, @@ -39,9 +39,9 @@ pub(crate) struct Traces { impl Traces { pub fn new() -> Self { Traces { + arithmetic_ops: vec![], cpu: vec![], logic_ops: vec![], - arithmetic: vec![], memory_ops: vec![], keccak_inputs: vec![], keccak_sponge_ops: vec![], @@ -50,22 +50,22 @@ impl Traces { pub fn checkpoint(&self) -> TraceCheckpoint { TraceCheckpoint { + arithmetic_len: self.arithmetic_ops.len(), cpu_len: self.cpu.len(), keccak_len: self.keccak_inputs.len(), keccak_sponge_len: self.keccak_sponge_ops.len(), logic_len: self.logic_ops.len(), - arithmetic_len: self.arithmetic.len(), memory_len: self.memory_ops.len(), } } pub fn rollback(&mut self, checkpoint: TraceCheckpoint) { + self.arithmetic_ops.truncate(checkpoint.arithmetic_len); self.cpu.truncate(checkpoint.cpu_len); self.keccak_inputs.truncate(checkpoint.keccak_len); self.keccak_sponge_ops .truncate(checkpoint.keccak_sponge_len); self.logic_ops.truncate(checkpoint.logic_len); - self.arithmetic.truncate(checkpoint.arithmetic_len); self.memory_ops.truncate(checkpoint.memory_len); } @@ -82,7 +82,7 @@ impl Traces { } pub fn push_arithmetic(&mut self, op: arithmetic::Operation) { - self.arithmetic.push(op); + self.arithmetic_ops.push(op); } pub fn push_memory(&mut self, op: MemoryOp) { @@ -122,14 +122,20 @@ impl Traces { { let cap_elements = config.fri_config.num_cap_elements(); let Traces { + arithmetic_ops, cpu, logic_ops, - arithmetic: _, // TODO memory_ops, keccak_inputs, keccak_sponge_ops, } = self; + let arithmetic_trace = timed!( + timing, + "generate arithmetic trace", + all_stark.arithmetic_stark.generate_trace(arithmetic_ops) + ); + let cpu_rows = cpu.into_iter().map(|x| x.into()).collect(); let cpu_trace = trace_rows_to_poly_values(cpu_rows); let keccak_trace = timed!( @@ -160,6 +166,7 @@ impl Traces { ); [ + arithmetic_trace, cpu_trace, keccak_trace, keccak_sponge_trace, diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index ec2d999b..4b2a2762 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -97,7 +97,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { let all_circuits = AllRecursiveCircuits::::new( &all_stark, - &[9..15, 9..15, 9..10, 9..12, 9..18], // Minimal ranges to prove an empty list + &[9..18, 9..15, 9..15, 9..10, 9..12, 9..18], // Minimal ranges to prove an empty list &config, );