From 013bf6471d1543c5e0d957eb564fab534ad7e99b Mon Sep 17 00:00:00 2001 From: Jacqueline Nabaglo Date: Fri, 26 Aug 2022 22:05:16 -0700 Subject: [PATCH] Transpose memory columns (make it an array of channel structs) (#700) --- evm/src/all_stark.rs | 57 +++++++++++++++------------ evm/src/cpu/columns/general.rs | 8 ++-- evm/src/cpu/columns/mod.rs | 22 +++++++---- evm/src/cpu/cpu_stark.rs | 25 ++++++------ evm/src/cpu/jumps.rs | 12 +++--- evm/src/cpu/simple_logic/eq_iszero.rs | 20 +++++----- evm/src/cpu/simple_logic/not.rs | 12 +++--- evm/src/cpu/syscalls.rs | 4 +- evm/src/generation/state.rs | 28 ++++++------- 9 files changed, 100 insertions(+), 88 deletions(-) diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index e8b44d23..5fd262ac 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -344,19 +344,16 @@ mod tests { if is_actual_op { let row: &mut cpu::columns::CpuColumnsView = cpu_trace_rows[clock].borrow_mut(); - - row.mem_channel_used[channel] = F::ONE; row.clock = F::from_canonical_usize(clock); - row.mem_is_read[channel] = memory_trace[memory::columns::IS_READ].values[i]; - row.mem_addr_context[channel] = - memory_trace[memory::columns::ADDR_CONTEXT].values[i]; - row.mem_addr_segment[channel] = - memory_trace[memory::columns::ADDR_SEGMENT].values[i]; - row.mem_addr_virtual[channel] = - memory_trace[memory::columns::ADDR_VIRTUAL].values[i]; + + let channel = &mut row.mem_channels[channel]; + channel.used = F::ONE; + channel.is_read = memory_trace[memory::columns::IS_READ].values[i]; + channel.addr_context = memory_trace[memory::columns::ADDR_CONTEXT].values[i]; + channel.addr_segment = memory_trace[memory::columns::ADDR_SEGMENT].values[i]; + channel.addr_virtual = memory_trace[memory::columns::ADDR_VIRTUAL].values[i]; for j in 0..8 { - row.mem_value[channel][j] = - memory_trace[memory::columns::value_limb(j)].values[i]; + channel.value[j] = memory_trace[memory::columns::value_limb(j)].values[i]; } } } @@ -382,16 +379,24 @@ mod tests { ); let input0_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT0); - for (col_cpu, limb_cols_logic) in row.mem_value[0].iter_mut().zip(input0_bit_cols) { + for (col_cpu, limb_cols_logic) in + row.mem_channels[0].value.iter_mut().zip(input0_bit_cols) + { *col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i])); } let input1_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT1); - for (col_cpu, limb_cols_logic) in row.mem_value[1].iter_mut().zip(input1_bit_cols) { + for (col_cpu, limb_cols_logic) in + row.mem_channels[1].value.iter_mut().zip(input1_bit_cols) + { *col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i])); } - for (col_cpu, col_logic) in row.mem_value[2].iter_mut().zip(logic::columns::RESULT) { + for (col_cpu, col_logic) in row.mem_channels[2] + .value + .iter_mut() + .zip(logic::columns::RESULT) + { *col_cpu = logic_trace[col_logic].values[i]; } @@ -409,7 +414,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x0a); // `EXP` is implemented in software row.is_kernel_mode = F::ONE; row.program_counter = last_row.program_counter + F::ONE; - row.mem_value[0] = [ + row.mem_channels[0].value = [ row.program_counter, F::ONE, F::ZERO, @@ -431,7 +436,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0xf9); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_usize(KERNEL.global_labels["sys_exp"]); - row.mem_value[0] = [ + row.mem_channels[0].value = [ F::from_canonical_u16(15682), F::ONE, F::ZERO, @@ -453,7 +458,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_u16(15682); - row.mem_value[0] = [ + row.mem_channels[0].value = [ F::from_canonical_u16(15106), F::ZERO, F::ZERO, @@ -463,7 +468,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.mem_value[1] = [ + row.mem_channels[1].value = [ F::ONE, F::ZERO, F::ZERO, @@ -490,7 +495,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0xf9); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_u16(15106); - row.mem_value[0] = [ + row.mem_channels[0].value = [ F::from_canonical_u16(63064), F::ZERO, F::ZERO, @@ -512,7 +517,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(63064); - row.mem_value[0] = [ + row.mem_channels[0].value = [ F::from_canonical_u16(3754), F::ZERO, F::ZERO, @@ -522,7 +527,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.mem_value[1] = [ + row.mem_channels[1].value = [ F::ONE, F::ZERO, F::ZERO, @@ -550,7 +555,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x57); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(3754); - row.mem_value[0] = [ + row.mem_channels[0].value = [ F::from_canonical_u16(37543), F::ZERO, F::ZERO, @@ -560,7 +565,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.mem_value[1] = [ + row.mem_channels[1].value = [ F::ZERO, F::ZERO, F::ZERO, @@ -588,7 +593,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x57); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(37543); - row.mem_value[0] = [ + row.mem_channels[0].value = [ F::from_canonical_u16(37543), F::ZERO, F::ZERO, @@ -617,7 +622,7 @@ mod tests { row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ZERO; row.program_counter = last_row.program_counter + F::ONE; - row.mem_value[0] = [ + row.mem_channels[0].value = [ F::from_canonical_u16(37543), F::ZERO, F::ZERO, @@ -627,7 +632,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.mem_value[1] = [ + row.mem_channels[1].value = [ F::ONE, F::ZERO, F::ZERO, diff --git a/evm/src/cpu/columns/general.rs b/evm/src/cpu/columns/general.rs index 43f987e9..134788dc 100644 --- a/evm/src/cpu/columns/general.rs +++ b/evm/src/cpu/columns/general.rs @@ -102,10 +102,10 @@ pub(crate) struct CpuLogicView { #[derive(Copy, Clone)] pub(crate) struct CpuJumpsView { - /// `input0` is `mem_value[0]`. It's the top stack value at entry (for jumps, the address; for - /// `EXIT_KERNEL`, the address and new privilege level). - /// `input1` is `mem_value[1]`. For `JUMPI`, it's the second stack value (the predicate). For - /// `JUMP`, 1. + /// `input0` is `mem_channel[0].value`. It's the top stack value at entry (for jumps, the + /// address; for `EXIT_KERNEL`, the address and new privilege level). + /// `input1` is `mem_channel[1].value`. For `JUMPI`, it's the second stack value (the + /// predicate). For `JUMP`, 1. /// Inverse of `input0[1] + ... + input0[7]`, if one exists; otherwise, an arbitrary value. /// Needed to prove that `input0` is nonzero. diff --git a/evm/src/cpu/columns/mod.rs b/evm/src/cpu/columns/mod.rs index 564ea246..34a02837 100644 --- a/evm/src/cpu/columns/mod.rs +++ b/evm/src/cpu/columns/mod.rs @@ -11,6 +11,19 @@ use crate::memory; mod general; +#[repr(C)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct MemoryChannelView { + /// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise + /// 0. + pub used: T, + pub is_read: T, + pub addr_context: T, + pub addr_segment: T, + pub addr_virtual: T, + pub value: [T; memory::VALUE_LIMBS], +} + #[repr(C)] #[derive(Eq, PartialEq, Debug)] pub struct CpuColumnsView { @@ -159,14 +172,7 @@ pub struct CpuColumnsView { pub(crate) general: CpuGeneralColumnsView, pub(crate) clock: T, - /// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise - /// 0. - pub mem_channel_used: [T; memory::NUM_CHANNELS], - pub mem_is_read: [T; memory::NUM_CHANNELS], - pub mem_addr_context: [T; memory::NUM_CHANNELS], - pub mem_addr_segment: [T; memory::NUM_CHANNELS], - pub mem_addr_virtual: [T; memory::NUM_CHANNELS], - pub mem_value: [[T; memory::VALUE_LIMBS]; memory::NUM_CHANNELS], + pub mem_channels: [MemoryChannelView; memory::NUM_CHANNELS], } // `u8` is guaranteed to have a `size_of` of 1. diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 39518a43..9fd4792d 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -28,9 +28,9 @@ pub fn ctl_data_keccak_memory() -> Vec> { // channel 1: stack[-1] = context // channel 2: stack[-2] = segment // channel 3: stack[-3] = virtual - let context = Column::single(COL_MAP.mem_value[1][0]); - let segment = Column::single(COL_MAP.mem_value[2][0]); - let virt = Column::single(COL_MAP.mem_value[3][0]); + let context = Column::single(COL_MAP.mem_channels[1].value[0]); + let segment = Column::single(COL_MAP.mem_channels[2].value[0]); + let virt = Column::single(COL_MAP.mem_channels[3].value[0]); let num_channels = F::from_canonical_usize(NUM_CHANNELS); let clock = Column::linear_combination([(COL_MAP.clock, num_channels)]); @@ -48,9 +48,9 @@ pub fn ctl_filter_keccak_memory() -> Column { pub fn ctl_data_logic() -> Vec> { let mut res = Column::singles([COL_MAP.is_and, COL_MAP.is_or, COL_MAP.is_xor]).collect_vec(); - res.extend(Column::singles(COL_MAP.mem_value[0])); - res.extend(Column::singles(COL_MAP.mem_value[1])); - res.extend(Column::singles(COL_MAP.mem_value[2])); + res.extend(Column::singles(COL_MAP.mem_channels[0].value)); + res.extend(Column::singles(COL_MAP.mem_channels[1].value)); + res.extend(Column::singles(COL_MAP.mem_channels[2].value)); res } @@ -60,14 +60,15 @@ pub fn ctl_filter_logic() -> Column { pub fn ctl_data_memory(channel: usize) -> Vec> { debug_assert!(channel < NUM_CHANNELS); + let channel_map = COL_MAP.mem_channels[channel]; let mut cols: Vec> = Column::singles([ - COL_MAP.mem_is_read[channel], - COL_MAP.mem_addr_context[channel], - COL_MAP.mem_addr_segment[channel], - COL_MAP.mem_addr_virtual[channel], + channel_map.is_read, + channel_map.addr_context, + channel_map.addr_segment, + channel_map.addr_virtual, ]) .collect_vec(); - cols.extend(Column::singles(COL_MAP.mem_value[channel])); + cols.extend(Column::singles(channel_map.value)); let scalar = F::from_canonical_usize(NUM_CHANNELS); let addend = F::from_canonical_usize(channel); @@ -80,7 +81,7 @@ pub fn ctl_data_memory(channel: usize) -> Vec> { } pub fn ctl_filter_memory(channel: usize) -> Column { - Column::single(COL_MAP.mem_channel_used[channel]) + Column::single(COL_MAP.mem_channels[channel].used) } #[derive(Copy, Clone, Default)] diff --git a/evm/src/cpu/jumps.rs b/evm/src/cpu/jumps.rs index bac10eb6..219b39dd 100644 --- a/evm/src/cpu/jumps.rs +++ b/evm/src/cpu/jumps.rs @@ -17,7 +17,7 @@ pub fn eval_packed_exit_kernel( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let input = lv.mem_value[0]; + let input = lv.mem_channels[0].value; // If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode // flag. The top 6 (32-bit) limbs are ignored (this is not part of the spec, but we trust the @@ -36,7 +36,7 @@ pub fn eval_ext_circuit_exit_kernel, const D: usize nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let input = lv.mem_value[0]; + let input = lv.mem_channels[0].value; let filter = builder.mul_extension(lv.is_cpu_cycle, lv.is_exit_kernel); // If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode @@ -58,8 +58,8 @@ pub fn eval_packed_jump_jumpi( yield_constr: &mut ConstraintConsumer

, ) { let jumps_lv = lv.general.jumps(); - let input0 = lv.mem_value[0]; - let input1 = lv.mem_value[1]; + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; let filter = lv.is_jump + lv.is_jumpi; // `JUMP` or `JUMPI` // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. @@ -160,8 +160,8 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> yield_constr: &mut RecursiveConstraintConsumer, ) { let jumps_lv = lv.general.jumps(); - let input0 = lv.mem_value[0]; - let input1 = lv.mem_value[1]; + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; let filter = builder.add_extension(lv.is_jump, lv.is_jumpi); // `JUMP` or `JUMPI` // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. diff --git a/evm/src/cpu/simple_logic/eq_iszero.rs b/evm/src/cpu/simple_logic/eq_iszero.rs index c3d9bc99..6b7294a8 100644 --- a/evm/src/cpu/simple_logic/eq_iszero.rs +++ b/evm/src/cpu/simple_logic/eq_iszero.rs @@ -8,7 +8,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::CpuColumnsView; pub fn generate(lv: &mut CpuColumnsView) { - let input0 = lv.mem_value[0]; + let input0 = lv.mem_channels[0].value; let eq_filter = lv.is_eq.to_canonical_u64(); let iszero_filter = lv.is_iszero.to_canonical_u64(); @@ -20,20 +20,20 @@ pub fn generate(lv: &mut CpuColumnsView) { return; } - let input1 = &mut lv.mem_value[1]; + let input1 = &mut lv.mem_channels[1].value; if iszero_filter != 0 { for limb in input1.iter_mut() { *limb = F::ZERO; } } - let input1 = lv.mem_value[1]; + let input1 = lv.mem_channels[1].value; let num_unequal_limbs = izip!(input0, input1) .map(|(limb0, limb1)| (limb0 != limb1) as usize) .sum(); let equal = num_unequal_limbs == 0; - let output = &mut lv.mem_value[2]; + let output = &mut lv.mem_channels[2].value; output[0] = F::from_bool(equal); for limb in &mut output[1..] { *limb = F::ZERO; @@ -58,9 +58,9 @@ pub fn eval_packed( yield_constr: &mut ConstraintConsumer

, ) { let logic = lv.general.logic(); - let input0 = lv.mem_value[0]; - let input1 = lv.mem_value[1]; - let output = lv.mem_value[2]; + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; + let output = lv.mem_channels[2].value; let eq_filter = lv.is_eq; let iszero_filter = lv.is_iszero; @@ -106,9 +106,9 @@ pub fn eval_ext_circuit, const D: usize>( let one = builder.one_extension(); let logic = lv.general.logic(); - let input0 = lv.mem_value[0]; - let input1 = lv.mem_value[1]; - let output = lv.mem_value[2]; + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; + let output = lv.mem_channels[2].value; let eq_filter = lv.is_eq; let iszero_filter = lv.is_iszero; diff --git a/evm/src/cpu/simple_logic/not.rs b/evm/src/cpu/simple_logic/not.rs index d9a16a66..83d43276 100644 --- a/evm/src/cpu/simple_logic/not.rs +++ b/evm/src/cpu/simple_logic/not.rs @@ -17,8 +17,8 @@ pub fn generate(lv: &mut CpuColumnsView) { } assert_eq!(is_not_filter, 1); - let input = lv.mem_value[0]; - let output = &mut lv.mem_value[1]; + let input = lv.mem_channels[0].value; + let output = &mut lv.mem_channels[1].value; for (input, output_ref) in input.into_iter().zip(output.iter_mut()) { let input = input.to_canonical_u64(); assert_eq!(input >> LIMB_SIZE, 0); @@ -32,8 +32,8 @@ pub fn eval_packed( yield_constr: &mut ConstraintConsumer

, ) { // This is simple: just do output = 0xffffffff - input. - let input = lv.mem_value[0]; - let output = lv.mem_value[1]; + let input = lv.mem_channels[0].value; + let output = lv.mem_channels[1].value; let cycle_filter = lv.is_cpu_cycle; let is_not_filter = lv.is_not; let filter = cycle_filter * is_not_filter; @@ -49,8 +49,8 @@ pub fn eval_ext_circuit, const D: usize>( lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let input = lv.mem_value[0]; - let output = lv.mem_value[1]; + let input = lv.mem_channels[0].value; + let output = lv.mem_channels[1].value; let cycle_filter = lv.is_cpu_cycle; let is_not_filter = lv.is_not; let filter = builder.mul_extension(cycle_filter, is_not_filter); diff --git a/evm/src/cpu/syscalls.rs b/evm/src/cpu/syscalls.rs index 1ca45bc3..116713ae 100644 --- a/evm/src/cpu/syscalls.rs +++ b/evm/src/cpu/syscalls.rs @@ -48,7 +48,7 @@ pub fn eval_packed( // If syscall: set kernel mode yield_constr.constraint_transition(filter * (nv.is_kernel_mode - P::ONES)); - let output = lv.mem_value[0]; + let output = lv.mem_channels[0].value; // If syscall: push current PC to stack yield_constr.constraint(filter * (output[0] - lv.program_counter)); // If syscall: push current kernel flag to stack (share register with PC) @@ -91,7 +91,7 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr.constraint_transition(builder, constr); } - let output = lv.mem_value[0]; + let output = lv.mem_channels[0].value; // If syscall: push current PC to stack { let constr = builder.sub_extension(output[0], lv.program_counter); diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 866f9fd7..4cbe61c8 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -77,13 +77,13 @@ impl GenerationState { let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index; let value = self.get_mem(context, segment, virt, timestamp); - self.current_cpu_row.mem_channel_used[channel_index] = F::ONE; - self.current_cpu_row.mem_is_read[channel_index] = F::ONE; - self.current_cpu_row.mem_addr_context[channel_index] = F::from_canonical_usize(context); - self.current_cpu_row.mem_addr_segment[channel_index] = - F::from_canonical_usize(segment as usize); - self.current_cpu_row.mem_addr_virtual[channel_index] = F::from_canonical_usize(virt); - self.current_cpu_row.mem_value[channel_index] = u256_limbs(value); + let channel = &mut self.current_cpu_row.mem_channels[channel_index]; + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(context); + channel.addr_segment = F::from_canonical_usize(segment as usize); + channel.addr_virtual = F::from_canonical_usize(virt); + channel.value = u256_limbs(value); value } @@ -133,13 +133,13 @@ impl GenerationState { let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index; self.set_mem(context, segment, virt, value, timestamp); - self.current_cpu_row.mem_channel_used[channel_index] = F::ONE; - self.current_cpu_row.mem_is_read[channel_index] = F::ZERO; // For clarity; should already be 0. - self.current_cpu_row.mem_addr_context[channel_index] = F::from_canonical_usize(context); - self.current_cpu_row.mem_addr_segment[channel_index] = - F::from_canonical_usize(segment as usize); - self.current_cpu_row.mem_addr_virtual[channel_index] = F::from_canonical_usize(virt); - self.current_cpu_row.mem_value[channel_index] = u256_limbs(value); + let channel = &mut self.current_cpu_row.mem_channels[channel_index]; + channel.used = F::ONE; + channel.is_read = F::ZERO; // For clarity; should already be 0. + channel.addr_context = F::from_canonical_usize(context); + channel.addr_segment = F::from_canonical_usize(segment as usize); + channel.addr_virtual = F::from_canonical_usize(virt); + channel.value = u256_limbs(value); } /// Write some memory, and log the operation.