Merge pull request #599 from mir-protocol/memory_ctl

Separate timestamps per memory operation
This commit is contained in:
Nicholas Ward 2022-07-11 10:45:28 -07:00 committed by GitHub
commit 29ef56eb69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 130 additions and 63 deletions

View File

@ -119,7 +119,7 @@ mod tests {
use anyhow::Result;
use itertools::{izip, Itertools};
use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::Field;
use plonky2::field::types::{Field, PrimeField64};
use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::CircuitConfig;
@ -196,9 +196,11 @@ mod tests {
num_memory_ops: usize,
memory_stark: &MemoryStark<F, D>,
rng: &mut R,
) -> Vec<PolynomialValues<F>> {
) -> (Vec<PolynomialValues<F>>, usize) {
let memory_ops = generate_random_memory_ops(num_memory_ops, rng);
memory_stark.generate_trace(memory_ops)
let trace = memory_stark.generate_trace(memory_ops);
let num_ops = trace[0].values.len();
(trace, num_ops)
}
fn make_cpu_trace(
@ -282,32 +284,34 @@ mod tests {
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
let mut current_cpu_index = 0;
let mut last_timestamp = memory_trace[memory::columns::TIMESTAMP].values[0];
for i in 0..num_memory_ops {
let mem_timestamp = memory_trace[memory::columns::TIMESTAMP].values[i];
let clock = mem_timestamp;
let op = (0..NUM_CHANNELS)
.filter(|&o| memory_trace[memory::columns::is_channel(o)].values[i] == F::ONE)
.collect_vec()[0];
let mem_timestamp: usize = memory_trace[memory::columns::TIMESTAMP].values[i]
.to_canonical_u64()
.try_into()
.unwrap();
let clock = mem_timestamp / NUM_CHANNELS;
let channel = mem_timestamp % NUM_CHANNELS;
if mem_timestamp != last_timestamp {
current_cpu_index += 1;
last_timestamp = mem_timestamp;
}
let is_padding_row = (0..NUM_CHANNELS)
.map(|c| memory_trace[memory::columns::is_channel(c)].values[i])
.all(|x| x == F::ZERO);
let row: &mut cpu::columns::CpuColumnsView<F> =
cpu_trace_rows[current_cpu_index].borrow_mut();
if !is_padding_row {
let row: &mut cpu::columns::CpuColumnsView<F> = cpu_trace_rows[clock].borrow_mut();
row.mem_channel_used[op] = F::ONE;
row.clock = clock;
row.mem_is_read[op] = memory_trace[memory::columns::IS_READ].values[i];
row.mem_addr_context[op] = memory_trace[memory::columns::ADDR_CONTEXT].values[i];
row.mem_addr_segment[op] = memory_trace[memory::columns::ADDR_SEGMENT].values[i];
row.mem_addr_virtual[op] = memory_trace[memory::columns::ADDR_VIRTUAL].values[i];
for j in 0..8 {
row.mem_value[op][j] = memory_trace[memory::columns::value_limb(j)].values[i];
row.mem_channel_used[channel] = F::ONE;
row.clock = F::from_canonical_usize(clock);
row.mem_is_read[channel] = memory_trace[memory::columns::IS_READ].values[i];
row.mem_addr_context[channel] =
memory_trace[memory::columns::ADDR_CONTEXT].values[i];
row.mem_addr_segment[channel] =
memory_trace[memory::columns::ADDR_SEGMENT].values[i];
row.mem_addr_virtual[channel] =
memory_trace[memory::columns::ADDR_VIRTUAL].values[i];
for j in 0..8 {
row.mem_value[channel][j] =
memory_trace[memory::columns::value_limb(j)].values[i];
}
}
}
trace_rows_to_poly_values(cpu_trace_rows)
@ -337,7 +341,9 @@ mod tests {
let keccak_trace = make_keccak_trace(num_keccak_perms, &keccak_stark, &mut rng);
let logic_trace = make_logic_trace(num_logic_rows, &logic_stark, &mut rng);
let mut memory_trace = make_memory_trace(num_memory_ops, &memory_stark, &mut rng);
let mem_trace = make_memory_trace(num_memory_ops, &memory_stark, &mut rng);
let mut memory_trace = mem_trace.0;
let num_memory_ops = mem_trace.1;
let cpu_trace = make_cpu_trace(
num_keccak_perms,
num_logic_rows,

View File

@ -40,7 +40,6 @@ pub fn ctl_filter_logic<F: Field>() -> Column<F> {
pub fn ctl_data_memory<F: Field>(channel: usize) -> Vec<Column<F>> {
debug_assert!(channel < NUM_CHANNELS);
let mut cols: Vec<Column<F>> = Column::singles([
COL_MAP.clock,
COL_MAP.mem_is_read[channel],
COL_MAP.mem_addr_context[channel],
COL_MAP.mem_addr_segment[channel],
@ -48,6 +47,14 @@ pub fn ctl_data_memory<F: Field>(channel: usize) -> Vec<Column<F>> {
])
.collect_vec();
cols.extend(Column::singles(COL_MAP.mem_value[channel]));
let scalar = F::from_canonical_usize(NUM_CHANNELS);
let addend = F::from_canonical_usize(channel);
cols.push(Column::linear_combination_with_constant(
vec![(COL_MAP.clock, scalar)],
addend,
));
cols
}

View File

@ -1,5 +1,7 @@
//! Memory registers.
use std::ops::Range;
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
pub(crate) const TIMESTAMP: usize = 0;
@ -36,20 +38,22 @@ pub(crate) const CONTEXT_FIRST_CHANGE: usize = SORTED_VALUE_START + VALUE_LIMBS;
pub(crate) const SEGMENT_FIRST_CHANGE: usize = CONTEXT_FIRST_CHANGE + 1;
pub(crate) const VIRTUAL_FIRST_CHANGE: usize = SEGMENT_FIRST_CHANGE + 1;
// Flags to indicate if this operation came from the `i`th channel of the memory bus.
const IS_CHANNEL_START: usize = VIRTUAL_FIRST_CHANGE + 1;
pub(crate) const fn is_channel(channel: usize) -> usize {
debug_assert!(channel < NUM_CHANNELS);
IS_CHANNEL_START + channel
}
// We use a range check to ensure sorting.
pub(crate) const RANGE_CHECK: usize = VIRTUAL_FIRST_CHANGE + 1;
pub(crate) const RANGE_CHECK: usize = IS_CHANNEL_START + NUM_CHANNELS;
// The counter column (used for the range check) starts from 0 and increments.
pub(crate) const COUNTER: usize = RANGE_CHECK + 1;
// Helper columns for the permutation argument used to enforce the range check.
pub(crate) const RANGE_CHECK_PERMUTED: usize = COUNTER + 1;
pub(crate) const COUNTER_PERMUTED: usize = RANGE_CHECK_PERMUTED + 1;
// Flags to indicate if this operation came from the `i`th channel of the memory bus.
const IS_CHANNEL_START: usize = COUNTER_PERMUTED + 1;
#[allow(dead_code)]
pub(crate) const fn is_channel(channel: usize) -> usize {
debug_assert!(channel < NUM_CHANNELS);
IS_CHANNEL_START + channel
}
// Columns to be padded at the top with zeroes, before the permutation argument takes place.
pub(crate) const COLUMNS_TO_PAD: Range<usize> = TIMESTAMP..RANGE_CHECK + 1;
pub(crate) const NUM_COLUMNS: usize = IS_CHANNEL_START + NUM_CHANNELS;
pub(crate) const NUM_COLUMNS: usize = COUNTER_PERMUTED + 1;

View File

@ -11,17 +11,17 @@ use plonky2::timed;
use plonky2::util::timing::TimingTree;
use rand::Rng;
use super::columns::is_channel;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::Column;
use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols};
use crate::memory::columns::{
sorted_value_limb, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE,
COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED,
SEGMENT_FIRST_CHANGE, SORTED_ADDR_CONTEXT, SORTED_ADDR_SEGMENT, SORTED_ADDR_VIRTUAL,
SORTED_IS_READ, SORTED_TIMESTAMP, TIMESTAMP, VIRTUAL_FIRST_CHANGE,
is_channel, sorted_value_limb, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL,
COLUMNS_TO_PAD, CONTEXT_FIRST_CHANGE, COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS,
RANGE_CHECK, RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, SORTED_ADDR_CONTEXT,
SORTED_ADDR_SEGMENT, SORTED_ADDR_VIRTUAL, SORTED_IS_READ, SORTED_TIMESTAMP, TIMESTAMP,
VIRTUAL_FIRST_CHANGE,
};
use crate::memory::NUM_CHANNELS;
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
use crate::permutation::PermutationPair;
use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values;
@ -30,9 +30,10 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
pub(crate) const NUM_PUBLIC_INPUTS: usize = 0;
pub fn ctl_data<F: Field>() -> Vec<Column<F>> {
let mut res = Column::singles([TIMESTAMP, IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL])
.collect_vec();
let mut res =
Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec();
res.extend(Column::singles((0..8).map(value_limb)));
res.push(Column::single(TIMESTAMP));
res
}
@ -63,8 +64,7 @@ pub fn generate_random_memory_ops<F: RichField, R: Rng>(
let mut current_memory_values: HashMap<(F, F, F), [F; 8]> = HashMap::new();
let num_cycles = num_ops / 2;
for i in 0..num_cycles {
let timestamp = F::from_canonical_usize(i);
for clock in 0..num_cycles {
let mut used_indices = HashSet::new();
let mut new_writes_this_cycle = HashMap::new();
let mut has_read = false;
@ -75,7 +75,7 @@ pub fn generate_random_memory_ops<F: RichField, R: Rng>(
}
used_indices.insert(channel_index);
let is_read = if i == 0 {
let is_read = if clock == 0 {
false
} else {
!has_read && rng.gen()
@ -111,6 +111,7 @@ pub fn generate_random_memory_ops<F: RichField, R: Rng>(
(context, segment, virt, vals)
};
let timestamp = F::from_canonical_usize(clock * NUM_CHANNELS + channel_index);
memory_ops.push(MemoryOp {
channel_index,
timestamp,
@ -200,7 +201,7 @@ pub fn generate_range_check_value<F: RichField>(
context_first_change: &[F],
segment_first_change: &[F],
virtual_first_change: &[F],
) -> Vec<F> {
) -> (Vec<F>, usize) {
let num_ops = context.len();
let mut range_check = Vec::new();
@ -209,7 +210,6 @@ pub fn generate_range_check_value<F: RichField>(
- context_first_change[idx]
- segment_first_change[idx]
- virtual_first_change[idx];
range_check.push(
context_first_change[idx] * (context[idx + 1] - context[idx] - F::ONE)
+ segment_first_change[idx] * (segment[idx + 1] - segment[idx] - F::ONE)
@ -217,10 +217,11 @@ pub fn generate_range_check_value<F: RichField>(
+ this_address_unchanged * (timestamp[idx + 1] - timestamp[idx] - F::ONE),
);
}
range_check.push(F::ZERO);
range_check
let max_diff = range_check.iter().map(F::to_canonical_u64).max().unwrap() as usize;
(range_check, max_diff)
}
impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
@ -254,6 +255,9 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
self.generate_memory(&mut trace_cols);
// The number of rows may have changed, if the range check required padding.
let num_ops = trace_cols[0].len();
let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops];
for (i, col) in trace_cols.iter().enumerate() {
for (j, &val) in col.iter().enumerate() {
@ -295,7 +299,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
let (context_first_change, segment_first_change, virtual_first_change) =
generate_first_change_flags(&sorted_context, &sorted_segment, &sorted_virtual);
let range_check_value = generate_range_check_value(
let (range_check_value, max_diff) = generate_range_check_value(
&sorted_context,
&sorted_segment,
&sorted_virtual,
@ -304,6 +308,8 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
&segment_first_change,
&virtual_first_change,
);
let to_pad_to = (max_diff + 1).max(num_trace_rows).next_power_of_two();
let to_pad = to_pad_to - num_trace_rows;
trace_cols[SORTED_TIMESTAMP] = sorted_timestamp;
trace_cols[SORTED_IS_READ] = sorted_is_read;
@ -311,7 +317,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
trace_cols[SORTED_ADDR_SEGMENT] = sorted_segment;
trace_cols[SORTED_ADDR_VIRTUAL] = sorted_virtual;
for i in 0..num_trace_rows {
for j in 0..8 {
for j in 0..VALUE_LIMBS {
trace_cols[sorted_value_limb(j)][i] = sorted_values[i][j];
}
}
@ -321,9 +327,12 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change;
trace_cols[RANGE_CHECK] = range_check_value;
trace_cols[COUNTER] = (0..num_trace_rows)
.map(|i| F::from_canonical_usize(i))
.collect();
for col in COLUMNS_TO_PAD {
trace_cols[col].splice(0..0, vec![F::ZERO; to_pad]);
}
trace_cols[COUNTER] = (0..to_pad_to).map(|i| F::from_canonical_usize(i)).collect();
let (permuted_inputs, permuted_table) =
permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]);
@ -383,6 +392,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
.map(|i| vars.next_values[sorted_value_limb(i)])
.collect();
// Indicator that this is a real row, not a row of padding.
// TODO: enforce that all padding is at the beginning.
let valid_row: P = (0..NUM_CHANNELS)
.map(|c| vars.local_values[is_channel(c)])
.sum();
let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE];
let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE];
let virtual_first_change = vars.local_values[VIRTUAL_FIRST_CHANGE];
@ -409,15 +424,21 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
.constraint_transition(virtual_first_change * (next_addr_context - addr_context));
yield_constr
.constraint_transition(virtual_first_change * (next_addr_segment - addr_segment));
yield_constr.constraint_transition(address_unchanged * (next_addr_context - addr_context));
yield_constr.constraint_transition(address_unchanged * (next_addr_segment - addr_segment));
yield_constr.constraint_transition(address_unchanged * (next_addr_virtual - addr_virtual));
yield_constr.constraint_transition(
valid_row * address_unchanged * (next_addr_context - addr_context),
);
yield_constr.constraint_transition(
valid_row * address_unchanged * (next_addr_segment - addr_segment),
);
yield_constr.constraint_transition(
valid_row * address_unchanged * (next_addr_virtual - addr_virtual),
);
// Third set of ordering constraints: range-check difference in the column that should be increasing.
let computed_range_check = context_first_change * (next_addr_context - addr_context - one)
+ segment_first_change * (next_addr_segment - addr_segment - one)
+ virtual_first_change * (next_addr_virtual - addr_virtual - one)
+ address_unchanged * (next_timestamp - timestamp - one);
+ valid_row * address_unchanged * (next_timestamp - timestamp - one);
yield_constr.constraint_transition(range_check - computed_range_check);
// Enumerate purportedly-ordered log.
@ -454,6 +475,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let next_is_read = vars.next_values[SORTED_IS_READ];
let next_timestamp = vars.next_values[SORTED_TIMESTAMP];
// Indicator that this is a real row, not a row of padding.
let mut valid_row = vars.local_values[is_channel(0)];
for c in 1..NUM_CHANNELS {
valid_row = builder.add_extension(valid_row, vars.local_values[is_channel(c)]);
}
let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE];
let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE];
let virtual_first_change = vars.local_values[VIRTUAL_FIRST_CHANGE];
@ -498,11 +525,17 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.mul_extension(virtual_first_change, addr_segment_diff);
yield_constr.constraint_transition(builder, virtual_first_change_check_2);
let address_unchanged_check_1 = builder.mul_extension(address_unchanged, addr_context_diff);
yield_constr.constraint_transition(builder, address_unchanged_check_1);
let address_unchanged_check_1_valid =
builder.mul_extension(valid_row, address_unchanged_check_1);
yield_constr.constraint_transition(builder, address_unchanged_check_1_valid);
let address_unchanged_check_2 = builder.mul_extension(address_unchanged, addr_segment_diff);
yield_constr.constraint_transition(builder, address_unchanged_check_2);
let address_unchanged_check_2_valid =
builder.mul_extension(valid_row, address_unchanged_check_2);
yield_constr.constraint_transition(builder, address_unchanged_check_2_valid);
let address_unchanged_check_3 = builder.mul_extension(address_unchanged, addr_virtual_diff);
yield_constr.constraint_transition(builder, address_unchanged_check_3);
let address_unchanged_check_3_valid =
builder.mul_extension(valid_row, address_unchanged_check_3);
yield_constr.constraint_transition(builder, address_unchanged_check_3_valid);
// Third set of ordering constraints: range-check difference in the column that should be increasing.
let context_diff = {
@ -525,6 +558,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.sub_extension(diff, one)
};
let timestamp_range_check = builder.mul_extension(address_unchanged, timestamp_diff);
let timestamp_range_check = builder.mul_extension(valid_row, timestamp_range_check);
let computed_range_check = {
let mut sum = builder.add_extension(context_range_check, segment_range_check);
@ -556,7 +590,23 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
}
fn permutation_pairs(&self) -> Vec<PermutationPair> {
let mut unsorted_cols = vec![TIMESTAMP, IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL];
unsorted_cols.extend((0..VALUE_LIMBS).map(value_limb));
let mut sorted_cols = vec![
SORTED_TIMESTAMP,
SORTED_IS_READ,
SORTED_ADDR_CONTEXT,
SORTED_ADDR_SEGMENT,
SORTED_ADDR_VIRTUAL,
];
sorted_cols.extend((0..VALUE_LIMBS).map(sorted_value_limb));
let column_pairs: Vec<_> = unsorted_cols
.into_iter()
.zip(sorted_cols.iter().cloned())
.collect();
vec![
PermutationPair { column_pairs },
PermutationPair::singletons(RANGE_CHECK, RANGE_CHECK_PERMUTED),
PermutationPair::singletons(COUNTER, COUNTER_PERMUTED),
]