diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 26a43d59..22f7c123 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -80,7 +80,7 @@ pub(crate) fn all_cross_table_lookups() -> Vec> { let mut ctls = vec![ctl_keccak(), ctl_logic(), ctl_memory(), ctl_keccak_sponge()]; // TODO: Some CTLs temporarily disabled while we get them working. disable_ctl(&mut ctls[0]); - disable_ctl(&mut ctls[1]); // Enable once we populate logic log in keccak_sponge_log. + disable_ctl(&mut ctls[1]); disable_ctl(&mut ctls[2]); disable_ctl(&mut ctls[3]); ctls diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index 60f7f28a..1493b292 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -33,11 +33,23 @@ pub(crate) enum BinaryOperator { impl BinaryOperator { pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 { match self { - BinaryOperator::Add => input0 + input1, - BinaryOperator::Mul => input0 * input1, - BinaryOperator::Sub => input0 - input1, - BinaryOperator::Div => input0 / input1, - BinaryOperator::Mod => input0 % input1, + BinaryOperator::Add => input0.overflowing_add(input1).0, + BinaryOperator::Mul => input0.overflowing_mul(input1).0, + BinaryOperator::Sub => input0.overflowing_sub(input1).0, + BinaryOperator::Div => { + if input1.is_zero() { + U256::zero() + } else { + input0 / input1 + } + } + BinaryOperator::Mod => { + if input1.is_zero() { + U256::zero() + } else { + input0 % input1 + } + } BinaryOperator::Lt => { if input0 < input1 { U256::one() @@ -52,8 +64,20 @@ impl BinaryOperator { U256::zero() } } - BinaryOperator::Shl => input0 << input1, - BinaryOperator::Shr => input0 >> input1, + BinaryOperator::Shl => { + if input0 > 255.into() { + U256::zero() + } else { + input1 << input0 + } + } + BinaryOperator::Shr => { + if input0 > 255.into() { + U256::zero() + } else { + input1 >> input0 + } + } BinaryOperator::AddFp254 => addmod(input0, input1, bn_base_order()), BinaryOperator::MulFp254 => mulmod(input0, input1, bn_base_order()), BinaryOperator::SubFp254 => submod(input0, input1, bn_base_order()), diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index abd87113..d9ea7d2c 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -490,12 +490,10 @@ impl<'a> Interpreter<'a> { fn run_byte(&mut self) { let i = self.pop(); let x = self.pop(); - let result = if i > 32.into() { - 0 + let result = if i < 32.into() { + x.byte(31 - i.as_usize()) } else { - let mut bytes = [0; 32]; - x.to_big_endian(&mut bytes); - bytes[i.as_usize()] + 0 }; self.push(result.into()); } diff --git a/evm/src/cpu/shift.rs b/evm/src/cpu/shift.rs index d383b6b2..bbabf173 100644 --- a/evm/src/cpu/shift.rs +++ b/evm/src/cpu/shift.rs @@ -14,7 +14,7 @@ pub(crate) fn eval_packed( yield_constr: &mut ConstraintConsumer

, ) { let is_shift = lv.op.shl + lv.op.shr; - let displacement = lv.mem_channels[1]; // holds the shift displacement d + let displacement = lv.mem_channels[0]; // holds the shift displacement d let two_exp = lv.mem_channels[2]; // holds 2^d // Not needed here; val is the input and we're verifying that output is @@ -65,7 +65,7 @@ pub(crate) fn eval_ext_circuit, const D: usize>( yield_constr: &mut RecursiveConstraintConsumer, ) { let is_shift = builder.add_extension(lv.op.shl, lv.op.shr); - let displacement = lv.mem_channels[1]; + let displacement = lv.mem_channels[0]; let two_exp = lv.mem_channels[2]; let shift_table_segment = F::from_canonical_u64(Segment::ShiftTable as u64); diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index d46b64d8..8b662a6d 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -24,7 +24,7 @@ use crate::proof::{BlockMetadata, PublicValues, TrieRoots}; use crate::witness::memory::MemoryAddress; use crate::witness::transition::transition; -pub(crate) mod mpt; +pub mod mpt; pub(crate) mod prover_input; pub(crate) mod rlp; pub(crate) mod state; @@ -74,6 +74,11 @@ pub(crate) fn generate_traces, const D: usize>( timed!(timing, "simulate CPU", simulate_cpu(&mut state)); + log::info!( + "Trace lengths (before padding): {:?}", + state.traces.checkpoint() + ); + let read_metadata = |field| { state.memory.get(MemoryAddress::new( 0, diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index a5be1205..15b92f45 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -9,11 +9,11 @@ use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::generation::TrieInputs; #[derive(RlpEncodable, RlpDecodable, Debug)] -pub(crate) struct AccountRlp { - pub(crate) nonce: U256, - pub(crate) balance: U256, - pub(crate) storage_root: H256, - pub(crate) code_hash: H256, +pub struct AccountRlp { + pub nonce: U256, + pub balance: U256, + pub storage_root: H256, + pub code_hash: H256, } pub(crate) fn all_mpt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec { diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 443b0380..13a4180b 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -80,7 +80,7 @@ fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize { } else if curr.address.virt != next.address.virt { next.address.virt - curr.address.virt - 1 } else { - next.timestamp - curr.timestamp - 1 + next.timestamp - curr.timestamp } }) .max() @@ -124,7 +124,7 @@ pub fn generate_first_change_flags_and_rc(trace_rows: &mut [[F; NU } else if virtual_first_change { next_virt - virt - F::ONE } else { - next_timestamp - timestamp - F::ONE + next_timestamp - timestamp }; } } @@ -283,7 +283,7 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark( let [(input0, log_in0), (input1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let operation = arithmetic::Operation::binary(operator, input0, input1); + let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?; + if operator == arithmetic::BinaryOperator::Shl || operator == arithmetic::BinaryOperator::Shr { + const LOOKUP_CHANNEL: usize = 2; + let lookup_addr = MemoryAddress::new(0, Segment::ShiftTable, input0.low_u32() as usize); + if input0.bits() <= 32 { + let (_, read) = + mem_read_gp_with_log_and_fill(LOOKUP_CHANNEL, lookup_addr, state, &mut row); + state.traces.push_memory(read); + } else { + // The shift constraints still expect the address to be set, even though no read will occur. + let mut channel = &mut row.mem_channels[LOOKUP_CHANNEL]; + channel.addr_context = F::from_canonical_usize(lookup_addr.context); + channel.addr_segment = F::from_canonical_usize(lookup_addr.segment); + channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt); + } + } + state.traces.push_arithmetic(operation); state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); @@ -121,6 +138,7 @@ pub(crate) fn generate_keccak_general( val.as_u32() as u8 }) .collect_vec(); + log::debug!("Hashing {:?}", input); let hash = keccak(&input); let log_push = stack_push_log_and_fill(state, &mut row, hash.into_uint())?; diff --git a/evm/src/witness/traces.rs b/evm/src/witness/traces.rs index 60e0d8af..41b654fb 100644 --- a/evm/src/witness/traces.rs +++ b/evm/src/witness/traces.rs @@ -18,6 +18,8 @@ use crate::{arithmetic, keccak, logic}; #[derive(Clone, Copy, Debug)] pub struct TraceCheckpoint { pub(self) cpu_len: usize, + pub(self) keccak_len: usize, + pub(self) keccak_sponge_len: usize, pub(self) logic_len: usize, pub(self) arithmetic_len: usize, pub(self) memory_len: usize, @@ -48,19 +50,22 @@ impl Traces { pub fn checkpoint(&self) -> TraceCheckpoint { TraceCheckpoint { cpu_len: self.cpu.len(), + keccak_len: self.keccak_inputs.len(), + keccak_sponge_len: self.keccak_sponge_ops.len(), logic_len: self.logic_ops.len(), arithmetic_len: self.arithmetic.len(), memory_len: self.memory_ops.len(), - // TODO others } } pub fn rollback(&mut self, checkpoint: TraceCheckpoint) { self.cpu.truncate(checkpoint.cpu_len); + self.keccak_inputs.truncate(checkpoint.keccak_len); + self.keccak_sponge_ops + .truncate(checkpoint.keccak_sponge_len); self.logic_ops.truncate(checkpoint.logic_len); self.arithmetic.truncate(checkpoint.arithmetic_len); self.memory_ops.truncate(checkpoint.memory_len); - // TODO others } pub fn mem_ops_since(&self, checkpoint: TraceCheckpoint) -> &[MemoryOp] { diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 74a9957f..39aac810 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -219,6 +219,12 @@ fn perform_op( _ => 1, }; + if let Some(label) = KERNEL.offset_label(state.registers.program_counter) { + if !label.starts_with("halt_pc") { + log::debug!("At {label}"); + } + } + Ok(()) } diff --git a/evm/src/witness/util.rs b/evm/src/witness/util.rs index 788f47e0..08d68edc 100644 --- a/evm/src/witness/util.rs +++ b/evm/src/witness/util.rs @@ -3,11 +3,12 @@ use plonky2::field::types::Field; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::keccak_util::keccakf_u8s; -use crate::cpu::membus::NUM_GP_CHANNELS; +use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS}; use crate::cpu::stack_bounds::MAX_USER_STACK_SIZE; use crate::generation::state::GenerationState; use crate::keccak_sponge::columns::{KECCAK_RATE_BYTES, KECCAK_WIDTH_BYTES}; use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; +use crate::logic; use crate::memory::segments::Segment; use crate::witness::errors::ProgramError; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; @@ -174,39 +175,76 @@ pub(crate) fn stack_push_log_and_fill( Ok(res) } +fn xor_into_sponge( + state: &mut GenerationState, + sponge_state: &mut [u8; KECCAK_WIDTH_BYTES], + block: &[u8; KECCAK_RATE_BYTES], +) { + for i in (0..KECCAK_RATE_BYTES).step_by(32) { + let range = i..KECCAK_RATE_BYTES.min(i + 32); + let lhs = U256::from_little_endian(&sponge_state[range.clone()]); + let rhs = U256::from_little_endian(&block[range]); + state + .traces + .push_logic(logic::Operation::new(logic::Op::Xor, lhs, rhs)); + } + for i in 0..KECCAK_RATE_BYTES { + sponge_state[i] ^= block[i]; + } +} + pub(crate) fn keccak_sponge_log( state: &mut GenerationState, base_address: MemoryAddress, input: Vec, ) { + let clock = state.traces.clock(); + + let mut address = base_address; let mut input_blocks = input.chunks_exact(KECCAK_RATE_BYTES); let mut sponge_state = [0u8; KECCAK_WIDTH_BYTES]; for block in input_blocks.by_ref() { - sponge_state[..KECCAK_RATE_BYTES].copy_from_slice(block); + for &byte in block { + state.traces.push_memory(MemoryOp::new( + MemoryChannel::Code, + clock, + address, + MemoryOpKind::Read, + byte.into(), + )); + address.increment(); + } + xor_into_sponge(state, &mut sponge_state, block.try_into().unwrap()); state.traces.push_keccak_bytes(sponge_state); - // TODO: Also push logic rows for XORs. - // TODO: Also push memory read rows. keccakf_u8s(&mut sponge_state); } - let final_inputs = input_blocks.remainder(); - sponge_state[..final_inputs.len()].copy_from_slice(final_inputs); - // pad10*1 rule - sponge_state[final_inputs.len()..KECCAK_RATE_BYTES].fill(0); - if final_inputs.len() == KECCAK_RATE_BYTES - 1 { - // Both 1s are placed in the same byte. - sponge_state[final_inputs.len()] = 0b10000001; - } else { - sponge_state[final_inputs.len()] = 1; - sponge_state[KECCAK_RATE_BYTES - 1] = 0b10000000; + for &byte in input_blocks.remainder() { + state.traces.push_memory(MemoryOp::new( + MemoryChannel::Code, + clock, + address, + MemoryOpKind::Read, + byte.into(), + )); + address.increment(); } + let mut final_block = [0u8; KECCAK_RATE_BYTES]; + final_block[..input_blocks.remainder().len()].copy_from_slice(input_blocks.remainder()); + // pad10*1 rule + if input_blocks.remainder().len() == KECCAK_RATE_BYTES - 1 { + // Both 1s are placed in the same byte. + final_block[input_blocks.remainder().len()] = 0b10000001; + } else { + final_block[input_blocks.remainder().len()] = 1; + final_block[KECCAK_RATE_BYTES - 1] = 0b10000000; + } + xor_into_sponge(state, &mut sponge_state, &final_block); state.traces.push_keccak_bytes(sponge_state); - // TODO: Also push logic rows for XORs. - // TODO: Also push memory read rows. state.traces.push_keccak_sponge(KeccakSpongeOp { base_address, - timestamp: state.traces.clock(), + timestamp: clock * NUM_CHANNELS, input, }); } diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index f6ae9910..abeef644 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -26,14 +26,6 @@ fn test_empty_txn_list() -> anyhow::Result<()> { let block_metadata = BlockMetadata::default(); - // TODO: This trie isn't working yet. - // let state_trie = PartialTrie::Leaf { - // nibbles: Nibbles { - // count: 5, - // packed: 0xABCDE.into(), - // }, - // value: vec![1, 2, 3], - // }; let state_trie = PartialTrie::Empty; let transactions_trie = PartialTrie::Empty; let receipts_trie = PartialTrie::Empty;