From d541e251ee95975638fb07db2b5d34f5e446bb01 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 22 Sep 2021 18:10:38 -0700 Subject: [PATCH 1/2] Add a MemoryOp to simplify MemoryOpSortGenerator --- src/gadgets/sorting.rs | 73 ++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 42 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 21601988..42dc8541 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -1,7 +1,7 @@ use itertools::izip; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; +use crate::field::field_types::{Field, RichField}; use crate::gates::comparison::ComparisonGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; @@ -9,6 +9,13 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; +pub struct MemoryOp { + is_write: bool, + address: F, + timestamp: F, + value: F, +} + #[derive(Clone, Debug)] pub struct MemoryOpTarget { is_write: BoolTarget, @@ -148,61 +155,43 @@ impl, const D: usize> SimpleGenerator let n = self.input_ops.len(); debug_assert!(self.output_ops.len() == n); - let (timestamp_values, address_values): (Vec<_>, Vec<_>) = self + let mut ops: Vec<_> = self .input_ops .iter() .map(|op| { - ( - witness.get_target(op.timestamp), - witness.get_target(op.address), - ) - }) - .unzip(); - - let combined_values: Vec<_> = timestamp_values - .iter() - .zip(&address_values) - .map(|(&t, &a)| { - F::from_canonical_u64( - (a.to_canonical_u64() << self.timestamp_bits as u64) + t.to_canonical_u64(), - ) + let is_write = witness.get_bool_target(op.is_write); + let address = witness.get_target(op.address); + let timestamp = witness.get_target(op.timestamp); + let value = witness.get_target(op.value); + MemoryOp { + is_write, + address, + timestamp, + value, + } }) .collect(); - let mut input_ops_and_keys: Vec<_> = self - .input_ops - .iter() - .zip(combined_values) - .collect::>(); - input_ops_and_keys.sort_by_key(|(_, val)| val.to_canonical_u64()); + ops.sort_unstable_by_key(|op| { + ( + op.address.to_canonical_u64(), + op.timestamp.to_canonical_u64(), + ) + }); - for i in 0..n { - out_buffer.set_target( - self.output_ops[i].is_write.target, - witness.get_target(input_ops_and_keys[i].0.is_write.target), - ); - out_buffer.set_target( - self.output_ops[i].address, - witness.get_target(input_ops_and_keys[i].0.address), - ); - out_buffer.set_target( - self.output_ops[i].timestamp, - witness.get_target(input_ops_and_keys[i].0.timestamp), - ); - out_buffer.set_target( - self.output_ops[i].value, - witness.get_target(input_ops_and_keys[i].0.value), - ); + for (op, out_op) in ops.iter().zip(&self.output_ops) { + out_buffer.set_target(out_op.is_write.target, F::from_bool(op.is_write)); + out_buffer.set_target(out_op.address, op.address); + out_buffer.set_target(out_op.timestamp, op.timestamp); + out_buffer.set_target(out_op.value, op.value); } } } #[cfg(test)] mod tests { - use std::collections::HashSet; - use anyhow::Result; - use rand::{seq::SliceRandom, thread_rng, Rng}; + use rand::{thread_rng, Rng}; use super::*; use crate::field::crandall_field::CrandallField; From 202967a40bee6bec7e3e2dbbffc4ae43d488a009 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 22 Sep 2021 18:14:58 -0700 Subject: [PATCH 2/2] Other tweaks --- src/gadgets/sorting.rs | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 42dc8541..5f074f94 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -8,6 +8,7 @@ use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; +use std::marker::PhantomData; pub struct MemoryOp { is_write: bool, @@ -46,14 +47,12 @@ impl, const D: usize> CircuitBuilder { rhs: Target, bits: usize, num_chunks: usize, - ) -> (ComparisonGate, usize) { + ) { let gate = ComparisonGate::new(bits, num_chunks); let gate_index = self.add_gate(gate.clone(), vec![]); self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); self.connect(Target::wire(gate_index, gate.wire_second_input()), rhs); - - (gate, gate_index) } /// Sort memory operations by address value, then by timestamp value. @@ -101,29 +100,21 @@ impl, const D: usize> CircuitBuilder { .map(|op| self.mul_add(op.address, two_n, op.timestamp)) .collect(); - let mut gates = Vec::new(); - let mut gate_indices = Vec::new(); for i in 1..n { - let (gate, gate_index) = self.assert_le( + self.assert_le( address_timestamp_combined[i - 1], address_timestamp_combined[i], combined_bits, num_chunks, ); - - gate_indices.push(gate_index); - gates.push(gate); } - self.assert_permutation_memory_ops(ops, output_targets.as_slice()); + self.assert_permutation_memory_ops(ops, &output_targets); self.add_simple_generator(MemoryOpSortGenerator:: { input_ops: ops.to_vec(), - gate_indices, - gates: gates.clone(), output_ops: output_targets.clone(), - address_bits, - timestamp_bits, + _phantom: PhantomData, }); output_targets @@ -133,11 +124,8 @@ impl, const D: usize> CircuitBuilder { #[derive(Debug)] struct MemoryOpSortGenerator, const D: usize> { input_ops: Vec, - gate_indices: Vec, - gates: Vec>, output_ops: Vec, - address_bits: usize, - timestamp_bits: usize, + _phantom: PhantomData, } impl, const D: usize> SimpleGenerator