Merge pull request #621 from mir-protocol/cpu_shared_cols

Shared CPU columns
This commit is contained in:
Daniel Lubarov 2022-07-28 14:10:34 -07:00 committed by GitHub
commit 431bb5e66e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 166 additions and 63 deletions

View File

@ -130,7 +130,7 @@ mod tests {
use anyhow::Result;
use ethereum_types::U256;
use itertools::{izip, Itertools};
use itertools::Itertools;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::{Field, PrimeField64};
use plonky2::iop::witness::PartialWitness;
@ -247,13 +247,10 @@ mod tests {
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.is_keccak = F::ONE;
for (j, input, output) in izip!(
0..2 * NUM_INPUTS,
row.keccak_input_limbs.iter_mut(),
row.keccak_output_limbs.iter_mut()
) {
*input = keccak_input_limbs[i][j];
*output = keccak_output_limbs[i][j];
let keccak = row.general.keccak_mut();
for j in 0..2 * NUM_INPUTS {
keccak.input_limbs[j] = keccak_input_limbs[i][j];
keccak.output_limbs[j] = keccak_output_limbs[i][j];
}
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
@ -271,21 +268,22 @@ mod tests {
.into_iter()
.map(|(col, opcode)| logic_trace[col].values[i] * F::from_canonical_u64(opcode))
.sum();
for (cols_cpu, cols_logic) in [
(&mut row.logic_input0, logic::columns::INPUT0),
(&mut row.logic_input1, logic::columns::INPUT1),
] {
for (col_cpu, limb_cols_logic) in cols_cpu
.iter_mut()
.zip(logic::columns::limb_bit_cols_for_input(cols_logic))
{
*col_cpu =
limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
let logic = row.general.logic_mut();
let input0_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT0);
for (col_cpu, limb_cols_logic) in logic.input0.iter_mut().zip(input0_bit_cols) {
*col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
for (col_cpu, col_logic) in row.logic_output.iter_mut().zip(logic::columns::RESULT) {
let input1_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT1);
for (col_cpu, limb_cols_logic) in logic.input1.iter_mut().zip(input1_bit_cols) {
*col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
for (col_cpu, col_logic) in logic.output.iter_mut().zip(logic::columns::RESULT) {
*col_cpu = logic_trace[col_logic].values[i];
}
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}

View File

@ -56,7 +56,8 @@ pub(crate) fn generate_bootstrap_kernel<F: Field>(state: &mut GenerationState<F>
}
sponge_state[sponge_input_pos] = packed_bytes;
state.current_cpu_row.keccak_input_limbs = sponge_state.map(F::from_canonical_u32);
let keccak = state.current_cpu_row.general.keccak_mut();
keccak.input_limbs = sponge_state.map(F::from_canonical_u32);
state.commit_cpu_row();
sponge_input_pos = (sponge_input_pos + 1) % KECCAK_RATE_LIMBS;
@ -65,7 +66,8 @@ pub(crate) fn generate_bootstrap_kernel<F: Field>(state: &mut GenerationState<F>
if sponge_input_pos == 0 {
state.current_cpu_row.is_keccak = F::ONE;
keccakf_u32s(&mut sponge_state);
state.current_cpu_row.keccak_output_limbs = sponge_state.map(F::from_canonical_u32);
let keccak = state.current_cpu_row.general.keccak_mut();
keccak.output_limbs = sponge_state.map(F::from_canonical_u32);
}
}
}
@ -97,7 +99,7 @@ pub(crate) fn eval_bootstrap_kernel<F: Field, P: PackedField<Scalar = F>>(
for (&expected, actual) in KERNEL
.code_hash
.iter()
.zip(local_values.keccak_output_limbs)
.zip(local_values.general.keccak().output_limbs)
{
let expected = P::from(F::from_canonical_u32(expected));
let diff = expected - actual;
@ -137,7 +139,7 @@ pub(crate) fn eval_bootstrap_kernel_circuit<F: RichField + Extendable<D>, const
for (&expected, actual) in KERNEL
.code_hash
.iter()
.zip(local_values.keccak_output_limbs)
.zip(local_values.general.keccak().output_limbs)
{
let expected = builder.constant_extension(F::Extension::from_canonical_u32(expected));
let diff = builder.sub_extension(expected, actual);

View File

@ -0,0 +1,95 @@
use std::borrow::{Borrow, BorrowMut};
use std::fmt::{Debug, Formatter};
use std::mem::{size_of, transmute};
/// General purpose columns, which can have different meanings depending on what CTL or other
/// operation is occurring at this row.
pub(crate) union CpuGeneralColumnsView<T: Copy> {
keccak: CpuKeccakView<T>,
arithmetic: CpuArithmeticView<T>,
logic: CpuLogicView<T>,
}
impl<T: Copy> CpuGeneralColumnsView<T> {
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn keccak(&self) -> &CpuKeccakView<T> {
unsafe { &self.keccak }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn keccak_mut(&mut self) -> &mut CpuKeccakView<T> {
unsafe { &mut self.keccak }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn arithmetic(&self) -> &CpuArithmeticView<T> {
unsafe { &self.arithmetic }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn arithmetic_mut(&mut self) -> &mut CpuArithmeticView<T> {
unsafe { &mut self.arithmetic }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn logic(&self) -> &CpuLogicView<T> {
unsafe { &self.logic }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn logic_mut(&mut self) -> &mut CpuLogicView<T> {
unsafe { &mut self.logic }
}
}
impl<T: Copy + PartialEq> PartialEq<Self> for CpuGeneralColumnsView<T> {
fn eq(&self, other: &Self) -> bool {
let self_arr: &[T; NUM_SHARED_COLUMNS] = self.borrow();
let other_arr: &[T; NUM_SHARED_COLUMNS] = other.borrow();
self_arr == other_arr
}
}
impl<T: Copy + Eq> Eq for CpuGeneralColumnsView<T> {}
impl<T: Copy + Debug> Debug for CpuGeneralColumnsView<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let self_arr: &[T; NUM_SHARED_COLUMNS] = self.borrow();
Debug::fmt(self_arr, f)
}
}
impl<T: Copy> Borrow<[T; NUM_SHARED_COLUMNS]> for CpuGeneralColumnsView<T> {
fn borrow(&self) -> &[T; NUM_SHARED_COLUMNS] {
unsafe { transmute(self) }
}
}
impl<T: Copy> BorrowMut<[T; NUM_SHARED_COLUMNS]> for CpuGeneralColumnsView<T> {
fn borrow_mut(&mut self) -> &mut [T; NUM_SHARED_COLUMNS] {
unsafe { transmute(self) }
}
}
#[derive(Copy, Clone)]
pub(crate) struct CpuKeccakView<T: Copy> {
pub(crate) input_limbs: [T; 50],
pub(crate) output_limbs: [T; 50],
}
#[derive(Copy, Clone)]
pub(crate) struct CpuArithmeticView<T: Copy> {
// TODO: Add "looking" columns for the arithmetic CTL.
tmp: T, // temporary, to suppress errors
}
#[derive(Copy, Clone)]
pub(crate) struct CpuLogicView<T: Copy> {
// Assuming a limb size of 16 bits. This can be changed, but it must be <= 28 bits.
pub(crate) input0: [T; 16],
pub(crate) input1: [T; 16],
pub(crate) output: [T; 16],
}
// `u8` is guaranteed to have a `size_of` of 1.
pub const NUM_SHARED_COLUMNS: usize = size_of::<CpuGeneralColumnsView<u8>>();

View File

@ -2,14 +2,18 @@
#![allow(dead_code)]
use std::borrow::{Borrow, BorrowMut};
use std::fmt::Debug;
use std::mem::{size_of, transmute, transmute_copy, ManuallyDrop};
use std::ops::{Index, IndexMut};
use crate::cpu::columns::general::CpuGeneralColumnsView;
use crate::memory;
mod general;
#[repr(C)]
#[derive(Eq, PartialEq, Debug)]
pub struct CpuColumnsView<T> {
pub struct CpuColumnsView<T: Copy> {
/// Filter. 1 if the row is part of bootstrapping the kernel code, 0 otherwise.
pub is_bootstrap_kernel: T,
@ -147,14 +151,9 @@ pub struct CpuColumnsView<T> {
/// Filter. 1 iff a Keccak permutation is computed on this row.
pub is_keccak: T,
pub keccak_input_limbs: [T; 50],
pub keccak_output_limbs: [T; 50],
// Assuming a limb size of 16 bits. This can be changed, but it must be <= 28 bits.
// TODO: These input/output columns can be shared between the logic operations and others.
pub logic_input0: [T; 16],
pub logic_input1: [T; 16],
pub logic_output: [T; 16],
pub(crate) general: CpuGeneralColumnsView<T>,
pub simple_logic_diff: T,
pub simple_logic_diff_inv: T,
@ -180,43 +179,43 @@ unsafe fn transmute_no_compile_time_size_checks<T, U>(value: T) -> U {
transmute_copy(&value)
}
impl<T> From<[T; NUM_CPU_COLUMNS]> for CpuColumnsView<T> {
impl<T: Copy> From<[T; NUM_CPU_COLUMNS]> for CpuColumnsView<T> {
fn from(value: [T; NUM_CPU_COLUMNS]) -> Self {
unsafe { transmute_no_compile_time_size_checks(value) }
}
}
impl<T> From<CpuColumnsView<T>> for [T; NUM_CPU_COLUMNS] {
impl<T: Copy> From<CpuColumnsView<T>> for [T; NUM_CPU_COLUMNS] {
fn from(value: CpuColumnsView<T>) -> Self {
unsafe { transmute_no_compile_time_size_checks(value) }
}
}
impl<T> Borrow<CpuColumnsView<T>> for [T; NUM_CPU_COLUMNS] {
impl<T: Copy> Borrow<CpuColumnsView<T>> for [T; NUM_CPU_COLUMNS] {
fn borrow(&self) -> &CpuColumnsView<T> {
unsafe { transmute(self) }
}
}
impl<T> BorrowMut<CpuColumnsView<T>> for [T; NUM_CPU_COLUMNS] {
impl<T: Copy> BorrowMut<CpuColumnsView<T>> for [T; NUM_CPU_COLUMNS] {
fn borrow_mut(&mut self) -> &mut CpuColumnsView<T> {
unsafe { transmute(self) }
}
}
impl<T> Borrow<[T; NUM_CPU_COLUMNS]> for CpuColumnsView<T> {
impl<T: Copy> Borrow<[T; NUM_CPU_COLUMNS]> for CpuColumnsView<T> {
fn borrow(&self) -> &[T; NUM_CPU_COLUMNS] {
unsafe { transmute(self) }
}
}
impl<T> BorrowMut<[T; NUM_CPU_COLUMNS]> for CpuColumnsView<T> {
impl<T: Copy> BorrowMut<[T; NUM_CPU_COLUMNS]> for CpuColumnsView<T> {
fn borrow_mut(&mut self) -> &mut [T; NUM_CPU_COLUMNS] {
unsafe { transmute(self) }
}
}
impl<T, I> Index<I> for CpuColumnsView<T>
impl<T: Copy, I> Index<I> for CpuColumnsView<T>
where
[T]: Index<I>,
{
@ -228,7 +227,7 @@ where
}
}
impl<T, I> IndexMut<I> for CpuColumnsView<T>
impl<T: Copy, I> IndexMut<I> for CpuColumnsView<T>
where
[T]: IndexMut<I>,
{

View File

@ -16,8 +16,9 @@ use crate::stark::Stark;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
pub fn ctl_data_keccak<F: Field>() -> Vec<Column<F>> {
let mut res: Vec<_> = Column::singles(COL_MAP.keccak_input_limbs).collect();
res.extend(Column::singles(COL_MAP.keccak_output_limbs));
let keccak = COL_MAP.general.keccak();
let mut res: Vec<_> = Column::singles(keccak.input_limbs).collect();
res.extend(Column::singles(keccak.output_limbs));
res
}
@ -27,9 +28,10 @@ pub fn ctl_filter_keccak<F: Field>() -> Column<F> {
pub fn ctl_data_logic<F: Field>() -> Vec<Column<F>> {
let mut res = Column::singles([COL_MAP.is_and, COL_MAP.is_or, COL_MAP.is_xor]).collect_vec();
res.extend(Column::singles(COL_MAP.logic_input0));
res.extend(Column::singles(COL_MAP.logic_input1));
res.extend(Column::singles(COL_MAP.logic_output));
let logic = COL_MAP.general.logic();
res.extend(Column::singles(logic.input0));
res.extend(Column::singles(logic.input1));
res.extend(Column::singles(logic.output));
res
}

View File

@ -9,6 +9,7 @@ use crate::cpu::columns::CpuColumnsView;
const LIMB_SIZE: usize = 16;
pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
let logic = lv.general.logic_mut();
let eq_filter = lv.is_eq.to_canonical_u64();
let iszero_filter = lv.is_iszero.to_canonical_u64();
assert!(eq_filter <= 1);
@ -20,9 +21,10 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
}
let diffs = if eq_filter == 1 {
lv.logic_input0
logic
.input0
.into_iter()
.zip(lv.logic_input1)
.zip(logic.input1)
.map(|(in0, in1)| {
assert_eq!(in0.to_canonical_u64() >> LIMB_SIZE, 0);
assert_eq!(in1.to_canonical_u64() >> LIMB_SIZE, 0);
@ -31,7 +33,7 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
})
.sum()
} else if iszero_filter == 1 {
lv.logic_input0.into_iter().sum()
logic.input0.into_iter().sum()
} else {
panic!()
};
@ -39,8 +41,8 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
lv.simple_logic_diff = diffs;
lv.simple_logic_diff_inv = diffs.try_inverse().unwrap_or(F::ZERO);
lv.logic_output[0] = F::from_bool(diffs == F::ZERO);
for out_limb_ref in lv.logic_output[1..].iter_mut() {
logic.output[0] = F::from_bool(diffs == F::ZERO);
for out_limb_ref in logic.output[1..].iter_mut() {
*out_limb_ref = F::ZERO;
}
}
@ -49,17 +51,18 @@ pub fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let logic = lv.general.logic();
let eq_filter = lv.is_eq;
let iszero_filter = lv.is_iszero;
let eq_or_iszero_filter = eq_filter + iszero_filter;
let ls_bit = lv.logic_output[0];
let ls_bit = logic.output[0];
// Handle EQ and ISZERO. Most limbs of the output are 0, but the least-significant one is
// either 0 or 1.
yield_constr.constraint(eq_or_iszero_filter * ls_bit * (ls_bit - P::ONES));
for &bit in &lv.logic_output[1..] {
for &bit in &logic.output[1..] {
yield_constr.constraint(eq_or_iszero_filter * bit);
}
@ -67,13 +70,13 @@ pub fn eval_packed<P: PackedField>(
let diffs = lv.simple_logic_diff;
let diffs_inv = lv.simple_logic_diff_inv;
{
let input0_sum: P = lv.logic_input0.into_iter().sum();
let input0_sum: P = logic.input0.into_iter().sum();
yield_constr.constraint(iszero_filter * (diffs - input0_sum));
let sum_squared_diffs: P = lv
.logic_input0
let sum_squared_diffs: P = logic
.input0
.into_iter()
.zip(lv.logic_input1)
.zip(logic.input1)
.map(|(in0, in1)| (in0 - in1).square())
.sum();
yield_constr.constraint(eq_filter * (diffs - sum_squared_diffs));
@ -90,11 +93,12 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
lv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let logic = lv.general.logic();
let eq_filter = lv.is_eq;
let iszero_filter = lv.is_iszero;
let eq_or_iszero_filter = builder.add_extension(eq_filter, iszero_filter);
let ls_bit = lv.logic_output[0];
let ls_bit = logic.output[0];
// Handle EQ and ISZERO. Most limbs of the output are 0, but the least-significant one is
// either 0 or 1.
@ -104,7 +108,7 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
yield_constr.constraint(builder, constr);
}
for &bit in &lv.logic_output[1..] {
for &bit in &logic.output[1..] {
let constr = builder.mul_extension(eq_or_iszero_filter, bit);
yield_constr.constraint(builder, constr);
}
@ -113,14 +117,14 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
let diffs = lv.simple_logic_diff;
let diffs_inv = lv.simple_logic_diff_inv;
{
let input0_sum = builder.add_many_extension(lv.logic_input0);
let input0_sum = builder.add_many_extension(logic.input0);
{
let constr = builder.sub_extension(diffs, input0_sum);
let constr = builder.mul_extension(iszero_filter, constr);
yield_constr.constraint(builder, constr);
}
let sum_squared_diffs = lv.logic_input0.into_iter().zip(lv.logic_input1).fold(
let sum_squared_diffs = logic.input0.into_iter().zip(logic.input1).fold(
builder.zero_extension(),
|acc, (in0, in1)| {
let diff = builder.sub_extension(in0, in1);

View File

@ -17,7 +17,8 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
}
assert_eq!(is_not_filter, 1);
for (input, output_ref) in lv.logic_input0.into_iter().zip(lv.logic_output.iter_mut()) {
let logic = lv.general.logic_mut();
for (input, output_ref) in logic.input0.into_iter().zip(logic.output.iter_mut()) {
let input = input.to_canonical_u64();
assert_eq!(input >> LIMB_SIZE, 0);
let output = input ^ ALL_1_LIMB;
@ -30,10 +31,11 @@ pub fn eval_packed<P: PackedField>(
yield_constr: &mut ConstraintConsumer<P>,
) {
// This is simple: just do output = 0xffff - input.
let logic = lv.general.logic();
let cycle_filter = lv.is_cpu_cycle;
let is_not_filter = lv.is_not;
let filter = cycle_filter * is_not_filter;
for (input, output) in lv.logic_input0.into_iter().zip(lv.logic_output) {
for (input, output) in logic.input0.into_iter().zip(logic.output) {
yield_constr
.constraint(filter * (output + input - P::Scalar::from_canonical_u64(ALL_1_LIMB)));
}
@ -44,10 +46,11 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
lv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let logic = lv.general.logic();
let cycle_filter = lv.is_cpu_cycle;
let is_not_filter = lv.is_not;
let filter = builder.mul_extension(cycle_filter, is_not_filter);
for (input, output) in lv.logic_input0.into_iter().zip(lv.logic_output) {
for (input, output) in logic.input0.into_iter().zip(logic.output) {
let constr = builder.add_extension(output, input);
let constr = builder.arithmetic_extension(
F::ONE,