Cross-table lookup for arithmetic stark (#905)

* First draft of linking arithmetic Stark into the CTL mechanism.

* Handle {ADD,SUB,MUL}FP254 operations explicitly in `modular.rs`.

* Adjust argument order; add tests.

* Add CTLs for ADD, MUL, SUB, LT and GT.

* Add CTLs for {ADD,MUL,SUB}MOD, DIV and MOD.

* Add CTLs for {ADD,MUL,SUB}FP254 operations.

* Refactor the CPU/arithmetic CTL mapping; add some documentation.

* Minor comment fixes.

* Combine addcy CTLs at the expense of repeated constraint evaluation.

* Combine addcy CTLs at the expense of repeated constraint evaluation.

* Merge `*FP254` CTL into main CTL; rename some registers.

* Connect extra argument from CPU in binary ops to facilitate combining with ternary ops.

* Merge modular ops CTL into main CTL.

* Refactor DIV and MOD code into its own module.

* Merge DIV and MOD into arithmetic CTL.

* Clippy.

* Fixes related to merge.

* Simplify register naming.

* Generate u16 BN254 modulus limbs at compile time.

* Clippy.

* Add degree bits ranges for Arithmetic table.
This commit is contained in:
Hamish Ivey-Law 2023-05-11 03:29:06 +10:00 committed by GitHub
parent 779456c2c9
commit c134b59763
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 894 additions and 305 deletions

View File

@ -4,6 +4,8 @@ use plonky2::field::extension::Extendable;
use plonky2::field::types::Field; use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField; use plonky2::hash::hash_types::RichField;
use crate::arithmetic::arithmetic_stark;
use crate::arithmetic::arithmetic_stark::ArithmeticStark;
use crate::config::StarkConfig; use crate::config::StarkConfig;
use crate::cpu::cpu_stark; use crate::cpu::cpu_stark;
use crate::cpu::cpu_stark::CpuStark; use crate::cpu::cpu_stark::CpuStark;
@ -22,6 +24,7 @@ use crate::stark::Stark;
#[derive(Clone)] #[derive(Clone)]
pub struct AllStark<F: RichField + Extendable<D>, const D: usize> { pub struct AllStark<F: RichField + Extendable<D>, const D: usize> {
pub arithmetic_stark: ArithmeticStark<F, D>,
pub cpu_stark: CpuStark<F, D>, pub cpu_stark: CpuStark<F, D>,
pub keccak_stark: KeccakStark<F, D>, pub keccak_stark: KeccakStark<F, D>,
pub keccak_sponge_stark: KeccakSpongeStark<F, D>, pub keccak_sponge_stark: KeccakSpongeStark<F, D>,
@ -33,6 +36,7 @@ pub struct AllStark<F: RichField + Extendable<D>, const D: usize> {
impl<F: RichField + Extendable<D>, const D: usize> Default for AllStark<F, D> { impl<F: RichField + Extendable<D>, const D: usize> Default for AllStark<F, D> {
fn default() -> Self { fn default() -> Self {
Self { Self {
arithmetic_stark: ArithmeticStark::default(),
cpu_stark: CpuStark::default(), cpu_stark: CpuStark::default(),
keccak_stark: KeccakStark::default(), keccak_stark: KeccakStark::default(),
keccak_sponge_stark: KeccakSpongeStark::default(), keccak_sponge_stark: KeccakSpongeStark::default(),
@ -46,6 +50,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Default for AllStark<F, D> {
impl<F: RichField + Extendable<D>, const D: usize> AllStark<F, D> { impl<F: RichField + Extendable<D>, const D: usize> AllStark<F, D> {
pub(crate) fn nums_permutation_zs(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { pub(crate) fn nums_permutation_zs(&self, config: &StarkConfig) -> [usize; NUM_TABLES] {
[ [
self.arithmetic_stark.num_permutation_batches(config),
self.cpu_stark.num_permutation_batches(config), self.cpu_stark.num_permutation_batches(config),
self.keccak_stark.num_permutation_batches(config), self.keccak_stark.num_permutation_batches(config),
self.keccak_sponge_stark.num_permutation_batches(config), self.keccak_sponge_stark.num_permutation_batches(config),
@ -56,6 +61,7 @@ impl<F: RichField + Extendable<D>, const D: usize> AllStark<F, D> {
pub(crate) fn permutation_batch_sizes(&self) -> [usize; NUM_TABLES] { pub(crate) fn permutation_batch_sizes(&self) -> [usize; NUM_TABLES] {
[ [
self.arithmetic_stark.permutation_batch_size(),
self.cpu_stark.permutation_batch_size(), self.cpu_stark.permutation_batch_size(),
self.keccak_stark.permutation_batch_size(), self.keccak_stark.permutation_batch_size(),
self.keccak_sponge_stark.permutation_batch_size(), self.keccak_sponge_stark.permutation_batch_size(),
@ -67,11 +73,12 @@ impl<F: RichField + Extendable<D>, const D: usize> AllStark<F, D> {
#[derive(Debug, Copy, Clone, Eq, PartialEq)] #[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Table { pub enum Table {
Cpu = 0, Arithmetic = 0,
Keccak = 1, Cpu = 1,
KeccakSponge = 2, Keccak = 2,
Logic = 3, KeccakSponge = 3,
Memory = 4, Logic = 4,
Memory = 5,
} }
pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1;
@ -79,6 +86,7 @@ pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1;
impl Table { impl Table {
pub(crate) fn all() -> [Self; NUM_TABLES] { pub(crate) fn all() -> [Self; NUM_TABLES] {
[ [
Self::Arithmetic,
Self::Cpu, Self::Cpu,
Self::Keccak, Self::Keccak,
Self::KeccakSponge, Self::KeccakSponge,
@ -89,9 +97,15 @@ impl Table {
} }
pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> { pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {
let mut ctls = vec![ctl_keccak_sponge(), ctl_keccak(), ctl_logic(), ctl_memory()]; let mut ctls = vec![
ctl_arithmetic(),
ctl_keccak_sponge(),
ctl_keccak(),
ctl_logic(),
ctl_memory(),
];
// TODO: Some CTLs temporarily disabled while we get them working. // TODO: Some CTLs temporarily disabled while we get them working.
disable_ctl(&mut ctls[3]); disable_ctl(&mut ctls[4]);
ctls ctls
} }
@ -102,6 +116,13 @@ fn disable_ctl<F: Field>(ctl: &mut CrossTableLookup<F>) {
ctl.looked_table.filter_column = Some(Column::zero()); ctl.looked_table.filter_column = Some(Column::zero());
} }
fn ctl_arithmetic<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(
vec![cpu_stark::ctl_arithmetic_rows()],
arithmetic_stark::ctl_arithmetic_rows(),
)
}
fn ctl_keccak<F: Field>() -> CrossTableLookup<F> { fn ctl_keccak<F: Field>() -> CrossTableLookup<F> {
let keccak_sponge_looking = TableWithColumns::new( let keccak_sponge_looking = TableWithColumns::new(
Table::KeccakSponge, Table::KeccakSponge,

View File

@ -28,68 +28,41 @@ use crate::arithmetic::utils::u256_to_array;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
/// Generate row for ADD, SUB, GT and LT operations. /// Generate row for ADD, SUB, GT and LT operations.
///
/// A row consists of four values, GENERAL_REGISTER_[012] and
/// GENERAL_REGISTER_BIT. The interpretation of these values for each
/// operation is as follows:
///
/// ADD: REGISTER_0 + REGISTER_1, output in REGISTER_2, ignore REGISTER_BIT
/// SUB: REGISTER_2 - REGISTER_0, output in REGISTER_1, ignore REGISTER_BIT
/// GT: REGISTER_0 > REGISTER_2, output in REGISTER_BIT, auxiliary output in REGISTER_1
/// LT: REGISTER_2 < REGISTER_0, output in REGISTER_BIT, auxiliary output in REGISTER_1
pub(crate) fn generate<F: PrimeField64>( pub(crate) fn generate<F: PrimeField64>(
lv: &mut [F], lv: &mut [F],
filter: usize, filter: usize,
left_in: U256, left_in: U256,
right_in: U256, right_in: U256,
) { ) {
// Swap left_in and right_in for LT u256_to_array(&mut lv[INPUT_REGISTER_0], left_in);
let (left_in, right_in) = if filter == IS_LT { u256_to_array(&mut lv[INPUT_REGISTER_1], right_in);
(right_in, left_in) u256_to_array(&mut lv[INPUT_REGISTER_2], U256::zero());
} else {
(left_in, right_in)
};
match filter { match filter {
IS_ADD => { IS_ADD => {
// x + y == z + cy*2^256
let (result, cy) = left_in.overflowing_add(right_in); let (result, cy) = left_in.overflowing_add(right_in);
u256_to_array(&mut lv[GENERAL_REGISTER_0], left_in); // x u256_to_array(&mut lv[AUX_INPUT_REGISTER_0], U256::from(cy as u32));
u256_to_array(&mut lv[GENERAL_REGISTER_1], right_in); // y u256_to_array(&mut lv[OUTPUT_REGISTER], result);
u256_to_array(&mut lv[GENERAL_REGISTER_2], result); // z
lv[GENERAL_REGISTER_BIT] = F::from_bool(cy);
} }
IS_SUB | IS_GT | IS_LT => { IS_SUB => {
// y == z - x + cy*2^256 let (diff, cy) = left_in.overflowing_sub(right_in);
u256_to_array(&mut lv[AUX_INPUT_REGISTER_0], U256::from(cy as u32));
u256_to_array(&mut lv[OUTPUT_REGISTER], diff);
}
IS_LT => {
let (diff, cy) = left_in.overflowing_sub(right_in);
u256_to_array(&mut lv[AUX_INPUT_REGISTER_0], diff);
u256_to_array(&mut lv[OUTPUT_REGISTER], U256::from(cy as u32));
}
IS_GT => {
let (diff, cy) = right_in.overflowing_sub(left_in); let (diff, cy) = right_in.overflowing_sub(left_in);
u256_to_array(&mut lv[GENERAL_REGISTER_0], left_in); // x u256_to_array(&mut lv[AUX_INPUT_REGISTER_0], diff);
u256_to_array(&mut lv[GENERAL_REGISTER_2], right_in); // z u256_to_array(&mut lv[OUTPUT_REGISTER], U256::from(cy as u32));
u256_to_array(&mut lv[GENERAL_REGISTER_1], diff); // y
lv[GENERAL_REGISTER_BIT] = F::from_bool(cy);
} }
_ => panic!("unexpected operation filter"), _ => panic!("unexpected operation filter"),
}; };
} }
fn eval_packed_generic_check_is_one_bit<P: PackedField>(
yield_constr: &mut ConstraintConsumer<P>,
filter: P,
x: P,
) {
yield_constr.constraint(filter * x * (x - P::ONES));
}
fn eval_ext_circuit_check_is_one_bit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
filter: ExtensionTarget<D>,
x: ExtensionTarget<D>,
) {
let constr = builder.mul_sub_extension(x, x, x);
let filtered_constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, filtered_constr);
}
/// 2^-16 mod (2^64 - 2^32 + 1) /// 2^-16 mod (2^64 - 2^32 + 1)
const GOLDILOCKS_INVERSE_65536: u64 = 18446462594437939201; const GOLDILOCKS_INVERSE_65536: u64 = 18446462594437939201;
@ -126,10 +99,12 @@ pub(crate) fn eval_packed_generic_addcy<P: PackedField>(
x: &[P], x: &[P],
y: &[P], y: &[P],
z: &[P], z: &[P],
given_cy: P, given_cy: &[P],
is_two_row_op: bool, is_two_row_op: bool,
) { ) {
debug_assert!(x.len() == N_LIMBS && y.len() == N_LIMBS && z.len() == N_LIMBS); debug_assert!(
x.len() == N_LIMBS && y.len() == N_LIMBS && z.len() == N_LIMBS && given_cy.len() == N_LIMBS
);
let overflow = P::Scalar::from_canonical_u64(1u64 << LIMB_BITS); let overflow = P::Scalar::from_canonical_u64(1u64 << LIMB_BITS);
let overflow_inv = P::Scalar::from_canonical_u64(GOLDILOCKS_INVERSE_65536); let overflow_inv = P::Scalar::from_canonical_u64(GOLDILOCKS_INVERSE_65536);
@ -154,9 +129,22 @@ pub(crate) fn eval_packed_generic_addcy<P: PackedField>(
} }
if is_two_row_op { if is_two_row_op {
yield_constr.constraint_transition(filter * (cy - given_cy)); // NB: Mild hack: We don't check that given_cy[0] is 0 or 1
// when is_two_row_op is true because that's only the case
// when this function is called from
// modular::modular_constr_poly(), in which case (1) this
// condition has already been checked and (2) it exceeds the
// degree budget because given_cy[0] is already degree 2.
yield_constr.constraint_transition(filter * (cy - given_cy[0]));
for i in 1..N_LIMBS {
yield_constr.constraint_transition(filter * given_cy[i]);
}
} else { } else {
yield_constr.constraint(filter * (cy - given_cy)); yield_constr.constraint(filter * given_cy[0] * (given_cy[0] - P::ONES));
yield_constr.constraint(filter * (cy - given_cy[0]));
for i in 1..N_LIMBS {
yield_constr.constraint(filter * given_cy[i]);
}
} }
} }
@ -169,30 +157,32 @@ pub fn eval_packed_generic<P: PackedField>(
let is_lt = lv[IS_LT]; let is_lt = lv[IS_LT];
let is_gt = lv[IS_GT]; let is_gt = lv[IS_GT];
let x = &lv[GENERAL_REGISTER_0]; let in0 = &lv[INPUT_REGISTER_0];
let y = &lv[GENERAL_REGISTER_1]; let in1 = &lv[INPUT_REGISTER_1];
let z = &lv[GENERAL_REGISTER_2]; let out = &lv[OUTPUT_REGISTER];
let cy = lv[GENERAL_REGISTER_BIT]; let aux = &lv[AUX_INPUT_REGISTER_0];
let op_filter = is_add + is_sub + is_lt + is_gt; // x + y = z + w*2^256
eval_packed_generic_check_is_one_bit(yield_constr, op_filter, cy); eval_packed_generic_addcy(yield_constr, is_add, in0, in1, out, aux, false);
eval_packed_generic_addcy(yield_constr, is_sub, in1, out, in0, aux, false);
// x + y = z + cy*2^256 eval_packed_generic_addcy(yield_constr, is_lt, in1, aux, in0, out, false);
eval_packed_generic_addcy(yield_constr, op_filter, x, y, z, cy, false); eval_packed_generic_addcy(yield_constr, is_gt, in0, aux, in1, out, false);
} }
#[allow(clippy::needless_collect)] #[allow(clippy::needless_collect)]
pub(crate) fn eval_ext_circuit_addcy<F: RichField + Extendable<D>, const D: usize>( pub(crate) fn eval_ext_circuit_addcy<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>, yield_constr: &mut RecursiveConstraintConsumer<F, D>,
filter: ExtensionTarget<D>, filter: ExtensionTarget<D>,
x: &[ExtensionTarget<D>], x: &[ExtensionTarget<D>],
y: &[ExtensionTarget<D>], y: &[ExtensionTarget<D>],
z: &[ExtensionTarget<D>], z: &[ExtensionTarget<D>],
given_cy: ExtensionTarget<D>, given_cy: &[ExtensionTarget<D>],
is_two_row_op: bool, is_two_row_op: bool,
) { ) {
debug_assert!(x.len() == N_LIMBS && y.len() == N_LIMBS && z.len() == N_LIMBS); debug_assert!(
x.len() == N_LIMBS && y.len() == N_LIMBS && z.len() == N_LIMBS && given_cy.len() == N_LIMBS
);
// 2^LIMB_BITS in the base field // 2^LIMB_BITS in the base field
let overflow_base = F::from_canonical_u64(1 << LIMB_BITS); let overflow_base = F::from_canonical_u64(1 << LIMB_BITS);
@ -222,17 +212,31 @@ pub(crate) fn eval_ext_circuit_addcy<F: RichField + Extendable<D>, const D: usiz
cy = builder.mul_const_extension(overflow_inv, t); cy = builder.mul_const_extension(overflow_inv, t);
} }
let good_cy = builder.sub_extension(cy, given_cy); let good_cy = builder.sub_extension(cy, given_cy[0]);
let filter = builder.mul_extension(filter, good_cy); let cy_filter = builder.mul_extension(filter, good_cy);
// Check given carry is one bit
let bit_constr = builder.mul_sub_extension(given_cy[0], given_cy[0], given_cy[0]);
let bit_filter = builder.mul_extension(filter, bit_constr);
if is_two_row_op { if is_two_row_op {
yield_constr.constraint_transition(builder, filter); yield_constr.constraint_transition(builder, cy_filter);
for i in 1..N_LIMBS {
let t = builder.mul_extension(filter, given_cy[i]);
yield_constr.constraint_transition(builder, t);
}
} else { } else {
yield_constr.constraint(builder, filter); yield_constr.constraint(builder, bit_filter);
yield_constr.constraint(builder, cy_filter);
for i in 1..N_LIMBS {
let t = builder.mul_extension(filter, given_cy[i]);
yield_constr.constraint(builder, t);
}
} }
} }
pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>( pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS], lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>, yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) { ) {
@ -241,14 +245,15 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
let is_lt = lv[IS_LT]; let is_lt = lv[IS_LT];
let is_gt = lv[IS_GT]; let is_gt = lv[IS_GT];
let x = &lv[GENERAL_REGISTER_0]; let in0 = &lv[INPUT_REGISTER_0];
let y = &lv[GENERAL_REGISTER_1]; let in1 = &lv[INPUT_REGISTER_1];
let z = &lv[GENERAL_REGISTER_2]; let out = &lv[OUTPUT_REGISTER];
let cy = lv[GENERAL_REGISTER_BIT]; let aux = &lv[AUX_INPUT_REGISTER_0];
let op_filter = builder.add_many_extension([is_add, is_sub, is_lt, is_gt]); eval_ext_circuit_addcy(builder, yield_constr, is_add, in0, in1, out, aux, false);
eval_ext_circuit_check_is_one_bit(builder, yield_constr, op_filter, cy); eval_ext_circuit_addcy(builder, yield_constr, is_sub, in1, out, in0, aux, false);
eval_ext_circuit_addcy(builder, yield_constr, op_filter, x, y, z, cy, false); eval_ext_circuit_addcy(builder, yield_constr, is_lt, in1, aux, in0, out, false);
eval_ext_circuit_addcy(builder, yield_constr, is_gt, in0, aux, in1, out, false);
} }
#[cfg(test)] #[cfg(test)]
@ -264,7 +269,7 @@ mod tests {
// TODO: Should be able to refactor this test to apply to all operations. // TODO: Should be able to refactor this test to apply to all operations.
#[test] #[test]
fn generate_eval_consistency_not_addcc() { fn generate_eval_consistency_not_addcy() {
type F = GoldilocksField; type F = GoldilocksField;
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
@ -291,7 +296,7 @@ mod tests {
} }
#[test] #[test]
fn generate_eval_consistency_addcc() { fn generate_eval_consistency_addcy() {
type F = GoldilocksField; type F = GoldilocksField;
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
@ -328,6 +333,21 @@ mod tests {
for &acc in &constrant_consumer.constraint_accs { for &acc in &constrant_consumer.constraint_accs {
assert_eq!(acc, F::ZERO); assert_eq!(acc, F::ZERO);
} }
let expected = match op_filter {
IS_ADD => left_in.overflowing_add(right_in).0,
IS_SUB => left_in.overflowing_sub(right_in).0,
IS_LT => U256::from((left_in < right_in) as u8),
IS_GT => U256::from((left_in > right_in) as u8),
_ => panic!("unrecognised operation"),
};
let mut expected_limbs = [F::ZERO; N_LIMBS];
u256_to_array(&mut expected_limbs, expected);
assert!(expected_limbs
.iter()
.zip(&lv[OUTPUT_REGISTER])
.all(|(x, y)| x == y));
} }
} }
} }

View File

@ -1,4 +1,5 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::Range;
use itertools::Itertools; use itertools::Itertools;
use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::extension::{Extendable, FieldExtension};
@ -8,15 +9,82 @@ use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField; use plonky2::hash::hash_types::RichField;
use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::util::transpose; use plonky2::util::transpose;
use static_assertions::const_assert;
use crate::arithmetic::{addcy, columns, modular, mul, Operation}; use crate::all_stark::Table;
use crate::arithmetic::{addcy, columns, divmod, modular, mul, Operation};
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::{Column, TableWithColumns};
use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols};
use crate::permutation::PermutationPair; use crate::permutation::PermutationPair;
use crate::stark::Stark; use crate::stark::Stark;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
#[derive(Copy, Clone)] /// Link the 16-bit columns of the arithmetic table, split into groups
/// of N_LIMBS at a time in `regs`, with the corresponding 32-bit
/// columns of the CPU table. Does this for all ops in `ops`.
///
/// This is done by taking pairs of columns (x, y) of the arithmetic
/// table and combining them as x + y*2^16 to ensure they equal the
/// corresponding 32-bit number in the CPU table.
fn cpu_arith_data_link<F: Field>(ops: &[usize], regs: &[Range<usize>]) -> Vec<Column<F>> {
let limb_base = F::from_canonical_u64(1 << columns::LIMB_BITS);
let mut res = Column::singles(ops).collect_vec();
// The inner for loop below assumes N_LIMBS is even.
const_assert!(columns::N_LIMBS % 2 == 0);
for reg_cols in regs {
// Loop below assumes we're operating on a "register" of N_LIMBS columns.
debug_assert_eq!(reg_cols.len(), columns::N_LIMBS);
for i in 0..(columns::N_LIMBS / 2) {
let c0 = reg_cols.start + 2 * i;
let c1 = reg_cols.start + 2 * i + 1;
res.push(Column::linear_combination([(c0, F::ONE), (c1, limb_base)]));
}
}
res
}
pub fn ctl_arithmetic_rows<F: Field>() -> TableWithColumns<F> {
const ARITH_OPS: [usize; 13] = [
columns::IS_ADD,
columns::IS_SUB,
columns::IS_MUL,
columns::IS_LT,
columns::IS_GT,
columns::IS_ADDFP254,
columns::IS_MULFP254,
columns::IS_SUBFP254,
columns::IS_ADDMOD,
columns::IS_MULMOD,
columns::IS_SUBMOD,
columns::IS_DIV,
columns::IS_MOD,
];
const REGISTER_MAP: [Range<usize>; 4] = [
columns::INPUT_REGISTER_0,
columns::INPUT_REGISTER_1,
columns::INPUT_REGISTER_2,
columns::OUTPUT_REGISTER,
];
// Create the Arithmetic Table whose columns are those of the
// operations listed in `ops` whose inputs and outputs are given
// by `regs`, where each element of `regs` is a range of columns
// corresponding to a 256-bit input or output register (also `ops`
// is used as the operation filter).
TableWithColumns::new(
Table::Arithmetic,
cpu_arith_data_link(&ARITH_OPS, &REGISTER_MAP),
Some(Column::sum(ARITH_OPS)),
)
}
#[derive(Copy, Clone, Default)]
pub struct ArithmeticStark<F, const D: usize> { pub struct ArithmeticStark<F, const D: usize> {
pub f: PhantomData<F>, pub f: PhantomData<F>,
} }
@ -48,8 +116,7 @@ impl<F: RichField, const D: usize> ArithmeticStark<F, D> {
} }
} }
#[allow(unused)] pub(crate) fn generate_trace(&self, operations: Vec<Operation>) -> Vec<PolynomialValues<F>> {
pub(crate) fn generate(&self, operations: Vec<Operation>) -> Vec<PolynomialValues<F>> {
// The number of rows reserved is the smallest value that's // The number of rows reserved is the smallest value that's
// guaranteed to avoid a reallocation: The only ops that use // guaranteed to avoid a reallocation: The only ops that use
// two rows are the modular operations and DIV, so the only // two rows are the modular operations and DIV, so the only
@ -114,7 +181,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
mul::eval_packed_generic(lv, yield_constr); mul::eval_packed_generic(lv, yield_constr);
addcy::eval_packed_generic(lv, yield_constr); addcy::eval_packed_generic(lv, yield_constr);
modular::eval_packed_generic(lv, nv, yield_constr); divmod::eval_packed(lv, nv, yield_constr);
modular::eval_packed(lv, nv, yield_constr);
} }
fn eval_ext_circuit( fn eval_ext_circuit(
@ -144,6 +212,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
mul::eval_ext_circuit(builder, lv, yield_constr); mul::eval_ext_circuit(builder, lv, yield_constr);
addcy::eval_ext_circuit(builder, lv, yield_constr); addcy::eval_ext_circuit(builder, lv, yield_constr);
divmod::eval_ext_circuit(builder, lv, nv, yield_constr);
modular::eval_ext_circuit(builder, lv, nv, yield_constr); modular::eval_ext_circuit(builder, lv, nv, yield_constr);
} }
@ -176,6 +245,7 @@ mod tests {
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
use super::{columns, ArithmeticStark}; use super::{columns, ArithmeticStark};
use crate::arithmetic::columns::OUTPUT_REGISTER;
use crate::arithmetic::*; use crate::arithmetic::*;
use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};
@ -249,7 +319,7 @@ mod tests {
let ops: Vec<Operation> = vec![add, mulmod, addmod, mul, modop, lt1, lt2, lt3, div]; let ops: Vec<Operation> = vec![add, mulmod, addmod, mul, modop, lt1, lt2, lt3, div];
let pols = stark.generate(ops); let pols = stark.generate_trace(ops);
// Trace should always have NUM_ARITH_COLUMNS columns and // Trace should always have NUM_ARITH_COLUMNS columns and
// min(RANGE_MAX, operations.len()) rows. In this case there // min(RANGE_MAX, operations.len()) rows. In this case there
@ -259,26 +329,23 @@ mod tests {
&& pols.iter().all(|v| v.len() == super::RANGE_MAX) && pols.iter().all(|v| v.len() == super::RANGE_MAX)
); );
// Wrap the single value GENERAL_REGISTER_BIT in a Range.
let cmp_range = columns::GENERAL_REGISTER_BIT..columns::GENERAL_REGISTER_BIT + 1;
// Each operation has a single word answer that we can check // Each operation has a single word answer that we can check
let expected_output = [ let expected_output = [
// Row (some ops take two rows), col, expected // Row (some ops take two rows), expected
(0, &columns::GENERAL_REGISTER_2, 579), // ADD_OUTPUT (0, 579), // ADD_OUTPUT
(1, &columns::MODULAR_OUTPUT, 703), (1, 703),
(3, &columns::MODULAR_OUTPUT, 794), (3, 794),
(5, &columns::MUL_OUTPUT, 56088), (5, 56088),
(6, &columns::MODULAR_OUTPUT, 11), (6, 11),
(8, &cmp_range, 0), (8, 0),
(9, &cmp_range, 1), (9, 1),
(10, &cmp_range, 0), (10, 0),
(11, &columns::DIV_OUTPUT, 9), (11, 9),
]; ];
for (row, col, expected) in expected_output { for (row, expected) in expected_output {
// First register should match expected value... // First register should match expected value...
let first = col.start; let first = OUTPUT_REGISTER.start;
let out = pols[first].values[row].to_canonical_u64(); let out = pols[first].values[row].to_canonical_u64();
assert_eq!( assert_eq!(
out, expected, out, expected,
@ -286,7 +353,7 @@ mod tests {
first, row, expected, out, first, row, expected, out,
); );
// ...other registers should be zero // ...other registers should be zero
let rest = col.start + 1..col.end; let rest = OUTPUT_REGISTER.start + 1..OUTPUT_REGISTER.end;
assert!(pols[rest].iter().all(|v| v.values[row] == F::ZERO)); assert!(pols[rest].iter().all(|v| v.values[row] == F::ZERO));
} }
} }
@ -314,7 +381,7 @@ mod tests {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let pols = stark.generate(ops); let pols = stark.generate_trace(ops);
// Trace should always have NUM_ARITH_COLUMNS columns and // Trace should always have NUM_ARITH_COLUMNS columns and
// min(RANGE_MAX, operations.len()) rows. In this case there // min(RANGE_MAX, operations.len()) rows. In this case there
@ -335,7 +402,7 @@ mod tests {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let pols = stark.generate(ops); let pols = stark.generate_trace(ops);
// Trace should always have NUM_ARITH_COLUMNS columns and // Trace should always have NUM_ARITH_COLUMNS columns and
// min(RANGE_MAX, operations.len()) rows. In this case there // min(RANGE_MAX, operations.len()) rows. In this case there

View File

@ -22,16 +22,19 @@ const fn n_limbs() -> usize {
/// Number of LIMB_BITS limbs that are in on EVM register-sized number. /// Number of LIMB_BITS limbs that are in on EVM register-sized number.
pub const N_LIMBS: usize = n_limbs(); pub const N_LIMBS: usize = n_limbs();
pub const IS_ADD: usize = 0; pub(crate) const IS_ADD: usize = 0;
pub const IS_MUL: usize = IS_ADD + 1; pub(crate) const IS_MUL: usize = IS_ADD + 1;
pub const IS_SUB: usize = IS_MUL + 1; pub(crate) const IS_SUB: usize = IS_MUL + 1;
pub const IS_DIV: usize = IS_SUB + 1; pub(crate) const IS_DIV: usize = IS_SUB + 1;
pub const IS_MOD: usize = IS_DIV + 1; pub(crate) const IS_MOD: usize = IS_DIV + 1;
pub const IS_ADDMOD: usize = IS_MOD + 1; pub(crate) const IS_ADDMOD: usize = IS_MOD + 1;
pub const IS_SUBMOD: usize = IS_ADDMOD + 1; pub(crate) const IS_MULMOD: usize = IS_ADDMOD + 1;
pub const IS_MULMOD: usize = IS_SUBMOD + 1; pub(crate) const IS_ADDFP254: usize = IS_MULMOD + 1;
pub const IS_LT: usize = IS_MULMOD + 1; pub(crate) const IS_MULFP254: usize = IS_ADDFP254 + 1;
pub const IS_GT: usize = IS_LT + 1; pub(crate) const IS_SUBFP254: usize = IS_MULFP254 + 1;
pub(crate) const IS_SUBMOD: usize = IS_SUBFP254 + 1;
pub(crate) const IS_LT: usize = IS_SUBMOD + 1;
pub(crate) const IS_GT: usize = IS_LT + 1;
pub(crate) const START_SHARED_COLS: usize = IS_GT + 1; pub(crate) const START_SHARED_COLS: usize = IS_GT + 1;
@ -46,28 +49,28 @@ pub(crate) const START_SHARED_COLS: usize = IS_GT + 1;
pub(crate) const NUM_SHARED_COLS: usize = 6 * N_LIMBS; pub(crate) const NUM_SHARED_COLS: usize = 6 * N_LIMBS;
pub(crate) const SHARED_COLS: Range<usize> = START_SHARED_COLS..START_SHARED_COLS + NUM_SHARED_COLS; pub(crate) const SHARED_COLS: Range<usize> = START_SHARED_COLS..START_SHARED_COLS + NUM_SHARED_COLS;
pub(crate) const GENERAL_REGISTER_0: Range<usize> = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS; pub(crate) const INPUT_REGISTER_0: Range<usize> = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS;
pub(crate) const GENERAL_REGISTER_1: Range<usize> = pub(crate) const INPUT_REGISTER_1: Range<usize> =
GENERAL_REGISTER_0.end..GENERAL_REGISTER_0.end + N_LIMBS; INPUT_REGISTER_0.end..INPUT_REGISTER_0.end + N_LIMBS;
pub(crate) const GENERAL_REGISTER_2: Range<usize> = pub(crate) const INPUT_REGISTER_2: Range<usize> =
GENERAL_REGISTER_1.end..GENERAL_REGISTER_1.end + N_LIMBS; INPUT_REGISTER_1.end..INPUT_REGISTER_1.end + N_LIMBS;
const GENERAL_REGISTER_3: Range<usize> = GENERAL_REGISTER_2.end..GENERAL_REGISTER_2.end + N_LIMBS; pub(crate) const OUTPUT_REGISTER: Range<usize> =
// NB: Uses first slot of the GENERAL_REGISTER_3 register. INPUT_REGISTER_2.end..INPUT_REGISTER_2.end + N_LIMBS;
pub(crate) const GENERAL_REGISTER_BIT: usize = GENERAL_REGISTER_3.start;
// NB: Only one of these two sets of columns will be used for a given operation // NB: Only one of AUX_INPUT_REGISTER_[01] or AUX_INPUT_REGISTER_DBL
const GENERAL_REGISTER_4: Range<usize> = GENERAL_REGISTER_3.end..GENERAL_REGISTER_3.end + N_LIMBS; // will be used for a given operation since they overlap
const GENERAL_REGISTER_4_DBL: Range<usize> = pub(crate) const AUX_INPUT_REGISTER_0: Range<usize> =
GENERAL_REGISTER_3.end..GENERAL_REGISTER_3.end + 2 * N_LIMBS; OUTPUT_REGISTER.end..OUTPUT_REGISTER.end + N_LIMBS;
pub(crate) const AUX_INPUT_REGISTER_1: Range<usize> =
AUX_INPUT_REGISTER_0.end..AUX_INPUT_REGISTER_0.end + N_LIMBS;
pub(crate) const AUX_INPUT_REGISTER_DBL: Range<usize> =
OUTPUT_REGISTER.end..OUTPUT_REGISTER.end + 2 * N_LIMBS;
// The auxiliary input columns overlap the general input columns // The auxiliary input columns overlap the general input columns
// because they correspond to the values in the second row for modular // because they correspond to the values in the second row for modular
// operations. // operations.
const AUX_REGISTER_0: Range<usize> = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS; const AUX_REGISTER_0: Range<usize> = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS;
const AUX_REGISTER_1: Range<usize> = AUX_REGISTER_0.end..AUX_REGISTER_0.end + 2 * N_LIMBS; const AUX_REGISTER_1: Range<usize> = AUX_REGISTER_0.end..AUX_REGISTER_0.end + 2 * N_LIMBS;
// These auxiliary input columns are awkwardly split across two rows,
// with the first half after the general input columns and the second
// half after the auxiliary input columns.
const AUX_REGISTER_2: Range<usize> = AUX_REGISTER_1.end..AUX_REGISTER_1.end + 2 * N_LIMBS - 1; const AUX_REGISTER_2: Range<usize> = AUX_REGISTER_1.end..AUX_REGISTER_1.end + 2 * N_LIMBS - 1;
// Each element c of {MUL,MODULAR}_AUX_REGISTER is -2^20 <= c <= 2^20; // Each element c of {MUL,MODULAR}_AUX_REGISTER is -2^20 <= c <= 2^20;
@ -76,11 +79,8 @@ const AUX_REGISTER_2: Range<usize> = AUX_REGISTER_1.end..AUX_REGISTER_1.end + 2
pub(crate) const AUX_COEFF_ABS_MAX: i64 = 1 << 20; pub(crate) const AUX_COEFF_ABS_MAX: i64 = 1 << 20;
// MUL takes 5 * N_LIMBS = 80 columns // MUL takes 5 * N_LIMBS = 80 columns
pub(crate) const MUL_INPUT_0: Range<usize> = GENERAL_REGISTER_0; pub(crate) const MUL_AUX_INPUT_LO: Range<usize> = AUX_INPUT_REGISTER_0;
pub(crate) const MUL_INPUT_1: Range<usize> = GENERAL_REGISTER_1; pub(crate) const MUL_AUX_INPUT_HI: Range<usize> = AUX_INPUT_REGISTER_1;
pub(crate) const MUL_OUTPUT: Range<usize> = GENERAL_REGISTER_2;
pub(crate) const MUL_AUX_INPUT_LO: Range<usize> = GENERAL_REGISTER_3;
pub(crate) const MUL_AUX_INPUT_HI: Range<usize> = GENERAL_REGISTER_4;
// MULMOD takes 4 * N_LIMBS + 3 * 2*N_LIMBS + N_LIMBS = 176 columns // MULMOD takes 4 * N_LIMBS + 3 * 2*N_LIMBS + N_LIMBS = 176 columns
// but split over two rows of 96 columns and 80 columns. // but split over two rows of 96 columns and 80 columns.
@ -88,11 +88,11 @@ pub(crate) const MUL_AUX_INPUT_HI: Range<usize> = GENERAL_REGISTER_4;
// ADDMOD, SUBMOD, MOD and DIV are currently implemented in terms of // ADDMOD, SUBMOD, MOD and DIV are currently implemented in terms of
// the general modular code, so they also take 144 columns (also split // the general modular code, so they also take 144 columns (also split
// over two rows). // over two rows).
pub(crate) const MODULAR_INPUT_0: Range<usize> = GENERAL_REGISTER_0; pub(crate) const MODULAR_INPUT_0: Range<usize> = INPUT_REGISTER_0;
pub(crate) const MODULAR_INPUT_1: Range<usize> = GENERAL_REGISTER_1; pub(crate) const MODULAR_INPUT_1: Range<usize> = INPUT_REGISTER_1;
pub(crate) const MODULAR_MODULUS: Range<usize> = GENERAL_REGISTER_2; pub(crate) const MODULAR_MODULUS: Range<usize> = INPUT_REGISTER_2;
pub(crate) const MODULAR_OUTPUT: Range<usize> = GENERAL_REGISTER_3; pub(crate) const MODULAR_OUTPUT: Range<usize> = OUTPUT_REGISTER;
pub(crate) const MODULAR_QUO_INPUT: Range<usize> = GENERAL_REGISTER_4_DBL; pub(crate) const MODULAR_QUO_INPUT: Range<usize> = AUX_INPUT_REGISTER_DBL;
pub(crate) const MODULAR_OUT_AUX_RED: Range<usize> = AUX_REGISTER_0; pub(crate) const MODULAR_OUT_AUX_RED: Range<usize> = AUX_REGISTER_0;
// NB: Last value is not used in AUX, it is used in MOD_IS_ZERO // NB: Last value is not used in AUX, it is used in MOD_IS_ZERO
pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_REGISTER_1.start; pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_REGISTER_1.start;
@ -101,14 +101,6 @@ pub(crate) const MODULAR_AUX_INPUT_HI: Range<usize> = AUX_REGISTER_2;
// Must be set to MOD_IS_ZERO for DIV operation i.e. MOD_IS_ZERO * lv[IS_DIV] // Must be set to MOD_IS_ZERO for DIV operation i.e. MOD_IS_ZERO * lv[IS_DIV]
pub(crate) const MODULAR_DIV_DENOM_IS_ZERO: usize = AUX_REGISTER_2.end; pub(crate) const MODULAR_DIV_DENOM_IS_ZERO: usize = AUX_REGISTER_2.end;
#[allow(unused)] // TODO: Will be used when hooking into the CPU
pub(crate) const DIV_NUMERATOR: Range<usize> = MODULAR_INPUT_0;
#[allow(unused)] // TODO: Will be used when hooking into the CPU
pub(crate) const DIV_DENOMINATOR: Range<usize> = MODULAR_MODULUS;
#[allow(unused)] // TODO: Will be used when hooking into the CPU
pub(crate) const DIV_OUTPUT: Range<usize> =
MODULAR_QUO_INPUT.start..MODULAR_QUO_INPUT.start + N_LIMBS;
// Need one column for the table, then two columns for every value // Need one column for the table, then two columns for every value
// that needs to be range checked in the trace, namely the permutation // that needs to be range checked in the trace, namely the permutation
// of the column and the permutation of the range. The two // of the column and the permutation of the range. The two

View File

@ -0,0 +1,339 @@
use std::ops::Range;
use ethereum_types::U256;
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::PrimeField64;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use crate::arithmetic::columns::*;
use crate::arithmetic::modular::{
generate_modular_op, modular_constr_poly, modular_constr_poly_ext_circuit,
};
use crate::arithmetic::utils::*;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
/// Generate the output and auxiliary values for modular operations.
pub(crate) fn generate<F: PrimeField64>(
lv: &mut [F],
nv: &mut [F],
filter: usize,
input0: U256,
input1: U256,
result: U256,
) {
debug_assert!(lv.len() == NUM_ARITH_COLUMNS);
u256_to_array(&mut lv[INPUT_REGISTER_0], input0);
u256_to_array(&mut lv[INPUT_REGISTER_1], input1);
u256_to_array(&mut lv[OUTPUT_REGISTER], result);
let input_limbs = read_value_i64_limbs::<N_LIMBS, _>(lv, INPUT_REGISTER_0);
let pol_input = pol_extend(input_limbs);
let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, INPUT_REGISTER_1);
debug_assert!(
&quo_input[N_LIMBS..].iter().all(|&x| x == F::ZERO),
"expected top half of quo_input to be zero"
);
// Initialise whole (double) register to zero; the low half will
// be overwritten via lv[AUX_INPUT_REGISTER] below.
for i in MODULAR_QUO_INPUT {
lv[i] = F::ZERO;
}
match filter {
IS_DIV => {
debug_assert!(
lv[OUTPUT_REGISTER]
.iter()
.zip(&quo_input[..N_LIMBS])
.all(|(x, y)| x == y),
"computed output doesn't match expected"
);
lv[AUX_INPUT_REGISTER_0].copy_from_slice(&out);
}
IS_MOD => {
debug_assert!(
lv[OUTPUT_REGISTER].iter().zip(&out).all(|(x, y)| x == y),
"computed output doesn't match expected"
);
lv[AUX_INPUT_REGISTER_0].copy_from_slice(&quo_input[..N_LIMBS]);
}
_ => panic!("expected filter to be IS_DIV or IS_MOD but it was {filter}"),
};
}
/// Verify that num = quo * den + rem and 0 <= rem < den.
fn eval_packed_divmod_helper<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
filter: P,
quo_range: Range<usize>,
rem_range: Range<usize>,
) {
debug_assert!(quo_range.len() == N_LIMBS);
debug_assert!(rem_range.len() == N_LIMBS);
yield_constr.constraint_last_row(filter);
let num = &lv[INPUT_REGISTER_0];
let den = read_value(lv, INPUT_REGISTER_1);
let quo = {
let mut quo = [P::ZEROS; 2 * N_LIMBS];
quo[..N_LIMBS].copy_from_slice(&lv[quo_range]);
quo
};
let rem = read_value(lv, rem_range);
let mut constr_poly = modular_constr_poly(lv, nv, yield_constr, filter, rem, den, quo);
let input = num;
pol_sub_assign(&mut constr_poly, input);
for &c in constr_poly.iter() {
yield_constr.constraint_transition(filter * c);
}
}
pub(crate) fn eval_packed<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
) {
eval_packed_divmod_helper(
lv,
nv,
yield_constr,
lv[IS_DIV],
OUTPUT_REGISTER,
AUX_INPUT_REGISTER_0,
);
eval_packed_divmod_helper(
lv,
nv,
yield_constr,
lv[IS_MOD],
AUX_INPUT_REGISTER_0,
OUTPUT_REGISTER,
);
}
fn eval_ext_circuit_divmod_helper<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
filter: ExtensionTarget<D>,
quo_range: Range<usize>,
rem_range: Range<usize>,
) {
yield_constr.constraint_last_row(builder, filter);
let num = &lv[INPUT_REGISTER_0];
let den = read_value(lv, INPUT_REGISTER_1);
let quo = {
let zero = builder.zero_extension();
let mut quo = [zero; 2 * N_LIMBS];
quo[..N_LIMBS].copy_from_slice(&lv[quo_range]);
quo
};
let rem = read_value(lv, rem_range);
let mut constr_poly =
modular_constr_poly_ext_circuit(lv, nv, builder, yield_constr, filter, rem, den, quo);
let input = num;
pol_sub_assign_ext_circuit(builder, &mut constr_poly, input);
for &c in constr_poly.iter() {
let t = builder.mul_extension(filter, c);
yield_constr.constraint_transition(builder, t);
}
}
pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
eval_ext_circuit_divmod_helper(
builder,
lv,
nv,
yield_constr,
lv[IS_DIV],
OUTPUT_REGISTER,
AUX_INPUT_REGISTER_0,
);
eval_ext_circuit_divmod_helper(
builder,
lv,
nv,
yield_constr,
lv[IS_MOD],
AUX_INPUT_REGISTER_0,
OUTPUT_REGISTER,
);
}
#[cfg(test)]
mod tests {
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::types::{Field, Sample};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use super::*;
use crate::arithmetic::columns::NUM_ARITH_COLUMNS;
use crate::constraint_consumer::ConstraintConsumer;
const N_RND_TESTS: usize = 1000;
const MODULAR_OPS: [usize; 2] = [IS_MOD, IS_DIV];
// TODO: Should be able to refactor this test to apply to all operations.
#[test]
fn generate_eval_consistency_not_modular() {
type F = GoldilocksField;
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
let nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
// if `IS_MOD == 0`, then the constraints should be met even
// if all values are garbage (and similarly for the other operations).
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
GoldilocksField::ONE,
GoldilocksField::ONE,
GoldilocksField::ONE,
);
eval_packed(&lv, &nv, &mut constraint_consumer);
for &acc in &constraint_consumer.constraint_accs {
assert_eq!(acc, GoldilocksField::ZERO);
}
}
#[test]
fn generate_eval_consistency() {
type F = GoldilocksField;
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
for op_filter in MODULAR_OPS {
for i in 0..N_RND_TESTS {
// set inputs to random values
let mut lv = [F::default(); NUM_ARITH_COLUMNS]
.map(|_| F::from_canonical_u16(rng.gen::<u16>()));
let mut nv = [F::default(); NUM_ARITH_COLUMNS]
.map(|_| F::from_canonical_u16(rng.gen::<u16>()));
// Reset operation columns, then select one
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
lv[op_filter] = F::ONE;
let input0 = U256::from(rng.gen::<[u8; 32]>());
let input1 = {
let mut modulus_limbs = [0u8; 32];
// For the second half of the tests, set the top
// 16-start digits of the "modulus" to zero so it is
// much smaller than the inputs.
if i > N_RND_TESTS / 2 {
// 1 <= start < N_LIMBS
let start = (rng.gen::<usize>() % (modulus_limbs.len() - 1)) + 1;
for mi in modulus_limbs.iter_mut().skip(start) {
*mi = 0u8;
}
}
U256::from(modulus_limbs)
};
let result = if input1 == U256::zero() {
U256::zero()
} else if op_filter == IS_DIV {
input0 / input1
} else {
input0 % input1
};
generate(&mut lv, &mut nv, op_filter, input0, input1, result);
let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
GoldilocksField::ONE,
GoldilocksField::ZERO,
GoldilocksField::ZERO,
);
eval_packed(&lv, &nv, &mut constraint_consumer);
for &acc in &constraint_consumer.constraint_accs {
assert_eq!(acc, GoldilocksField::ZERO);
}
}
}
}
#[test]
fn zero_modulus() {
type F = GoldilocksField;
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
for op_filter in MODULAR_OPS {
for _i in 0..N_RND_TESTS {
// set inputs to random values and the modulus to zero;
// the output is defined to be zero when modulus is zero.
let mut lv = [F::default(); NUM_ARITH_COLUMNS]
.map(|_| F::from_canonical_u16(rng.gen::<u16>()));
let mut nv = [F::default(); NUM_ARITH_COLUMNS]
.map(|_| F::from_canonical_u16(rng.gen::<u16>()));
// Reset operation columns, then select one
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
lv[op_filter] = F::ONE;
let input0 = U256::from(rng.gen::<[u8; 32]>());
let input1 = U256::zero();
generate(&mut lv, &mut nv, op_filter, input0, input1, U256::zero());
// check that the correct output was generated
assert!(lv[OUTPUT_REGISTER].iter().all(|&c| c == F::ZERO));
let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
GoldilocksField::ONE,
GoldilocksField::ZERO,
GoldilocksField::ZERO,
);
eval_packed(&lv, &nv, &mut constraint_consumer);
assert!(constraint_consumer
.constraint_accs
.iter()
.all(|&acc| acc == F::ZERO));
// Corrupt one output limb by setting it to a non-zero value
let random_oi = OUTPUT_REGISTER.start + rng.gen::<usize>() % N_LIMBS;
lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX));
eval_packed(&lv, &nv, &mut constraint_consumer);
// Check that at least one of the constraints was non-zero
assert!(constraint_consumer
.constraint_accs
.iter()
.any(|&acc| acc != F::ZERO));
}
}
}
}

View File

@ -5,6 +5,7 @@ use crate::extension_tower::BN_BASE;
use crate::util::{addmod, mulmod, submod}; use crate::util::{addmod, mulmod, submod};
mod addcy; mod addcy;
mod divmod;
mod modular; mod modular;
mod mul; mod mul;
mod utils; mod utils;
@ -63,9 +64,9 @@ impl BinaryOperator {
BinaryOperator::Mod => columns::IS_MOD, BinaryOperator::Mod => columns::IS_MOD,
BinaryOperator::Lt => columns::IS_LT, BinaryOperator::Lt => columns::IS_LT,
BinaryOperator::Gt => columns::IS_GT, BinaryOperator::Gt => columns::IS_GT,
BinaryOperator::AddFp254 => columns::IS_ADDMOD, BinaryOperator::AddFp254 => columns::IS_ADDFP254,
BinaryOperator::MulFp254 => columns::IS_MULMOD, BinaryOperator::MulFp254 => columns::IS_MULFP254,
BinaryOperator::SubFp254 => columns::IS_SUBMOD, BinaryOperator::SubFp254 => columns::IS_SUBFP254,
} }
} }
} }
@ -209,7 +210,9 @@ fn binary_op_to_rows<F: PrimeField64>(
(row, None) (row, None)
} }
BinaryOperator::Div | BinaryOperator::Mod => { BinaryOperator::Div | BinaryOperator::Mod => {
ternary_op_to_rows::<F>(op.row_filter(), input0, U256::zero(), input1, result) let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS];
divmod::generate(&mut row, &mut nv, op.row_filter(), input0, input1, result);
(row, Some(nv))
} }
BinaryOperator::AddFp254 | BinaryOperator::MulFp254 | BinaryOperator::SubFp254 => { BinaryOperator::AddFp254 | BinaryOperator::MulFp254 | BinaryOperator::SubFp254 => {
ternary_op_to_rows::<F>(op.row_filter(), input0, input1, BN_BASE, result) ternary_op_to_rows::<F>(op.row_filter(), input0, input1, BN_BASE, result)

View File

@ -108,6 +108,8 @@
//! only require 96 columns, or 80 if the output doesn't need to be //! only require 96 columns, or 80 if the output doesn't need to be
//! reduced. //! reduced.
use std::ops::Range;
use ethereum_types::U256; use ethereum_types::U256;
use num::bigint::Sign; use num::bigint::Sign;
use num::{BigInt, One, Zero}; use num::{BigInt, One, Zero};
@ -117,12 +119,29 @@ use plonky2::field::types::{Field, PrimeField64};
use plonky2::hash::hash_types::RichField; use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_builder::CircuitBuilder;
use static_assertions::const_assert;
use super::columns; use super::columns;
use crate::arithmetic::addcy::{eval_ext_circuit_addcy, eval_packed_generic_addcy}; use crate::arithmetic::addcy::{eval_ext_circuit_addcy, eval_packed_generic_addcy};
use crate::arithmetic::columns::*; use crate::arithmetic::columns::*;
use crate::arithmetic::utils::*; use crate::arithmetic::utils::*;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::extension_tower::BN_BASE;
const fn bn254_modulus_limbs() -> [u16; N_LIMBS] {
const_assert!(N_LIMBS == 16); // Assumed below
let mut limbs = [0u16; N_LIMBS];
let mut i = 0;
while i < N_LIMBS / 4 {
let x = BN_BASE.0[i];
limbs[4 * i] = x as u16;
limbs[4 * i + 1] = (x >> 16) as u16;
limbs[4 * i + 2] = (x >> 32) as u16;
limbs[4 * i + 3] = (x >> 48) as u16;
i += 1;
}
limbs
}
/// Convert the base-2^16 representation of a number into a BigInt. /// Convert the base-2^16 representation of a number into a BigInt.
/// ///
@ -190,29 +209,26 @@ fn bigint_to_columns<const N: usize>(num: &BigInt) -> [i64; N] {
/// ///
/// NB: `operation` can set the higher order elements in its result to /// NB: `operation` can set the higher order elements in its result to
/// zero if they are not used. /// zero if they are not used.
fn generate_modular_op<F: PrimeField64>( pub(crate) fn generate_modular_op<F: PrimeField64>(
lv: &mut [F], lv: &mut [F],
nv: &mut [F], nv: &mut [F],
filter: usize, filter: usize,
operation: fn([i64; N_LIMBS], [i64; N_LIMBS]) -> [i64; 2 * N_LIMBS - 1], pol_input: [i64; 2 * N_LIMBS - 1],
) { modulus_range: Range<usize>,
// Inputs are all range-checked in [0, 2^16), so the "as i64" ) -> ([F; N_LIMBS], [F; 2 * N_LIMBS]) {
// conversion is safe. assert!(modulus_range.len() == N_LIMBS);
let mut modulus_limbs = read_value_i64_limbs(lv, modulus_range);
let input0_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_0);
let input1_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_1);
let mut modulus_limbs = read_value_i64_limbs(lv, MODULAR_MODULUS);
// BigInts are just used to avoid having to implement modular // BigInts are just used to avoid having to implement modular
// reduction. // reduction.
let mut modulus = columns_to_bigint(&modulus_limbs); let mut modulus = columns_to_bigint(&modulus_limbs);
// constr_poly is initialised to the calculated input, and is // constr_poly is initialised to the input calculation as
// used as such for the BigInt reduction; later, other values are // polynomials, and is used as such for the BigInt reduction;
// added/subtracted, which is where its meaning as the "constraint // later, other values are added/subtracted, which is where its
// polynomial" comes in. // meaning as the "constraint polynomial" comes in.
let mut constr_poly = [0i64; 2 * N_LIMBS]; let mut constr_poly = [0i64; 2 * N_LIMBS];
constr_poly[..2 * N_LIMBS - 1].copy_from_slice(&operation(input0_limbs, input1_limbs)); constr_poly[..2 * N_LIMBS - 1].copy_from_slice(&pol_input);
// two_exp_256 == 2^256 // two_exp_256 == 2^256
let two_exp_256 = { let two_exp_256 = {
@ -264,8 +280,6 @@ fn generate_modular_op<F: PrimeField64>(
// Higher order terms of the product must be zero for valid quot and modulus: // Higher order terms of the product must be zero for valid quot and modulus:
debug_assert!(&prod[2 * N_LIMBS..].iter().all(|&x| x == 0i64)); debug_assert!(&prod[2 * N_LIMBS..].iter().all(|&x| x == 0i64));
lv[MODULAR_OUTPUT].copy_from_slice(&output_limbs.map(F::from_canonical_i64));
lv[MODULAR_QUO_INPUT].copy_from_slice(&quot_limbs.map(F::from_noncanonical_i64));
// constr_poly must be zero when evaluated at x = β := // constr_poly must be zero when evaluated at x = β :=
// 2^LIMB_BITS, hence it's divisible by (x - β). `aux_limbs` is // 2^LIMB_BITS, hence it's divisible by (x - β). `aux_limbs` is
// the result of removing that root. // the result of removing that root.
@ -286,11 +300,16 @@ fn generate_modular_op<F: PrimeField64>(
nv[MODULAR_MOD_IS_ZERO] = mod_is_zero; nv[MODULAR_MOD_IS_ZERO] = mod_is_zero;
nv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(F::from_canonical_i64)); nv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(F::from_canonical_i64));
nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * lv[IS_DIV]; nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * lv[IS_DIV];
(
output_limbs.map(F::from_canonical_i64),
quot_limbs.map(F::from_noncanonical_i64),
)
} }
/// Generate the output and auxiliary values for modular operations. /// Generate the output and auxiliary values for modular operations.
/// ///
/// `filter` must be one of `columns::IS_{ADDMOD,MULMOD,MOD}`. /// `filter` must be one of `columns::IS_{ADD,MUL,SUB}{MOD,FP254}`.
pub(crate) fn generate<F: PrimeField64>( pub(crate) fn generate<F: PrimeField64>(
lv: &mut [F], lv: &mut [F],
nv: &mut [F], nv: &mut [F],
@ -305,15 +324,29 @@ pub(crate) fn generate<F: PrimeField64>(
u256_to_array(&mut lv[MODULAR_INPUT_1], input1); u256_to_array(&mut lv[MODULAR_INPUT_1], input1);
u256_to_array(&mut lv[MODULAR_MODULUS], modulus); u256_to_array(&mut lv[MODULAR_MODULUS], modulus);
match filter { if [
columns::IS_ADDMOD => generate_modular_op(lv, nv, filter, pol_add), columns::IS_ADDFP254,
columns::IS_SUBMOD => generate_modular_op(lv, nv, filter, pol_sub), columns::IS_SUBFP254,
columns::IS_MULMOD => generate_modular_op(lv, nv, filter, pol_mul_wide), columns::IS_MULFP254,
columns::IS_MOD | columns::IS_DIV => { ]
generate_modular_op(lv, nv, filter, |a, _| pol_extend(a)) .contains(&filter)
{
debug_assert!(modulus == BN_BASE);
} }
// Inputs are all in [0, 2^16), so the "as i64" conversion is safe.
let input0_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_0);
let input1_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_1);
let pol_input = match filter {
columns::IS_ADDMOD | columns::IS_ADDFP254 => pol_add(input0_limbs, input1_limbs),
columns::IS_SUBMOD | columns::IS_SUBFP254 => pol_sub(input0_limbs, input1_limbs),
columns::IS_MULMOD | columns::IS_MULFP254 => pol_mul_wide(input0_limbs, input1_limbs),
_ => panic!("generate modular operation called with unknown opcode"), _ => panic!("generate modular operation called with unknown opcode"),
} };
let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, MODULAR_MODULUS);
lv[MODULAR_OUTPUT].copy_from_slice(&out);
lv[MODULAR_QUO_INPUT].copy_from_slice(&quo_input);
} }
/// Build the part of the constraint polynomial that's common to all /// Build the part of the constraint polynomial that's common to all
@ -324,13 +357,15 @@ pub(crate) fn generate<F: PrimeField64>(
/// c(x) + q(x) * m(x) + (x - β) * s(x) /// c(x) + q(x) * m(x) + (x - β) * s(x)
/// ///
/// and check consistency when m = 0, and that c is reduced. /// and check consistency when m = 0, and that c is reduced.
fn modular_constr_poly<P: PackedField>( pub(crate) fn modular_constr_poly<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS], lv: &[P; NUM_ARITH_COLUMNS],
nv: &[P; NUM_ARITH_COLUMNS], nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>, yield_constr: &mut ConstraintConsumer<P>,
filter: P, filter: P,
mut output: [P; N_LIMBS],
mut modulus: [P; N_LIMBS],
quot: [P; 2 * N_LIMBS],
) -> [P; 2 * N_LIMBS] { ) -> [P; 2 * N_LIMBS] {
let mut modulus = read_value::<N_LIMBS, _>(lv, MODULAR_MODULUS);
let mod_is_zero = nv[MODULAR_MOD_IS_ZERO]; let mod_is_zero = nv[MODULAR_MOD_IS_ZERO];
// Check that mod_is_zero is zero or one // Check that mod_is_zero is zero or one
@ -345,8 +380,6 @@ fn modular_constr_poly<P: PackedField>(
// modulus = 0. // modulus = 0.
modulus[0] += mod_is_zero; modulus[0] += mod_is_zero;
let mut output = read_value::<N_LIMBS, _>(lv, MODULAR_OUTPUT);
// Is 1 iff the operation is DIV and the denominator is zero. // Is 1 iff the operation is DIV and the denominator is zero.
let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO]; let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO];
yield_constr.constraint_transition(filter * (mod_is_zero * lv[IS_DIV] - div_denom_is_zero)); yield_constr.constraint_transition(filter * (mod_is_zero * lv[IS_DIV] - div_denom_is_zero));
@ -365,7 +398,8 @@ fn modular_constr_poly<P: PackedField>(
// modulus + out_aux_red == output + is_less_than*2^256 // modulus + out_aux_red == output + is_less_than*2^256
// //
// and we are given output = out_aux_red when modulus is zero. // and we are given output = out_aux_red when modulus is zero.
let is_less_than = P::ONES - mod_is_zero * lv[IS_DIV]; let mut is_less_than = [P::ZEROS; N_LIMBS];
is_less_than[0] = P::ONES - mod_is_zero * lv[IS_DIV];
// NB: output and modulus in lv while out_aux_red and // NB: output and modulus in lv while out_aux_red and
// is_less_than (via mod_is_zero) depend on nv, hence the // is_less_than (via mod_is_zero) depend on nv, hence the
// 'is_two_row_op' argument is set to 'true'. // 'is_two_row_op' argument is set to 'true'.
@ -375,19 +409,13 @@ fn modular_constr_poly<P: PackedField>(
&modulus, &modulus,
out_aux_red, out_aux_red,
&output, &output,
is_less_than, &is_less_than,
true, true,
); );
// restore output[0] // restore output[0]
output[0] -= div_denom_is_zero; output[0] -= div_denom_is_zero;
// prod = q(x) * m(x) // prod = q(x) * m(x)
let quot = {
let mut quot = [P::default(); 2 * N_LIMBS];
quot.copy_from_slice(&lv[MODULAR_QUO_INPUT]);
quot
};
let prod = pol_mul_wide2(quot, modulus); let prod = pol_mul_wide2(quot, modulus);
// higher order terms must be zero // higher order terms must be zero
for &x in prod[2 * N_LIMBS..].iter() { for &x in prod[2 * N_LIMBS..].iter() {
@ -419,25 +447,34 @@ fn modular_constr_poly<P: PackedField>(
} }
/// Add constraints for modular operations. /// Add constraints for modular operations.
pub(crate) fn eval_packed_generic<P: PackedField>( pub(crate) fn eval_packed<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS], lv: &[P; NUM_ARITH_COLUMNS],
nv: &[P; NUM_ARITH_COLUMNS], nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>, yield_constr: &mut ConstraintConsumer<P>,
) { ) {
// NB: The CTL code guarantees that filter is 0 or 1, i.e. that // NB: The CTL code guarantees that filter is 0 or 1, i.e. that
// only one of the operations below is "live". // only one of the operations below is "live".
let filter = lv[columns::IS_ADDMOD] let bn254_filter =
+ lv[columns::IS_SUBMOD] lv[columns::IS_ADDFP254] + lv[columns::IS_MULFP254] + lv[columns::IS_SUBFP254];
+ lv[columns::IS_MULMOD] let filter =
+ lv[columns::IS_MOD] lv[columns::IS_ADDMOD] + lv[columns::IS_SUBMOD] + lv[columns::IS_MULMOD] + bn254_filter;
+ lv[columns::IS_DIV];
// Ensure that this operation is not the last row of the table; // Ensure that this operation is not the last row of the table;
// needed because we access the next row of the table in nv. // needed because we access the next row of the table in nv.
yield_constr.constraint_last_row(filter); yield_constr.constraint_last_row(filter);
// Verify that the modulus is the BN254 modulus for the
// {ADD,MUL,SUB}FP254 operations.
let modulus = read_value::<N_LIMBS, _>(lv, MODULAR_MODULUS);
for (&mi, bi) in modulus.iter().zip(bn254_modulus_limbs()) {
yield_constr.constraint_transition(bn254_filter * (mi - P::Scalar::from_canonical_u16(bi)));
}
let output = read_value::<N_LIMBS, _>(lv, MODULAR_OUTPUT);
let quo_input = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT);
// constr_poly has 2*N_LIMBS limbs // constr_poly has 2*N_LIMBS limbs
let constr_poly = modular_constr_poly(lv, nv, yield_constr, filter); let constr_poly = modular_constr_poly(lv, nv, yield_constr, filter, output, modulus, quo_input);
let input0 = read_value(lv, MODULAR_INPUT_0); let input0 = read_value(lv, MODULAR_INPUT_0);
let input1 = read_value(lv, MODULAR_INPUT_1); let input1 = read_value(lv, MODULAR_INPUT_1);
@ -445,13 +482,15 @@ pub(crate) fn eval_packed_generic<P: PackedField>(
let add_input = pol_add(input0, input1); let add_input = pol_add(input0, input1);
let sub_input = pol_sub(input0, input1); let sub_input = pol_sub(input0, input1);
let mul_input = pol_mul_wide(input0, input1); let mul_input = pol_mul_wide(input0, input1);
let mod_input = pol_extend(input0);
let add_filter = lv[columns::IS_ADDMOD] + lv[columns::IS_ADDFP254];
let sub_filter = lv[columns::IS_SUBMOD] + lv[columns::IS_SUBFP254];
let mul_filter = lv[columns::IS_MULMOD] + lv[columns::IS_MULFP254];
for (input, &filter) in [ for (input, &filter) in [
(&add_input, &lv[columns::IS_ADDMOD]), (&add_input, &add_filter),
(&sub_input, &lv[columns::IS_SUBMOD]), (&sub_input, &sub_filter),
(&mul_input, &lv[columns::IS_MULMOD]), (&mul_input, &mul_filter),
(&mod_input, &(lv[columns::IS_MOD] + lv[columns::IS_DIV])),
] { ] {
// Need constr_poly_copy to be the first argument to // Need constr_poly_copy to be the first argument to
// pol_sub_assign, since it is the longer of the two // pol_sub_assign, since it is the longer of the two
@ -473,14 +512,16 @@ pub(crate) fn eval_packed_generic<P: PackedField>(
} }
} }
fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, const D: usize>( pub(crate) fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS], lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS], nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>, yield_constr: &mut RecursiveConstraintConsumer<F, D>,
filter: ExtensionTarget<D>, filter: ExtensionTarget<D>,
mut output: [ExtensionTarget<D>; N_LIMBS],
mut modulus: [ExtensionTarget<D>; N_LIMBS],
quot: [ExtensionTarget<D>; 2 * N_LIMBS],
) -> [ExtensionTarget<D>; 2 * N_LIMBS] { ) -> [ExtensionTarget<D>; 2 * N_LIMBS] {
let mut modulus = read_value::<N_LIMBS, _>(lv, MODULAR_MODULUS);
let mod_is_zero = nv[MODULAR_MOD_IS_ZERO]; let mod_is_zero = nv[MODULAR_MOD_IS_ZERO];
let t = builder.mul_sub_extension(mod_is_zero, mod_is_zero, mod_is_zero); let t = builder.mul_sub_extension(mod_is_zero, mod_is_zero, mod_is_zero);
@ -494,8 +535,6 @@ fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, const D: usize>
modulus[0] = builder.add_extension(modulus[0], mod_is_zero); modulus[0] = builder.add_extension(modulus[0], mod_is_zero);
let mut output = read_value::<N_LIMBS, _>(lv, MODULAR_OUTPUT);
let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO]; let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO];
let t = builder.mul_sub_extension(mod_is_zero, lv[IS_DIV], div_denom_is_zero); let t = builder.mul_sub_extension(mod_is_zero, lv[IS_DIV], div_denom_is_zero);
let t = builder.mul_extension(filter, t); let t = builder.mul_extension(filter, t);
@ -504,7 +543,9 @@ fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, const D: usize>
let out_aux_red = &nv[MODULAR_OUT_AUX_RED]; let out_aux_red = &nv[MODULAR_OUT_AUX_RED];
let one = builder.one_extension(); let one = builder.one_extension();
let is_less_than = let zero = builder.zero_extension();
let mut is_less_than = [zero; N_LIMBS];
is_less_than[0] =
builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one); builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one);
eval_ext_circuit_addcy( eval_ext_circuit_addcy(
@ -514,16 +555,10 @@ fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, const D: usize>
&modulus, &modulus,
out_aux_red, out_aux_red,
&output, &output,
is_less_than, &is_less_than,
true, true,
); );
output[0] = builder.sub_extension(output[0], div_denom_is_zero); output[0] = builder.sub_extension(output[0], div_denom_is_zero);
let quot = {
let zero = builder.zero_extension();
let mut quot = [zero; 2 * N_LIMBS];
quot.copy_from_slice(&lv[MODULAR_QUO_INPUT]);
quot
};
let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus); let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus);
for &x in prod[2 * N_LIMBS..].iter() { for &x in prod[2 * N_LIMBS..].iter() {
@ -559,31 +594,60 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS], nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>, yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) { ) {
let bn254_filter = builder.add_many_extension([
lv[columns::IS_ADDFP254],
lv[columns::IS_MULFP254],
lv[columns::IS_SUBFP254],
]);
let filter = builder.add_many_extension([ let filter = builder.add_many_extension([
lv[columns::IS_ADDMOD], lv[columns::IS_ADDMOD],
lv[columns::IS_SUBMOD], lv[columns::IS_SUBMOD],
lv[columns::IS_MULMOD], lv[columns::IS_MULMOD],
lv[columns::IS_MOD], bn254_filter,
lv[columns::IS_DIV],
]); ]);
yield_constr.constraint_last_row(builder, filter); yield_constr.constraint_last_row(builder, filter);
let constr_poly = modular_constr_poly_ext_circuit(lv, nv, builder, yield_constr, filter); let modulus = read_value::<N_LIMBS, _>(lv, MODULAR_MODULUS);
for (&mi, bi) in modulus.iter().zip(bn254_modulus_limbs()) {
// bn254_filter * (mi - bi)
let t = builder.arithmetic_extension(
F::ONE,
-F::from_canonical_u16(bi),
mi,
bn254_filter,
bn254_filter,
);
yield_constr.constraint_transition(builder, t);
}
let output = read_value::<N_LIMBS, _>(lv, MODULAR_OUTPUT);
let quo_input = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT);
let constr_poly = modular_constr_poly_ext_circuit(
lv,
nv,
builder,
yield_constr,
filter,
output,
modulus,
quo_input,
);
let input0 = read_value(lv, MODULAR_INPUT_0); let input0 = read_value(lv, MODULAR_INPUT_0);
let input1 = read_value(lv, MODULAR_INPUT_1); let input1 = read_value(lv, MODULAR_INPUT_1);
let add_input = pol_add_ext_circuit(builder, input0, input1); let add_input = pol_add_ext_circuit(builder, input0, input1);
let sub_input = pol_sub_ext_circuit(builder, input0, input1); let sub_input = pol_sub_ext_circuit(builder, input0, input1);
let mul_input = pol_mul_wide_ext_circuit(builder, input0, input1); let mul_input = pol_mul_wide_ext_circuit(builder, input0, input1);
let mod_input = pol_extend_ext_circuit(builder, input0);
let mod_div_filter = builder.add_extension(lv[columns::IS_MOD], lv[columns::IS_DIV]); let add_filter = builder.add_extension(lv[columns::IS_ADDMOD], lv[columns::IS_ADDFP254]);
let sub_filter = builder.add_extension(lv[columns::IS_SUBMOD], lv[columns::IS_SUBFP254]);
let mul_filter = builder.add_extension(lv[columns::IS_MULMOD], lv[columns::IS_MULFP254]);
for (input, &filter) in [ for (input, &filter) in [
(&add_input, &lv[columns::IS_ADDMOD]), (&add_input, &add_filter),
(&sub_input, &lv[columns::IS_SUBMOD]), (&sub_input, &sub_filter),
(&mul_input, &lv[columns::IS_MULMOD]), (&mul_input, &mul_filter),
(&mod_input, &mod_div_filter),
] { ] {
let mut constr_poly_copy = constr_poly; let mut constr_poly_copy = constr_poly;
pol_sub_assign_ext_circuit(builder, &mut constr_poly_copy, input); pol_sub_assign_ext_circuit(builder, &mut constr_poly_copy, input);
@ -604,8 +668,17 @@ mod tests {
use super::*; use super::*;
use crate::arithmetic::columns::NUM_ARITH_COLUMNS; use crate::arithmetic::columns::NUM_ARITH_COLUMNS;
use crate::constraint_consumer::ConstraintConsumer; use crate::constraint_consumer::ConstraintConsumer;
use crate::extension_tower::BN_BASE;
const N_RND_TESTS: usize = 1000; const N_RND_TESTS: usize = 1000;
const MODULAR_OPS: [usize; 6] = [
IS_ADDMOD,
IS_SUBMOD,
IS_MULMOD,
IS_ADDFP254,
IS_SUBFP254,
IS_MULFP254,
];
// TODO: Should be able to refactor this test to apply to all operations. // TODO: Should be able to refactor this test to apply to all operations.
#[test] #[test]
@ -617,12 +690,12 @@ mod tests {
let nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng)); let nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
// if `IS_ADDMOD == 0`, then the constraints should be met even // if `IS_ADDMOD == 0`, then the constraints should be met even
// if all values are garbage. // if all values are garbage (and similarly for the other operations).
lv[IS_ADDMOD] = F::ZERO; for op in MODULAR_OPS {
lv[IS_SUBMOD] = F::ZERO; lv[op] = F::ZERO;
lv[IS_MULMOD] = F::ZERO; }
lv[IS_MOD] = F::ZERO;
lv[IS_DIV] = F::ZERO; lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;
let mut constraint_consumer = ConstraintConsumer::new( let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
@ -630,7 +703,7 @@ mod tests {
GoldilocksField::ONE, GoldilocksField::ONE,
GoldilocksField::ONE, GoldilocksField::ONE,
); );
eval_packed_generic(&lv, &nv, &mut constraint_consumer); eval_packed(&lv, &nv, &mut constraint_consumer);
for &acc in &constraint_consumer.constraint_accs { for &acc in &constraint_consumer.constraint_accs {
assert_eq!(acc, GoldilocksField::ZERO); assert_eq!(acc, GoldilocksField::ZERO);
} }
@ -642,7 +715,7 @@ mod tests {
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
for op_filter in [IS_ADDMOD, IS_DIV, IS_SUBMOD, IS_MOD, IS_MULMOD] { for op_filter in MODULAR_OPS {
for i in 0..N_RND_TESTS { for i in 0..N_RND_TESTS {
// set inputs to random values // set inputs to random values
let mut lv = [F::default(); NUM_ARITH_COLUMNS] let mut lv = [F::default(); NUM_ARITH_COLUMNS]
@ -651,16 +724,19 @@ mod tests {
.map(|_| F::from_canonical_u16(rng.gen::<u16>())); .map(|_| F::from_canonical_u16(rng.gen::<u16>()));
// Reset operation columns, then select one // Reset operation columns, then select one
lv[IS_ADDMOD] = F::ZERO; for op in MODULAR_OPS {
lv[IS_SUBMOD] = F::ZERO; lv[op] = F::ZERO;
lv[IS_MULMOD] = F::ZERO; }
lv[IS_MOD] = F::ZERO;
lv[IS_DIV] = F::ZERO; lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;
lv[op_filter] = F::ONE; lv[op_filter] = F::ONE;
let input0 = U256::from(rng.gen::<[u8; 32]>()); let input0 = U256::from(rng.gen::<[u8; 32]>());
let input1 = U256::from(rng.gen::<[u8; 32]>()); let input1 = U256::from(rng.gen::<[u8; 32]>());
let modulus = if [IS_ADDFP254, IS_MULFP254, IS_SUBFP254].contains(&op_filter) {
BN_BASE
} else {
let mut modulus_limbs = [0u8; 32]; let mut modulus_limbs = [0u8; 32];
// For the second half of the tests, set the top // For the second half of the tests, set the top
// 16-start digits of the modulus to zero so it is // 16-start digits of the modulus to zero so it is
@ -672,7 +748,8 @@ mod tests {
*mi = 0u8; *mi = 0u8;
} }
} }
let modulus = U256::from(modulus_limbs); U256::from(modulus_limbs)
};
generate(&mut lv, &mut nv, op_filter, input0, input1, modulus); generate(&mut lv, &mut nv, op_filter, input0, input1, modulus);
@ -682,7 +759,7 @@ mod tests {
GoldilocksField::ZERO, GoldilocksField::ZERO,
GoldilocksField::ZERO, GoldilocksField::ZERO,
); );
eval_packed_generic(&lv, &nv, &mut constraint_consumer); eval_packed(&lv, &nv, &mut constraint_consumer);
for &acc in &constraint_consumer.constraint_accs { for &acc in &constraint_consumer.constraint_accs {
assert_eq!(acc, GoldilocksField::ZERO); assert_eq!(acc, GoldilocksField::ZERO);
} }
@ -696,7 +773,7 @@ mod tests {
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
for op_filter in [IS_ADDMOD, IS_SUBMOD, IS_DIV, IS_MOD, IS_MULMOD] { for op_filter in [IS_ADDMOD, IS_SUBMOD, IS_MULMOD] {
for _i in 0..N_RND_TESTS { for _i in 0..N_RND_TESTS {
// set inputs to random values and the modulus to zero; // set inputs to random values and the modulus to zero;
// the output is defined to be zero when modulus is zero. // the output is defined to be zero when modulus is zero.
@ -706,11 +783,11 @@ mod tests {
.map(|_| F::from_canonical_u16(rng.gen::<u16>())); .map(|_| F::from_canonical_u16(rng.gen::<u16>()));
// Reset operation columns, then select one // Reset operation columns, then select one
lv[IS_ADDMOD] = F::ZERO; for op in MODULAR_OPS {
lv[IS_SUBMOD] = F::ZERO; lv[op] = F::ZERO;
lv[IS_MULMOD] = F::ZERO; }
lv[IS_MOD] = F::ZERO;
lv[IS_DIV] = F::ZERO; lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;
lv[op_filter] = F::ONE; lv[op_filter] = F::ONE;
let input0 = U256::from(rng.gen::<[u8; 32]>()); let input0 = U256::from(rng.gen::<[u8; 32]>());
@ -720,11 +797,7 @@ mod tests {
generate(&mut lv, &mut nv, op_filter, input0, input1, modulus); generate(&mut lv, &mut nv, op_filter, input0, input1, modulus);
// check that the correct output was generated // check that the correct output was generated
if op_filter == IS_DIV {
assert!(lv[DIV_OUTPUT].iter().all(|&c| c == F::ZERO));
} else {
assert!(lv[MODULAR_OUTPUT].iter().all(|&c| c == F::ZERO)); assert!(lv[MODULAR_OUTPUT].iter().all(|&c| c == F::ZERO));
}
let mut constraint_consumer = ConstraintConsumer::new( let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
@ -732,22 +805,17 @@ mod tests {
GoldilocksField::ZERO, GoldilocksField::ZERO,
GoldilocksField::ZERO, GoldilocksField::ZERO,
); );
eval_packed_generic(&lv, &nv, &mut constraint_consumer); eval_packed(&lv, &nv, &mut constraint_consumer);
assert!(constraint_consumer assert!(constraint_consumer
.constraint_accs .constraint_accs
.iter() .iter()
.all(|&acc| acc == F::ZERO)); .all(|&acc| acc == F::ZERO));
// Corrupt one output limb by setting it to a non-zero value // Corrupt one output limb by setting it to a non-zero value
if op_filter == IS_DIV {
let random_oi = DIV_OUTPUT.start + rng.gen::<usize>() % N_LIMBS;
lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX));
} else {
let random_oi = MODULAR_OUTPUT.start + rng.gen::<usize>() % N_LIMBS; let random_oi = MODULAR_OUTPUT.start + rng.gen::<usize>() % N_LIMBS;
lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX));
};
eval_packed_generic(&lv, &nv, &mut constraint_consumer); eval_packed(&lv, &nv, &mut constraint_consumer);
// Check that at least one of the constraints was non-zero // Check that at least one of the constraints was non-zero
assert!(constraint_consumer assert!(constraint_consumer

View File

@ -70,11 +70,12 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
pub fn generate<F: PrimeField64>(lv: &mut [F], left_in: U256, right_in: U256) { pub fn generate<F: PrimeField64>(lv: &mut [F], left_in: U256, right_in: U256) {
// TODO: It would probably be clearer/cleaner to read the U256 // TODO: It would probably be clearer/cleaner to read the U256
// into an [i64;N] and then copy that to the lv table. // into an [i64;N] and then copy that to the lv table.
u256_to_array(&mut lv[MUL_INPUT_0], left_in); u256_to_array(&mut lv[INPUT_REGISTER_0], left_in);
u256_to_array(&mut lv[MUL_INPUT_1], right_in); u256_to_array(&mut lv[INPUT_REGISTER_1], right_in);
u256_to_array(&mut lv[INPUT_REGISTER_2], U256::zero());
let input0 = read_value_i64_limbs(lv, MUL_INPUT_0); let input0 = read_value_i64_limbs(lv, INPUT_REGISTER_0);
let input1 = read_value_i64_limbs(lv, MUL_INPUT_1); let input1 = read_value_i64_limbs(lv, INPUT_REGISTER_1);
const MASK: i64 = (1i64 << LIMB_BITS) - 1i64; const MASK: i64 = (1i64 << LIMB_BITS) - 1i64;
@ -96,7 +97,7 @@ pub fn generate<F: PrimeField64>(lv: &mut [F], left_in: U256, right_in: U256) {
// aux_limbs to handle the fact that unreduced_prod will // aux_limbs to handle the fact that unreduced_prod will
// inevitably contain one digit's worth that is > 2^256. // inevitably contain one digit's worth that is > 2^256.
lv[MUL_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_i64(c))); lv[OUTPUT_REGISTER].copy_from_slice(&output_limbs.map(|c| F::from_canonical_i64(c)));
pol_sub_assign(&mut unreduced_prod, &output_limbs); pol_sub_assign(&mut unreduced_prod, &output_limbs);
let mut aux_limbs = pol_remove_root_2exp::<LIMB_BITS, _, N_LIMBS>(unreduced_prod); let mut aux_limbs = pol_remove_root_2exp::<LIMB_BITS, _, N_LIMBS>(unreduced_prod);
@ -121,9 +122,9 @@ pub fn eval_packed_generic<P: PackedField>(
let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS);
let is_mul = lv[IS_MUL]; let is_mul = lv[IS_MUL];
let input0_limbs = read_value::<N_LIMBS, _>(lv, MUL_INPUT_0); let input0_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_0);
let input1_limbs = read_value::<N_LIMBS, _>(lv, MUL_INPUT_1); let input1_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_1);
let output_limbs = read_value::<N_LIMBS, _>(lv, MUL_OUTPUT); let output_limbs = read_value::<N_LIMBS, _>(lv, OUTPUT_REGISTER);
let aux_limbs = { let aux_limbs = {
// MUL_AUX_INPUT was offset by 2^20 in generation, so we undo // MUL_AUX_INPUT was offset by 2^20 in generation, so we undo
@ -173,9 +174,9 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
yield_constr: &mut RecursiveConstraintConsumer<F, D>, yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) { ) {
let is_mul = lv[IS_MUL]; let is_mul = lv[IS_MUL];
let input0_limbs = read_value::<N_LIMBS, _>(lv, MUL_INPUT_0); let input0_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_0);
let input1_limbs = read_value::<N_LIMBS, _>(lv, MUL_INPUT_1); let input1_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_1);
let output_limbs = read_value::<N_LIMBS, _>(lv, MUL_OUTPUT); let output_limbs = read_value::<N_LIMBS, _>(lv, OUTPUT_REGISTER);
let aux_limbs = { let aux_limbs = {
let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << LIMB_BITS)); let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << LIMB_BITS));
@ -253,7 +254,7 @@ mod tests {
for _i in 0..N_RND_TESTS { for _i in 0..N_RND_TESTS {
// set inputs to random values // set inputs to random values
for (ai, bi) in MUL_INPUT_0.zip(MUL_INPUT_1) { for (ai, bi) in INPUT_REGISTER_0.zip(INPUT_REGISTER_1) {
lv[ai] = F::from_canonical_u16(rng.gen()); lv[ai] = F::from_canonical_u16(rng.gen());
lv[bi] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen());
} }

View File

@ -227,17 +227,6 @@ where
zero_extend zero_extend
} }
pub(crate) fn pol_extend_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
a: [ExtensionTarget<D>; N_LIMBS],
) -> [ExtensionTarget<D>; 2 * N_LIMBS - 1] {
let zero = builder.zero_extension();
let mut zero_extend = [zero; 2 * N_LIMBS - 1];
zero_extend[..N_LIMBS].copy_from_slice(&a);
zero_extend
}
/// Given polynomial a(x) = \sum_{i=0}^{N-2} a[i] x^i and an element /// Given polynomial a(x) = \sum_{i=0}^{N-2} a[i] x^i and an element
/// `root`, return b = (x - root) * a(x). /// `root`, return b = (x - root) * a(x).
pub(crate) fn pol_adjoin_root<T, U, const N: usize>(a: [T; N], root: U) -> [T; N] pub(crate) fn pol_adjoin_root<T, U, const N: usize>(a: [T; N], root: U) -> [T; N]

View File

@ -8,6 +8,7 @@ use plonky2::field::packed::PackedField;
use plonky2::field::types::Field; use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField; use plonky2::hash::hash_types::RichField;
use crate::all_stark::Table;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS};
use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::membus::NUM_GP_CHANNELS;
@ -15,7 +16,7 @@ use crate::cpu::{
bootstrap_kernel, contextops, control_flow, decode, dup_swap, gas, jumps, membus, memio, bootstrap_kernel, contextops, control_flow, decode, dup_swap, gas, jumps, membus, memio,
modfp254, pc, shift, simple_logic, stack, stack_bounds, syscalls, modfp254, pc, shift, simple_logic, stack, stack_bounds, syscalls,
}; };
use crate::cross_table_lookup::Column; use crate::cross_table_lookup::{Column, TableWithColumns};
use crate::memory::segments::Segment; use crate::memory::segments::Segment;
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
use crate::stark::Stark; use crate::stark::Stark;
@ -45,8 +46,10 @@ pub fn ctl_filter_keccak_sponge<F: Field>() -> Column<F> {
Column::single(COL_MAP.is_keccak_sponge) Column::single(COL_MAP.is_keccak_sponge)
} }
pub fn ctl_data_logic<F: Field>() -> Vec<Column<F>> { /// Create the vector of Columns corresponding to the two inputs and
let mut res = Column::singles([COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor]).collect_vec(); /// one output of a binary operation.
fn ctl_data_binops<F: Field>(ops: &[usize]) -> Vec<Column<F>> {
let mut res = Column::singles(ops).collect_vec();
res.extend(Column::singles(COL_MAP.mem_channels[0].value)); res.extend(Column::singles(COL_MAP.mem_channels[0].value));
res.extend(Column::singles(COL_MAP.mem_channels[1].value)); res.extend(Column::singles(COL_MAP.mem_channels[1].value));
res.extend(Column::singles( res.extend(Column::singles(
@ -55,10 +58,51 @@ pub fn ctl_data_logic<F: Field>() -> Vec<Column<F>> {
res res
} }
/// Create the vector of Columns corresponding to the three inputs and
/// one output of a ternary operation.
fn ctl_data_ternops<F: Field>(ops: &[usize]) -> Vec<Column<F>> {
let mut res = Column::singles(ops).collect_vec();
res.extend(Column::singles(COL_MAP.mem_channels[0].value));
res.extend(Column::singles(COL_MAP.mem_channels[1].value));
res.extend(Column::singles(COL_MAP.mem_channels[2].value));
res.extend(Column::singles(
COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value,
));
res
}
pub fn ctl_data_logic<F: Field>() -> Vec<Column<F>> {
ctl_data_binops(&[COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor])
}
pub fn ctl_filter_logic<F: Field>() -> Column<F> { pub fn ctl_filter_logic<F: Field>() -> Column<F> {
Column::sum([COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor]) Column::sum([COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor])
} }
pub fn ctl_arithmetic_rows<F: Field>() -> TableWithColumns<F> {
const OPS: [usize; 13] = [
COL_MAP.op.add,
COL_MAP.op.sub,
COL_MAP.op.mul,
COL_MAP.op.lt,
COL_MAP.op.gt,
COL_MAP.op.addfp254,
COL_MAP.op.mulfp254,
COL_MAP.op.subfp254,
COL_MAP.op.addmod,
COL_MAP.op.mulmod,
COL_MAP.op.submod,
COL_MAP.op.div,
COL_MAP.op.mod_,
];
// Create the CPU Table whose columns are those with the three
// inputs and one output of the ternary operations listed in `ops`
// (also `ops` is used as the operation filter). The list of
// operations includes binary operations which will simply ignore
// the third input.
TableWithColumns::new(Table::Cpu, ctl_data_ternops(&OPS), Some(Column::sum(OPS)))
}
pub const MEM_CODE_CHANNEL_IDX: usize = 0; pub const MEM_CODE_CHANNEL_IDX: usize = 0;
pub const MEM_GP_CHANNELS_IDX_START: usize = MEM_CODE_CHANNEL_IDX + 1; pub const MEM_GP_CHANNELS_IDX_START: usize = MEM_CODE_CHANNEL_IDX + 1;

View File

@ -26,6 +26,7 @@ use plonky2::util::timing::TimingTree;
use plonky2_util::log2_ceil; use plonky2_util::log2_ceil;
use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES};
use crate::arithmetic::arithmetic_stark::ArithmeticStark;
use crate::config::StarkConfig; use crate::config::StarkConfig;
use crate::cpu::cpu_stark::CpuStark; use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CrossTableLookup}; use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CrossTableLookup};
@ -265,6 +266,7 @@ where
F: RichField + Extendable<D>, F: RichField + Extendable<D>,
C: GenericConfig<D, F = F> + 'static, C: GenericConfig<D, F = F> + 'static,
C::Hasher: AlgebraicHasher<F>, C::Hasher: AlgebraicHasher<F>,
[(); ArithmeticStark::<F, D>::COLUMNS]:,
[(); CpuStark::<F, D>::COLUMNS]:, [(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:, [(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:, [(); KeccakSpongeStark::<F, D>::COLUMNS]:,
@ -338,43 +340,50 @@ where
degree_bits_ranges: &[Range<usize>; NUM_TABLES], degree_bits_ranges: &[Range<usize>; NUM_TABLES],
stark_config: &StarkConfig, stark_config: &StarkConfig,
) -> Self { ) -> Self {
let arithmetic = RecursiveCircuitsForTable::new(
Table::Arithmetic,
&all_stark.arithmetic_stark,
degree_bits_ranges[0].clone(),
&all_stark.cross_table_lookups,
stark_config,
);
let cpu = RecursiveCircuitsForTable::new( let cpu = RecursiveCircuitsForTable::new(
Table::Cpu, Table::Cpu,
&all_stark.cpu_stark, &all_stark.cpu_stark,
degree_bits_ranges[0].clone(), degree_bits_ranges[1].clone(),
&all_stark.cross_table_lookups, &all_stark.cross_table_lookups,
stark_config, stark_config,
); );
let keccak = RecursiveCircuitsForTable::new( let keccak = RecursiveCircuitsForTable::new(
Table::Keccak, Table::Keccak,
&all_stark.keccak_stark, &all_stark.keccak_stark,
degree_bits_ranges[1].clone(), degree_bits_ranges[2].clone(),
&all_stark.cross_table_lookups, &all_stark.cross_table_lookups,
stark_config, stark_config,
); );
let keccak_sponge = RecursiveCircuitsForTable::new( let keccak_sponge = RecursiveCircuitsForTable::new(
Table::KeccakSponge, Table::KeccakSponge,
&all_stark.keccak_sponge_stark, &all_stark.keccak_sponge_stark,
degree_bits_ranges[2].clone(), degree_bits_ranges[3].clone(),
&all_stark.cross_table_lookups, &all_stark.cross_table_lookups,
stark_config, stark_config,
); );
let logic = RecursiveCircuitsForTable::new( let logic = RecursiveCircuitsForTable::new(
Table::Logic, Table::Logic,
&all_stark.logic_stark, &all_stark.logic_stark,
degree_bits_ranges[3].clone(), degree_bits_ranges[4].clone(),
&all_stark.cross_table_lookups, &all_stark.cross_table_lookups,
stark_config, stark_config,
); );
let memory = RecursiveCircuitsForTable::new( let memory = RecursiveCircuitsForTable::new(
Table::Memory, Table::Memory,
&all_stark.memory_stark, &all_stark.memory_stark,
degree_bits_ranges[4].clone(), degree_bits_ranges[5].clone(),
&all_stark.cross_table_lookups, &all_stark.cross_table_lookups,
stark_config, stark_config,
); );
let by_table = [cpu, keccak, keccak_sponge, logic, memory]; let by_table = [arithmetic, cpu, keccak, keccak_sponge, logic, memory];
let root = Self::create_root_circuit(&by_table, stark_config); let root = Self::create_root_circuit(&by_table, stark_config);
let aggregation = Self::create_aggregation_circuit(&root); let aggregation = Self::create_aggregation_circuit(&root);
let block = Self::create_block_circuit(&aggregation); let block = Self::create_block_circuit(&aggregation);

View File

@ -20,6 +20,7 @@ use plonky2_maybe_rayon::*;
use plonky2_util::{log2_ceil, log2_strict}; use plonky2_util::{log2_ceil, log2_strict};
use crate::all_stark::{AllStark, Table, NUM_TABLES}; use crate::all_stark::{AllStark, Table, NUM_TABLES};
use crate::arithmetic::arithmetic_stark::ArithmeticStark;
use crate::config::StarkConfig; use crate::config::StarkConfig;
use crate::constraint_consumer::ConstraintConsumer; use crate::constraint_consumer::ConstraintConsumer;
use crate::cpu::cpu_stark::CpuStark; use crate::cpu::cpu_stark::CpuStark;
@ -50,6 +51,7 @@ pub fn prove<F, C, const D: usize>(
where where
F: RichField + Extendable<D>, F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>, C: GenericConfig<D, F = F>,
[(); ArithmeticStark::<F, D>::COLUMNS]:,
[(); CpuStark::<F, D>::COLUMNS]:, [(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:, [(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:, [(); KeccakSpongeStark::<F, D>::COLUMNS]:,
@ -71,6 +73,7 @@ pub fn prove_with_outputs<F, C, const D: usize>(
where where
F: RichField + Extendable<D>, F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>, C: GenericConfig<D, F = F>,
[(); ArithmeticStark::<F, D>::COLUMNS]:,
[(); CpuStark::<F, D>::COLUMNS]:, [(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:, [(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:, [(); KeccakSpongeStark::<F, D>::COLUMNS]:,
@ -98,6 +101,7 @@ pub(crate) fn prove_with_traces<F, C, const D: usize>(
where where
F: RichField + Extendable<D>, F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>, C: GenericConfig<D, F = F>,
[(); ArithmeticStark::<F, D>::COLUMNS]:,
[(); CpuStark::<F, D>::COLUMNS]:, [(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:, [(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:, [(); KeccakSpongeStark::<F, D>::COLUMNS]:,
@ -185,12 +189,26 @@ fn prove_with_commitments<F, C, const D: usize>(
where where
F: RichField + Extendable<D>, F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>, C: GenericConfig<D, F = F>,
[(); ArithmeticStark::<F, D>::COLUMNS]:,
[(); CpuStark::<F, D>::COLUMNS]:, [(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:, [(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:, [(); KeccakSpongeStark::<F, D>::COLUMNS]:,
[(); LogicStark::<F, D>::COLUMNS]:, [(); LogicStark::<F, D>::COLUMNS]:,
[(); MemoryStark::<F, D>::COLUMNS]:, [(); MemoryStark::<F, D>::COLUMNS]:,
{ {
let arithmetic_proof = timed!(
timing,
"prove Arithmetic STARK",
prove_single_table(
&all_stark.arithmetic_stark,
config,
&trace_poly_values[Table::Arithmetic as usize],
&trace_commitments[Table::Arithmetic as usize],
&ctl_data_per_table[Table::Arithmetic as usize],
challenger,
timing,
)?
);
let cpu_proof = timed!( let cpu_proof = timed!(
timing, timing,
"prove CPU STARK", "prove CPU STARK",
@ -257,6 +275,7 @@ where
)? )?
); );
Ok([ Ok([
arithmetic_proof,
cpu_proof, cpu_proof,
keccak_proof, keccak_proof,
keccak_sponge_proof, keccak_sponge_proof,

View File

@ -9,6 +9,7 @@ use plonky2::plonk::config::GenericConfig;
use plonky2::plonk::plonk_common::reduce_with_powers; use plonky2::plonk::plonk_common::reduce_with_powers;
use crate::all_stark::{AllStark, Table}; use crate::all_stark::{AllStark, Table};
use crate::arithmetic::arithmetic_stark::ArithmeticStark;
use crate::config::StarkConfig; use crate::config::StarkConfig;
use crate::constraint_consumer::ConstraintConsumer; use crate::constraint_consumer::ConstraintConsumer;
use crate::cpu::cpu_stark::CpuStark; use crate::cpu::cpu_stark::CpuStark;
@ -31,6 +32,7 @@ pub fn verify_proof<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, co
config: &StarkConfig, config: &StarkConfig,
) -> Result<()> ) -> Result<()>
where where
[(); ArithmeticStark::<F, D>::COLUMNS]:,
[(); CpuStark::<F, D>::COLUMNS]:, [(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:, [(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:, [(); KeccakSpongeStark::<F, D>::COLUMNS]:,
@ -45,6 +47,7 @@ where
let nums_permutation_zs = all_stark.nums_permutation_zs(config); let nums_permutation_zs = all_stark.nums_permutation_zs(config);
let AllStark { let AllStark {
arithmetic_stark,
cpu_stark, cpu_stark,
keccak_stark, keccak_stark,
keccak_sponge_stark, keccak_sponge_stark,
@ -60,6 +63,13 @@ where
&nums_permutation_zs, &nums_permutation_zs,
); );
verify_stark_proof_with_challenges(
arithmetic_stark,
&all_proof.stark_proofs[Table::Arithmetic as usize].proof,
&stark_challenges[Table::Arithmetic as usize],
&ctl_vars_per_table[Table::Arithmetic as usize],
config,
)?;
verify_stark_proof_with_challenges( verify_stark_proof_with_challenges(
cpu_stark, cpu_stark,
&all_proof.stark_proofs[Table::Cpu as usize].proof, &all_proof.stark_proofs[Table::Cpu as usize].proof,

View File

@ -18,19 +18,19 @@ use crate::{arithmetic, keccak, logic};
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct TraceCheckpoint { pub struct TraceCheckpoint {
pub(self) arithmetic_len: usize,
pub(self) cpu_len: usize, pub(self) cpu_len: usize,
pub(self) keccak_len: usize, pub(self) keccak_len: usize,
pub(self) keccak_sponge_len: usize, pub(self) keccak_sponge_len: usize,
pub(self) logic_len: usize, pub(self) logic_len: usize,
pub(self) arithmetic_len: usize,
pub(self) memory_len: usize, pub(self) memory_len: usize,
} }
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Traces<T: Copy> { pub(crate) struct Traces<T: Copy> {
pub(crate) arithmetic_ops: Vec<arithmetic::Operation>,
pub(crate) cpu: Vec<CpuColumnsView<T>>, pub(crate) cpu: Vec<CpuColumnsView<T>>,
pub(crate) logic_ops: Vec<logic::Operation>, pub(crate) logic_ops: Vec<logic::Operation>,
pub(crate) arithmetic: Vec<arithmetic::Operation>,
pub(crate) memory_ops: Vec<MemoryOp>, pub(crate) memory_ops: Vec<MemoryOp>,
pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>, pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>,
pub(crate) keccak_sponge_ops: Vec<KeccakSpongeOp>, pub(crate) keccak_sponge_ops: Vec<KeccakSpongeOp>,
@ -39,9 +39,9 @@ pub(crate) struct Traces<T: Copy> {
impl<T: Copy> Traces<T> { impl<T: Copy> Traces<T> {
pub fn new() -> Self { pub fn new() -> Self {
Traces { Traces {
arithmetic_ops: vec![],
cpu: vec![], cpu: vec![],
logic_ops: vec![], logic_ops: vec![],
arithmetic: vec![],
memory_ops: vec![], memory_ops: vec![],
keccak_inputs: vec![], keccak_inputs: vec![],
keccak_sponge_ops: vec![], keccak_sponge_ops: vec![],
@ -50,22 +50,22 @@ impl<T: Copy> Traces<T> {
pub fn checkpoint(&self) -> TraceCheckpoint { pub fn checkpoint(&self) -> TraceCheckpoint {
TraceCheckpoint { TraceCheckpoint {
arithmetic_len: self.arithmetic_ops.len(),
cpu_len: self.cpu.len(), cpu_len: self.cpu.len(),
keccak_len: self.keccak_inputs.len(), keccak_len: self.keccak_inputs.len(),
keccak_sponge_len: self.keccak_sponge_ops.len(), keccak_sponge_len: self.keccak_sponge_ops.len(),
logic_len: self.logic_ops.len(), logic_len: self.logic_ops.len(),
arithmetic_len: self.arithmetic.len(),
memory_len: self.memory_ops.len(), memory_len: self.memory_ops.len(),
} }
} }
pub fn rollback(&mut self, checkpoint: TraceCheckpoint) { pub fn rollback(&mut self, checkpoint: TraceCheckpoint) {
self.arithmetic_ops.truncate(checkpoint.arithmetic_len);
self.cpu.truncate(checkpoint.cpu_len); self.cpu.truncate(checkpoint.cpu_len);
self.keccak_inputs.truncate(checkpoint.keccak_len); self.keccak_inputs.truncate(checkpoint.keccak_len);
self.keccak_sponge_ops self.keccak_sponge_ops
.truncate(checkpoint.keccak_sponge_len); .truncate(checkpoint.keccak_sponge_len);
self.logic_ops.truncate(checkpoint.logic_len); self.logic_ops.truncate(checkpoint.logic_len);
self.arithmetic.truncate(checkpoint.arithmetic_len);
self.memory_ops.truncate(checkpoint.memory_len); self.memory_ops.truncate(checkpoint.memory_len);
} }
@ -82,7 +82,7 @@ impl<T: Copy> Traces<T> {
} }
pub fn push_arithmetic(&mut self, op: arithmetic::Operation) { pub fn push_arithmetic(&mut self, op: arithmetic::Operation) {
self.arithmetic.push(op); self.arithmetic_ops.push(op);
} }
pub fn push_memory(&mut self, op: MemoryOp) { pub fn push_memory(&mut self, op: MemoryOp) {
@ -122,14 +122,20 @@ impl<T: Copy> Traces<T> {
{ {
let cap_elements = config.fri_config.num_cap_elements(); let cap_elements = config.fri_config.num_cap_elements();
let Traces { let Traces {
arithmetic_ops,
cpu, cpu,
logic_ops, logic_ops,
arithmetic: _, // TODO
memory_ops, memory_ops,
keccak_inputs, keccak_inputs,
keccak_sponge_ops, keccak_sponge_ops,
} = self; } = self;
let arithmetic_trace = timed!(
timing,
"generate arithmetic trace",
all_stark.arithmetic_stark.generate_trace(arithmetic_ops)
);
let cpu_rows = cpu.into_iter().map(|x| x.into()).collect(); let cpu_rows = cpu.into_iter().map(|x| x.into()).collect();
let cpu_trace = trace_rows_to_poly_values(cpu_rows); let cpu_trace = trace_rows_to_poly_values(cpu_rows);
let keccak_trace = timed!( let keccak_trace = timed!(
@ -160,6 +166,7 @@ impl<T: Copy> Traces<T> {
); );
[ [
arithmetic_trace,
cpu_trace, cpu_trace,
keccak_trace, keccak_trace,
keccak_sponge_trace, keccak_sponge_trace,

View File

@ -97,7 +97,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> {
let all_circuits = AllRecursiveCircuits::<F, C, D>::new( let all_circuits = AllRecursiveCircuits::<F, C, D>::new(
&all_stark, &all_stark,
&[9..15, 9..15, 9..10, 9..12, 9..18], // Minimal ranges to prove an empty list &[9..18, 9..15, 9..15, 9..10, 9..12, 9..18], // Minimal ranges to prove an empty list
&config, &config,
); );