Merge pull request #605 from mir-protocol/memory_misc

More realistic padding rows in memory table
This commit is contained in:
Daniel Lubarov 2022-07-13 10:55:04 -07:00 committed by GitHub
commit d36eda20e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 61 deletions

View File

@ -62,7 +62,7 @@ impl<F: Field> GenerationState<F> {
let context = self.current_context;
let value = self.memory.contexts[context].segments[segment].get(virt);
self.memory.log.push(MemoryOp {
channel_index,
channel_index: Some(channel_index),
timestamp,
is_read: true,
context,
@ -84,7 +84,7 @@ impl<F: Field> GenerationState<F> {
let timestamp = self.cpu_rows.len();
let context = self.current_context;
self.memory.log.push(MemoryOp {
channel_index,
channel_index: Some(channel_index),
timestamp,
is_read: false,
context,

View File

@ -1,7 +1,5 @@
//! Memory registers.
use std::ops::Range;
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
// Columns for memory operations, ordered by (addr, timestamp).
@ -41,7 +39,4 @@ pub(crate) const COUNTER: usize = RANGE_CHECK + 1;
pub(crate) const RANGE_CHECK_PERMUTED: usize = COUNTER + 1;
pub(crate) const COUNTER_PERMUTED: usize = RANGE_CHECK_PERMUTED + 1;
// Columns to be padded at the top with zeroes, before the permutation argument takes place.
pub(crate) const COLUMNS_TO_PAD: Range<usize> = TIMESTAMP..RANGE_CHECK + 1;
pub(crate) const NUM_COLUMNS: usize = COUNTER_PERMUTED + 1;

View File

@ -15,11 +15,11 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cross_table_lookup::Column;
use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols};
use crate::memory::columns::{
is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, COLUMNS_TO_PAD,
CONTEXT_FIRST_CHANGE, COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK,
RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE,
is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE,
COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED,
SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE,
};
use crate::memory::NUM_CHANNELS;
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
use crate::permutation::PermutationPair;
use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values;
@ -44,9 +44,10 @@ pub struct MemoryStark<F, const D: usize> {
pub(crate) f: PhantomData<F>,
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct MemoryOp<F> {
pub channel_index: usize,
/// The channel this operation came from, or `None` if it's a dummy operation for padding.
pub channel_index: Option<usize>,
pub timestamp: usize,
pub is_read: bool,
pub context: usize,
@ -111,7 +112,7 @@ pub fn generate_random_memory_ops<F: RichField, R: Rng>(
let timestamp = clock * NUM_CHANNELS + channel_index;
memory_ops.push(MemoryOp {
channel_index,
channel_index: Some(channel_index),
timestamp,
is_read,
context,
@ -128,6 +129,25 @@ pub fn generate_random_memory_ops<F: RichField, R: Rng>(
memory_ops
}
fn get_max_range_check<F: Field>(memory_ops: &[MemoryOp<F>]) -> usize {
memory_ops
.iter()
.tuple_windows()
.map(|(curr, next)| {
if curr.context != next.context {
next.context - curr.context - 1
} else if curr.segment != next.segment {
next.segment - curr.segment - 1
} else if curr.virt != next.virt {
next.virt - curr.virt - 1
} else {
next.timestamp - curr.timestamp - 1
}
})
.max()
.unwrap_or(0)
}
pub fn generate_first_change_flags<F: RichField>(
context: &[F],
segment: &[F],
@ -169,7 +189,7 @@ pub fn generate_range_check_value<F: RichField>(
context_first_change: &[F],
segment_first_change: &[F],
virtual_first_change: &[F],
) -> (Vec<F>, usize) {
) -> Vec<F> {
let num_ops = context.len();
let mut range_check = Vec::new();
@ -187,9 +207,7 @@ pub fn generate_range_check_value<F: RichField>(
}
range_check.push(F::ZERO);
let max_diff = range_check.iter().map(F::to_canonical_u64).max().unwrap() as usize;
(range_check, max_diff)
range_check
}
impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
@ -198,7 +216,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
mut memory_ops: Vec<MemoryOp<F>>,
) -> Vec<[F; NUM_COLUMNS]> {
memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp));
Self::pad_memory_ops(&mut memory_ops);
let num_ops = memory_ops.len();
let mut trace_cols = [(); NUM_COLUMNS].map(|_| vec![F::ZERO; num_ops]);
@ -212,22 +230,21 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
virt,
value,
} = memory_ops[i];
trace_cols[is_channel(channel_index)][i] = F::ONE;
if let Some(channel) = channel_index {
trace_cols[is_channel(channel)][i] = F::ONE;
}
trace_cols[TIMESTAMP][i] = F::from_canonical_usize(timestamp);
trace_cols[IS_READ][i] = F::from_bool(is_read);
trace_cols[ADDR_CONTEXT][i] = F::from_canonical_usize(context);
trace_cols[ADDR_SEGMENT][i] = F::from_canonical_usize(segment);
trace_cols[ADDR_VIRTUAL][i] = F::from_canonical_usize(virt);
for j in 0..8 {
for j in 0..VALUE_LIMBS {
trace_cols[value_limb(j)][i] = value[j];
}
}
self.generate_memory(&mut trace_cols);
// The number of rows may have changed, if the range check required padding.
let num_ops = trace_cols[0].len();
let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops];
for (i, col) in trace_cols.iter().enumerate() {
for (j, &val) in col.iter().enumerate() {
@ -237,6 +254,29 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
trace_rows
}
fn pad_memory_ops(memory_ops: &mut Vec<MemoryOp<F>>) {
let num_ops = memory_ops.len();
let max_range_check = get_max_range_check(memory_ops);
let num_ops_padded = num_ops.max(max_range_check + 1).next_power_of_two();
let to_pad = num_ops_padded - num_ops;
let last_op = memory_ops.last().expect("No memory ops?").clone();
// We essentially repeat the last operation until our operation list has the desired size,
// with a few changes:
// - We change its channel to `None` to indicate that this is a dummy operation.
// - We increment its timestamp in order to pass the ordering check.
// - We make sure it's a read, sine dummy operations must be reads.
for i in 0..to_pad {
memory_ops.push(MemoryOp {
channel_index: None,
timestamp: last_op.timestamp + i + 1,
is_read: true,
..last_op
});
}
}
fn generate_memory(&self, trace_cols: &mut [Vec<F>]) {
let num_trace_rows = trace_cols[0].len();
@ -248,7 +288,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
let (context_first_change, segment_first_change, virtual_first_change) =
generate_first_change_flags(context, segment, virtuals);
let (range_check_value, max_diff) = generate_range_check_value(
trace_cols[RANGE_CHECK] = generate_range_check_value(
context,
segment,
virtuals,
@ -257,20 +297,14 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
&segment_first_change,
&virtual_first_change,
);
let to_pad_to = (max_diff + 1).max(num_trace_rows).next_power_of_two();
let to_pad = to_pad_to - num_trace_rows;
trace_cols[CONTEXT_FIRST_CHANGE] = context_first_change;
trace_cols[SEGMENT_FIRST_CHANGE] = segment_first_change;
trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change;
trace_cols[RANGE_CHECK] = range_check_value;
for col in COLUMNS_TO_PAD {
trace_cols[col].splice(0..0, vec![F::ZERO; to_pad]);
}
trace_cols[COUNTER] = (0..to_pad_to).map(|i| F::from_canonical_usize(i)).collect();
trace_cols[COUNTER] = (0..num_trace_rows)
.map(|i| F::from_canonical_usize(i))
.collect();
let (permuted_inputs, permuted_table) =
permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]);
@ -326,11 +360,23 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let next_addr_virtual = vars.next_values[ADDR_VIRTUAL];
let next_values: Vec<_> = (0..8).map(|i| vars.next_values[value_limb(i)]).collect();
// Indicator that this is a real row, not a row of padding.
// TODO: enforce that all padding is at the beginning.
let valid_row: P = (0..NUM_CHANNELS)
// Each `is_channel` value must be 0 or 1.
for c in 0..NUM_CHANNELS {
let is_channel = vars.local_values[is_channel(c)];
yield_constr.constraint(is_channel * (is_channel - P::ONES));
}
// The sum of `is_channel` flags, `has_channel`, must also be 0 or 1.
let has_channel: P = (0..NUM_CHANNELS)
.map(|c| vars.local_values[is_channel(c)])
.sum();
yield_constr.constraint(has_channel * (has_channel - P::ONES));
// If this is a dummy row (with no channel), it must be a read. This means the prover can
// insert reads which never appear in the CPU trace (which are harmless), but not writes.
let is_dummy = P::ONES - has_channel;
let is_write = P::ONES - vars.local_values[IS_READ];
yield_constr.constraint(is_dummy * is_write);
let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE];
let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE];
@ -358,21 +404,15 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
.constraint_transition(virtual_first_change * (next_addr_context - addr_context));
yield_constr
.constraint_transition(virtual_first_change * (next_addr_segment - addr_segment));
yield_constr.constraint_transition(
valid_row * address_unchanged * (next_addr_context - addr_context),
);
yield_constr.constraint_transition(
valid_row * address_unchanged * (next_addr_segment - addr_segment),
);
yield_constr.constraint_transition(
valid_row * address_unchanged * (next_addr_virtual - addr_virtual),
);
yield_constr.constraint_transition(address_unchanged * (next_addr_context - addr_context));
yield_constr.constraint_transition(address_unchanged * (next_addr_segment - addr_segment));
yield_constr.constraint_transition(address_unchanged * (next_addr_virtual - addr_virtual));
// Third set of ordering constraints: range-check difference in the column that should be increasing.
let computed_range_check = context_first_change * (next_addr_context - addr_context - one)
+ segment_first_change * (next_addr_segment - addr_segment - one)
+ virtual_first_change * (next_addr_virtual - addr_virtual - one)
+ valid_row * address_unchanged * (next_timestamp - timestamp - one);
+ address_unchanged * (next_timestamp - timestamp - one);
yield_constr.constraint_transition(range_check - computed_range_check);
// Enumerate purportedly-ordered log.
@ -405,12 +445,26 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let next_is_read = vars.next_values[IS_READ];
let next_timestamp = vars.next_values[TIMESTAMP];
// Indicator that this is a real row, not a row of padding.
let mut valid_row = vars.local_values[is_channel(0)];
for c in 1..NUM_CHANNELS {
valid_row = builder.add_extension(valid_row, vars.local_values[is_channel(c)]);
// Each `is_channel` value must be 0 or 1.
for c in 0..NUM_CHANNELS {
let is_channel = vars.local_values[is_channel(c)];
let constraint = builder.mul_sub_extension(is_channel, is_channel, is_channel);
yield_constr.constraint(builder, constraint);
}
// The sum of `is_channel` flags, `has_channel`, must also be 0 or 1.
let has_channel =
builder.add_many_extension((0..NUM_CHANNELS).map(|c| vars.local_values[is_channel(c)]));
let has_channel_bool = builder.mul_sub_extension(has_channel, has_channel, has_channel);
yield_constr.constraint(builder, has_channel_bool);
// If this is a dummy row (with no channel), it must be a read. This means the prover can
// insert reads which never appear in the CPU trace (which are harmless), but not writes.
let is_dummy = builder.sub_extension(one, has_channel);
let is_write = builder.sub_extension(one, vars.local_values[IS_READ]);
let is_dummy_write = builder.mul_extension(is_dummy, is_write);
yield_constr.constraint(builder, is_dummy_write);
let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE];
let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE];
let virtual_first_change = vars.local_values[VIRTUAL_FIRST_CHANGE];
@ -455,17 +509,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.mul_extension(virtual_first_change, addr_segment_diff);
yield_constr.constraint_transition(builder, virtual_first_change_check_2);
let address_unchanged_check_1 = builder.mul_extension(address_unchanged, addr_context_diff);
let address_unchanged_check_1_valid =
builder.mul_extension(valid_row, address_unchanged_check_1);
yield_constr.constraint_transition(builder, address_unchanged_check_1_valid);
yield_constr.constraint_transition(builder, address_unchanged_check_1);
let address_unchanged_check_2 = builder.mul_extension(address_unchanged, addr_segment_diff);
let address_unchanged_check_2_valid =
builder.mul_extension(valid_row, address_unchanged_check_2);
yield_constr.constraint_transition(builder, address_unchanged_check_2_valid);
yield_constr.constraint_transition(builder, address_unchanged_check_2);
let address_unchanged_check_3 = builder.mul_extension(address_unchanged, addr_virtual_diff);
let address_unchanged_check_3_valid =
builder.mul_extension(valid_row, address_unchanged_check_3);
yield_constr.constraint_transition(builder, address_unchanged_check_3_valid);
yield_constr.constraint_transition(builder, address_unchanged_check_3);
// Third set of ordering constraints: range-check difference in the column that should be increasing.
let context_diff = {
@ -488,7 +536,6 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.sub_extension(diff, one)
};
let timestamp_range_check = builder.mul_extension(address_unchanged, timestamp_diff);
let timestamp_range_check = builder.mul_extension(valid_row, timestamp_range_check);
let computed_range_check = {
let mut sum = builder.add_extension(context_range_check, segment_range_check);