diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index b2fe7dc9..566415ca 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::marker::PhantomData; -use itertools::{izip, multiunzip}; +use itertools::{izip, multiunzip, Itertools}; use plonky2::field::extension_field::{Extendable, FieldExtension}; use plonky2::field::packed_field::PackedField; use plonky2::field::polynomial::PolynomialValues; @@ -72,12 +72,7 @@ pub fn generate_random_memory_ops(num_ops: usize) -> Vec<(F, F, F, let virt = F::from_canonical_usize(rng.gen_range(0..20)); let val: [u32; 8] = rng.gen(); - let vals: [F; 8] = val - .iter() - .map(|&x| F::from_canonical_u32(x)) - .collect::>() - .try_into() - .unwrap(); + let vals: [F; 8] = val.map(F::from_canonical_u32); current_memory_values.insert((context, segment, virt), vals); @@ -97,11 +92,11 @@ pub fn sort_memory_ops( context: &[F], segment: &[F], virtuals: &[F], - values: &[Vec], + values: &Vec<[F; 8]>, is_read: &[F], timestamp: &[F], -) -> (Vec, Vec, Vec, Vec>, Vec, Vec) { - let mut ops: Vec<(F, F, F, Vec, F, F)> = izip!( +) -> (Vec, Vec, Vec, Vec<[F; 8]>, Vec, Vec) { + let mut ops: Vec<(F, F, F, [F; 8], F, F)> = izip!( context.iter().cloned(), segment.iter().cloned(), virtuals.iter().cloned(), @@ -111,19 +106,13 @@ pub fn sort_memory_ops( ) .collect(); - ops.sort_by(|&(c1, s1, v1, _, _, t1), &(c2, s2, v2, _, _, t2)| { + ops.sort_by_key(|&(c, s, v, _, _, t)| { ( - c1.to_noncanonical_u64(), - s1.to_noncanonical_u64(), - v1.to_noncanonical_u64(), - t1.to_noncanonical_u64(), + c.to_noncanonical_u64(), + s.to_noncanonical_u64(), + v.to_noncanonical_u64(), + t.to_noncanonical_u64(), ) - .cmp(&( - c2.to_noncanonical_u64(), - s2.to_noncanonical_u64(), - v2.to_noncanonical_u64(), - t2.to_noncanonical_u64(), - )) }); multiunzip(ops) @@ -236,12 +225,21 @@ impl, const D: usize> MemoryStark { } fn generate_memory(&self, trace_cols: &mut [Vec]) { + let num_trace_rows = trace_cols[0].len(); + let context = &trace_cols[MEMORY_ADDR_CONTEXT]; let segment = &trace_cols[MEMORY_ADDR_SEGMENT]; let virtuals = &trace_cols[MEMORY_ADDR_VIRTUAL]; - let values: Vec> = (0..8) - .map(|i| &trace_cols[memory_value_limb(i)]) - .cloned() + let values: Vec<[F; 8]> = (0..num_trace_rows) + .map(|i| { + let arr: [F; 8] = (0..8) + .map(|j| &trace_cols[memory_value_limb(j)][i]) + .cloned() + .collect_vec() + .try_into() + .unwrap(); + arr + }) .collect(); let is_read = &trace_cols[MEMORY_IS_READ]; let timestamp = &trace_cols[MEMORY_TIMESTAMP]; @@ -271,8 +269,10 @@ impl, const D: usize> MemoryStark { trace_cols[SORTED_MEMORY_ADDR_CONTEXT] = sorted_context; trace_cols[SORTED_MEMORY_ADDR_SEGMENT] = sorted_segment; trace_cols[SORTED_MEMORY_ADDR_VIRTUAL] = sorted_virtual; - for i in 0..8 { - trace_cols[sorted_memory_value_limb(i)] = sorted_values[i].clone(); + for i in 0..num_trace_rows { + for j in 0..8 { + trace_cols[sorted_memory_value_limb(j)][i] = sorted_values[i][j]; + } } trace_cols[SORTED_MEMORY_IS_READ] = sorted_is_read; trace_cols[SORTED_MEMORY_TIMESTAMP] = sorted_timestamp; @@ -367,19 +367,23 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark