Store memory values as U256s

Ultimately they're encoded as `[F; 8]`s in the table, but I don't anticipate that we'll have any use cases where we want to store more than 256 bits. Might as well store `U256` until we actually build the table since they're more compact.
This commit is contained in:
Daniel Lubarov 2022-07-16 09:14:51 -07:00
parent 934bf757dd
commit 997453237f
5 changed files with 36 additions and 44 deletions

View File

@ -17,7 +17,6 @@ use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::keccak_util::keccakf_u32s; use crate::cpu::kernel::keccak_util::keccakf_u32s;
use crate::cpu::public_inputs::NUM_PUBLIC_INPUTS; use crate::cpu::public_inputs::NUM_PUBLIC_INPUTS;
use crate::generation::state::GenerationState; use crate::generation::state::GenerationState;
use crate::memory;
use crate::memory::segments::Segment; use crate::memory::segments::Segment;
use crate::memory::NUM_CHANNELS; use crate::memory::NUM_CHANNELS;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
@ -50,11 +49,8 @@ pub(crate) fn generate_bootstrap_kernel<F: Field>(state: &mut GenerationState<F>
// Write this chunk to memory, while simultaneously packing its bytes into a u32 word. // Write this chunk to memory, while simultaneously packing its bytes into a u32 word.
let mut packed_bytes: u32 = 0; let mut packed_bytes: u32 = 0;
for (addr, byte) in chunk { for (addr, byte) in chunk {
let mut value = [F::ZERO; memory::VALUE_LIMBS];
value[0] = F::from_canonical_u8(byte);
let channel = addr % NUM_CHANNELS; let channel = addr % NUM_CHANNELS;
state.set_mem_current(channel, Segment::Code, addr, value); state.set_mem_current(channel, Segment::Code, addr, byte.into());
packed_bytes = (packed_bytes << 8) | byte as u32; packed_bytes = (packed_bytes << 8) | byte as u32;
} }

View File

@ -1,19 +1,18 @@
use plonky2::field::types::Field; use ethereum_types::U256;
use crate::memory::memory_stark::MemoryOp; use crate::memory::memory_stark::MemoryOp;
use crate::memory::segments::Segment; use crate::memory::segments::Segment;
use crate::memory::VALUE_LIMBS;
#[allow(unused)] // TODO: Should be used soon. #[allow(unused)] // TODO: Should be used soon.
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct MemoryState<F: Field> { pub(crate) struct MemoryState {
/// A log of each memory operation, in the order that it occurred. /// A log of each memory operation, in the order that it occurred.
pub log: Vec<MemoryOp<F>>, pub log: Vec<MemoryOp>,
pub contexts: Vec<MemoryContextState<F>>, pub contexts: Vec<MemoryContextState>,
} }
impl<F: Field> Default for MemoryState<F> { impl Default for MemoryState {
fn default() -> Self { fn default() -> Self {
Self { Self {
log: vec![], log: vec![],
@ -24,28 +23,27 @@ impl<F: Field> Default for MemoryState<F> {
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub(crate) struct MemoryContextState<F: Field> { pub(crate) struct MemoryContextState {
/// The content of each memory segment. /// The content of each memory segment.
pub segments: [MemorySegmentState<F>; Segment::COUNT], pub segments: [MemorySegmentState; Segment::COUNT],
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub(crate) struct MemorySegmentState<F: Field> { pub(crate) struct MemorySegmentState {
pub content: Vec<[F; VALUE_LIMBS]>, pub content: Vec<U256>,
} }
impl<F: Field> MemorySegmentState<F> { impl MemorySegmentState {
pub(super) fn get(&self, virtual_addr: usize) -> [F; VALUE_LIMBS] { pub(super) fn get(&self, virtual_addr: usize) -> U256 {
self.content self.content
.get(virtual_addr) .get(virtual_addr)
.copied() .copied()
.unwrap_or([F::ZERO; VALUE_LIMBS]) .unwrap_or(U256::zero())
} }
pub(super) fn set(&mut self, virtual_addr: usize, value: [F; VALUE_LIMBS]) { pub(super) fn set(&mut self, virtual_addr: usize, value: U256) {
if virtual_addr + 1 > self.content.len() { if virtual_addr + 1 > self.content.len() {
self.content self.content.resize(virtual_addr + 1, U256::zero());
.resize(virtual_addr + 1, [F::ZERO; VALUE_LIMBS]);
} }
self.content[virtual_addr] = value; self.content[virtual_addr] = value;
} }

View File

@ -15,7 +15,7 @@ pub(crate) struct GenerationState<F: Field> {
pub(crate) current_cpu_row: CpuColumnsView<F>, pub(crate) current_cpu_row: CpuColumnsView<F>,
pub(crate) current_context: usize, pub(crate) current_context: usize,
pub(crate) memory: MemoryState<F>, pub(crate) memory: MemoryState,
pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>, pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>,
pub(crate) logic_ops: Vec<logic::Operation>, pub(crate) logic_ops: Vec<logic::Operation>,
@ -55,7 +55,7 @@ impl<F: Field> GenerationState<F> {
channel_index: usize, channel_index: usize,
segment: Segment, segment: Segment,
virt: usize, virt: usize,
) -> [F; crate::memory::VALUE_LIMBS] { ) -> U256 {
let timestamp = self.cpu_rows.len(); let timestamp = self.cpu_rows.len();
let context = self.current_context; let context = self.current_context;
let value = self.memory.contexts[context].segments[segment as usize].get(virt); let value = self.memory.contexts[context].segments[segment as usize].get(virt);
@ -77,7 +77,7 @@ impl<F: Field> GenerationState<F> {
channel_index: usize, channel_index: usize,
segment: Segment, segment: Segment,
virt: usize, virt: usize,
value: [F; crate::memory::VALUE_LIMBS], value: U256,
) { ) {
let timestamp = self.cpu_rows.len(); let timestamp = self.cpu_rows.len();
let context = self.current_context; let context = self.current_context;

View File

@ -9,7 +9,8 @@ pub(crate) const ADDR_CONTEXT: usize = IS_READ + 1;
pub(crate) const ADDR_SEGMENT: usize = ADDR_CONTEXT + 1; pub(crate) const ADDR_SEGMENT: usize = ADDR_CONTEXT + 1;
pub(crate) const ADDR_VIRTUAL: usize = ADDR_SEGMENT + 1; pub(crate) const ADDR_VIRTUAL: usize = ADDR_SEGMENT + 1;
// Eight limbs to hold up to a 256-bit value. // Eight 32-bit limbs hold a total of 256 bits.
// If a value represents an integer, it is little-endian encoded.
const VALUE_START: usize = ADDR_VIRTUAL + 1; const VALUE_START: usize = ADDR_VIRTUAL + 1;
pub(crate) const fn value_limb(i: usize) -> usize { pub(crate) const fn value_limb(i: usize) -> usize {
debug_assert!(i < VALUE_LIMBS); debug_assert!(i < VALUE_LIMBS);

View File

@ -1,5 +1,6 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use ethereum_types::U256;
use itertools::Itertools; use itertools::Itertools;
use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField; use plonky2::field::packed::PackedField;
@ -45,7 +46,7 @@ pub struct MemoryStark<F, const D: usize> {
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct MemoryOp<F> { pub(crate) struct MemoryOp {
/// The channel this operation came from, or `None` if it's a dummy operation for padding. /// The channel this operation came from, or `None` if it's a dummy operation for padding.
pub channel_index: Option<usize>, pub channel_index: Option<usize>,
pub timestamp: usize, pub timestamp: usize,
@ -53,15 +54,15 @@ pub(crate) struct MemoryOp<F> {
pub context: usize, pub context: usize,
pub segment: Segment, pub segment: Segment,
pub virt: usize, pub virt: usize,
pub value: [F; 8], pub value: U256,
} }
impl<F: Field> MemoryOp<F> { impl MemoryOp {
/// Generate a row for a given memory operation. Note that this does not generate columns which /// Generate a row for a given memory operation. Note that this does not generate columns which
/// depend on the next operation, such as `CONTEXT_FIRST_CHANGE`; those are generated later. /// depend on the next operation, such as `CONTEXT_FIRST_CHANGE`; those are generated later.
/// It also does not generate columns such as `COUNTER`, which are generated later, after the /// It also does not generate columns such as `COUNTER`, which are generated later, after the
/// trace has been transposed into column-major form. /// trace has been transposed into column-major form.
fn to_row(&self) -> [F; NUM_COLUMNS] { fn to_row<F: Field>(&self) -> [F; NUM_COLUMNS] {
let mut row = [F::ZERO; NUM_COLUMNS]; let mut row = [F::ZERO; NUM_COLUMNS];
if let Some(channel) = self.channel_index { if let Some(channel) = self.channel_index {
row[is_channel(channel)] = F::ONE; row[is_channel(channel)] = F::ONE;
@ -72,13 +73,13 @@ impl<F: Field> MemoryOp<F> {
row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment as usize); row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment as usize);
row[ADDR_VIRTUAL] = F::from_canonical_usize(self.virt); row[ADDR_VIRTUAL] = F::from_canonical_usize(self.virt);
for j in 0..VALUE_LIMBS { for j in 0..VALUE_LIMBS {
row[value_limb(j)] = self.value[j]; row[value_limb(j)] = F::from_canonical_u32((self.value >> (j * 32)).low_u32());
} }
row row
} }
} }
fn get_max_range_check<F: Field>(memory_ops: &[MemoryOp<F>]) -> usize { fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize {
memory_ops memory_ops
.iter() .iter()
.tuple_windows() .tuple_windows()
@ -142,7 +143,7 @@ pub fn generate_first_change_flags_and_rc<F: RichField>(trace_rows: &mut [[F; NU
impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> { impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
/// Generate most of the trace rows. Excludes a few columns like `COUNTER`, which are generated /// Generate most of the trace rows. Excludes a few columns like `COUNTER`, which are generated
/// later, after transposing to column-major form. /// later, after transposing to column-major form.
fn generate_trace_row_major(&self, mut memory_ops: Vec<MemoryOp<F>>) -> Vec<[F; NUM_COLUMNS]> { fn generate_trace_row_major(&self, mut memory_ops: Vec<MemoryOp>) -> Vec<[F; NUM_COLUMNS]> {
memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp)); memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp));
Self::pad_memory_ops(&mut memory_ops); Self::pad_memory_ops(&mut memory_ops);
@ -167,7 +168,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
trace_col_vecs[COUNTER_PERMUTED] = permuted_table; trace_col_vecs[COUNTER_PERMUTED] = permuted_table;
} }
fn pad_memory_ops(memory_ops: &mut Vec<MemoryOp<F>>) { fn pad_memory_ops(memory_ops: &mut Vec<MemoryOp>) {
let num_ops = memory_ops.len(); let num_ops = memory_ops.len();
let max_range_check = get_max_range_check(memory_ops); let max_range_check = get_max_range_check(memory_ops);
let num_ops_padded = num_ops.max(max_range_check + 1).next_power_of_two(); let num_ops_padded = num_ops.max(max_range_check + 1).next_power_of_two();
@ -190,7 +191,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
} }
} }
pub(crate) fn generate_trace(&self, memory_ops: Vec<MemoryOp<F>>) -> Vec<PolynomialValues<F>> { pub(crate) fn generate_trace(&self, memory_ops: Vec<MemoryOp>) -> Vec<PolynomialValues<F>> {
let mut timing = TimingTree::new("generate trace", log::Level::Debug); let mut timing = TimingTree::new("generate trace", log::Level::Debug);
// Generate most of the trace in row-major form. // Generate most of the trace in row-major form.
@ -463,7 +464,7 @@ pub(crate) mod tests {
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use anyhow::Result; use anyhow::Result;
use plonky2::hash::hash_types::RichField; use ethereum_types::U256;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use rand::prelude::SliceRandom; use rand::prelude::SliceRandom;
use rand::Rng; use rand::Rng;
@ -473,13 +474,10 @@ pub(crate) mod tests {
use crate::memory::NUM_CHANNELS; use crate::memory::NUM_CHANNELS;
use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};
pub(crate) fn generate_random_memory_ops<F: RichField, R: Rng>( pub(crate) fn generate_random_memory_ops<R: Rng>(num_ops: usize, rng: &mut R) -> Vec<MemoryOp> {
num_ops: usize,
rng: &mut R,
) -> Vec<MemoryOp<F>> {
let mut memory_ops = Vec::new(); let mut memory_ops = Vec::new();
let mut current_memory_values: HashMap<(usize, Segment, usize), [F; 8]> = HashMap::new(); let mut current_memory_values: HashMap<(usize, Segment, usize), U256> = HashMap::new();
let num_cycles = num_ops / 2; let num_cycles = num_ops / 2;
for clock in 0..num_cycles { for clock in 0..num_cycles {
let mut used_indices = HashSet::new(); let mut used_indices = HashSet::new();
@ -520,12 +518,11 @@ pub(crate) mod tests {
virt = rng.gen_range(0..20); virt = rng.gen_range(0..20);
} }
let val: [u32; 8] = rng.gen(); let val = U256(rng.gen());
let vals: [F; 8] = val.map(F::from_canonical_u32);
new_writes_this_cycle.insert((context, segment, virt), vals); new_writes_this_cycle.insert((context, segment, virt), val);
(context, segment, virt, vals) (context, segment, virt, val)
}; };
let timestamp = clock * NUM_CHANNELS + channel_index; let timestamp = clock * NUM_CHANNELS + channel_index;