diff --git a/ecdsa/src/gadgets/biguint.rs b/ecdsa/src/gadgets/biguint.rs index 1dbe4657..faae365c 100644 --- a/ecdsa/src/gadgets/biguint.rs +++ b/ecdsa/src/gadgets/biguint.rs @@ -7,10 +7,10 @@ use plonky2::iop::target::{BoolTarget, Target}; use plonky2::iop::witness::{PartitionWitness, Witness}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2_field::extension::Extendable; -use plonky2_field::types::PrimeField; +use plonky2_field::types::{PrimeField, PrimeField64}; use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; use plonky2_u32::gadgets::multiple_comparison::list_le_u32_circuit; -use plonky2_u32::witness::{generated_values_set_u32_target, witness_set_u32_target}; +use plonky2_u32::witness::{GeneratedValuesU32, WitnessU32}; #[derive(Clone, Debug)] pub struct BigUintTarget { @@ -270,41 +270,44 @@ impl, const D: usize> CircuitBuilderBiguint } } -pub fn witness_get_biguint_target, F: PrimeField>( - witness: &W, - bt: BigUintTarget, -) -> BigUint { - bt.limbs - .into_iter() - .rev() - .fold(BigUint::zero(), |acc, limb| { - (acc << 32) + witness.get_target(limb.0).to_canonical_biguint() - }) +pub trait WitnessBigUint: Witness { + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint; + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); } -pub fn witness_set_biguint_target, F: PrimeField>( - witness: &mut W, - target: &BigUintTarget, - value: &BigUint, -) { - let mut limbs = value.to_u32_digits(); - assert!(target.num_limbs() >= limbs.len()); - limbs.resize(target.num_limbs(), 0); - for i in 0..target.num_limbs() { - witness_set_u32_target(witness, target.limbs[i], limbs[i]); +impl, F: PrimeField64> WitnessBigUint for T { + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + target + .limbs + .into_iter() + .rev() + .fold(BigUint::zero(), |acc, limb| { + (acc << 32) + self.get_target(limb.0).to_canonical_biguint() + }) + } + + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + self.set_u32_target(target.limbs[i], limbs[i]); + } } } -pub fn buffer_set_biguint_target( - buffer: &mut GeneratedValues, - target: &BigUintTarget, - value: &BigUint, -) { - let mut limbs = value.to_u32_digits(); - assert!(target.num_limbs() >= limbs.len()); - limbs.resize(target.num_limbs(), 0); - for i in 0..target.num_limbs() { - generated_values_set_u32_target(buffer, target.get_limb(i), limbs[i]); +pub trait GeneratedValuesBigUint { + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); +} + +impl GeneratedValuesBigUint for GeneratedValues { + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + self.set_u32_target(target.get_limb(i), limbs[i]); + } } } @@ -330,12 +333,12 @@ impl, const D: usize> SimpleGenerator } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness_get_biguint_target(witness, self.a.clone()); - let b = witness_get_biguint_target(witness, self.b.clone()); + let a = witness.get_biguint_target(self.a.clone()); + let b = witness.get_biguint_target(self.b.clone()); let (div, rem) = a.div_rem(&b); - buffer_set_biguint_target(out_buffer, &self.div, &div); - buffer_set_biguint_target(out_buffer, &self.rem, &rem); + out_buffer.set_biguint_target(&self.div, &div); + out_buffer.set_biguint_target(&self.rem, &rem); } } @@ -350,7 +353,7 @@ mod tests { }; use rand::Rng; - use crate::gadgets::biguint::{witness_set_biguint_target, CircuitBuilderBiguint}; + use crate::gadgets::biguint::{CircuitBuilderBiguint, WitnessBigUint}; #[test] fn test_biguint_add() -> Result<()> { @@ -373,9 +376,9 @@ mod tests { let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); - witness_set_biguint_target(&mut pw, &x, &x_value); - witness_set_biguint_target(&mut pw, &y, &y_value); - witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); let data = builder.build::(); let proof = data.prove(pw).unwrap(); @@ -433,9 +436,9 @@ mod tests { let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); - witness_set_biguint_target(&mut pw, &x, &x_value); - witness_set_biguint_target(&mut pw, &y, &y_value); - witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); let data = builder.build::(); let proof = data.prove(pw).unwrap(); diff --git a/ecdsa/src/gadgets/curve_fixed_base.rs b/ecdsa/src/gadgets/curve_fixed_base.rs index 44dc9488..0fd8e841 100644 --- a/ecdsa/src/gadgets/curve_fixed_base.rs +++ b/ecdsa/src/gadgets/curve_fixed_base.rs @@ -76,7 +76,7 @@ mod tests { use crate::curve::curve_types::{Curve, CurveScalar}; use crate::curve::secp256k1::Secp256K1; - use crate::gadgets::biguint::witness_set_biguint_target; + use crate::gadgets::biguint::WitnessBigUint; use crate::gadgets::curve::CircuitBuilderCurve; use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; use crate::gadgets::nonnative::CircuitBuilderNonNative; @@ -101,7 +101,7 @@ mod tests { builder.curve_assert_valid(&res_expected); let n_target = builder.add_virtual_nonnative_target::(); - witness_set_biguint_target(&mut pw, &n_target.value, &n.to_canonical_biguint()); + pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint()); let res_target = fixed_base_curve_mul_circuit(&mut builder, g, &n_target); builder.curve_assert_valid(&res_target); diff --git a/ecdsa/src/gadgets/glv.rs b/ecdsa/src/gadgets/glv.rs index 8e62e906..4302023e 100644 --- a/ecdsa/src/gadgets/glv.rs +++ b/ecdsa/src/gadgets/glv.rs @@ -12,7 +12,7 @@ use plonky2_field::types::{Field, PrimeField}; use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S}; use crate::curve::secp256k1::Secp256K1; -use crate::gadgets::biguint::{buffer_set_biguint_target, witness_get_biguint_target}; +use crate::gadgets::biguint::{GeneratedValuesBigUint, WitnessBigUint}; use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; use crate::gadgets::curve_msm::curve_msm_circuit; use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; @@ -116,15 +116,14 @@ impl, const D: usize> SimpleGenerator } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let k = Secp256K1Scalar::from_noncanonical_biguint(witness_get_biguint_target( - witness, - self.k.value.clone(), - )); + let k = Secp256K1Scalar::from_noncanonical_biguint( + witness.get_biguint_target(self.k.value.clone()), + ); let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); - buffer_set_biguint_target(out_buffer, &self.k1.value, &k1.to_canonical_biguint()); - buffer_set_biguint_target(out_buffer, &self.k2.value, &k2.to_canonical_biguint()); + out_buffer.set_biguint_target(&self.k1.value, &k1.to_canonical_biguint()); + out_buffer.set_biguint_target(&self.k2.value, &k2.to_canonical_biguint()); out_buffer.set_bool_target(self.k1_neg, k1_neg); out_buffer.set_bool_target(self.k2_neg, k2_neg); } diff --git a/ecdsa/src/gadgets/nonnative.rs b/ecdsa/src/gadgets/nonnative.rs index 393aac75..c6ff4753 100644 --- a/ecdsa/src/gadgets/nonnative.rs +++ b/ecdsa/src/gadgets/nonnative.rs @@ -10,11 +10,11 @@ use plonky2_field::types::PrimeField; use plonky2_field::{extension::Extendable, types::Field}; use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; use plonky2_u32::gadgets::range_check::range_check_u32_circuit; -use plonky2_u32::witness::generated_values_set_u32_target; +use plonky2_u32::witness::GeneratedValuesU32; use plonky2_util::ceil_div_usize; use crate::gadgets::biguint::{ - buffer_set_biguint_target, witness_get_biguint_target, BigUintTarget, CircuitBuilderBiguint, + BigUintTarget, CircuitBuilderBiguint, GeneratedValuesBigUint, WitnessBigUint, }; #[derive(Clone, Debug)] @@ -467,14 +467,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, - self.a.value.clone(), - )); - let b = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, - self.b.value.clone(), - )); + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); let sum_biguint = a_biguint + b_biguint; @@ -485,7 +479,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat (false, sum_biguint) }; - buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); + out_buffer.set_biguint_target(&self.sum.value, &sum_reduced); out_buffer.set_bool_target(self.overflow, overflow); } } @@ -514,10 +508,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat .summands .iter() .map(|summand| { - FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, - summand.value.clone(), - )) + FF::from_noncanonical_biguint(witness.get_biguint_target(summand.value.clone())) }) .collect(); let summand_biguints: Vec<_> = summands @@ -533,8 +524,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); let overflow = overflow_biguint.to_u64_digits()[0] as u32; - buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); - generated_values_set_u32_target(out_buffer, self.overflow, overflow); + out_buffer.set_biguint_target(&self.sum.value, &sum_reduced); + out_buffer.set_u32_target(self.overflow, overflow); } } @@ -562,14 +553,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, - self.a.value.clone(), - )); - let b = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, - self.b.value.clone(), - )); + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); @@ -580,7 +565,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat (modulus + a_biguint - b_biguint, true) }; - buffer_set_biguint_target(out_buffer, &self.diff.value, &diff_biguint); + out_buffer.set_biguint_target(&self.diff.value, &diff_biguint); out_buffer.set_bool_target(self.overflow, overflow); } } @@ -609,14 +594,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, - self.a.value.clone(), - )); - let b = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, - self.b.value.clone(), - )); + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); @@ -625,8 +604,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let modulus = FF::order(); let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); - buffer_set_biguint_target(out_buffer, &self.prod.value, &prod_reduced); - buffer_set_biguint_target(out_buffer, &self.overflow, &overflow_biguint); + out_buffer.set_biguint_target(&self.prod.value, &prod_reduced); + out_buffer.set_biguint_target(&self.overflow, &overflow_biguint); } } @@ -646,10 +625,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let x = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, - self.x.value.clone(), - )); + let x = FF::from_noncanonical_biguint(witness.get_biguint_target(self.x.value.clone())); let inv = x.inverse(); let x_biguint = x.to_canonical_biguint(); @@ -658,8 +634,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let modulus = FF::order(); let (div, _rem) = prod.div_rem(&modulus); - buffer_set_biguint_target(out_buffer, &self.div, &div); - buffer_set_biguint_target(out_buffer, &self.inv, &inv_biguint); + out_buffer.set_biguint_target(&self.div, &div); + out_buffer.set_biguint_target(&self.inv, &inv_biguint); } } diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 0d822ec6..2b0fa6b9 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -41,28 +41,24 @@ impl, const D: usize> Default for AllStark { } impl, const D: usize> AllStark { - pub(crate) fn nums_permutation_zs(&self, config: &StarkConfig) -> Vec { - let ans = vec![ + pub(crate) fn nums_permutation_zs(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { + [ self.cpu_stark.num_permutation_batches(config), self.keccak_stark.num_permutation_batches(config), self.keccak_memory_stark.num_permutation_batches(config), self.logic_stark.num_permutation_batches(config), self.memory_stark.num_permutation_batches(config), - ]; - debug_assert_eq!(ans.len(), Table::num_tables()); - ans + ] } - pub(crate) fn permutation_batch_sizes(&self) -> Vec { - let ans = vec![ + pub(crate) fn permutation_batch_sizes(&self) -> [usize; NUM_TABLES] { + [ self.cpu_stark.permutation_batch_size(), self.keccak_stark.permutation_batch_size(), self.keccak_memory_stark.permutation_batch_size(), self.logic_stark.permutation_batch_size(), self.memory_stark.permutation_batch_size(), - ]; - debug_assert_eq!(ans.len(), Table::num_tables()); - ans + ] } } @@ -75,11 +71,7 @@ pub enum Table { Memory = 4, } -impl Table { - pub(crate) const fn num_tables() -> usize { - Table::Memory as usize + 1 - } -} +pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; #[allow(unused)] // TODO: Should be used soon. pub(crate) fn all_cross_table_lookups() -> Vec> { @@ -326,6 +318,46 @@ mod tests { cpu_trace_rows.push(row.into()); } + // Pad to `num_memory_ops` for memory testing. + for _ in cpu_trace_rows.len()..num_memory_ops { + let mut row: cpu::columns::CpuColumnsView = + [F::ZERO; CpuStark::::COLUMNS].into(); + row.opcode_bits = bits_from_opcode(0x5b); + row.is_cpu_cycle = F::ONE; + row.is_kernel_mode = F::ONE; + row.program_counter = F::from_canonical_usize(KERNEL.global_labels["route_txn"]); + cpu_stark.generate(row.borrow_mut()); + cpu_trace_rows.push(row.into()); + } + + for i in 0..num_memory_ops { + let mem_timestamp: usize = memory_trace[memory::columns::TIMESTAMP].values[i] + .to_canonical_u64() + .try_into() + .unwrap(); + let clock = mem_timestamp / NUM_CHANNELS; + let channel = mem_timestamp % NUM_CHANNELS; + + let filter = memory_trace[memory::columns::FILTER].values[i]; + assert!(filter.is_one() || filter.is_zero()); + let is_actual_op = filter.is_one(); + + if is_actual_op { + let row: &mut cpu::columns::CpuColumnsView = cpu_trace_rows[clock].borrow_mut(); + row.clock = F::from_canonical_usize(clock); + + let channel = &mut row.mem_channels[channel]; + channel.used = F::ONE; + channel.is_read = memory_trace[memory::columns::IS_READ].values[i]; + channel.addr_context = memory_trace[memory::columns::ADDR_CONTEXT].values[i]; + channel.addr_segment = memory_trace[memory::columns::ADDR_SEGMENT].values[i]; + channel.addr_virtual = memory_trace[memory::columns::ADDR_VIRTUAL].values[i]; + for j in 0..8 { + channel.value[j] = memory_trace[memory::columns::value_limb(j)].values[i]; + } + } + } + for i in 0..num_logic_rows { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); @@ -346,55 +378,31 @@ mod tests { }, ); - let logic = row.general.logic_mut(); - let input0_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT0); - for (col_cpu, limb_cols_logic) in logic.input0.iter_mut().zip(input0_bit_cols) { + for (col_cpu, limb_cols_logic) in + row.mem_channels[0].value.iter_mut().zip(input0_bit_cols) + { *col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i])); } let input1_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT1); - for (col_cpu, limb_cols_logic) in logic.input1.iter_mut().zip(input1_bit_cols) { + for (col_cpu, limb_cols_logic) in + row.mem_channels[1].value.iter_mut().zip(input1_bit_cols) + { *col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i])); } - for (col_cpu, col_logic) in logic.output.iter_mut().zip(logic::columns::RESULT) { + for (col_cpu, col_logic) in row.mem_channels[2] + .value + .iter_mut() + .zip(logic::columns::RESULT) + { *col_cpu = logic_trace[col_logic].values[i]; } cpu_stark.generate(row.borrow_mut()); cpu_trace_rows.push(row.into()); } - for i in 0..num_memory_ops { - let mem_timestamp: usize = memory_trace[memory::columns::TIMESTAMP].values[i] - .to_canonical_u64() - .try_into() - .unwrap(); - let clock = mem_timestamp / NUM_CHANNELS; - let channel = mem_timestamp % NUM_CHANNELS; - - let filter = memory_trace[memory::columns::FILTER].values[i]; - assert!(filter.is_one() || filter.is_zero()); - let is_actual_op = filter.is_one(); - - if is_actual_op { - let row: &mut cpu::columns::CpuColumnsView = cpu_trace_rows[clock].borrow_mut(); - - row.mem_channel_used[channel] = F::ONE; - row.clock = F::from_canonical_usize(clock); - row.mem_is_read[channel] = memory_trace[memory::columns::IS_READ].values[i]; - row.mem_addr_context[channel] = - memory_trace[memory::columns::ADDR_CONTEXT].values[i]; - row.mem_addr_segment[channel] = - memory_trace[memory::columns::ADDR_SEGMENT].values[i]; - row.mem_addr_virtual[channel] = - memory_trace[memory::columns::ADDR_VIRTUAL].values[i]; - for j in 0..8 { - row.mem_value[channel][j] = - memory_trace[memory::columns::value_limb(j)].values[i]; - } - } - } // Trap to kernel { @@ -406,7 +414,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x0a); // `EXP` is implemented in software row.is_kernel_mode = F::ONE; row.program_counter = last_row.program_counter + F::ONE; - row.general.syscalls_mut().output = [ + row.mem_channels[0].value = [ row.program_counter, F::ONE, F::ZERO, @@ -428,7 +436,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0xf9); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_usize(KERNEL.global_labels["sys_exp"]); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(15682), F::ONE, F::ZERO, @@ -450,7 +458,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_u16(15682); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(15106), F::ZERO, F::ZERO, @@ -460,7 +468,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.general.jumps_mut().input1 = [ + row.mem_channels[1].value = [ F::ONE, F::ZERO, F::ZERO, @@ -487,7 +495,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0xf9); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_u16(15106); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(63064), F::ZERO, F::ZERO, @@ -509,7 +517,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(63064); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(3754), F::ZERO, F::ZERO, @@ -519,7 +527,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.general.jumps_mut().input1 = [ + row.mem_channels[1].value = [ F::ONE, F::ZERO, F::ZERO, @@ -547,7 +555,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x57); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(3754); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(37543), F::ZERO, F::ZERO, @@ -557,7 +565,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.general.jumps_mut().input1 = [ + row.mem_channels[1].value = [ F::ZERO, F::ZERO, F::ZERO, @@ -585,7 +593,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x57); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(37543); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(37543), F::ZERO, F::ZERO, @@ -614,7 +622,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ZERO; row.program_counter = last_row.program_counter + F::ONE; - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(37543), F::ZERO, F::ZERO, @@ -624,7 +632,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.general.jumps_mut().input1 = [ + row.mem_channels[1].value = [ F::ONE, F::ZERO, F::ZERO, @@ -695,7 +703,7 @@ mod tests { &mut memory_trace, ); - let traces = vec![ + let traces = [ cpu_trace, keccak_trace, keccak_memory_trace, diff --git a/evm/src/cpu/columns/general.rs b/evm/src/cpu/columns/general.rs index affd676d..134788dc 100644 --- a/evm/src/cpu/columns/general.rs +++ b/evm/src/cpu/columns/general.rs @@ -9,7 +9,6 @@ pub(crate) union CpuGeneralColumnsView { arithmetic: CpuArithmeticView, logic: CpuLogicView, jumps: CpuJumpsView, - syscalls: CpuSyscallsView, } impl CpuGeneralColumnsView { @@ -52,16 +51,6 @@ impl CpuGeneralColumnsView { pub(crate) fn jumps_mut(&mut self) -> &mut CpuJumpsView { unsafe { &mut self.jumps } } - - // SAFETY: Each view is a valid interpretation of the underlying array. - pub(crate) fn syscalls(&self) -> &CpuSyscallsView { - unsafe { &self.syscalls } - } - - // SAFETY: Each view is a valid interpretation of the underlying array. - pub(crate) fn syscalls_mut(&mut self) -> &mut CpuSyscallsView { - unsafe { &mut self.syscalls } - } } impl PartialEq for CpuGeneralColumnsView { @@ -107,23 +96,16 @@ pub(crate) struct CpuArithmeticView { #[derive(Copy, Clone)] pub(crate) struct CpuLogicView { - // Assuming a limb size of 32 bits. - pub(crate) input0: [T; 8], - pub(crate) input1: [T; 8], - pub(crate) output: [T; 8], - - // Pseudoinverse of `(input0 - input1)`. Used prove that they are unequal. + // Pseudoinverse of `(input0 - input1)`. Used prove that they are unequal. Assumes 32-bit limbs. pub(crate) diff_pinv: [T; 8], } #[derive(Copy, Clone)] pub(crate) struct CpuJumpsView { - /// Assuming a limb size of 32 bits. - /// The top stack value at entry (for jumps, the address; for `EXIT_KERNEL`, the address and new - /// privilege level). - pub(crate) input0: [T; 8], - /// For `JUMPI`, the second stack value (the predicate). For `JUMP`, 1. - pub(crate) input1: [T; 8], + /// `input0` is `mem_channel[0].value`. It's the top stack value at entry (for jumps, the + /// address; for `EXIT_KERNEL`, the address and new privilege level). + /// `input1` is `mem_channel[1].value`. For `JUMPI`, it's the second stack value (the + /// predicate). For `JUMP`, 1. /// Inverse of `input0[1] + ... + input0[7]`, if one exists; otherwise, an arbitrary value. /// Needed to prove that `input0` is nonzero. @@ -162,15 +144,5 @@ pub(crate) struct CpuJumpsView { pub(crate) should_trap: T, } -#[derive(Copy, Clone)] -pub(crate) struct CpuSyscallsView { - /// Assuming a limb size of 32 bits. - /// The output contains the context that is required to from the system call in `EXIT_KERNEL`. - /// `output[0]` contains the program counter at the time the system call was made (the address - /// of the syscall instruction). `output[1]` is 1 if we were in kernel mode at the time and 0 - /// otherwise. `output[2]`, ..., `output[7]` are zero. - pub(crate) output: [T; 8], -} - // `u8` is guaranteed to have a `size_of` of 1. pub const NUM_SHARED_COLUMNS: usize = size_of::>(); diff --git a/evm/src/cpu/columns/mod.rs b/evm/src/cpu/columns/mod.rs index 564ea246..567c5a97 100644 --- a/evm/src/cpu/columns/mod.rs +++ b/evm/src/cpu/columns/mod.rs @@ -3,14 +3,28 @@ use std::borrow::{Borrow, BorrowMut}; use std::fmt::Debug; -use std::mem::{size_of, transmute, transmute_copy, ManuallyDrop}; +use std::mem::{size_of, transmute}; use std::ops::{Index, IndexMut}; use crate::cpu::columns::general::CpuGeneralColumnsView; use crate::memory; +use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; mod general; +#[repr(C)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct MemoryChannelView { + /// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise + /// 0. + pub used: T, + pub is_read: T, + pub addr_context: T, + pub addr_segment: T, + pub addr_virtual: T, + pub value: [T; memory::VALUE_LIMBS], +} + #[repr(C)] #[derive(Eq, PartialEq, Debug)] pub struct CpuColumnsView { @@ -124,28 +138,7 @@ pub struct CpuColumnsView { pub is_revert: T, pub is_selfdestruct: T, - // An instruction is invalid if _any_ of the below flags is 1. - pub is_invalid_0: T, - pub is_invalid_1: T, - pub is_invalid_2: T, - pub is_invalid_3: T, - pub is_invalid_4: T, - pub is_invalid_5: T, - pub is_invalid_6: T, - pub is_invalid_7: T, - pub is_invalid_8: T, - pub is_invalid_9: T, - pub is_invalid_10: T, - pub is_invalid_11: T, - pub is_invalid_12: T, - pub is_invalid_13: T, - pub is_invalid_14: T, - pub is_invalid_15: T, - pub is_invalid_16: T, - pub is_invalid_17: T, - pub is_invalid_18: T, - pub is_invalid_19: T, - pub is_invalid_20: T, + pub is_invalid: T, /// If CPU cycle: the opcode, broken up into bits in little-endian order. pub opcode_bits: [T; 8], @@ -159,27 +152,12 @@ pub struct CpuColumnsView { pub(crate) general: CpuGeneralColumnsView, pub(crate) clock: T, - /// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise - /// 0. - pub mem_channel_used: [T; memory::NUM_CHANNELS], - pub mem_is_read: [T; memory::NUM_CHANNELS], - pub mem_addr_context: [T; memory::NUM_CHANNELS], - pub mem_addr_segment: [T; memory::NUM_CHANNELS], - pub mem_addr_virtual: [T; memory::NUM_CHANNELS], - pub mem_value: [[T; memory::VALUE_LIMBS]; memory::NUM_CHANNELS], + pub mem_channels: [MemoryChannelView; memory::NUM_CHANNELS], } // `u8` is guaranteed to have a `size_of` of 1. pub const NUM_CPU_COLUMNS: usize = size_of::>(); -unsafe fn transmute_no_compile_time_size_checks(value: T) -> U { - debug_assert_eq!(size_of::(), size_of::()); - // Need ManuallyDrop so that `value` is not dropped by this function. - let value = ManuallyDrop::new(value); - // Copy the bit pattern. The original value is no longer safe to use. - transmute_copy(&value) -} - impl From<[T; NUM_CPU_COLUMNS]> for CpuColumnsView { fn from(value: [T; NUM_CPU_COLUMNS]) -> Self { unsafe { transmute_no_compile_time_size_checks(value) } @@ -239,12 +217,7 @@ where } const fn make_col_map() -> CpuColumnsView { - let mut indices_arr = [0; NUM_CPU_COLUMNS]; - let mut i = 0; - while i < NUM_CPU_COLUMNS { - indices_arr[i] = i; - i += 1; - } + let indices_arr = indices_arr::(); unsafe { transmute::<[usize; NUM_CPU_COLUMNS], CpuColumnsView>(indices_arr) } } diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 852b7b54..9fd4792d 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -28,9 +28,9 @@ pub fn ctl_data_keccak_memory() -> Vec> { // channel 1: stack[-1] = context // channel 2: stack[-2] = segment // channel 3: stack[-3] = virtual - let context = Column::single(COL_MAP.mem_value[1][0]); - let segment = Column::single(COL_MAP.mem_value[2][0]); - let virt = Column::single(COL_MAP.mem_value[3][0]); + let context = Column::single(COL_MAP.mem_channels[1].value[0]); + let segment = Column::single(COL_MAP.mem_channels[2].value[0]); + let virt = Column::single(COL_MAP.mem_channels[3].value[0]); let num_channels = F::from_canonical_usize(NUM_CHANNELS); let clock = Column::linear_combination([(COL_MAP.clock, num_channels)]); @@ -48,10 +48,9 @@ pub fn ctl_filter_keccak_memory() -> Column { pub fn ctl_data_logic() -> Vec> { let mut res = Column::singles([COL_MAP.is_and, COL_MAP.is_or, COL_MAP.is_xor]).collect_vec(); - let logic = COL_MAP.general.logic(); - res.extend(Column::singles(logic.input0)); - res.extend(Column::singles(logic.input1)); - res.extend(Column::singles(logic.output)); + res.extend(Column::singles(COL_MAP.mem_channels[0].value)); + res.extend(Column::singles(COL_MAP.mem_channels[1].value)); + res.extend(Column::singles(COL_MAP.mem_channels[2].value)); res } @@ -61,14 +60,15 @@ pub fn ctl_filter_logic() -> Column { pub fn ctl_data_memory(channel: usize) -> Vec> { debug_assert!(channel < NUM_CHANNELS); + let channel_map = COL_MAP.mem_channels[channel]; let mut cols: Vec> = Column::singles([ - COL_MAP.mem_is_read[channel], - COL_MAP.mem_addr_context[channel], - COL_MAP.mem_addr_segment[channel], - COL_MAP.mem_addr_virtual[channel], + channel_map.is_read, + channel_map.addr_context, + channel_map.addr_segment, + channel_map.addr_virtual, ]) .collect_vec(); - cols.extend(Column::singles(COL_MAP.mem_value[channel])); + cols.extend(Column::singles(channel_map.value)); let scalar = F::from_canonical_usize(NUM_CHANNELS); let addend = F::from_canonical_usize(channel); @@ -81,7 +81,7 @@ pub fn ctl_data_memory(channel: usize) -> Vec> { } pub fn ctl_filter_memory(channel: usize) -> Column { - Column::single(COL_MAP.mem_channel_used[channel]) + Column::single(COL_MAP.mem_channels[channel].used) } #[derive(Copy, Clone, Default)] diff --git a/evm/src/cpu/decode.rs b/evm/src/cpu/decode.rs index e58b474d..7ca9a650 100644 --- a/evm/src/cpu/decode.rs +++ b/evm/src/cpu/decode.rs @@ -6,14 +6,6 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP}; -#[derive(PartialEq, Eq)] -enum Availability { - All, - User, - Kernel, -} -use Availability::{All, Kernel, User}; - /// List of opcode blocks /// Each block corresponds to exactly one flag, and each flag corresponds to exactly one block. /// Each block of opcodes: @@ -28,124 +20,132 @@ use Availability::{All, Kernel, User}; /// The exception is the PANIC instruction which is user-only without a corresponding kernel block. /// This makes the proof unverifiable when PANIC is executed in kernel mode, which is the intended /// behavior. -const OPCODES: [(u8, usize, Availability, usize); 113] = [ - // (start index of block, number of top bits to check (log2), availability, flag column) - (0x00, 0, All, COL_MAP.is_stop), - (0x01, 0, All, COL_MAP.is_add), - (0x02, 0, All, COL_MAP.is_mul), - (0x03, 0, All, COL_MAP.is_sub), - (0x04, 0, All, COL_MAP.is_div), - (0x05, 0, All, COL_MAP.is_sdiv), - (0x06, 0, All, COL_MAP.is_mod), - (0x07, 0, All, COL_MAP.is_smod), - (0x08, 0, All, COL_MAP.is_addmod), - (0x09, 0, All, COL_MAP.is_mulmod), - (0x0a, 0, All, COL_MAP.is_exp), - (0x0b, 0, All, COL_MAP.is_signextend), - (0x0c, 2, All, COL_MAP.is_invalid_0), // 0x0c-0x0f - (0x10, 0, All, COL_MAP.is_lt), - (0x11, 0, All, COL_MAP.is_gt), - (0x12, 0, All, COL_MAP.is_slt), - (0x13, 0, All, COL_MAP.is_sgt), - (0x14, 0, All, COL_MAP.is_eq), - (0x15, 0, All, COL_MAP.is_iszero), - (0x16, 0, All, COL_MAP.is_and), - (0x17, 0, All, COL_MAP.is_or), - (0x18, 0, All, COL_MAP.is_xor), - (0x19, 0, All, COL_MAP.is_not), - (0x1a, 0, All, COL_MAP.is_byte), - (0x1b, 0, All, COL_MAP.is_shl), - (0x1c, 0, All, COL_MAP.is_shr), - (0x1d, 0, All, COL_MAP.is_sar), - (0x1e, 1, All, COL_MAP.is_invalid_1), // 0x1e-0x1f - (0x20, 0, All, COL_MAP.is_keccak256), - (0x21, 0, All, COL_MAP.is_invalid_2), - (0x22, 1, All, COL_MAP.is_invalid_3), // 0x22-0x23 - (0x24, 2, All, COL_MAP.is_invalid_4), // 0x24-0x27 - (0x28, 3, All, COL_MAP.is_invalid_5), // 0x28-0x2f - (0x30, 0, All, COL_MAP.is_address), - (0x31, 0, All, COL_MAP.is_balance), - (0x32, 0, All, COL_MAP.is_origin), - (0x33, 0, All, COL_MAP.is_caller), - (0x34, 0, All, COL_MAP.is_callvalue), - (0x35, 0, All, COL_MAP.is_calldataload), - (0x36, 0, All, COL_MAP.is_calldatasize), - (0x37, 0, All, COL_MAP.is_calldatacopy), - (0x38, 0, All, COL_MAP.is_codesize), - (0x39, 0, All, COL_MAP.is_codecopy), - (0x3a, 0, All, COL_MAP.is_gasprice), - (0x3b, 0, All, COL_MAP.is_extcodesize), - (0x3c, 0, All, COL_MAP.is_extcodecopy), - (0x3d, 0, All, COL_MAP.is_returndatasize), - (0x3e, 0, All, COL_MAP.is_returndatacopy), - (0x3f, 0, All, COL_MAP.is_extcodehash), - (0x40, 0, All, COL_MAP.is_blockhash), - (0x41, 0, All, COL_MAP.is_coinbase), - (0x42, 0, All, COL_MAP.is_timestamp), - (0x43, 0, All, COL_MAP.is_number), - (0x44, 0, All, COL_MAP.is_difficulty), - (0x45, 0, All, COL_MAP.is_gaslimit), - (0x46, 0, All, COL_MAP.is_chainid), - (0x47, 0, All, COL_MAP.is_selfbalance), - (0x48, 0, All, COL_MAP.is_basefee), - (0x49, 0, User, COL_MAP.is_invalid_6), - (0x49, 0, Kernel, COL_MAP.is_prover_input), - (0x4a, 1, All, COL_MAP.is_invalid_7), // 0x4a-0x4b - (0x4c, 2, All, COL_MAP.is_invalid_8), // 0x4c-0x4f - (0x50, 0, All, COL_MAP.is_pop), - (0x51, 0, All, COL_MAP.is_mload), - (0x52, 0, All, COL_MAP.is_mstore), - (0x53, 0, All, COL_MAP.is_mstore8), - (0x54, 0, All, COL_MAP.is_sload), - (0x55, 0, All, COL_MAP.is_sstore), - (0x56, 0, All, COL_MAP.is_jump), - (0x57, 0, All, COL_MAP.is_jumpi), - (0x58, 0, All, COL_MAP.is_pc), - (0x59, 0, All, COL_MAP.is_msize), - (0x5a, 0, All, COL_MAP.is_gas), - (0x5b, 0, All, COL_MAP.is_jumpdest), - (0x5c, 2, User, COL_MAP.is_invalid_9), // 0x5c-5f - (0x5c, 0, Kernel, COL_MAP.is_get_state_root), - (0x5d, 0, Kernel, COL_MAP.is_set_state_root), - (0x5e, 0, Kernel, COL_MAP.is_get_receipt_root), - (0x5f, 0, Kernel, COL_MAP.is_set_receipt_root), - (0x60, 5, All, COL_MAP.is_push), // 0x60-0x7f - (0x80, 4, All, COL_MAP.is_dup), // 0x80-0x8f - (0x90, 4, All, COL_MAP.is_swap), // 0x90-0x9f - (0xa0, 0, All, COL_MAP.is_log0), - (0xa1, 0, All, COL_MAP.is_log1), - (0xa2, 0, All, COL_MAP.is_log2), - (0xa3, 0, All, COL_MAP.is_log3), - (0xa4, 0, All, COL_MAP.is_log4), - (0xa5, 0, User, COL_MAP.is_invalid_10), +/// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to +/// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid. +const OPCODES: [(u8, usize, bool, usize); 92] = [ + // (start index of block, number of top bits to check (log2), kernel-only, flag column) + (0x00, 0, false, COL_MAP.is_stop), + (0x01, 0, false, COL_MAP.is_add), + (0x02, 0, false, COL_MAP.is_mul), + (0x03, 0, false, COL_MAP.is_sub), + (0x04, 0, false, COL_MAP.is_div), + (0x05, 0, false, COL_MAP.is_sdiv), + (0x06, 0, false, COL_MAP.is_mod), + (0x07, 0, false, COL_MAP.is_smod), + (0x08, 0, false, COL_MAP.is_addmod), + (0x09, 0, false, COL_MAP.is_mulmod), + (0x0a, 0, false, COL_MAP.is_exp), + (0x0b, 0, false, COL_MAP.is_signextend), + (0x10, 0, false, COL_MAP.is_lt), + (0x11, 0, false, COL_MAP.is_gt), + (0x12, 0, false, COL_MAP.is_slt), + (0x13, 0, false, COL_MAP.is_sgt), + (0x14, 0, false, COL_MAP.is_eq), + (0x15, 0, false, COL_MAP.is_iszero), + (0x16, 0, false, COL_MAP.is_and), + (0x17, 0, false, COL_MAP.is_or), + (0x18, 0, false, COL_MAP.is_xor), + (0x19, 0, false, COL_MAP.is_not), + (0x1a, 0, false, COL_MAP.is_byte), + (0x1b, 0, false, COL_MAP.is_shl), + (0x1c, 0, false, COL_MAP.is_shr), + (0x1d, 0, false, COL_MAP.is_sar), + (0x20, 0, false, COL_MAP.is_keccak256), + (0x30, 0, false, COL_MAP.is_address), + (0x31, 0, false, COL_MAP.is_balance), + (0x32, 0, false, COL_MAP.is_origin), + (0x33, 0, false, COL_MAP.is_caller), + (0x34, 0, false, COL_MAP.is_callvalue), + (0x35, 0, false, COL_MAP.is_calldataload), + (0x36, 0, false, COL_MAP.is_calldatasize), + (0x37, 0, false, COL_MAP.is_calldatacopy), + (0x38, 0, false, COL_MAP.is_codesize), + (0x39, 0, false, COL_MAP.is_codecopy), + (0x3a, 0, false, COL_MAP.is_gasprice), + (0x3b, 0, false, COL_MAP.is_extcodesize), + (0x3c, 0, false, COL_MAP.is_extcodecopy), + (0x3d, 0, false, COL_MAP.is_returndatasize), + (0x3e, 0, false, COL_MAP.is_returndatacopy), + (0x3f, 0, false, COL_MAP.is_extcodehash), + (0x40, 0, false, COL_MAP.is_blockhash), + (0x41, 0, false, COL_MAP.is_coinbase), + (0x42, 0, false, COL_MAP.is_timestamp), + (0x43, 0, false, COL_MAP.is_number), + (0x44, 0, false, COL_MAP.is_difficulty), + (0x45, 0, false, COL_MAP.is_gaslimit), + (0x46, 0, false, COL_MAP.is_chainid), + (0x47, 0, false, COL_MAP.is_selfbalance), + (0x48, 0, false, COL_MAP.is_basefee), + (0x49, 0, true, COL_MAP.is_prover_input), + (0x50, 0, false, COL_MAP.is_pop), + (0x51, 0, false, COL_MAP.is_mload), + (0x52, 0, false, COL_MAP.is_mstore), + (0x53, 0, false, COL_MAP.is_mstore8), + (0x54, 0, false, COL_MAP.is_sload), + (0x55, 0, false, COL_MAP.is_sstore), + (0x56, 0, false, COL_MAP.is_jump), + (0x57, 0, false, COL_MAP.is_jumpi), + (0x58, 0, false, COL_MAP.is_pc), + (0x59, 0, false, COL_MAP.is_msize), + (0x5a, 0, false, COL_MAP.is_gas), + (0x5b, 0, false, COL_MAP.is_jumpdest), + (0x5c, 0, true, COL_MAP.is_get_state_root), + (0x5d, 0, true, COL_MAP.is_set_state_root), + (0x5e, 0, true, COL_MAP.is_get_receipt_root), + (0x5f, 0, true, COL_MAP.is_set_receipt_root), + (0x60, 5, false, COL_MAP.is_push), // 0x60-0x7f + (0x80, 4, false, COL_MAP.is_dup), // 0x80-0x8f + (0x90, 4, false, COL_MAP.is_swap), // 0x90-0x9f + (0xa0, 0, false, COL_MAP.is_log0), + (0xa1, 0, false, COL_MAP.is_log1), + (0xa2, 0, false, COL_MAP.is_log2), + (0xa3, 0, false, COL_MAP.is_log3), + (0xa4, 0, false, COL_MAP.is_log4), // Opcode 0xa5 is PANIC when Kernel. Make the proof unverifiable by giving it no flag to decode to. - (0xa6, 1, All, COL_MAP.is_invalid_11), // 0xa6-0xa7 - (0xa8, 3, All, COL_MAP.is_invalid_12), // 0xa8-0xaf - (0xb0, 4, All, COL_MAP.is_invalid_13), // 0xb0-0xbf - (0xc0, 5, All, COL_MAP.is_invalid_14), // 0xc0-0xdf - (0xe0, 4, All, COL_MAP.is_invalid_15), // 0xe0-0xef - (0xf0, 0, All, COL_MAP.is_create), - (0xf1, 0, All, COL_MAP.is_call), - (0xf2, 0, All, COL_MAP.is_callcode), - (0xf3, 0, All, COL_MAP.is_return), - (0xf4, 0, All, COL_MAP.is_delegatecall), - (0xf5, 0, All, COL_MAP.is_create2), - (0xf6, 1, User, COL_MAP.is_invalid_16), // 0xf6-0xf7 - (0xf6, 0, Kernel, COL_MAP.is_get_context), - (0xf7, 0, Kernel, COL_MAP.is_set_context), - (0xf8, 1, User, COL_MAP.is_invalid_17), // 0xf8-0xf9 - (0xf8, 0, Kernel, COL_MAP.is_consume_gas), - (0xf9, 0, Kernel, COL_MAP.is_exit_kernel), - (0xfa, 0, All, COL_MAP.is_staticcall), - (0xfb, 0, User, COL_MAP.is_invalid_18), - (0xfb, 0, Kernel, COL_MAP.is_mload_general), - (0xfc, 0, User, COL_MAP.is_invalid_19), - (0xfc, 0, Kernel, COL_MAP.is_mstore_general), - (0xfd, 0, All, COL_MAP.is_revert), - (0xfe, 0, All, COL_MAP.is_invalid_20), - (0xff, 0, All, COL_MAP.is_selfdestruct), + (0xf0, 0, false, COL_MAP.is_create), + (0xf1, 0, false, COL_MAP.is_call), + (0xf2, 0, false, COL_MAP.is_callcode), + (0xf3, 0, false, COL_MAP.is_return), + (0xf4, 0, false, COL_MAP.is_delegatecall), + (0xf5, 0, false, COL_MAP.is_create2), + (0xf6, 0, true, COL_MAP.is_get_context), + (0xf7, 0, true, COL_MAP.is_set_context), + (0xf8, 0, true, COL_MAP.is_consume_gas), + (0xf9, 0, true, COL_MAP.is_exit_kernel), + (0xfa, 0, false, COL_MAP.is_staticcall), + (0xfb, 0, true, COL_MAP.is_mload_general), + (0xfc, 0, true, COL_MAP.is_mstore_general), + (0xfd, 0, false, COL_MAP.is_revert), + (0xff, 0, false, COL_MAP.is_selfdestruct), ]; +/// Bitfield of invalid opcodes, in little-endian order. +pub(crate) const fn invalid_opcodes_user() -> [u8; 32] { + let mut res = [u8::MAX; 32]; // Start with all opcodes marked invalid. + + let mut i = 0; + while i < OPCODES.len() { + let (block_start, lb_block_len, kernel_only, _) = OPCODES[i]; + i += 1; + + if kernel_only { + continue; + } + + let block_len = 1 << lb_block_len; + let block_start = block_start as usize; + let block_end = block_start + block_len; + let mut j = block_start; + while j < block_end { + let byte = j / u8::BITS as usize; + let bit = j % u8::BITS as usize; + res[byte] &= !(1 << bit); // Mark opcode as invalid by zeroing the bit. + j += 1; + } + } + res +} + pub fn generate(lv: &mut CpuColumnsView) { let cycle_filter = lv.is_cpu_cycle; if cycle_filter == F::ZERO { @@ -184,15 +184,19 @@ pub fn generate(lv: &mut CpuColumnsView) { assert!(kernel <= 1); let kernel = kernel != 0; - for (oc, block_length, availability, col) in OPCODES { - let available = match availability { - All => true, - User => !kernel, - Kernel => kernel, - }; + let mut any_flag_set = false; + for (oc, block_length, kernel_only, col) in OPCODES { + let available = !kernel_only || kernel; let opcode_match = top_bits[8 - block_length] == oc; - lv[col] = F::from_bool(available && opcode_match); + let flag = available && opcode_match; + lv[col] = F::from_bool(flag); + if flag && any_flag_set { + panic!("opcode matched multiple flags"); + } + any_flag_set = any_flag_set || flag; } + // is_invalid is a catch-all for opcodes we can't decode. + lv.is_invalid = F::from_bool(!any_flag_set); } /// Break up an opcode (which is 8 bits long) into its eight bits. @@ -230,21 +234,22 @@ pub fn eval_packed_generic( let flag = lv[flag_col]; yield_constr.constraint(cycle_filter * flag * (flag - P::ONES)); } + yield_constr.constraint(cycle_filter * lv.is_invalid * (lv.is_invalid - P::ONES)); // Now check that exactly one is 1. let flag_sum: P = OPCODES .into_iter() .map(|(_, _, _, flag_col)| lv[flag_col]) - .sum(); + .sum::

() + + lv.is_invalid; yield_constr.constraint(cycle_filter * (P::ONES - flag_sum)); // Finally, classify all opcodes, together with the kernel flag, into blocks - for (oc, block_length, availability, col) in OPCODES { - // 0 if the block/flag is available to us (is always available, is user-only and we are in - // user mode, or kernel-only and we are in kernel mode) and 1 otherwise. - let unavailable = match availability { - All => P::ZEROS, - User => kernel_mode, - Kernel => P::ONES - kernel_mode, + for (oc, block_length, kernel_only, col) in OPCODES { + // 0 if the block/flag is available to us (is always available or we are in kernel mode) and + // 1 otherwise. + let unavailable = match kernel_only { + false => P::ZEROS, + true => P::ONES - kernel_mode, }; // 0 if all the opcode bits match, and something in {1, ..., 8}, otherwise. let opcode_mismatch: P = lv @@ -299,6 +304,11 @@ pub fn eval_ext_circuit, const D: usize>( let constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); } + { + let constr = builder.mul_sub_extension(lv.is_invalid, lv.is_invalid, lv.is_invalid); + let constr = builder.mul_extension(cycle_filter, constr); + yield_constr.constraint(builder, constr); + } // Now check that exactly one is 1. { let mut constr = builder.one_extension(); @@ -306,18 +316,18 @@ pub fn eval_ext_circuit, const D: usize>( let flag = lv[flag_col]; constr = builder.sub_extension(constr, flag); } + constr = builder.sub_extension(constr, lv.is_invalid); constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); } // Finally, classify all opcodes, together with the kernel flag, into blocks - for (oc, block_length, availability, col) in OPCODES { - // 0 if the block/flag is available to us (is always available, is user-only and we are in - // user mode, or kernel-only and we are in kernel mode) and 1 otherwise. - let unavailable = match availability { - All => builder.zero_extension(), - User => kernel_mode, - Kernel => builder.sub_extension(one, kernel_mode), + for (oc, block_length, kernel_only, col) in OPCODES { + // 0 if the block/flag is available to us (is always available or we are in kernel mode) and + // 1 otherwise. + let unavailable = match kernel_only { + false => builder.zero_extension(), + true => builder.sub_extension(one, kernel_mode), }; // 0 if all the opcode bits match, and something in {1, ..., 8}, otherwise. let opcode_mismatch = lv diff --git a/evm/src/cpu/jumps.rs b/evm/src/cpu/jumps.rs index 10c9503a..219b39dd 100644 --- a/evm/src/cpu/jumps.rs +++ b/evm/src/cpu/jumps.rs @@ -17,16 +17,16 @@ pub fn eval_packed_exit_kernel( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let jumps_lv = lv.general.jumps(); + let input = lv.mem_channels[0].value; // If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode // flag. The top 6 (32-bit) limbs are ignored (this is not part of the spec, but we trust the // kernel to set them to zero). yield_constr.constraint_transition( - lv.is_cpu_cycle * lv.is_exit_kernel * (jumps_lv.input0[0] - nv.program_counter), + lv.is_cpu_cycle * lv.is_exit_kernel * (input[0] - nv.program_counter), ); yield_constr.constraint_transition( - lv.is_cpu_cycle * lv.is_exit_kernel * (jumps_lv.input0[1] - nv.is_kernel_mode), + lv.is_cpu_cycle * lv.is_exit_kernel * (input[1] - nv.is_kernel_mode), ); } @@ -36,18 +36,18 @@ pub fn eval_ext_circuit_exit_kernel, const D: usize nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let jumps_lv = lv.general.jumps(); + let input = lv.mem_channels[0].value; let filter = builder.mul_extension(lv.is_cpu_cycle, lv.is_exit_kernel); // If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode // flag. The top 6 (32-bit) limbs are ignored (this is not part of the spec, but we trust the // kernel to set them to zero). - let pc_constr = builder.sub_extension(jumps_lv.input0[0], nv.program_counter); + let pc_constr = builder.sub_extension(input[0], nv.program_counter); let pc_constr = builder.mul_extension(filter, pc_constr); yield_constr.constraint_transition(builder, pc_constr); - let kernel_constr = builder.sub_extension(jumps_lv.input0[1], nv.is_kernel_mode); + let kernel_constr = builder.sub_extension(input[1], nv.is_kernel_mode); let kernel_constr = builder.mul_extension(filter, kernel_constr); yield_constr.constraint_transition(builder, kernel_constr); } @@ -58,12 +58,14 @@ pub fn eval_packed_jump_jumpi( yield_constr: &mut ConstraintConsumer

, ) { let jumps_lv = lv.general.jumps(); + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; let filter = lv.is_jump + lv.is_jumpi; // `JUMP` or `JUMPI` // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`. - yield_constr.constraint(lv.is_jump * (jumps_lv.input1[0] - P::ONES)); - for &limb in &jumps_lv.input1[1..] { + yield_constr.constraint(lv.is_jump * (input1[0] - P::ONES)); + for &limb in &input1[1..] { // Set all limbs (other than the least-significant limb) to 0. // NB: Technically, they don't have to be 0, as long as the sum // `input1[0] + ... + input1[7]` cannot overflow. @@ -75,7 +77,7 @@ pub fn eval_packed_jump_jumpi( yield_constr .constraint(filter * jumps_lv.input0_upper_zero * (jumps_lv.input0_upper_zero - P::ONES)); // The below sum cannot overflow due to the limb size. - let input0_upper_sum: P = jumps_lv.input0[1..].iter().copied().sum(); + let input0_upper_sum: P = input0[1..].iter().copied().sum(); // `input0_upper_zero` = 1 implies `input0_upper_sum` = 0. yield_constr.constraint(filter * jumps_lv.input0_upper_zero * input0_upper_sum); // `input0_upper_zero` = 0 implies `input0_upper_sum_inv * input0_upper_sum` = 1, which can only @@ -113,7 +115,7 @@ pub fn eval_packed_jump_jumpi( // Validate `should_continue` // This sum cannot overflow (due to limb size). - let input1_sum: P = jumps_lv.input1.into_iter().sum(); + let input1_sum: P = input1.into_iter().sum(); // `should_continue` = 1 implies `input1_sum` = 0. yield_constr.constraint(filter * jumps_lv.should_continue * input1_sum); // `should_continue` = 0 implies `input1_sum * input1_sum_inv` = 1, which can only happen if @@ -147,9 +149,8 @@ pub fn eval_packed_jump_jumpi( yield_constr.constraint_transition( filter * jumps_lv.should_continue * (nv.program_counter - lv.program_counter - P::ONES), ); - yield_constr.constraint_transition( - filter * jumps_lv.should_jump * (nv.program_counter - jumps_lv.input0[0]), - ); + yield_constr + .constraint_transition(filter * jumps_lv.should_jump * (nv.program_counter - input0[0])); } pub fn eval_ext_circuit_jump_jumpi, const D: usize>( @@ -159,15 +160,17 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> yield_constr: &mut RecursiveConstraintConsumer, ) { let jumps_lv = lv.general.jumps(); + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; let filter = builder.add_extension(lv.is_jump, lv.is_jumpi); // `JUMP` or `JUMPI` // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`. { - let constr = builder.mul_sub_extension(lv.is_jump, jumps_lv.input1[0], lv.is_jump); + let constr = builder.mul_sub_extension(lv.is_jump, input1[0], lv.is_jump); yield_constr.constraint(builder, constr); } - for &limb in &jumps_lv.input1[1..] { + for &limb in &input1[1..] { // Set all limbs (other than the least-significant limb) to 0. // NB: Technically, they don't have to be 0, as long as the sum // `input1[0] + ... + input1[7]` cannot overflow. @@ -188,7 +191,7 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> } { // The below sum cannot overflow due to the limb size. - let input0_upper_sum = builder.add_many_extension(jumps_lv.input0[1..].iter()); + let input0_upper_sum = builder.add_many_extension(input0[1..].iter()); // `input0_upper_zero` = 1 implies `input0_upper_sum` = 0. let constr = builder.mul_extension(jumps_lv.input0_upper_zero, input0_upper_sum); @@ -251,7 +254,7 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> // Validate `should_continue` { // This sum cannot overflow (due to limb size). - let input1_sum = builder.add_many_extension(jumps_lv.input1.into_iter()); + let input1_sum = builder.add_many_extension(input1.into_iter()); // `should_continue` = 1 implies `input1_sum` = 0. let constr = builder.mul_extension(jumps_lv.should_continue, input1_sum); @@ -326,7 +329,7 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> } // ...or jumping. { - let constr = builder.sub_extension(nv.program_counter, jumps_lv.input0[0]); + let constr = builder.sub_extension(nv.program_counter, input0[0]); let constr = builder.mul_extension(jumps_lv.should_jump, constr); let constr = builder.mul_extension(filter, constr); yield_constr.constraint_transition(builder, constr); diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index eb55238b..dda006e6 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -15,6 +15,7 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/core/create.asm"), include_str!("asm/core/create_addresses.asm"), include_str!("asm/core/intrinsic_gas.asm"), + include_str!("asm/core/invalid.asm"), include_str!("asm/core/nonce.asm"), include_str!("asm/core/process_txn.asm"), include_str!("asm/core/terminate.asm"), diff --git a/evm/src/cpu/kernel/asm/core/invalid.asm b/evm/src/cpu/kernel/asm/core/invalid.asm new file mode 100644 index 00000000..6a7f4c17 --- /dev/null +++ b/evm/src/cpu/kernel/asm/core/invalid.asm @@ -0,0 +1,26 @@ +global handle_invalid: + // stack: trap_info + + // if the kernel is trying to execute an invalid instruction, then we've already screwed up and + // there's no chance of getting a useful proof, so we just panic + DUP1 + // stack: trap_info, trap_info + %shr_const(32) + // stack: is_kernel, trap_info + %jumpi(panic) + + // check if the opcode that triggered this trap is _actually_ invalid + // stack: program_counter (is_kernel == 0, so trap_info == program_counter) + %mload_current_code + // stack: opcode + PUSH @INVALID_OPCODES_USER + // stack: invalid_opcodes_user, opcode + SWAP1 + // stack: opcode, invalid_opcodes_user + SHR + %and_const(1) + // stack: opcode_is_invalid + // if the opcode is indeed invalid, then perform an exceptional exit + %jumpi(fault_exception) + // otherwise, panic because this trap should not have been entered + PANIC diff --git a/evm/src/cpu/kernel/asm/memory/core.asm b/evm/src/cpu/kernel/asm/memory/core.asm index 2c896345..73bafbee 100644 --- a/evm/src/cpu/kernel/asm/memory/core.asm +++ b/evm/src/cpu/kernel/asm/memory/core.asm @@ -26,6 +26,13 @@ // stack: (empty) %endmacro +// Load a single byte from user code. +%macro mload_current_code + // stack: offset + %mload_current(@SEGMENT_CODE) + // stack: value +%endmacro + // Load a single value from the given segment of kernel (context 0) memory. %macro mload_kernel(segment) // stack: offset diff --git a/evm/src/cpu/kernel/asm/util/basic_macros.asm b/evm/src/cpu/kernel/asm/util/basic_macros.asm index e8dd9eb8..13965e39 100644 --- a/evm/src/cpu/kernel/asm/util/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/util/basic_macros.asm @@ -44,6 +44,13 @@ %endrep %endmacro +%macro and_const(c) + // stack: input, ... + PUSH $c + AND + // stack: input & c, ... +%endmacro + %macro add_const(c) // stack: input, ... PUSH $c @@ -101,6 +108,13 @@ // stack: input << c, ... %endmacro +%macro shr_const(c) + // stack: input, ... + PUSH $c + SHR + // stack: input >> c, ... +%endmacro + %macro eq_const(c) // stack: input, ... PUSH $c diff --git a/evm/src/cpu/kernel/constants.rs b/evm/src/cpu/kernel/constants.rs index 5bc5908e..98fe57c6 100644 --- a/evm/src/cpu/kernel/constants.rs +++ b/evm/src/cpu/kernel/constants.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use ethereum_types::U256; use hex_literal::hex; +use crate::cpu::decode::invalid_opcodes_user; use crate::cpu::kernel::context_metadata::ContextMetadata; use crate::cpu::kernel::global_metadata::GlobalMetadata; use crate::cpu::kernel::txn_fields::NormalizedTxnField; @@ -29,6 +30,10 @@ pub fn evm_constants() -> HashMap { for txn_field in ContextMetadata::all() { c.insert(txn_field.var_name().into(), (txn_field as u32).into()); } + c.insert( + "INVALID_OPCODES_USER".into(), + U256::from_little_endian(&invalid_opcodes_user()), + ); c } diff --git a/evm/src/cpu/kernel/keccak_util.rs b/evm/src/cpu/kernel/keccak_util.rs index 1498ba08..52cc0f08 100644 --- a/evm/src/cpu/kernel/keccak_util.rs +++ b/evm/src/cpu/kernel/keccak_util.rs @@ -1,3 +1,5 @@ +use tiny_keccak::keccakf; + /// A Keccak-f based hash. /// /// This hash does not use standard Keccak padding, since we don't care about extra zeros at the @@ -9,6 +11,42 @@ pub(crate) fn hash_kernel(_code: &[u8]) -> [u32; 8] { } /// Like tiny-keccak's `keccakf`, but deals with `u32` limbs instead of `u64` limbs. -pub(crate) fn keccakf_u32s(_state: &mut [u32; 50]) { - // TODO: Implement +pub(crate) fn keccakf_u32s(state_u32s: &mut [u32; 50]) { + let mut state_u64s: [u64; 25] = std::array::from_fn(|i| { + let lo = state_u32s[i * 2] as u64; + let hi = state_u32s[i * 2 + 1] as u64; + lo | (hi << 32) + }); + keccakf(&mut state_u64s); + *state_u32s = std::array::from_fn(|i| { + let u64_limb = state_u64s[i / 2]; + let is_hi = i % 2; + (u64_limb >> (is_hi * 32)) as u32 + }); +} + +#[cfg(test)] +mod tests { + use tiny_keccak::keccakf; + + use crate::cpu::kernel::keccak_util::keccakf_u32s; + + #[test] + #[rustfmt::skip] + fn test_consistency() { + // We will hash the same data using keccakf and keccakf_u32s. + // The inputs were randomly generated in Python. + let mut state_u64s: [u64; 25] = [0x5dc43ed05dc64048, 0x7bb9e18cdc853880, 0xc1fde300665b008f, 0xeeab85e089d5e431, 0xf7d61298e9ef27ea, 0xc2c5149d1a492455, 0x37a2f4eca0c2d2f2, 0xa35e50c015b3e85c, 0xd2daeced29446ebe, 0x245845f1bac1b98e, 0x3b3aa8783f30a9bf, 0x209ca9a81956d241, 0x8b8ea714da382165, 0x6063e67e202c6d29, 0xf4bac2ded136b907, 0xb17301b461eae65, 0xa91ff0e134ed747c, 0xcc080b28d0c20f1d, 0xf0f79cbec4fb551c, 0x25e04cb0aa930cad, 0x803113d1b541a202, 0xfaf1e4e7cd23b7ec, 0x36a03bbf2469d3b0, 0x25217341908cdfc0, 0xe9cd83f88fdcd500]; + let mut state_u32s: [u32; 50] = [0x5dc64048, 0x5dc43ed0, 0xdc853880, 0x7bb9e18c, 0x665b008f, 0xc1fde300, 0x89d5e431, 0xeeab85e0, 0xe9ef27ea, 0xf7d61298, 0x1a492455, 0xc2c5149d, 0xa0c2d2f2, 0x37a2f4ec, 0x15b3e85c, 0xa35e50c0, 0x29446ebe, 0xd2daeced, 0xbac1b98e, 0x245845f1, 0x3f30a9bf, 0x3b3aa878, 0x1956d241, 0x209ca9a8, 0xda382165, 0x8b8ea714, 0x202c6d29, 0x6063e67e, 0xd136b907, 0xf4bac2de, 0x461eae65, 0xb17301b, 0x34ed747c, 0xa91ff0e1, 0xd0c20f1d, 0xcc080b28, 0xc4fb551c, 0xf0f79cbe, 0xaa930cad, 0x25e04cb0, 0xb541a202, 0x803113d1, 0xcd23b7ec, 0xfaf1e4e7, 0x2469d3b0, 0x36a03bbf, 0x908cdfc0, 0x25217341, 0x8fdcd500, 0xe9cd83f8]; + + // The first output was generated using tiny-keccak; the second was derived from it. + let out_u64s: [u64; 25] = [0x8a541df597e79a72, 0x5c26b8c84faaebb3, 0xc0e8f4e67ca50497, 0x95d98a688de12dec, 0x1c837163975ffaed, 0x9481ec7ef948900e, 0x6a072c65d050a9a1, 0x3b2817da6d615bee, 0x7ffb3c4f8b94bf21, 0x85d6c418cced4a11, 0x18edbe0442884135, 0x2bf265ef3204b7fd, 0xc1e12ce30630d105, 0x8c554dbc61844574, 0x5504db652ce9e42c, 0x2217f3294d0dabe5, 0x7df8eebbcf5b74df, 0x3a56ebb61956f501, 0x7840219dc6f37cc, 0x23194159c967947, 0x9da289bf616ba14d, 0x5a90aaeeca9e9e5b, 0x885dcdc4a549b4e3, 0x46cb188c20947df7, 0x1ef285948ee3d8ab]; + let out_u32s: [u32; 50] = [0x97e79a72, 0x8a541df5, 0x4faaebb3, 0x5c26b8c8, 0x7ca50497, 0xc0e8f4e6, 0x8de12dec, 0x95d98a68, 0x975ffaed, 0x1c837163, 0xf948900e, 0x9481ec7e, 0xd050a9a1, 0x6a072c65, 0x6d615bee, 0x3b2817da, 0x8b94bf21, 0x7ffb3c4f, 0xcced4a11, 0x85d6c418, 0x42884135, 0x18edbe04, 0x3204b7fd, 0x2bf265ef, 0x630d105, 0xc1e12ce3, 0x61844574, 0x8c554dbc, 0x2ce9e42c, 0x5504db65, 0x4d0dabe5, 0x2217f329, 0xcf5b74df, 0x7df8eebb, 0x1956f501, 0x3a56ebb6, 0xdc6f37cc, 0x7840219, 0x9c967947, 0x2319415, 0x616ba14d, 0x9da289bf, 0xca9e9e5b, 0x5a90aaee, 0xa549b4e3, 0x885dcdc4, 0x20947df7, 0x46cb188c, 0x8ee3d8ab, 0x1ef28594]; + + keccakf(&mut state_u64s); + keccakf_u32s(&mut state_u32s); + + assert_eq!(state_u64s, out_u64s); + assert_eq!(state_u32s, out_u32s); + } } diff --git a/evm/src/cpu/kernel/parser.rs b/evm/src/cpu/kernel/parser.rs index 9ed578d4..35bde4b6 100644 --- a/evm/src/cpu/kernel/parser.rs +++ b/evm/src/cpu/kernel/parser.rs @@ -89,14 +89,14 @@ fn parse_macro_call(item: Pair) -> Item { fn parse_repeat(item: Pair) -> Item { assert_eq!(item.as_rule(), Rule::repeat); - let mut inner = item.into_inner().peekable(); + let mut inner = item.into_inner(); let count = parse_literal_u256(inner.next().unwrap()); Item::Repeat(count, inner.map(parse_item).collect()) } fn parse_stack(item: Pair) -> Item { assert_eq!(item.as_rule(), Rule::stack); - let mut inner = item.into_inner().peekable(); + let mut inner = item.into_inner(); let params = inner.next().unwrap(); assert_eq!(params.as_rule(), Rule::paramlist); diff --git a/evm/src/cpu/simple_logic/eq_iszero.rs b/evm/src/cpu/simple_logic/eq_iszero.rs index e1b33dc9..6b7294a8 100644 --- a/evm/src/cpu/simple_logic/eq_iszero.rs +++ b/evm/src/cpu/simple_logic/eq_iszero.rs @@ -8,7 +8,8 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::CpuColumnsView; pub fn generate(lv: &mut CpuColumnsView) { - let logic = lv.general.logic_mut(); + let input0 = lv.mem_channels[0].value; + let eq_filter = lv.is_eq.to_canonical_u64(); let iszero_filter = lv.is_iszero.to_canonical_u64(); assert!(eq_filter <= 1); @@ -19,19 +20,22 @@ pub fn generate(lv: &mut CpuColumnsView) { return; } + let input1 = &mut lv.mem_channels[1].value; if iszero_filter != 0 { - for limb in logic.input1.iter_mut() { + for limb in input1.iter_mut() { *limb = F::ZERO; } } - let num_unequal_limbs = izip!(logic.input0, logic.input1) + let input1 = lv.mem_channels[1].value; + let num_unequal_limbs = izip!(input0, input1) .map(|(limb0, limb1)| (limb0 != limb1) as usize) .sum(); let equal = num_unequal_limbs == 0; - logic.output[0] = F::from_bool(equal); - for limb in &mut logic.output[1..] { + let output = &mut lv.mem_channels[2].value; + output[0] = F::from_bool(equal); + for limb in &mut output[1..] { *limb = F::ZERO; } @@ -40,10 +44,11 @@ pub fn generate(lv: &mut CpuColumnsView) { // Then `diff @ x = num_unequal_limbs`, where `@` denotes the dot product. We set // `diff_pinv = num_unequal_limbs^-1 * x` if `num_unequal_limbs != 0` and 0 otherwise. We have // `diff @ diff_pinv = 1 - equal` as desired. + let logic = lv.general.logic_mut(); let num_unequal_limbs_inv = F::from_canonical_usize(num_unequal_limbs) .try_inverse() .unwrap_or(F::ZERO); - for (limb_pinv, limb0, limb1) in izip!(logic.diff_pinv.iter_mut(), logic.input0, logic.input1) { + for (limb_pinv, limb0, limb1) in izip!(logic.diff_pinv.iter_mut(), input0, input1) { *limb_pinv = (limb0 - limb1).try_inverse().unwrap_or(F::ZERO) * num_unequal_limbs_inv; } } @@ -53,27 +58,31 @@ pub fn eval_packed( yield_constr: &mut ConstraintConsumer

, ) { let logic = lv.general.logic(); + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; + let output = lv.mem_channels[2].value; + let eq_filter = lv.is_eq; let iszero_filter = lv.is_iszero; let eq_or_iszero_filter = eq_filter + iszero_filter; - let equal = logic.output[0]; + let equal = output[0]; let unequal = P::ONES - equal; // Handle `EQ` and `ISZERO`. Most limbs of the output are 0, but the least-significant one is // either 0 or 1. yield_constr.constraint(eq_or_iszero_filter * equal * unequal); - for &limb in &logic.output[1..] { + for &limb in &output[1..] { yield_constr.constraint(eq_or_iszero_filter * limb); } // If `ISZERO`, constrain input1 to be zero, effectively implementing ISZERO(x) as EQ(x, 0). - for limb in logic.input1 { + for limb in input1 { yield_constr.constraint(iszero_filter * limb); } // `equal` implies `input0[i] == input1[i]` for all `i`. - for (limb0, limb1) in izip!(logic.input0, logic.input1) { + for (limb0, limb1) in izip!(input0, input1) { let diff = limb0 - limb1; yield_constr.constraint(eq_or_iszero_filter * equal * diff); } @@ -82,7 +91,7 @@ pub fn eval_packed( // If `unequal`, find `diff_pinv` such that `(input0 - input1) @ diff_pinv == 1`, where `@` // denotes the dot product (there will be many such `diff_pinv`). This can only be done if // `input0 != input1`. - let dot: P = izip!(logic.input0, logic.input1, logic.diff_pinv) + let dot: P = izip!(input0, input1, logic.diff_pinv) .map(|(limb0, limb1, diff_pinv_el)| (limb0 - limb1) * diff_pinv_el) .sum(); yield_constr.constraint(eq_or_iszero_filter * (dot - unequal)); @@ -97,11 +106,15 @@ pub fn eval_ext_circuit, const D: usize>( let one = builder.one_extension(); let logic = lv.general.logic(); + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; + let output = lv.mem_channels[2].value; + let eq_filter = lv.is_eq; let iszero_filter = lv.is_iszero; let eq_or_iszero_filter = builder.add_extension(eq_filter, iszero_filter); - let equal = logic.output[0]; + let equal = output[0]; let unequal = builder.sub_extension(one, equal); // Handle `EQ` and `ISZERO`. Most limbs of the output are 0, but the least-significant one is @@ -111,19 +124,19 @@ pub fn eval_ext_circuit, const D: usize>( let constr = builder.mul_extension(eq_or_iszero_filter, constr); yield_constr.constraint(builder, constr); } - for &limb in &logic.output[1..] { + for &limb in &output[1..] { let constr = builder.mul_extension(eq_or_iszero_filter, limb); yield_constr.constraint(builder, constr); } // If `ISZERO`, constrain input1 to be zero, effectively implementing ISZERO(x) as EQ(x, 0). - for limb in logic.input1 { + for limb in input1 { let constr = builder.mul_extension(iszero_filter, limb); yield_constr.constraint(builder, constr); } // `equal` implies `input0[i] == input1[i]` for all `i`. - for (limb0, limb1) in izip!(logic.input0, logic.input1) { + for (limb0, limb1) in izip!(input0, input1) { let diff = builder.sub_extension(limb0, limb1); let constr = builder.mul_extension(equal, diff); let constr = builder.mul_extension(eq_or_iszero_filter, constr); @@ -135,7 +148,7 @@ pub fn eval_ext_circuit, const D: usize>( // denotes the dot product (there will be many such `diff_pinv`). This can only be done if // `input0 != input1`. { - let dot: ExtensionTarget = izip!(logic.input0, logic.input1, logic.diff_pinv).fold( + let dot: ExtensionTarget = izip!(input0, input1, logic.diff_pinv).fold( zero, |cumul, (limb0, limb1, diff_pinv_el)| { let diff = builder.sub_extension(limb0, limb1); diff --git a/evm/src/cpu/simple_logic/not.rs b/evm/src/cpu/simple_logic/not.rs index bcff3344..83d43276 100644 --- a/evm/src/cpu/simple_logic/not.rs +++ b/evm/src/cpu/simple_logic/not.rs @@ -17,8 +17,9 @@ pub fn generate(lv: &mut CpuColumnsView) { } assert_eq!(is_not_filter, 1); - let logic = lv.general.logic_mut(); - for (input, output_ref) in logic.input0.into_iter().zip(logic.output.iter_mut()) { + let input = lv.mem_channels[0].value; + let output = &mut lv.mem_channels[1].value; + for (input, output_ref) in input.into_iter().zip(output.iter_mut()) { let input = input.to_canonical_u64(); assert_eq!(input >> LIMB_SIZE, 0); let output = input ^ ALL_1_LIMB; @@ -30,14 +31,16 @@ pub fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - // This is simple: just do output = 0xffff - input. - let logic = lv.general.logic(); + // This is simple: just do output = 0xffffffff - input. + let input = lv.mem_channels[0].value; + let output = lv.mem_channels[1].value; let cycle_filter = lv.is_cpu_cycle; let is_not_filter = lv.is_not; let filter = cycle_filter * is_not_filter; - for (input, output) in logic.input0.into_iter().zip(logic.output) { - yield_constr - .constraint(filter * (output + input - P::Scalar::from_canonical_u64(ALL_1_LIMB))); + for (input_limb, output_limb) in input.into_iter().zip(output) { + yield_constr.constraint( + filter * (output_limb + input_limb - P::Scalar::from_canonical_u64(ALL_1_LIMB)), + ); } } @@ -46,12 +49,13 @@ pub fn eval_ext_circuit, const D: usize>( lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let logic = lv.general.logic(); + let input = lv.mem_channels[0].value; + let output = lv.mem_channels[1].value; let cycle_filter = lv.is_cpu_cycle; let is_not_filter = lv.is_not; let filter = builder.mul_extension(cycle_filter, is_not_filter); - for (input, output) in logic.input0.into_iter().zip(logic.output) { - let constr = builder.add_extension(output, input); + for (input_limb, output_limb) in input.into_iter().zip(output) { + let constr = builder.add_extension(output_limb, input_limb); let constr = builder.arithmetic_extension( F::ONE, -F::from_canonical_u64(ALL_1_LIMB), diff --git a/evm/src/cpu/syscalls.rs b/evm/src/cpu/syscalls.rs index a676a6a2..b0b63be8 100644 --- a/evm/src/cpu/syscalls.rs +++ b/evm/src/cpu/syscalls.rs @@ -13,12 +13,16 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; -const NUM_SYSCALLS: usize = 2; +const NUM_SYSCALLS: usize = 3; fn make_syscall_list() -> [(usize, usize); NUM_SYSCALLS] { let kernel = Lazy::force(&KERNEL); - [(COL_MAP.is_stop, "sys_stop"), (COL_MAP.is_exp, "sys_exp")] - .map(|(col_index, handler_name)| (col_index, kernel.global_labels[handler_name])) + [ + (COL_MAP.is_stop, "sys_stop"), + (COL_MAP.is_exp, "sys_exp"), + (COL_MAP.is_invalid, "handle_invalid"), + ] + .map(|(col_index, handler_name)| (col_index, kernel.global_labels[handler_name])) } static TRAP_LIST: Lazy<[(usize, usize); NUM_SYSCALLS]> = Lazy::new(make_syscall_list); @@ -28,7 +32,6 @@ pub fn eval_packed( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let lv_syscalls = lv.general.syscalls(); let syscall_list = Lazy::force(&TRAP_LIST); // 1 if _any_ syscall, else 0. let should_syscall: P = syscall_list @@ -48,12 +51,14 @@ pub fn eval_packed( yield_constr.constraint_transition(filter * (nv.program_counter - syscall_dst)); // If syscall: set kernel mode yield_constr.constraint_transition(filter * (nv.is_kernel_mode - P::ONES)); + + let output = lv.mem_channels[0].value; // If syscall: push current PC to stack - yield_constr.constraint(filter * (lv_syscalls.output[0] - lv.program_counter)); + yield_constr.constraint(filter * (output[0] - lv.program_counter)); // If syscall: push current kernel flag to stack (share register with PC) - yield_constr.constraint(filter * (lv_syscalls.output[1] - lv.is_kernel_mode)); + yield_constr.constraint(filter * (output[1] - lv.is_kernel_mode)); // If syscall: zero the rest of that register - for &limb in &lv_syscalls.output[2..] { + for &limb in &output[2..] { yield_constr.constraint(filter * limb); } } @@ -64,7 +69,6 @@ pub fn eval_ext_circuit, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let lv_syscalls = lv.general.syscalls(); let syscall_list = Lazy::force(&TRAP_LIST); // 1 if _any_ syscall, else 0. let should_syscall = @@ -90,20 +94,22 @@ pub fn eval_ext_circuit, const D: usize>( let constr = builder.mul_sub_extension(filter, nv.is_kernel_mode, filter); yield_constr.constraint_transition(builder, constr); } + + let output = lv.mem_channels[0].value; // If syscall: push current PC to stack { - let constr = builder.sub_extension(lv_syscalls.output[0], lv.program_counter); + let constr = builder.sub_extension(output[0], lv.program_counter); let constr = builder.mul_extension(filter, constr); yield_constr.constraint(builder, constr); } // If syscall: push current kernel flag to stack (share register with PC) { - let constr = builder.sub_extension(lv_syscalls.output[1], lv.is_kernel_mode); + let constr = builder.sub_extension(output[1], lv.is_kernel_mode); let constr = builder.mul_extension(filter, constr); yield_constr.constraint(builder, constr); } // If syscall: zero the rest of that register - for &limb in &lv_syscalls.output[2..] { + for &limb in &output[2..] { let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); } diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 5d242ced..602ff6c5 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -1,3 +1,4 @@ +use std::borrow::Borrow; use std::iter::repeat; use anyhow::{ensure, Result}; @@ -13,7 +14,7 @@ use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::config::GenericConfig; -use crate::all_stark::Table; +use crate::all_stark::{Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::permutation::{ @@ -38,8 +39,10 @@ impl Column { } } - pub fn singles>(cs: I) -> impl Iterator { - cs.into_iter().map(Self::single) + pub fn singles>>( + cs: I, + ) -> impl Iterator { + cs.into_iter().map(|c| Self::single(*c.borrow())) } pub fn constant(constant: F) -> Self { @@ -53,6 +56,10 @@ impl Column { Self::constant(F::ZERO) } + pub fn one() -> Self { + Self::constant(F::ONE) + } + pub fn linear_combination_with_constant>( iter: I, constant: F, @@ -74,16 +81,20 @@ impl Column { Self::linear_combination_with_constant(iter, F::ZERO) } - pub fn le_bits>(cs: I) -> Self { - Self::linear_combination(cs.into_iter().zip(F::TWO.powers())) + pub fn le_bits>>(cs: I) -> Self { + Self::linear_combination(cs.into_iter().map(|c| *c.borrow()).zip(F::TWO.powers())) } - pub fn le_bytes>(cs: I) -> Self { - Self::linear_combination(cs.into_iter().zip(F::from_canonical_u16(256).powers())) + pub fn le_bytes>>(cs: I) -> Self { + Self::linear_combination( + cs.into_iter() + .map(|c| *c.borrow()) + .zip(F::from_canonical_u16(256).powers()), + ) } - pub fn sum>(cs: I) -> Self { - Self::linear_combination(cs.into_iter().zip(repeat(F::ONE))) + pub fn sum>>(cs: I) -> Self { + Self::linear_combination(cs.into_iter().map(|c| *c.borrow()).zip(repeat(F::ONE))) } pub fn eval(&self, v: &[P]) -> P @@ -216,12 +227,12 @@ impl CtlData { pub fn cross_table_lookup_data, const D: usize>( config: &StarkConfig, - trace_poly_values: &[Vec>], + trace_poly_values: &[Vec>; NUM_TABLES], cross_table_lookups: &[CrossTableLookup], challenger: &mut Challenger, -) -> Vec> { +) -> [CtlData; NUM_TABLES] { let challenges = get_grand_product_challenge_set(challenger, config.num_challenges); - let mut ctl_data_per_table = vec![CtlData::default(); trace_poly_values.len()]; + let mut ctl_data_per_table = [0; NUM_TABLES].map(|_| CtlData::default()); for CrossTableLookup { looking_tables, looked_table, @@ -337,12 +348,11 @@ impl<'a, F: RichField + Extendable, const D: usize> CtlCheckVars<'a, F, F::Extension, F::Extension, D> { pub(crate) fn from_proofs>( - proofs: &[StarkProof], + proofs: &[StarkProof; NUM_TABLES], cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, - num_permutation_zs: &[usize], - ) -> Vec> { - debug_assert_eq!(proofs.len(), num_permutation_zs.len()); + num_permutation_zs: &[usize; NUM_TABLES], + ) -> [Vec; NUM_TABLES] { let mut ctl_zs = proofs .iter() .zip(num_permutation_zs) @@ -354,7 +364,7 @@ impl<'a, F: RichField + Extendable, const D: usize> }) .collect::>(); - let mut ctl_vars_per_table = vec![vec![]; proofs.len()]; + let mut ctl_vars_per_table = [0; NUM_TABLES].map(|_| vec![]); for CrossTableLookup { looking_tables, looked_table, @@ -441,12 +451,11 @@ pub struct CtlCheckVarsTarget<'a, F: Field, const D: usize> { impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { pub(crate) fn from_proofs( - proofs: &[StarkProofTarget], + proofs: &[StarkProofTarget; NUM_TABLES], cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, - num_permutation_zs: &[usize], - ) -> Vec> { - debug_assert_eq!(proofs.len(), num_permutation_zs.len()); + num_permutation_zs: &[usize; NUM_TABLES], + ) -> [Vec; NUM_TABLES] { let mut ctl_zs = proofs .iter() .zip(num_permutation_zs) @@ -458,7 +467,7 @@ impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { }) .collect::>(); - let mut ctl_vars_per_table = vec![vec![]; proofs.len()]; + let mut ctl_vars_per_table = [0; NUM_TABLES].map(|_| vec![]); for CrossTableLookup { looking_tables, looked_table, @@ -612,7 +621,7 @@ pub(crate) fn verify_cross_table_lookups< const D: usize, >( cross_table_lookups: Vec>, - proofs: &[StarkProof], + proofs: &[StarkProof; NUM_TABLES], challenges: GrandProductChallengeSet, config: &StarkConfig, ) -> Result<()> { @@ -670,7 +679,7 @@ pub(crate) fn verify_cross_table_lookups_circuit< >( builder: &mut CircuitBuilder, cross_table_lookups: Vec>, - proofs: &[StarkProofTarget], + proofs: &[StarkProofTarget; NUM_TABLES], challenges: GrandProductChallengeSet, inner_config: &StarkConfig, ) { diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 67b65c31..5b0b3c8f 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -4,7 +4,8 @@ use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; -use crate::all_stark::AllStark; +use crate::all_stark::{AllStark, NUM_TABLES}; +use crate::config::StarkConfig; use crate::cpu::bootstrap_kernel::generate_bootstrap_kernel; use crate::cpu::columns::NUM_CPU_COLUMNS; use crate::cpu::kernel::global_metadata::GlobalMetadata; @@ -45,7 +46,8 @@ pub struct GenerationInputs { pub(crate) fn generate_traces, const D: usize>( all_stark: &AllStark, inputs: GenerationInputs, -) -> (Vec>>, PublicValues) { + config: &StarkConfig, +) -> ([Vec>; NUM_TABLES], PublicValues) { let mut state = GenerationState::::default(); generate_bootstrap_kernel::(&mut state); @@ -83,6 +85,7 @@ pub(crate) fn generate_traces, const D: usize>( current_cpu_row, memory, keccak_inputs, + keccak_memory_inputs, logic_ops, .. } = state; @@ -90,9 +93,18 @@ pub(crate) fn generate_traces, const D: usize>( let cpu_trace = trace_rows_to_poly_values(cpu_rows); let keccak_trace = all_stark.keccak_stark.generate_trace(keccak_inputs); + let keccak_memory_trace = all_stark + .keccak_memory_stark + .generate_trace(keccak_memory_inputs, 1 << config.fri_config.cap_height); let logic_trace = all_stark.logic_stark.generate_trace(logic_ops); let memory_trace = all_stark.memory_stark.generate_trace(memory.log); - let traces = vec![cpu_trace, keccak_trace, logic_trace, memory_trace]; + let traces = [ + cpu_trace, + keccak_trace, + keccak_memory_trace, + logic_trace, + memory_trace, + ]; let public_values = PublicValues { trie_roots_before, diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 866f9fd7..4cbe61c8 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -77,13 +77,13 @@ impl GenerationState { let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index; let value = self.get_mem(context, segment, virt, timestamp); - self.current_cpu_row.mem_channel_used[channel_index] = F::ONE; - self.current_cpu_row.mem_is_read[channel_index] = F::ONE; - self.current_cpu_row.mem_addr_context[channel_index] = F::from_canonical_usize(context); - self.current_cpu_row.mem_addr_segment[channel_index] = - F::from_canonical_usize(segment as usize); - self.current_cpu_row.mem_addr_virtual[channel_index] = F::from_canonical_usize(virt); - self.current_cpu_row.mem_value[channel_index] = u256_limbs(value); + let channel = &mut self.current_cpu_row.mem_channels[channel_index]; + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(context); + channel.addr_segment = F::from_canonical_usize(segment as usize); + channel.addr_virtual = F::from_canonical_usize(virt); + channel.value = u256_limbs(value); value } @@ -133,13 +133,13 @@ impl GenerationState { let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index; self.set_mem(context, segment, virt, value, timestamp); - self.current_cpu_row.mem_channel_used[channel_index] = F::ONE; - self.current_cpu_row.mem_is_read[channel_index] = F::ZERO; // For clarity; should already be 0. - self.current_cpu_row.mem_addr_context[channel_index] = F::from_canonical_usize(context); - self.current_cpu_row.mem_addr_segment[channel_index] = - F::from_canonical_usize(segment as usize); - self.current_cpu_row.mem_addr_virtual[channel_index] = F::from_canonical_usize(virt); - self.current_cpu_row.mem_value[channel_index] = u256_limbs(value); + let channel = &mut self.current_cpu_row.mem_channels[channel_index]; + channel.used = F::ONE; + channel.is_read = F::ZERO; // For clarity; should already be 0. + channel.addr_context = F::from_canonical_usize(context); + channel.addr_segment = F::from_canonical_usize(segment as usize); + channel.addr_virtual = F::from_canonical_usize(virt); + channel.value = u256_limbs(value); } /// Write some memory, and log the operation. diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index 52c2b796..6545a1af 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -1,4 +1,3 @@ -use itertools::izip; use plonky2::field::extension::Extendable; use plonky2::fri::proof::{FriProof, FriProofTarget}; use plonky2::hash::hash_types::RichField; @@ -32,16 +31,18 @@ impl, C: GenericConfig, const D: usize> A let ctl_challenges = get_grand_product_challenge_set(&mut challenger, config.num_challenges); + let num_permutation_zs = all_stark.nums_permutation_zs(config); + let num_permutation_batch_sizes = all_stark.permutation_batch_sizes(); + AllProofChallenges { - stark_challenges: izip!( - &self.stark_proofs, - all_stark.nums_permutation_zs(config), - all_stark.permutation_batch_sizes() - ) - .map(|(proof, num_perm, batch_size)| { - proof.get_challenges(&mut challenger, num_perm > 0, batch_size, config) - }) - .collect(), + stark_challenges: std::array::from_fn(|i| { + self.stark_proofs[i].get_challenges( + &mut challenger, + num_permutation_zs[i] > 0, + num_permutation_batch_sizes[i], + config, + ) + }), ctl_challenges, } } @@ -66,22 +67,19 @@ impl AllProofTarget { let ctl_challenges = get_grand_product_challenge_set_target(builder, &mut challenger, config.num_challenges); + let num_permutation_zs = all_stark.nums_permutation_zs(config); + let num_permutation_batch_sizes = all_stark.permutation_batch_sizes(); + AllProofChallengesTarget { - stark_challenges: izip!( - &self.stark_proofs, - all_stark.nums_permutation_zs(config), - all_stark.permutation_batch_sizes() - ) - .map(|(proof, num_perm, batch_size)| { - proof.get_challenges::( + stark_challenges: std::array::from_fn(|i| { + self.stark_proofs[i].get_challenges::( builder, &mut challenger, - num_perm > 0, - batch_size, + num_permutation_zs[i] > 0, + num_permutation_batch_sizes[i], config, ) - }) - .collect(), + }), ctl_challenges, } } diff --git a/evm/src/keccak_sponge/columns.rs b/evm/src/keccak_sponge/columns.rs new file mode 100644 index 00000000..08194e87 --- /dev/null +++ b/evm/src/keccak_sponge/columns.rs @@ -0,0 +1,114 @@ +use std::borrow::{Borrow, BorrowMut}; +use std::mem::{size_of, transmute}; + +use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; + +pub(crate) const KECCAK_WIDTH_BYTES: usize = 200; +pub(crate) const KECCAK_WIDTH_U32S: usize = KECCAK_WIDTH_BYTES / 4; +pub(crate) const KECCAK_RATE_BYTES: usize = 136; +pub(crate) const KECCAK_RATE_U32S: usize = KECCAK_RATE_BYTES / 4; +pub(crate) const KECCAK_CAPACITY_BYTES: usize = 64; +pub(crate) const KECCAK_CAPACITY_U32S: usize = KECCAK_CAPACITY_BYTES / 4; + +#[repr(C)] +#[derive(Eq, PartialEq, Debug)] +pub(crate) struct KeccakSpongeColumnsView { + /// 1 if this row represents a full input block, i.e. one in which each byte is an input byte, + /// not a padding byte; 0 otherwise. + pub is_full_input_block: T, + + /// 1 if this row represents the final block of a sponge, in which case some or all of the bytes + /// in the block will be padding bytes; 0 otherwise. + pub is_final_block: T, + + // The address at which we will read the input block. + pub context: T, + pub segment: T, + pub virt: T, + + /// The timestamp at which inputs should be read from memory. + pub timestamp: T, + + /// The length of the original input, in bytes. + pub len: T, + + /// The number of input bytes that have already been absorbed prior to this block. + pub already_absorbed_bytes: T, + + /// If this row represents a final block row, the `i`th entry should be 1 if the final chunk of + /// input has length `i` (in other words if `len - already_absorbed == i`), otherwise 0. + /// + /// If this row represents a full input block, this should contain all 0s. + pub is_final_input_len: [T; KECCAK_RATE_BYTES], + + /// The initial rate part of the sponge, at the start of this step. + pub original_rate_u32s: [T; KECCAK_RATE_U32S], + + /// The capacity part of the sponge, encoded as 32-bit chunks, at the start of this step. + pub original_capacity_u32s: [T; KECCAK_CAPACITY_U32S], + + /// The block being absorbed, which may contain input bytes and/or padding bytes. + pub block_bytes: [T; KECCAK_RATE_BYTES], + + /// The rate part of the sponge, encoded as 32-bit chunks, after the current block is xor'd in, + /// but before the permutation is applied. + pub xored_rate_u32s: [T; KECCAK_RATE_U32S], + + /// The entire state (rate + capacity) of the sponge, encoded as 32-bit chunks, after the + /// permutation is applied. + pub updated_state_u32s: [T; KECCAK_WIDTH_U32S], +} + +// `u8` is guaranteed to have a `size_of` of 1. +pub const NUM_KECCAK_SPONGE_COLUMNS: usize = size_of::>(); + +impl From<[T; NUM_KECCAK_SPONGE_COLUMNS]> for KeccakSpongeColumnsView { + fn from(value: [T; NUM_KECCAK_SPONGE_COLUMNS]) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl From> for [T; NUM_KECCAK_SPONGE_COLUMNS] { + fn from(value: KeccakSpongeColumnsView) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl Borrow> for [T; NUM_KECCAK_SPONGE_COLUMNS] { + fn borrow(&self) -> &KeccakSpongeColumnsView { + unsafe { transmute(self) } + } +} + +impl BorrowMut> for [T; NUM_KECCAK_SPONGE_COLUMNS] { + fn borrow_mut(&mut self) -> &mut KeccakSpongeColumnsView { + unsafe { transmute(self) } + } +} + +impl Borrow<[T; NUM_KECCAK_SPONGE_COLUMNS]> for KeccakSpongeColumnsView { + fn borrow(&self) -> &[T; NUM_KECCAK_SPONGE_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl BorrowMut<[T; NUM_KECCAK_SPONGE_COLUMNS]> for KeccakSpongeColumnsView { + fn borrow_mut(&mut self) -> &mut [T; NUM_KECCAK_SPONGE_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl Default for KeccakSpongeColumnsView { + fn default() -> Self { + [T::default(); NUM_KECCAK_SPONGE_COLUMNS].into() + } +} + +const fn make_col_map() -> KeccakSpongeColumnsView { + let indices_arr = indices_arr::(); + unsafe { + transmute::<[usize; NUM_KECCAK_SPONGE_COLUMNS], KeccakSpongeColumnsView>(indices_arr) + } +} + +pub(crate) const KECCAK_SPONGE_COL_MAP: KeccakSpongeColumnsView = make_col_map(); diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs new file mode 100644 index 00000000..afde02c2 --- /dev/null +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -0,0 +1,468 @@ +use std::borrow::Borrow; +use std::iter::{once, repeat}; +use std::marker::PhantomData; +use std::mem::size_of; + +use itertools::Itertools; +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::timed; +use plonky2::util::timing::TimingTree; +use plonky2_util::ceil_div_usize; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cpu::kernel::keccak_util::keccakf_u32s; +use crate::cross_table_lookup::Column; +use crate::keccak_sponge::columns::*; +use crate::memory::segments::Segment; +use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; +use crate::vars::StarkEvaluationTargets; +use crate::vars::StarkEvaluationVars; + +#[allow(unused)] // TODO: Should be used soon. +pub(crate) fn ctl_looked_data() -> Vec> { + let cols = KECCAK_SPONGE_COL_MAP; + let outputs = Column::singles(&cols.updated_state_u32s[..8]); + Column::singles([ + cols.context, + cols.segment, + cols.virt, + cols.timestamp, + cols.len, + ]) + .chain(outputs) + .collect() +} + +#[allow(unused)] // TODO: Should be used soon. +pub(crate) fn ctl_looking_keccak() -> Vec> { + let cols = KECCAK_SPONGE_COL_MAP; + Column::singles( + [ + cols.original_rate_u32s.as_slice(), + &cols.original_capacity_u32s, + &cols.updated_state_u32s, + ] + .concat(), + ) + .collect() +} + +#[allow(unused)] // TODO: Should be used soon. +pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { + let cols = KECCAK_SPONGE_COL_MAP; + + let mut res = vec![Column::constant(F::ONE)]; // is_read + + res.extend(Column::singles([cols.context, cols.segment])); + + // The address of the byte being read is `virt + already_absorbed_bytes + i`. + res.push(Column::linear_combination_with_constant( + [(cols.virt, F::ONE), (cols.already_absorbed_bytes, F::ONE)], + F::from_canonical_usize(i), + )); + + // The i'th input byte being read. + res.push(Column::single(cols.block_bytes[i])); + + // Since we're reading a single byte, the higher limbs must be zero. + res.extend((1..8).map(|_| Column::zero())); + + res.push(Column::single(cols.timestamp)); + + assert_eq!( + res.len(), + crate::memory::memory_stark::ctl_data::().len() + ); + res +} + +/// CTL for performing the `i`th logic CTL. Since we need to do 136 byte XORs, and the logic CTL can +/// XOR 32 bytes per CTL, there are 5 such CTLs. +#[allow(unused)] // TODO: Should be used soon. +pub(crate) fn ctl_looking_logic(i: usize) -> Vec> { + const U32S_PER_CTL: usize = 8; + const U8S_PER_CTL: usize = 32; + + debug_assert!(i < ceil_div_usize(KECCAK_RATE_BYTES, U8S_PER_CTL)); + let cols = KECCAK_SPONGE_COL_MAP; + + let mut res = vec![ + Column::zero(), // is_and + Column::zero(), // is_or + Column::one(), // is_xor + ]; + + // Input 0 contains some of the sponge's original rate chunks. If this is the last CTL, we won't + // need to use all of the CTL's inputs, so we will pass some zeros. + res.extend( + Column::singles(&cols.original_rate_u32s[i * U32S_PER_CTL..]) + .chain(repeat(Column::zero())) + .take(U32S_PER_CTL), + ); + + // Input 1 contains some of block's chunks. Again, for the last CTL it will include some zeros. + res.extend( + cols.block_bytes[i * U8S_PER_CTL..] + .chunks(size_of::()) + .map(|chunk| Column::le_bytes(chunk)) + .chain(repeat(Column::zero())) + .take(U8S_PER_CTL), + ); + + // The output contains the XOR'd rate part. + res.extend( + Column::singles(&cols.xored_rate_u32s[i * U32S_PER_CTL..]) + .chain(repeat(Column::zero())) + .take(U32S_PER_CTL), + ); + + res +} + +#[allow(unused)] // TODO: Should be used soon. +pub(crate) fn ctl_looked_filter() -> Column { + // The CPU table is only interested in our final-block rows, since those contain the final + // sponge output. + Column::single(KECCAK_SPONGE_COL_MAP.is_final_block) +} + +#[allow(unused)] // TODO: Should be used soon. +/// CTL filter for reading the `i`th byte of input from memory. +pub(crate) fn ctl_looking_memory_filter(i: usize) -> Column { + // We perform the `i`th read if either + // - this is a full input block, or + // - this is a final block of length `i` or greater + let cols = KECCAK_SPONGE_COL_MAP; + Column::sum(once(&cols.is_full_input_block).chain(&cols.is_final_input_len[i..])) +} + +/// Information about a Keccak sponge operation needed for witness generation. +#[derive(Debug)] +pub(crate) struct KeccakSpongeOp { + // The address at which inputs are read. + pub(crate) context: usize, + pub(crate) segment: Segment, + pub(crate) virt: usize, + + /// The timestamp at which inputs are read. + pub(crate) timestamp: usize, + + /// The length of the input, in bytes. + pub(crate) len: usize, + + /// The input that was read. + pub(crate) input: Vec, +} + +#[derive(Copy, Clone, Default)] +pub(crate) struct KeccakSpongeStark { + f: PhantomData, +} + +impl, const D: usize> KeccakSpongeStark { + #[allow(unused)] // TODO: Should be used soon. + pub(crate) fn generate_trace( + &self, + operations: Vec, + min_rows: usize, + ) -> Vec> { + let mut timing = TimingTree::new("generate trace", log::Level::Debug); + + // Generate the witness row-wise. + let trace_rows = timed!( + &mut timing, + "generate trace rows", + self.generate_trace_rows(operations, min_rows) + ); + + let trace_polys = timed!( + &mut timing, + "convert to PolynomialValues", + trace_rows_to_poly_values(trace_rows) + ); + + timing.print(); + trace_polys + } + + fn generate_trace_rows( + &self, + operations: Vec, + min_rows: usize, + ) -> Vec<[F; NUM_KECCAK_SPONGE_COLUMNS]> { + let num_rows = operations.len().max(min_rows).next_power_of_two(); + operations + .into_iter() + .flat_map(|op| self.generate_rows_for_op(op)) + .chain(repeat(self.generate_padding_row())) + .take(num_rows) + .collect() + } + + fn generate_rows_for_op(&self, op: KeccakSpongeOp) -> Vec<[F; NUM_KECCAK_SPONGE_COLUMNS]> { + let mut rows = vec![]; + + let mut sponge_state = [0u32; KECCAK_WIDTH_U32S]; + + let mut input_blocks = op.input.chunks_exact(KECCAK_RATE_BYTES); + let mut already_absorbed_bytes = 0; + for block in input_blocks.by_ref() { + let row = self.generate_full_input_row( + &op, + already_absorbed_bytes, + sponge_state, + block.try_into().unwrap(), + ); + + sponge_state = row.updated_state_u32s.map(|f| f.to_canonical_u64() as u32); + + rows.push(row.into()); + already_absorbed_bytes += KECCAK_RATE_BYTES; + } + + rows.push( + self.generate_final_row( + &op, + already_absorbed_bytes, + sponge_state, + input_blocks.remainder(), + ) + .into(), + ); + + rows + } + + fn generate_full_input_row( + &self, + op: &KeccakSpongeOp, + already_absorbed_bytes: usize, + sponge_state: [u32; KECCAK_WIDTH_U32S], + block: [u8; KECCAK_RATE_BYTES], + ) -> KeccakSpongeColumnsView { + let mut row = KeccakSpongeColumnsView { + is_full_input_block: F::ONE, + ..Default::default() + }; + + row.block_bytes = block.map(F::from_canonical_u8); + + Self::generate_common_fields(&mut row, op, already_absorbed_bytes, sponge_state); + row + } + + fn generate_final_row( + &self, + op: &KeccakSpongeOp, + already_absorbed_bytes: usize, + sponge_state: [u32; KECCAK_WIDTH_U32S], + final_inputs: &[u8], + ) -> KeccakSpongeColumnsView { + assert_eq!(already_absorbed_bytes + final_inputs.len(), op.len); + + let mut row = KeccakSpongeColumnsView { + is_final_block: F::ONE, + ..Default::default() + }; + + for (block_byte, input_byte) in row.block_bytes.iter_mut().zip(final_inputs) { + *block_byte = F::from_canonical_u8(*input_byte); + } + + // pad10*1 rule + if final_inputs.len() == KECCAK_RATE_BYTES - 1 { + // Both 1s are placed in the same byte. + row.block_bytes[final_inputs.len()] = F::from_canonical_u8(0b10000001); + } else { + row.block_bytes[final_inputs.len()] = F::ONE; + row.block_bytes[KECCAK_RATE_BYTES - 1] = F::from_canonical_u8(0b10000000); + } + + row.is_final_input_len[final_inputs.len()] = F::ONE; + + Self::generate_common_fields(&mut row, op, already_absorbed_bytes, sponge_state); + row + } + + /// Generate fields that are common to both full-input-block rows and final-block rows. + /// Also updates the sponge state with a single absorption. + fn generate_common_fields( + row: &mut KeccakSpongeColumnsView, + op: &KeccakSpongeOp, + already_absorbed_bytes: usize, + mut sponge_state: [u32; KECCAK_WIDTH_U32S], + ) { + row.context = F::from_canonical_usize(op.context); + row.segment = F::from_canonical_usize(op.segment as usize); + row.virt = F::from_canonical_usize(op.virt); + row.timestamp = F::from_canonical_usize(op.timestamp); + row.len = F::from_canonical_usize(op.len); + row.already_absorbed_bytes = F::from_canonical_usize(already_absorbed_bytes); + + row.original_rate_u32s = sponge_state[..KECCAK_RATE_U32S] + .iter() + .map(|x| F::from_canonical_u32(*x)) + .collect_vec() + .try_into() + .unwrap(); + + row.original_capacity_u32s = sponge_state[KECCAK_RATE_U32S..] + .iter() + .map(|x| F::from_canonical_u32(*x)) + .collect_vec() + .try_into() + .unwrap(); + + let block_u32s = (0..KECCAK_RATE_U32S).map(|i| { + u32::from_le_bytes( + row.block_bytes[i * 4..(i + 1) * 4] + .iter() + .map(|x| x.to_canonical_u64() as u8) + .collect_vec() + .try_into() + .unwrap(), + ) + }); + + // xor in the block + for (state_i, block_i) in sponge_state.iter_mut().zip(block_u32s) { + *state_i ^= block_i; + } + let xored_rate_u32s: [u32; KECCAK_RATE_U32S] = sponge_state[..KECCAK_RATE_U32S] + .to_vec() + .try_into() + .unwrap(); + row.xored_rate_u32s = xored_rate_u32s.map(F::from_canonical_u32); + + keccakf_u32s(&mut sponge_state); + row.updated_state_u32s = sponge_state.map(F::from_canonical_u32); + } + + fn generate_padding_row(&self) -> [F; NUM_KECCAK_SPONGE_COLUMNS] { + // The default instance has is_full_input_block = is_final_block = 0, + // indicating that it's a dummy/padding row. + KeccakSpongeColumnsView::default().into() + } +} + +impl, const D: usize> Stark for KeccakSpongeStark { + const COLUMNS: usize = NUM_KECCAK_SPONGE_COLUMNS; + + fn eval_packed_generic( + &self, + vars: StarkEvaluationVars, + _yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField, + { + let _local_values: &KeccakSpongeColumnsView

= vars.local_values.borrow(); + + // TODO: Each flag (full-input block, final block or implied dummy flag) must be boolean. + // TODO: before_rate_bits, block_bits and is_final_input_len must contain booleans. + + // TODO: Sum of is_final_input_len should equal is_final_block (which will be 0 or 1). + + // TODO: If this is the first row, the original sponge state should be 0 and already_absorbed_bytes = 0. + // TODO: If this is a final block, the next row's original sponge state should be 0 and already_absorbed_bytes = 0. + + // TODO: If this is a full-input block, the next row's address, time and len must match. + // TODO: If this is a full-input block, the next row's "before" should match our "after" state. + // TODO: If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136. + + // TODO: A dummy row is always followed by another dummy row, so the prover can't put dummy rows "in between" to avoid the above checks. + + // TODO: is_final_input_len implies `len - already_absorbed == i`. + } + + fn eval_ext_circuit( + &self, + _builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + vars: StarkEvaluationTargets, + _yield_constr: &mut RecursiveConstraintConsumer, + ) { + let _local_values: &KeccakSpongeColumnsView> = + vars.local_values.borrow(); + + // TODO + } + + fn constraint_degree(&self) -> usize { + 3 + } +} + +#[cfg(test)] +mod tests { + use std::borrow::Borrow; + + use anyhow::Result; + use itertools::Itertools; + use keccak_hash::keccak; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::PrimeField64; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use crate::keccak_sponge::columns::KeccakSpongeColumnsView; + use crate::keccak_sponge::keccak_sponge_stark::{KeccakSpongeOp, KeccakSpongeStark}; + use crate::memory::segments::Segment; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; + + #[test] + fn test_stark_degree() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = KeccakSpongeStark; + + let stark = S::default(); + test_stark_low_degree(stark) + } + + #[test] + fn test_stark_circuit() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = KeccakSpongeStark; + + let stark = S::default(); + test_stark_circuit_constraints::(stark) + } + + #[test] + fn test_generation() -> Result<()> { + const D: usize = 2; + type F = GoldilocksField; + type S = KeccakSpongeStark; + + let input = vec![1, 2, 3]; + let expected_output = keccak(&input); + + let op = KeccakSpongeOp { + context: 0, + segment: Segment::Code, + virt: 0, + timestamp: 0, + len: input.len(), + input, + }; + let stark = S::default(); + let rows = stark.generate_rows_for_op(op); + assert_eq!(rows.len(), 1); + let last_row: &KeccakSpongeColumnsView = rows.last().unwrap().borrow(); + let output = last_row.updated_state_u32s[..8] + .iter() + .flat_map(|x| (x.to_canonical_u64() as u32).to_le_bytes()) + .collect_vec(); + + assert_eq!(output, expected_output.0); + Ok(()) + } +} diff --git a/evm/src/keccak_sponge/mod.rs b/evm/src/keccak_sponge/mod.rs new file mode 100644 index 00000000..92b7f0c1 --- /dev/null +++ b/evm/src/keccak_sponge/mod.rs @@ -0,0 +1,6 @@ +//! The Keccak sponge STARK is used to hash a variable amount of data which is read from memory. +//! It connects to the memory STARK to read input data, and to the Keccak-f STARK to evaluate the +//! permutation at each absorption step. + +pub mod columns; +pub mod keccak_sponge_stark; diff --git a/evm/src/lib.rs b/evm/src/lib.rs index 0a31a7ba..6f332b59 100644 --- a/evm/src/lib.rs +++ b/evm/src/lib.rs @@ -2,6 +2,7 @@ #![allow(clippy::needless_range_loop)] #![allow(clippy::too_many_arguments)] #![allow(clippy::type_complexity)] +#![feature(let_chains)] #![feature(generic_const_exprs)] pub mod all_stark; @@ -14,6 +15,7 @@ pub mod generation; mod get_challenges; pub mod keccak; pub mod keccak_memory; +pub mod keccak_sponge; pub mod logic; pub mod lookup; pub mod memory; diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 54f3618b..b23e7a37 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -13,38 +13,32 @@ use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::plonk::config::GenericConfig; +use crate::all_stark::NUM_TABLES; use crate::config::StarkConfig; use crate::permutation::GrandProductChallengeSet; #[derive(Debug, Clone)] pub struct AllProof, C: GenericConfig, const D: usize> { - pub stark_proofs: Vec>, + pub stark_proofs: [StarkProof; NUM_TABLES], pub public_values: PublicValues, } impl, C: GenericConfig, const D: usize> AllProof { - pub fn degree_bits(&self, config: &StarkConfig) -> Vec { - self.stark_proofs - .iter() - .map(|proof| proof.recover_degree_bits(config)) - .collect() + pub fn degree_bits(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { + std::array::from_fn(|i| self.stark_proofs[i].recover_degree_bits(config)) } - pub fn nums_ctl_zs(&self) -> Vec { - self.stark_proofs - .iter() - .map(|proof| proof.num_ctl_zs()) - .collect() + pub fn nums_ctl_zs(&self) -> [usize; NUM_TABLES] { + std::array::from_fn(|i| self.stark_proofs[i].num_ctl_zs()) } -} pub(crate) struct AllProofChallenges, const D: usize> { - pub stark_challenges: Vec>, + pub stark_challenges: [StarkProofChallenges; NUM_TABLES], pub ctl_challenges: GrandProductChallengeSet, } pub struct AllProofTarget { - pub stark_proofs: Vec>, + pub stark_proofs: [StarkProofTarget; NUM_TABLES], pub public_values: PublicValuesTarget, } @@ -99,7 +93,7 @@ pub struct BlockMetadataTarget { } pub(crate) struct AllProofChallengesTarget { - pub stark_challenges: Vec>, + pub stark_challenges: [StarkProofChallengesTarget; NUM_TABLES], pub ctl_challenges: GrandProductChallengeSet, } diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 75152d61..31e76a1c 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -17,7 +17,7 @@ use plonky2::util::timing::TimingTree; use plonky2::util::transpose; use plonky2_util::{log2_ceil, log2_strict}; -use crate::all_stark::{AllStark, Table}; +use crate::all_stark::{AllStark, Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; @@ -53,7 +53,7 @@ where [(); LogicStark::::COLUMNS]:, [(); MemoryStark::::COLUMNS]:, { - let (traces, public_values) = generate_traces(all_stark, inputs); + let (traces, public_values) = generate_traces(all_stark, inputs, config); prove_with_traces(all_stark, config, traces, public_values, timing) } @@ -61,7 +61,7 @@ where pub(crate) fn prove_with_traces( all_stark: &AllStark, config: &StarkConfig, - trace_poly_values: Vec>>, + trace_poly_values: [Vec>; NUM_TABLES], public_values: PublicValues, timing: &mut TimingTree, ) -> Result> @@ -75,9 +75,6 @@ where [(); LogicStark::::COLUMNS]:, [(); MemoryStark::::COLUMNS]:, { - let num_starks = Table::num_tables(); - debug_assert_eq!(num_starks, trace_poly_values.len()); - let rate_bits = config.fri_config.rate_bits; let cap_height = config.fri_config.cap_height; @@ -163,14 +160,13 @@ where timing, )?; - let stark_proofs = vec![ + let stark_proofs = [ cpu_proof, keccak_proof, keccak_memory_proof, logic_proof, memory_proof, ]; - debug_assert_eq!(stark_proofs.len(), num_starks); Ok(AllProof { stark_proofs, diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 9c23e01e..ecc16a70 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -423,7 +423,7 @@ pub fn add_virtual_all_proof, const D: usize>( degree_bits: &[usize], nums_ctl_zs: &[usize], ) -> AllProofTarget { - let stark_proofs = vec![ + let stark_proofs = [ add_virtual_stark_proof( builder, all_stark.cpu_stark, @@ -460,7 +460,6 @@ pub fn add_virtual_all_proof, const D: usize>( nums_ctl_zs[Table::Memory as usize], ), ]; - assert_eq!(stark_proofs.len(), Table::num_tables()); let public_values = add_virtual_public_values(builder); AllProofTarget { diff --git a/evm/src/stark_testing.rs b/evm/src/stark_testing.rs index 5cd83e41..81b0f68f 100644 --- a/evm/src/stark_testing.rs +++ b/evm/src/stark_testing.rs @@ -60,17 +60,20 @@ where }) .collect::>(); - let constraint_eval_degree = PolynomialValues::new(constraint_evals).degree(); - let maximum_degree = WITNESS_SIZE * stark.constraint_degree() - 1; + let constraint_poly_values = PolynomialValues::new(constraint_evals); + if !constraint_poly_values.is_zero() { + let constraint_eval_degree = constraint_poly_values.degree(); + let maximum_degree = WITNESS_SIZE * stark.constraint_degree() - 1; - ensure!( - constraint_eval_degree <= maximum_degree, - "Expected degrees at most {} * {} - 1 = {}, actual {:?}", - WITNESS_SIZE, - stark.constraint_degree(), - maximum_degree, - constraint_eval_degree - ); + ensure!( + constraint_eval_degree <= maximum_degree, + "Expected degrees at most {} * {} - 1 = {}, actual {:?}", + WITNESS_SIZE, + stark.constraint_degree(), + maximum_degree, + constraint_eval_degree + ); + } Ok(()) } diff --git a/evm/src/util.rs b/evm/src/util.rs index ae5281db..12aead46 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -1,3 +1,5 @@ +use std::mem::{size_of, transmute_copy, ManuallyDrop}; + use ethereum_types::{H160, U256}; use itertools::Itertools; use plonky2::field::extension::Extendable; @@ -67,3 +69,21 @@ pub(crate) fn h160_limbs(h160: H160) -> [F; 5] { .try_into() .unwrap() } + +pub(crate) const fn indices_arr() -> [usize; N] { + let mut indices_arr = [0; N]; + let mut i = 0; + while i < N { + indices_arr[i] = i; + i += 1; + } + indices_arr +} + +pub(crate) unsafe fn transmute_no_compile_time_size_checks(value: T) -> U { + debug_assert_eq!(size_of::(), size_of::()); + // Need ManuallyDrop so that `value` is not dropped by this function. + let value = ManuallyDrop::new(value); + // Copy the bit pattern. The original value is no longer safe to use. + transmute_copy(&value) +} diff --git a/field/src/polynomial/mod.rs b/field/src/polynomial/mod.rs index 20f1c318..09ed69c7 100644 --- a/field/src/polynomial/mod.rs +++ b/field/src/polynomial/mod.rs @@ -37,6 +37,10 @@ impl PolynomialValues { Self::constant(F::ZERO, len) } + pub fn is_zero(&self) -> bool { + self.values.iter().all(|x| x.is_zero()) + } + /// Returns the polynomial whole value is one at the given index, and zero elsewhere. pub fn selector(len: usize, index: usize) -> Self { let mut result = Self::zero(len); diff --git a/plonky2/src/gadgets/arithmetic_extension.rs b/plonky2/src/gadgets/arithmetic_extension.rs index 97dedf28..23caeac1 100644 --- a/plonky2/src/gadgets/arithmetic_extension.rs +++ b/plonky2/src/gadgets/arithmetic_extension.rs @@ -505,7 +505,7 @@ impl, const D: usize> SimpleGenerator { fn dependencies(&self) -> Vec { let mut deps = self.numerator.to_target_array().to_vec(); - deps.extend(&self.denominator.to_target_array()); + deps.extend(self.denominator.to_target_array()); deps } diff --git a/plonky2/src/gates/interpolation.rs b/plonky2/src/gates/interpolation.rs index 1983e5aa..a619d1f2 100644 --- a/plonky2/src/gates/interpolation.rs +++ b/plonky2/src/gates/interpolation.rs @@ -100,13 +100,13 @@ impl, const D: usize> Gate for (i, point) in coset.into_iter().enumerate() { let value = vars.get_local_ext_algebra(self.wires_value(i)); let computed_value = interpolant.eval_base(point); - constraints.extend(&(value - computed_value).to_basefield_array()); + constraints.extend((value - computed_value).to_basefield_array()); } let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval(evaluation_point); - constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); + constraints.extend((evaluation_value - computed_evaluation_value).to_basefield_array()); constraints } @@ -151,7 +151,7 @@ impl, const D: usize> Gate let value = vars.get_local_ext_algebra(self.wires_value(i)); let computed_value = interpolant.eval_scalar(builder, point); constraints.extend( - &builder + builder .sub_ext_algebra(value, computed_value) .to_ext_target_array(), ); @@ -161,7 +161,7 @@ impl, const D: usize> Gate let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval(builder, evaluation_point); constraints.extend( - &builder + builder .sub_ext_algebra(evaluation_value, computed_evaluation_value) .to_ext_target_array(), ); diff --git a/plonky2/src/gates/low_degree_interpolation.rs b/plonky2/src/gates/low_degree_interpolation.rs index 217f4f0a..dabadfa4 100644 --- a/plonky2/src/gates/low_degree_interpolation.rs +++ b/plonky2/src/gates/low_degree_interpolation.rs @@ -113,7 +113,7 @@ impl, const D: usize> Gate for LowDegreeInter { let value = vars.get_local_ext_algebra(self.wires_value(i)); let computed_value = altered_interpolant.eval_base(point); - constraints.extend(&(value - computed_value).to_basefield_array()); + constraints.extend((value - computed_value).to_basefield_array()); } let evaluation_point_powers = (1..self.num_points()) @@ -128,7 +128,7 @@ impl, const D: usize> Gate for LowDegreeInter } let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); - constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); + constraints.extend((evaluation_value - computed_evaluation_value).to_basefield_array()); constraints } @@ -225,7 +225,7 @@ impl, const D: usize> Gate for LowDegreeInter let point = builder.constant_extension(point); let computed_value = altered_interpolant.eval_scalar(builder, point); constraints.extend( - &builder + builder .sub_ext_algebra(value, computed_value) .to_ext_target_array(), ); @@ -253,7 +253,7 @@ impl, const D: usize> Gate for LowDegreeInter // let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); // let computed_evaluation_value = interpolant.eval(builder, evaluation_point); constraints.extend( - &builder + builder .sub_ext_algebra(evaluation_value, computed_evaluation_value) .to_ext_target_array(), ); diff --git a/u32/src/gadgets/arithmetic_u32.rs b/u32/src/gadgets/arithmetic_u32.rs index 7a7731b1..7475681c 100644 --- a/u32/src/gadgets/arithmetic_u32.rs +++ b/u32/src/gadgets/arithmetic_u32.rs @@ -10,7 +10,7 @@ use plonky2_field::extension::Extendable; use crate::gates::add_many_u32::U32AddManyGate; use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::subtraction_u32::U32SubtractionGate; -use crate::witness::generated_values_set_u32_target; +use crate::witness::GeneratedValuesU32; #[derive(Clone, Copy, Debug)] pub struct U32Target(pub Target); @@ -249,8 +249,8 @@ impl, const D: usize> SimpleGenerator let low = x_u64 as u32; let high = (x_u64 >> 32) as u32; - generated_values_set_u32_target(out_buffer, self.low, low); - generated_values_set_u32_target(out_buffer, self.high, high); + out_buffer.set_u32_target(self.low, low); + out_buffer.set_u32_target(self.high, high); } } diff --git a/u32/src/witness.rs b/u32/src/witness.rs index 1b88d60d..ddc3432f 100644 --- a/u32/src/witness.rs +++ b/u32/src/witness.rs @@ -1,21 +1,33 @@ use plonky2::iop::generator::GeneratedValues; use plonky2::iop::witness::Witness; -use plonky2_field::types::Field; +use plonky2_field::types::{Field, PrimeField64}; use crate::gadgets::arithmetic_u32::U32Target; -pub fn generated_values_set_u32_target( - buffer: &mut GeneratedValues, - target: U32Target, - value: u32, -) { - buffer.set_target(target.0, F::from_canonical_u32(value)) +pub trait WitnessU32: Witness { + fn set_u32_target(&mut self, target: U32Target, value: u32); + fn get_u32_target(&self, target: U32Target) -> (u32, u32); } -pub fn witness_set_u32_target, F: Field>( - witness: &mut W, - target: U32Target, - value: u32, -) { - witness.set_target(target.0, F::from_canonical_u32(value)) +impl, F: PrimeField64> WitnessU32 for T { + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)); + } + + fn get_u32_target(&self, target: U32Target) -> (u32, u32) { + let x_u64 = self.get_target(target.0).to_canonical_u64(); + let low = x_u64 as u32; + let high = (x_u64 >> 32) as u32; + (low, high) + } +} + +pub trait GeneratedValuesU32 { + fn set_u32_target(&mut self, target: U32Target, value: u32); +} + +impl GeneratedValuesU32 for GeneratedValues { + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)) + } }