diff --git a/evm/src/cpu/columns/general.rs b/evm/src/cpu/columns/general.rs index db7436ba..affd676d 100644 --- a/evm/src/cpu/columns/general.rs +++ b/evm/src/cpu/columns/general.rs @@ -107,10 +107,13 @@ pub(crate) struct CpuArithmeticView { #[derive(Copy, Clone)] pub(crate) struct CpuLogicView { - // Assuming a limb size of 16 bits. This can be changed, but it must be <= 28 bits. - pub(crate) input0: [T; 16], - pub(crate) input1: [T; 16], - pub(crate) output: [T; 16], + // Assuming a limb size of 32 bits. + pub(crate) input0: [T; 8], + pub(crate) input1: [T; 8], + pub(crate) output: [T; 8], + + // Pseudoinverse of `(input0 - input1)`. Used prove that they are unequal. + pub(crate) diff_pinv: [T; 8], } #[derive(Copy, Clone)] diff --git a/evm/src/cpu/columns/mod.rs b/evm/src/cpu/columns/mod.rs index 824ae13d..3016b2fd 100644 --- a/evm/src/cpu/columns/mod.rs +++ b/evm/src/cpu/columns/mod.rs @@ -158,9 +158,6 @@ pub struct CpuColumnsView { pub(crate) general: CpuGeneralColumnsView, - pub simple_logic_diff: T, - pub simple_logic_diff_inv: T, - pub(crate) clock: T, /// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise /// 0. diff --git a/evm/src/cpu/simple_logic/eq_iszero.rs b/evm/src/cpu/simple_logic/eq_iszero.rs index 75bb8bb6..e1b33dc9 100644 --- a/evm/src/cpu/simple_logic/eq_iszero.rs +++ b/evm/src/cpu/simple_logic/eq_iszero.rs @@ -1,3 +1,4 @@ +use itertools::izip; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; use plonky2::hash::hash_types::RichField; @@ -6,8 +7,6 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; -const LIMB_SIZE: usize = 16; - pub fn generate(lv: &mut CpuColumnsView) { let logic = lv.general.logic_mut(); let eq_filter = lv.is_eq.to_canonical_u64(); @@ -16,34 +15,36 @@ pub fn generate(lv: &mut CpuColumnsView) { assert!(iszero_filter <= 1); assert!(eq_filter + iszero_filter <= 1); - if eq_filter != 1 && iszero_filter != 1 { + if eq_filter + iszero_filter == 0 { return; } - let diffs = if eq_filter == 1 { - logic - .input0 - .into_iter() - .zip(logic.input1) - .map(|(in0, in1)| { - assert_eq!(in0.to_canonical_u64() >> LIMB_SIZE, 0); - assert_eq!(in1.to_canonical_u64() >> LIMB_SIZE, 0); - let diff = in0 - in1; - diff.square() - }) - .sum() - } else if iszero_filter == 1 { - logic.input0.into_iter().sum() - } else { - panic!() - }; + if iszero_filter != 0 { + for limb in logic.input1.iter_mut() { + *limb = F::ZERO; + } + } - lv.simple_logic_diff = diffs; - lv.simple_logic_diff_inv = diffs.try_inverse().unwrap_or(F::ZERO); + let num_unequal_limbs = izip!(logic.input0, logic.input1) + .map(|(limb0, limb1)| (limb0 != limb1) as usize) + .sum(); + let equal = num_unequal_limbs == 0; - logic.output[0] = F::from_bool(diffs == F::ZERO); - for out_limb_ref in logic.output[1..].iter_mut() { - *out_limb_ref = F::ZERO; + logic.output[0] = F::from_bool(equal); + for limb in &mut logic.output[1..] { + *limb = F::ZERO; + } + + // Form `diff_pinv`. + // Let `diff = input0 - input1`. Consider `x[i] = diff[i]^-1` if `diff[i] != 0` and 0 otherwise. + // Then `diff @ x = num_unequal_limbs`, where `@` denotes the dot product. We set + // `diff_pinv = num_unequal_limbs^-1 * x` if `num_unequal_limbs != 0` and 0 otherwise. We have + // `diff @ diff_pinv = 1 - equal` as desired. + let num_unequal_limbs_inv = F::from_canonical_usize(num_unequal_limbs) + .try_inverse() + .unwrap_or(F::ZERO); + for (limb_pinv, limb0, limb1) in izip!(logic.diff_pinv.iter_mut(), logic.input0, logic.input1) { + *limb_pinv = (limb0 - limb1).try_inverse().unwrap_or(F::ZERO) * num_unequal_limbs_inv; } } @@ -56,36 +57,35 @@ pub fn eval_packed( let iszero_filter = lv.is_iszero; let eq_or_iszero_filter = eq_filter + iszero_filter; - let ls_bit = logic.output[0]; + let equal = logic.output[0]; + let unequal = P::ONES - equal; - // Handle EQ and ISZERO. Most limbs of the output are 0, but the least-significant one is + // Handle `EQ` and `ISZERO`. Most limbs of the output are 0, but the least-significant one is // either 0 or 1. - yield_constr.constraint(eq_or_iszero_filter * ls_bit * (ls_bit - P::ONES)); - - for &bit in &logic.output[1..] { - yield_constr.constraint(eq_or_iszero_filter * bit); + yield_constr.constraint(eq_or_iszero_filter * equal * unequal); + for &limb in &logic.output[1..] { + yield_constr.constraint(eq_or_iszero_filter * limb); } - // Check SIMPLE_LOGIC_DIFF - let diffs = lv.simple_logic_diff; - let diffs_inv = lv.simple_logic_diff_inv; - { - let input0_sum: P = logic.input0.into_iter().sum(); - yield_constr.constraint(iszero_filter * (diffs - input0_sum)); - - let sum_squared_diffs: P = logic - .input0 - .into_iter() - .zip(logic.input1) - .map(|(in0, in1)| (in0 - in1).square()) - .sum(); - yield_constr.constraint(eq_filter * (diffs - sum_squared_diffs)); + // If `ISZERO`, constrain input1 to be zero, effectively implementing ISZERO(x) as EQ(x, 0). + for limb in logic.input1 { + yield_constr.constraint(iszero_filter * limb); } - // diffs != 0 => ls_bit == 0 - yield_constr.constraint(eq_or_iszero_filter * diffs * ls_bit); - // ls_bit == 0 => diffs != 0 (we provide a diffs_inv) - yield_constr.constraint(eq_or_iszero_filter * (diffs * diffs_inv + ls_bit - P::ONES)); + // `equal` implies `input0[i] == input1[i]` for all `i`. + for (limb0, limb1) in izip!(logic.input0, logic.input1) { + let diff = limb0 - limb1; + yield_constr.constraint(eq_or_iszero_filter * equal * diff); + } + + // `input0[i] == input1[i]` for all `i` implies `equal`. + // If `unequal`, find `diff_pinv` such that `(input0 - input1) @ diff_pinv == 1`, where `@` + // denotes the dot product (there will be many such `diff_pinv`). This can only be done if + // `input0 != input1`. + let dot: P = izip!(logic.input0, logic.input1, logic.diff_pinv) + .map(|(limb0, limb1, diff_pinv_el)| (limb0 - limb1) * diff_pinv_el) + .sum(); + yield_constr.constraint(eq_or_iszero_filter * (dot - unequal)); } pub fn eval_ext_circuit, const D: usize>( @@ -93,61 +93,57 @@ pub fn eval_ext_circuit, const D: usize>( lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { + let zero = builder.zero_extension(); + let one = builder.one_extension(); + let logic = lv.general.logic(); let eq_filter = lv.is_eq; let iszero_filter = lv.is_iszero; let eq_or_iszero_filter = builder.add_extension(eq_filter, iszero_filter); - let ls_bit = logic.output[0]; + let equal = logic.output[0]; + let unequal = builder.sub_extension(one, equal); - // Handle EQ and ISZERO. Most limbs of the output are 0, but the least-significant one is + // Handle `EQ` and `ISZERO`. Most limbs of the output are 0, but the least-significant one is // either 0 or 1. { - let constr = builder.mul_sub_extension(ls_bit, ls_bit, ls_bit); + let constr = builder.mul_extension(equal, unequal); + let constr = builder.mul_extension(eq_or_iszero_filter, constr); + yield_constr.constraint(builder, constr); + } + for &limb in &logic.output[1..] { + let constr = builder.mul_extension(eq_or_iszero_filter, limb); + yield_constr.constraint(builder, constr); + } + + // If `ISZERO`, constrain input1 to be zero, effectively implementing ISZERO(x) as EQ(x, 0). + for limb in logic.input1 { + let constr = builder.mul_extension(iszero_filter, limb); + yield_constr.constraint(builder, constr); + } + + // `equal` implies `input0[i] == input1[i]` for all `i`. + for (limb0, limb1) in izip!(logic.input0, logic.input1) { + let diff = builder.sub_extension(limb0, limb1); + let constr = builder.mul_extension(equal, diff); let constr = builder.mul_extension(eq_or_iszero_filter, constr); yield_constr.constraint(builder, constr); } - for &bit in &logic.output[1..] { - let constr = builder.mul_extension(eq_or_iszero_filter, bit); - yield_constr.constraint(builder, constr); - } - - // Check SIMPLE_LOGIC_DIFF - let diffs = lv.simple_logic_diff; - let diffs_inv = lv.simple_logic_diff_inv; + // `input0[i] == input1[i]` for all `i` implies `equal`. + // If `unequal`, find `diff_pinv` such that `(input0 - input1) @ diff_pinv == 1`, where `@` + // denotes the dot product (there will be many such `diff_pinv`). This can only be done if + // `input0 != input1`. { - let input0_sum = builder.add_many_extension(logic.input0); - { - let constr = builder.sub_extension(diffs, input0_sum); - let constr = builder.mul_extension(iszero_filter, constr); - yield_constr.constraint(builder, constr); - } - - let sum_squared_diffs = logic.input0.into_iter().zip(logic.input1).fold( - builder.zero_extension(), - |acc, (in0, in1)| { - let diff = builder.sub_extension(in0, in1); - builder.mul_add_extension(diff, diff, acc) + let dot: ExtensionTarget = izip!(logic.input0, logic.input1, logic.diff_pinv).fold( + zero, + |cumul, (limb0, limb1, diff_pinv_el)| { + let diff = builder.sub_extension(limb0, limb1); + builder.mul_add_extension(diff, diff_pinv_el, cumul) }, ); - { - let constr = builder.sub_extension(diffs, sum_squared_diffs); - let constr = builder.mul_extension(eq_filter, constr); - yield_constr.constraint(builder, constr); - } - } - - { - // diffs != 0 => ls_bit == 0 - let constr = builder.mul_extension(diffs, ls_bit); + let constr = builder.sub_extension(dot, unequal); let constr = builder.mul_extension(eq_or_iszero_filter, constr); yield_constr.constraint(builder, constr); } - { - // ls_bit == 0 => diffs != 0 (we provide a diffs_inv) - let constr = builder.mul_add_extension(diffs, diffs_inv, ls_bit); - let constr = builder.mul_sub_extension(eq_or_iszero_filter, constr, eq_or_iszero_filter); - yield_constr.constraint(builder, constr); - } } diff --git a/evm/src/cpu/simple_logic/not.rs b/evm/src/cpu/simple_logic/not.rs index efbf51a6..bcff3344 100644 --- a/evm/src/cpu/simple_logic/not.rs +++ b/evm/src/cpu/simple_logic/not.rs @@ -7,7 +7,7 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; -const LIMB_SIZE: usize = 16; +const LIMB_SIZE: usize = 32; const ALL_1_LIMB: u64 = (1 << LIMB_SIZE) - 1; pub fn generate(lv: &mut CpuColumnsView) { diff --git a/evm/src/logic.rs b/evm/src/logic.rs index bde5d645..119c3d32 100644 --- a/evm/src/logic.rs +++ b/evm/src/logic.rs @@ -17,7 +17,7 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; // Total number of bits per input/output. const VAL_BITS: usize = 256; // Number of bits stored per field element. Ensure that this fits; it is not checked. -pub(crate) const PACKED_LIMB_BITS: usize = 16; +pub(crate) const PACKED_LIMB_BITS: usize = 32; // Number of field elements needed to store each input/output at the specified packing. const PACKED_LEN: usize = (VAL_BITS + PACKED_LIMB_BITS - 1) / PACKED_LIMB_BITS;