Merge branch 'main' into evm_generation

This commit is contained in:
Daniel Lubarov 2022-07-11 17:23:22 -07:00
commit ef3addea2c
6 changed files with 454 additions and 125 deletions

View File

@ -59,7 +59,7 @@ impl<F: RichField + Extendable<D>, const D: usize> AllStark<F, D> {
}
}
#[derive(Copy, Clone)]
#[derive(Debug, Copy, Clone)]
pub enum Table {
Cpu = 0,
Keccak = 1,
@ -132,7 +132,7 @@ mod tests {
use ethereum_types::U256;
use itertools::{izip, Itertools};
use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::Field;
use plonky2::field::types::{Field, PrimeField64};
use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::CircuitConfig;
@ -143,6 +143,7 @@ mod tests {
use crate::all_stark::AllStark;
use crate::config::StarkConfig;
use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::testutils::check_ctls;
use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS};
use crate::logic::{self, LogicStark, Operation};
use crate::memory::memory_stark::{generate_random_memory_ops, MemoryStark};
@ -197,9 +198,11 @@ mod tests {
num_memory_ops: usize,
memory_stark: &MemoryStark<F, D>,
rng: &mut R,
) -> Vec<PolynomialValues<F>> {
) -> (Vec<PolynomialValues<F>>, usize) {
let memory_ops = generate_random_memory_ops(num_memory_ops, rng);
memory_stark.generate_trace(memory_ops)
let trace = memory_stark.generate_trace(memory_ops);
let num_ops = trace[0].values.len();
(trace, num_ops)
}
fn make_cpu_trace(
@ -288,32 +291,34 @@ mod tests {
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
let mut current_cpu_index = 0;
let mut last_timestamp = memory_trace[memory::columns::TIMESTAMP].values[0];
for i in 0..num_memory_ops {
let mem_timestamp = memory_trace[memory::columns::TIMESTAMP].values[i];
let clock = mem_timestamp;
let op = (0..NUM_CHANNELS)
.filter(|&o| memory_trace[memory::columns::is_channel(o)].values[i] == F::ONE)
.collect_vec()[0];
let mem_timestamp: usize = memory_trace[memory::columns::TIMESTAMP].values[i]
.to_canonical_u64()
.try_into()
.unwrap();
let clock = mem_timestamp / NUM_CHANNELS;
let channel = mem_timestamp % NUM_CHANNELS;
if mem_timestamp != last_timestamp {
current_cpu_index += 1;
last_timestamp = mem_timestamp;
}
let is_padding_row = (0..NUM_CHANNELS)
.map(|c| memory_trace[memory::columns::is_channel(c)].values[i])
.all(|x| x == F::ZERO);
let row: &mut cpu::columns::CpuColumnsView<F> =
cpu_trace_rows[current_cpu_index].borrow_mut();
if !is_padding_row {
let row: &mut cpu::columns::CpuColumnsView<F> = cpu_trace_rows[clock].borrow_mut();
row.mem_channel_used[op] = F::ONE;
row.clock = clock;
row.mem_is_read[op] = memory_trace[memory::columns::IS_READ].values[i];
row.mem_addr_context[op] = memory_trace[memory::columns::ADDR_CONTEXT].values[i];
row.mem_addr_segment[op] = memory_trace[memory::columns::ADDR_SEGMENT].values[i];
row.mem_addr_virtual[op] = memory_trace[memory::columns::ADDR_VIRTUAL].values[i];
for j in 0..8 {
row.mem_value[op][j] = memory_trace[memory::columns::value_limb(j)].values[i];
row.mem_channel_used[channel] = F::ONE;
row.clock = F::from_canonical_usize(clock);
row.mem_is_read[channel] = memory_trace[memory::columns::IS_READ].values[i];
row.mem_addr_context[channel] =
memory_trace[memory::columns::ADDR_CONTEXT].values[i];
row.mem_addr_segment[channel] =
memory_trace[memory::columns::ADDR_SEGMENT].values[i];
row.mem_addr_virtual[channel] =
memory_trace[memory::columns::ADDR_VIRTUAL].values[i];
for j in 0..8 {
row.mem_value[channel][j] =
memory_trace[memory::columns::value_limb(j)].values[i];
}
}
}
@ -336,6 +341,8 @@ mod tests {
let keccak_trace = make_keccak_trace(num_keccak_perms, &all_stark.keccak_stark, &mut rng);
let logic_trace = make_logic_trace(num_logic_rows, &all_stark.logic_stark, &mut rng);
let mut memory_trace = make_memory_trace(num_memory_ops, &all_stark.memory_stark, &mut rng);
let mut memory_trace = mem_trace.0;
let num_memory_ops = mem_trace.1;
let cpu_trace = make_cpu_trace(
num_keccak_perms,
num_logic_rows,
@ -346,10 +353,13 @@ mod tests {
&mut memory_trace,
);
let traces = vec![cpu_trace, keccak_trace, logic_trace, memory_trace];
check_ctls(&traces, &all_stark.cross_table_lookups);
let proof = prove::<F, C, D>(
&all_stark,
config,
vec![cpu_trace, keccak_trace, logic_trace, memory_trace],
traces,
vec![vec![]; 4],
&mut TimingTree::default(),
)?;

View File

@ -40,7 +40,6 @@ pub fn ctl_filter_logic<F: Field>() -> Column<F> {
pub fn ctl_data_memory<F: Field>(channel: usize) -> Vec<Column<F>> {
debug_assert!(channel < NUM_CHANNELS);
let mut cols: Vec<Column<F>> = Column::singles([
COL_MAP.clock,
COL_MAP.mem_is_read[channel],
COL_MAP.mem_addr_context[channel],
COL_MAP.mem_addr_segment[channel],
@ -48,6 +47,14 @@ pub fn ctl_data_memory<F: Field>(channel: usize) -> Vec<Column<F>> {
])
.collect_vec();
cols.extend(Column::singles(COL_MAP.mem_value[channel]));
let scalar = F::from_canonical_usize(NUM_CHANNELS);
let addend = F::from_canonical_usize(channel);
cols.push(Column::linear_combination_with_constant(
vec![(COL_MAP.clock, scalar)],
addend,
));
cols
}

View File

@ -649,3 +649,137 @@ pub(crate) fn verify_cross_table_lookups_circuit<
}
debug_assert!(ctl_zs_openings.iter_mut().all(|iter| iter.next().is_none()));
}
#[cfg(test)]
pub(crate) mod testutils {
use std::collections::HashMap;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::Field;
use crate::all_stark::Table;
use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns};
type MultiSet<F> = HashMap<Vec<F>, Vec<(Table, usize)>>;
/// Check that the provided traces and cross-table lookups are consistent.
pub(crate) fn check_ctls<F: Field>(
trace_poly_values: &[Vec<PolynomialValues<F>>],
cross_table_lookups: &[CrossTableLookup<F>],
) {
for (i, ctl) in cross_table_lookups.iter().enumerate() {
check_ctl(trace_poly_values, ctl, i);
}
}
fn check_ctl<F: Field>(
trace_poly_values: &[Vec<PolynomialValues<F>>],
ctl: &CrossTableLookup<F>,
ctl_index: usize,
) {
let CrossTableLookup {
looking_tables,
looked_table,
default,
} = ctl;
// Maps `m` with `(table, i) in m[row]` iff the `i`-th row of `table` is equal to `row` and
// the filter is 1. Without default values, the CTL check holds iff `looking_multiset == looked_multiset`.
let mut looking_multiset = MultiSet::<F>::new();
let mut looked_multiset = MultiSet::<F>::new();
for table in looking_tables {
process_table(trace_poly_values, table, &mut looking_multiset);
}
process_table(trace_poly_values, looked_table, &mut looked_multiset);
let empty = &vec![];
// Check that every row in the looking tables appears in the looked table the same number of times
// with some special logic for the default row.
for (row, looking_locations) in &looking_multiset {
let looked_locations = looked_multiset.get(row).unwrap_or(empty);
if let Some(default) = default {
if row == default {
continue;
}
}
check_locations(looking_locations, looked_locations, ctl_index, row);
}
let extra_default_count = default.as_ref().map(|d| {
let looking_default_locations = looking_multiset.get(d).unwrap_or(empty);
let looked_default_locations = looked_multiset.get(d).unwrap_or(empty);
looking_default_locations
.len()
.checked_sub(looked_default_locations.len())
.unwrap_or_else(|| {
// If underflow, panic. There should be more default rows in the looking side.
check_locations(
looking_default_locations,
looked_default_locations,
ctl_index,
d,
);
unreachable!()
})
});
// Check that the number of extra default rows is correct.
if let Some(count) = extra_default_count {
assert_eq!(
count,
looking_tables
.iter()
.map(|table| trace_poly_values[table.table as usize][0].len())
.sum::<usize>()
- trace_poly_values[looked_table.table as usize][0].len()
);
}
// Check that every row in the looked tables appears in the looked table the same number of times.
for (row, looked_locations) in &looked_multiset {
let looking_locations = looking_multiset.get(row).unwrap_or(empty);
check_locations(looking_locations, looked_locations, ctl_index, row);
}
}
fn process_table<F: Field>(
trace_poly_values: &[Vec<PolynomialValues<F>>],
table: &TableWithColumns<F>,
multiset: &mut MultiSet<F>,
) {
let trace = &trace_poly_values[table.table as usize];
for i in 0..trace[0].len() {
let filter = if let Some(column) = &table.filter_column {
column.eval_table(trace, i)
} else {
F::ONE
};
if filter.is_one() {
let row = table
.columns
.iter()
.map(|c| c.eval_table(trace, i))
.collect::<Vec<_>>();
multiset.entry(row).or_default().push((table.table, i));
} else {
assert_eq!(filter, F::ZERO, "Non-binary filter?")
}
}
}
fn check_locations<F: Field>(
looking_locations: &[(Table, usize)],
looked_locations: &[(Table, usize)],
ctl_index: usize,
row: &[F],
) {
if looking_locations.len() != looked_locations.len() {
panic!(
"CTL #{ctl_index}:\n\
Row {row:?} is present {l0} times in the looking tables, but {l1} times in the looked table.\n\
Looking locations (Table, Row index): {looking_locations:?}.\n\
Looked locations (Table, Row index): {looked_locations:?}.",
l0 = looking_locations.len(),
l1 = looked_locations.len(),
);
}
}
}

View File

@ -1,5 +1,7 @@
//! Memory registers.
use std::ops::Range;
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
pub(crate) const TIMESTAMP: usize = 0;
@ -36,20 +38,22 @@ pub(crate) const CONTEXT_FIRST_CHANGE: usize = SORTED_VALUE_START + VALUE_LIMBS;
pub(crate) const SEGMENT_FIRST_CHANGE: usize = CONTEXT_FIRST_CHANGE + 1;
pub(crate) const VIRTUAL_FIRST_CHANGE: usize = SEGMENT_FIRST_CHANGE + 1;
// Flags to indicate if this operation came from the `i`th channel of the memory bus.
const IS_CHANNEL_START: usize = VIRTUAL_FIRST_CHANGE + 1;
pub(crate) const fn is_channel(channel: usize) -> usize {
debug_assert!(channel < NUM_CHANNELS);
IS_CHANNEL_START + channel
}
// We use a range check to ensure sorting.
pub(crate) const RANGE_CHECK: usize = VIRTUAL_FIRST_CHANGE + 1;
pub(crate) const RANGE_CHECK: usize = IS_CHANNEL_START + NUM_CHANNELS;
// The counter column (used for the range check) starts from 0 and increments.
pub(crate) const COUNTER: usize = RANGE_CHECK + 1;
// Helper columns for the permutation argument used to enforce the range check.
pub(crate) const RANGE_CHECK_PERMUTED: usize = COUNTER + 1;
pub(crate) const COUNTER_PERMUTED: usize = RANGE_CHECK_PERMUTED + 1;
// Flags to indicate if this operation came from the `i`th channel of the memory bus.
const IS_CHANNEL_START: usize = COUNTER_PERMUTED + 1;
#[allow(dead_code)]
pub(crate) const fn is_channel(channel: usize) -> usize {
debug_assert!(channel < NUM_CHANNELS);
IS_CHANNEL_START + channel
}
// Columns to be padded at the top with zeroes, before the permutation argument takes place.
pub(crate) const COLUMNS_TO_PAD: Range<usize> = TIMESTAMP..RANGE_CHECK + 1;
pub(crate) const NUM_COLUMNS: usize = IS_CHANNEL_START + NUM_CHANNELS;
pub(crate) const NUM_COLUMNS: usize = COUNTER_PERMUTED + 1;

View File

@ -11,17 +11,17 @@ use plonky2::timed;
use plonky2::util::timing::TimingTree;
use rand::Rng;
use super::columns::is_channel;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::Column;
use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols};
use crate::memory::columns::{
sorted_value_limb, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE,
COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED,
SEGMENT_FIRST_CHANGE, SORTED_ADDR_CONTEXT, SORTED_ADDR_SEGMENT, SORTED_ADDR_VIRTUAL,
SORTED_IS_READ, SORTED_TIMESTAMP, TIMESTAMP, VIRTUAL_FIRST_CHANGE,
is_channel, sorted_value_limb, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL,
COLUMNS_TO_PAD, CONTEXT_FIRST_CHANGE, COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS,
RANGE_CHECK, RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, SORTED_ADDR_CONTEXT,
SORTED_ADDR_SEGMENT, SORTED_ADDR_VIRTUAL, SORTED_IS_READ, SORTED_TIMESTAMP, TIMESTAMP,
VIRTUAL_FIRST_CHANGE,
};
use crate::memory::NUM_CHANNELS;
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
use crate::permutation::PermutationPair;
use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values;
@ -30,9 +30,10 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
pub(crate) const NUM_PUBLIC_INPUTS: usize = 0;
pub fn ctl_data<F: Field>() -> Vec<Column<F>> {
let mut res = Column::singles([TIMESTAMP, IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL])
.collect_vec();
let mut res =
Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec();
res.extend(Column::singles((0..8).map(value_limb)));
res.push(Column::single(TIMESTAMP));
res
}
@ -64,7 +65,7 @@ pub fn generate_random_memory_ops<F: RichField, R: Rng>(
let mut current_memory_values: HashMap<(usize, usize, usize), [F; 8]> = HashMap::new();
let num_cycles = num_ops / 2;
for i in 0..num_cycles {
for clock in 0..num_cycles {
let mut used_indices = HashSet::new();
let mut new_writes_this_cycle = HashMap::new();
let mut has_read = false;
@ -75,7 +76,7 @@ pub fn generate_random_memory_ops<F: RichField, R: Rng>(
}
used_indices.insert(channel_index);
let is_read = if i == 0 {
let is_read = if clock == 0 {
false
} else {
!has_read && rng.gen()
@ -110,9 +111,10 @@ pub fn generate_random_memory_ops<F: RichField, R: Rng>(
(context, segment, virt, vals)
};
let timestamp = clock * NUM_CHANNELS + channel_index;
memory_ops.push(MemoryOp {
channel_index,
timestamp: i,
timestamp,
is_read,
context,
segment,
@ -199,7 +201,7 @@ pub fn generate_range_check_value<F: RichField>(
context_first_change: &[F],
segment_first_change: &[F],
virtual_first_change: &[F],
) -> Vec<F> {
) -> (Vec<F>, usize) {
let num_ops = context.len();
let mut range_check = Vec::new();
@ -208,7 +210,6 @@ pub fn generate_range_check_value<F: RichField>(
- context_first_change[idx]
- segment_first_change[idx]
- virtual_first_change[idx];
range_check.push(
context_first_change[idx] * (context[idx + 1] - context[idx] - F::ONE)
+ segment_first_change[idx] * (segment[idx + 1] - segment[idx] - F::ONE)
@ -216,10 +217,11 @@ pub fn generate_range_check_value<F: RichField>(
+ this_address_unchanged * (timestamp[idx + 1] - timestamp[idx] - F::ONE),
);
}
range_check.push(F::ZERO);
range_check
let max_diff = range_check.iter().map(F::to_canonical_u64).max().unwrap() as usize;
(range_check, max_diff)
}
impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
@ -253,6 +255,9 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
self.generate_memory(&mut trace_cols);
// The number of rows may have changed, if the range check required padding.
let num_ops = trace_cols[0].len();
let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops];
for (i, col) in trace_cols.iter().enumerate() {
for (j, &val) in col.iter().enumerate() {
@ -294,7 +299,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
let (context_first_change, segment_first_change, virtual_first_change) =
generate_first_change_flags(&sorted_context, &sorted_segment, &sorted_virtual);
let range_check_value = generate_range_check_value(
let (range_check_value, max_diff) = generate_range_check_value(
&sorted_context,
&sorted_segment,
&sorted_virtual,
@ -303,6 +308,8 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
&segment_first_change,
&virtual_first_change,
);
let to_pad_to = (max_diff + 1).max(num_trace_rows).next_power_of_two();
let to_pad = to_pad_to - num_trace_rows;
trace_cols[SORTED_TIMESTAMP] = sorted_timestamp;
trace_cols[SORTED_IS_READ] = sorted_is_read;
@ -310,7 +317,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
trace_cols[SORTED_ADDR_SEGMENT] = sorted_segment;
trace_cols[SORTED_ADDR_VIRTUAL] = sorted_virtual;
for i in 0..num_trace_rows {
for j in 0..8 {
for j in 0..VALUE_LIMBS {
trace_cols[sorted_value_limb(j)][i] = sorted_values[i][j];
}
}
@ -320,9 +327,12 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change;
trace_cols[RANGE_CHECK] = range_check_value;
trace_cols[COUNTER] = (0..num_trace_rows)
.map(|i| F::from_canonical_usize(i))
.collect();
for col in COLUMNS_TO_PAD {
trace_cols[col].splice(0..0, vec![F::ZERO; to_pad]);
}
trace_cols[COUNTER] = (0..to_pad_to).map(|i| F::from_canonical_usize(i)).collect();
let (permuted_inputs, permuted_table) =
permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]);
@ -382,6 +392,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
.map(|i| vars.next_values[sorted_value_limb(i)])
.collect();
// Indicator that this is a real row, not a row of padding.
// TODO: enforce that all padding is at the beginning.
let valid_row: P = (0..NUM_CHANNELS)
.map(|c| vars.local_values[is_channel(c)])
.sum();
let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE];
let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE];
let virtual_first_change = vars.local_values[VIRTUAL_FIRST_CHANGE];
@ -408,15 +424,21 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
.constraint_transition(virtual_first_change * (next_addr_context - addr_context));
yield_constr
.constraint_transition(virtual_first_change * (next_addr_segment - addr_segment));
yield_constr.constraint_transition(address_unchanged * (next_addr_context - addr_context));
yield_constr.constraint_transition(address_unchanged * (next_addr_segment - addr_segment));
yield_constr.constraint_transition(address_unchanged * (next_addr_virtual - addr_virtual));
yield_constr.constraint_transition(
valid_row * address_unchanged * (next_addr_context - addr_context),
);
yield_constr.constraint_transition(
valid_row * address_unchanged * (next_addr_segment - addr_segment),
);
yield_constr.constraint_transition(
valid_row * address_unchanged * (next_addr_virtual - addr_virtual),
);
// Third set of ordering constraints: range-check difference in the column that should be increasing.
let computed_range_check = context_first_change * (next_addr_context - addr_context - one)
+ segment_first_change * (next_addr_segment - addr_segment - one)
+ virtual_first_change * (next_addr_virtual - addr_virtual - one)
+ address_unchanged * (next_timestamp - timestamp - one);
+ valid_row * address_unchanged * (next_timestamp - timestamp - one);
yield_constr.constraint_transition(range_check - computed_range_check);
// Enumerate purportedly-ordered log.
@ -453,6 +475,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let next_is_read = vars.next_values[SORTED_IS_READ];
let next_timestamp = vars.next_values[SORTED_TIMESTAMP];
// Indicator that this is a real row, not a row of padding.
let mut valid_row = vars.local_values[is_channel(0)];
for c in 1..NUM_CHANNELS {
valid_row = builder.add_extension(valid_row, vars.local_values[is_channel(c)]);
}
let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE];
let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE];
let virtual_first_change = vars.local_values[VIRTUAL_FIRST_CHANGE];
@ -497,11 +525,17 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.mul_extension(virtual_first_change, addr_segment_diff);
yield_constr.constraint_transition(builder, virtual_first_change_check_2);
let address_unchanged_check_1 = builder.mul_extension(address_unchanged, addr_context_diff);
yield_constr.constraint_transition(builder, address_unchanged_check_1);
let address_unchanged_check_1_valid =
builder.mul_extension(valid_row, address_unchanged_check_1);
yield_constr.constraint_transition(builder, address_unchanged_check_1_valid);
let address_unchanged_check_2 = builder.mul_extension(address_unchanged, addr_segment_diff);
yield_constr.constraint_transition(builder, address_unchanged_check_2);
let address_unchanged_check_2_valid =
builder.mul_extension(valid_row, address_unchanged_check_2);
yield_constr.constraint_transition(builder, address_unchanged_check_2_valid);
let address_unchanged_check_3 = builder.mul_extension(address_unchanged, addr_virtual_diff);
yield_constr.constraint_transition(builder, address_unchanged_check_3);
let address_unchanged_check_3_valid =
builder.mul_extension(valid_row, address_unchanged_check_3);
yield_constr.constraint_transition(builder, address_unchanged_check_3_valid);
// Third set of ordering constraints: range-check difference in the column that should be increasing.
let context_diff = {
@ -524,6 +558,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.sub_extension(diff, one)
};
let timestamp_range_check = builder.mul_extension(address_unchanged, timestamp_diff);
let timestamp_range_check = builder.mul_extension(valid_row, timestamp_range_check);
let computed_range_check = {
let mut sum = builder.add_extension(context_range_check, segment_range_check);
@ -555,7 +590,23 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
}
fn permutation_pairs(&self) -> Vec<PermutationPair> {
let mut unsorted_cols = vec![TIMESTAMP, IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL];
unsorted_cols.extend((0..VALUE_LIMBS).map(value_limb));
let mut sorted_cols = vec![
SORTED_TIMESTAMP,
SORTED_IS_READ,
SORTED_ADDR_CONTEXT,
SORTED_ADDR_SEGMENT,
SORTED_ADDR_VIRTUAL,
];
sorted_cols.extend((0..VALUE_LIMBS).map(sorted_value_limb));
let column_pairs: Vec<_> = unsorted_cols
.into_iter()
.zip(sorted_cols.iter().cloned())
.collect();
vec![
PermutationPair { column_pairs },
PermutationPair::singletons(RANGE_CHECK, RANGE_CHECK_PERMUTED),
PermutationPair::singletons(COUNTER, COUNTER_PERMUTED),
]

View File

@ -36,31 +36,36 @@ impl<F: RichField + Extendable<D>, const D: usize> U32ArithmeticGate<F, D> {
}
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
let wires_per_op = 5 + Self::num_limbs();
let routed_wires_per_op = 5;
(config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op)
let wires_per_op = Self::routed_wires_per_op() + Self::num_limbs();
(config.num_wires / wires_per_op).min(config.num_routed_wires / Self::routed_wires_per_op())
}
pub fn wire_ith_multiplicand_0(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i
Self::routed_wires_per_op() * i
}
pub fn wire_ith_multiplicand_1(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 1
Self::routed_wires_per_op() * i + 1
}
pub fn wire_ith_addend(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 2
Self::routed_wires_per_op() * i + 2
}
pub fn wire_ith_output_low_half(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 3
Self::routed_wires_per_op() * i + 3
}
pub fn wire_ith_output_high_half(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 4
Self::routed_wires_per_op() * i + 4
}
pub fn wire_ith_inverse(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
Self::routed_wires_per_op() * i + 5
}
pub fn limb_bits() -> usize {
@ -69,11 +74,13 @@ impl<F: RichField + Extendable<D>, const D: usize> U32ArithmeticGate<F, D> {
pub fn num_limbs() -> usize {
64 / Self::limb_bits()
}
pub fn routed_wires_per_op() -> usize {
6
}
pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize {
debug_assert!(i < self.num_ops);
debug_assert!(j < Self::num_limbs());
5 * self.num_ops + Self::num_limbs() * i + j
Self::routed_wires_per_op() * self.num_ops + Self::num_limbs() * i + j
}
}
@ -93,9 +100,28 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticG
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
let inverse = vars.local_wires[self.wire_ith_inverse(i)];
let base = F::Extension::from_canonical_u64(1 << 32u64);
let combined_output = output_high * base + output_low;
// Check canonicity of combined_output = output_high * 2^32 + output_low
let combined_output = {
let base = F::Extension::from_canonical_u64(1 << 32u64);
let one = F::Extension::ONE;
let u32_max = F::Extension::from_canonical_u32(u32::MAX);
// This is zero if and only if the high limb is `u32::MAX`.
// u32::MAX - output_high
let diff = u32_max - output_high;
// If this is zero, the diff is invertible, so the high limb is not `u32::MAX`.
// inverse * diff - 1
let hi_not_max = inverse * diff - one;
// If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero.
// hi_not_max * limb_0_u32
let hi_not_max_or_lo_zero = hi_not_max * output_low;
constraints.push(hi_not_max_or_lo_zero);
output_high * base + output_low
};
constraints.push(combined_output - computed_output);
@ -152,10 +178,27 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticG
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
let inverse = vars.local_wires[self.wire_ith_inverse(i)];
let base: F::Extension = F::from_canonical_u64(1 << 32u64).into();
let base_target = builder.constant_extension(base);
let combined_output = builder.mul_add_extension(output_high, base_target, output_low);
// Check canonicity of combined_output = output_high * 2^32 + output_low
let combined_output = {
let base: F::Extension = F::from_canonical_u64(1 << 32u64).into();
let base_target = builder.constant_extension(base);
let one = builder.one_extension();
let u32_max =
builder.constant_extension(F::Extension::from_canonical_u32(u32::MAX));
// This is zero if and only if the high limb is `u32::MAX`.
let diff = builder.sub_extension(u32_max, output_high);
// If this is zero, the diff is invertible, so the high limb is not `u32::MAX`.
let hi_not_max = builder.mul_sub_extension(inverse, diff, one);
// If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero.
let hi_not_max_or_lo_zero = builder.mul_extension(hi_not_max, output_low);
constraints.push(hi_not_max_or_lo_zero);
builder.mul_add_extension(output_high, base_target, output_low)
};
constraints.push(builder.sub_extension(combined_output, computed_output));
@ -211,7 +254,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticG
}
fn num_wires(&self) -> usize {
self.num_ops * (5 + Self::num_limbs())
self.num_ops * (Self::routed_wires_per_op() + Self::num_limbs())
}
fn num_constants(&self) -> usize {
@ -223,7 +266,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticG
}
fn num_constraints(&self) -> usize {
self.num_ops * (3 + Self::num_limbs())
self.num_ops * (4 + Self::num_limbs())
}
}
@ -244,9 +287,27 @@ impl<F: RichField + Extendable<D>, const D: usize> PackedEvaluableBase<F, D>
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
let inverse = vars.local_wires[self.wire_ith_inverse(i)];
let base = F::from_canonical_u64(1 << 32u64);
let combined_output = output_high * base + output_low;
let combined_output = {
let base = P::from(F::from_canonical_u64(1 << 32u64));
let one = P::ONES;
let u32_max = P::from(F::from_canonical_u32(u32::MAX));
// This is zero if and only if the high limb is `u32::MAX`.
// u32::MAX - output_high
let diff = u32_max - output_high;
// If this is zero, the diff is invertible, so the high limb is not `u32::MAX`.
// inverse * diff - 1
let hi_not_max = inverse * diff - one;
// If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero.
// hi_not_max * limb_0_u32
let hi_not_max_or_lo_zero = hi_not_max * output_low;
yield_constr.one(hi_not_max_or_lo_zero);
output_high * base + output_low
};
yield_constr.one(combined_output - computed_output);
@ -322,6 +383,15 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
out_buffer.set_wire(output_high_wire, output_high);
out_buffer.set_wire(output_low_wire, output_low);
let diff = u32::MAX as u64 - output_high_u64;
let inverse = if diff == 0 {
F::ZERO
} else {
F::from_canonical_u64(diff).inverse()
};
let inverse_wire = local_wire(self.gate.wire_ith_inverse(self.i));
out_buffer.set_wire(inverse_wire, inverse);
let num_limbs = U32ArithmeticGate::<F, D>::num_limbs();
let limb_base = 1 << U32ArithmeticGate::<F, D>::limb_bits();
let output_limbs_u64 = unfold((), move |_| {
@ -347,8 +417,10 @@ mod tests {
use plonky2::gates::gate::Gate;
use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree};
use plonky2::hash::hash_types::HashOut;
use plonky2::hash::hash_types::RichField;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use plonky2::plonk::vars::EvaluationVars;
use plonky2_field::extension::Extendable;
use plonky2_field::goldilocks_field::GoldilocksField;
use plonky2_field::types::Field;
use rand::Rng;
@ -374,6 +446,59 @@ mod tests {
})
}
fn get_wires<
F: RichField + Extendable<D>,
FF: From<F>,
const D: usize,
const NUM_U32_ARITHMETIC_OPS: usize,
>(
multiplicands_0: Vec<u64>,
multiplicands_1: Vec<u64>,
addends: Vec<u64>,
) -> Vec<FF> {
let mut v0 = Vec::new();
let mut v1 = Vec::new();
let limb_bits = U32ArithmeticGate::<F, D>::limb_bits();
let num_limbs = U32ArithmeticGate::<F, D>::num_limbs();
let limb_base = 1 << limb_bits;
for c in 0..NUM_U32_ARITHMETIC_OPS {
let m0 = multiplicands_0[c];
let m1 = multiplicands_1[c];
let a = addends[c];
let mut output = m0 * m1 + a;
let output_low = output & ((1 << 32) - 1);
let output_high = output >> 32;
let diff = u32::MAX as u64 - output_high;
let inverse = if diff == 0 {
F::ZERO
} else {
F::from_canonical_u64(diff).inverse()
};
let mut output_limbs = Vec::with_capacity(num_limbs);
for _i in 0..num_limbs {
output_limbs.push(output % limb_base);
output /= limb_base;
}
let mut output_limbs_f: Vec<_> = output_limbs
.into_iter()
.map(F::from_canonical_u64)
.collect();
v0.push(F::from_canonical_u64(m0));
v0.push(F::from_canonical_u64(m1));
v0.push(F::from_noncanonical_u64(a));
v0.push(F::from_canonical_u64(output_low));
v0.push(F::from_canonical_u64(output_high));
v0.push(inverse);
v1.append(&mut output_limbs_f);
}
v0.iter().chain(v1.iter()).map(|&x| x.into()).collect()
}
#[test]
fn test_gate_constraint() {
const D: usize = 2;
@ -382,47 +507,6 @@ mod tests {
type FF = <C as GenericConfig<D>>::FE;
const NUM_U32_ARITHMETIC_OPS: usize = 3;
fn get_wires(
multiplicands_0: Vec<u64>,
multiplicands_1: Vec<u64>,
addends: Vec<u64>,
) -> Vec<FF> {
let mut v0 = Vec::new();
let mut v1 = Vec::new();
let limb_bits = U32ArithmeticGate::<F, D>::limb_bits();
let num_limbs = U32ArithmeticGate::<F, D>::num_limbs();
let limb_base = 1 << limb_bits;
for c in 0..NUM_U32_ARITHMETIC_OPS {
let m0 = multiplicands_0[c];
let m1 = multiplicands_1[c];
let a = addends[c];
let mut output = m0 * m1 + a;
let output_low = output & ((1 << 32) - 1);
let output_high = output >> 32;
let mut output_limbs = Vec::with_capacity(num_limbs);
for _i in 0..num_limbs {
output_limbs.push(output % limb_base);
output /= limb_base;
}
let mut output_limbs_f: Vec<_> = output_limbs
.into_iter()
.map(F::from_canonical_u64)
.collect();
v0.push(F::from_canonical_u64(m0));
v0.push(F::from_canonical_u64(m1));
v0.push(F::from_canonical_u64(a));
v0.push(F::from_canonical_u64(output_low));
v0.push(F::from_canonical_u64(output_high));
v1.append(&mut output_limbs_f);
}
v0.iter().chain(v1.iter()).map(|&x| x.into()).collect()
}
let mut rng = rand::thread_rng();
let multiplicands_0: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS)
.map(|_| rng.gen::<u32>() as u64)
@ -441,7 +525,11 @@ mod tests {
let vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(multiplicands_0, multiplicands_1, addends),
local_wires: &get_wires::<F, FF, D, NUM_U32_ARITHMETIC_OPS>(
multiplicands_0,
multiplicands_1,
addends,
),
public_inputs_hash: &HashOut::rand(),
};
@ -450,4 +538,39 @@ mod tests {
"Gate constraints are not satisfied."
);
}
#[test]
fn test_canonicity() {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type FF = <C as GenericConfig<D>>::FE;
const NUM_U32_ARITHMETIC_OPS: usize = 3;
let multiplicands_0 = vec![0; NUM_U32_ARITHMETIC_OPS];
let multiplicands_1 = vec![0; NUM_U32_ARITHMETIC_OPS];
// A non-canonical addend will produce a non-canonical output using
// get_wires.
let addends = vec![0xFFFFFFFF00000001; NUM_U32_ARITHMETIC_OPS];
let gate = U32ArithmeticGate::<F, D> {
num_ops: NUM_U32_ARITHMETIC_OPS,
_phantom: PhantomData,
};
let vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires::<F, FF, D, NUM_U32_ARITHMETIC_OPS>(
multiplicands_0,
multiplicands_1,
addends,
),
public_inputs_hash: &HashOut::rand(),
};
assert!(
!gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
"Non-canonical output should not pass constraints."
);
}
}