diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 21601988..5f074f94 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -1,13 +1,21 @@ use itertools::izip; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; +use crate::field::field_types::{Field, RichField}; use crate::gates::comparison::ComparisonGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; +use std::marker::PhantomData; + +pub struct MemoryOp { + is_write: bool, + address: F, + timestamp: F, + value: F, +} #[derive(Clone, Debug)] pub struct MemoryOpTarget { @@ -39,14 +47,12 @@ impl, const D: usize> CircuitBuilder { rhs: Target, bits: usize, num_chunks: usize, - ) -> (ComparisonGate, usize) { + ) { let gate = ComparisonGate::new(bits, num_chunks); let gate_index = self.add_gate(gate.clone(), vec![]); self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); self.connect(Target::wire(gate_index, gate.wire_second_input()), rhs); - - (gate, gate_index) } /// Sort memory operations by address value, then by timestamp value. @@ -94,29 +100,21 @@ impl, const D: usize> CircuitBuilder { .map(|op| self.mul_add(op.address, two_n, op.timestamp)) .collect(); - let mut gates = Vec::new(); - let mut gate_indices = Vec::new(); for i in 1..n { - let (gate, gate_index) = self.assert_le( + self.assert_le( address_timestamp_combined[i - 1], address_timestamp_combined[i], combined_bits, num_chunks, ); - - gate_indices.push(gate_index); - gates.push(gate); } - self.assert_permutation_memory_ops(ops, output_targets.as_slice()); + self.assert_permutation_memory_ops(ops, &output_targets); self.add_simple_generator(MemoryOpSortGenerator:: { input_ops: ops.to_vec(), - gate_indices, - gates: gates.clone(), output_ops: output_targets.clone(), - address_bits, - timestamp_bits, + _phantom: PhantomData, }); output_targets @@ -126,11 +124,8 @@ impl, const D: usize> CircuitBuilder { #[derive(Debug)] struct MemoryOpSortGenerator, const D: usize> { input_ops: Vec, - gate_indices: Vec, - gates: Vec>, output_ops: Vec, - address_bits: usize, - timestamp_bits: usize, + _phantom: PhantomData, } impl, const D: usize> SimpleGenerator @@ -148,61 +143,43 @@ impl, const D: usize> SimpleGenerator let n = self.input_ops.len(); debug_assert!(self.output_ops.len() == n); - let (timestamp_values, address_values): (Vec<_>, Vec<_>) = self + let mut ops: Vec<_> = self .input_ops .iter() .map(|op| { - ( - witness.get_target(op.timestamp), - witness.get_target(op.address), - ) - }) - .unzip(); - - let combined_values: Vec<_> = timestamp_values - .iter() - .zip(&address_values) - .map(|(&t, &a)| { - F::from_canonical_u64( - (a.to_canonical_u64() << self.timestamp_bits as u64) + t.to_canonical_u64(), - ) + let is_write = witness.get_bool_target(op.is_write); + let address = witness.get_target(op.address); + let timestamp = witness.get_target(op.timestamp); + let value = witness.get_target(op.value); + MemoryOp { + is_write, + address, + timestamp, + value, + } }) .collect(); - let mut input_ops_and_keys: Vec<_> = self - .input_ops - .iter() - .zip(combined_values) - .collect::>(); - input_ops_and_keys.sort_by_key(|(_, val)| val.to_canonical_u64()); + ops.sort_unstable_by_key(|op| { + ( + op.address.to_canonical_u64(), + op.timestamp.to_canonical_u64(), + ) + }); - for i in 0..n { - out_buffer.set_target( - self.output_ops[i].is_write.target, - witness.get_target(input_ops_and_keys[i].0.is_write.target), - ); - out_buffer.set_target( - self.output_ops[i].address, - witness.get_target(input_ops_and_keys[i].0.address), - ); - out_buffer.set_target( - self.output_ops[i].timestamp, - witness.get_target(input_ops_and_keys[i].0.timestamp), - ); - out_buffer.set_target( - self.output_ops[i].value, - witness.get_target(input_ops_and_keys[i].0.value), - ); + for (op, out_op) in ops.iter().zip(&self.output_ops) { + out_buffer.set_target(out_op.is_write.target, F::from_bool(op.is_write)); + out_buffer.set_target(out_op.address, op.address); + out_buffer.set_target(out_op.timestamp, op.timestamp); + out_buffer.set_target(out_op.value, op.value); } } } #[cfg(test)] mod tests { - use std::collections::HashSet; - use anyhow::Result; - use rand::{seq::SliceRandom, thread_rng, Rng}; + use rand::{thread_rng, Rng}; use super::*; use crate::field::crandall_field::CrandallField;