diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 130343b3..d1f993cd 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -114,6 +114,8 @@ fn ctl_memory(channel: usize) -> CrossTableLookup { #[cfg(test)] mod tests { + use std::borrow::BorrowMut; + use anyhow::Result; use itertools::{izip, Itertools}; use plonky2::field::polynomial::PolynomialValues; @@ -127,7 +129,6 @@ mod tests { use crate::all_stark::{all_cross_table_lookups, AllStark}; use crate::config::StarkConfig; - use crate::cpu::columns::{KECCAK_INPUT_LIMBS, KECCAK_OUTPUT_LIMBS}; use crate::cpu::cpu_stark::CpuStark; use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS}; use crate::logic::{self, LogicStark}; @@ -234,23 +235,28 @@ mod tests { }) .collect(); - let mut cpu_trace_rows = vec![]; + let mut cpu_trace_rows: Vec<[F; CpuStark::::COLUMNS]> = vec![]; for i in 0..num_keccak_perms { - let mut row = [F::ZERO; CpuStark::::COLUMNS]; - row[cpu::columns::IS_KECCAK] = F::ONE; - for (j, input, output) in - izip!(0..2 * NUM_INPUTS, KECCAK_INPUT_LIMBS, KECCAK_OUTPUT_LIMBS) - { - row[input] = keccak_input_limbs[i][j]; - row[output] = keccak_output_limbs[i][j]; + let mut row: cpu::columns::CpuColumnsView = + [F::ZERO; CpuStark::::COLUMNS].into(); + row.is_keccak = F::ONE; + for (j, input, output) in izip!( + 0..2 * NUM_INPUTS, + row.keccak_input_limbs.iter_mut(), + row.keccak_output_limbs.iter_mut() + ) { + *input = keccak_input_limbs[i][j]; + *output = keccak_output_limbs[i][j]; } - cpu_stark.generate(&mut row); - cpu_trace_rows.push(row); + cpu_stark.generate(row.borrow_mut()); + cpu_trace_rows.push(row.into()); } + for i in 0..num_logic_rows { - let mut row = [F::ZERO; CpuStark::::COLUMNS]; - row[cpu::columns::IS_CPU_CYCLE] = F::ONE; - row[cpu::columns::OPCODE] = [ + let mut row: cpu::columns::CpuColumnsView = + [F::ZERO; CpuStark::::COLUMNS].into(); + row.is_cpu_cycle = F::ONE; + row.opcode = [ (logic::columns::IS_AND, 0x16), (logic::columns::IS_OR, 0x17), (logic::columns::IS_XOR, 0x18), @@ -259,22 +265,24 @@ mod tests { .map(|(col, opcode)| logic_trace[col].values[i] * F::from_canonical_u64(opcode)) .sum(); for (cols_cpu, cols_logic) in [ - (cpu::columns::LOGIC_INPUT0, logic::columns::INPUT0), - (cpu::columns::LOGIC_INPUT1, logic::columns::INPUT1), + (&mut row.logic_input0, logic::columns::INPUT0), + (&mut row.logic_input1, logic::columns::INPUT1), ] { - for (col_cpu, limb_cols_logic) in - cols_cpu.zip(logic::columns::limb_bit_cols_for_input(cols_logic)) + for (col_cpu, limb_cols_logic) in cols_cpu + .iter_mut() + .zip(logic::columns::limb_bit_cols_for_input(cols_logic)) { - row[col_cpu] = + *col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i])); } } - for (col_cpu, col_logic) in cpu::columns::LOGIC_OUTPUT.zip(logic::columns::RESULT) { - row[col_cpu] = logic_trace[col_logic].values[i]; + for (col_cpu, col_logic) in row.logic_output.iter_mut().zip(logic::columns::RESULT) { + *col_cpu = logic_trace[col_logic].values[i]; } - cpu_stark.generate(&mut row); - cpu_trace_rows.push(row); + cpu_stark.generate(row.borrow_mut()); + cpu_trace_rows.push(row.into()); } + let mut current_cpu_index = 0; let mut last_timestamp = memory_trace[memory::columns::TIMESTAMP].values[0]; for i in 0..num_memory_ops { @@ -289,19 +297,17 @@ mod tests { last_timestamp = mem_timestamp; } - cpu_trace_rows[current_cpu_index][cpu::columns::mem_channel_used(op)] = F::ONE; - cpu_trace_rows[current_cpu_index][cpu::columns::CLOCK] = clock; - cpu_trace_rows[current_cpu_index][cpu::columns::mem_is_read(op)] = - memory_trace[memory::columns::IS_READ].values[i]; - cpu_trace_rows[current_cpu_index][cpu::columns::mem_addr_context(op)] = - memory_trace[memory::columns::ADDR_CONTEXT].values[i]; - cpu_trace_rows[current_cpu_index][cpu::columns::mem_addr_segment(op)] = - memory_trace[memory::columns::ADDR_SEGMENT].values[i]; - cpu_trace_rows[current_cpu_index][cpu::columns::mem_addr_virtual(op)] = - memory_trace[memory::columns::ADDR_VIRTUAL].values[i]; + let row: &mut cpu::columns::CpuColumnsView = + cpu_trace_rows[current_cpu_index].borrow_mut(); + + row.mem_channel_used[op] = F::ONE; + row.clock = clock; + row.mem_is_read[op] = memory_trace[memory::columns::IS_READ].values[i]; + row.mem_addr_context[op] = memory_trace[memory::columns::ADDR_CONTEXT].values[i]; + row.mem_addr_segment[op] = memory_trace[memory::columns::ADDR_SEGMENT].values[i]; + row.mem_addr_virtual[op] = memory_trace[memory::columns::ADDR_VIRTUAL].values[i]; for j in 0..8 { - cpu_trace_rows[current_cpu_index][cpu::columns::mem_value(op, j)] = - memory_trace[memory::columns::value_limb(j)].values[i]; + row.mem_value[op][j] = memory_trace[memory::columns::value_limb(j)].values[i]; } } trace_rows_to_poly_values(cpu_trace_rows) diff --git a/evm/src/cpu/columns.rs b/evm/src/cpu/columns.rs index 8b65b84f..1fce6cb7 100644 --- a/evm/src/cpu/columns.rs +++ b/evm/src/cpu/columns.rs @@ -1,207 +1,239 @@ // TODO: remove when possible. #![allow(dead_code)] -use std::ops::Range; +use std::borrow::{Borrow, BorrowMut}; +use std::mem::{size_of, transmute, transmute_copy, ManuallyDrop}; +use std::ops::{Index, IndexMut}; use crate::memory; -/// Filter. 1 if the row is part of bootstrapping the kernel code, 0 otherwise. -pub const IS_BOOTSTRAP_KERNEL: usize = 0; +#[repr(C)] +pub struct CpuColumnsView { + /// Filter. 1 if the row is part of bootstrapping the kernel code, 0 otherwise. + pub is_bootstrap_kernel: T, -/// Filter. 1 if the row is part of bootstrapping a contract's code, 0 otherwise. -pub const IS_BOOTSTRAP_CONTRACT: usize = IS_BOOTSTRAP_KERNEL + 1; + /// Filter. 1 if the row is part of bootstrapping a contract's code, 0 otherwise. + pub is_bootstrap_contract: T, -/// Filter. 1 if the row corresponds to a cycle of execution and 0 otherwise. -/// Lets us re-use decode columns in non-cycle rows. -pub const IS_CPU_CYCLE: usize = IS_BOOTSTRAP_CONTRACT + 1; + /// Filter. 1 if the row corresponds to a cycle of execution and 0 otherwise. + /// Lets us re-use decode columns in non-cycle rows. + pub is_cpu_cycle: T, -/// If CPU cycle: The opcode being decoded, in {0, ..., 255}. -pub const OPCODE: usize = IS_CPU_CYCLE + 1; + /// If CPU cycle: The opcode being decoded, in {0, ..., 255}. + pub opcode: T, -// If CPU cycle: flags for EVM instructions. PUSHn, DUPn, and SWAPn only get one flag each. Invalid -// opcodes are split between a number of flags for practical reasons. Exactly one of these flags -// must be 1. -pub const IS_STOP: usize = OPCODE + 1; -pub const IS_ADD: usize = IS_STOP + 1; -pub const IS_MUL: usize = IS_ADD + 1; -pub const IS_SUB: usize = IS_MUL + 1; -pub const IS_DIV: usize = IS_SUB + 1; -pub const IS_SDIV: usize = IS_DIV + 1; -pub const IS_MOD: usize = IS_SDIV + 1; -pub const IS_SMOD: usize = IS_MOD + 1; -pub const IS_ADDMOD: usize = IS_SMOD + 1; -pub const IS_MULMOD: usize = IS_ADDMOD + 1; -pub const IS_EXP: usize = IS_MULMOD + 1; -pub const IS_SIGNEXTEND: usize = IS_EXP + 1; -pub const IS_LT: usize = IS_SIGNEXTEND + 1; -pub const IS_GT: usize = IS_LT + 1; -pub const IS_SLT: usize = IS_GT + 1; -pub const IS_SGT: usize = IS_SLT + 1; -pub const IS_EQ: usize = IS_SGT + 1; // Note: This column must be 0 when is_cpu_cycle = 0. -pub const IS_ISZERO: usize = IS_EQ + 1; // Note: This column must be 0 when is_cpu_cycle = 0. -pub const IS_AND: usize = IS_ISZERO + 1; -pub const IS_OR: usize = IS_AND + 1; -pub const IS_XOR: usize = IS_OR + 1; -pub const IS_NOT: usize = IS_XOR + 1; -pub const IS_BYTE: usize = IS_NOT + 1; -pub const IS_SHL: usize = IS_BYTE + 1; -pub const IS_SHR: usize = IS_SHL + 1; -pub const IS_SAR: usize = IS_SHR + 1; -pub const IS_SHA3: usize = IS_SAR + 1; -pub const IS_ADDRESS: usize = IS_SHA3 + 1; -pub const IS_BALANCE: usize = IS_ADDRESS + 1; -pub const IS_ORIGIN: usize = IS_BALANCE + 1; -pub const IS_CALLER: usize = IS_ORIGIN + 1; -pub const IS_CALLVALUE: usize = IS_CALLER + 1; -pub const IS_CALLDATALOAD: usize = IS_CALLVALUE + 1; -pub const IS_CALLDATASIZE: usize = IS_CALLDATALOAD + 1; -pub const IS_CALLDATACOPY: usize = IS_CALLDATASIZE + 1; -pub const IS_CODESIZE: usize = IS_CALLDATACOPY + 1; -pub const IS_CODECOPY: usize = IS_CODESIZE + 1; -pub const IS_GASPRICE: usize = IS_CODECOPY + 1; -pub const IS_EXTCODESIZE: usize = IS_GASPRICE + 1; -pub const IS_EXTCODECOPY: usize = IS_EXTCODESIZE + 1; -pub const IS_RETURNDATASIZE: usize = IS_EXTCODECOPY + 1; -pub const IS_RETURNDATACOPY: usize = IS_RETURNDATASIZE + 1; -pub const IS_EXTCODEHASH: usize = IS_RETURNDATACOPY + 1; -pub const IS_BLOCKHASH: usize = IS_EXTCODEHASH + 1; -pub const IS_COINBASE: usize = IS_BLOCKHASH + 1; -pub const IS_TIMESTAMP: usize = IS_COINBASE + 1; -pub const IS_NUMBER: usize = IS_TIMESTAMP + 1; -pub const IS_DIFFICULTY: usize = IS_NUMBER + 1; -pub const IS_GASLIMIT: usize = IS_DIFFICULTY + 1; -pub const IS_CHAINID: usize = IS_GASLIMIT + 1; -pub const IS_SELFBALANCE: usize = IS_CHAINID + 1; -pub const IS_BASEFEE: usize = IS_SELFBALANCE + 1; -pub const IS_POP: usize = IS_BASEFEE + 1; -pub const IS_MLOAD: usize = IS_POP + 1; -pub const IS_MSTORE: usize = IS_MLOAD + 1; -pub const IS_MSTORE8: usize = IS_MSTORE + 1; -pub const IS_SLOAD: usize = IS_MSTORE8 + 1; -pub const IS_SSTORE: usize = IS_SLOAD + 1; -pub const IS_JUMP: usize = IS_SSTORE + 1; -pub const IS_JUMPI: usize = IS_JUMP + 1; -pub const IS_PC: usize = IS_JUMPI + 1; -pub const IS_MSIZE: usize = IS_PC + 1; -pub const IS_GAS: usize = IS_MSIZE + 1; -pub const IS_JUMPDEST: usize = IS_GAS + 1; -// Find the number of bytes to push by reading the bottom 5 bits of the opcode. -pub const IS_PUSH: usize = IS_JUMPDEST + 1; -// Find the stack offset to duplicate by reading the bottom 4 bits of the opcode. -pub const IS_DUP: usize = IS_PUSH + 1; -// Find the stack offset to swap with by reading the bottom 4 bits of the opcode. -pub const IS_SWAP: usize = IS_DUP + 1; -pub const IS_LOG0: usize = IS_SWAP + 1; -pub const IS_LOG1: usize = IS_LOG0 + 1; -pub const IS_LOG2: usize = IS_LOG1 + 1; -pub const IS_LOG3: usize = IS_LOG2 + 1; -pub const IS_LOG4: usize = IS_LOG3 + 1; -pub const IS_CREATE: usize = IS_LOG4 + 1; -pub const IS_CALL: usize = IS_CREATE + 1; -pub const IS_CALLCODE: usize = IS_CALL + 1; -pub const IS_RETURN: usize = IS_CALLCODE + 1; -pub const IS_DELEGATECALL: usize = IS_RETURN + 1; -pub const IS_CREATE2: usize = IS_DELEGATECALL + 1; -pub const IS_STATICCALL: usize = IS_CREATE2 + 1; -pub const IS_REVERT: usize = IS_STATICCALL + 1; -pub const IS_SELFDESTRUCT: usize = IS_REVERT + 1; + // If CPU cycle: flags for EVM instructions. PUSHn, DUPn, and SWAPn only get one flag each. + // Invalid opcodes are split between a number of flags for practical reasons. Exactly one of + // these flags must be 1. + pub is_stop: T, + pub is_add: T, + pub is_mul: T, + pub is_sub: T, + pub is_div: T, + pub is_sdiv: T, + pub is_mod: T, + pub is_smod: T, + pub is_addmod: T, + pub is_mulmod: T, + pub is_exp: T, + pub is_signextend: T, + pub is_lt: T, + pub is_gt: T, + pub is_slt: T, + pub is_sgt: T, + pub is_eq: T, // Note: This column must be 0 when is_cpu_cycle = 0. + pub is_iszero: T, // Note: This column must be 0 when is_cpu_cycle = 0. + pub is_and: T, + pub is_or: T, + pub is_xor: T, + pub is_not: T, + pub is_byte: T, + pub is_shl: T, + pub is_shr: T, + pub is_sar: T, + pub is_sha3: T, + pub is_address: T, + pub is_balance: T, + pub is_origin: T, + pub is_caller: T, + pub is_callvalue: T, + pub is_calldataload: T, + pub is_calldatasize: T, + pub is_calldatacopy: T, + pub is_codesize: T, + pub is_codecopy: T, + pub is_gasprice: T, + pub is_extcodesize: T, + pub is_extcodecopy: T, + pub is_returndatasize: T, + pub is_returndatacopy: T, + pub is_extcodehash: T, + pub is_blockhash: T, + pub is_coinbase: T, + pub is_timestamp: T, + pub is_number: T, + pub is_difficulty: T, + pub is_gaslimit: T, + pub is_chainid: T, + pub is_selfbalance: T, + pub is_basefee: T, + pub is_pop: T, + pub is_mload: T, + pub is_mstore: T, + pub is_mstore8: T, + pub is_sload: T, + pub is_sstore: T, + pub is_jump: T, + pub is_jumpi: T, + pub is_pc: T, + pub is_msize: T, + pub is_gas: T, + pub is_jumpdest: T, + pub is_push: T, + pub is_dup: T, + pub is_swap: T, + pub is_log0: T, + pub is_log1: T, + pub is_log2: T, + pub is_log3: T, + pub is_log4: T, + pub is_create: T, + pub is_call: T, + pub is_callcode: T, + pub is_return: T, + pub is_delegatecall: T, + pub is_create2: T, + pub is_staticcall: T, + pub is_revert: T, + pub is_selfdestruct: T, -pub const IS_INVALID_0: usize = IS_SELFDESTRUCT + 1; -pub const IS_INVALID_1: usize = IS_INVALID_0 + 1; -pub const IS_INVALID_2: usize = IS_INVALID_1 + 1; -pub const IS_INVALID_3: usize = IS_INVALID_2 + 1; -pub const IS_INVALID_4: usize = IS_INVALID_3 + 1; -pub const IS_INVALID_5: usize = IS_INVALID_4 + 1; -pub const IS_INVALID_6: usize = IS_INVALID_5 + 1; -pub const IS_INVALID_7: usize = IS_INVALID_6 + 1; -pub const IS_INVALID_8: usize = IS_INVALID_7 + 1; -pub const IS_INVALID_9: usize = IS_INVALID_8 + 1; -pub const IS_INVALID_10: usize = IS_INVALID_9 + 1; -pub const IS_INVALID_11: usize = IS_INVALID_10 + 1; -pub const IS_INVALID_12: usize = IS_INVALID_11 + 1; -pub const IS_INVALID_13: usize = IS_INVALID_12 + 1; -pub const IS_INVALID_14: usize = IS_INVALID_13 + 1; -pub const IS_INVALID_15: usize = IS_INVALID_14 + 1; -pub const IS_INVALID_16: usize = IS_INVALID_15 + 1; -pub const IS_INVALID_17: usize = IS_INVALID_16 + 1; -pub const IS_INVALID_18: usize = IS_INVALID_17 + 1; -pub const IS_INVALID_19: usize = IS_INVALID_18 + 1; -pub const IS_INVALID_20: usize = IS_INVALID_19 + 1; -// An instruction is invalid if _any_ of the above flags is 1. + // An instruction is invalid if _any_ of the below flags is 1. + pub is_invalid_0: T, + pub is_invalid_1: T, + pub is_invalid_2: T, + pub is_invalid_3: T, + pub is_invalid_4: T, + pub is_invalid_5: T, + pub is_invalid_6: T, + pub is_invalid_7: T, + pub is_invalid_8: T, + pub is_invalid_9: T, + pub is_invalid_10: T, + pub is_invalid_11: T, + pub is_invalid_12: T, + pub is_invalid_13: T, + pub is_invalid_14: T, + pub is_invalid_15: T, + pub is_invalid_16: T, + pub is_invalid_17: T, + pub is_invalid_18: T, + pub is_invalid_19: T, + pub is_invalid_20: T, -pub const START_INSTRUCTION_FLAGS: usize = IS_STOP; -pub const END_INSTRUCTION_FLAGS: usize = IS_INVALID_20 + 1; + /// If CPU cycle: the opcode, broken up into bits in **big-endian** order. + pub opcode_bits: [T; 8], -/// If CPU cycle: the opcode, broken up into bits. -/// **Big-endian** order. -pub const OPCODE_BITS: [usize; 8] = [ - END_INSTRUCTION_FLAGS, - END_INSTRUCTION_FLAGS + 1, - END_INSTRUCTION_FLAGS + 2, - END_INSTRUCTION_FLAGS + 3, - END_INSTRUCTION_FLAGS + 4, - END_INSTRUCTION_FLAGS + 5, - END_INSTRUCTION_FLAGS + 6, - END_INSTRUCTION_FLAGS + 7, -]; + /// Filter. 1 iff a Keccak permutation is computed on this row. + pub is_keccak: T, + pub keccak_input_limbs: [T; 50], + pub keccak_output_limbs: [T; 50], -/// Filter. 1 iff a Keccak permutation is computed on this row. -pub const IS_KECCAK: usize = OPCODE_BITS[OPCODE_BITS.len() - 1] + 1; + // Assuming a limb size of 16 bits. This can be changed, but it must be <= 28 bits. + // TODO: These input/output columns can be shared between the logic operations and others. + pub logic_input0: [T; 16], + pub logic_input1: [T; 16], + pub logic_output: [T; 16], + pub simple_logic_diff: T, + pub simple_logic_diff_inv: T, -pub const START_KECCAK_INPUT: usize = IS_KECCAK + 1; -pub const KECCAK_INPUT_LIMBS: Range = START_KECCAK_INPUT..START_KECCAK_INPUT + 50; - -pub const START_KECCAK_OUTPUT: usize = KECCAK_INPUT_LIMBS.end; -pub const KECCAK_OUTPUT_LIMBS: Range = START_KECCAK_OUTPUT..START_KECCAK_OUTPUT + 50; - -// Assuming a limb size of 16 bits. This can be changed, but it must be <= 28 bits. -// TODO: These input/output columns can be shared between the logic operations and others. -pub const LOGIC_INPUT0: Range = KECCAK_OUTPUT_LIMBS.end..KECCAK_OUTPUT_LIMBS.end + 16; -pub const LOGIC_INPUT1: Range = LOGIC_INPUT0.end..LOGIC_INPUT0.end + 16; -pub const LOGIC_OUTPUT: Range = LOGIC_INPUT1.end..LOGIC_INPUT1.end + 16; - -pub const SIMPLE_LOGIC_DIFF: usize = LOGIC_OUTPUT.end; -pub const SIMPLE_LOGIC_DIFF_INV: usize = SIMPLE_LOGIC_DIFF + 1; - -pub(crate) const CLOCK: usize = SIMPLE_LOGIC_DIFF_INV + 1; - -/// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise 0. -const MEM_CHANNEL_USED_START: usize = CLOCK + 1; -pub const fn mem_channel_used(channel: usize) -> usize { - debug_assert!(channel < memory::NUM_CHANNELS); - MEM_CHANNEL_USED_START + channel + pub(crate) clock: T, + /// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise + /// 0. + pub mem_channel_used: [T; memory::NUM_CHANNELS], + pub mem_is_read: [T; memory::NUM_CHANNELS], + pub mem_addr_context: [T; memory::NUM_CHANNELS], + pub mem_addr_segment: [T; memory::NUM_CHANNELS], + pub mem_addr_virtual: [T; memory::NUM_CHANNELS], + pub mem_value: [[T; memory::VALUE_LIMBS]; memory::NUM_CHANNELS], } -const MEM_ISREAD_START: usize = MEM_CHANNEL_USED_START + memory::NUM_CHANNELS; -pub const fn mem_is_read(channel: usize) -> usize { - debug_assert!(channel < memory::NUM_CHANNELS); - MEM_ISREAD_START + channel +// `u8` is guaranteed to have a `size_of` of 1. +pub const NUM_CPU_COLUMNS: usize = size_of::>(); + +unsafe fn transmute_no_compile_time_size_checks(value: T) -> U { + debug_assert_eq!(size_of::(), size_of::()); + // Need ManuallyDrop so that `value` is not dropped by this function. + let value = ManuallyDrop::new(value); + // Copy the bit pattern. The original value is no longer safe to use. + transmute_copy(&value) } -const MEM_ADDR_CONTEXT_START: usize = MEM_ISREAD_START + memory::NUM_CHANNELS; -pub const fn mem_addr_context(channel: usize) -> usize { - debug_assert!(channel < memory::NUM_CHANNELS); - MEM_ADDR_CONTEXT_START + channel +impl From<[T; NUM_CPU_COLUMNS]> for CpuColumnsView { + fn from(value: [T; NUM_CPU_COLUMNS]) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } } -const MEM_ADDR_SEGMENT_START: usize = MEM_ADDR_CONTEXT_START + memory::NUM_CHANNELS; -pub const fn mem_addr_segment(channel: usize) -> usize { - debug_assert!(channel < memory::NUM_CHANNELS); - MEM_ADDR_SEGMENT_START + channel +impl From> for [T; NUM_CPU_COLUMNS] { + fn from(value: CpuColumnsView) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } } -const MEM_ADDR_VIRTUAL_START: usize = MEM_ADDR_SEGMENT_START + memory::NUM_CHANNELS; -pub const fn mem_addr_virtual(channel: usize) -> usize { - debug_assert!(channel < memory::NUM_CHANNELS); - MEM_ADDR_VIRTUAL_START + channel +impl Borrow> for [T; NUM_CPU_COLUMNS] { + fn borrow(&self) -> &CpuColumnsView { + unsafe { transmute(self) } + } } -const MEM_ADDR_VALUE_START: usize = MEM_ADDR_VIRTUAL_START + memory::NUM_CHANNELS; -pub const fn mem_value(channel: usize, limb: usize) -> usize { - debug_assert!(channel < memory::NUM_CHANNELS); - debug_assert!(limb < memory::VALUE_LIMBS); - MEM_ADDR_VALUE_START + channel * memory::VALUE_LIMBS + limb +impl BorrowMut> for [T; NUM_CPU_COLUMNS] { + fn borrow_mut(&mut self) -> &mut CpuColumnsView { + unsafe { transmute(self) } + } } -pub const NUM_CPU_COLUMNS: usize = - MEM_ADDR_VALUE_START + memory::NUM_CHANNELS * memory::VALUE_LIMBS; +impl Borrow<[T; NUM_CPU_COLUMNS]> for CpuColumnsView { + fn borrow(&self) -> &[T; NUM_CPU_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl BorrowMut<[T; NUM_CPU_COLUMNS]> for CpuColumnsView { + fn borrow_mut(&mut self) -> &mut [T; NUM_CPU_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl Index for CpuColumnsView +where + [T]: Index, +{ + type Output = <[T] as Index>::Output; + + fn index(&self, index: I) -> &Self::Output { + let arr: &[T; NUM_CPU_COLUMNS] = self.borrow(); + <[T] as Index>::index(arr, index) + } +} + +impl IndexMut for CpuColumnsView +where + [T]: IndexMut, +{ + fn index_mut(&mut self, index: I) -> &mut Self::Output { + let arr: &mut [T; NUM_CPU_COLUMNS] = self.borrow_mut(); + <[T] as IndexMut>::index_mut(arr, index) + } +} + +const fn make_col_map() -> CpuColumnsView { + let mut indices_arr = [0; NUM_CPU_COLUMNS]; + let mut i = 0; + while i < NUM_CPU_COLUMNS { + indices_arr[i] = i; + i += 1; + } + unsafe { transmute::<[usize; NUM_CPU_COLUMNS], CpuColumnsView>(indices_arr) } +} + +pub const COL_MAP: CpuColumnsView = make_col_map(); diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 1c8522c3..ee0cf98e 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -1,3 +1,4 @@ +use std::borrow::{Borrow, BorrowMut}; use std::marker::PhantomData; use itertools::Itertools; @@ -7,53 +8,51 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cpu::{columns, decode, simple_logic}; +use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; +use crate::cpu::{decode, simple_logic}; use crate::cross_table_lookup::Column; use crate::memory::NUM_CHANNELS; -use crate::permutation::PermutationPair; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; pub fn ctl_data_keccak() -> Vec> { - let mut res: Vec<_> = columns::KECCAK_INPUT_LIMBS.map(Column::single).collect(); - res.extend(columns::KECCAK_OUTPUT_LIMBS.map(Column::single)); + let mut res: Vec<_> = Column::singles(COL_MAP.keccak_input_limbs).collect(); + res.extend(Column::singles(COL_MAP.keccak_output_limbs)); res } pub fn ctl_filter_keccak() -> Column { - Column::single(columns::IS_KECCAK) + Column::single(COL_MAP.is_keccak) } pub fn ctl_data_logic() -> Vec> { - let mut res = Column::singles([columns::IS_AND, columns::IS_OR, columns::IS_XOR]).collect_vec(); - res.extend(columns::LOGIC_INPUT0.map(Column::single)); - res.extend(columns::LOGIC_INPUT1.map(Column::single)); - res.extend(columns::LOGIC_OUTPUT.map(Column::single)); + let mut res = Column::singles([COL_MAP.is_and, COL_MAP.is_or, COL_MAP.is_xor]).collect_vec(); + res.extend(Column::singles(COL_MAP.logic_input0)); + res.extend(Column::singles(COL_MAP.logic_input1)); + res.extend(Column::singles(COL_MAP.logic_output)); res } pub fn ctl_filter_logic() -> Column { - Column::sum([columns::IS_AND, columns::IS_OR, columns::IS_XOR]) + Column::sum([COL_MAP.is_and, COL_MAP.is_or, COL_MAP.is_xor]) } pub fn ctl_data_memory(channel: usize) -> Vec> { debug_assert!(channel < NUM_CHANNELS); let mut cols: Vec> = Column::singles([ - columns::CLOCK, - columns::mem_is_read(channel), - columns::mem_addr_context(channel), - columns::mem_addr_segment(channel), - columns::mem_addr_virtual(channel), + COL_MAP.clock, + COL_MAP.mem_is_read[channel], + COL_MAP.mem_addr_context[channel], + COL_MAP.mem_addr_segment[channel], + COL_MAP.mem_addr_virtual[channel], ]) .collect_vec(); - cols.extend(Column::singles( - (0..8).map(|j| columns::mem_value(channel, j)), - )); + cols.extend(Column::singles(COL_MAP.mem_value[channel])); cols } pub fn ctl_filter_memory(channel: usize) -> Column { - Column::single(columns::mem_channel_used(channel)) + Column::single(COL_MAP.mem_channel_used[channel]) } #[derive(Copy, Clone)] @@ -62,14 +61,15 @@ pub struct CpuStark { } impl CpuStark { - pub fn generate(&self, local_values: &mut [F; columns::NUM_CPU_COLUMNS]) { + pub fn generate(&self, local_values: &mut [F; NUM_CPU_COLUMNS]) { + let local_values: &mut CpuColumnsView<_> = local_values.borrow_mut(); decode::generate(local_values); simple_logic::generate(local_values); } } impl, const D: usize> Stark for CpuStark { - const COLUMNS: usize = columns::NUM_CPU_COLUMNS; + const COLUMNS: usize = NUM_CPU_COLUMNS; const PUBLIC_INPUTS: usize = 0; fn eval_packed_generic( @@ -80,8 +80,9 @@ impl, const D: usize> Stark for CpuStark, P: PackedField, { - decode::eval_packed_generic(vars.local_values, yield_constr); - simple_logic::eval_packed(vars.local_values, yield_constr); + let local_values = vars.local_values.borrow(); + decode::eval_packed_generic(local_values, yield_constr); + simple_logic::eval_packed(local_values, yield_constr); } fn eval_ext_circuit( @@ -90,17 +91,14 @@ impl, const D: usize> Stark for CpuStark, yield_constr: &mut RecursiveConstraintConsumer, ) { - decode::eval_ext_circuit(builder, vars.local_values, yield_constr); - simple_logic::eval_ext_circuit(builder, vars.local_values, yield_constr); + let local_values = vars.local_values.borrow(); + decode::eval_ext_circuit(builder, local_values, yield_constr); + simple_logic::eval_ext_circuit(builder, local_values, yield_constr); } fn constraint_degree(&self) -> usize { 3 } - - fn permutation_pairs(&self) -> Vec { - vec![] - } } #[cfg(test)] diff --git a/evm/src/cpu/decode.rs b/evm/src/cpu/decode.rs index 69d6c810..0b091558 100644 --- a/evm/src/cpu/decode.rs +++ b/evm/src/cpu/decode.rs @@ -5,7 +5,7 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cpu::columns; +use crate::cpu::columns::{CpuColumnsView, COL_MAP}; // List of opcode blocks // Each block corresponds to exactly one flag, and each flag corresponds to exactly one block. @@ -17,127 +17,126 @@ use crate::cpu::columns; // top 8-n bits. const OPCODES: [(u64, usize, usize); 102] = [ // (start index of block, number of top bits to check (log2), flag column) - (0x00, 0, columns::IS_STOP), - (0x01, 0, columns::IS_ADD), - (0x02, 0, columns::IS_MUL), - (0x03, 0, columns::IS_SUB), - (0x04, 0, columns::IS_DIV), - (0x05, 0, columns::IS_SDIV), - (0x06, 0, columns::IS_MOD), - (0x07, 0, columns::IS_SMOD), - (0x08, 0, columns::IS_ADDMOD), - (0x09, 0, columns::IS_MULMOD), - (0x0a, 0, columns::IS_EXP), - (0x0b, 0, columns::IS_SIGNEXTEND), - (0x0c, 2, columns::IS_INVALID_0), // 0x0c-0x0f - (0x10, 0, columns::IS_LT), - (0x11, 0, columns::IS_GT), - (0x12, 0, columns::IS_SLT), - (0x13, 0, columns::IS_SGT), - (0x14, 0, columns::IS_EQ), - (0x15, 0, columns::IS_ISZERO), - (0x16, 0, columns::IS_AND), - (0x17, 0, columns::IS_OR), - (0x18, 0, columns::IS_XOR), - (0x19, 0, columns::IS_NOT), - (0x1a, 0, columns::IS_BYTE), - (0x1b, 0, columns::IS_SHL), - (0x1c, 0, columns::IS_SHR), - (0x1d, 0, columns::IS_SAR), - (0x1e, 1, columns::IS_INVALID_1), // 0x1e-0x1f - (0x20, 0, columns::IS_SHA3), - (0x21, 0, columns::IS_INVALID_2), - (0x22, 1, columns::IS_INVALID_3), // 0x22-0x23 - (0x24, 2, columns::IS_INVALID_4), // 0x24-0x27 - (0x28, 3, columns::IS_INVALID_5), // 0x28-0x2f - (0x30, 0, columns::IS_ADDRESS), - (0x31, 0, columns::IS_BALANCE), - (0x32, 0, columns::IS_ORIGIN), - (0x33, 0, columns::IS_CALLER), - (0x34, 0, columns::IS_CALLVALUE), - (0x35, 0, columns::IS_CALLDATALOAD), - (0x36, 0, columns::IS_CALLDATASIZE), - (0x37, 0, columns::IS_CALLDATACOPY), - (0x38, 0, columns::IS_CODESIZE), - (0x39, 0, columns::IS_CODECOPY), - (0x3a, 0, columns::IS_GASPRICE), - (0x3b, 0, columns::IS_EXTCODESIZE), - (0x3c, 0, columns::IS_EXTCODECOPY), - (0x3d, 0, columns::IS_RETURNDATASIZE), - (0x3e, 0, columns::IS_RETURNDATACOPY), - (0x3f, 0, columns::IS_EXTCODEHASH), - (0x40, 0, columns::IS_BLOCKHASH), - (0x41, 0, columns::IS_COINBASE), - (0x42, 0, columns::IS_TIMESTAMP), - (0x43, 0, columns::IS_NUMBER), - (0x44, 0, columns::IS_DIFFICULTY), - (0x45, 0, columns::IS_GASLIMIT), - (0x46, 0, columns::IS_CHAINID), - (0x47, 0, columns::IS_SELFBALANCE), - (0x48, 0, columns::IS_BASEFEE), - (0x49, 0, columns::IS_INVALID_6), - (0x4a, 1, columns::IS_INVALID_7), // 0x4a-0x4b - (0x4c, 2, columns::IS_INVALID_8), // 0x4c-0x4f - (0x50, 0, columns::IS_POP), - (0x51, 0, columns::IS_MLOAD), - (0x52, 0, columns::IS_MSTORE), - (0x53, 0, columns::IS_MSTORE8), - (0x54, 0, columns::IS_SLOAD), - (0x55, 0, columns::IS_SSTORE), - (0x56, 0, columns::IS_JUMP), - (0x57, 0, columns::IS_JUMPI), - (0x58, 0, columns::IS_PC), - (0x59, 0, columns::IS_MSIZE), - (0x5a, 0, columns::IS_GAS), - (0x5b, 0, columns::IS_JUMPDEST), - (0x5c, 2, columns::IS_INVALID_9), // 0x5c-0x5f - (0x60, 5, columns::IS_PUSH), // 0x60-0x7f - (0x80, 4, columns::IS_DUP), // 0x80-0x8f - (0x90, 4, columns::IS_SWAP), // 0x90-0x9f - (0xa0, 0, columns::IS_LOG0), - (0xa1, 0, columns::IS_LOG1), - (0xa2, 0, columns::IS_LOG2), - (0xa3, 0, columns::IS_LOG3), - (0xa4, 0, columns::IS_LOG4), - (0xa5, 0, columns::IS_INVALID_10), - (0xa6, 1, columns::IS_INVALID_11), // 0xa6-0xa7 - (0xa8, 3, columns::IS_INVALID_12), // 0xa8-0xaf - (0xb0, 4, columns::IS_INVALID_13), // 0xb0-0xbf - (0xc0, 5, columns::IS_INVALID_14), // 0xc0-0xdf - (0xe0, 4, columns::IS_INVALID_15), // 0xe0-0xef - (0xf0, 0, columns::IS_CREATE), - (0xf1, 0, columns::IS_CALL), - (0xf2, 0, columns::IS_CALLCODE), - (0xf3, 0, columns::IS_RETURN), - (0xf4, 0, columns::IS_DELEGATECALL), - (0xf5, 0, columns::IS_CREATE2), - (0xf6, 1, columns::IS_INVALID_16), // 0xf6-0xf7 - (0xf8, 1, columns::IS_INVALID_17), // 0xf8-0xf9 - (0xfa, 0, columns::IS_STATICCALL), - (0xfb, 0, columns::IS_INVALID_18), - (0xfc, 0, columns::IS_INVALID_19), - (0xfd, 0, columns::IS_REVERT), - (0xfe, 0, columns::IS_INVALID_20), - (0xff, 0, columns::IS_SELFDESTRUCT), + (0x00, 0, COL_MAP.is_stop), + (0x01, 0, COL_MAP.is_add), + (0x02, 0, COL_MAP.is_mul), + (0x03, 0, COL_MAP.is_sub), + (0x04, 0, COL_MAP.is_div), + (0x05, 0, COL_MAP.is_sdiv), + (0x06, 0, COL_MAP.is_mod), + (0x07, 0, COL_MAP.is_smod), + (0x08, 0, COL_MAP.is_addmod), + (0x09, 0, COL_MAP.is_mulmod), + (0x0a, 0, COL_MAP.is_exp), + (0x0b, 0, COL_MAP.is_signextend), + (0x0c, 2, COL_MAP.is_invalid_0), // 0x0c-0x0f + (0x10, 0, COL_MAP.is_lt), + (0x11, 0, COL_MAP.is_gt), + (0x12, 0, COL_MAP.is_slt), + (0x13, 0, COL_MAP.is_sgt), + (0x14, 0, COL_MAP.is_eq), + (0x15, 0, COL_MAP.is_iszero), + (0x16, 0, COL_MAP.is_and), + (0x17, 0, COL_MAP.is_or), + (0x18, 0, COL_MAP.is_xor), + (0x19, 0, COL_MAP.is_not), + (0x1a, 0, COL_MAP.is_byte), + (0x1b, 0, COL_MAP.is_shl), + (0x1c, 0, COL_MAP.is_shr), + (0x1d, 0, COL_MAP.is_sar), + (0x1e, 1, COL_MAP.is_invalid_1), // 0x1e-0x1f + (0x20, 0, COL_MAP.is_sha3), + (0x21, 0, COL_MAP.is_invalid_2), + (0x22, 1, COL_MAP.is_invalid_3), // 0x22-0x23 + (0x24, 2, COL_MAP.is_invalid_4), // 0x24-0x27 + (0x28, 3, COL_MAP.is_invalid_5), // 0x28-0x2f + (0x30, 0, COL_MAP.is_address), + (0x31, 0, COL_MAP.is_balance), + (0x32, 0, COL_MAP.is_origin), + (0x33, 0, COL_MAP.is_caller), + (0x34, 0, COL_MAP.is_callvalue), + (0x35, 0, COL_MAP.is_calldataload), + (0x36, 0, COL_MAP.is_calldatasize), + (0x37, 0, COL_MAP.is_calldatacopy), + (0x38, 0, COL_MAP.is_codesize), + (0x39, 0, COL_MAP.is_codecopy), + (0x3a, 0, COL_MAP.is_gasprice), + (0x3b, 0, COL_MAP.is_extcodesize), + (0x3c, 0, COL_MAP.is_extcodecopy), + (0x3d, 0, COL_MAP.is_returndatasize), + (0x3e, 0, COL_MAP.is_returndatacopy), + (0x3f, 0, COL_MAP.is_extcodehash), + (0x40, 0, COL_MAP.is_blockhash), + (0x41, 0, COL_MAP.is_coinbase), + (0x42, 0, COL_MAP.is_timestamp), + (0x43, 0, COL_MAP.is_number), + (0x44, 0, COL_MAP.is_difficulty), + (0x45, 0, COL_MAP.is_gaslimit), + (0x46, 0, COL_MAP.is_chainid), + (0x47, 0, COL_MAP.is_selfbalance), + (0x48, 0, COL_MAP.is_basefee), + (0x49, 0, COL_MAP.is_invalid_6), + (0x4a, 1, COL_MAP.is_invalid_7), // 0x4a-0x4b + (0x4c, 2, COL_MAP.is_invalid_8), // 0x4c-0x4f + (0x50, 0, COL_MAP.is_pop), + (0x51, 0, COL_MAP.is_mload), + (0x52, 0, COL_MAP.is_mstore), + (0x53, 0, COL_MAP.is_mstore8), + (0x54, 0, COL_MAP.is_sload), + (0x55, 0, COL_MAP.is_sstore), + (0x56, 0, COL_MAP.is_jump), + (0x57, 0, COL_MAP.is_jumpi), + (0x58, 0, COL_MAP.is_pc), + (0x59, 0, COL_MAP.is_msize), + (0x5a, 0, COL_MAP.is_gas), + (0x5b, 0, COL_MAP.is_jumpdest), + (0x5c, 2, COL_MAP.is_invalid_9), // 0x5c-0x5f + (0x60, 5, COL_MAP.is_push), // 0x60-0x7f + (0x80, 4, COL_MAP.is_dup), // 0x80-0x8f + (0x90, 4, COL_MAP.is_swap), // 0x90-0x9f + (0xa0, 0, COL_MAP.is_log0), + (0xa1, 0, COL_MAP.is_log1), + (0xa2, 0, COL_MAP.is_log2), + (0xa3, 0, COL_MAP.is_log3), + (0xa4, 0, COL_MAP.is_log4), + (0xa5, 0, COL_MAP.is_invalid_10), + (0xa6, 1, COL_MAP.is_invalid_11), // 0xa6-0xa7 + (0xa8, 3, COL_MAP.is_invalid_12), // 0xa8-0xaf + (0xb0, 4, COL_MAP.is_invalid_13), // 0xb0-0xbf + (0xc0, 5, COL_MAP.is_invalid_14), // 0xc0-0xdf + (0xe0, 4, COL_MAP.is_invalid_15), // 0xe0-0xef + (0xf0, 0, COL_MAP.is_create), + (0xf1, 0, COL_MAP.is_call), + (0xf2, 0, COL_MAP.is_callcode), + (0xf3, 0, COL_MAP.is_return), + (0xf4, 0, COL_MAP.is_delegatecall), + (0xf5, 0, COL_MAP.is_create2), + (0xf6, 1, COL_MAP.is_invalid_16), // 0xf6-0xf7 + (0xf8, 1, COL_MAP.is_invalid_17), // 0xf8-0xf9 + (0xfa, 0, COL_MAP.is_staticcall), + (0xfb, 0, COL_MAP.is_invalid_18), + (0xfc, 0, COL_MAP.is_invalid_19), + (0xfd, 0, COL_MAP.is_revert), + (0xfe, 0, COL_MAP.is_invalid_20), + (0xff, 0, COL_MAP.is_selfdestruct), ]; -pub fn generate(lv: &mut [F; columns::NUM_CPU_COLUMNS]) { - let cycle_filter = lv[columns::IS_CPU_CYCLE]; +pub fn generate(lv: &mut CpuColumnsView) { + let cycle_filter = lv.is_cpu_cycle; if cycle_filter == F::ZERO { // These columns cannot be shared. - lv[columns::IS_EQ] = F::ZERO; - lv[columns::IS_ISZERO] = F::ZERO; + lv.is_eq = F::ZERO; + lv.is_iszero = F::ZERO; return; } // This assert is not _strictly_ necessary, but I include it as a sanity check. assert_eq!(cycle_filter, F::ONE, "cycle_filter should be 0 or 1"); - let opcode = lv[columns::OPCODE].to_canonical_u64(); + let opcode = lv.opcode.to_canonical_u64(); assert!(opcode < 256, "opcode should be in {{0, ..., 255}}"); - for (i, &col) in columns::OPCODE_BITS.iter().enumerate() { - let bit = (opcode >> (7 - i)) & 1; - lv[col] = F::from_canonical_u64(bit); + for (i, bit) in lv.opcode_bits.iter_mut().enumerate() { + *bit = F::from_canonical_u64((opcode >> (7 - i)) & 1); } let top_bits: [u64; 9] = [ @@ -158,14 +157,14 @@ pub fn generate(lv: &mut [F; columns::NUM_CPU_COLUMNS]) { } pub fn eval_packed_generic( - lv: &[P; columns::NUM_CPU_COLUMNS], + lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let cycle_filter = lv[columns::IS_CPU_CYCLE]; + let cycle_filter = lv.is_cpu_cycle; // Ensure that the opcode bits are valid: each has to be either 0 or 1, and they must match // the opcode. Note that this also validates that this implicitly range-checks the opcode. - let bits = columns::OPCODE_BITS.map(|i| lv[i]); + let bits = lv.opcode_bits; // First check that the bits are either 0 or 1. for bit in bits { yield_constr.constraint(cycle_filter * bit * (bit - P::ONES)); @@ -181,18 +180,19 @@ pub fn eval_packed_generic( }; // Now check that they match the opcode. - let opcode = lv[columns::OPCODE]; + let opcode = lv.opcode; yield_constr.constraint(cycle_filter * (opcode - top_bits[8])); // Check that the instruction flags are valid. // First, check that they are all either 0 or 1. - for &flag in &lv[columns::START_INSTRUCTION_FLAGS..columns::END_INSTRUCTION_FLAGS] { + for (_, _, flag_col) in OPCODES { + let flag = lv[flag_col]; yield_constr.constraint(cycle_filter * flag * (flag - P::ONES)); } // Now check that exactly one is 1. - let flag_sum: P = (columns::START_INSTRUCTION_FLAGS..columns::END_INSTRUCTION_FLAGS) + let flag_sum: P = OPCODES .into_iter() - .map(|i| lv[i]) + .map(|(_, _, flag_col)| lv[flag_col]) .sum(); yield_constr.constraint(cycle_filter * (P::ONES - flag_sum)); @@ -205,14 +205,14 @@ pub fn eval_packed_generic( pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, - lv: &[ExtensionTarget; columns::NUM_CPU_COLUMNS], + lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let cycle_filter = lv[columns::IS_CPU_CYCLE]; + let cycle_filter = lv.is_cpu_cycle; // Ensure that the opcode bits are valid: each has to be either 0 or 1, and they must match // the opcode. Note that this also validates that this implicitly range-checks the opcode. - let bits = columns::OPCODE_BITS.map(|i| lv[i]); + let bits = lv.opcode_bits; // First check that the bits are either 0 or 1. for bit in bits { let constr = builder.mul_sub_extension(bit, bit, bit); @@ -234,14 +234,15 @@ pub fn eval_ext_circuit, const D: usize>( // Now check that the bits match the opcode. { - let constr = builder.sub_extension(lv[columns::OPCODE], top_bits[8]); + let constr = builder.sub_extension(lv.opcode, top_bits[8]); let constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); }; // Check that the instruction flags are valid. // First, check that they are all either 0 or 1. - for &flag in &lv[columns::START_INSTRUCTION_FLAGS..columns::END_INSTRUCTION_FLAGS] { + for (_, _, flag_col) in OPCODES { + let flag = lv[flag_col]; let constr = builder.mul_sub_extension(flag, flag, flag); let constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); @@ -249,7 +250,8 @@ pub fn eval_ext_circuit, const D: usize>( // Now check that they sum to 1. { let mut constr = builder.one_extension(); - for &flag in &lv[columns::START_INSTRUCTION_FLAGS..columns::END_INSTRUCTION_FLAGS] { + for (_, _, flag_col) in OPCODES { + let flag = lv[flag_col]; constr = builder.sub_extension(constr, flag); } constr = builder.mul_extension(cycle_filter, constr); diff --git a/evm/src/cpu/simple_logic/eq_iszero.rs b/evm/src/cpu/simple_logic/eq_iszero.rs index 82af4c1b..97e000b6 100644 --- a/evm/src/cpu/simple_logic/eq_iszero.rs +++ b/evm/src/cpu/simple_logic/eq_iszero.rs @@ -4,13 +4,13 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cpu::columns; +use crate::cpu::columns::CpuColumnsView; const LIMB_SIZE: usize = 16; -pub fn generate(lv: &mut [F; columns::NUM_CPU_COLUMNS]) { - let eq_filter = lv[columns::IS_EQ].to_canonical_u64(); - let iszero_filter = lv[columns::IS_ISZERO].to_canonical_u64(); +pub fn generate(lv: &mut CpuColumnsView) { + let eq_filter = lv.is_eq.to_canonical_u64(); + let iszero_filter = lv.is_iszero.to_canonical_u64(); assert!(eq_filter <= 1); assert!(iszero_filter <= 1); assert!(eq_filter + iszero_filter <= 1); @@ -20,11 +20,10 @@ pub fn generate(lv: &mut [F; columns::NUM_CPU_COLUMNS]) { } let diffs = if eq_filter == 1 { - columns::LOGIC_INPUT0 - .zip(columns::LOGIC_INPUT1) - .map(|(in0_col, in1_col)| { - let in0 = lv[in0_col]; - let in1 = lv[in1_col]; + lv.logic_input0 + .into_iter() + .zip(lv.logic_input1) + .map(|(in0, in1)| { assert_eq!(in0.to_canonical_u64() >> LIMB_SIZE, 0); assert_eq!(in1.to_canonical_u64() >> LIMB_SIZE, 0); let diff = in0 - in1; @@ -32,54 +31,50 @@ pub fn generate(lv: &mut [F; columns::NUM_CPU_COLUMNS]) { }) .sum() } else if iszero_filter == 1 { - columns::LOGIC_INPUT0.map(|i| lv[i]).sum() + lv.logic_input0.into_iter().sum() } else { panic!() }; - lv[columns::SIMPLE_LOGIC_DIFF] = diffs; - lv[columns::SIMPLE_LOGIC_DIFF_INV] = diffs.try_inverse().unwrap_or(F::ZERO); + lv.simple_logic_diff = diffs; + lv.simple_logic_diff_inv = diffs.try_inverse().unwrap_or(F::ZERO); - lv[columns::LOGIC_OUTPUT.start] = F::from_bool(diffs == F::ZERO); - for i in columns::LOGIC_OUTPUT.start + 1..columns::LOGIC_OUTPUT.end { - lv[i] = F::ZERO; + lv.logic_output[0] = F::from_bool(diffs == F::ZERO); + for out_limb_ref in lv.logic_output[1..].iter_mut() { + *out_limb_ref = F::ZERO; } } pub fn eval_packed( - lv: &[P; columns::NUM_CPU_COLUMNS], + lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let eq_filter = lv[columns::IS_EQ]; - let iszero_filter = lv[columns::IS_ISZERO]; + let eq_filter = lv.is_eq; + let iszero_filter = lv.is_iszero; let eq_or_iszero_filter = eq_filter + iszero_filter; - let ls_bit = lv[columns::LOGIC_OUTPUT.start]; + let ls_bit = lv.logic_output[0]; // Handle EQ and ISZERO. Most limbs of the output are 0, but the least-significant one is // either 0 or 1. yield_constr.constraint(eq_or_iszero_filter * ls_bit * (ls_bit - P::ONES)); - for bit_col in columns::LOGIC_OUTPUT.start + 1..columns::LOGIC_OUTPUT.end { - let bit = lv[bit_col]; + for &bit in &lv.logic_output[1..] { yield_constr.constraint(eq_or_iszero_filter * bit); } // Check SIMPLE_LOGIC_DIFF - let diffs = lv[columns::SIMPLE_LOGIC_DIFF]; - let diffs_inv = lv[columns::SIMPLE_LOGIC_DIFF_INV]; + let diffs = lv.simple_logic_diff; + let diffs_inv = lv.simple_logic_diff_inv; { - let input0_sum: P = columns::LOGIC_INPUT0.map(|i| lv[i]).sum(); + let input0_sum: P = lv.logic_input0.into_iter().sum(); yield_constr.constraint(iszero_filter * (diffs - input0_sum)); - let sum_squared_diffs: P = columns::LOGIC_INPUT0 - .zip(columns::LOGIC_INPUT1) - .map(|(in0_col, in1_col)| { - let in0 = lv[in0_col]; - let in1 = lv[in1_col]; - let diff = in0 - in1; - diff.square() - }) + let sum_squared_diffs: P = lv + .logic_input0 + .into_iter() + .zip(lv.logic_input1) + .map(|(in0, in1)| (in0 - in1).square()) .sum(); yield_constr.constraint(eq_filter * (diffs - sum_squared_diffs)); } @@ -92,14 +87,14 @@ pub fn eval_packed( pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, - lv: &[ExtensionTarget; columns::NUM_CPU_COLUMNS], + lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let eq_filter = lv[columns::IS_EQ]; - let iszero_filter = lv[columns::IS_ISZERO]; + let eq_filter = lv.is_eq; + let iszero_filter = lv.is_iszero; let eq_or_iszero_filter = builder.add_extension(eq_filter, iszero_filter); - let ls_bit = lv[columns::LOGIC_OUTPUT.start]; + let ls_bit = lv.logic_output[0]; // Handle EQ and ISZERO. Most limbs of the output are 0, but the least-significant one is // either 0 or 1. @@ -109,28 +104,25 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr.constraint(builder, constr); } - for bit_col in columns::LOGIC_OUTPUT.start + 1..columns::LOGIC_OUTPUT.end { - let bit = lv[bit_col]; + for &bit in &lv.logic_output[1..] { let constr = builder.mul_extension(eq_or_iszero_filter, bit); yield_constr.constraint(builder, constr); } // Check SIMPLE_LOGIC_DIFF - let diffs = lv[columns::SIMPLE_LOGIC_DIFF]; - let diffs_inv = lv[columns::SIMPLE_LOGIC_DIFF_INV]; + let diffs = lv.simple_logic_diff; + let diffs_inv = lv.simple_logic_diff_inv; { - let input0_sum = builder.add_many_extension(columns::LOGIC_INPUT0.map(|i| lv[i])); + let input0_sum = builder.add_many_extension(lv.logic_input0); { let constr = builder.sub_extension(diffs, input0_sum); let constr = builder.mul_extension(iszero_filter, constr); yield_constr.constraint(builder, constr); } - let sum_squared_diffs = columns::LOGIC_INPUT0.zip(columns::LOGIC_INPUT1).fold( + let sum_squared_diffs = lv.logic_input0.into_iter().zip(lv.logic_input1).fold( builder.zero_extension(), - |acc, (in0_col, in1_col)| { - let in0 = lv[in0_col]; - let in1 = lv[in1_col]; + |acc, (in0, in1)| { let diff = builder.sub_extension(in0, in1); builder.mul_add_extension(diff, diff, acc) }, diff --git a/evm/src/cpu/simple_logic/mod.rs b/evm/src/cpu/simple_logic/mod.rs index 368c13bd..963b11b2 100644 --- a/evm/src/cpu/simple_logic/mod.rs +++ b/evm/src/cpu/simple_logic/mod.rs @@ -7,10 +7,10 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cpu::columns; +use crate::cpu::columns::CpuColumnsView; -pub fn generate(lv: &mut [F; columns::NUM_CPU_COLUMNS]) { - let cycle_filter = lv[columns::IS_CPU_CYCLE].to_canonical_u64(); +pub fn generate(lv: &mut CpuColumnsView) { + let cycle_filter = lv.is_cpu_cycle.to_canonical_u64(); if cycle_filter == 0 { return; } @@ -21,7 +21,7 @@ pub fn generate(lv: &mut [F; columns::NUM_CPU_COLUMNS]) { } pub fn eval_packed( - lv: &[P; columns::NUM_CPU_COLUMNS], + lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { not::eval_packed(lv, yield_constr); @@ -30,7 +30,7 @@ pub fn eval_packed( pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, - lv: &[ExtensionTarget; columns::NUM_CPU_COLUMNS], + lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { not::eval_ext_circuit(builder, lv, yield_constr); diff --git a/evm/src/cpu/simple_logic/not.rs b/evm/src/cpu/simple_logic/not.rs index 019ffb80..d1ba4d46 100644 --- a/evm/src/cpu/simple_logic/not.rs +++ b/evm/src/cpu/simple_logic/not.rs @@ -5,37 +5,35 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cpu::columns; +use crate::cpu::columns::CpuColumnsView; const LIMB_SIZE: usize = 16; const ALL_1_LIMB: u64 = (1 << LIMB_SIZE) - 1; -pub fn generate(lv: &mut [F; columns::NUM_CPU_COLUMNS]) { - let is_not_filter = lv[columns::IS_NOT].to_canonical_u64(); +pub fn generate(lv: &mut CpuColumnsView) { + let is_not_filter = lv.is_not.to_canonical_u64(); if is_not_filter == 0 { return; } assert_eq!(is_not_filter, 1); - for (input_col, output_col) in columns::LOGIC_INPUT0.zip(columns::LOGIC_OUTPUT) { - let input = lv[input_col].to_canonical_u64(); + for (input, output_ref) in lv.logic_input0.into_iter().zip(lv.logic_output.iter_mut()) { + let input = input.to_canonical_u64(); assert_eq!(input >> LIMB_SIZE, 0); let output = input ^ ALL_1_LIMB; - lv[output_col] = F::from_canonical_u64(output); + *output_ref = F::from_canonical_u64(output); } } pub fn eval_packed( - lv: &[P; columns::NUM_CPU_COLUMNS], + lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { // This is simple: just do output = 0xffff - input. - let cycle_filter = lv[columns::IS_CPU_CYCLE]; - let is_not_filter = lv[columns::IS_NOT]; + let cycle_filter = lv.is_cpu_cycle; + let is_not_filter = lv.is_not; let filter = cycle_filter * is_not_filter; - for (input_col, output_col) in columns::LOGIC_INPUT0.zip(columns::LOGIC_OUTPUT) { - let input = lv[input_col]; - let output = lv[output_col]; + for (input, output) in lv.logic_input0.into_iter().zip(lv.logic_output) { yield_constr .constraint(filter * (output + input - P::Scalar::from_canonical_u64(ALL_1_LIMB))); } @@ -43,15 +41,13 @@ pub fn eval_packed( pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, - lv: &[ExtensionTarget; columns::NUM_CPU_COLUMNS], + lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let cycle_filter = lv[columns::IS_CPU_CYCLE]; - let is_not_filter = lv[columns::IS_NOT]; + let cycle_filter = lv.is_cpu_cycle; + let is_not_filter = lv.is_not; let filter = builder.mul_extension(cycle_filter, is_not_filter); - for (input_col, output_col) in columns::LOGIC_INPUT0.zip(columns::LOGIC_OUTPUT) { - let input = lv[input_col]; - let output = lv[output_col]; + for (input, output) in lv.logic_input0.into_iter().zip(lv.logic_output) { let constr = builder.add_extension(output, input); let constr = builder.arithmetic_extension( F::ONE,