Merge pull request #609 from mir-protocol/row_wise_memory_gen

Generate most of the memory table while it's in row-wise form
This commit is contained in:
Daniel Lubarov 2022-07-13 17:09:56 -07:00 committed by GitHub
commit 8751aaec7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,7 +9,9 @@ use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::timed;
use plonky2::util::timing::TimingTree;
use plonky2::util::transpose;
use rand::Rng;
use rayon::prelude::*;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::Column;
@ -22,7 +24,6 @@ use crate::memory::columns::{
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
use crate::permutation::PermutationPair;
use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
pub(crate) const NUM_PUBLIC_INPUTS: usize = 0;
@ -56,6 +57,28 @@ pub struct MemoryOp<F> {
pub value: [F; 8],
}
impl<F: Field> MemoryOp<F> {
/// Generate a row for a given memory operation. Note that this does not generate columns which
/// depend on the next operation, such as `CONTEXT_FIRST_CHANGE`; those are generated later.
/// It also does not generate columns such as `COUNTER`, which are generated later, after the
/// trace has been transposed into column-major form.
fn to_row(&self) -> [F; NUM_COLUMNS] {
let mut row = [F::ZERO; NUM_COLUMNS];
if let Some(channel) = self.channel_index {
row[is_channel(channel)] = F::ONE;
}
row[TIMESTAMP] = F::from_canonical_usize(self.timestamp);
row[IS_READ] = F::from_bool(self.is_read);
row[ADDR_CONTEXT] = F::from_canonical_usize(self.context);
row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment);
row[ADDR_VIRTUAL] = F::from_canonical_usize(self.virt);
for j in 0..VALUE_LIMBS {
row[value_limb(j)] = self.value[j];
}
row
}
}
pub fn generate_random_memory_ops<F: RichField, R: Rng>(
num_ops: usize,
rng: &mut R,
@ -148,112 +171,76 @@ fn get_max_range_check<F: Field>(memory_ops: &[MemoryOp<F>]) -> usize {
.unwrap_or(0)
}
pub fn generate_first_change_flags<F: RichField>(
context: &[F],
segment: &[F],
virtuals: &[F],
) -> (Vec<F>, Vec<F>, Vec<F>) {
let num_ops = context.len();
let mut context_first_change = Vec::with_capacity(num_ops);
let mut segment_first_change = Vec::with_capacity(num_ops);
let mut virtual_first_change = Vec::with_capacity(num_ops);
/// Generates the `_FIRST_CHANGE` columns and the `RANGE_CHECK` column in the trace.
pub fn generate_first_change_flags_and_rc<F: RichField>(trace_rows: &mut [[F; NUM_COLUMNS]]) {
let num_ops = trace_rows.len();
for idx in 0..num_ops - 1 {
let this_context_first_change = context[idx] != context[idx + 1];
let this_segment_first_change =
segment[idx] != segment[idx + 1] && !this_context_first_change;
let this_virtual_first_change = virtuals[idx] != virtuals[idx + 1]
&& !this_segment_first_change
&& !this_context_first_change;
let row = trace_rows[idx].as_slice();
let next_row = trace_rows[idx + 1].as_slice();
context_first_change.push(F::from_bool(this_context_first_change));
segment_first_change.push(F::from_bool(this_segment_first_change));
virtual_first_change.push(F::from_bool(this_virtual_first_change));
let context = row[ADDR_CONTEXT];
let segment = row[ADDR_SEGMENT];
let virt = row[ADDR_VIRTUAL];
let timestamp = row[TIMESTAMP];
let next_context = next_row[ADDR_CONTEXT];
let next_segment = next_row[ADDR_SEGMENT];
let next_virt = next_row[ADDR_VIRTUAL];
let next_timestamp = next_row[TIMESTAMP];
let context_changed = context != next_context;
let segment_changed = segment != next_segment;
let virtual_changed = virt != next_virt;
let context_first_change = context_changed;
let segment_first_change = segment_changed && !context_first_change;
let virtual_first_change =
virtual_changed && !segment_first_change && !context_first_change;
let row = trace_rows[idx].as_mut_slice();
row[CONTEXT_FIRST_CHANGE] = F::from_bool(context_first_change);
row[SEGMENT_FIRST_CHANGE] = F::from_bool(segment_first_change);
row[VIRTUAL_FIRST_CHANGE] = F::from_bool(virtual_first_change);
row[RANGE_CHECK] = if context_first_change {
next_context - context - F::ONE
} else if segment_first_change {
next_segment - segment - F::ONE
} else if virtual_first_change {
next_virt - virt - F::ONE
} else {
next_timestamp - timestamp - F::ONE
};
}
context_first_change.push(F::ZERO);
segment_first_change.push(F::ZERO);
virtual_first_change.push(F::ZERO);
(
context_first_change,
segment_first_change,
virtual_first_change,
)
}
pub fn generate_range_check_value<F: RichField>(
context: &[F],
segment: &[F],
virtuals: &[F],
timestamp: &[F],
context_first_change: &[F],
segment_first_change: &[F],
virtual_first_change: &[F],
) -> Vec<F> {
let num_ops = context.len();
let mut range_check = Vec::new();
for idx in 0..num_ops - 1 {
let this_address_unchanged = F::ONE
- context_first_change[idx]
- segment_first_change[idx]
- virtual_first_change[idx];
range_check.push(
context_first_change[idx] * (context[idx + 1] - context[idx] - F::ONE)
+ segment_first_change[idx] * (segment[idx + 1] - segment[idx] - F::ONE)
+ virtual_first_change[idx] * (virtuals[idx + 1] - virtuals[idx] - F::ONE)
+ this_address_unchanged * (timestamp[idx + 1] - timestamp[idx] - F::ONE),
);
}
range_check.push(F::ZERO);
range_check
}
impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
pub(crate) fn generate_trace_rows(
&self,
mut memory_ops: Vec<MemoryOp<F>>,
) -> Vec<[F; NUM_COLUMNS]> {
/// Generate most of the trace rows. Excludes a few columns like `COUNTER`, which are generated
/// later, after transposing to column-major form.
fn generate_trace_row_major(&self, mut memory_ops: Vec<MemoryOp<F>>) -> Vec<[F; NUM_COLUMNS]> {
memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp));
Self::pad_memory_ops(&mut memory_ops);
let num_ops = memory_ops.len();
let mut trace_cols = [(); NUM_COLUMNS].map(|_| vec![F::ZERO; num_ops]);
for i in 0..num_ops {
let MemoryOp {
channel_index,
timestamp,
is_read,
context,
segment,
virt,
value,
} = memory_ops[i];
if let Some(channel) = channel_index {
trace_cols[is_channel(channel)][i] = F::ONE;
}
trace_cols[TIMESTAMP][i] = F::from_canonical_usize(timestamp);
trace_cols[IS_READ][i] = F::from_bool(is_read);
trace_cols[ADDR_CONTEXT][i] = F::from_canonical_usize(context);
trace_cols[ADDR_SEGMENT][i] = F::from_canonical_usize(segment);
trace_cols[ADDR_VIRTUAL][i] = F::from_canonical_usize(virt);
for j in 0..VALUE_LIMBS {
trace_cols[value_limb(j)][i] = value[j];
}
}
self.generate_memory(&mut trace_cols);
let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops];
for (i, col) in trace_cols.iter().enumerate() {
for (j, &val) in col.iter().enumerate() {
trace_rows[j][i] = val;
}
}
let mut trace_rows = memory_ops
.into_par_iter()
.map(|op| op.to_row())
.collect::<Vec<_>>();
generate_first_change_flags_and_rc(trace_rows.as_mut_slice());
trace_rows
}
/// Generates the `COUNTER`, `RANGE_CHECK_PERMUTED` and `COUNTER_PERMUTED` columns, given a
/// trace in column-major form.
fn generate_trace_col_major(trace_col_vecs: &mut [Vec<F>]) {
let height = trace_col_vecs[0].len();
trace_col_vecs[COUNTER] = (0..height).map(|i| F::from_canonical_usize(i)).collect();
let (permuted_inputs, permuted_table) =
permuted_cols(&trace_col_vecs[RANGE_CHECK], &trace_col_vecs[COUNTER]);
trace_col_vecs[RANGE_CHECK_PERMUTED] = permuted_inputs;
trace_col_vecs[COUNTER_PERMUTED] = permuted_table;
}
fn pad_memory_ops(memory_ops: &mut Vec<MemoryOp<F>>) {
let num_ops = memory_ops.len();
let max_range_check = get_max_range_check(memory_ops);
@ -277,56 +264,27 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
}
}
fn generate_memory(&self, trace_cols: &mut [Vec<F>]) {
let num_trace_rows = trace_cols[0].len();
let timestamp = &trace_cols[TIMESTAMP];
let context = &trace_cols[ADDR_CONTEXT];
let segment = &trace_cols[ADDR_SEGMENT];
let virtuals = &trace_cols[ADDR_VIRTUAL];
let (context_first_change, segment_first_change, virtual_first_change) =
generate_first_change_flags(context, segment, virtuals);
trace_cols[RANGE_CHECK] = generate_range_check_value(
context,
segment,
virtuals,
timestamp,
&context_first_change,
&segment_first_change,
&virtual_first_change,
);
trace_cols[CONTEXT_FIRST_CHANGE] = context_first_change;
trace_cols[SEGMENT_FIRST_CHANGE] = segment_first_change;
trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change;
trace_cols[COUNTER] = (0..num_trace_rows)
.map(|i| F::from_canonical_usize(i))
.collect();
let (permuted_inputs, permuted_table) =
permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]);
trace_cols[RANGE_CHECK_PERMUTED] = permuted_inputs;
trace_cols[COUNTER_PERMUTED] = permuted_table;
}
pub fn generate_trace(&self, memory_ops: Vec<MemoryOp<F>>) -> Vec<PolynomialValues<F>> {
let mut timing = TimingTree::new("generate trace", log::Level::Debug);
// Generate the witness.
// Generate most of the trace in row-major form.
let trace_rows = timed!(
&mut timing,
"generate trace rows",
self.generate_trace_rows(memory_ops)
self.generate_trace_row_major(memory_ops)
);
let trace_row_vecs: Vec<_> = trace_rows.into_iter().map(|row| row.to_vec()).collect();
let trace_polys = timed!(
&mut timing,
"convert to PolynomialValues",
trace_rows_to_poly_values(trace_rows)
);
// Transpose to column-major form.
let mut trace_col_vecs = transpose(&trace_row_vecs);
// A few final generation steps, which work better in column-major form.
Self::generate_trace_col_major(&mut trace_col_vecs);
let trace_polys = trace_col_vecs
.into_iter()
.map(|column| PolynomialValues::new(column))
.collect();
timing.print();
trace_polys