CTL: limbs (CPU) <-> bits (logic) (#577)

* CTL: limbs (CPU) <-> bits (logic)

* Minor: stray TODO mark

* Document Zero op

* Util for constructing an int from bits
This commit is contained in:
Jacqueline Nabaglo 2022-06-25 13:34:04 -07:00 committed by GitHub
parent 46df1bb6b2
commit 912281de9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 80 additions and 115 deletions

View File

@ -87,7 +87,7 @@ mod tests {
add_virtual_all_proof, set_all_proof_target, verify_proof_circuit,
};
use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values;
use crate::util::{limb_from_bits_le, trace_rows_to_poly_values};
use crate::verifier::verify_proof;
use crate::{cpu, keccak, memory};
@ -116,11 +116,11 @@ mod tests {
let mut row = [F::ZERO; logic::columns::NUM_COLUMNS];
assert_eq!(logic::PACKED_LIMB_BITS, 16);
for col in logic::columns::INPUT0_PACKED {
row[col] = F::from_canonical_u16(rng.gen());
for col in logic::columns::INPUT0 {
row[col] = F::from_bool(rng.gen());
}
for col in logic::columns::INPUT1_PACKED {
row[col] = F::from_canonical_u16(rng.gen());
for col in logic::columns::INPUT1 {
row[col] = F::from_bool(rng.gen());
}
let op: usize = rng.gen_range(0..3);
let op_col = [
@ -207,14 +207,19 @@ mod tests {
.map(|(col, opcode)| logic_trace[col].values[i] * F::from_canonical_u64(opcode))
.sum();
for (cols_cpu, cols_logic) in [
(cpu::columns::LOGIC_INPUT0, logic::columns::INPUT0_PACKED),
(cpu::columns::LOGIC_INPUT1, logic::columns::INPUT1_PACKED),
(cpu::columns::LOGIC_OUTPUT, logic::columns::RESULT),
(cpu::columns::LOGIC_INPUT0, logic::columns::INPUT0),
(cpu::columns::LOGIC_INPUT1, logic::columns::INPUT1),
] {
for (col_cpu, col_logic) in cols_cpu.zip(cols_logic) {
row[col_cpu] = logic_trace[col_logic].values[i];
for (col_cpu, limb_cols_logic) in
cols_cpu.zip(logic::columns::limb_bit_cols_for_input(cols_logic))
{
row[col_cpu] =
limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
}
for (col_cpu, col_logic) in cpu::columns::LOGIC_OUTPUT.zip(logic::columns::RESULT) {
row[col_cpu] = logic_trace[col_logic].values[i];
}
cpu_stark.generate(&mut row);
cpu_trace_rows.push(row);
}

View File

@ -9,6 +9,7 @@ use plonky2::hash::hash_types::RichField;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::Column;
use crate::stark::Stark;
use crate::util::{limb_from_bits_le, limb_from_bits_le_recursive};
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
// Total number of bits per input/output.
@ -27,11 +28,11 @@ pub(crate) mod columns {
pub const IS_AND: usize = 0;
pub const IS_OR: usize = IS_AND + 1;
pub const IS_XOR: usize = IS_OR + 1;
pub const INPUT0_PACKED: Range<usize> = (IS_XOR + 1)..(IS_XOR + 1) + PACKED_LEN;
pub const INPUT1_PACKED: Range<usize> = INPUT0_PACKED.end..INPUT0_PACKED.end + PACKED_LEN;
pub const RESULT: Range<usize> = INPUT1_PACKED.end..INPUT1_PACKED.end + PACKED_LEN;
pub const INPUT0_BITS: Range<usize> = RESULT.end..RESULT.end + VAL_BITS;
pub const INPUT1_BITS: Range<usize> = INPUT0_BITS.end..INPUT0_BITS.end + VAL_BITS;
// The inputs are decomposed into bits.
pub const INPUT0: Range<usize> = (IS_XOR + 1)..(IS_XOR + 1) + VAL_BITS;
pub const INPUT1: Range<usize> = INPUT0.end..INPUT0.end + VAL_BITS;
// The result is packed in limbs of `PACKED_LIMB_BITS` bits.
pub const RESULT: Range<usize> = INPUT1.end..INPUT1.end + PACKED_LEN;
pub fn limb_bit_cols_for_input(input_bits: Range<usize>) -> impl Iterator<Item = Range<usize>> {
(0..PACKED_LEN).map(move |i| {
@ -41,7 +42,7 @@ pub(crate) mod columns {
})
}
pub const NUM_COLUMNS: usize = INPUT1_BITS.end;
pub const NUM_COLUMNS: usize = RESULT.end;
}
pub fn ctl_data<F: Field>() -> Vec<Column<F>> {
@ -50,8 +51,8 @@ pub fn ctl_data<F: Field>() -> Vec<Column<F>> {
Column::single(columns::IS_OR),
Column::single(columns::IS_XOR),
];
res.extend(columns::INPUT0_PACKED.map(Column::single));
res.extend(columns::INPUT1_PACKED.map(Column::single));
res.extend(columns::limb_bit_cols_for_input(columns::INPUT0).map(Column::le_bits));
res.extend(columns::limb_bit_cols_for_input(columns::INPUT1).map(Column::le_bits));
res.extend(columns::RESULT.map(Column::single));
res
}
@ -66,6 +67,9 @@ pub struct LogicStark<F, const D: usize> {
}
enum Op {
// The `Zero` op is just for convenience. The all-zero row already satisfies the constraints;
// `Zero` lets us call `generate` on it without crashing.
Zero,
And,
Or,
Xor,
@ -78,7 +82,7 @@ fn check_op_flags<F: RichField>(lv: &[F; columns::NUM_COLUMNS]) -> Op {
assert!(is_or <= 1);
let is_xor = lv[columns::IS_XOR].to_canonical_u64();
assert!(is_xor <= 1);
assert_eq!(is_and + is_or + is_xor, 1);
assert!(is_and + is_or + is_xor <= 1);
if is_and == 1 {
Op::And
} else if is_or == 1 {
@ -86,54 +90,29 @@ fn check_op_flags<F: RichField>(lv: &[F; columns::NUM_COLUMNS]) -> Op {
} else if is_xor == 1 {
Op::Xor
} else {
panic!("unknown operation")
Op::Zero
}
}
fn check_limb_length<F: RichField>(lv: &[F; columns::NUM_COLUMNS]) {
for (packed_input_cols, bit_cols) in [
(columns::INPUT0_PACKED, columns::INPUT0_BITS),
(columns::INPUT1_PACKED, columns::INPUT1_BITS),
] {
let limb_bit_cols_iter = columns::limb_bit_cols_for_input(bit_cols);
// Not actually reading/writing the bit columns, but this is a convenient way of
// calculating the size of each limb.
for (packed_limb_col, limb_bit_cols) in packed_input_cols.zip(limb_bit_cols_iter) {
let packed_limb = lv[packed_limb_col].to_canonical_u64();
let limb_length = limb_bit_cols.end - limb_bit_cols.start;
assert_eq!(packed_limb >> limb_length, 0);
}
}
}
fn make_bit_decomposition<F: RichField>(lv: &mut [F; columns::NUM_COLUMNS]) {
for (packed_input_cols, bit_cols) in [
(columns::INPUT0_PACKED, columns::INPUT0_BITS),
(columns::INPUT1_PACKED, columns::INPUT1_BITS),
] {
for (i, limb_col) in packed_input_cols.enumerate() {
let limb = lv[limb_col].to_canonical_u64();
let limb_bits_cols = bit_cols
.clone()
.skip(i * PACKED_LIMB_BITS)
.take(PACKED_LIMB_BITS);
for (j, col) in limb_bits_cols.enumerate() {
let bit = (limb >> j) & 1;
lv[col] = F::from_canonical_u64(bit);
}
fn check_bits<F: RichField>(lv: &[F; columns::NUM_COLUMNS]) {
for bit_cols in [columns::INPUT0, columns::INPUT1] {
for bit_col in bit_cols {
let bit = lv[bit_col].to_canonical_u64();
assert!(bit <= 1);
}
}
}
fn make_result<F: RichField>(lv: &mut [F; columns::NUM_COLUMNS], op: Op) {
for (res_col, limb_in0_col, limb_in1_col) in izip!(
for (res_col, limb_in0_cols, limb_in1_cols) in izip!(
columns::RESULT,
columns::INPUT0_PACKED,
columns::INPUT1_PACKED
columns::limb_bit_cols_for_input(columns::INPUT0),
columns::limb_bit_cols_for_input(columns::INPUT1),
) {
let limb_in0 = lv[limb_in0_col].to_canonical_u64();
let limb_in1 = lv[limb_in1_col].to_canonical_u64();
let limb_in0: u64 = limb_from_bits_le(limb_in0_cols.map(|col| lv[col])).to_canonical_u64();
let limb_in1: u64 = limb_from_bits_le(limb_in1_cols.map(|col| lv[col])).to_canonical_u64();
let res = match op {
Op::Zero => 0,
Op::And => limb_in0 & limb_in1,
Op::Or => limb_in0 | limb_in1,
Op::Xor => limb_in0 ^ limb_in1,
@ -145,8 +124,7 @@ fn make_result<F: RichField>(lv: &mut [F; columns::NUM_COLUMNS], op: Op) {
impl<F: RichField, const D: usize> LogicStark<F, D> {
pub fn generate(&self, lv: &mut [F; columns::NUM_COLUMNS]) {
let op = check_op_flags(lv);
check_limb_length(lv);
make_bit_decomposition(lv);
check_bits(lv);
make_result(lv, op);
}
}
@ -178,43 +156,22 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for LogicStark<F,
let and_coeff = is_and - is_or - is_xor * FE::TWO;
// Ensure that all bits are indeed bits.
for input_bits_cols in [columns::INPUT0_BITS, columns::INPUT1_BITS] {
for input_bits_cols in [columns::INPUT0, columns::INPUT1] {
for i in input_bits_cols {
let bit = lv[i];
yield_constr.constraint(bit * (bit - P::ONES));
}
}
// Check that the bits match the packed inputs.
for (input_bits_cols, input_packed_cols) in [
(columns::INPUT0_BITS, columns::INPUT0_PACKED),
(columns::INPUT1_BITS, columns::INPUT1_PACKED),
] {
for (limb_bits_cols, limb_col) in
columns::limb_bit_cols_for_input(input_bits_cols).zip(input_packed_cols)
{
let limb_from_bits: P = limb_bits_cols
.enumerate()
.map(|(i, bit_col)| {
let bit = lv[bit_col];
bit * FE::from_canonical_u64(1 << i)
})
.sum();
let limb = lv[limb_col];
yield_constr.constraint(limb - limb_from_bits);
}
}
// Form the result
for (result_col, x_col, y_col, x_bits_cols, y_bits_cols) in izip!(
for (result_col, x_bits_cols, y_bits_cols) in izip!(
columns::RESULT,
columns::INPUT0_PACKED,
columns::INPUT1_PACKED,
columns::limb_bit_cols_for_input(columns::INPUT0_BITS),
columns::limb_bit_cols_for_input(columns::INPUT1_BITS),
columns::limb_bit_cols_for_input(columns::INPUT0),
columns::limb_bit_cols_for_input(columns::INPUT1),
) {
let x = lv[x_col];
let y = lv[y_col];
let x: P = limb_from_bits_le(x_bits_cols.clone().map(|col| lv[col]));
let y: P = limb_from_bits_le(y_bits_cols.clone().map(|col| lv[col]));
let x_bits = x_bits_cols.map(|i| lv[i]);
let y_bits = y_bits_cols.map(|i| lv[i]);
@ -251,7 +208,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for LogicStark<F,
};
// Ensure that all bits are indeed bits.
for input_bits_cols in [columns::INPUT0_BITS, columns::INPUT1_BITS] {
for input_bits_cols in [columns::INPUT0, columns::INPUT1] {
for i in input_bits_cols {
let bit = lv[i];
let constr = builder.mul_sub_extension(bit, bit, bit);
@ -259,37 +216,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for LogicStark<F,
}
}
// Check that the bits match the packed inputs.
for (input_bits_cols, input_packed_cols) in [
(columns::INPUT0_BITS, columns::INPUT0_PACKED),
(columns::INPUT1_BITS, columns::INPUT1_PACKED),
] {
for (limb_bits_cols, limb_col) in
columns::limb_bit_cols_for_input(input_bits_cols).zip(input_packed_cols)
{
let limb_from_bits = limb_bits_cols.enumerate().fold(
builder.zero_extension(),
|acc, (i, bit_col)| {
let bit = lv[bit_col];
builder.mul_const_add_extension(F::from_canonical_u64(1 << i), bit, acc)
},
);
let limb = lv[limb_col];
let constr = builder.sub_extension(limb, limb_from_bits);
yield_constr.constraint(builder, constr);
}
}
// Form the result
for (result_col, x_col, y_col, x_bits_cols, y_bits_cols) in izip!(
for (result_col, x_bits_cols, y_bits_cols) in izip!(
columns::RESULT,
columns::INPUT0_PACKED,
columns::INPUT1_PACKED,
columns::limb_bit_cols_for_input(columns::INPUT0_BITS),
columns::limb_bit_cols_for_input(columns::INPUT1_BITS),
columns::limb_bit_cols_for_input(columns::INPUT0),
columns::limb_bit_cols_for_input(columns::INPUT1),
) {
let x = lv[x_col];
let y = lv[y_col];
let x = limb_from_bits_le_recursive(builder, x_bits_cols.clone().map(|i| lv[i]));
let y = limb_from_bits_le_recursive(builder, y_bits_cols.clone().map(|i| lv[i]));
let x_bits = x_bits_cols.map(|i| lv[i]);
let y_bits = y_bits_cols.map(|i| lv[i]);

View File

@ -1,8 +1,34 @@
use itertools::Itertools;
use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::Field;
use plonky2::field::packed_field::PackedField;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::util::transpose;
/// Construct an integer from its constituent bits (in little-endian order)
pub fn limb_from_bits_le<P: PackedField>(iter: impl IntoIterator<Item = P>) -> P {
// TODO: This is technically wrong, as 1 << i won't be canonical for all fields...
iter.into_iter()
.enumerate()
.map(|(i, bit)| bit * P::Scalar::from_canonical_u64(1 << i))
.sum()
}
/// Construct an integer from its constituent bits (in little-endian order): recursive edition
pub fn limb_from_bits_le_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
iter: impl IntoIterator<Item = ExtensionTarget<D>>,
) -> ExtensionTarget<D> {
iter.into_iter()
.enumerate()
.fold(builder.zero_extension(), |acc, (i, bit)| {
// TODO: This is technically wrong, as 1 << i won't be canonical for all fields...
builder.mul_const_add_extension(F::from_canonical_u64(1 << i), bit, acc)
})
}
/// A helper function to transpose a row-wise trace and put it in the format that `prove` expects.
pub fn trace_rows_to_poly_values<F: Field, const COLUMNS: usize>(
trace_rows: Vec<[F; COLUMNS]>,