From a00f2536ffb0b8186095329ff5d5264f4dd3a117 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 16 Sep 2021 20:44:09 -0700 Subject: [PATCH 01/12] initial memory sorting gadget --- src/gadgets/mod.rs | 1 + src/gadgets/sorting.rs | 222 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+) create mode 100644 src/gadgets/sorting.rs diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 4b3371ef..aa18fbeb 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -8,5 +8,6 @@ pub mod polynomial; pub mod random_access; pub mod range_check; pub mod select; +pub mod sorting; pub mod split_base; pub(crate) mod split_join; diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs new file mode 100644 index 00000000..40ec510c --- /dev/null +++ b/src/gadgets/sorting.rs @@ -0,0 +1,222 @@ +use itertools::izip; +use std::marker::PhantomData; + +use crate::field::field_types::RichField; +use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::gates::comparison::ComparisonGate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::{BoolTarget, Target}; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; + +pub struct MemoryOpTarget { + is_write: BoolTarget, + address: Target, + timestamp: Target, + value: Target, +} + +impl, const D: usize> CircuitBuilder { + pub fn assert_permutation_memory_ops(&mut self, a: &[MemoryOpTarget], b: &[MemoryOpTarget]) { + let a_chunks: Vec> = a.iter().map(|op| { + vec![op.address, op.timestamp, op.is_write.target, op.value] + }).collect(); + let b_chunks: Vec> = b.iter().map(|op| { + vec![op.address, op.timestamp, op.is_write.target, op.value] + }).collect(); + + self.assert_permutation(a_chunks, b_chunks); + } + + pub fn sort_memory_ops(&mut self, ops: &[MemoryOpTarget], address_bits: usize, timestamp_bits: usize) -> Vec { + let n = ops.len(); + + let address_chunk_size = (address_bits as f64).sqrt() as usize; + let timestamp_chunk_size = (timestamp_bits as f64).sqrt() as usize; + + let is_write_targets: Vec<_> = self.add_virtual_targets(n).iter().map(|&t| BoolTarget::new_unsafe(t)).collect(); + let address_targets = self.add_virtual_targets(n); + let timestamp_targets = self.add_virtual_targets(n); + let value_targets = self.add_virtual_targets(n); + + let output_targets: Vec<_> = izip!(is_write_targets, address_targets, timestamp_targets, value_targets).map(|(i, a, t, v)| { + MemoryOpTarget { + is_write: i, + address: a, + timestamp: t, + value: v, + } + }).collect(); + + for i in 1..n { + let (address_gate, address_gate_index) = { + let gate = ComparisonGate::new(address_bits, address_chunk_size); + let gate_index = self.add_gate(gate.clone(), vec![]); + (gate, gate_index) + }; + + self.connect( + Target::wire(address_gate_index, address_gate.wire_first_input()), + output_targets[i-1].address, + ); + self.connect( + Target::wire(address_gate_index, address_gate.wire_second_input()), + output_targets[i].address, + ); + + let (timestamp_gate, timestamp_gate_index) = { + let gate = ComparisonGate::new(timestamp_bits, timestamp_chunk_size); + let gate_index = self.add_gate(gate.clone(), vec![]); + (gate, gate_index) + }; + + self.connect( + Target::wire(timestamp_gate_index, timestamp_gate.wire_first_input()), + output_targets[i-1].timestamp, + ); + self.connect( + Target::wire(timestamp_gate_index, timestamp_gate.wire_second_input()), + output_targets[i].timestamp, + ); + } + + self.assert_permutation_memory_ops(ops, output_targets.as_slice()); + + output_targets + } +} + +/*#[derive(Debug)] +struct MemoryOpSortGenerator { + a: Vec>, + b: Vec>, + a_switches: Vec, + b_switches: Vec, + _phantom: PhantomData, +} + +impl SimpleGenerator for MemoryOpSortGenerator { + fn dependencies(&self) -> Vec { + self.a.iter().chain(&self.b).flatten().cloned().collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a_values = self + .a + .iter() + .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) + .collect(); + let b_values = self + .b + .iter() + .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) + .collect(); + route( + a_values, + b_values, + self.a_switches.clone(), + self.b_switches.clone(), + witness, + out_buffer, + ); + } +}*/ + +#[cfg(test)] +mod tests { + use anyhow::Result; + use rand::{seq::SliceRandom, thread_rng}; + + use super::*; + use crate::field::crandall_field::CrandallField; + use crate::field::field_types::Field; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + fn test_permutation_good(size: usize) -> Result<()> { + type F = CrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_zk_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let lst: Vec = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect(); + let a: Vec> = lst[..] + .chunks(2) + .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) + .collect(); + let mut b = a.clone(); + b.shuffle(&mut thread_rng()); + + builder.assert_permutation(a, b); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + fn test_permutation_bad(size: usize) -> Result<()> { + type F = CrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_zk_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let lst1: Vec = F::rand_vec(size * 2); + let lst2: Vec = F::rand_vec(size * 2); + let a: Vec> = lst1[..] + .chunks(2) + .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) + .collect(); + let b: Vec> = lst2[..] + .chunks(2) + .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) + .collect(); + + builder.assert_permutation(a, b); + + let data = builder.build(); + data.prove(pw).unwrap(); + + Ok(()) + } + + #[test] + fn test_permutations_good() -> Result<()> { + for n in 2..9 { + test_permutation_good(n)?; + } + + Ok(()) + } + + #[test] + #[should_panic] + fn test_permutation_bad_small() { + let size = 2; + + test_permutation_bad(size).unwrap() + } + + #[test] + #[should_panic] + fn test_permutation_bad_medium() { + let size = 6; + + test_permutation_bad(size).unwrap() + } + + #[test] + #[should_panic] + fn test_permutation_bad_large() { + let size = 10; + + test_permutation_bad(size).unwrap() + } +} From 14fd1dfa6b9f9b1c5bf951a0ee4bfc93160a876c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 16 Sep 2021 21:06:54 -0700 Subject: [PATCH 02/12] fmt --- src/gadgets/sorting.rs | 55 +++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 40ec510c..77d4d173 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -1,6 +1,7 @@ -use itertools::izip; use std::marker::PhantomData; +use itertools::izip; + use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gates::comparison::ComparisonGate; @@ -18,35 +19,51 @@ pub struct MemoryOpTarget { impl, const D: usize> CircuitBuilder { pub fn assert_permutation_memory_ops(&mut self, a: &[MemoryOpTarget], b: &[MemoryOpTarget]) { - let a_chunks: Vec> = a.iter().map(|op| { - vec![op.address, op.timestamp, op.is_write.target, op.value] - }).collect(); - let b_chunks: Vec> = b.iter().map(|op| { - vec![op.address, op.timestamp, op.is_write.target, op.value] - }).collect(); + let a_chunks: Vec> = a + .iter() + .map(|op| vec![op.address, op.timestamp, op.is_write.target, op.value]) + .collect(); + let b_chunks: Vec> = b + .iter() + .map(|op| vec![op.address, op.timestamp, op.is_write.target, op.value]) + .collect(); self.assert_permutation(a_chunks, b_chunks); } - pub fn sort_memory_ops(&mut self, ops: &[MemoryOpTarget], address_bits: usize, timestamp_bits: usize) -> Vec { + pub fn sort_memory_ops( + &mut self, + ops: &[MemoryOpTarget], + address_bits: usize, + timestamp_bits: usize, + ) -> Vec { let n = ops.len(); let address_chunk_size = (address_bits as f64).sqrt() as usize; let timestamp_chunk_size = (timestamp_bits as f64).sqrt() as usize; - let is_write_targets: Vec<_> = self.add_virtual_targets(n).iter().map(|&t| BoolTarget::new_unsafe(t)).collect(); + let is_write_targets: Vec<_> = self + .add_virtual_targets(n) + .iter() + .map(|&t| BoolTarget::new_unsafe(t)) + .collect(); let address_targets = self.add_virtual_targets(n); let timestamp_targets = self.add_virtual_targets(n); let value_targets = self.add_virtual_targets(n); - let output_targets: Vec<_> = izip!(is_write_targets, address_targets, timestamp_targets, value_targets).map(|(i, a, t, v)| { - MemoryOpTarget { - is_write: i, - address: a, - timestamp: t, - value: v, - } - }).collect(); + let output_targets: Vec<_> = izip!( + is_write_targets, + address_targets, + timestamp_targets, + value_targets + ) + .map(|(i, a, t, v)| MemoryOpTarget { + is_write: i, + address: a, + timestamp: t, + value: v, + }) + .collect(); for i in 1..n { let (address_gate, address_gate_index) = { @@ -57,7 +74,7 @@ impl, const D: usize> CircuitBuilder { self.connect( Target::wire(address_gate_index, address_gate.wire_first_input()), - output_targets[i-1].address, + output_targets[i - 1].address, ); self.connect( Target::wire(address_gate_index, address_gate.wire_second_input()), @@ -72,7 +89,7 @@ impl, const D: usize> CircuitBuilder { self.connect( Target::wire(timestamp_gate_index, timestamp_gate.wire_first_input()), - output_targets[i-1].timestamp, + output_targets[i - 1].timestamp, ); self.connect( Target::wire(timestamp_gate_index, timestamp_gate.wire_second_input()), From 2c1c116ead0e4a0b9d3fd19124b9a4abce543d7a Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 17 Sep 2021 13:09:24 -0700 Subject: [PATCH 03/12] fixes (addressed comments) --- src/gadgets/sorting.rs | 37 ++++++++++++++----------------------- src/gates/comparison.rs | 39 +++++++++++++++++++++++++++++---------- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 77d4d173..1547295b 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -39,8 +39,8 @@ impl, const D: usize> CircuitBuilder { ) -> Vec { let n = ops.len(); - let address_chunk_size = (address_bits as f64).sqrt() as usize; - let timestamp_chunk_size = (timestamp_bits as f64).sqrt() as usize; + let combined_bits = address_bits + timestamp_bits; + let chunk_size = 3; let is_write_targets: Vec<_> = self .add_virtual_targets(n) @@ -65,35 +65,26 @@ impl, const D: usize> CircuitBuilder { }) .collect(); + let two_n = self.constant(F::from_canonical_usize(1 << timestamp_bits)); + let address_timestamp_combined: Vec<_> = output_targets + .iter() + .map(|op| self.mul_add(op.timestamp, two_n, op.address)) + .collect(); + for i in 1..n { - let (address_gate, address_gate_index) = { - let gate = ComparisonGate::new(address_bits, address_chunk_size); + let (gate, gate_index) = { + let gate = ComparisonGate::new(combined_bits, chunk_size); let gate_index = self.add_gate(gate.clone(), vec![]); (gate, gate_index) }; self.connect( - Target::wire(address_gate_index, address_gate.wire_first_input()), - output_targets[i - 1].address, + Target::wire(gate_index, gate.wire_first_input()), + address_timestamp_combined[i - 1], ); self.connect( - Target::wire(address_gate_index, address_gate.wire_second_input()), - output_targets[i].address, - ); - - let (timestamp_gate, timestamp_gate_index) = { - let gate = ComparisonGate::new(timestamp_bits, timestamp_chunk_size); - let gate_index = self.add_gate(gate.clone(), vec![]); - (gate, gate_index) - }; - - self.connect( - Target::wire(timestamp_gate_index, timestamp_gate.wire_first_input()), - output_targets[i - 1].timestamp, - ); - self.connect( - Target::wire(timestamp_gate_index, timestamp_gate.wire_second_input()), - output_targets[i].timestamp, + Target::wire(gate_index, gate.wire_second_input()), + address_timestamp_combined[i], ); } diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index d928bd6f..86601fba 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -13,7 +13,7 @@ use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recu use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::util::ceil_div_usize; -/// A gate for checking that one value is less than another. +/// A gate for checking that one value is less than or equal to another. #[derive(Clone, Debug)] pub(crate) struct ComparisonGate, const D: usize> { pub(crate) num_bits: usize, @@ -137,7 +137,7 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(most_significant_diff - most_significant_diff_so_far); // Range check `most_significant_diff` to be less than `chunk_size`. - let product = (1..chunk_size) + let product = (0..chunk_size) .map(|x| most_significant_diff - F::Extension::from_canonical_usize(x)) .product(); constraints.push(product); @@ -205,7 +205,7 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(most_significant_diff - most_significant_diff_so_far); // Range check `most_significant_diff` to be less than `chunk_size`. - let product = (1..chunk_size) + let product = (0..chunk_size) .map(|x| most_significant_diff - F::from_canonical_usize(x)) .product(); constraints.push(product); @@ -286,7 +286,7 @@ impl, const D: usize> Gate for ComparisonGate // Range check `most_significant_diff` to be less than `chunk_size`. let mut product = builder.one_extension(); - for x in 1..chunk_size { + for x in 0..chunk_size { let x_F = builder.constant_extension(F::Extension::from_canonical_usize(x)); let diff = builder.sub_extension(most_significant_diff, x_F); product = builder.mul_extension(product, diff); @@ -553,7 +553,7 @@ mod tests { let first_input_u64 = rng.gen_range(0..max); let second_input_u64 = { let mut val = rng.gen_range(0..max); - while val <= first_input_u64 { + while val < first_input_u64 { val = rng.gen_range(0..max); } val @@ -562,20 +562,39 @@ mod tests { let first_input = F::from_canonical_u64(first_input_u64); let second_input = F::from_canonical_u64(second_input_u64); - let gate = ComparisonGate:: { + let less_than_gate = ComparisonGate:: { num_bits, num_chunks, _phantom: PhantomData, }; - - let vars = EvaluationVars { + let less_than_vars = EvaluationVars { local_constants: &[], local_wires: &get_wires(first_input, second_input), public_inputs_hash: &HashOut::rand(), }; - assert!( - gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + less_than_gate + .eval_unfiltered(less_than_vars) + .iter() + .all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + + let equal_gate = ComparisonGate:: { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + let equal_vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(first_input, first_input), + public_inputs_hash: &HashOut::rand(), + }; + assert!( + equal_gate + .eval_unfiltered(equal_vars) + .iter() + .all(|x| x.is_zero()), "Gate constraints are not satisfied." ); } From 8dd00b8d41ed3b60d43b1fa17902d6f499571e71 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 17 Sep 2021 13:40:07 -0700 Subject: [PATCH 04/12] added generator --- src/gadgets/sorting.rs | 87 +++++++++++++++++++++++++++++------------- 1 file changed, 61 insertions(+), 26 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 1547295b..102f30cd 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -1,8 +1,8 @@ use std::marker::PhantomData; -use itertools::izip; +use itertools::{izip, Itertools}; -use crate::field::field_types::RichField; +use crate::field::field_types::{PrimeField, RichField}; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gates::comparison::ComparisonGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; @@ -10,6 +10,7 @@ use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; +#[derive(Debug)] pub struct MemoryOpTarget { is_write: BoolTarget, address: Target, @@ -94,41 +95,75 @@ impl, const D: usize> CircuitBuilder { } } -/*#[derive(Debug)] -struct MemoryOpSortGenerator { - a: Vec>, - b: Vec>, - a_switches: Vec, - b_switches: Vec, +#[derive(Debug)] +struct MemoryOpSortGenerator { + input_ops: Vec, + output_ops: Vec, + address_bits: usize, + timestamp_bits: usize, _phantom: PhantomData, } -impl SimpleGenerator for MemoryOpSortGenerator { +impl SimpleGenerator for MemoryOpSortGenerator { fn dependencies(&self) -> Vec { - self.a.iter().chain(&self.b).flatten().cloned().collect() + self.input_ops + .iter() + .map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) + .flatten() + .collect() } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a_values = self - .a + let n = self.input_ops.len(); + debug_assert!(self.output_ops.len() == n); + + let (timestamp_values, address_values): (Vec<_>, Vec<_>) = self + .input_ops .iter() - .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) - .collect(); - let b_values = self - .b + .map(|op| { + ( + witness.get_target(op.timestamp), + witness.get_target(op.address), + ) + }) + .unzip(); + + let combined_values_u64: Vec<_> = timestamp_values .iter() - .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) + .zip(address_values.iter()) + .map(|(&t, &a)| { + a.to_canonical_u64() * (1 << self.timestamp_bits as u64) + t.to_canonical_u64() + }) .collect(); - route( - a_values, - b_values, - self.a_switches.clone(), - self.b_switches.clone(), - witness, - out_buffer, - ); + + let mut input_ops_and_keys: Vec<_> = self + .input_ops + .iter() + .zip(combined_values_u64) + .collect::>(); + input_ops_and_keys.sort_by(|(_, a_val), (_, b_val)| a_val.cmp(b_val)); + let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(op, _)| op).collect(); + + for i in 0..n { + out_buffer.set_target( + self.output_ops[i].is_write.target, + witness.get_target(input_ops_sorted[i].is_write.target), + ); + out_buffer.set_target( + self.output_ops[i].address, + witness.get_target(input_ops_sorted[i].address), + ); + out_buffer.set_target( + self.output_ops[i].timestamp, + witness.get_target(input_ops_sorted[i].timestamp), + ); + out_buffer.set_target( + self.output_ops[i].value, + witness.get_target(input_ops_sorted[i].value), + ); + } } -}*/ +} #[cfg(test)] mod tests { From 3d93766cc804899dc1a311794ce7f2eba800b5ac Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 17 Sep 2021 14:50:37 -0700 Subject: [PATCH 05/12] test (wip) --- src/gadgets/sorting.rs | 93 ++++++++++++------------------------------ 1 file changed, 27 insertions(+), 66 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 102f30cd..1d1938b4 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -168,7 +168,7 @@ impl SimpleGenerator for MemoryOpSortGenerator { #[cfg(test)] mod tests { use anyhow::Result; - use rand::{seq::SliceRandom, thread_rng}; + use rand::{seq::SliceRandom, thread_rng, Rng}; use super::*; use crate::field::crandall_field::CrandallField; @@ -177,7 +177,7 @@ mod tests { use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; - fn test_permutation_good(size: usize) -> Result<()> { + fn test_sorting(size: usize, address_bits: usize, timestamp_bits: usize) -> Result<()> { type F = CrandallField; const D: usize = 4; @@ -186,15 +186,28 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let lst: Vec = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect(); - let a: Vec> = lst[..] - .chunks(2) - .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) + let mut rng = thread_rng(); + let is_write_vals: Vec<_> = (0..size).map(|_| rng.gen_range(0..2) != 0).collect(); + let address_vals: Vec<_> = (0..size) + .map(|_| F::from_canonical_u64(rng.gen_range(0..1 << address_bits as u64))) .collect(); - let mut b = a.clone(); - b.shuffle(&mut thread_rng()); + let timestamp_vals: Vec<_> = (0..size) + .map(|_| F::from_canonical_u64(rng.gen_range(0..1 << timestamp_bits as u64))) + .collect(); + let value_vals: Vec<_> = (0..size).map(|_| F::rand()).collect(); - builder.assert_permutation(a, b); + let input_ops: Vec = + izip!(is_write_vals, address_vals, timestamp_vals, value_vals) + .map(|(is_write, address, timestamp, value)| MemoryOpTarget { + is_write: builder.constant_bool(is_write), + address: builder.constant(address), + timestamp: builder.constant(timestamp), + value: builder.constant(value), + }) + .collect(); + + let _output_ops = + builder.sort_memory_ops(input_ops.as_slice(), address_bits, timestamp_bits); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -202,64 +215,12 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } - fn test_permutation_bad(size: usize) -> Result<()> { - type F = CrandallField; - const D: usize = 4; - - let config = CircuitConfig::large_zk_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let lst1: Vec = F::rand_vec(size * 2); - let lst2: Vec = F::rand_vec(size * 2); - let a: Vec> = lst1[..] - .chunks(2) - .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) - .collect(); - let b: Vec> = lst2[..] - .chunks(2) - .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) - .collect(); - - builder.assert_permutation(a, b); - - let data = builder.build(); - data.prove(pw).unwrap(); - - Ok(()) - } - #[test] - fn test_permutations_good() -> Result<()> { - for n in 2..9 { - test_permutation_good(n)?; - } + fn test_sorting_small() -> Result<()> { + let size = 5; + let address_bits = 20; + let timestamp_bits = 20; - Ok(()) - } - - #[test] - #[should_panic] - fn test_permutation_bad_small() { - let size = 2; - - test_permutation_bad(size).unwrap() - } - - #[test] - #[should_panic] - fn test_permutation_bad_medium() { - let size = 6; - - test_permutation_bad(size).unwrap() - } - - #[test] - #[should_panic] - fn test_permutation_bad_large() { - let size = 10; - - test_permutation_bad(size).unwrap() + test_sorting(size, address_bits, timestamp_bits) } } From 644d87e49549fa324fdea647c38d23b0198679c3 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 21 Sep 2021 18:01:21 -0700 Subject: [PATCH 06/12] fixes galore --- src/gadgets/permutation.rs | 40 +++++++++++++++++++- src/gadgets/sorting.rs | 75 ++++++++++++++++++++++++++++++++------ src/gates/comparison.rs | 2 +- src/gates/switch.rs | 51 +++++++++++++++++++------- 4 files changed, 140 insertions(+), 28 deletions(-) diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index 126846ec..0f320dfd 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -384,7 +384,7 @@ impl SimpleGenerator for PermutationGenerator { #[cfg(test)] mod tests { use anyhow::Result; - use rand::{seq::SliceRandom, thread_rng}; + use rand::{seq::SliceRandom, thread_rng, Rng}; use super::*; use crate::field::crandall_field::CrandallField; @@ -418,6 +418,35 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + fn test_permutation_duplicates(size: usize) -> Result<()> { + type F = CrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_zk_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let mut rng = thread_rng(); + let lst: Vec = (0..size * 2) + .map(|_| F::from_canonical_usize(rng.gen_range(0..2usize))) + .collect(); + let a: Vec> = lst[..] + .chunks(2) + .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) + .collect(); + + let mut b = a.clone(); + b.shuffle(&mut thread_rng()); + + builder.assert_permutation(a, b); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + fn test_permutation_bad(size: usize) -> Result<()> { type F = CrandallField; const D: usize = 4; @@ -446,6 +475,15 @@ mod tests { Ok(()) } + #[test] + fn test_permutations_duplicates() -> Result<()> { + for n in 2..9 { + test_permutation_duplicates(n)?; + } + + Ok(()) + } + #[test] fn test_permutations_good() -> Result<()> { for n in 2..9 { diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 1d1938b4..099209c9 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -1,16 +1,15 @@ -use std::marker::PhantomData; +use itertools::izip; -use itertools::{izip, Itertools}; - -use crate::field::field_types::{PrimeField, RichField}; -use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::field::field_types::RichField; +use crate::field::extension_field::Extendable; use crate::gates::comparison::ComparisonGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::ceil_div_usize; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct MemoryOpTarget { is_write: BoolTarget, address: Target, @@ -41,7 +40,8 @@ impl, const D: usize> CircuitBuilder { let n = ops.len(); let combined_bits = address_bits + timestamp_bits; - let chunk_size = 3; + let chunk_bits = 3; + let num_chunks = ceil_div_usize(combined_bits, chunk_bits); let is_write_targets: Vec<_> = self .add_virtual_targets(n) @@ -69,12 +69,14 @@ impl, const D: usize> CircuitBuilder { let two_n = self.constant(F::from_canonical_usize(1 << timestamp_bits)); let address_timestamp_combined: Vec<_> = output_targets .iter() - .map(|op| self.mul_add(op.timestamp, two_n, op.address)) + .map(|op| self.mul_add(op.address, two_n, op.timestamp)) .collect(); + let mut gate_indices = Vec::new(); + let mut gates = Vec::new(); for i in 1..n { let (gate, gate_index) = { - let gate = ComparisonGate::new(combined_bits, chunk_size); + let gate = ComparisonGate::new(combined_bits, num_chunks); let gate_index = self.add_gate(gate.clone(), vec![]); (gate, gate_index) }; @@ -87,24 +89,39 @@ impl, const D: usize> CircuitBuilder { Target::wire(gate_index, gate.wire_second_input()), address_timestamp_combined[i], ); + + gate_indices.push(gate_index); + gates.push(gate); } self.assert_permutation_memory_ops(ops, output_targets.as_slice()); + self.add_simple_generator(MemoryOpSortGenerator:: { + input_ops: ops.to_vec(), + gate_indices, + gates: gates.clone(), + output_ops: output_targets.clone(), + address_bits, + timestamp_bits, + }); + output_targets } } #[derive(Debug)] -struct MemoryOpSortGenerator { +struct MemoryOpSortGenerator, const D: usize> { input_ops: Vec, + gate_indices: Vec, + gates: Vec>, output_ops: Vec, address_bits: usize, timestamp_bits: usize, - _phantom: PhantomData, } -impl SimpleGenerator for MemoryOpSortGenerator { +impl, const D: usize> SimpleGenerator + for MemoryOpSortGenerator +{ fn dependencies(&self) -> Vec { self.input_ops .iter() @@ -136,6 +153,13 @@ impl SimpleGenerator for MemoryOpSortGenerator { }) .collect(); + let mut combined_values_sorted = combined_values_u64.clone(); + combined_values_sorted.sort(); + let combined_values: Vec<_> = combined_values_sorted + .iter() + .map(|&x| F::from_canonical_u64(x)) + .collect(); + let mut input_ops_and_keys: Vec<_> = self .input_ops .iter() @@ -161,12 +185,30 @@ impl SimpleGenerator for MemoryOpSortGenerator { self.output_ops[i].value, witness.get_target(input_ops_sorted[i].value), ); + + if i > 0 { + out_buffer.set_target( + Target::wire( + self.gate_indices[i - 1], + self.gates[i - 1].wire_second_input(), + ), + combined_values[i], + ); + } + if i < n - 1 { + out_buffer.set_target( + Target::wire(self.gate_indices[i], self.gates[i].wire_first_input()), + combined_values[i], + ); + } } } } #[cfg(test)] mod tests { + use std::collections::HashSet; + use anyhow::Result; use rand::{seq::SliceRandom, thread_rng, Rng}; @@ -223,4 +265,13 @@ mod tests { test_sorting(size, address_bits, timestamp_bits) } + + #[test] + fn test_sorting_large() -> Result<()> { + let size = 20; + let address_bits = 20; + let timestamp_bits = 20; + + test_sorting(size, address_bits, timestamp_bits) + } } diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 86601fba..b1d57929 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -386,7 +386,7 @@ impl, const D: usize> SimpleGenerator let mut most_significant_diff_so_far = F::ZERO; let mut intermediate_values = Vec::new(); - for i in 1..self.gate.num_chunks { + for i in 0..self.gate.num_chunks { if first_input_chunks[i] != second_input_chunks[i] { most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; intermediate_values.push(F::ZERO); diff --git a/src/gates/switch.rs b/src/gates/switch.rs index 1df5fea9..26efac0e 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -236,20 +236,43 @@ impl, const D: usize> SwitchGenerator { let get_local_wire = |input| witness.get_wire(local_wire(input)); - for e in 0..self.gate.chunk_size { - let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy)); - let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e)); - let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e)); - let first_output = get_local_wire(self.gate.wire_first_output(self.copy, e)); - let second_output = get_local_wire(self.gate.wire_second_output(self.copy, e)); + let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy)); - if first_output == first_input && second_output == second_input { - out_buffer.set_wire(switch_bool_wire, F::ZERO); - } else if first_output == second_input && second_output == first_input { - out_buffer.set_wire(switch_bool_wire, F::ONE); - } else { - panic!("No permutation from given inputs to given outputs"); - } + let mut first_inputs = Vec::new(); + let mut second_inputs = Vec::new(); + let mut first_outputs = Vec::new(); + let mut second_outputs = Vec::new(); + for e in 0..self.gate.chunk_size { + first_inputs.push(get_local_wire(self.gate.wire_first_input(self.copy, e))); + second_inputs.push(get_local_wire(self.gate.wire_second_input(self.copy, e))); + first_outputs.push(get_local_wire(self.gate.wire_first_output(self.copy, e))); + second_outputs.push(get_local_wire(self.gate.wire_second_output(self.copy, e))); + } + + let first_keep = first_outputs + .iter() + .zip(first_inputs.iter()) + .all(|(x, y)| x == y); + let second_keep = second_outputs + .iter() + .zip(second_inputs.iter()) + .all(|(x, y)| x == y); + + let first_swap = first_outputs + .iter() + .zip(second_inputs.iter()) + .all(|(x, y)| x == y); + let second_swap = second_outputs + .iter() + .zip(first_inputs.iter()) + .all(|(x, y)| x == y); + + if first_keep && second_keep { + out_buffer.set_wire(switch_bool_wire, F::ZERO); + } else if first_swap && second_swap { + out_buffer.set_wire(switch_bool_wire, F::ONE); + } else { + panic!("No permutation from given inputs to given outputs"); } } @@ -261,12 +284,12 @@ impl, const D: usize> SwitchGenerator { let get_local_wire = |input| witness.get_wire(local_wire(input)); + let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy)); for e in 0..self.gate.chunk_size { let first_output_wire = local_wire(self.gate.wire_first_output(self.copy, e)); let second_output_wire = local_wire(self.gate.wire_second_output(self.copy, e)); let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e)); let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e)); - let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy)); let (first_output, second_output) = if switch_bool == F::ZERO { (first_input, second_input) From 6c4173d2eced0b2d937968761c9e6a9e03022b8a Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 21 Sep 2021 18:02:56 -0700 Subject: [PATCH 07/12] fmt --- src/gadgets/sorting.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 099209c9..f3b9cf0d 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -1,7 +1,7 @@ use itertools::izip; -use crate::field::field_types::RichField; use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; use crate::gates::comparison::ComparisonGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; From 2ec3b2974157041d72682f22e1814bf78662d325 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 22 Sep 2021 11:49:28 -0700 Subject: [PATCH 08/12] addressed comments --- src/gadgets/sorting.rs | 38 ++++++++++++++++++++++++++------------ src/gates/comparison.rs | 3 +-- src/gates/switch.rs | 22 ++-------------------- 3 files changed, 29 insertions(+), 34 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index f3b9cf0d..57f65a51 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -31,6 +31,26 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation(a_chunks, b_chunks); } + /// Add a ComparisonGate to + /// Returns the gate and its index + pub fn assert_le( + &mut self, + lhs: Target, + rhs: Target, + bits: usize, + num_chunks: usize, + ) -> (ComparisonGate, usize) { + let gate = ComparisonGate::new(bits, num_chunks); + let gate_index = self.add_gate(gate.clone(), vec![]); + + self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); + self.connect(Target::wire(gate_index, gate.wire_second_input()), rhs); + + (gate, gate_index) + } + + /// Sort memory operations by address value, then by timestamp value. + /// This is done by combining address and timestamp into one field element (using their given bit lengths). pub fn sort_memory_ops( &mut self, ops: &[MemoryOpTarget], @@ -43,11 +63,13 @@ impl, const D: usize> CircuitBuilder { let chunk_bits = 3; let num_chunks = ceil_div_usize(combined_bits, chunk_bits); + // This is safe because `assert_permutation` will force these targets (in the output list) to match the boolean values from the input list. let is_write_targets: Vec<_> = self .add_virtual_targets(n) .iter() .map(|&t| BoolTarget::new_unsafe(t)) .collect(); + let address_targets = self.add_virtual_targets(n); let timestamp_targets = self.add_virtual_targets(n); let value_targets = self.add_virtual_targets(n); @@ -72,22 +94,14 @@ impl, const D: usize> CircuitBuilder { .map(|op| self.mul_add(op.address, two_n, op.timestamp)) .collect(); - let mut gate_indices = Vec::new(); let mut gates = Vec::new(); + let mut gate_indices = Vec::new(); for i in 1..n { - let (gate, gate_index) = { - let gate = ComparisonGate::new(combined_bits, num_chunks); - let gate_index = self.add_gate(gate.clone(), vec![]); - (gate, gate_index) - }; - - self.connect( - Target::wire(gate_index, gate.wire_first_input()), + let (gate, gate_index) = self.assert_le( address_timestamp_combined[i - 1], - ); - self.connect( - Target::wire(gate_index, gate.wire_second_input()), address_timestamp_combined[i], + combined_bits, + num_chunks, ); gate_indices.push(gate_index); diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index b1d57929..44f2923a 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -15,7 +15,7 @@ use crate::util::ceil_div_usize; /// A gate for checking that one value is less than or equal to another. #[derive(Clone, Debug)] -pub(crate) struct ComparisonGate, const D: usize> { +pub struct ComparisonGate, const D: usize> { pub(crate) num_bits: usize, pub(crate) num_chunks: usize, _phantom: PhantomData, @@ -436,7 +436,6 @@ mod tests { use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::hash::hash_types::HashOut; - use crate::plonk::plonk_common::reduce_with_powers; use crate::plonk::vars::EvaluationVars; #[test] diff --git a/src/gates/switch.rs b/src/gates/switch.rs index 26efac0e..2f6b7122 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -249,27 +249,9 @@ impl, const D: usize> SwitchGenerator { second_outputs.push(get_local_wire(self.gate.wire_second_output(self.copy, e))); } - let first_keep = first_outputs - .iter() - .zip(first_inputs.iter()) - .all(|(x, y)| x == y); - let second_keep = second_outputs - .iter() - .zip(second_inputs.iter()) - .all(|(x, y)| x == y); - - let first_swap = first_outputs - .iter() - .zip(second_inputs.iter()) - .all(|(x, y)| x == y); - let second_swap = second_outputs - .iter() - .zip(first_inputs.iter()) - .all(|(x, y)| x == y); - - if first_keep && second_keep { + if first_outputs == first_inputs && second_outputs == second_inputs { out_buffer.set_wire(switch_bool_wire, F::ZERO); - } else if first_swap && second_swap { + } else if first_outputs == second_inputs && second_outputs == first_inputs { out_buffer.set_wire(switch_bool_wire, F::ONE); } else { panic!("No permutation from given inputs to given outputs"); From 8aa43763601a68bb84998c7b1d794292271e35b6 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 22 Sep 2021 14:03:27 -0700 Subject: [PATCH 09/12] addressed comments (set sorted values in partial witness; no more directly setting gate inputs) --- src/gadgets/sorting.rs | 91 +++++++++++++++++++++--------------------- 1 file changed, 46 insertions(+), 45 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 57f65a51..21601988 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -159,62 +159,40 @@ impl, const D: usize> SimpleGenerator }) .unzip(); - let combined_values_u64: Vec<_> = timestamp_values + let combined_values: Vec<_> = timestamp_values .iter() - .zip(address_values.iter()) + .zip(&address_values) .map(|(&t, &a)| { - a.to_canonical_u64() * (1 << self.timestamp_bits as u64) + t.to_canonical_u64() + F::from_canonical_u64( + (a.to_canonical_u64() << self.timestamp_bits as u64) + t.to_canonical_u64(), + ) }) .collect(); - let mut combined_values_sorted = combined_values_u64.clone(); - combined_values_sorted.sort(); - let combined_values: Vec<_> = combined_values_sorted - .iter() - .map(|&x| F::from_canonical_u64(x)) - .collect(); - let mut input_ops_and_keys: Vec<_> = self .input_ops .iter() - .zip(combined_values_u64) + .zip(combined_values) .collect::>(); - input_ops_and_keys.sort_by(|(_, a_val), (_, b_val)| a_val.cmp(b_val)); - let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(op, _)| op).collect(); + input_ops_and_keys.sort_by_key(|(_, val)| val.to_canonical_u64()); for i in 0..n { out_buffer.set_target( self.output_ops[i].is_write.target, - witness.get_target(input_ops_sorted[i].is_write.target), + witness.get_target(input_ops_and_keys[i].0.is_write.target), ); out_buffer.set_target( self.output_ops[i].address, - witness.get_target(input_ops_sorted[i].address), + witness.get_target(input_ops_and_keys[i].0.address), ); out_buffer.set_target( self.output_ops[i].timestamp, - witness.get_target(input_ops_sorted[i].timestamp), + witness.get_target(input_ops_and_keys[i].0.timestamp), ); out_buffer.set_target( self.output_ops[i].value, - witness.get_target(input_ops_sorted[i].value), + witness.get_target(input_ops_and_keys[i].0.value), ); - - if i > 0 { - out_buffer.set_target( - Target::wire( - self.gate_indices[i - 1], - self.gates[i - 1].wire_second_input(), - ), - combined_values[i], - ); - } - if i < n - 1 { - out_buffer.set_target( - Target::wire(self.gate_indices[i], self.gates[i].wire_first_input()), - combined_values[i], - ); - } } } } @@ -228,7 +206,7 @@ mod tests { use super::*; use crate::field::crandall_field::CrandallField; - use crate::field::field_types::Field; + use crate::field::field_types::{Field, PrimeField}; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; @@ -239,7 +217,7 @@ mod tests { let config = CircuitConfig::large_zk_config(); - let pw = PartialWitness::new(); + let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let mut rng = thread_rng(); @@ -252,19 +230,42 @@ mod tests { .collect(); let value_vals: Vec<_> = (0..size).map(|_| F::rand()).collect(); - let input_ops: Vec = - izip!(is_write_vals, address_vals, timestamp_vals, value_vals) - .map(|(is_write, address, timestamp, value)| MemoryOpTarget { - is_write: builder.constant_bool(is_write), - address: builder.constant(address), - timestamp: builder.constant(timestamp), - value: builder.constant(value), - }) - .collect(); + let input_ops: Vec = izip!( + is_write_vals.clone(), + address_vals.clone(), + timestamp_vals.clone(), + value_vals.clone() + ) + .map(|(is_write, address, timestamp, value)| MemoryOpTarget { + is_write: builder.constant_bool(is_write), + address: builder.constant(address), + timestamp: builder.constant(timestamp), + value: builder.constant(value), + }) + .collect(); - let _output_ops = + let combined_vals_u64: Vec<_> = timestamp_vals + .iter() + .zip(&address_vals) + .map(|(&t, &a)| (a.to_canonical_u64() << timestamp_bits as u64) + t.to_canonical_u64()) + .collect(); + let mut input_ops_and_keys: Vec<_> = + izip!(is_write_vals, address_vals, timestamp_vals, value_vals) + .zip(combined_vals_u64) + .collect::>(); + input_ops_and_keys.sort_by_key(|(_, val)| val.clone()); + let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(x, _)| x).collect(); + + let output_ops = builder.sort_memory_ops(input_ops.as_slice(), address_bits, timestamp_bits); + for i in 0..size { + pw.set_bool_target(output_ops[i].is_write, input_ops_sorted[i].0); + pw.set_target(output_ops[i].address, input_ops_sorted[i].1); + pw.set_target(output_ops[i].timestamp, input_ops_sorted[i].2); + pw.set_target(output_ops[i].value, input_ops_sorted[i].3); + } + let data = builder.build(); let proof = data.prove(pw).unwrap(); From d541e251ee95975638fb07db2b5d34f5e446bb01 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 22 Sep 2021 18:10:38 -0700 Subject: [PATCH 10/12] Add a MemoryOp to simplify MemoryOpSortGenerator --- src/gadgets/sorting.rs | 73 ++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 42 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 21601988..42dc8541 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -1,7 +1,7 @@ use itertools::izip; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; +use crate::field::field_types::{Field, RichField}; use crate::gates::comparison::ComparisonGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; @@ -9,6 +9,13 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; +pub struct MemoryOp { + is_write: bool, + address: F, + timestamp: F, + value: F, +} + #[derive(Clone, Debug)] pub struct MemoryOpTarget { is_write: BoolTarget, @@ -148,61 +155,43 @@ impl, const D: usize> SimpleGenerator let n = self.input_ops.len(); debug_assert!(self.output_ops.len() == n); - let (timestamp_values, address_values): (Vec<_>, Vec<_>) = self + let mut ops: Vec<_> = self .input_ops .iter() .map(|op| { - ( - witness.get_target(op.timestamp), - witness.get_target(op.address), - ) - }) - .unzip(); - - let combined_values: Vec<_> = timestamp_values - .iter() - .zip(&address_values) - .map(|(&t, &a)| { - F::from_canonical_u64( - (a.to_canonical_u64() << self.timestamp_bits as u64) + t.to_canonical_u64(), - ) + let is_write = witness.get_bool_target(op.is_write); + let address = witness.get_target(op.address); + let timestamp = witness.get_target(op.timestamp); + let value = witness.get_target(op.value); + MemoryOp { + is_write, + address, + timestamp, + value, + } }) .collect(); - let mut input_ops_and_keys: Vec<_> = self - .input_ops - .iter() - .zip(combined_values) - .collect::>(); - input_ops_and_keys.sort_by_key(|(_, val)| val.to_canonical_u64()); + ops.sort_unstable_by_key(|op| { + ( + op.address.to_canonical_u64(), + op.timestamp.to_canonical_u64(), + ) + }); - for i in 0..n { - out_buffer.set_target( - self.output_ops[i].is_write.target, - witness.get_target(input_ops_and_keys[i].0.is_write.target), - ); - out_buffer.set_target( - self.output_ops[i].address, - witness.get_target(input_ops_and_keys[i].0.address), - ); - out_buffer.set_target( - self.output_ops[i].timestamp, - witness.get_target(input_ops_and_keys[i].0.timestamp), - ); - out_buffer.set_target( - self.output_ops[i].value, - witness.get_target(input_ops_and_keys[i].0.value), - ); + for (op, out_op) in ops.iter().zip(&self.output_ops) { + out_buffer.set_target(out_op.is_write.target, F::from_bool(op.is_write)); + out_buffer.set_target(out_op.address, op.address); + out_buffer.set_target(out_op.timestamp, op.timestamp); + out_buffer.set_target(out_op.value, op.value); } } } #[cfg(test)] mod tests { - use std::collections::HashSet; - use anyhow::Result; - use rand::{seq::SliceRandom, thread_rng, Rng}; + use rand::{thread_rng, Rng}; use super::*; use crate::field::crandall_field::CrandallField; From 202967a40bee6bec7e3e2dbbffc4ae43d488a009 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 22 Sep 2021 18:14:58 -0700 Subject: [PATCH 11/12] Other tweaks --- src/gadgets/sorting.rs | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 42dc8541..5f074f94 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -8,6 +8,7 @@ use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; +use std::marker::PhantomData; pub struct MemoryOp { is_write: bool, @@ -46,14 +47,12 @@ impl, const D: usize> CircuitBuilder { rhs: Target, bits: usize, num_chunks: usize, - ) -> (ComparisonGate, usize) { + ) { let gate = ComparisonGate::new(bits, num_chunks); let gate_index = self.add_gate(gate.clone(), vec![]); self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); self.connect(Target::wire(gate_index, gate.wire_second_input()), rhs); - - (gate, gate_index) } /// Sort memory operations by address value, then by timestamp value. @@ -101,29 +100,21 @@ impl, const D: usize> CircuitBuilder { .map(|op| self.mul_add(op.address, two_n, op.timestamp)) .collect(); - let mut gates = Vec::new(); - let mut gate_indices = Vec::new(); for i in 1..n { - let (gate, gate_index) = self.assert_le( + self.assert_le( address_timestamp_combined[i - 1], address_timestamp_combined[i], combined_bits, num_chunks, ); - - gate_indices.push(gate_index); - gates.push(gate); } - self.assert_permutation_memory_ops(ops, output_targets.as_slice()); + self.assert_permutation_memory_ops(ops, &output_targets); self.add_simple_generator(MemoryOpSortGenerator:: { input_ops: ops.to_vec(), - gate_indices, - gates: gates.clone(), output_ops: output_targets.clone(), - address_bits, - timestamp_bits, + _phantom: PhantomData, }); output_targets @@ -133,11 +124,8 @@ impl, const D: usize> CircuitBuilder { #[derive(Debug)] struct MemoryOpSortGenerator, const D: usize> { input_ops: Vec, - gate_indices: Vec, - gates: Vec>, output_ops: Vec, - address_bits: usize, - timestamp_bits: usize, + _phantom: PhantomData, } impl, const D: usize> SimpleGenerator From 0c0a8fd862d4875679f9c1dd3232d50db36be66a Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 23 Sep 2021 09:16:38 -0700 Subject: [PATCH 12/12] tweaks --- src/gadgets/sorting.rs | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 5f074f94..f202620c 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -1,3 +1,5 @@ +use std::marker::PhantomData; + use itertools::izip; use crate::field::extension_field::Extendable; @@ -8,7 +10,6 @@ use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; -use std::marker::PhantomData; pub struct MemoryOp { is_write: bool, @@ -39,15 +40,8 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation(a_chunks, b_chunks); } - /// Add a ComparisonGate to - /// Returns the gate and its index - pub fn assert_le( - &mut self, - lhs: Target, - rhs: Target, - bits: usize, - num_chunks: usize, - ) { + /// Add a ComparisonGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits. + pub fn assert_le(&mut self, lhs: Target, rhs: Target, bits: usize, num_chunks: usize) { let gate = ComparisonGate::new(bits, num_chunks); let gate_index = self.add_gate(gate.clone(), vec![]); @@ -125,7 +119,7 @@ impl, const D: usize> CircuitBuilder { struct MemoryOpSortGenerator, const D: usize> { input_ops: Vec, output_ops: Vec, - _phantom: PhantomData, + _phantom: PhantomData, } impl, const D: usize> SimpleGenerator