Merge pull request #603 from mir-protocol/memory_simplifications

Simplify memory table
This commit is contained in:
Daniel Lubarov 2022-07-12 14:59:21 -07:00 committed by GitHub
commit f053932791
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 140 deletions

View File

@ -4,6 +4,7 @@ use std::ops::Range;
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
// Columns for memory operations, ordered by (addr, timestamp).
pub(crate) const TIMESTAMP: usize = 0;
pub(crate) const IS_READ: usize = TIMESTAMP + 1;
pub(crate) const ADDR_CONTEXT: usize = IS_READ + 1;
@ -17,24 +18,11 @@ pub(crate) const fn value_limb(i: usize) -> usize {
VALUE_START + i
}
// Separate columns for the same memory operations, sorted by (addr, timestamp).
pub(crate) const SORTED_TIMESTAMP: usize = VALUE_START + VALUE_LIMBS;
pub(crate) const SORTED_IS_READ: usize = SORTED_TIMESTAMP + 1;
pub(crate) const SORTED_ADDR_CONTEXT: usize = SORTED_IS_READ + 1;
pub(crate) const SORTED_ADDR_SEGMENT: usize = SORTED_ADDR_CONTEXT + 1;
pub(crate) const SORTED_ADDR_VIRTUAL: usize = SORTED_ADDR_SEGMENT + 1;
const SORTED_VALUE_START: usize = SORTED_ADDR_VIRTUAL + 1;
pub(crate) const fn sorted_value_limb(i: usize) -> usize {
debug_assert!(i < VALUE_LIMBS);
SORTED_VALUE_START + i
}
// Flags to indicate whether this part of the address differs from the next row (in the sorted
// columns), and the previous parts do not differ.
// That is, e.g., `SEGMENT_FIRST_CHANGE` is `F::ONE` iff `SORTED_ADDR_CONTEXT` is the same in this
// row and the next, but `SORTED_ADDR_SEGMENT` is not.
pub(crate) const CONTEXT_FIRST_CHANGE: usize = SORTED_VALUE_START + VALUE_LIMBS;
// Flags to indicate whether this part of the address differs from the next row,
// and the previous parts do not differ.
// That is, e.g., `SEGMENT_FIRST_CHANGE` is `F::ONE` iff `ADDR_CONTEXT` is the same in this
// row and the next, but `ADDR_SEGMENT` is not.
pub(crate) const CONTEXT_FIRST_CHANGE: usize = VALUE_START + VALUE_LIMBS;
pub(crate) const SEGMENT_FIRST_CHANGE: usize = CONTEXT_FIRST_CHANGE + 1;
pub(crate) const VIRTUAL_FIRST_CHANGE: usize = SEGMENT_FIRST_CHANGE + 1;
@ -45,7 +33,7 @@ pub(crate) const fn is_channel(channel: usize) -> usize {
IS_CHANNEL_START + channel
}
// We use a range check to ensure sorting.
// We use a range check to enforce the ordering.
pub(crate) const RANGE_CHECK: usize = IS_CHANNEL_START + NUM_CHANNELS;
// The counter column (used for the range check) starts from 0 and increments.
pub(crate) const COUNTER: usize = RANGE_CHECK + 1;

View File

@ -1,7 +1,7 @@
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
use itertools::{izip, multiunzip, Itertools};
use itertools::Itertools;
use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField;
use plonky2::field::polynomial::PolynomialValues;
@ -15,13 +15,11 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cross_table_lookup::Column;
use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols};
use crate::memory::columns::{
is_channel, sorted_value_limb, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL,
COLUMNS_TO_PAD, CONTEXT_FIRST_CHANGE, COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS,
RANGE_CHECK, RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, SORTED_ADDR_CONTEXT,
SORTED_ADDR_SEGMENT, SORTED_ADDR_VIRTUAL, SORTED_IS_READ, SORTED_TIMESTAMP, TIMESTAMP,
VIRTUAL_FIRST_CHANGE,
is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, COLUMNS_TO_PAD,
CONTEXT_FIRST_CHANGE, COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK,
RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE,
};
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
use crate::memory::NUM_CHANNELS;
use crate::permutation::PermutationPair;
use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values;
@ -130,36 +128,6 @@ pub fn generate_random_memory_ops<F: RichField, R: Rng>(
memory_ops
}
pub fn sort_memory_ops<F: RichField>(
timestamp: &[F],
is_read: &[F],
context: &[F],
segment: &[F],
virtuals: &[F],
values: &[[F; 8]],
) -> (Vec<F>, Vec<F>, Vec<F>, Vec<F>, Vec<F>, Vec<[F; 8]>) {
let mut ops: Vec<(F, F, F, F, F, [F; 8])> = izip!(
timestamp.iter().cloned(),
is_read.iter().cloned(),
context.iter().cloned(),
segment.iter().cloned(),
virtuals.iter().cloned(),
values.iter().cloned(),
)
.collect();
ops.sort_unstable_by_key(|&(t, _, c, s, v, _)| {
(
c.to_noncanonical_u64(),
s.to_noncanonical_u64(),
v.to_noncanonical_u64(),
t.to_noncanonical_u64(),
)
});
multiunzip(ops)
}
pub fn generate_first_change_flags<F: RichField>(
context: &[F],
segment: &[F],
@ -227,8 +195,10 @@ pub fn generate_range_check_value<F: RichField>(
impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
pub(crate) fn generate_trace_rows(
&self,
memory_ops: Vec<MemoryOp<F>>,
mut memory_ops: Vec<MemoryOp<F>>,
) -> Vec<[F; NUM_COLUMNS]> {
memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp));
let num_ops = memory_ops.len();
let mut trace_cols = [(); NUM_COLUMNS].map(|_| vec![F::ZERO; num_ops]);
@ -271,39 +241,18 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
let num_trace_rows = trace_cols[0].len();
let timestamp = &trace_cols[TIMESTAMP];
let is_read = &trace_cols[IS_READ];
let context = &trace_cols[ADDR_CONTEXT];
let segment = &trace_cols[ADDR_SEGMENT];
let virtuals = &trace_cols[ADDR_VIRTUAL];
let values: Vec<[F; 8]> = (0..num_trace_rows)
.map(|i| {
let arr: [F; 8] = (0..8)
.map(|j| &trace_cols[value_limb(j)][i])
.cloned()
.collect_vec()
.try_into()
.unwrap();
arr
})
.collect();
let (
sorted_timestamp,
sorted_is_read,
sorted_context,
sorted_segment,
sorted_virtual,
sorted_values,
) = sort_memory_ops(timestamp, is_read, context, segment, virtuals, &values);
let (context_first_change, segment_first_change, virtual_first_change) =
generate_first_change_flags(&sorted_context, &sorted_segment, &sorted_virtual);
generate_first_change_flags(context, segment, virtuals);
let (range_check_value, max_diff) = generate_range_check_value(
&sorted_context,
&sorted_segment,
&sorted_virtual,
&sorted_timestamp,
context,
segment,
virtuals,
timestamp,
&context_first_change,
&segment_first_change,
&virtual_first_change,
@ -311,17 +260,6 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
let to_pad_to = (max_diff + 1).max(num_trace_rows).next_power_of_two();
let to_pad = to_pad_to - num_trace_rows;
trace_cols[SORTED_TIMESTAMP] = sorted_timestamp;
trace_cols[SORTED_IS_READ] = sorted_is_read;
trace_cols[SORTED_ADDR_CONTEXT] = sorted_context;
trace_cols[SORTED_ADDR_SEGMENT] = sorted_segment;
trace_cols[SORTED_ADDR_VIRTUAL] = sorted_virtual;
for i in 0..num_trace_rows {
for j in 0..VALUE_LIMBS {
trace_cols[sorted_value_limb(j)][i] = sorted_values[i][j];
}
}
trace_cols[CONTEXT_FIRST_CHANGE] = context_first_change;
trace_cols[SEGMENT_FIRST_CHANGE] = segment_first_change;
trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change;
@ -375,22 +313,18 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
{
let one = P::from(FE::ONE);
let timestamp = vars.local_values[SORTED_TIMESTAMP];
let addr_context = vars.local_values[SORTED_ADDR_CONTEXT];
let addr_segment = vars.local_values[SORTED_ADDR_SEGMENT];
let addr_virtual = vars.local_values[SORTED_ADDR_VIRTUAL];
let values: Vec<_> = (0..8)
.map(|i| vars.local_values[sorted_value_limb(i)])
.collect();
let timestamp = vars.local_values[TIMESTAMP];
let addr_context = vars.local_values[ADDR_CONTEXT];
let addr_segment = vars.local_values[ADDR_SEGMENT];
let addr_virtual = vars.local_values[ADDR_VIRTUAL];
let values: Vec<_> = (0..8).map(|i| vars.local_values[value_limb(i)]).collect();
let next_timestamp = vars.next_values[SORTED_TIMESTAMP];
let next_is_read = vars.next_values[SORTED_IS_READ];
let next_addr_context = vars.next_values[SORTED_ADDR_CONTEXT];
let next_addr_segment = vars.next_values[SORTED_ADDR_SEGMENT];
let next_addr_virtual = vars.next_values[SORTED_ADDR_VIRTUAL];
let next_values: Vec<_> = (0..8)
.map(|i| vars.next_values[sorted_value_limb(i)])
.collect();
let next_timestamp = vars.next_values[TIMESTAMP];
let next_is_read = vars.next_values[IS_READ];
let next_addr_context = vars.next_values[ADDR_CONTEXT];
let next_addr_segment = vars.next_values[ADDR_SEGMENT];
let next_addr_virtual = vars.next_values[ADDR_VIRTUAL];
let next_values: Vec<_> = (0..8).map(|i| vars.next_values[value_limb(i)]).collect();
// Indicator that this is a real row, not a row of padding.
// TODO: enforce that all padding is at the beginning.
@ -458,22 +392,18 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
) {
let one = builder.one_extension();
let addr_context = vars.local_values[SORTED_ADDR_CONTEXT];
let addr_segment = vars.local_values[SORTED_ADDR_SEGMENT];
let addr_virtual = vars.local_values[SORTED_ADDR_VIRTUAL];
let values: Vec<_> = (0..8)
.map(|i| vars.local_values[sorted_value_limb(i)])
.collect();
let timestamp = vars.local_values[SORTED_TIMESTAMP];
let addr_context = vars.local_values[ADDR_CONTEXT];
let addr_segment = vars.local_values[ADDR_SEGMENT];
let addr_virtual = vars.local_values[ADDR_VIRTUAL];
let values: Vec<_> = (0..8).map(|i| vars.local_values[value_limb(i)]).collect();
let timestamp = vars.local_values[TIMESTAMP];
let next_addr_context = vars.next_values[SORTED_ADDR_CONTEXT];
let next_addr_segment = vars.next_values[SORTED_ADDR_SEGMENT];
let next_addr_virtual = vars.next_values[SORTED_ADDR_VIRTUAL];
let next_values: Vec<_> = (0..8)
.map(|i| vars.next_values[sorted_value_limb(i)])
.collect();
let next_is_read = vars.next_values[SORTED_IS_READ];
let next_timestamp = vars.next_values[SORTED_TIMESTAMP];
let next_addr_context = vars.next_values[ADDR_CONTEXT];
let next_addr_segment = vars.next_values[ADDR_SEGMENT];
let next_addr_virtual = vars.next_values[ADDR_VIRTUAL];
let next_values: Vec<_> = (0..8).map(|i| vars.next_values[value_limb(i)]).collect();
let next_is_read = vars.next_values[IS_READ];
let next_timestamp = vars.next_values[TIMESTAMP];
// Indicator that this is a real row, not a row of padding.
let mut valid_row = vars.local_values[is_channel(0)];
@ -590,23 +520,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
}
fn permutation_pairs(&self) -> Vec<PermutationPair> {
let mut unsorted_cols = vec![TIMESTAMP, IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL];
unsorted_cols.extend((0..VALUE_LIMBS).map(value_limb));
let mut sorted_cols = vec![
SORTED_TIMESTAMP,
SORTED_IS_READ,
SORTED_ADDR_CONTEXT,
SORTED_ADDR_SEGMENT,
SORTED_ADDR_VIRTUAL,
];
sorted_cols.extend((0..VALUE_LIMBS).map(sorted_value_limb));
let column_pairs: Vec<_> = unsorted_cols
.into_iter()
.zip(sorted_cols.iter().cloned())
.collect();
vec![
PermutationPair { column_pairs },
PermutationPair::singletons(RANGE_CHECK, RANGE_CHECK_PERMUTED),
PermutationPair::singletons(COUNTER, COUNTER_PERMUTED),
]