diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 62a6c2cc..5ec4041b 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -13,7 +13,7 @@ use crate::keccak::keccak_stark; use crate::keccak::keccak_stark::KeccakStark; use crate::keccak_sponge::columns::KECCAK_RATE_BYTES; use crate::keccak_sponge::keccak_sponge_stark; -use crate::keccak_sponge::keccak_sponge_stark::{num_logic_ctls, KeccakSpongeStark}; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic; use crate::logic::LogicStark; use crate::memory::memory_stark; @@ -89,11 +89,9 @@ impl Table { } pub(crate) fn all_cross_table_lookups() -> Vec> { - let mut ctls = vec![ctl_keccak(), ctl_logic(), ctl_memory(), ctl_keccak_sponge()]; + let mut ctls = vec![ctl_keccak_sponge(), ctl_keccak(), ctl_logic(), ctl_memory()]; // TODO: Some CTLs temporarily disabled while we get them working. disable_ctl(&mut ctls[0]); - disable_ctl(&mut ctls[1]); - disable_ctl(&mut ctls[2]); disable_ctl(&mut ctls[3]); ctls } @@ -140,12 +138,11 @@ fn ctl_logic() -> CrossTableLookup { Some(cpu_stark::ctl_filter_logic()), ); let mut all_lookers = vec![cpu_looking]; - for i in 0..num_logic_ctls() { + for i in 0..keccak_sponge_stark::num_logic_ctls() { let keccak_sponge_looking = TableWithColumns::new( Table::KeccakSponge, keccak_sponge_stark::ctl_looking_logic(i), - // TODO: Double check, but I think it's the same filter for memory and logic? - Some(keccak_sponge_stark::ctl_looking_memory_filter(i)), + Some(keccak_sponge_stark::ctl_looking_logic_filter()), ); all_lookers.push(keccak_sponge_looking); } diff --git a/evm/src/cpu/bootstrap_kernel.rs b/evm/src/cpu/bootstrap_kernel.rs index ba45f738..192c4e7a 100644 --- a/evm/src/cpu/bootstrap_kernel.rs +++ b/evm/src/cpu/bootstrap_kernel.rs @@ -43,17 +43,18 @@ pub(crate) fn generate_bootstrap_kernel(state: &mut GenerationState final_cpu_row.is_bootstrap_kernel = F::ONE; final_cpu_row.is_keccak_sponge = F::ONE; // The Keccak sponge CTL uses memory value columns for its inputs and outputs. - final_cpu_row.mem_channels[0].value[0] = F::ZERO; - final_cpu_row.mem_channels[1].value[0] = F::from_canonical_usize(Segment::Code as usize); - final_cpu_row.mem_channels[2].value[0] = F::ZERO; - final_cpu_row.mem_channels[3].value[0] = F::from_canonical_usize(state.traces.clock()); + final_cpu_row.mem_channels[0].value[0] = F::ZERO; // context + final_cpu_row.mem_channels[1].value[0] = F::from_canonical_usize(Segment::Code as usize); // segment + final_cpu_row.mem_channels[2].value[0] = F::ZERO; // virt + final_cpu_row.mem_channels[3].value[0] = F::from_canonical_usize(KERNEL.code.len()); // len final_cpu_row.mem_channels[4].value = KERNEL.code_hash.map(F::from_canonical_u32); - state.traces.push_cpu(final_cpu_row); keccak_sponge_log( state, MemoryAddress::new(0, Segment::Code, 0), KERNEL.code.clone(), ); + state.traces.push_cpu(final_cpu_row); + log::info!("Bootstrapping took {} cycles", state.traces.clock()); } pub(crate) fn eval_bootstrap_kernel>( diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 1af8428d..b0e61dd2 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -10,6 +10,7 @@ use plonky2::hash::hash_types::RichField; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; +use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::{ bootstrap_kernel, contextops, control_flow, decode, dup_swap, jumps, membus, memio, modfp254, pc, shift, simple_logic, stack, stack_bounds, syscalls, @@ -36,7 +37,7 @@ pub fn ctl_data_keccak_sponge() -> Vec> { let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); let mut cols = vec![context, segment, virt, len, timestamp]; - cols.extend(COL_MAP.mem_channels[3].value.map(Column::single)); + cols.extend(COL_MAP.mem_channels[4].value.map(Column::single)); cols } @@ -48,7 +49,9 @@ pub fn ctl_data_logic() -> Vec> { let mut res = Column::singles([COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor]).collect_vec(); res.extend(Column::singles(COL_MAP.mem_channels[0].value)); res.extend(Column::singles(COL_MAP.mem_channels[1].value)); - res.extend(Column::singles(COL_MAP.mem_channels[2].value)); + res.extend(Column::singles( + COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value, + )); res } diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index c26f5229..0064c3aa 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -40,7 +40,10 @@ pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { let storage_tries_by_state_key = trie_inputs .storage_tries .iter() - .map(|(address, storage_trie)| (Nibbles::from(keccak(address)), storage_trie)) + .map(|(address, storage_trie)| { + let key = Nibbles::from_bytes_be(keccak(address).as_bytes()).unwrap(); + (key, storage_trie) + }) .collect(); mpt_prover_inputs_state_trie( diff --git a/evm/src/keccak/columns.rs b/evm/src/keccak/columns.rs index 8313c676..039db078 100644 --- a/evm/src/keccak/columns.rs +++ b/evm/src/keccak/columns.rs @@ -9,6 +9,10 @@ pub const fn reg_step(i: usize) -> usize { i } +/// A register which indicates if a row should be included in the CTL. Should be 1 only for certain +/// rows which are final steps, i.e. with `reg_step(23) = 1`. +pub const REG_FILTER: usize = NUM_ROUNDS; + /// Registers to hold permutation inputs. /// `reg_input_limb(2*i) -> input[i] as u32` /// `reg_input_limb(2*i+1) -> input[i] >> 32` @@ -20,7 +24,7 @@ pub fn reg_input_limb(i: usize) -> Column { let y = i_u64 / 5; let x = i_u64 % 5; - let reg_low_limb = reg_a(x, y); + let reg_low_limb = reg_preimage(x, y); let is_high_limb = i % 2; Column::single(reg_low_limb + is_high_limb) } @@ -48,7 +52,15 @@ const R: [[u8; 5]; 5] = [ [27, 20, 39, 8, 14], ]; -const START_A: usize = NUM_ROUNDS; +const START_PREIMAGE: usize = NUM_ROUNDS + 1; +/// Registers to hold the original input to a permutation, i.e. the input to the first round. +pub(crate) const fn reg_preimage(x: usize, y: usize) -> usize { + debug_assert!(x < 5); + debug_assert!(y < 5); + START_PREIMAGE + (x * 5 + y) * 2 +} + +const START_A: usize = START_PREIMAGE + 5 * 5 * 2; pub(crate) const fn reg_a(x: usize, y: usize) -> usize { debug_assert!(x < 5); debug_assert!(y < 5); diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 7be421fb..df842a41 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -14,7 +14,8 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cross_table_lookup::Column; use crate::keccak::columns::{ reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime, - reg_b, reg_c, reg_c_prime, reg_input_limb, reg_output_limb, reg_step, NUM_COLUMNS, + reg_b, reg_c, reg_c_prime, reg_input_limb, reg_output_limb, reg_preimage, reg_step, + NUM_COLUMNS, REG_FILTER, }; use crate::keccak::constants::{rc_value, rc_value_bit}; use crate::keccak::logic::{ @@ -38,8 +39,7 @@ pub fn ctl_data() -> Vec> { } pub fn ctl_filter() -> Column { - // TODO: Also need to filter out padding rows somehow. - Column::single(reg_step(NUM_ROUNDS - 1)) + Column::single(REG_FILTER) } #[derive(Copy, Clone, Default)] @@ -60,7 +60,10 @@ impl, const D: usize> KeccakStark { .next_power_of_two(); let mut rows = Vec::with_capacity(num_rows); for input in inputs.iter() { - rows.extend(self.generate_trace_rows_for_perm(*input)); + let mut rows_for_perm = self.generate_trace_rows_for_perm(*input); + // Since this is a real operation, not padding, we set the filter to 1 on the last row. + rows_for_perm[NUM_ROUNDS - 1][REG_FILTER] = F::ONE; + rows.extend(rows_for_perm); } let pad_rows = self.generate_trace_rows_for_perm([0; NUM_INPUTS]); @@ -71,9 +74,26 @@ impl, const D: usize> KeccakStark { rows } - fn generate_trace_rows_for_perm(&self, input: [u64; NUM_INPUTS]) -> Vec<[F; NUM_COLUMNS]> { - let mut rows = vec![[F::ZERO; NUM_COLUMNS]; NUM_ROUNDS]; + fn generate_trace_rows_for_perm( + &self, + input: [u64; NUM_INPUTS], + ) -> [[F; NUM_COLUMNS]; NUM_ROUNDS] { + let mut rows = [[F::ZERO; NUM_COLUMNS]; NUM_ROUNDS]; + // Populate the preimage for each row. + for round in 0..24 { + for x in 0..5 { + for y in 0..5 { + let input_xy = input[y * 5 + x]; + let reg_preimage_lo = reg_preimage(x, y); + let reg_preimage_hi = reg_preimage_lo + 1; + rows[round][reg_preimage_lo] = F::from_canonical_u64(input_xy & 0xFFFFFFFF); + rows[round][reg_preimage_hi] = F::from_canonical_u64(input_xy >> 32); + } + } + } + + // Populate the round input for the first round. for x in 0..5 { for y in 0..5 { let input_xy = input[y * 5 + x]; @@ -237,6 +257,24 @@ impl, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, yield_constr: &mut RecursiveConstraintConsumer, ) { + let one_ext = builder.one_extension(); let two = builder.two(); let two_ext = builder.two_extension(); let four_ext = builder.constant_extension(F::Extension::from_canonical_u8(4)); eval_round_flags_recursively(builder, vars, yield_constr); + // The filter must be 0 or 1. + let filter = vars.local_values[REG_FILTER]; + let constraint = builder.mul_sub_extension(filter, filter, filter); + yield_constr.constraint(builder, constraint); + + // If this is not the final step, the filter must be off. + let final_step = vars.local_values[reg_step(NUM_ROUNDS - 1)]; + let not_final_step = builder.sub_extension(one_ext, final_step); + let constraint = builder.mul_extension(not_final_step, filter); + yield_constr.constraint(builder, constraint); + + // If this is not the final step, the local and next preimages must match. + for x in 0..5 { + for y in 0..5 { + let preimage = reg_preimage(x, y); + let diff = + builder.sub_extension(vars.local_values[preimage], vars.next_values[preimage]); + let constraint = builder.mul_extension(not_final_step, diff); + yield_constr.constraint_transition(builder, constraint); + } + } + // C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]). for x in 0..5 { for z in 0..64 { diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs index ebefce06..3fe59f74 100644 --- a/evm/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -41,7 +41,7 @@ pub(crate) fn ctl_looking_keccak() -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; Column::singles( [ - cols.original_rate_u32s.as_slice(), + cols.xored_rate_u32s.as_slice(), &cols.original_capacity_u32s, &cols.updated_state_u32s, ] @@ -140,13 +140,19 @@ pub(crate) fn ctl_looking_memory_filter(i: usize) -> Column { Column::sum(once(&cols.is_full_input_block).chain(&cols.is_final_input_len[i..])) } +/// CTL filter for looking at XORs in the logic table. +pub(crate) fn ctl_looking_logic_filter() -> Column { + let cols = KECCAK_SPONGE_COL_MAP; + Column::sum([cols.is_full_input_block, cols.is_final_block]) +} + pub(crate) fn ctl_looking_keccak_filter() -> Column { let cols = KECCAK_SPONGE_COL_MAP; Column::sum([cols.is_full_input_block, cols.is_final_block]) } /// Information about a Keccak sponge operation needed for witness generation. -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) struct KeccakSpongeOp { /// The base address at which inputs are read. pub(crate) base_address: MemoryAddress, @@ -192,13 +198,15 @@ impl, const D: usize> KeccakSpongeStark { operations: Vec, min_rows: usize, ) -> Vec<[F; NUM_KECCAK_SPONGE_COLUMNS]> { - let num_rows = operations.len().max(min_rows).next_power_of_two(); - operations - .into_iter() - .flat_map(|op| self.generate_rows_for_op(op)) - .chain(repeat(self.generate_padding_row())) - .take(num_rows) - .collect() + let mut rows = vec![]; + for op in operations { + rows.extend(self.generate_rows_for_op(op)); + } + let padded_rows = rows.len().max(min_rows).next_power_of_two(); + for _ in rows.len()..padded_rows { + rows.push(self.generate_padding_row()); + } + rows } fn generate_rows_for_op(&self, op: KeccakSpongeOp) -> Vec<[F; NUM_KECCAK_SPONGE_COLUMNS]> { diff --git a/evm/src/witness/mem_tx.rs b/evm/src/witness/mem_tx.rs deleted file mode 100644 index 7cc33653..00000000 --- a/evm/src/witness/mem_tx.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::witness::memory::{MemoryOp, MemoryOpKind, MemoryState}; - -pub fn apply_mem_ops(state: &mut MemoryState, mut ops: Vec) { - ops.sort_unstable_by_key(|mem_op| mem_op.timestamp); - - for op in ops { - let MemoryOp { address, op, .. } = op; - if let MemoryOpKind::Write(val) = op { - state.set(address, val); - } - } -} diff --git a/evm/tests/transfer_to_new_addr.rs b/evm/tests/transfer_to_new_addr.rs index 351506da..e4fe8eb4 100644 --- a/evm/tests/transfer_to_new_addr.rs +++ b/evm/tests/transfer_to_new_addr.rs @@ -33,8 +33,8 @@ fn test_simple_transfer() -> anyhow::Result<()> { let to = hex!("a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0"); let sender_state_key = keccak(sender); let to_state_key = keccak(to); - let sender_nibbles = Nibbles::from(sender_state_key); - let to_nibbles = Nibbles::from(to_state_key); + let sender_nibbles = Nibbles::from_bytes_be(sender_state_key.as_bytes()).unwrap(); + let to_nibbles = Nibbles::from_bytes_be(to_state_key.as_bytes()).unwrap(); let value = U256::from(100u32); let sender_account_before = AccountRlp {