Merge pull request #605 from mir-protocol/memory_misc

More realistic padding rows in memory table
This commit is contained in:
Daniel Lubarov 2022-07-13 10:55:04 -07:00 committed by GitHub
commit d36eda20e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 61 deletions

View File

@ -62,7 +62,7 @@ impl<F: Field> GenerationState<F> {
let context = self.current_context; let context = self.current_context;
let value = self.memory.contexts[context].segments[segment].get(virt); let value = self.memory.contexts[context].segments[segment].get(virt);
self.memory.log.push(MemoryOp { self.memory.log.push(MemoryOp {
channel_index, channel_index: Some(channel_index),
timestamp, timestamp,
is_read: true, is_read: true,
context, context,
@ -84,7 +84,7 @@ impl<F: Field> GenerationState<F> {
let timestamp = self.cpu_rows.len(); let timestamp = self.cpu_rows.len();
let context = self.current_context; let context = self.current_context;
self.memory.log.push(MemoryOp { self.memory.log.push(MemoryOp {
channel_index, channel_index: Some(channel_index),
timestamp, timestamp,
is_read: false, is_read: false,
context, context,

View File

@ -1,7 +1,5 @@
//! Memory registers. //! Memory registers.
use std::ops::Range;
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
// Columns for memory operations, ordered by (addr, timestamp). // Columns for memory operations, ordered by (addr, timestamp).
@ -41,7 +39,4 @@ pub(crate) const COUNTER: usize = RANGE_CHECK + 1;
pub(crate) const RANGE_CHECK_PERMUTED: usize = COUNTER + 1; pub(crate) const RANGE_CHECK_PERMUTED: usize = COUNTER + 1;
pub(crate) const COUNTER_PERMUTED: usize = RANGE_CHECK_PERMUTED + 1; pub(crate) const COUNTER_PERMUTED: usize = RANGE_CHECK_PERMUTED + 1;
// Columns to be padded at the top with zeroes, before the permutation argument takes place.
pub(crate) const COLUMNS_TO_PAD: Range<usize> = TIMESTAMP..RANGE_CHECK + 1;
pub(crate) const NUM_COLUMNS: usize = COUNTER_PERMUTED + 1; pub(crate) const NUM_COLUMNS: usize = COUNTER_PERMUTED + 1;

View File

@ -15,11 +15,11 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cross_table_lookup::Column; use crate::cross_table_lookup::Column;
use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols};
use crate::memory::columns::{ use crate::memory::columns::{
is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, COLUMNS_TO_PAD, is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE,
CONTEXT_FIRST_CHANGE, COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED,
RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE, SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE,
}; };
use crate::memory::NUM_CHANNELS; use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
use crate::permutation::PermutationPair; use crate::permutation::PermutationPair;
use crate::stark::Stark; use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values; use crate::util::trace_rows_to_poly_values;
@ -44,9 +44,10 @@ pub struct MemoryStark<F, const D: usize> {
pub(crate) f: PhantomData<F>, pub(crate) f: PhantomData<F>,
} }
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct MemoryOp<F> { pub struct MemoryOp<F> {
pub channel_index: usize, /// The channel this operation came from, or `None` if it's a dummy operation for padding.
pub channel_index: Option<usize>,
pub timestamp: usize, pub timestamp: usize,
pub is_read: bool, pub is_read: bool,
pub context: usize, pub context: usize,
@ -111,7 +112,7 @@ pub fn generate_random_memory_ops<F: RichField, R: Rng>(
let timestamp = clock * NUM_CHANNELS + channel_index; let timestamp = clock * NUM_CHANNELS + channel_index;
memory_ops.push(MemoryOp { memory_ops.push(MemoryOp {
channel_index, channel_index: Some(channel_index),
timestamp, timestamp,
is_read, is_read,
context, context,
@ -128,6 +129,25 @@ pub fn generate_random_memory_ops<F: RichField, R: Rng>(
memory_ops memory_ops
} }
fn get_max_range_check<F: Field>(memory_ops: &[MemoryOp<F>]) -> usize {
memory_ops
.iter()
.tuple_windows()
.map(|(curr, next)| {
if curr.context != next.context {
next.context - curr.context - 1
} else if curr.segment != next.segment {
next.segment - curr.segment - 1
} else if curr.virt != next.virt {
next.virt - curr.virt - 1
} else {
next.timestamp - curr.timestamp - 1
}
})
.max()
.unwrap_or(0)
}
pub fn generate_first_change_flags<F: RichField>( pub fn generate_first_change_flags<F: RichField>(
context: &[F], context: &[F],
segment: &[F], segment: &[F],
@ -169,7 +189,7 @@ pub fn generate_range_check_value<F: RichField>(
context_first_change: &[F], context_first_change: &[F],
segment_first_change: &[F], segment_first_change: &[F],
virtual_first_change: &[F], virtual_first_change: &[F],
) -> (Vec<F>, usize) { ) -> Vec<F> {
let num_ops = context.len(); let num_ops = context.len();
let mut range_check = Vec::new(); let mut range_check = Vec::new();
@ -187,9 +207,7 @@ pub fn generate_range_check_value<F: RichField>(
} }
range_check.push(F::ZERO); range_check.push(F::ZERO);
let max_diff = range_check.iter().map(F::to_canonical_u64).max().unwrap() as usize; range_check
(range_check, max_diff)
} }
impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> { impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
@ -198,7 +216,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
mut memory_ops: Vec<MemoryOp<F>>, mut memory_ops: Vec<MemoryOp<F>>,
) -> Vec<[F; NUM_COLUMNS]> { ) -> Vec<[F; NUM_COLUMNS]> {
memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp)); memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp));
Self::pad_memory_ops(&mut memory_ops);
let num_ops = memory_ops.len(); let num_ops = memory_ops.len();
let mut trace_cols = [(); NUM_COLUMNS].map(|_| vec![F::ZERO; num_ops]); let mut trace_cols = [(); NUM_COLUMNS].map(|_| vec![F::ZERO; num_ops]);
@ -212,22 +230,21 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
virt, virt,
value, value,
} = memory_ops[i]; } = memory_ops[i];
trace_cols[is_channel(channel_index)][i] = F::ONE; if let Some(channel) = channel_index {
trace_cols[is_channel(channel)][i] = F::ONE;
}
trace_cols[TIMESTAMP][i] = F::from_canonical_usize(timestamp); trace_cols[TIMESTAMP][i] = F::from_canonical_usize(timestamp);
trace_cols[IS_READ][i] = F::from_bool(is_read); trace_cols[IS_READ][i] = F::from_bool(is_read);
trace_cols[ADDR_CONTEXT][i] = F::from_canonical_usize(context); trace_cols[ADDR_CONTEXT][i] = F::from_canonical_usize(context);
trace_cols[ADDR_SEGMENT][i] = F::from_canonical_usize(segment); trace_cols[ADDR_SEGMENT][i] = F::from_canonical_usize(segment);
trace_cols[ADDR_VIRTUAL][i] = F::from_canonical_usize(virt); trace_cols[ADDR_VIRTUAL][i] = F::from_canonical_usize(virt);
for j in 0..8 { for j in 0..VALUE_LIMBS {
trace_cols[value_limb(j)][i] = value[j]; trace_cols[value_limb(j)][i] = value[j];
} }
} }
self.generate_memory(&mut trace_cols); self.generate_memory(&mut trace_cols);
// The number of rows may have changed, if the range check required padding.
let num_ops = trace_cols[0].len();
let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops]; let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops];
for (i, col) in trace_cols.iter().enumerate() { for (i, col) in trace_cols.iter().enumerate() {
for (j, &val) in col.iter().enumerate() { for (j, &val) in col.iter().enumerate() {
@ -237,6 +254,29 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
trace_rows trace_rows
} }
fn pad_memory_ops(memory_ops: &mut Vec<MemoryOp<F>>) {
let num_ops = memory_ops.len();
let max_range_check = get_max_range_check(memory_ops);
let num_ops_padded = num_ops.max(max_range_check + 1).next_power_of_two();
let to_pad = num_ops_padded - num_ops;
let last_op = memory_ops.last().expect("No memory ops?").clone();
// We essentially repeat the last operation until our operation list has the desired size,
// with a few changes:
// - We change its channel to `None` to indicate that this is a dummy operation.
// - We increment its timestamp in order to pass the ordering check.
// - We make sure it's a read, sine dummy operations must be reads.
for i in 0..to_pad {
memory_ops.push(MemoryOp {
channel_index: None,
timestamp: last_op.timestamp + i + 1,
is_read: true,
..last_op
});
}
}
fn generate_memory(&self, trace_cols: &mut [Vec<F>]) { fn generate_memory(&self, trace_cols: &mut [Vec<F>]) {
let num_trace_rows = trace_cols[0].len(); let num_trace_rows = trace_cols[0].len();
@ -248,7 +288,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
let (context_first_change, segment_first_change, virtual_first_change) = let (context_first_change, segment_first_change, virtual_first_change) =
generate_first_change_flags(context, segment, virtuals); generate_first_change_flags(context, segment, virtuals);
let (range_check_value, max_diff) = generate_range_check_value( trace_cols[RANGE_CHECK] = generate_range_check_value(
context, context,
segment, segment,
virtuals, virtuals,
@ -257,20 +297,14 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
&segment_first_change, &segment_first_change,
&virtual_first_change, &virtual_first_change,
); );
let to_pad_to = (max_diff + 1).max(num_trace_rows).next_power_of_two();
let to_pad = to_pad_to - num_trace_rows;
trace_cols[CONTEXT_FIRST_CHANGE] = context_first_change; trace_cols[CONTEXT_FIRST_CHANGE] = context_first_change;
trace_cols[SEGMENT_FIRST_CHANGE] = segment_first_change; trace_cols[SEGMENT_FIRST_CHANGE] = segment_first_change;
trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change; trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change;
trace_cols[RANGE_CHECK] = range_check_value; trace_cols[COUNTER] = (0..num_trace_rows)
.map(|i| F::from_canonical_usize(i))
for col in COLUMNS_TO_PAD { .collect();
trace_cols[col].splice(0..0, vec![F::ZERO; to_pad]);
}
trace_cols[COUNTER] = (0..to_pad_to).map(|i| F::from_canonical_usize(i)).collect();
let (permuted_inputs, permuted_table) = let (permuted_inputs, permuted_table) =
permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]); permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]);
@ -326,11 +360,23 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let next_addr_virtual = vars.next_values[ADDR_VIRTUAL]; let next_addr_virtual = vars.next_values[ADDR_VIRTUAL];
let next_values: Vec<_> = (0..8).map(|i| vars.next_values[value_limb(i)]).collect(); let next_values: Vec<_> = (0..8).map(|i| vars.next_values[value_limb(i)]).collect();
// Indicator that this is a real row, not a row of padding. // Each `is_channel` value must be 0 or 1.
// TODO: enforce that all padding is at the beginning. for c in 0..NUM_CHANNELS {
let valid_row: P = (0..NUM_CHANNELS) let is_channel = vars.local_values[is_channel(c)];
yield_constr.constraint(is_channel * (is_channel - P::ONES));
}
// The sum of `is_channel` flags, `has_channel`, must also be 0 or 1.
let has_channel: P = (0..NUM_CHANNELS)
.map(|c| vars.local_values[is_channel(c)]) .map(|c| vars.local_values[is_channel(c)])
.sum(); .sum();
yield_constr.constraint(has_channel * (has_channel - P::ONES));
// If this is a dummy row (with no channel), it must be a read. This means the prover can
// insert reads which never appear in the CPU trace (which are harmless), but not writes.
let is_dummy = P::ONES - has_channel;
let is_write = P::ONES - vars.local_values[IS_READ];
yield_constr.constraint(is_dummy * is_write);
let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE]; let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE];
let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE]; let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE];
@ -358,21 +404,15 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
.constraint_transition(virtual_first_change * (next_addr_context - addr_context)); .constraint_transition(virtual_first_change * (next_addr_context - addr_context));
yield_constr yield_constr
.constraint_transition(virtual_first_change * (next_addr_segment - addr_segment)); .constraint_transition(virtual_first_change * (next_addr_segment - addr_segment));
yield_constr.constraint_transition( yield_constr.constraint_transition(address_unchanged * (next_addr_context - addr_context));
valid_row * address_unchanged * (next_addr_context - addr_context), yield_constr.constraint_transition(address_unchanged * (next_addr_segment - addr_segment));
); yield_constr.constraint_transition(address_unchanged * (next_addr_virtual - addr_virtual));
yield_constr.constraint_transition(
valid_row * address_unchanged * (next_addr_segment - addr_segment),
);
yield_constr.constraint_transition(
valid_row * address_unchanged * (next_addr_virtual - addr_virtual),
);
// Third set of ordering constraints: range-check difference in the column that should be increasing. // Third set of ordering constraints: range-check difference in the column that should be increasing.
let computed_range_check = context_first_change * (next_addr_context - addr_context - one) let computed_range_check = context_first_change * (next_addr_context - addr_context - one)
+ segment_first_change * (next_addr_segment - addr_segment - one) + segment_first_change * (next_addr_segment - addr_segment - one)
+ virtual_first_change * (next_addr_virtual - addr_virtual - one) + virtual_first_change * (next_addr_virtual - addr_virtual - one)
+ valid_row * address_unchanged * (next_timestamp - timestamp - one); + address_unchanged * (next_timestamp - timestamp - one);
yield_constr.constraint_transition(range_check - computed_range_check); yield_constr.constraint_transition(range_check - computed_range_check);
// Enumerate purportedly-ordered log. // Enumerate purportedly-ordered log.
@ -405,12 +445,26 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let next_is_read = vars.next_values[IS_READ]; let next_is_read = vars.next_values[IS_READ];
let next_timestamp = vars.next_values[TIMESTAMP]; let next_timestamp = vars.next_values[TIMESTAMP];
// Indicator that this is a real row, not a row of padding. // Each `is_channel` value must be 0 or 1.
let mut valid_row = vars.local_values[is_channel(0)]; for c in 0..NUM_CHANNELS {
for c in 1..NUM_CHANNELS { let is_channel = vars.local_values[is_channel(c)];
valid_row = builder.add_extension(valid_row, vars.local_values[is_channel(c)]); let constraint = builder.mul_sub_extension(is_channel, is_channel, is_channel);
yield_constr.constraint(builder, constraint);
} }
// The sum of `is_channel` flags, `has_channel`, must also be 0 or 1.
let has_channel =
builder.add_many_extension((0..NUM_CHANNELS).map(|c| vars.local_values[is_channel(c)]));
let has_channel_bool = builder.mul_sub_extension(has_channel, has_channel, has_channel);
yield_constr.constraint(builder, has_channel_bool);
// If this is a dummy row (with no channel), it must be a read. This means the prover can
// insert reads which never appear in the CPU trace (which are harmless), but not writes.
let is_dummy = builder.sub_extension(one, has_channel);
let is_write = builder.sub_extension(one, vars.local_values[IS_READ]);
let is_dummy_write = builder.mul_extension(is_dummy, is_write);
yield_constr.constraint(builder, is_dummy_write);
let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE]; let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE];
let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE]; let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE];
let virtual_first_change = vars.local_values[VIRTUAL_FIRST_CHANGE]; let virtual_first_change = vars.local_values[VIRTUAL_FIRST_CHANGE];
@ -455,17 +509,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.mul_extension(virtual_first_change, addr_segment_diff); builder.mul_extension(virtual_first_change, addr_segment_diff);
yield_constr.constraint_transition(builder, virtual_first_change_check_2); yield_constr.constraint_transition(builder, virtual_first_change_check_2);
let address_unchanged_check_1 = builder.mul_extension(address_unchanged, addr_context_diff); let address_unchanged_check_1 = builder.mul_extension(address_unchanged, addr_context_diff);
let address_unchanged_check_1_valid = yield_constr.constraint_transition(builder, address_unchanged_check_1);
builder.mul_extension(valid_row, address_unchanged_check_1);
yield_constr.constraint_transition(builder, address_unchanged_check_1_valid);
let address_unchanged_check_2 = builder.mul_extension(address_unchanged, addr_segment_diff); let address_unchanged_check_2 = builder.mul_extension(address_unchanged, addr_segment_diff);
let address_unchanged_check_2_valid = yield_constr.constraint_transition(builder, address_unchanged_check_2);
builder.mul_extension(valid_row, address_unchanged_check_2);
yield_constr.constraint_transition(builder, address_unchanged_check_2_valid);
let address_unchanged_check_3 = builder.mul_extension(address_unchanged, addr_virtual_diff); let address_unchanged_check_3 = builder.mul_extension(address_unchanged, addr_virtual_diff);
let address_unchanged_check_3_valid = yield_constr.constraint_transition(builder, address_unchanged_check_3);
builder.mul_extension(valid_row, address_unchanged_check_3);
yield_constr.constraint_transition(builder, address_unchanged_check_3_valid);
// Third set of ordering constraints: range-check difference in the column that should be increasing. // Third set of ordering constraints: range-check difference in the column that should be increasing.
let context_diff = { let context_diff = {
@ -488,7 +536,6 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.sub_extension(diff, one) builder.sub_extension(diff, one)
}; };
let timestamp_range_check = builder.mul_extension(address_unchanged, timestamp_diff); let timestamp_range_check = builder.mul_extension(address_unchanged, timestamp_diff);
let timestamp_range_check = builder.mul_extension(valid_row, timestamp_range_check);
let computed_range_check = { let computed_range_check = {
let mut sum = builder.add_extension(context_range_check, segment_range_check); let mut sum = builder.add_extension(context_range_check, segment_range_check);