From 8beba56903381ff37bfff9aa8a3b7c6fe67d1b15 Mon Sep 17 00:00:00 2001 From: Hamy Ratoanina Date: Mon, 28 Aug 2023 16:32:04 -0400 Subject: [PATCH 1/3] Constrain next row's stack length --- evm/src/cpu/cpu_stark.rs | 8 ++++---- evm/src/cpu/jumps.rs | 14 +++++++++++--- evm/src/cpu/simple_logic/eq_iszero.rs | 7 ++++++- evm/src/cpu/simple_logic/mod.rs | 6 ++++-- evm/src/cpu/stack.rs | 23 +++++++++++++++++++++-- 5 files changed, 46 insertions(+), 12 deletions(-) diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 7fd0c76f..5aa64c1a 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -247,8 +247,8 @@ impl, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark( let is_jumpi = filter * lv.opcode_bits[0]; // Stack constraints. - stack::eval_packed_one(lv, is_jump, stack::JUMP_OP.unwrap(), yield_constr); - stack::eval_packed_one(lv, is_jumpi, stack::JUMPI_OP.unwrap(), yield_constr); + stack::eval_packed_one(lv, nv, is_jump, stack::JUMP_OP.unwrap(), yield_constr); + stack::eval_packed_one(lv, nv, is_jumpi, stack::JUMPI_OP.unwrap(), yield_constr); // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`. @@ -151,10 +151,18 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> let is_jumpi = builder.mul_extension(filter, lv.opcode_bits[0]); // Stack constraints. - stack::eval_ext_circuit_one(builder, lv, is_jump, stack::JUMP_OP.unwrap(), yield_constr); stack::eval_ext_circuit_one( builder, lv, + nv, + is_jump, + stack::JUMP_OP.unwrap(), + yield_constr, + ); + stack::eval_ext_circuit_one( + builder, + lv, + nv, is_jumpi, stack::JUMPI_OP.unwrap(), yield_constr, diff --git a/evm/src/cpu/simple_logic/eq_iszero.rs b/evm/src/cpu/simple_logic/eq_iszero.rs index f16901f5..7be021ca 100644 --- a/evm/src/cpu/simple_logic/eq_iszero.rs +++ b/evm/src/cpu/simple_logic/eq_iszero.rs @@ -51,6 +51,7 @@ pub fn generate_pinv_diff(val0: U256, val1: U256, lv: &mut CpuColumnsV pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let logic = lv.general.logic(); @@ -94,9 +95,10 @@ pub fn eval_packed( yield_constr.constraint(eq_or_iszero_filter * (dot - unequal)); // Stack constraints. - stack::eval_packed_one(lv, eq_filter, EQ_STACK_BEHAVIOR.unwrap(), yield_constr); + stack::eval_packed_one(lv, nv, eq_filter, EQ_STACK_BEHAVIOR.unwrap(), yield_constr); stack::eval_packed_one( lv, + nv, iszero_filter, IS_ZERO_STACK_BEHAVIOR.unwrap(), yield_constr, @@ -106,6 +108,7 @@ pub fn eval_packed( pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { let zero = builder.zero_extension(); @@ -173,6 +176,7 @@ pub fn eval_ext_circuit, const D: usize>( stack::eval_ext_circuit_one( builder, lv, + nv, eq_filter, EQ_STACK_BEHAVIOR.unwrap(), yield_constr, @@ -180,6 +184,7 @@ pub fn eval_ext_circuit, const D: usize>( stack::eval_ext_circuit_one( builder, lv, + nv, iszero_filter, IS_ZERO_STACK_BEHAVIOR.unwrap(), yield_constr, diff --git a/evm/src/cpu/simple_logic/mod.rs b/evm/src/cpu/simple_logic/mod.rs index 03d2dd15..9b4e60b0 100644 --- a/evm/src/cpu/simple_logic/mod.rs +++ b/evm/src/cpu/simple_logic/mod.rs @@ -11,17 +11,19 @@ use crate::cpu::columns::CpuColumnsView; pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { not::eval_packed(lv, yield_constr); - eq_iszero::eval_packed(lv, yield_constr); + eq_iszero::eval_packed(lv, nv, yield_constr); } pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { not::eval_ext_circuit(builder, lv, yield_constr); - eq_iszero::eval_ext_circuit(builder, lv, yield_constr); + eq_iszero::eval_ext_circuit(builder, lv, nv, yield_constr); } diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index 8ffc152d..198c76db 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -140,6 +140,7 @@ pub(crate) const IS_ZERO_STACK_BEHAVIOR: Option = BASIC_UNARY_OP; pub(crate) fn eval_packed_one( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, filter: P, stack_behavior: StackBehavior, yield_constr: &mut ConstraintConsumer

, @@ -185,15 +186,21 @@ pub(crate) fn eval_packed_one( yield_constr.constraint(filter * channel.used); } } + + // Constrain new stack length. + let num_pops = P::Scalar::from_canonical_usize(stack_behavior.num_pops); + let push = P::Scalar::from_canonical_usize(stack_behavior.pushes as usize); + yield_constr.constraint_transition(filter * (nv.stack_len - (lv.stack_len - num_pops + push))); } pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { for (op, stack_behavior) in izip!(lv.op.into_iter(), STACK_BEHAVIORS.into_iter()) { if let Some(stack_behavior) = stack_behavior { - eval_packed_one(lv, op, stack_behavior, yield_constr); + eval_packed_one(lv, nv, op, stack_behavior, yield_constr); } } } @@ -201,6 +208,7 @@ pub fn eval_packed( pub(crate) fn eval_ext_circuit_one, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, filter: ExtensionTarget, stack_behavior: StackBehavior, yield_constr: &mut RecursiveConstraintConsumer, @@ -298,16 +306,27 @@ pub(crate) fn eval_ext_circuit_one, const D: usize> yield_constr.constraint(builder, constr); } } + + // Constrain new stack length. + let diff = builder.constant_extension( + F::Extension::from_canonical_usize(stack_behavior.num_pops) + - F::Extension::from_canonical_usize(stack_behavior.pushes as usize), + ); + let diff = builder.sub_extension(lv.stack_len, diff); + let diff = builder.sub_extension(nv.stack_len, diff); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint_transition(builder, constr); } pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { for (op, stack_behavior) in izip!(lv.op.into_iter(), STACK_BEHAVIORS.into_iter()) { if let Some(stack_behavior) = stack_behavior { - eval_ext_circuit_one(builder, lv, op, stack_behavior, yield_constr); + eval_ext_circuit_one(builder, lv, nv, op, stack_behavior, yield_constr); } } } From 06bc73f7ea62d91684596df4535feb27c9c70ef8 Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Thu, 14 Sep 2023 10:36:48 -0400 Subject: [PATCH 2/3] Combine arithmetic flags on the CPU side (#1187) * Combine FP254 flags * Combine basic binary ops together and do CTL with opcode value * Combine ternary ops together * Combine MUL DIV and MOD * Combine shift operations * Combine byte with other binary ops * Fix tests * Clean leftover comment * Update from latest main * Put the 'is_simulated' flag inside the Operation enum * Cleaner way to handle "simulated" operations SHL and SHR. * Fix comments. * Minor: suggestion for re-expressing `combined_ops`. * Update comment --------- Co-authored-by: Hamish Ivey-Law --- evm/src/arithmetic/arithmetic_stark.rs | 53 ++++++++++++------- evm/src/arithmetic/columns.rs | 4 +- evm/src/arithmetic/divmod.rs | 15 ++++-- evm/src/arithmetic/mod.rs | 32 ++++++++++-- evm/src/arithmetic/mul.rs | 6 ++- evm/src/cpu/columns/ops.rs | 30 +++-------- evm/src/cpu/control_flow.rs | 20 ++------ evm/src/cpu/cpu_stark.rs | 66 +++++++----------------- evm/src/cpu/decode.rs | 61 +++++++++++----------- evm/src/cpu/gas.rs | 70 ++++++++++++++++++++------ evm/src/cpu/modfp254.rs | 4 +- evm/src/cpu/shift.rs | 4 +- evm/src/cpu/stack.rs | 24 ++------- evm/src/witness/gas.rs | 4 +- evm/src/witness/operation.rs | 13 +++-- evm/src/witness/transition.rs | 31 ++++-------- 16 files changed, 223 insertions(+), 214 deletions(-) diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 4695798a..5441cf27 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -27,10 +27,17 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// This is done by taking pairs of columns (x, y) of the arithmetic /// table and combining them as x + y*2^16 to ensure they equal the /// corresponding 32-bit number in the CPU table. -fn cpu_arith_data_link(ops: &[usize], regs: &[Range]) -> Vec> { +fn cpu_arith_data_link( + combined_ops: &[(usize, u8)], + regs: &[Range], +) -> Vec> { let limb_base = F::from_canonical_u64(1 << columns::LIMB_BITS); - let mut res = Column::singles(ops).collect_vec(); + let mut res = vec![Column::linear_combination( + combined_ops + .iter() + .map(|&(col, code)| (col, F::from_canonical_u8(code))), + )]; // The inner for loop below assumes N_LIMBS is even. const_assert!(columns::N_LIMBS % 2 == 0); @@ -49,21 +56,27 @@ fn cpu_arith_data_link(ops: &[usize], regs: &[Range]) -> Vec() -> TableWithColumns { - const ARITH_OPS: [usize; 14] = [ - columns::IS_ADD, - columns::IS_SUB, - columns::IS_MUL, - columns::IS_LT, - columns::IS_GT, - columns::IS_ADDFP254, - columns::IS_MULFP254, - columns::IS_SUBFP254, - columns::IS_ADDMOD, - columns::IS_MULMOD, - columns::IS_SUBMOD, - columns::IS_DIV, - columns::IS_MOD, - columns::IS_BYTE, + // We scale each filter flag with the associated opcode value. + // If an arithmetic operation is happening on the CPU side, + // the CTL will enforce that the reconstructed opcode value + // from the opcode bits matches. + const COMBINED_OPS: [(usize, u8); 16] = [ + (columns::IS_ADD, 0x01), + (columns::IS_MUL, 0x02), + (columns::IS_SUB, 0x03), + (columns::IS_DIV, 0x04), + (columns::IS_MOD, 0x06), + (columns::IS_ADDMOD, 0x08), + (columns::IS_MULMOD, 0x09), + (columns::IS_ADDFP254, 0x0c), + (columns::IS_MULFP254, 0x0d), + (columns::IS_SUBFP254, 0x0e), + (columns::IS_SUBMOD, 0x0f), + (columns::IS_LT, 0x10), + (columns::IS_GT, 0x11), + (columns::IS_BYTE, 0x1a), + (columns::IS_SHL, 0x1b), + (columns::IS_SHR, 0x1c), ]; const REGISTER_MAP: [Range; 4] = [ @@ -73,6 +86,8 @@ pub fn ctl_arithmetic_rows() -> TableWithColumns { columns::OUTPUT_REGISTER, ]; + let filter_column = Some(Column::sum(COMBINED_OPS.iter().map(|(c, _v)| *c))); + // Create the Arithmetic Table whose columns are those of the // operations listed in `ops` whose inputs and outputs are given // by `regs`, where each element of `regs` is a range of columns @@ -80,8 +95,8 @@ pub fn ctl_arithmetic_rows() -> TableWithColumns { // is used as the operation filter). TableWithColumns::new( Table::Arithmetic, - cpu_arith_data_link(&ARITH_OPS, ®ISTER_MAP), - Some(Column::sum(ARITH_OPS)), + cpu_arith_data_link(&COMBINED_OPS, ®ISTER_MAP), + filter_column, ) } diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index afdd5832..48e00f8e 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -36,8 +36,10 @@ pub(crate) const IS_SUBMOD: usize = IS_SUBFP254 + 1; pub(crate) const IS_LT: usize = IS_SUBMOD + 1; pub(crate) const IS_GT: usize = IS_LT + 1; pub(crate) const IS_BYTE: usize = IS_GT + 1; +pub(crate) const IS_SHL: usize = IS_BYTE + 1; +pub(crate) const IS_SHR: usize = IS_SHL + 1; -pub(crate) const START_SHARED_COLS: usize = IS_BYTE + 1; +pub(crate) const START_SHARED_COLS: usize = IS_SHR + 1; /// Within the Arithmetic Unit, there are shared columns which can be /// used by any arithmetic circuit, depending on which one is active diff --git a/evm/src/arithmetic/divmod.rs b/evm/src/arithmetic/divmod.rs index 4f2dd748..258c131f 100644 --- a/evm/src/arithmetic/divmod.rs +++ b/evm/src/arithmetic/divmod.rs @@ -45,7 +45,7 @@ pub(crate) fn generate( } match filter { - IS_DIV => { + IS_DIV | IS_SHR => { debug_assert!( lv[OUTPUT_REGISTER] .iter() @@ -104,11 +104,14 @@ pub(crate) fn eval_packed( nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, ) { + // Constrain IS_SHR independently, so that it doesn't impact the + // constraints when combining the flag with IS_DIV. + yield_constr.constraint_last_row(lv[IS_SHR]); eval_packed_divmod_helper( lv, nv, yield_constr, - lv[IS_DIV], + lv[IS_DIV] + lv[IS_SHR], OUTPUT_REGISTER, AUX_INPUT_REGISTER_0, ); @@ -161,12 +164,14 @@ pub(crate) fn eval_ext_circuit, const D: usize>( nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, ) { + yield_constr.constraint_last_row(builder, lv[IS_SHR]); + let div_shr_flag = builder.add_extension(lv[IS_DIV], lv[IS_SHR]); eval_ext_circuit_divmod_helper( builder, lv, nv, yield_constr, - lv[IS_DIV], + div_shr_flag, OUTPUT_REGISTER, AUX_INPUT_REGISTER_0, ); @@ -209,6 +214,8 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + // Deactivate the SHR flag so that a DIV operation is not triggered. + lv[IS_SHR] = F::ZERO; let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], @@ -240,6 +247,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + lv[IS_SHR] = F::ZERO; lv[op_filter] = F::ONE; let input0 = U256::from(rng.gen::<[u8; 32]>()); @@ -300,6 +308,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + lv[IS_SHR] = F::ZERO; lv[op_filter] = F::ONE; let input0 = U256::from(rng.gen::<[u8; 32]>()); diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index d9d63a0b..bd6d56e8 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -27,15 +27,17 @@ pub(crate) enum BinaryOperator { MulFp254, SubFp254, Byte, + Shl, // simulated with MUL + Shr, // simulated with DIV } impl BinaryOperator { pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 { match self { BinaryOperator::Add => input0.overflowing_add(input1).0, - BinaryOperator::Mul => input0.overflowing_mul(input1).0, + BinaryOperator::Mul | BinaryOperator::Shl => input0.overflowing_mul(input1).0, BinaryOperator::Sub => input0.overflowing_sub(input1).0, - BinaryOperator::Div => { + BinaryOperator::Div | BinaryOperator::Shr => { if input1.is_zero() { U256::zero() } else { @@ -77,6 +79,8 @@ impl BinaryOperator { BinaryOperator::MulFp254 => columns::IS_MULFP254, BinaryOperator::SubFp254 => columns::IS_SUBFP254, BinaryOperator::Byte => columns::IS_BYTE, + BinaryOperator::Shl => columns::IS_SHL, + BinaryOperator::Shr => columns::IS_SHR, } } } @@ -107,6 +111,7 @@ impl TernaryOperator { } } +/// An enum representing arithmetic operations that can be either binary or ternary. #[derive(Debug)] pub(crate) enum Operation { BinaryOperation { @@ -125,6 +130,21 @@ pub(crate) enum Operation { } impl Operation { + /// Create a binary operator with given inputs. + /// + /// NB: This works as you would expect, EXCEPT for SHL and SHR, + /// whose inputs need a small amount of preprocessing. Specifically, + /// to create `SHL(shift, value)`, call (note the reversal of + /// argument order): + /// + /// `Operation::binary(BinaryOperator::Shl, value, 1 << shift)` + /// + /// Similarly, to create `SHR(shift, value)`, call + /// + /// `Operation::binary(BinaryOperator::Shr, value, 1 << shift)` + /// + /// See witness/operation.rs::append_shift() for an example (indeed + /// the only call site for such inputs). pub(crate) fn binary(operator: BinaryOperator, input0: U256, input1: U256) -> Self { let result = operator.result(input0, input1); Self::BinaryOperation { @@ -164,6 +184,10 @@ impl Operation { /// use vectors because that's what utils::transpose (who consumes /// the result of this function as part of the range check code) /// expects. + /// + /// The `is_simulated` bool indicates whether we use a native arithmetic + /// operation or simulate one with another. This is used to distinguish + /// SHL and SHR operations that are simulated through MUL and DIV respectively. fn to_rows(&self) -> (Vec, Option>) { match *self { Operation::BinaryOperation { @@ -214,11 +238,11 @@ fn binary_op_to_rows( addcy::generate(&mut row, op.row_filter(), input0, input1); (row, None) } - BinaryOperator::Mul => { + BinaryOperator::Mul | BinaryOperator::Shl => { mul::generate(&mut row, input0, input1); (row, None) } - BinaryOperator::Div | BinaryOperator::Mod => { + BinaryOperator::Div | BinaryOperator::Mod | BinaryOperator::Shr => { let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; divmod::generate(&mut row, &mut nv, op.row_filter(), input0, input1, result); (row, Some(nv)) diff --git a/evm/src/arithmetic/mul.rs b/evm/src/arithmetic/mul.rs index 597d4051..efb4d822 100644 --- a/evm/src/arithmetic/mul.rs +++ b/evm/src/arithmetic/mul.rs @@ -121,7 +121,7 @@ pub fn eval_packed_generic( ) { let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); - let is_mul = lv[IS_MUL]; + let is_mul = lv[IS_MUL] + lv[IS_SHL]; let input0_limbs = read_value::(lv, INPUT_REGISTER_0); let input1_limbs = read_value::(lv, INPUT_REGISTER_1); let output_limbs = read_value::(lv, OUTPUT_REGISTER); @@ -173,7 +173,7 @@ pub fn eval_ext_circuit, const D: usize>( lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, ) { - let is_mul = lv[IS_MUL]; + let is_mul = builder.add_extension(lv[IS_MUL], lv[IS_SHL]); let input0_limbs = read_value::(lv, INPUT_REGISTER_0); let input1_limbs = read_value::(lv, INPUT_REGISTER_1); let output_limbs = read_value::(lv, OUTPUT_REGISTER); @@ -229,6 +229,8 @@ mod tests { // if `IS_MUL == 0`, then the constraints should be met even // if all values are garbage. lv[IS_MUL] = F::ZERO; + // Deactivate the SHL flag so that a MUL operation is not triggered. + lv[IS_SHL] = F::ZERO; let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index 6c68a183..b8a4d8a6 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -7,33 +7,17 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; #[repr(C)] #[derive(Clone, Copy, Eq, PartialEq, Debug)] pub struct OpsColumnsView { - // TODO: combine ADD, MUL, SUB, DIV, MOD, ADDFP254, MULFP254, SUBFP254, LT, and GT into one flag - pub add: T, - pub mul: T, - pub sub: T, - pub div: T, - pub mod_: T, - // TODO: combine ADDMOD, MULMOD and SUBMOD into one flag - pub addmod: T, - pub mulmod: T, - pub addfp254: T, - pub mulfp254: T, - pub subfp254: T, - pub submod: T, - pub lt: T, - pub gt: T, - pub eq_iszero: T, // Combines EQ and ISZERO flags. - pub logic_op: T, // Combines AND, OR and XOR flags. + pub binary_op: T, // Combines ADD, MUL, SUB, DIV, MOD, LT, GT and BYTE flags. + pub ternary_op: T, // Combines ADDMOD, MULMOD and SUBMOD flags. + pub fp254_op: T, // Combines ADD_FP254, MUL_FP254 and SUB_FP254 flags. + pub eq_iszero: T, // Combines EQ and ISZERO flags. + pub logic_op: T, // Combines AND, OR and XOR flags. pub not: T, - pub byte: T, - // TODO: combine SHL and SHR into one flag - pub shl: T, - pub shr: T, + pub shift: T, // Combines SHL and SHR flags. pub keccak_general: T, pub prover_input: T, pub pop: T, - // TODO: combine JUMP and JUMPI into one flag - pub jumps: T, // Note: This column must be 0 when is_cpu_cycle = 0. + pub jumps: T, // Combines JUMP and JUMPI flags. pub pc: T, pub jumpdest: T, pub push0: T, diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index 0bea5c7c..8d0ee264 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -8,24 +8,14 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; -const NATIVE_INSTRUCTIONS: [usize; 28] = [ - COL_MAP.op.add, - COL_MAP.op.mul, - COL_MAP.op.sub, - COL_MAP.op.div, - COL_MAP.op.mod_, - COL_MAP.op.addmod, - COL_MAP.op.mulmod, - COL_MAP.op.addfp254, - COL_MAP.op.mulfp254, - COL_MAP.op.subfp254, - COL_MAP.op.lt, - COL_MAP.op.gt, +const NATIVE_INSTRUCTIONS: [usize; 18] = [ + COL_MAP.op.binary_op, + COL_MAP.op.ternary_op, + COL_MAP.op.fp254_op, COL_MAP.op.eq_iszero, COL_MAP.op.logic_op, COL_MAP.op.not, - COL_MAP.op.shl, - COL_MAP.op.shr, + COL_MAP.op.shift, COL_MAP.op.keccak_general, COL_MAP.op.prover_input, COL_MAP.op.pop, diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 25e7cc6b..820ccd3d 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -48,9 +48,8 @@ pub fn ctl_filter_keccak_sponge() -> Column { /// Create the vector of Columns corresponding to the two inputs and /// one output of a binary operation. -fn ctl_data_binops(ops: &[usize]) -> Vec> { - let mut res = Column::singles(ops).collect_vec(); - res.extend(Column::singles(COL_MAP.mem_channels[0].value)); +fn ctl_data_binops() -> Vec> { + let mut res = Column::singles(COL_MAP.mem_channels[0].value).collect_vec(); res.extend(Column::singles(COL_MAP.mem_channels[1].value)); res.extend(Column::singles( COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value, @@ -70,10 +69,9 @@ fn ctl_data_binops(ops: &[usize]) -> Vec> { /// case of shift operations, which will skip the first memory channel and use the /// next three as ternary inputs. Because both `MUL` and `DIV` are binary operations, /// the last memory channel used for the inputs will be safely ignored. -fn ctl_data_ternops(ops: &[usize], is_shift: bool) -> Vec> { +fn ctl_data_ternops(is_shift: bool) -> Vec> { let offset = is_shift as usize; - let mut res = Column::singles(ops).collect_vec(); - res.extend(Column::singles(COL_MAP.mem_channels[offset].value)); + let mut res = Column::singles(COL_MAP.mem_channels[offset].value).collect_vec(); res.extend(Column::singles(COL_MAP.mem_channels[offset + 1].value)); res.extend(Column::singles(COL_MAP.mem_channels[offset + 2].value)); res.extend(Column::singles( @@ -85,7 +83,7 @@ fn ctl_data_ternops(ops: &[usize], is_shift: bool) -> Vec> { pub fn ctl_data_logic() -> Vec> { // Instead of taking single columns, we reconstruct the entire opcode value directly. let mut res = vec![Column::le_bits(COL_MAP.opcode_bits)]; - res.extend(ctl_data_binops(&[])); + res.extend(ctl_data_binops()); res } @@ -94,22 +92,9 @@ pub fn ctl_filter_logic() -> Column { } pub fn ctl_arithmetic_base_rows() -> TableWithColumns { - const OPS: [usize; 14] = [ - COL_MAP.op.add, - COL_MAP.op.sub, - COL_MAP.op.mul, - COL_MAP.op.lt, - COL_MAP.op.gt, - COL_MAP.op.addfp254, - COL_MAP.op.mulfp254, - COL_MAP.op.subfp254, - COL_MAP.op.addmod, - COL_MAP.op.mulmod, - COL_MAP.op.submod, - COL_MAP.op.div, - COL_MAP.op.mod_, - COL_MAP.op.byte, - ]; + // Instead of taking single columns, we reconstruct the entire opcode value directly. + let mut columns = vec![Column::le_bits(COL_MAP.opcode_bits)]; + columns.extend(ctl_data_ternops(false)); // Create the CPU Table whose columns are those with the three // inputs and one output of the ternary operations listed in `ops` // (also `ops` is used as the operation filter). The list of @@ -117,40 +102,25 @@ pub fn ctl_arithmetic_base_rows() -> TableWithColumns { // the third input. TableWithColumns::new( Table::Cpu, - ctl_data_ternops(&OPS, false), - Some(Column::sum(OPS)), + columns, + Some(Column::sum([ + COL_MAP.op.binary_op, + COL_MAP.op.fp254_op, + COL_MAP.op.ternary_op, + ])), ) } pub fn ctl_arithmetic_shift_rows() -> TableWithColumns { - const OPS: [usize; 14] = [ - COL_MAP.op.add, - COL_MAP.op.sub, - // SHL is interpreted as MUL on the arithmetic side - COL_MAP.op.shl, - COL_MAP.op.lt, - COL_MAP.op.gt, - COL_MAP.op.addfp254, - COL_MAP.op.mulfp254, - COL_MAP.op.subfp254, - COL_MAP.op.addmod, - COL_MAP.op.mulmod, - COL_MAP.op.submod, - // SHR is interpreted as DIV on the arithmetic side - COL_MAP.op.shr, - COL_MAP.op.mod_, - COL_MAP.op.byte, - ]; + // Instead of taking single columns, we reconstruct the entire opcode value directly. + let mut columns = vec![Column::le_bits(COL_MAP.opcode_bits)]; + columns.extend(ctl_data_ternops(true)); // Create the CPU Table whose columns are those with the three // inputs and one output of the ternary operations listed in `ops` // (also `ops` is used as the operation filter). The list of // operations includes binary operations which will simply ignore // the third input. - TableWithColumns::new( - Table::Cpu, - ctl_data_ternops(&OPS, true), - Some(Column::sum([COL_MAP.op.shl, COL_MAP.op.shr])), - ) + TableWithColumns::new(Table::Cpu, columns, Some(Column::single(COL_MAP.op.shift))) } pub fn ctl_data_byte_packing() -> Vec> { diff --git a/evm/src/cpu/decode.rs b/evm/src/cpu/decode.rs index 9a9c5723..cc87281c 100644 --- a/evm/src/cpu/decode.rs +++ b/evm/src/cpu/decode.rs @@ -22,26 +22,15 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP}; /// behavior. /// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to /// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid. -const OPCODES: [(u8, usize, bool, usize); 33] = [ +const OPCODES: [(u8, usize, bool, usize); 18] = [ // (start index of block, number of top bits to check (log2), kernel-only, flag column) - (0x01, 0, false, COL_MAP.op.add), - (0x02, 0, false, COL_MAP.op.mul), - (0x03, 0, false, COL_MAP.op.sub), - (0x04, 0, false, COL_MAP.op.div), - (0x06, 0, false, COL_MAP.op.mod_), - (0x08, 0, false, COL_MAP.op.addmod), - (0x09, 0, false, COL_MAP.op.mulmod), - (0x0c, 0, true, COL_MAP.op.addfp254), - (0x0d, 0, true, COL_MAP.op.mulfp254), - (0x0e, 0, true, COL_MAP.op.subfp254), - (0x10, 0, false, COL_MAP.op.lt), - (0x11, 0, false, COL_MAP.op.gt), + // ADD, MUL, SUB, DIV, MOD, LT, GT and BYTE flags are handled partly manually here, and partly through the Arithmetic table CTL. + // ADDMOD, MULMOD and SUBMOD flags are handled partly manually here, and partly through the Arithmetic table CTL. + // FP254 operation flags are handled partly manually here, and partly through the Arithmetic table CTL. (0x14, 1, false, COL_MAP.op.eq_iszero), // AND, OR and XOR flags are handled partly manually here, and partly through the Logic table CTL. (0x19, 0, false, COL_MAP.op.not), - (0x1a, 0, false, COL_MAP.op.byte), - (0x1b, 0, false, COL_MAP.op.shl), - (0x1c, 0, false, COL_MAP.op.shr), + // SHL and SHR flags are handled partly manually here, and partly through the Logic table CTL. (0x21, 0, true, COL_MAP.op.keccak_general), (0x49, 0, true, COL_MAP.op.prover_input), (0x50, 0, false, COL_MAP.op.pop), @@ -60,6 +49,17 @@ const OPCODES: [(u8, usize, bool, usize); 33] = [ (0xfc, 0, true, COL_MAP.op.mstore_general), ]; +/// List of combined opcodes requiring a special handling. +/// Each index in the list corresponds to an arbitrary combination +/// of opcodes defined in evm/src/cpu/columns/ops.rs. +const COMBINED_OPCODES: [usize; 5] = [ + COL_MAP.op.logic_op, + COL_MAP.op.fp254_op, + COL_MAP.op.binary_op, + COL_MAP.op.ternary_op, + COL_MAP.op.shift, +]; + pub fn generate(lv: &mut CpuColumnsView) { let cycle_filter: F = COL_MAP.op.iter().map(|&col_i| lv[col_i]).sum(); @@ -134,17 +134,17 @@ pub fn eval_packed_generic( let flag = lv[flag_col]; yield_constr.constraint(flag * (flag - P::ONES)); } - // Manually check the logic_op flag combining AND, OR and XOR. - let flag = lv.op.logic_op; - yield_constr.constraint(flag * (flag - P::ONES)); + // Also check that the combined instruction flags are valid. + for flag_idx in COMBINED_OPCODES { + yield_constr.constraint(lv[flag_idx] * (lv[flag_idx] - P::ONES)); + } - // Now check that they sum to 0 or 1. - // Includes the logic_op flag encompassing AND, OR and XOR opcodes. + // Now check that they sum to 0 or 1, including the combined flags. let flag_sum: P = OPCODES .into_iter() .map(|(_, _, _, flag_col)| lv[flag_col]) - .sum::

() - + lv.op.logic_op; + .chain(COMBINED_OPCODES.map(|op| lv[op])) + .sum::

(); yield_constr.constraint(flag_sum * (flag_sum - P::ONES)); // Finally, classify all opcodes, together with the kernel flag, into blocks @@ -204,15 +204,16 @@ pub fn eval_ext_circuit, const D: usize>( let constr = builder.mul_sub_extension(flag, flag, flag); yield_constr.constraint(builder, constr); } - // Manually check the logic_op flag combining AND, OR and XOR. - let flag = lv.op.logic_op; - let constr = builder.mul_sub_extension(flag, flag, flag); - yield_constr.constraint(builder, constr); + // Also check that the combined instruction flags are valid. + for flag_idx in COMBINED_OPCODES { + let constr = builder.mul_sub_extension(lv[flag_idx], lv[flag_idx], lv[flag_idx]); + yield_constr.constraint(builder, constr); + } - // Now check that they sum to 0 or 1. - // Includes the logic_op flag encompassing AND, OR and XOR opcodes. + // Now check that they sum to 0 or 1, including the combined flags. { - let mut flag_sum = lv.op.logic_op; + let mut flag_sum = + builder.add_many_extension(COMBINED_OPCODES.into_iter().map(|idx| lv[idx])); for (_, _, _, flag_col) in OPCODES { let flag = lv[flag_col]; flag_sum = builder.add_extension(flag_sum, flag); diff --git a/evm/src/cpu/gas.rs b/evm/src/cpu/gas.rs index e967c07e..a4a499ad 100644 --- a/evm/src/cpu/gas.rs +++ b/evm/src/cpu/gas.rs @@ -19,25 +19,13 @@ const G_MID: Option = Some(8); const G_HIGH: Option = Some(10); const SIMPLE_OPCODES: OpsColumnsView> = OpsColumnsView { - add: G_VERYLOW, - mul: G_LOW, - sub: G_VERYLOW, - div: G_LOW, - mod_: G_LOW, - addmod: G_MID, - mulmod: G_MID, - addfp254: KERNEL_ONLY_INSTR, - mulfp254: KERNEL_ONLY_INSTR, - subfp254: KERNEL_ONLY_INSTR, - submod: KERNEL_ONLY_INSTR, - lt: G_VERYLOW, - gt: G_VERYLOW, + binary_op: None, // This is handled manually below + ternary_op: None, // This is handled manually below + fp254_op: KERNEL_ONLY_INSTR, eq_iszero: G_VERYLOW, logic_op: G_VERYLOW, not: G_VERYLOW, - byte: G_VERYLOW, - shl: G_VERYLOW, - shr: G_VERYLOW, + shift: G_VERYLOW, keccak_general: KERNEL_ONLY_INSTR, prover_input: KERNEL_ONLY_INSTR, pop: G_BASE, @@ -97,6 +85,21 @@ fn eval_packed_accumulate( let jump_gas_cost = P::Scalar::from_canonical_u32(G_MID.unwrap()) + lv.opcode_bits[0] * P::Scalar::from_canonical_u32(G_HIGH.unwrap() - G_MID.unwrap()); yield_constr.constraint_transition(lv.op.jumps * (nv.gas - lv.gas - jump_gas_cost)); + + // For binary_ops. + // MUL, DIV and MOD are differentiated from ADD, SUB, LT, GT and BYTE by their first and fifth bits set to 0. + let cost_filter = lv.opcode_bits[0] + lv.opcode_bits[4] - lv.opcode_bits[0] * lv.opcode_bits[4]; + let binary_op_cost = P::Scalar::from_canonical_u32(G_LOW.unwrap()) + + cost_filter + * (P::Scalar::from_canonical_u32(G_VERYLOW.unwrap()) + - P::Scalar::from_canonical_u32(G_LOW.unwrap())); + yield_constr.constraint_transition(lv.op.binary_op * (nv.gas - lv.gas - binary_op_cost)); + + // For ternary_ops. + // SUBMOD is differentiated by its second bit set to 1. + let ternary_op_cost = P::Scalar::from_canonical_u32(G_MID.unwrap()) + - lv.opcode_bits[1] * P::Scalar::from_canonical_u32(G_MID.unwrap()); + yield_constr.constraint_transition(lv.op.ternary_op * (nv.gas - lv.gas - ternary_op_cost)); } fn eval_packed_init( @@ -186,6 +189,41 @@ fn eval_ext_circuit_accumulate, const D: usize>( let gas_diff = builder.sub_extension(nv_lv_diff, jump_gas_cost); let constr = builder.mul_extension(filter, gas_diff); yield_constr.constraint_transition(builder, constr); + + // For binary_ops. + // MUL, DIV and MOD are differentiated from ADD, SUB, LT, GT and BYTE by their first and fifth bits set to 0. + let filter = lv.op.binary_op; + let cost_filter = { + let a = builder.add_extension(lv.opcode_bits[0], lv.opcode_bits[4]); + let b = builder.mul_extension(lv.opcode_bits[0], lv.opcode_bits[4]); + builder.sub_extension(a, b) + }; + let binary_op_cost = builder.mul_const_extension( + F::from_canonical_u32(G_VERYLOW.unwrap()) - F::from_canonical_u32(G_LOW.unwrap()), + cost_filter, + ); + let binary_op_cost = + builder.add_const_extension(binary_op_cost, F::from_canonical_u32(G_LOW.unwrap())); + + let nv_lv_diff = builder.sub_extension(nv.gas, lv.gas); + let gas_diff = builder.sub_extension(nv_lv_diff, binary_op_cost); + let constr = builder.mul_extension(filter, gas_diff); + yield_constr.constraint_transition(builder, constr); + + // For ternary_ops. + // SUBMOD is differentiated by its second bit set to 1. + let filter = lv.op.ternary_op; + let ternary_op_cost = builder.mul_const_extension( + F::from_canonical_u32(G_MID.unwrap()).neg(), + lv.opcode_bits[1], + ); + let ternary_op_cost = + builder.add_const_extension(ternary_op_cost, F::from_canonical_u32(G_MID.unwrap())); + + let nv_lv_diff = builder.sub_extension(nv.gas, lv.gas); + let gas_diff = builder.sub_extension(nv_lv_diff, ternary_op_cost); + let constr = builder.mul_extension(filter, gas_diff); + yield_constr.constraint_transition(builder, constr); } fn eval_ext_circuit_init, const D: usize>( diff --git a/evm/src/cpu/modfp254.rs b/evm/src/cpu/modfp254.rs index e6a2815d..86f08052 100644 --- a/evm/src/cpu/modfp254.rs +++ b/evm/src/cpu/modfp254.rs @@ -19,7 +19,7 @@ pub fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.addfp254 + lv.op.mulfp254 + lv.op.subfp254; + let filter = lv.op.fp254_op; // We want to use all the same logic as the usual mod operations, but without needing to read // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's @@ -36,7 +36,7 @@ pub fn eval_ext_circuit, const D: usize>( lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let filter = builder.add_many_extension([lv.op.addfp254, lv.op.mulfp254, lv.op.subfp254]); + let filter = lv.op.fp254_op; // We want to use all the same logic as the usual mod operations, but without needing to read // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's diff --git a/evm/src/cpu/shift.rs b/evm/src/cpu/shift.rs index a8acf5d4..a4249297 100644 --- a/evm/src/cpu/shift.rs +++ b/evm/src/cpu/shift.rs @@ -13,7 +13,7 @@ pub(crate) fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let is_shift = lv.op.shl + lv.op.shr; + let is_shift = lv.op.shift; let displacement = lv.mem_channels[0]; // holds the shift displacement d let two_exp = lv.mem_channels[2]; // holds 2^d @@ -64,7 +64,7 @@ pub(crate) fn eval_ext_circuit, const D: usize>( lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let is_shift = builder.add_extension(lv.op.shl, lv.op.shr); + let is_shift = lv.op.shift; let displacement = lv.mem_channels[0]; let two_exp = lv.mem_channels[2]; diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index cfeaa1b0..a0c8df5c 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -50,29 +50,13 @@ pub(crate) const JUMPI_OP: Option = Some(StackBehavior { // except the first `num_pops` and the last `pushes as usize` channels have their read flag and // address constrained automatically in this file. const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { - add: BASIC_BINARY_OP, - mul: BASIC_BINARY_OP, - sub: BASIC_BINARY_OP, - div: BASIC_BINARY_OP, - mod_: BASIC_BINARY_OP, - addmod: BASIC_TERNARY_OP, - mulmod: BASIC_TERNARY_OP, - addfp254: BASIC_BINARY_OP, - mulfp254: BASIC_BINARY_OP, - subfp254: BASIC_BINARY_OP, - submod: BASIC_TERNARY_OP, - lt: BASIC_BINARY_OP, - gt: BASIC_BINARY_OP, + binary_op: BASIC_BINARY_OP, + ternary_op: BASIC_TERNARY_OP, + fp254_op: BASIC_BINARY_OP, eq_iszero: None, // EQ is binary, IS_ZERO is unary. logic_op: BASIC_BINARY_OP, not: BASIC_UNARY_OP, - byte: BASIC_BINARY_OP, - shl: Some(StackBehavior { - num_pops: 2, - pushes: true, - disable_other_channels: false, - }), - shr: Some(StackBehavior { + shift: Some(StackBehavior { num_pops: 2, pushes: true, disable_other_channels: false, diff --git a/evm/src/witness/gas.rs b/evm/src/witness/gas.rs index 3a46c044..aa312078 100644 --- a/evm/src/witness/gas.rs +++ b/evm/src/witness/gas.rs @@ -25,8 +25,8 @@ pub(crate) fn gas_to_charge(op: Operation) -> u64 { BinaryArithmetic(Lt) => G_VERYLOW, BinaryArithmetic(Gt) => G_VERYLOW, BinaryArithmetic(Byte) => G_VERYLOW, - Shl => G_VERYLOW, - Shr => G_VERYLOW, + BinaryArithmetic(Shl) => G_VERYLOW, + BinaryArithmetic(Shr) => G_VERYLOW, BinaryArithmetic(AddFp254) => KERNEL_ONLY_INSTR, BinaryArithmetic(MulFp254) => KERNEL_ONLY_INSTR, BinaryArithmetic(SubFp254) => KERNEL_ONLY_INSTR, diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index 7d07576d..23d64be4 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -29,8 +29,6 @@ use crate::{arithmetic, logic}; pub(crate) enum Operation { Iszero, Not, - Shl, - Shr, Syscall(u8, usize, bool), // (syscall number, minimum stack length, increases stack length) Eq, BinaryLogic(logic::Op), @@ -473,6 +471,7 @@ pub(crate) fn generate_iszero( fn append_shift( state: &mut GenerationState, mut row: CpuColumnsView, + is_shl: bool, input0: U256, input1: U256, log_in0: MemoryOp, @@ -500,10 +499,10 @@ fn append_shift( } else { U256::one() << input0 }; - let operator = if row.op.shl.is_one() { - BinaryOperator::Mul + let operator = if is_shl { + BinaryOperator::Shl } else { - BinaryOperator::Div + BinaryOperator::Shr }; let operation = arithmetic::Operation::binary(operator, input1, input0); @@ -527,7 +526,7 @@ pub(crate) fn generate_shl( } else { input1 << input0 }; - append_shift(state, row, input0, input1, log_in0, log_in1, result) + append_shift(state, row, true, input0, input1, log_in0, log_in1, result) } pub(crate) fn generate_shr( @@ -542,7 +541,7 @@ pub(crate) fn generate_shr( } else { input1 >> input0 }; - append_shift(state, row, input0, input1, log_in0, log_in1, result) + append_shift(state, row, false, input0, input1, log_in0, log_in1, result) } pub(crate) fn generate_syscall( diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 6e279cdf..9532aa33 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -70,8 +70,8 @@ fn decode(registers: RegistersState, opcode: u8) -> Result Ok(Operation::BinaryArithmetic( arithmetic::BinaryOperator::Byte, )), - (0x1b, _) => Ok(Operation::Shl), - (0x1c, _) => Ok(Operation::Shr), + (0x1b, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl)), + (0x1c, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr)), (0x1d, _) => Ok(Operation::Syscall(opcode, 2, false)), // SAR (0x20, _) => Ok(Operation::Syscall(opcode, 2, false)), // KECCAK256 (0x21, true) => Ok(Operation::KeccakGeneral), @@ -162,22 +162,13 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { Operation::Not => &mut flags.not, Operation::Syscall(_, _, _) => &mut flags.syscall, Operation::BinaryLogic(_) => &mut flags.logic_op, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Add) => &mut flags.add, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mul) => &mut flags.mul, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Sub) => &mut flags.sub, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Div) => &mut flags.div, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mod) => &mut flags.mod_, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Lt) => &mut flags.lt, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Gt) => &mut flags.gt, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Byte) => &mut flags.byte, - Operation::Shl => &mut flags.shl, - Operation::Shr => &mut flags.shr, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::AddFp254) => &mut flags.addfp254, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::MulFp254) => &mut flags.mulfp254, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::SubFp254) => &mut flags.subfp254, - Operation::TernaryArithmetic(arithmetic::TernaryOperator::AddMod) => &mut flags.addmod, - Operation::TernaryArithmetic(arithmetic::TernaryOperator::MulMod) => &mut flags.mulmod, - Operation::TernaryArithmetic(arithmetic::TernaryOperator::SubMod) => &mut flags.submod, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::AddFp254) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::MulFp254) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::SubFp254) => &mut flags.fp254_op, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => &mut flags.shift, + Operation::BinaryArithmetic(_) => &mut flags.binary_op, + Operation::TernaryArithmetic(_) => &mut flags.ternary_op, Operation::KeccakGeneral => &mut flags.keccak_general, Operation::ProverInput => &mut flags.prover_input, Operation::Pop => &mut flags.pop, @@ -204,8 +195,8 @@ fn perform_op( Operation::Swap(n) => generate_swap(n, state, row)?, Operation::Iszero => generate_iszero(state, row)?, Operation::Not => generate_not(state, row)?, - Operation::Shl => generate_shl(state, row)?, - Operation::Shr => generate_shr(state, row)?, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) => generate_shl(state, row)?, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => generate_shr(state, row)?, Operation::Syscall(opcode, stack_values_read, stack_len_increased) => { generate_syscall(opcode, stack_values_read, stack_len_increased, state, row)? } From 19220b21d71b828c3d952e11f2c7716244e0ec43 Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:27:38 -0400 Subject: [PATCH 3/3] Remove redundant Keccak sponge cols (#1233) * Rename columns in KeccakSponge for clarity * Remove redundant columns * Apply comments --- evm/src/keccak_sponge/columns.rs | 13 +- evm/src/keccak_sponge/keccak_sponge_stark.rs | 160 ++++++++++++------- 2 files changed, 113 insertions(+), 60 deletions(-) diff --git a/evm/src/keccak_sponge/columns.rs b/evm/src/keccak_sponge/columns.rs index 44f66a5d..431c09e0 100644 --- a/evm/src/keccak_sponge/columns.rs +++ b/evm/src/keccak_sponge/columns.rs @@ -5,11 +5,14 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; pub(crate) const KECCAK_WIDTH_BYTES: usize = 200; pub(crate) const KECCAK_WIDTH_U32S: usize = KECCAK_WIDTH_BYTES / 4; +pub(crate) const KECCAK_WIDTH_MINUS_DIGEST_U32S: usize = + (KECCAK_WIDTH_BYTES - KECCAK_DIGEST_BYTES) / 4; pub(crate) const KECCAK_RATE_BYTES: usize = 136; pub(crate) const KECCAK_RATE_U32S: usize = KECCAK_RATE_BYTES / 4; pub(crate) const KECCAK_CAPACITY_BYTES: usize = 64; pub(crate) const KECCAK_CAPACITY_U32S: usize = KECCAK_CAPACITY_BYTES / 4; pub(crate) const KECCAK_DIGEST_BYTES: usize = 32; +pub(crate) const KECCAK_DIGEST_U32S: usize = KECCAK_DIGEST_BYTES / 4; #[repr(C)] #[derive(Eq, PartialEq, Debug)] @@ -52,10 +55,14 @@ pub(crate) struct KeccakSpongeColumnsView { pub xored_rate_u32s: [T; KECCAK_RATE_U32S], /// The entire state (rate + capacity) of the sponge, encoded as 32-bit chunks, after the - /// permutation is applied. - pub updated_state_u32s: [T; KECCAK_WIDTH_U32S], + /// permutation is applied, minus the first limbs where the digest is extracted from. + /// Those missing limbs can be recomputed from their corresponding bytes stored in + /// `updated_digest_state_bytes`. + pub partial_updated_state_u32s: [T; KECCAK_WIDTH_MINUS_DIGEST_U32S], - pub updated_state_bytes: [T; KECCAK_DIGEST_BYTES], + /// The first part of the state of the sponge, seen as bytes, after the permutation is applied. + /// This also represents the output digest of the Keccak sponge during the squeezing phase. + pub updated_digest_state_bytes: [T; KECCAK_DIGEST_BYTES], } // `u8` is guaranteed to have a `size_of` of 1. diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs index 5f1a49cc..d78e9651 100644 --- a/evm/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -28,7 +28,7 @@ pub(crate) fn ctl_looked_data() -> Vec> { let mut outputs = Vec::with_capacity(8); for i in (0..8).rev() { let cur_col = Column::linear_combination( - cols.updated_state_bytes[i * 4..(i + 1) * 4] + cols.updated_digest_state_bytes[i * 4..(i + 1) * 4] .iter() .enumerate() .map(|(j, &c)| (c, F::from_canonical_u64(1 << (24 - 8 * j)))), @@ -49,15 +49,30 @@ pub(crate) fn ctl_looked_data() -> Vec> { pub(crate) fn ctl_looking_keccak() -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; - Column::singles( + let mut res: Vec<_> = Column::singles( [ cols.xored_rate_u32s.as_slice(), &cols.original_capacity_u32s, - &cols.updated_state_u32s, ] .concat(), ) - .collect() + .collect(); + + // We recover the 32-bit digest limbs from their corresponding bytes, + // and then append them to the rest of the updated state limbs. + let digest_u32s = cols.updated_digest_state_bytes.chunks_exact(4).map(|c| { + Column::linear_combination( + c.iter() + .enumerate() + .map(|(i, &b)| (b, F::from_canonical_usize(1 << (8 * i)))), + ) + }); + + res.extend(digest_u32s); + + res.extend(Column::singles(&cols.partial_updated_state_u32s)); + + res } pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { @@ -239,7 +254,21 @@ impl, const D: usize> KeccakSpongeStark { block.try_into().unwrap(), ); - sponge_state = row.updated_state_u32s.map(|f| f.to_canonical_u64() as u32); + sponge_state[..KECCAK_DIGEST_U32S] + .iter_mut() + .zip(row.updated_digest_state_bytes.chunks_exact(4)) + .for_each(|(s, bs)| { + *s = bs + .iter() + .enumerate() + .map(|(i, b)| (b.to_canonical_u64() as u32) << (8 * i)) + .sum(); + }); + + sponge_state[KECCAK_DIGEST_U32S..] + .iter_mut() + .zip(row.partial_updated_state_u32s) + .for_each(|(s, x)| *s = x.to_canonical_u64() as u32); rows.push(row.into()); already_absorbed_bytes += KECCAK_RATE_BYTES; @@ -357,24 +386,33 @@ impl, const D: usize> KeccakSpongeStark { row.xored_rate_u32s = xored_rate_u32s.map(F::from_canonical_u32); keccakf_u32s(&mut sponge_state); - row.updated_state_u32s = sponge_state.map(F::from_canonical_u32); - let is_final_block = row.is_final_input_len.iter().copied().sum::() == F::ONE; - if is_final_block { - for (l, &elt) in row.updated_state_u32s[..8].iter().enumerate() { + // Store all but the first `KECCAK_DIGEST_U32S` limbs in the updated state. + // Those missing limbs will be broken down into bytes and stored separately. + row.partial_updated_state_u32s.copy_from_slice( + &sponge_state[KECCAK_DIGEST_U32S..] + .iter() + .copied() + .map(|i| F::from_canonical_u32(i)) + .collect::>(), + ); + sponge_state[..KECCAK_DIGEST_U32S] + .iter() + .enumerate() + .for_each(|(l, &elt)| { let mut cur_elt = elt; (0..4).for_each(|i| { - row.updated_state_bytes[l * 4 + i] = - F::from_canonical_u32((cur_elt.to_canonical_u64() & 0xFF) as u32); - cur_elt = F::from_canonical_u64(cur_elt.to_canonical_u64() >> 8); + row.updated_digest_state_bytes[l * 4 + i] = + F::from_canonical_u32(cur_elt & 0xFF); + cur_elt >>= 8; }); - let mut s = row.updated_state_bytes[l * 4].to_canonical_u64(); + // 32-bit limb reconstruction consistency check. + let mut s = row.updated_digest_state_bytes[l * 4].to_canonical_u64(); for i in 1..4 { - s += row.updated_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i); + s += row.updated_digest_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i); } - assert_eq!(elt, F::from_canonical_u64(s), "not equal"); - } - } + assert_eq!(elt as u64, s, "not equal"); + }) } fn generate_padding_row(&self) -> [F; NUM_KECCAK_SPONGE_COLUMNS] { @@ -445,26 +483,39 @@ impl, const D: usize> Stark for KeccakSpongeS ); // If this is a full-input block, the next row's "before" should match our "after" state. + for (current_bytes_after, next_before) in local_values + .updated_digest_state_bytes + .chunks_exact(4) + .zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S]) + { + let mut current_after = current_bytes_after[0]; + for i in 1..4 { + current_after += + current_bytes_after[i] * P::from(FE::from_canonical_usize(1 << (8 * i))); + } + yield_constr + .constraint_transition(is_full_input_block * (*next_before - current_after)); + } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .zip(next_values.original_rate_u32s.iter()) + .zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter()) { yield_constr.constraint_transition(is_full_input_block * (next_before - current_after)); } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .skip(KECCAK_RATE_U32S) + .skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S) .zip(next_values.original_capacity_u32s.iter()) { yield_constr.constraint_transition(is_full_input_block * (next_before - current_after)); } - // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136. + // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`. yield_constr.constraint_transition( is_full_input_block - * (already_absorbed_bytes + P::from(FE::from_canonical_u64(136)) + * (already_absorbed_bytes + P::from(FE::from_canonical_usize(KECCAK_RATE_BYTES)) - next_values.already_absorbed_bytes), ); @@ -481,16 +532,6 @@ impl, const D: usize> Stark for KeccakSpongeS let entry_match = offset - P::from(FE::from_canonical_usize(i)); yield_constr.constraint(is_final_len * entry_match); } - - // Adding constraints for byte columns. - for (l, &elt) in local_values.updated_state_u32s[..8].iter().enumerate() { - let mut s = local_values.updated_state_bytes[l * 4]; - for i in 1..4 { - s += local_values.updated_state_bytes[l * 4 + i] - * P::from(FE::from_canonical_usize(1 << (8 * i))); - } - yield_constr.constraint(is_final_block * (s - elt)); - } } fn eval_ext_circuit( @@ -566,19 +607,36 @@ impl, const D: usize> Stark for KeccakSpongeS yield_constr.constraint_transition(builder, constraint); // If this is a full-input block, the next row's "before" should match our "after" state. + for (current_bytes_after, next_before) in local_values + .updated_digest_state_bytes + .chunks_exact(4) + .zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S]) + { + let mut current_after = current_bytes_after[0]; + for i in 1..4 { + current_after = builder.mul_const_add_extension( + F::from_canonical_usize(1 << (8 * i)), + current_bytes_after[i], + current_after, + ); + } + let diff = builder.sub_extension(*next_before, current_after); + let constraint = builder.mul_extension(is_full_input_block, diff); + yield_constr.constraint_transition(builder, constraint); + } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .zip(next_values.original_rate_u32s.iter()) + .zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter()) { let diff = builder.sub_extension(next_before, current_after); let constraint = builder.mul_extension(is_full_input_block, diff); yield_constr.constraint_transition(builder, constraint); } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .skip(KECCAK_RATE_U32S) + .skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S) .zip(next_values.original_capacity_u32s.iter()) { let diff = builder.sub_extension(next_before, current_after); @@ -586,9 +644,11 @@ impl, const D: usize> Stark for KeccakSpongeS yield_constr.constraint_transition(builder, constraint); } - // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136. - let absorbed_bytes = - builder.add_const_extension(already_absorbed_bytes, F::from_canonical_u64(136)); + // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`. + let absorbed_bytes = builder.add_const_extension( + already_absorbed_bytes, + F::from_canonical_usize(KECCAK_RATE_BYTES), + ); let absorbed_diff = builder.sub_extension(absorbed_bytes, next_values.already_absorbed_bytes); let constraint = builder.mul_extension(is_full_input_block, absorbed_diff); @@ -615,21 +675,6 @@ impl, const D: usize> Stark for KeccakSpongeS let constraint = builder.mul_extension(is_final_len, entry_match); yield_constr.constraint(builder, constraint); } - - // Adding constraints for byte columns. - for (l, &elt) in local_values.updated_state_u32s[..8].iter().enumerate() { - let mut s = local_values.updated_state_bytes[l * 4]; - for i in 1..4 { - s = builder.mul_const_add_extension( - F::from_canonical_usize(1 << (8 * i)), - local_values.updated_state_bytes[l * 4 + i], - s, - ); - } - let constraint = builder.sub_extension(s, elt); - let constraint = builder.mul_extension(is_final_block, constraint); - yield_constr.constraint(builder, constraint); - } } fn constraint_degree(&self) -> usize { @@ -698,9 +743,10 @@ mod tests { let rows = stark.generate_rows_for_op(op); assert_eq!(rows.len(), 1); let last_row: &KeccakSpongeColumnsView = rows.last().unwrap().borrow(); - let output = last_row.updated_state_u32s[..8] + let output = last_row + .updated_digest_state_bytes .iter() - .flat_map(|x| (x.to_canonical_u64() as u32).to_le_bytes()) + .map(|x| x.to_canonical_u64() as u8) .collect_vec(); assert_eq!(output, expected_output.0);