diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 7387653d..244af688 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::marker::PhantomData; use itertools::{izip, multiunzip, Itertools}; @@ -52,41 +52,48 @@ pub fn generate_random_memory_ops( let is_read = if i == 0 { false } else { rng.gen() }; let is_read_field = F::from_bool(is_read); - let (context, segment, virt, vals) = if is_read { - let written: Vec<_> = current_memory_values.keys().collect(); - let &(context, segment, virt) = written[rng.gen_range(0..written.len())]; - let &vals = current_memory_values - .get(&(context, segment, virt)) - .unwrap(); - - (context, segment, virt, vals) - } else { - // TODO: with taller memory table or more padding (to enable range-checking bigger diffs), - // test larger address values. - let context = F::from_canonical_usize(rng.gen_range(0..40)); - let segment = F::from_canonical_usize(rng.gen_range(0..8)); - let virt = F::from_canonical_usize(rng.gen_range(0..20)); - - let val: [u32; 8] = rng.gen(); - let vals: [F; 8] = val.map(F::from_canonical_u32); - - current_memory_values.insert((context, segment, virt), vals); - - (context, segment, virt, vals) - }; - let timestamp = F::from_canonical_usize(i); - let channel_index = rng.gen_range(0..4); + let mut used_indices = HashSet::new(); + for _ in 0..2 { + let mut channel_index = rng.gen_range(0..4); + while used_indices.contains(&channel_index) { + channel_index = rng.gen_range(0..4); + } + used_indices.insert(channel_index); - memory_ops.push(MemoryOp { - channel_index, - timestamp, - is_read: is_read_field, - context, - segment, - virt, - value: vals, - }); + let (context, segment, virt, vals) = if is_read { + let written: Vec<_> = current_memory_values.keys().collect(); + let &(context, segment, virt) = written[rng.gen_range(0..written.len())]; + let &vals = current_memory_values + .get(&(context, segment, virt)) + .unwrap(); + + (context, segment, virt, vals) + } else { + // TODO: with taller memory table or more padding (to enable range-checking bigger diffs), + // test larger address values. + let context = F::from_canonical_usize(rng.gen_range(0..40)); + let segment = F::from_canonical_usize(rng.gen_range(0..8)); + let virt = F::from_canonical_usize(rng.gen_range(0..20)); + + let val: [u32; 8] = rng.gen(); + let vals: [F; 8] = val.map(F::from_canonical_u32); + + current_memory_values.insert((context, segment, virt), vals); + + (context, segment, virt, vals) + }; + + memory_ops.push(MemoryOp { + channel_index, + timestamp, + is_read: is_read_field, + context, + segment, + virt, + value: vals, + }); + } } memory_ops