Make jumps, logic, and syscalls read from/write to memory columns (#699)

* Make jumps, logic, and syscalls read from/write to memory columns

* Change CTL convention (outputs precede inputs)

* Change convention so outputs follow inputs in memory channel order
This commit is contained in:
Jacqueline Nabaglo 2022-08-26 16:39:39 -05:00 committed by GitHub
parent 70971aee2d
commit f48de368a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 137 additions and 133 deletions

View File

@ -318,45 +318,18 @@ mod tests {
cpu_trace_rows.push(row.into());
}
for i in 0..num_logic_rows {
// Pad to `num_memory_ops` for memory testing.
for _ in cpu_trace_rows.len()..num_memory_ops {
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.opcode_bits = bits_from_opcode(0x5b);
row.is_cpu_cycle = F::ONE;
row.is_kernel_mode = F::ONE;
// Since these are the first cycle rows, we must start with PC=route_txn then increment.
row.program_counter = F::from_canonical_usize(KERNEL.global_labels["route_txn"] + i);
row.opcode_bits = bits_from_opcode(
if logic_trace[logic::columns::IS_AND].values[i] != F::ZERO {
0x16
} else if logic_trace[logic::columns::IS_OR].values[i] != F::ZERO {
0x17
} else if logic_trace[logic::columns::IS_XOR].values[i] != F::ZERO {
0x18
} else {
panic!()
},
);
let logic = row.general.logic_mut();
let input0_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT0);
for (col_cpu, limb_cols_logic) in logic.input0.iter_mut().zip(input0_bit_cols) {
*col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
let input1_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT1);
for (col_cpu, limb_cols_logic) in logic.input1.iter_mut().zip(input1_bit_cols) {
*col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
for (col_cpu, col_logic) in logic.output.iter_mut().zip(logic::columns::RESULT) {
*col_cpu = logic_trace[col_logic].values[i];
}
row.program_counter = F::from_canonical_usize(KERNEL.global_labels["route_txn"]);
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
for i in 0..num_memory_ops {
let mem_timestamp: usize = memory_trace[memory::columns::TIMESTAMP].values[i]
.to_canonical_u64()
@ -388,6 +361,44 @@ mod tests {
}
}
for i in 0..num_logic_rows {
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.is_cpu_cycle = F::ONE;
row.is_kernel_mode = F::ONE;
// Since these are the first cycle rows, we must start with PC=route_txn then increment.
row.program_counter = F::from_canonical_usize(KERNEL.global_labels["route_txn"] + i);
row.opcode_bits = bits_from_opcode(
if logic_trace[logic::columns::IS_AND].values[i] != F::ZERO {
0x16
} else if logic_trace[logic::columns::IS_OR].values[i] != F::ZERO {
0x17
} else if logic_trace[logic::columns::IS_XOR].values[i] != F::ZERO {
0x18
} else {
panic!()
},
);
let input0_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT0);
for (col_cpu, limb_cols_logic) in row.mem_value[0].iter_mut().zip(input0_bit_cols) {
*col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
let input1_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT1);
for (col_cpu, limb_cols_logic) in row.mem_value[1].iter_mut().zip(input1_bit_cols) {
*col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
for (col_cpu, col_logic) in row.mem_value[2].iter_mut().zip(logic::columns::RESULT) {
*col_cpu = logic_trace[col_logic].values[i];
}
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// Trap to kernel
{
let mut row: cpu::columns::CpuColumnsView<F> =
@ -398,7 +409,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x0a); // `EXP` is implemented in software
row.is_kernel_mode = F::ONE;
row.program_counter = last_row.program_counter + F::ONE;
row.general.syscalls_mut().output = [
row.mem_value[0] = [
row.program_counter,
F::ONE,
F::ZERO,
@ -420,7 +431,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0xf9);
row.is_kernel_mode = F::ONE;
row.program_counter = F::from_canonical_usize(KERNEL.global_labels["sys_exp"]);
row.general.jumps_mut().input0 = [
row.mem_value[0] = [
F::from_canonical_u16(15682),
F::ONE,
F::ZERO,
@ -442,7 +453,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x56);
row.is_kernel_mode = F::ONE;
row.program_counter = F::from_canonical_u16(15682);
row.general.jumps_mut().input0 = [
row.mem_value[0] = [
F::from_canonical_u16(15106),
F::ZERO,
F::ZERO,
@ -452,7 +463,7 @@ mod tests {
F::ZERO,
F::ZERO,
];
row.general.jumps_mut().input1 = [
row.mem_value[1] = [
F::ONE,
F::ZERO,
F::ZERO,
@ -479,7 +490,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0xf9);
row.is_kernel_mode = F::ONE;
row.program_counter = F::from_canonical_u16(15106);
row.general.jumps_mut().input0 = [
row.mem_value[0] = [
F::from_canonical_u16(63064),
F::ZERO,
F::ZERO,
@ -501,7 +512,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x56);
row.is_kernel_mode = F::ZERO;
row.program_counter = F::from_canonical_u16(63064);
row.general.jumps_mut().input0 = [
row.mem_value[0] = [
F::from_canonical_u16(3754),
F::ZERO,
F::ZERO,
@ -511,7 +522,7 @@ mod tests {
F::ZERO,
F::ZERO,
];
row.general.jumps_mut().input1 = [
row.mem_value[1] = [
F::ONE,
F::ZERO,
F::ZERO,
@ -539,7 +550,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x57);
row.is_kernel_mode = F::ZERO;
row.program_counter = F::from_canonical_u16(3754);
row.general.jumps_mut().input0 = [
row.mem_value[0] = [
F::from_canonical_u16(37543),
F::ZERO,
F::ZERO,
@ -549,7 +560,7 @@ mod tests {
F::ZERO,
F::ZERO,
];
row.general.jumps_mut().input1 = [
row.mem_value[1] = [
F::ZERO,
F::ZERO,
F::ZERO,
@ -577,7 +588,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x57);
row.is_kernel_mode = F::ZERO;
row.program_counter = F::from_canonical_u16(37543);
row.general.jumps_mut().input0 = [
row.mem_value[0] = [
F::from_canonical_u16(37543),
F::ZERO,
F::ZERO,
@ -606,7 +617,7 @@ mod tests {
row.opcode_bits = bits_from_opcode(0x56);
row.is_kernel_mode = F::ZERO;
row.program_counter = last_row.program_counter + F::ONE;
row.general.jumps_mut().input0 = [
row.mem_value[0] = [
F::from_canonical_u16(37543),
F::ZERO,
F::ZERO,
@ -616,7 +627,7 @@ mod tests {
F::ZERO,
F::ZERO,
];
row.general.jumps_mut().input1 = [
row.mem_value[1] = [
F::ONE,
F::ZERO,
F::ZERO,

View File

@ -9,7 +9,6 @@ pub(crate) union CpuGeneralColumnsView<T: Copy> {
arithmetic: CpuArithmeticView<T>,
logic: CpuLogicView<T>,
jumps: CpuJumpsView<T>,
syscalls: CpuSyscallsView<T>,
}
impl<T: Copy> CpuGeneralColumnsView<T> {
@ -52,16 +51,6 @@ impl<T: Copy> CpuGeneralColumnsView<T> {
pub(crate) fn jumps_mut(&mut self) -> &mut CpuJumpsView<T> {
unsafe { &mut self.jumps }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn syscalls(&self) -> &CpuSyscallsView<T> {
unsafe { &self.syscalls }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn syscalls_mut(&mut self) -> &mut CpuSyscallsView<T> {
unsafe { &mut self.syscalls }
}
}
impl<T: Copy + PartialEq> PartialEq<Self> for CpuGeneralColumnsView<T> {
@ -107,23 +96,16 @@ pub(crate) struct CpuArithmeticView<T: Copy> {
#[derive(Copy, Clone)]
pub(crate) struct CpuLogicView<T: Copy> {
// Assuming a limb size of 32 bits.
pub(crate) input0: [T; 8],
pub(crate) input1: [T; 8],
pub(crate) output: [T; 8],
// Pseudoinverse of `(input0 - input1)`. Used prove that they are unequal.
// Pseudoinverse of `(input0 - input1)`. Used prove that they are unequal. Assumes 32-bit limbs.
pub(crate) diff_pinv: [T; 8],
}
#[derive(Copy, Clone)]
pub(crate) struct CpuJumpsView<T: Copy> {
/// Assuming a limb size of 32 bits.
/// The top stack value at entry (for jumps, the address; for `EXIT_KERNEL`, the address and new
/// privilege level).
pub(crate) input0: [T; 8],
/// For `JUMPI`, the second stack value (the predicate). For `JUMP`, 1.
pub(crate) input1: [T; 8],
/// `input0` is `mem_value[0]`. It's the top stack value at entry (for jumps, the address; for
/// `EXIT_KERNEL`, the address and new privilege level).
/// `input1` is `mem_value[1]`. For `JUMPI`, it's the second stack value (the predicate). For
/// `JUMP`, 1.
/// Inverse of `input0[1] + ... + input0[7]`, if one exists; otherwise, an arbitrary value.
/// Needed to prove that `input0` is nonzero.
@ -162,15 +144,5 @@ pub(crate) struct CpuJumpsView<T: Copy> {
pub(crate) should_trap: T,
}
#[derive(Copy, Clone)]
pub(crate) struct CpuSyscallsView<T: Copy> {
/// Assuming a limb size of 32 bits.
/// The output contains the context that is required to from the system call in `EXIT_KERNEL`.
/// `output[0]` contains the program counter at the time the system call was made (the address
/// of the syscall instruction). `output[1]` is 1 if we were in kernel mode at the time and 0
/// otherwise. `output[2]`, ..., `output[7]` are zero.
pub(crate) output: [T; 8],
}
// `u8` is guaranteed to have a `size_of` of 1.
pub const NUM_SHARED_COLUMNS: usize = size_of::<CpuGeneralColumnsView<u8>>();

View File

@ -48,10 +48,9 @@ pub fn ctl_filter_keccak_memory<F: Field>() -> Column<F> {
pub fn ctl_data_logic<F: Field>() -> Vec<Column<F>> {
let mut res = Column::singles([COL_MAP.is_and, COL_MAP.is_or, COL_MAP.is_xor]).collect_vec();
let logic = COL_MAP.general.logic();
res.extend(Column::singles(logic.input0));
res.extend(Column::singles(logic.input1));
res.extend(Column::singles(logic.output));
res.extend(Column::singles(COL_MAP.mem_value[0]));
res.extend(Column::singles(COL_MAP.mem_value[1]));
res.extend(Column::singles(COL_MAP.mem_value[2]));
res
}

View File

@ -17,16 +17,16 @@ pub fn eval_packed_exit_kernel<P: PackedField>(
nv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let jumps_lv = lv.general.jumps();
let input = lv.mem_value[0];
// If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode
// flag. The top 6 (32-bit) limbs are ignored (this is not part of the spec, but we trust the
// kernel to set them to zero).
yield_constr.constraint_transition(
lv.is_cpu_cycle * lv.is_exit_kernel * (jumps_lv.input0[0] - nv.program_counter),
lv.is_cpu_cycle * lv.is_exit_kernel * (input[0] - nv.program_counter),
);
yield_constr.constraint_transition(
lv.is_cpu_cycle * lv.is_exit_kernel * (jumps_lv.input0[1] - nv.is_kernel_mode),
lv.is_cpu_cycle * lv.is_exit_kernel * (input[1] - nv.is_kernel_mode),
);
}
@ -36,18 +36,18 @@ pub fn eval_ext_circuit_exit_kernel<F: RichField + Extendable<D>, const D: usize
nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let jumps_lv = lv.general.jumps();
let input = lv.mem_value[0];
let filter = builder.mul_extension(lv.is_cpu_cycle, lv.is_exit_kernel);
// If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode
// flag. The top 6 (32-bit) limbs are ignored (this is not part of the spec, but we trust the
// kernel to set them to zero).
let pc_constr = builder.sub_extension(jumps_lv.input0[0], nv.program_counter);
let pc_constr = builder.sub_extension(input[0], nv.program_counter);
let pc_constr = builder.mul_extension(filter, pc_constr);
yield_constr.constraint_transition(builder, pc_constr);
let kernel_constr = builder.sub_extension(jumps_lv.input0[1], nv.is_kernel_mode);
let kernel_constr = builder.sub_extension(input[1], nv.is_kernel_mode);
let kernel_constr = builder.mul_extension(filter, kernel_constr);
yield_constr.constraint_transition(builder, kernel_constr);
}
@ -58,12 +58,14 @@ pub fn eval_packed_jump_jumpi<P: PackedField>(
yield_constr: &mut ConstraintConsumer<P>,
) {
let jumps_lv = lv.general.jumps();
let input0 = lv.mem_value[0];
let input1 = lv.mem_value[1];
let filter = lv.is_jump + lv.is_jumpi; // `JUMP` or `JUMPI`
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.
// In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`.
yield_constr.constraint(lv.is_jump * (jumps_lv.input1[0] - P::ONES));
for &limb in &jumps_lv.input1[1..] {
yield_constr.constraint(lv.is_jump * (input1[0] - P::ONES));
for &limb in &input1[1..] {
// Set all limbs (other than the least-significant limb) to 0.
// NB: Technically, they don't have to be 0, as long as the sum
// `input1[0] + ... + input1[7]` cannot overflow.
@ -75,7 +77,7 @@ pub fn eval_packed_jump_jumpi<P: PackedField>(
yield_constr
.constraint(filter * jumps_lv.input0_upper_zero * (jumps_lv.input0_upper_zero - P::ONES));
// The below sum cannot overflow due to the limb size.
let input0_upper_sum: P = jumps_lv.input0[1..].iter().copied().sum();
let input0_upper_sum: P = input0[1..].iter().copied().sum();
// `input0_upper_zero` = 1 implies `input0_upper_sum` = 0.
yield_constr.constraint(filter * jumps_lv.input0_upper_zero * input0_upper_sum);
// `input0_upper_zero` = 0 implies `input0_upper_sum_inv * input0_upper_sum` = 1, which can only
@ -113,7 +115,7 @@ pub fn eval_packed_jump_jumpi<P: PackedField>(
// Validate `should_continue`
// This sum cannot overflow (due to limb size).
let input1_sum: P = jumps_lv.input1.into_iter().sum();
let input1_sum: P = input1.into_iter().sum();
// `should_continue` = 1 implies `input1_sum` = 0.
yield_constr.constraint(filter * jumps_lv.should_continue * input1_sum);
// `should_continue` = 0 implies `input1_sum * input1_sum_inv` = 1, which can only happen if
@ -147,9 +149,8 @@ pub fn eval_packed_jump_jumpi<P: PackedField>(
yield_constr.constraint_transition(
filter * jumps_lv.should_continue * (nv.program_counter - lv.program_counter - P::ONES),
);
yield_constr.constraint_transition(
filter * jumps_lv.should_jump * (nv.program_counter - jumps_lv.input0[0]),
);
yield_constr
.constraint_transition(filter * jumps_lv.should_jump * (nv.program_counter - input0[0]));
}
pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>(
@ -159,15 +160,17 @@ pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let jumps_lv = lv.general.jumps();
let input0 = lv.mem_value[0];
let input1 = lv.mem_value[1];
let filter = builder.add_extension(lv.is_jump, lv.is_jumpi); // `JUMP` or `JUMPI`
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.
// In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`.
{
let constr = builder.mul_sub_extension(lv.is_jump, jumps_lv.input1[0], lv.is_jump);
let constr = builder.mul_sub_extension(lv.is_jump, input1[0], lv.is_jump);
yield_constr.constraint(builder, constr);
}
for &limb in &jumps_lv.input1[1..] {
for &limb in &input1[1..] {
// Set all limbs (other than the least-significant limb) to 0.
// NB: Technically, they don't have to be 0, as long as the sum
// `input1[0] + ... + input1[7]` cannot overflow.
@ -188,7 +191,7 @@ pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>
}
{
// The below sum cannot overflow due to the limb size.
let input0_upper_sum = builder.add_many_extension(jumps_lv.input0[1..].iter());
let input0_upper_sum = builder.add_many_extension(input0[1..].iter());
// `input0_upper_zero` = 1 implies `input0_upper_sum` = 0.
let constr = builder.mul_extension(jumps_lv.input0_upper_zero, input0_upper_sum);
@ -251,7 +254,7 @@ pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>
// Validate `should_continue`
{
// This sum cannot overflow (due to limb size).
let input1_sum = builder.add_many_extension(jumps_lv.input1.into_iter());
let input1_sum = builder.add_many_extension(input1.into_iter());
// `should_continue` = 1 implies `input1_sum` = 0.
let constr = builder.mul_extension(jumps_lv.should_continue, input1_sum);
@ -326,7 +329,7 @@ pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>
}
// ...or jumping.
{
let constr = builder.sub_extension(nv.program_counter, jumps_lv.input0[0]);
let constr = builder.sub_extension(nv.program_counter, input0[0]);
let constr = builder.mul_extension(jumps_lv.should_jump, constr);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint_transition(builder, constr);

View File

@ -8,7 +8,8 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cpu::columns::CpuColumnsView;
pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
let logic = lv.general.logic_mut();
let input0 = lv.mem_value[0];
let eq_filter = lv.is_eq.to_canonical_u64();
let iszero_filter = lv.is_iszero.to_canonical_u64();
assert!(eq_filter <= 1);
@ -19,19 +20,22 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
return;
}
let input1 = &mut lv.mem_value[1];
if iszero_filter != 0 {
for limb in logic.input1.iter_mut() {
for limb in input1.iter_mut() {
*limb = F::ZERO;
}
}
let num_unequal_limbs = izip!(logic.input0, logic.input1)
let input1 = lv.mem_value[1];
let num_unequal_limbs = izip!(input0, input1)
.map(|(limb0, limb1)| (limb0 != limb1) as usize)
.sum();
let equal = num_unequal_limbs == 0;
logic.output[0] = F::from_bool(equal);
for limb in &mut logic.output[1..] {
let output = &mut lv.mem_value[2];
output[0] = F::from_bool(equal);
for limb in &mut output[1..] {
*limb = F::ZERO;
}
@ -40,10 +44,11 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
// Then `diff @ x = num_unequal_limbs`, where `@` denotes the dot product. We set
// `diff_pinv = num_unequal_limbs^-1 * x` if `num_unequal_limbs != 0` and 0 otherwise. We have
// `diff @ diff_pinv = 1 - equal` as desired.
let logic = lv.general.logic_mut();
let num_unequal_limbs_inv = F::from_canonical_usize(num_unequal_limbs)
.try_inverse()
.unwrap_or(F::ZERO);
for (limb_pinv, limb0, limb1) in izip!(logic.diff_pinv.iter_mut(), logic.input0, logic.input1) {
for (limb_pinv, limb0, limb1) in izip!(logic.diff_pinv.iter_mut(), input0, input1) {
*limb_pinv = (limb0 - limb1).try_inverse().unwrap_or(F::ZERO) * num_unequal_limbs_inv;
}
}
@ -53,27 +58,31 @@ pub fn eval_packed<P: PackedField>(
yield_constr: &mut ConstraintConsumer<P>,
) {
let logic = lv.general.logic();
let input0 = lv.mem_value[0];
let input1 = lv.mem_value[1];
let output = lv.mem_value[2];
let eq_filter = lv.is_eq;
let iszero_filter = lv.is_iszero;
let eq_or_iszero_filter = eq_filter + iszero_filter;
let equal = logic.output[0];
let equal = output[0];
let unequal = P::ONES - equal;
// Handle `EQ` and `ISZERO`. Most limbs of the output are 0, but the least-significant one is
// either 0 or 1.
yield_constr.constraint(eq_or_iszero_filter * equal * unequal);
for &limb in &logic.output[1..] {
for &limb in &output[1..] {
yield_constr.constraint(eq_or_iszero_filter * limb);
}
// If `ISZERO`, constrain input1 to be zero, effectively implementing ISZERO(x) as EQ(x, 0).
for limb in logic.input1 {
for limb in input1 {
yield_constr.constraint(iszero_filter * limb);
}
// `equal` implies `input0[i] == input1[i]` for all `i`.
for (limb0, limb1) in izip!(logic.input0, logic.input1) {
for (limb0, limb1) in izip!(input0, input1) {
let diff = limb0 - limb1;
yield_constr.constraint(eq_or_iszero_filter * equal * diff);
}
@ -82,7 +91,7 @@ pub fn eval_packed<P: PackedField>(
// If `unequal`, find `diff_pinv` such that `(input0 - input1) @ diff_pinv == 1`, where `@`
// denotes the dot product (there will be many such `diff_pinv`). This can only be done if
// `input0 != input1`.
let dot: P = izip!(logic.input0, logic.input1, logic.diff_pinv)
let dot: P = izip!(input0, input1, logic.diff_pinv)
.map(|(limb0, limb1, diff_pinv_el)| (limb0 - limb1) * diff_pinv_el)
.sum();
yield_constr.constraint(eq_or_iszero_filter * (dot - unequal));
@ -97,11 +106,15 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
let one = builder.one_extension();
let logic = lv.general.logic();
let input0 = lv.mem_value[0];
let input1 = lv.mem_value[1];
let output = lv.mem_value[2];
let eq_filter = lv.is_eq;
let iszero_filter = lv.is_iszero;
let eq_or_iszero_filter = builder.add_extension(eq_filter, iszero_filter);
let equal = logic.output[0];
let equal = output[0];
let unequal = builder.sub_extension(one, equal);
// Handle `EQ` and `ISZERO`. Most limbs of the output are 0, but the least-significant one is
@ -111,19 +124,19 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
let constr = builder.mul_extension(eq_or_iszero_filter, constr);
yield_constr.constraint(builder, constr);
}
for &limb in &logic.output[1..] {
for &limb in &output[1..] {
let constr = builder.mul_extension(eq_or_iszero_filter, limb);
yield_constr.constraint(builder, constr);
}
// If `ISZERO`, constrain input1 to be zero, effectively implementing ISZERO(x) as EQ(x, 0).
for limb in logic.input1 {
for limb in input1 {
let constr = builder.mul_extension(iszero_filter, limb);
yield_constr.constraint(builder, constr);
}
// `equal` implies `input0[i] == input1[i]` for all `i`.
for (limb0, limb1) in izip!(logic.input0, logic.input1) {
for (limb0, limb1) in izip!(input0, input1) {
let diff = builder.sub_extension(limb0, limb1);
let constr = builder.mul_extension(equal, diff);
let constr = builder.mul_extension(eq_or_iszero_filter, constr);
@ -135,7 +148,7 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
// denotes the dot product (there will be many such `diff_pinv`). This can only be done if
// `input0 != input1`.
{
let dot: ExtensionTarget<D> = izip!(logic.input0, logic.input1, logic.diff_pinv).fold(
let dot: ExtensionTarget<D> = izip!(input0, input1, logic.diff_pinv).fold(
zero,
|cumul, (limb0, limb1, diff_pinv_el)| {
let diff = builder.sub_extension(limb0, limb1);

View File

@ -17,8 +17,9 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
}
assert_eq!(is_not_filter, 1);
let logic = lv.general.logic_mut();
for (input, output_ref) in logic.input0.into_iter().zip(logic.output.iter_mut()) {
let input = lv.mem_value[0];
let output = &mut lv.mem_value[1];
for (input, output_ref) in input.into_iter().zip(output.iter_mut()) {
let input = input.to_canonical_u64();
assert_eq!(input >> LIMB_SIZE, 0);
let output = input ^ ALL_1_LIMB;
@ -30,14 +31,16 @@ pub fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
// This is simple: just do output = 0xffff - input.
let logic = lv.general.logic();
// This is simple: just do output = 0xffffffff - input.
let input = lv.mem_value[0];
let output = lv.mem_value[1];
let cycle_filter = lv.is_cpu_cycle;
let is_not_filter = lv.is_not;
let filter = cycle_filter * is_not_filter;
for (input, output) in logic.input0.into_iter().zip(logic.output) {
yield_constr
.constraint(filter * (output + input - P::Scalar::from_canonical_u64(ALL_1_LIMB)));
for (input_limb, output_limb) in input.into_iter().zip(output) {
yield_constr.constraint(
filter * (output_limb + input_limb - P::Scalar::from_canonical_u64(ALL_1_LIMB)),
);
}
}
@ -46,12 +49,13 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
lv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let logic = lv.general.logic();
let input = lv.mem_value[0];
let output = lv.mem_value[1];
let cycle_filter = lv.is_cpu_cycle;
let is_not_filter = lv.is_not;
let filter = builder.mul_extension(cycle_filter, is_not_filter);
for (input, output) in logic.input0.into_iter().zip(logic.output) {
let constr = builder.add_extension(output, input);
for (input_limb, output_limb) in input.into_iter().zip(output) {
let constr = builder.add_extension(output_limb, input_limb);
let constr = builder.arithmetic_extension(
F::ONE,
-F::from_canonical_u64(ALL_1_LIMB),

View File

@ -28,7 +28,6 @@ pub fn eval_packed<P: PackedField>(
nv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let lv_syscalls = lv.general.syscalls();
let syscall_list = Lazy::force(&TRAP_LIST);
// 1 if _any_ syscall, else 0.
let should_syscall: P = syscall_list
@ -48,12 +47,14 @@ pub fn eval_packed<P: PackedField>(
yield_constr.constraint_transition(filter * (nv.program_counter - syscall_dst));
// If syscall: set kernel mode
yield_constr.constraint_transition(filter * (nv.is_kernel_mode - P::ONES));
let output = lv.mem_value[0];
// If syscall: push current PC to stack
yield_constr.constraint(filter * (lv_syscalls.output[0] - lv.program_counter));
yield_constr.constraint(filter * (output[0] - lv.program_counter));
// If syscall: push current kernel flag to stack (share register with PC)
yield_constr.constraint(filter * (lv_syscalls.output[1] - lv.is_kernel_mode));
yield_constr.constraint(filter * (output[1] - lv.is_kernel_mode));
// If syscall: zero the rest of that register
for &limb in &lv_syscalls.output[2..] {
for &limb in &output[2..] {
yield_constr.constraint(filter * limb);
}
}
@ -64,7 +65,6 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let lv_syscalls = lv.general.syscalls();
let syscall_list = Lazy::force(&TRAP_LIST);
// 1 if _any_ syscall, else 0.
let should_syscall =
@ -90,20 +90,22 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
let constr = builder.mul_sub_extension(filter, nv.is_kernel_mode, filter);
yield_constr.constraint_transition(builder, constr);
}
let output = lv.mem_value[0];
// If syscall: push current PC to stack
{
let constr = builder.sub_extension(lv_syscalls.output[0], lv.program_counter);
let constr = builder.sub_extension(output[0], lv.program_counter);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
// If syscall: push current kernel flag to stack (share register with PC)
{
let constr = builder.sub_extension(lv_syscalls.output[1], lv.is_kernel_mode);
let constr = builder.sub_extension(output[1], lv.is_kernel_mode);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
// If syscall: zero the rest of that register
for &limb in &lv_syscalls.output[2..] {
for &limb in &output[2..] {
let constr = builder.mul_extension(filter, limb);
yield_constr.constraint(builder, constr);
}