Merge pull request #617 from mir-protocol/segment_enum

Organize segments in an enum
This commit is contained in:
Daniel Lubarov 2022-07-17 07:46:24 -07:00 committed by GitHub
commit 934bf757dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 165 additions and 113 deletions

View File

@ -146,7 +146,8 @@ mod tests {
use crate::cross_table_lookup::testutils::check_ctls;
use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS};
use crate::logic::{self, LogicStark, Operation};
use crate::memory::memory_stark::{generate_random_memory_ops, MemoryStark};
use crate::memory::memory_stark::tests::generate_random_memory_ops;
use crate::memory::memory_stark::MemoryStark;
use crate::memory::NUM_CHANNELS;
use crate::proof::AllProof;
use crate::prover::prove;

View File

@ -18,7 +18,8 @@ use crate::cpu::kernel::keccak_util::keccakf_u32s;
use crate::cpu::public_inputs::NUM_PUBLIC_INPUTS;
use crate::generation::state::GenerationState;
use crate::memory;
use crate::memory::{segments, NUM_CHANNELS};
use crate::memory::segments::Segment;
use crate::memory::NUM_CHANNELS;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
/// The Keccak rate (1088 bits), measured in bytes.
@ -53,7 +54,7 @@ pub(crate) fn generate_bootstrap_kernel<F: Field>(state: &mut GenerationState<F>
value[0] = F::from_canonical_u8(byte);
let channel = addr % NUM_CHANNELS;
state.set_mem_current(channel, segments::CODE, addr, value);
state.set_mem_current(channel, Segment::Code, addr, value);
packed_bytes = (packed_bytes << 8) | byte as u32;
}

View File

@ -8,12 +8,15 @@ use once_cell::sync::Lazy;
use super::assembler::{assemble, Kernel};
use crate::cpu::kernel::parser::parse;
use crate::memory::segments::Segment;
pub static KERNEL: Lazy<Kernel> = Lazy::new(combined_kernel);
pub fn evm_constants() -> HashMap<String, U256> {
let mut c = HashMap::new();
c.insert("SEGMENT_ID_TXN_DATA".into(), 0.into()); // TODO: Replace with actual segment ID.
for segment in Segment::all() {
c.insert(segment.var_name().into(), (segment as u32).into());
}
c
}

View File

@ -1,7 +1,7 @@
use plonky2::field::types::Field;
use crate::memory::memory_stark::MemoryOp;
use crate::memory::segments::NUM_SEGMENTS;
use crate::memory::segments::Segment;
use crate::memory::VALUE_LIMBS;
#[allow(unused)] // TODO: Should be used soon.
@ -26,7 +26,7 @@ impl<F: Field> Default for MemoryState<F> {
#[derive(Default, Debug)]
pub(crate) struct MemoryContextState<F: Field> {
/// The content of each memory segment.
pub segments: [MemorySegmentState<F>; NUM_SEGMENTS],
pub segments: [MemorySegmentState<F>; Segment::COUNT],
}
#[derive(Default, Debug)]

View File

@ -6,6 +6,7 @@ use plonky2::field::types::Field;
use crate::cpu::columns::{CpuColumnsView, NUM_CPU_COLUMNS};
use crate::generation::memory::MemoryState;
use crate::memory::memory_stark::MemoryOp;
use crate::memory::segments::Segment;
use crate::{keccak, logic};
#[derive(Debug)]
@ -52,12 +53,12 @@ impl<F: Field> GenerationState<F> {
pub(crate) fn get_mem_current(
&mut self,
channel_index: usize,
segment: usize,
segment: Segment,
virt: usize,
) -> [F; crate::memory::VALUE_LIMBS] {
let timestamp = self.cpu_rows.len();
let context = self.current_context;
let value = self.memory.contexts[context].segments[segment].get(virt);
let value = self.memory.contexts[context].segments[segment as usize].get(virt);
self.memory.log.push(MemoryOp {
channel_index: Some(channel_index),
timestamp,
@ -74,7 +75,7 @@ impl<F: Field> GenerationState<F> {
pub(crate) fn set_mem_current(
&mut self,
channel_index: usize,
segment: usize,
segment: Segment,
virt: usize,
value: [F; crate::memory::VALUE_LIMBS],
) {
@ -89,7 +90,7 @@ impl<F: Field> GenerationState<F> {
virt,
value,
});
self.memory.contexts[context].segments[segment].set(virt, value)
self.memory.contexts[context].segments[segment as usize].set(virt, value)
}
pub(crate) fn commit_cpu_row(&mut self) {

View File

@ -1,4 +1,3 @@
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
use itertools::Itertools;
@ -10,7 +9,6 @@ use plonky2::hash::hash_types::RichField;
use plonky2::timed;
use plonky2::util::timing::TimingTree;
use plonky2::util::transpose;
use rand::Rng;
use rayon::prelude::*;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
@ -21,6 +19,7 @@ use crate::memory::columns::{
COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED,
SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE,
};
use crate::memory::segments::Segment;
use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
use crate::permutation::PermutationPair;
use crate::stark::Stark;
@ -46,13 +45,13 @@ pub struct MemoryStark<F, const D: usize> {
}
#[derive(Clone, Debug)]
pub struct MemoryOp<F> {
pub(crate) struct MemoryOp<F> {
/// The channel this operation came from, or `None` if it's a dummy operation for padding.
pub channel_index: Option<usize>,
pub timestamp: usize,
pub is_read: bool,
pub context: usize,
pub segment: usize,
pub segment: Segment,
pub virt: usize,
pub value: [F; 8],
}
@ -70,7 +69,7 @@ impl<F: Field> MemoryOp<F> {
row[TIMESTAMP] = F::from_canonical_usize(self.timestamp);
row[IS_READ] = F::from_bool(self.is_read);
row[ADDR_CONTEXT] = F::from_canonical_usize(self.context);
row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment);
row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment as usize);
row[ADDR_VIRTUAL] = F::from_canonical_usize(self.virt);
for j in 0..VALUE_LIMBS {
row[value_limb(j)] = self.value[j];
@ -79,79 +78,6 @@ impl<F: Field> MemoryOp<F> {
}
}
pub fn generate_random_memory_ops<F: RichField, R: Rng>(
num_ops: usize,
rng: &mut R,
) -> Vec<MemoryOp<F>> {
let mut memory_ops = Vec::new();
let mut current_memory_values: HashMap<(usize, usize, usize), [F; 8]> = HashMap::new();
let num_cycles = num_ops / 2;
for clock in 0..num_cycles {
let mut used_indices = HashSet::new();
let mut new_writes_this_cycle = HashMap::new();
let mut has_read = false;
for _ in 0..2 {
let mut channel_index = rng.gen_range(0..NUM_CHANNELS);
while used_indices.contains(&channel_index) {
channel_index = rng.gen_range(0..NUM_CHANNELS);
}
used_indices.insert(channel_index);
let is_read = if clock == 0 {
false
} else {
!has_read && rng.gen()
};
has_read = has_read || is_read;
let (context, segment, virt, vals) = if is_read {
let written: Vec<_> = current_memory_values.keys().collect();
let &(context, segment, virt) = written[rng.gen_range(0..written.len())];
let &vals = current_memory_values
.get(&(context, segment, virt))
.unwrap();
(context, segment, virt, vals)
} else {
// TODO: with taller memory table or more padding (to enable range-checking bigger diffs),
// test larger address values.
let mut context = rng.gen_range(0..40);
let mut segment = rng.gen_range(0..8);
let mut virt = rng.gen_range(0..20);
while new_writes_this_cycle.contains_key(&(context, segment, virt)) {
context = rng.gen_range(0..40);
segment = rng.gen_range(0..8);
virt = rng.gen_range(0..20);
}
let val: [u32; 8] = rng.gen();
let vals: [F; 8] = val.map(F::from_canonical_u32);
new_writes_this_cycle.insert((context, segment, virt), vals);
(context, segment, virt, vals)
};
let timestamp = clock * NUM_CHANNELS + channel_index;
memory_ops.push(MemoryOp {
channel_index: Some(channel_index),
timestamp,
is_read,
context,
segment,
virt,
value: vals,
});
}
for (k, v) in new_writes_this_cycle {
current_memory_values.insert(k, v);
}
}
memory_ops
}
fn get_max_range_check<F: Field>(memory_ops: &[MemoryOp<F>]) -> usize {
memory_ops
.iter()
@ -160,7 +86,7 @@ fn get_max_range_check<F: Field>(memory_ops: &[MemoryOp<F>]) -> usize {
if curr.context != next.context {
next.context - curr.context - 1
} else if curr.segment != next.segment {
next.segment - curr.segment - 1
next.segment as usize - curr.segment as usize - 1
} else if curr.virt != next.virt {
next.virt - curr.virt - 1
} else {
@ -264,7 +190,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
}
}
pub fn generate_trace(&self, memory_ops: Vec<MemoryOp<F>>) -> Vec<PolynomialValues<F>> {
pub(crate) fn generate_trace(&self, memory_ops: Vec<MemoryOp<F>>) -> Vec<PolynomialValues<F>> {
let mut timing = TimingTree::new("generate trace", log::Level::Debug);
// Generate most of the trace in row-major form.
@ -533,13 +459,94 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
pub(crate) mod tests {
use std::collections::{HashMap, HashSet};
use crate::memory::memory_stark::MemoryStark;
use anyhow::Result;
use plonky2::hash::hash_types::RichField;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use rand::prelude::SliceRandom;
use rand::Rng;
use crate::memory::memory_stark::{MemoryOp, MemoryStark};
use crate::memory::segments::Segment;
use crate::memory::NUM_CHANNELS;
use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};
pub(crate) fn generate_random_memory_ops<F: RichField, R: Rng>(
num_ops: usize,
rng: &mut R,
) -> Vec<MemoryOp<F>> {
let mut memory_ops = Vec::new();
let mut current_memory_values: HashMap<(usize, Segment, usize), [F; 8]> = HashMap::new();
let num_cycles = num_ops / 2;
for clock in 0..num_cycles {
let mut used_indices = HashSet::new();
let mut new_writes_this_cycle = HashMap::new();
let mut has_read = false;
for _ in 0..2 {
let mut channel_index = rng.gen_range(0..NUM_CHANNELS);
while used_indices.contains(&channel_index) {
channel_index = rng.gen_range(0..NUM_CHANNELS);
}
used_indices.insert(channel_index);
let is_read = if clock == 0 {
false
} else {
!has_read && rng.gen()
};
has_read = has_read || is_read;
let (context, segment, virt, vals) = if is_read {
let written: Vec<_> = current_memory_values.keys().collect();
let &(context, segment, virt) = written[rng.gen_range(0..written.len())];
let &vals = current_memory_values
.get(&(context, segment, virt))
.unwrap();
(context, segment, virt, vals)
} else {
// TODO: with taller memory table or more padding (to enable range-checking bigger diffs),
// test larger address values.
let mut context = rng.gen_range(0..40);
let segments = [Segment::Code, Segment::Stack, Segment::MainMemory];
let mut segment = *segments.choose(rng).unwrap();
let mut virt = rng.gen_range(0..20);
while new_writes_this_cycle.contains_key(&(context, segment, virt)) {
context = rng.gen_range(0..40);
segment = *segments.choose(rng).unwrap();
virt = rng.gen_range(0..20);
}
let val: [u32; 8] = rng.gen();
let vals: [F; 8] = val.map(F::from_canonical_u32);
new_writes_this_cycle.insert((context, segment, virt), vals);
(context, segment, virt, vals)
};
let timestamp = clock * NUM_CHANNELS + channel_index;
memory_ops.push(MemoryOp {
channel_index: Some(channel_index),
timestamp,
is_read,
context,
segment,
virt,
value: vals,
});
}
for (k, v) in new_writes_this_cycle {
current_memory_values.insert(k, v);
}
}
memory_ops
}
#[test]
fn test_stark_degree() -> Result<()> {
const D: usize = 2;

View File

@ -1,22 +1,61 @@
/// Contains EVM bytecode.
pub const CODE: usize = 0;
#[allow(dead_code)] // TODO: Not all segments are used yet.
#[derive(Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd, Debug)]
pub(crate) enum Segment {
/// Contains EVM bytecode.
Code = 0,
/// The program stack.
Stack = 1,
/// Main memory, owned by the contract code.
MainMemory = 2,
/// Data passed to the current context by its caller.
Calldata = 3,
/// Data returned to the current context by its latest callee.
Returndata = 4,
/// A segment which contains a few fixed-size metadata fields, such as the caller's context, or the
/// size of `CALLDATA` and `RETURNDATA`.
Metadata = 5,
/// General purpose kernel memory, used by various kernel functions.
/// In general, calling a helper function can result in this memory being clobbered.
KernelGeneral = 6,
/// Contains transaction data (after it's parsed and converted to a standard format).
TxnData = 7,
/// Raw RLP data.
RlpRaw = 8,
/// RLP data that has been parsed and converted to a more "friendly" format.
RlpParsed = 9,
}
pub const STACK: usize = 1;
impl Segment {
pub(crate) const COUNT: usize = 10;
/// Main memory, owned by the contract code.
pub const MAIN_MEM: usize = 2;
pub(crate) fn all() -> [Self; Self::COUNT] {
[
Self::Code,
Self::Stack,
Self::MainMemory,
Self::Calldata,
Self::Returndata,
Self::Metadata,
Self::KernelGeneral,
Self::TxnData,
Self::RlpRaw,
Self::RlpParsed,
]
}
/// Memory owned by the kernel.
pub const KERNEL_MEM: usize = 3;
/// Data passed to the current context by its caller.
pub const CALLDATA: usize = 4;
/// Data returned to the current context by its latest callee.
pub const RETURNDATA: usize = 5;
/// A segment which contains a few fixed-size metadata fields, such as the caller's context, or the
/// size of `CALLDATA` and `RETURNDATA`.
pub const METADATA: usize = 6;
pub const NUM_SEGMENTS: usize = 7;
/// The variable name that gets passed into kernel assembly code.
pub(crate) fn var_name(&self) -> &'static str {
match self {
Segment::Code => "SEGMENT_CODE",
Segment::Stack => "SEGMENT_STACK",
Segment::MainMemory => "SEGMENT_MAIN_MEMORY",
Segment::Calldata => "SEGMENT_CALLDATA",
Segment::Returndata => "SEGMENT_RETURNDATA",
Segment::Metadata => "SEGMENT_METADATA",
Segment::KernelGeneral => "SEGMENT_KERNEL_GENERAL",
Segment::TxnData => "SEGMENT_TXN_DATA",
Segment::RlpRaw => "SEGMENT_RLP_RAW",
Segment::RlpParsed => "SEGMENT_RLP_PARSED",
}
}
}