diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 21601988..42dc8541 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -1,7 +1,7 @@ use itertools::izip; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; +use crate::field::field_types::{Field, RichField}; use crate::gates::comparison::ComparisonGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; @@ -9,6 +9,13 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; +pub struct MemoryOp { + is_write: bool, + address: F, + timestamp: F, + value: F, +} + #[derive(Clone, Debug)] pub struct MemoryOpTarget { is_write: BoolTarget, @@ -148,61 +155,43 @@ impl, const D: usize> SimpleGenerator let n = self.input_ops.len(); debug_assert!(self.output_ops.len() == n); - let (timestamp_values, address_values): (Vec<_>, Vec<_>) = self + let mut ops: Vec<_> = self .input_ops .iter() .map(|op| { - ( - witness.get_target(op.timestamp), - witness.get_target(op.address), - ) - }) - .unzip(); - - let combined_values: Vec<_> = timestamp_values - .iter() - .zip(&address_values) - .map(|(&t, &a)| { - F::from_canonical_u64( - (a.to_canonical_u64() << self.timestamp_bits as u64) + t.to_canonical_u64(), - ) + let is_write = witness.get_bool_target(op.is_write); + let address = witness.get_target(op.address); + let timestamp = witness.get_target(op.timestamp); + let value = witness.get_target(op.value); + MemoryOp { + is_write, + address, + timestamp, + value, + } }) .collect(); - let mut input_ops_and_keys: Vec<_> = self - .input_ops - .iter() - .zip(combined_values) - .collect::>(); - input_ops_and_keys.sort_by_key(|(_, val)| val.to_canonical_u64()); + ops.sort_unstable_by_key(|op| { + ( + op.address.to_canonical_u64(), + op.timestamp.to_canonical_u64(), + ) + }); - for i in 0..n { - out_buffer.set_target( - self.output_ops[i].is_write.target, - witness.get_target(input_ops_and_keys[i].0.is_write.target), - ); - out_buffer.set_target( - self.output_ops[i].address, - witness.get_target(input_ops_and_keys[i].0.address), - ); - out_buffer.set_target( - self.output_ops[i].timestamp, - witness.get_target(input_ops_and_keys[i].0.timestamp), - ); - out_buffer.set_target( - self.output_ops[i].value, - witness.get_target(input_ops_and_keys[i].0.value), - ); + for (op, out_op) in ops.iter().zip(&self.output_ops) { + out_buffer.set_target(out_op.is_write.target, F::from_bool(op.is_write)); + out_buffer.set_target(out_op.address, op.address); + out_buffer.set_target(out_op.timestamp, op.timestamp); + out_buffer.set_target(out_op.value, op.value); } } } #[cfg(test)] mod tests { - use std::collections::HashSet; - use anyhow::Result; - use rand::{seq::SliceRandom, thread_rng, Rng}; + use rand::{thread_rng, Rng}; use super::*; use crate::field::crandall_field::CrandallField;