mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-04 06:43:07 +00:00
Merge pull request #609 from mir-protocol/row_wise_memory_gen
Generate most of the memory table while it's in row-wise form
This commit is contained in:
commit
8751aaec7a
@ -9,7 +9,9 @@ use plonky2::field::types::Field;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::timed;
|
||||
use plonky2::util::timing::TimingTree;
|
||||
use plonky2::util::transpose;
|
||||
use rand::Rng;
|
||||
use rayon::prelude::*;
|
||||
|
||||
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
|
||||
use crate::cross_table_lookup::Column;
|
||||
@ -22,7 +24,6 @@ use crate::memory::columns::{
|
||||
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
|
||||
use crate::permutation::PermutationPair;
|
||||
use crate::stark::Stark;
|
||||
use crate::util::trace_rows_to_poly_values;
|
||||
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
|
||||
|
||||
pub(crate) const NUM_PUBLIC_INPUTS: usize = 0;
|
||||
@ -56,6 +57,28 @@ pub struct MemoryOp<F> {
|
||||
pub value: [F; 8],
|
||||
}
|
||||
|
||||
impl<F: Field> MemoryOp<F> {
|
||||
/// Generate a row for a given memory operation. Note that this does not generate columns which
|
||||
/// depend on the next operation, such as `CONTEXT_FIRST_CHANGE`; those are generated later.
|
||||
/// It also does not generate columns such as `COUNTER`, which are generated later, after the
|
||||
/// trace has been transposed into column-major form.
|
||||
fn to_row(&self) -> [F; NUM_COLUMNS] {
|
||||
let mut row = [F::ZERO; NUM_COLUMNS];
|
||||
if let Some(channel) = self.channel_index {
|
||||
row[is_channel(channel)] = F::ONE;
|
||||
}
|
||||
row[TIMESTAMP] = F::from_canonical_usize(self.timestamp);
|
||||
row[IS_READ] = F::from_bool(self.is_read);
|
||||
row[ADDR_CONTEXT] = F::from_canonical_usize(self.context);
|
||||
row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment);
|
||||
row[ADDR_VIRTUAL] = F::from_canonical_usize(self.virt);
|
||||
for j in 0..VALUE_LIMBS {
|
||||
row[value_limb(j)] = self.value[j];
|
||||
}
|
||||
row
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_random_memory_ops<F: RichField, R: Rng>(
|
||||
num_ops: usize,
|
||||
rng: &mut R,
|
||||
@ -148,112 +171,76 @@ fn get_max_range_check<F: Field>(memory_ops: &[MemoryOp<F>]) -> usize {
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn generate_first_change_flags<F: RichField>(
|
||||
context: &[F],
|
||||
segment: &[F],
|
||||
virtuals: &[F],
|
||||
) -> (Vec<F>, Vec<F>, Vec<F>) {
|
||||
let num_ops = context.len();
|
||||
let mut context_first_change = Vec::with_capacity(num_ops);
|
||||
let mut segment_first_change = Vec::with_capacity(num_ops);
|
||||
let mut virtual_first_change = Vec::with_capacity(num_ops);
|
||||
/// Generates the `_FIRST_CHANGE` columns and the `RANGE_CHECK` column in the trace.
|
||||
pub fn generate_first_change_flags_and_rc<F: RichField>(trace_rows: &mut [[F; NUM_COLUMNS]]) {
|
||||
let num_ops = trace_rows.len();
|
||||
for idx in 0..num_ops - 1 {
|
||||
let this_context_first_change = context[idx] != context[idx + 1];
|
||||
let this_segment_first_change =
|
||||
segment[idx] != segment[idx + 1] && !this_context_first_change;
|
||||
let this_virtual_first_change = virtuals[idx] != virtuals[idx + 1]
|
||||
&& !this_segment_first_change
|
||||
&& !this_context_first_change;
|
||||
let row = trace_rows[idx].as_slice();
|
||||
let next_row = trace_rows[idx + 1].as_slice();
|
||||
|
||||
context_first_change.push(F::from_bool(this_context_first_change));
|
||||
segment_first_change.push(F::from_bool(this_segment_first_change));
|
||||
virtual_first_change.push(F::from_bool(this_virtual_first_change));
|
||||
let context = row[ADDR_CONTEXT];
|
||||
let segment = row[ADDR_SEGMENT];
|
||||
let virt = row[ADDR_VIRTUAL];
|
||||
let timestamp = row[TIMESTAMP];
|
||||
let next_context = next_row[ADDR_CONTEXT];
|
||||
let next_segment = next_row[ADDR_SEGMENT];
|
||||
let next_virt = next_row[ADDR_VIRTUAL];
|
||||
let next_timestamp = next_row[TIMESTAMP];
|
||||
|
||||
let context_changed = context != next_context;
|
||||
let segment_changed = segment != next_segment;
|
||||
let virtual_changed = virt != next_virt;
|
||||
|
||||
let context_first_change = context_changed;
|
||||
let segment_first_change = segment_changed && !context_first_change;
|
||||
let virtual_first_change =
|
||||
virtual_changed && !segment_first_change && !context_first_change;
|
||||
|
||||
let row = trace_rows[idx].as_mut_slice();
|
||||
row[CONTEXT_FIRST_CHANGE] = F::from_bool(context_first_change);
|
||||
row[SEGMENT_FIRST_CHANGE] = F::from_bool(segment_first_change);
|
||||
row[VIRTUAL_FIRST_CHANGE] = F::from_bool(virtual_first_change);
|
||||
|
||||
row[RANGE_CHECK] = if context_first_change {
|
||||
next_context - context - F::ONE
|
||||
} else if segment_first_change {
|
||||
next_segment - segment - F::ONE
|
||||
} else if virtual_first_change {
|
||||
next_virt - virt - F::ONE
|
||||
} else {
|
||||
next_timestamp - timestamp - F::ONE
|
||||
};
|
||||
}
|
||||
|
||||
context_first_change.push(F::ZERO);
|
||||
segment_first_change.push(F::ZERO);
|
||||
virtual_first_change.push(F::ZERO);
|
||||
|
||||
(
|
||||
context_first_change,
|
||||
segment_first_change,
|
||||
virtual_first_change,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn generate_range_check_value<F: RichField>(
|
||||
context: &[F],
|
||||
segment: &[F],
|
||||
virtuals: &[F],
|
||||
timestamp: &[F],
|
||||
context_first_change: &[F],
|
||||
segment_first_change: &[F],
|
||||
virtual_first_change: &[F],
|
||||
) -> Vec<F> {
|
||||
let num_ops = context.len();
|
||||
let mut range_check = Vec::new();
|
||||
|
||||
for idx in 0..num_ops - 1 {
|
||||
let this_address_unchanged = F::ONE
|
||||
- context_first_change[idx]
|
||||
- segment_first_change[idx]
|
||||
- virtual_first_change[idx];
|
||||
range_check.push(
|
||||
context_first_change[idx] * (context[idx + 1] - context[idx] - F::ONE)
|
||||
+ segment_first_change[idx] * (segment[idx + 1] - segment[idx] - F::ONE)
|
||||
+ virtual_first_change[idx] * (virtuals[idx + 1] - virtuals[idx] - F::ONE)
|
||||
+ this_address_unchanged * (timestamp[idx + 1] - timestamp[idx] - F::ONE),
|
||||
);
|
||||
}
|
||||
range_check.push(F::ZERO);
|
||||
|
||||
range_check
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
|
||||
pub(crate) fn generate_trace_rows(
|
||||
&self,
|
||||
mut memory_ops: Vec<MemoryOp<F>>,
|
||||
) -> Vec<[F; NUM_COLUMNS]> {
|
||||
/// Generate most of the trace rows. Excludes a few columns like `COUNTER`, which are generated
|
||||
/// later, after transposing to column-major form.
|
||||
fn generate_trace_row_major(&self, mut memory_ops: Vec<MemoryOp<F>>) -> Vec<[F; NUM_COLUMNS]> {
|
||||
memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp));
|
||||
|
||||
Self::pad_memory_ops(&mut memory_ops);
|
||||
let num_ops = memory_ops.len();
|
||||
|
||||
let mut trace_cols = [(); NUM_COLUMNS].map(|_| vec![F::ZERO; num_ops]);
|
||||
for i in 0..num_ops {
|
||||
let MemoryOp {
|
||||
channel_index,
|
||||
timestamp,
|
||||
is_read,
|
||||
context,
|
||||
segment,
|
||||
virt,
|
||||
value,
|
||||
} = memory_ops[i];
|
||||
if let Some(channel) = channel_index {
|
||||
trace_cols[is_channel(channel)][i] = F::ONE;
|
||||
}
|
||||
trace_cols[TIMESTAMP][i] = F::from_canonical_usize(timestamp);
|
||||
trace_cols[IS_READ][i] = F::from_bool(is_read);
|
||||
trace_cols[ADDR_CONTEXT][i] = F::from_canonical_usize(context);
|
||||
trace_cols[ADDR_SEGMENT][i] = F::from_canonical_usize(segment);
|
||||
trace_cols[ADDR_VIRTUAL][i] = F::from_canonical_usize(virt);
|
||||
for j in 0..VALUE_LIMBS {
|
||||
trace_cols[value_limb(j)][i] = value[j];
|
||||
}
|
||||
}
|
||||
|
||||
self.generate_memory(&mut trace_cols);
|
||||
|
||||
let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops];
|
||||
for (i, col) in trace_cols.iter().enumerate() {
|
||||
for (j, &val) in col.iter().enumerate() {
|
||||
trace_rows[j][i] = val;
|
||||
}
|
||||
}
|
||||
let mut trace_rows = memory_ops
|
||||
.into_par_iter()
|
||||
.map(|op| op.to_row())
|
||||
.collect::<Vec<_>>();
|
||||
generate_first_change_flags_and_rc(trace_rows.as_mut_slice());
|
||||
trace_rows
|
||||
}
|
||||
|
||||
/// Generates the `COUNTER`, `RANGE_CHECK_PERMUTED` and `COUNTER_PERMUTED` columns, given a
|
||||
/// trace in column-major form.
|
||||
fn generate_trace_col_major(trace_col_vecs: &mut [Vec<F>]) {
|
||||
let height = trace_col_vecs[0].len();
|
||||
trace_col_vecs[COUNTER] = (0..height).map(|i| F::from_canonical_usize(i)).collect();
|
||||
|
||||
let (permuted_inputs, permuted_table) =
|
||||
permuted_cols(&trace_col_vecs[RANGE_CHECK], &trace_col_vecs[COUNTER]);
|
||||
trace_col_vecs[RANGE_CHECK_PERMUTED] = permuted_inputs;
|
||||
trace_col_vecs[COUNTER_PERMUTED] = permuted_table;
|
||||
}
|
||||
|
||||
fn pad_memory_ops(memory_ops: &mut Vec<MemoryOp<F>>) {
|
||||
let num_ops = memory_ops.len();
|
||||
let max_range_check = get_max_range_check(memory_ops);
|
||||
@ -277,56 +264,27 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_memory(&self, trace_cols: &mut [Vec<F>]) {
|
||||
let num_trace_rows = trace_cols[0].len();
|
||||
|
||||
let timestamp = &trace_cols[TIMESTAMP];
|
||||
let context = &trace_cols[ADDR_CONTEXT];
|
||||
let segment = &trace_cols[ADDR_SEGMENT];
|
||||
let virtuals = &trace_cols[ADDR_VIRTUAL];
|
||||
|
||||
let (context_first_change, segment_first_change, virtual_first_change) =
|
||||
generate_first_change_flags(context, segment, virtuals);
|
||||
|
||||
trace_cols[RANGE_CHECK] = generate_range_check_value(
|
||||
context,
|
||||
segment,
|
||||
virtuals,
|
||||
timestamp,
|
||||
&context_first_change,
|
||||
&segment_first_change,
|
||||
&virtual_first_change,
|
||||
);
|
||||
|
||||
trace_cols[CONTEXT_FIRST_CHANGE] = context_first_change;
|
||||
trace_cols[SEGMENT_FIRST_CHANGE] = segment_first_change;
|
||||
trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change;
|
||||
|
||||
trace_cols[COUNTER] = (0..num_trace_rows)
|
||||
.map(|i| F::from_canonical_usize(i))
|
||||
.collect();
|
||||
|
||||
let (permuted_inputs, permuted_table) =
|
||||
permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]);
|
||||
trace_cols[RANGE_CHECK_PERMUTED] = permuted_inputs;
|
||||
trace_cols[COUNTER_PERMUTED] = permuted_table;
|
||||
}
|
||||
|
||||
pub fn generate_trace(&self, memory_ops: Vec<MemoryOp<F>>) -> Vec<PolynomialValues<F>> {
|
||||
let mut timing = TimingTree::new("generate trace", log::Level::Debug);
|
||||
|
||||
// Generate the witness.
|
||||
// Generate most of the trace in row-major form.
|
||||
let trace_rows = timed!(
|
||||
&mut timing,
|
||||
"generate trace rows",
|
||||
self.generate_trace_rows(memory_ops)
|
||||
self.generate_trace_row_major(memory_ops)
|
||||
);
|
||||
let trace_row_vecs: Vec<_> = trace_rows.into_iter().map(|row| row.to_vec()).collect();
|
||||
|
||||
let trace_polys = timed!(
|
||||
&mut timing,
|
||||
"convert to PolynomialValues",
|
||||
trace_rows_to_poly_values(trace_rows)
|
||||
);
|
||||
// Transpose to column-major form.
|
||||
let mut trace_col_vecs = transpose(&trace_row_vecs);
|
||||
|
||||
// A few final generation steps, which work better in column-major form.
|
||||
Self::generate_trace_col_major(&mut trace_col_vecs);
|
||||
|
||||
let trace_polys = trace_col_vecs
|
||||
.into_iter()
|
||||
.map(|column| PolynomialValues::new(column))
|
||||
.collect();
|
||||
|
||||
timing.print();
|
||||
trace_polys
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user