Connect SHL/SHR operations to the Arithmetic table (#1166)

* Add corresponding arithmetic operations to shift ones

* Include SHL/SHR in the arithmetic CTL

* Prevent overflow

* Expand documentation for ctl_data_ternops()
This commit is contained in:
Robin Salen 2023-08-09 09:17:06 -04:00 committed by GitHub
parent df07ae093a
commit 5f4b15af7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 11 deletions

View File

@ -108,7 +108,10 @@ pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {
fn ctl_arithmetic<F: Field>() -> CrossTableLookup<F> { fn ctl_arithmetic<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new( CrossTableLookup::new(
vec![cpu_stark::ctl_arithmetic_rows()], vec![
cpu_stark::ctl_arithmetic_base_rows(),
cpu_stark::ctl_arithmetic_shift_rows(),
],
arithmetic_stark::ctl_arithmetic_rows(), arithmetic_stark::ctl_arithmetic_rows(),
) )
} }

View File

@ -59,12 +59,23 @@ fn ctl_data_binops<F: Field>(ops: &[usize]) -> Vec<Column<F>> {
} }
/// Create the vector of Columns corresponding to the three inputs and /// Create the vector of Columns corresponding to the three inputs and
/// one output of a ternary operation. /// one output of a ternary operation. By default, ternary operations use
fn ctl_data_ternops<F: Field>(ops: &[usize]) -> Vec<Column<F>> { /// the first three memory channels, and the last one for the result (binary
/// operations do not use the third inputs).
///
/// Shift operations are different, as they are simulated with `MUL` or `DIV`
/// on the arithmetic side. We first convert the shift into the multiplicand
/// (in case of `SHL`) or the divisor (in case of `SHR`), making the first memory
/// channel not directly usable. We overcome this by adding an offset of 1 in
/// case of shift operations, which will skip the first memory channel and use the
/// next three as ternary inputs. Because both `MUL` and `DIV` are binary operations,
/// the last memory channel used for the inputs will be safely ignored.
fn ctl_data_ternops<F: Field>(ops: &[usize], is_shift: bool) -> Vec<Column<F>> {
let offset = is_shift as usize;
let mut res = Column::singles(ops).collect_vec(); let mut res = Column::singles(ops).collect_vec();
res.extend(Column::singles(COL_MAP.mem_channels[0].value)); res.extend(Column::singles(COL_MAP.mem_channels[offset].value));
res.extend(Column::singles(COL_MAP.mem_channels[1].value)); res.extend(Column::singles(COL_MAP.mem_channels[offset + 1].value));
res.extend(Column::singles(COL_MAP.mem_channels[2].value)); res.extend(Column::singles(COL_MAP.mem_channels[offset + 2].value));
res.extend(Column::singles( res.extend(Column::singles(
COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value, COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value,
)); ));
@ -79,7 +90,7 @@ pub fn ctl_filter_logic<F: Field>() -> Column<F> {
Column::sum([COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor]) Column::sum([COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor])
} }
pub fn ctl_arithmetic_rows<F: Field>() -> TableWithColumns<F> { pub fn ctl_arithmetic_base_rows<F: Field>() -> TableWithColumns<F> {
const OPS: [usize; 14] = [ const OPS: [usize; 14] = [
COL_MAP.op.add, COL_MAP.op.add,
COL_MAP.op.sub, COL_MAP.op.sub,
@ -101,7 +112,42 @@ pub fn ctl_arithmetic_rows<F: Field>() -> TableWithColumns<F> {
// (also `ops` is used as the operation filter). The list of // (also `ops` is used as the operation filter). The list of
// operations includes binary operations which will simply ignore // operations includes binary operations which will simply ignore
// the third input. // the third input.
TableWithColumns::new(Table::Cpu, ctl_data_ternops(&OPS), Some(Column::sum(OPS))) TableWithColumns::new(
Table::Cpu,
ctl_data_ternops(&OPS, false),
Some(Column::sum(OPS)),
)
}
pub fn ctl_arithmetic_shift_rows<F: Field>() -> TableWithColumns<F> {
const OPS: [usize; 14] = [
COL_MAP.op.add,
COL_MAP.op.sub,
// SHL is interpreted as MUL on the arithmetic side
COL_MAP.op.shl,
COL_MAP.op.lt,
COL_MAP.op.gt,
COL_MAP.op.addfp254,
COL_MAP.op.mulfp254,
COL_MAP.op.subfp254,
COL_MAP.op.addmod,
COL_MAP.op.mulmod,
COL_MAP.op.submod,
// SHR is interpreted as DIV on the arithmetic side
COL_MAP.op.shr,
COL_MAP.op.mod_,
COL_MAP.op.byte,
];
// Create the CPU Table whose columns are those with the three
// inputs and one output of the ternary operations listed in `ops`
// (also `ops` is used as the operation filter). The list of
// operations includes binary operations which will simply ignore
// the third input.
TableWithColumns::new(
Table::Cpu,
ctl_data_ternops(&OPS, true),
Some(Column::sum([COL_MAP.op.shl, COL_MAP.op.shr])),
)
} }
pub const MEM_CODE_CHANNEL_IDX: usize = 0; pub const MEM_CODE_CHANNEL_IDX: usize = 0;

View File

@ -54,7 +54,7 @@ pub(crate) fn eval_packed<P: PackedField>(
// (in the case of left shift) or DIV (in the case of right shift) // (in the case of left shift) or DIV (in the case of right shift)
// in the arithmetic table. Specifically, the mapping is // in the arithmetic table. Specifically, the mapping is
// //
// 0 -> 0 (value to be shifted is the same) // 1 -> 0 (value to be shifted is the same)
// 2 -> 1 (two_exp becomes the multiplicand (resp. divisor)) // 2 -> 1 (two_exp becomes the multiplicand (resp. divisor))
// last -> last (output is the same) // last -> last (output is the same)
} }

View File

@ -3,6 +3,7 @@ use itertools::Itertools;
use keccak_hash::keccak; use keccak_hash::keccak;
use plonky2::field::types::Field; use plonky2::field::types::Field;
use crate::arithmetic::BinaryOperator;
use crate::cpu::columns::CpuColumnsView; use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::assembler::BYTES_PER_OFFSET; use crate::cpu::kernel::assembler::BYTES_PER_OFFSET;
@ -470,6 +471,7 @@ fn append_shift<F: Field>(
state: &mut GenerationState<F>, state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>, mut row: CpuColumnsView<F>,
input0: U256, input0: U256,
input1: U256,
log_in0: MemoryOp, log_in0: MemoryOp,
log_in1: MemoryOp, log_in1: MemoryOp,
result: U256, result: U256,
@ -489,6 +491,20 @@ fn append_shift<F: Field>(
channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt); channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt);
} }
// Convert the shift, and log the corresponding arithmetic operation.
let input0 = if input0 > U256::from(255u64) {
U256::zero()
} else {
U256::one() << input0
};
let operator = if row.op.shl.is_one() {
BinaryOperator::Mul
} else {
BinaryOperator::Div
};
let operation = arithmetic::Operation::binary(operator, input1, input0);
state.traces.push_arithmetic(operation);
state.traces.push_memory(log_in0); state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1); state.traces.push_memory(log_in1);
state.traces.push_memory(log_out); state.traces.push_memory(log_out);
@ -508,7 +524,7 @@ pub(crate) fn generate_shl<F: Field>(
} else { } else {
input1 << input0 input1 << input0
}; };
append_shift(state, row, input0, log_in0, log_in1, result) append_shift(state, row, input0, input1, log_in0, log_in1, result)
} }
pub(crate) fn generate_shr<F: Field>( pub(crate) fn generate_shr<F: Field>(
@ -523,7 +539,7 @@ pub(crate) fn generate_shr<F: Field>(
} else { } else {
input1 >> input0 input1 >> input0
}; };
append_shift(state, row, input0, log_in0, log_in1, result) append_shift(state, row, input0, input1, log_in0, log_in1, result)
} }
pub(crate) fn generate_syscall<F: Field>( pub(crate) fn generate_syscall<F: Field>(