Transpose memory columns (make it an array of channel structs) (#700)

This commit is contained in:
Jacqueline Nabaglo 2022-08-26 22:05:16 -07:00 committed by GitHub
parent 08758a3b9d
commit 013bf6471d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 100 additions and 88 deletions

View File

@ -344,19 +344,16 @@ mod tests {
if is_actual_op {
let row: &mut cpu::columns::CpuColumnsView<F> = cpu_trace_rows[clock].borrow_mut();
row.mem_channel_used[channel] = F::ONE;
row.clock = F::from_canonical_usize(clock);
row.mem_is_read[channel] = memory_trace[memory::columns::IS_READ].values[i];
row.mem_addr_context[channel] =
memory_trace[memory::columns::ADDR_CONTEXT].values[i];
row.mem_addr_segment[channel] =
memory_trace[memory::columns::ADDR_SEGMENT].values[i];
row.mem_addr_virtual[channel] =
memory_trace[memory::columns::ADDR_VIRTUAL].values[i];
let channel = &mut row.mem_channels[channel];
channel.used = F::ONE;
channel.is_read = memory_trace[memory::columns::IS_READ].values[i];
channel.addr_context = memory_trace[memory::columns::ADDR_CONTEXT].values[i];
channel.addr_segment = memory_trace[memory::columns::ADDR_SEGMENT].values[i];
channel.addr_virtual = memory_trace[memory::columns::ADDR_VIRTUAL].values[i];
for j in 0..8 {
row.mem_value[channel][j] =
memory_trace[memory::columns::value_limb(j)].values[i];
channel.value[j] = memory_trace[memory::columns::value_limb(j)].values[i];
}
}
}
@ -382,16 +379,24 @@ mod tests {
);
let input0_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT0);
for (col_cpu, limb_cols_logic) in row.mem_value[0].iter_mut().zip(input0_bit_cols) {
for (col_cpu, limb_cols_logic) in
row.mem_channels[0].value.iter_mut().zip(input0_bit_cols)
{
*col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
let input1_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT1);
for (col_cpu, limb_cols_logic) in row.mem_value[1].iter_mut().zip(input1_bit_cols) {
for (col_cpu, limb_cols_logic) in
row.mem_channels[1].value.iter_mut().zip(input1_bit_cols)
{
*col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
for (col_cpu, col_logic) in row.mem_value[2].iter_mut().zip(logic::columns::RESULT) {
for (col_cpu, col_logic) in row.mem_channels[2]
.value
.iter_mut()
.zip(logic::columns::RESULT)
{
*col_cpu = logic_trace[col_logic].values[i];
}
@ -409,7 +414,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x0a); // `EXP` is implemented in software
row.is_kernel_mode = F::ONE;
row.program_counter = last_row.program_counter + F::ONE;
row.mem_value[0] = [
row.mem_channels[0].value = [
row.program_counter,
F::ONE,
F::ZERO,
@ -431,7 +436,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0xf9);
row.is_kernel_mode = F::ONE;
row.program_counter = F::from_canonical_usize(KERNEL.global_labels["sys_exp"]);
row.mem_value[0] = [
row.mem_channels[0].value = [
F::from_canonical_u16(15682),
F::ONE,
F::ZERO,
@ -453,7 +458,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x56);
row.is_kernel_mode = F::ONE;
row.program_counter = F::from_canonical_u16(15682);
row.mem_value[0] = [
row.mem_channels[0].value = [
F::from_canonical_u16(15106),
F::ZERO,
F::ZERO,
@ -463,7 +468,7 @@ mod tests {
F::ZERO,
F::ZERO,
];
row.mem_value[1] = [
row.mem_channels[1].value = [
F::ONE,
F::ZERO,
F::ZERO,
@ -490,7 +495,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0xf9);
row.is_kernel_mode = F::ONE;
row.program_counter = F::from_canonical_u16(15106);
row.mem_value[0] = [
row.mem_channels[0].value = [
F::from_canonical_u16(63064),
F::ZERO,
F::ZERO,
@ -512,7 +517,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x56);
row.is_kernel_mode = F::ZERO;
row.program_counter = F::from_canonical_u16(63064);
row.mem_value[0] = [
row.mem_channels[0].value = [
F::from_canonical_u16(3754),
F::ZERO,
F::ZERO,
@ -522,7 +527,7 @@ mod tests {
F::ZERO,
F::ZERO,
];
row.mem_value[1] = [
row.mem_channels[1].value = [
F::ONE,
F::ZERO,
F::ZERO,
@ -550,7 +555,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x57);
row.is_kernel_mode = F::ZERO;
row.program_counter = F::from_canonical_u16(3754);
row.mem_value[0] = [
row.mem_channels[0].value = [
F::from_canonical_u16(37543),
F::ZERO,
F::ZERO,
@ -560,7 +565,7 @@ mod tests {
F::ZERO,
F::ZERO,
];
row.mem_value[1] = [
row.mem_channels[1].value = [
F::ZERO,
F::ZERO,
F::ZERO,
@ -588,7 +593,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x57);
row.is_kernel_mode = F::ZERO;
row.program_counter = F::from_canonical_u16(37543);
row.mem_value[0] = [
row.mem_channels[0].value = [
F::from_canonical_u16(37543),
F::ZERO,
F::ZERO,
@ -617,7 +622,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x56);
row.is_kernel_mode = F::ZERO;
row.program_counter = last_row.program_counter + F::ONE;
row.mem_value[0] = [
row.mem_channels[0].value = [
F::from_canonical_u16(37543),
F::ZERO,
F::ZERO,
@ -627,7 +632,7 @@ mod tests {
F::ZERO,
F::ZERO,
];
row.mem_value[1] = [
row.mem_channels[1].value = [
F::ONE,
F::ZERO,
F::ZERO,

View File

@ -102,10 +102,10 @@ pub(crate) struct CpuLogicView<T: Copy> {
#[derive(Copy, Clone)]
pub(crate) struct CpuJumpsView<T: Copy> {
/// `input0` is `mem_value[0]`. It's the top stack value at entry (for jumps, the address; for
/// `EXIT_KERNEL`, the address and new privilege level).
/// `input1` is `mem_value[1]`. For `JUMPI`, it's the second stack value (the predicate). For
/// `JUMP`, 1.
/// `input0` is `mem_channel[0].value`. It's the top stack value at entry (for jumps, the
/// address; for `EXIT_KERNEL`, the address and new privilege level).
/// `input1` is `mem_channel[1].value`. For `JUMPI`, it's the second stack value (the
/// predicate). For `JUMP`, 1.
/// Inverse of `input0[1] + ... + input0[7]`, if one exists; otherwise, an arbitrary value.
/// Needed to prove that `input0` is nonzero.

View File

@ -11,6 +11,19 @@ use crate::memory;
mod general;
#[repr(C)]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct MemoryChannelView<T: Copy> {
/// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise
/// 0.
pub used: T,
pub is_read: T,
pub addr_context: T,
pub addr_segment: T,
pub addr_virtual: T,
pub value: [T; memory::VALUE_LIMBS],
}
#[repr(C)]
#[derive(Eq, PartialEq, Debug)]
pub struct CpuColumnsView<T: Copy> {
@ -159,14 +172,7 @@ pub struct CpuColumnsView<T: Copy> {
pub(crate) general: CpuGeneralColumnsView<T>,
pub(crate) clock: T,
/// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise
/// 0.
pub mem_channel_used: [T; memory::NUM_CHANNELS],
pub mem_is_read: [T; memory::NUM_CHANNELS],
pub mem_addr_context: [T; memory::NUM_CHANNELS],
pub mem_addr_segment: [T; memory::NUM_CHANNELS],
pub mem_addr_virtual: [T; memory::NUM_CHANNELS],
pub mem_value: [[T; memory::VALUE_LIMBS]; memory::NUM_CHANNELS],
pub mem_channels: [MemoryChannelView<T>; memory::NUM_CHANNELS],
}
// `u8` is guaranteed to have a `size_of` of 1.

View File

@ -28,9 +28,9 @@ pub fn ctl_data_keccak_memory<F: Field>() -> Vec<Column<F>> {
// channel 1: stack[-1] = context
// channel 2: stack[-2] = segment
// channel 3: stack[-3] = virtual
let context = Column::single(COL_MAP.mem_value[1][0]);
let segment = Column::single(COL_MAP.mem_value[2][0]);
let virt = Column::single(COL_MAP.mem_value[3][0]);
let context = Column::single(COL_MAP.mem_channels[1].value[0]);
let segment = Column::single(COL_MAP.mem_channels[2].value[0]);
let virt = Column::single(COL_MAP.mem_channels[3].value[0]);
let num_channels = F::from_canonical_usize(NUM_CHANNELS);
let clock = Column::linear_combination([(COL_MAP.clock, num_channels)]);
@ -48,9 +48,9 @@ pub fn ctl_filter_keccak_memory<F: Field>() -> Column<F> {
pub fn ctl_data_logic<F: Field>() -> Vec<Column<F>> {
let mut res = Column::singles([COL_MAP.is_and, COL_MAP.is_or, COL_MAP.is_xor]).collect_vec();
res.extend(Column::singles(COL_MAP.mem_value[0]));
res.extend(Column::singles(COL_MAP.mem_value[1]));
res.extend(Column::singles(COL_MAP.mem_value[2]));
res.extend(Column::singles(COL_MAP.mem_channels[0].value));
res.extend(Column::singles(COL_MAP.mem_channels[1].value));
res.extend(Column::singles(COL_MAP.mem_channels[2].value));
res
}
@ -60,14 +60,15 @@ pub fn ctl_filter_logic<F: Field>() -> Column<F> {
pub fn ctl_data_memory<F: Field>(channel: usize) -> Vec<Column<F>> {
debug_assert!(channel < NUM_CHANNELS);
let channel_map = COL_MAP.mem_channels[channel];
let mut cols: Vec<Column<F>> = Column::singles([
COL_MAP.mem_is_read[channel],
COL_MAP.mem_addr_context[channel],
COL_MAP.mem_addr_segment[channel],
COL_MAP.mem_addr_virtual[channel],
channel_map.is_read,
channel_map.addr_context,
channel_map.addr_segment,
channel_map.addr_virtual,
])
.collect_vec();
cols.extend(Column::singles(COL_MAP.mem_value[channel]));
cols.extend(Column::singles(channel_map.value));
let scalar = F::from_canonical_usize(NUM_CHANNELS);
let addend = F::from_canonical_usize(channel);
@ -80,7 +81,7 @@ pub fn ctl_data_memory<F: Field>(channel: usize) -> Vec<Column<F>> {
}
pub fn ctl_filter_memory<F: Field>(channel: usize) -> Column<F> {
Column::single(COL_MAP.mem_channel_used[channel])
Column::single(COL_MAP.mem_channels[channel].used)
}
#[derive(Copy, Clone, Default)]

View File

@ -17,7 +17,7 @@ pub fn eval_packed_exit_kernel<P: PackedField>(
nv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let input = lv.mem_value[0];
let input = lv.mem_channels[0].value;
// If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode
// flag. The top 6 (32-bit) limbs are ignored (this is not part of the spec, but we trust the
@ -36,7 +36,7 @@ pub fn eval_ext_circuit_exit_kernel<F: RichField + Extendable<D>, const D: usize
nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let input = lv.mem_value[0];
let input = lv.mem_channels[0].value;
let filter = builder.mul_extension(lv.is_cpu_cycle, lv.is_exit_kernel);
// If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode
@ -58,8 +58,8 @@ pub fn eval_packed_jump_jumpi<P: PackedField>(
yield_constr: &mut ConstraintConsumer<P>,
) {
let jumps_lv = lv.general.jumps();
let input0 = lv.mem_value[0];
let input1 = lv.mem_value[1];
let input0 = lv.mem_channels[0].value;
let input1 = lv.mem_channels[1].value;
let filter = lv.is_jump + lv.is_jumpi; // `JUMP` or `JUMPI`
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.
@ -160,8 +160,8 @@ pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let jumps_lv = lv.general.jumps();
let input0 = lv.mem_value[0];
let input1 = lv.mem_value[1];
let input0 = lv.mem_channels[0].value;
let input1 = lv.mem_channels[1].value;
let filter = builder.add_extension(lv.is_jump, lv.is_jumpi); // `JUMP` or `JUMPI`
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.

View File

@ -8,7 +8,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cpu::columns::CpuColumnsView;
pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
let input0 = lv.mem_value[0];
let input0 = lv.mem_channels[0].value;
let eq_filter = lv.is_eq.to_canonical_u64();
let iszero_filter = lv.is_iszero.to_canonical_u64();
@ -20,20 +20,20 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
return;
}
let input1 = &mut lv.mem_value[1];
let input1 = &mut lv.mem_channels[1].value;
if iszero_filter != 0 {
for limb in input1.iter_mut() {
*limb = F::ZERO;
}
}
let input1 = lv.mem_value[1];
let input1 = lv.mem_channels[1].value;
let num_unequal_limbs = izip!(input0, input1)
.map(|(limb0, limb1)| (limb0 != limb1) as usize)
.sum();
let equal = num_unequal_limbs == 0;
let output = &mut lv.mem_value[2];
let output = &mut lv.mem_channels[2].value;
output[0] = F::from_bool(equal);
for limb in &mut output[1..] {
*limb = F::ZERO;
@ -58,9 +58,9 @@ pub fn eval_packed<P: PackedField>(
yield_constr: &mut ConstraintConsumer<P>,
) {
let logic = lv.general.logic();
let input0 = lv.mem_value[0];
let input1 = lv.mem_value[1];
let output = lv.mem_value[2];
let input0 = lv.mem_channels[0].value;
let input1 = lv.mem_channels[1].value;
let output = lv.mem_channels[2].value;
let eq_filter = lv.is_eq;
let iszero_filter = lv.is_iszero;
@ -106,9 +106,9 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
let one = builder.one_extension();
let logic = lv.general.logic();
let input0 = lv.mem_value[0];
let input1 = lv.mem_value[1];
let output = lv.mem_value[2];
let input0 = lv.mem_channels[0].value;
let input1 = lv.mem_channels[1].value;
let output = lv.mem_channels[2].value;
let eq_filter = lv.is_eq;
let iszero_filter = lv.is_iszero;

View File

@ -17,8 +17,8 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
}
assert_eq!(is_not_filter, 1);
let input = lv.mem_value[0];
let output = &mut lv.mem_value[1];
let input = lv.mem_channels[0].value;
let output = &mut lv.mem_channels[1].value;
for (input, output_ref) in input.into_iter().zip(output.iter_mut()) {
let input = input.to_canonical_u64();
assert_eq!(input >> LIMB_SIZE, 0);
@ -32,8 +32,8 @@ pub fn eval_packed<P: PackedField>(
yield_constr: &mut ConstraintConsumer<P>,
) {
// This is simple: just do output = 0xffffffff - input.
let input = lv.mem_value[0];
let output = lv.mem_value[1];
let input = lv.mem_channels[0].value;
let output = lv.mem_channels[1].value;
let cycle_filter = lv.is_cpu_cycle;
let is_not_filter = lv.is_not;
let filter = cycle_filter * is_not_filter;
@ -49,8 +49,8 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
lv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let input = lv.mem_value[0];
let output = lv.mem_value[1];
let input = lv.mem_channels[0].value;
let output = lv.mem_channels[1].value;
let cycle_filter = lv.is_cpu_cycle;
let is_not_filter = lv.is_not;
let filter = builder.mul_extension(cycle_filter, is_not_filter);

View File

@ -48,7 +48,7 @@ pub fn eval_packed<P: PackedField>(
// If syscall: set kernel mode
yield_constr.constraint_transition(filter * (nv.is_kernel_mode - P::ONES));
let output = lv.mem_value[0];
let output = lv.mem_channels[0].value;
// If syscall: push current PC to stack
yield_constr.constraint(filter * (output[0] - lv.program_counter));
// If syscall: push current kernel flag to stack (share register with PC)
@ -91,7 +91,7 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
yield_constr.constraint_transition(builder, constr);
}
let output = lv.mem_value[0];
let output = lv.mem_channels[0].value;
// If syscall: push current PC to stack
{
let constr = builder.sub_extension(output[0], lv.program_counter);

View File

@ -77,13 +77,13 @@ impl<F: Field> GenerationState<F> {
let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index;
let value = self.get_mem(context, segment, virt, timestamp);
self.current_cpu_row.mem_channel_used[channel_index] = F::ONE;
self.current_cpu_row.mem_is_read[channel_index] = F::ONE;
self.current_cpu_row.mem_addr_context[channel_index] = F::from_canonical_usize(context);
self.current_cpu_row.mem_addr_segment[channel_index] =
F::from_canonical_usize(segment as usize);
self.current_cpu_row.mem_addr_virtual[channel_index] = F::from_canonical_usize(virt);
self.current_cpu_row.mem_value[channel_index] = u256_limbs(value);
let channel = &mut self.current_cpu_row.mem_channels[channel_index];
channel.used = F::ONE;
channel.is_read = F::ONE;
channel.addr_context = F::from_canonical_usize(context);
channel.addr_segment = F::from_canonical_usize(segment as usize);
channel.addr_virtual = F::from_canonical_usize(virt);
channel.value = u256_limbs(value);
value
}
@ -133,13 +133,13 @@ impl<F: Field> GenerationState<F> {
let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index;
self.set_mem(context, segment, virt, value, timestamp);
self.current_cpu_row.mem_channel_used[channel_index] = F::ONE;
self.current_cpu_row.mem_is_read[channel_index] = F::ZERO; // For clarity; should already be 0.
self.current_cpu_row.mem_addr_context[channel_index] = F::from_canonical_usize(context);
self.current_cpu_row.mem_addr_segment[channel_index] =
F::from_canonical_usize(segment as usize);
self.current_cpu_row.mem_addr_virtual[channel_index] = F::from_canonical_usize(virt);
self.current_cpu_row.mem_value[channel_index] = u256_limbs(value);
let channel = &mut self.current_cpu_row.mem_channels[channel_index];
channel.used = F::ONE;
channel.is_read = F::ZERO; // For clarity; should already be 0.
channel.addr_context = F::from_canonical_usize(context);
channel.addr_segment = F::from_canonical_usize(segment as usize);
channel.addr_virtual = F::from_canonical_usize(virt);
channel.value = u256_limbs(value);
}
/// Write some memory, and log the operation.