From 1818e69ce3e20b1be88bbedffde251af40d6a3c9 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 6 Sep 2021 08:38:47 -0700 Subject: [PATCH] addressed comments --- src/gadgets/permutation.rs | 68 +++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index 945f58fa..66d00bcd 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -1,11 +1,10 @@ use std::collections::BTreeMap; -use std::convert::TryInto; use std::marker::PhantomData; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gates::switch::SwitchGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::{BoolTarget, Target}; +use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bimap::bimap_from_lists; @@ -62,12 +61,14 @@ impl, const D: usize> CircuitBuilder { } } + /// Given two input wire chunks, add a new switch to the circuit (by adding one copy to a switch + /// gate). Returns the wire for the switch boolean, and the two output wire chunks. fn create_switch( &mut self, a1: Vec, a2: Vec, ) -> (Target, Vec, Vec) { - assert!(a1.len() == a2.len(), "Chunk size must be the same"); + assert_eq!(a1.len(), a2.len(), "Chunk size must be the same"); let chunk_size = a1.len(); @@ -113,9 +114,7 @@ impl, const D: usize> CircuitBuilder { next_copy += 1; if next_copy == num_copies { - let new_gate = SwitchGate::::new_from_config(self.config.clone(), chunk_size); - let new_gate_index = self.add_gate(new_gate.clone(), vec![]); - self.current_switch_gates[chunk_size - 1] = Some((new_gate, new_gate_index, 0)); + self.current_switch_gates[chunk_size - 1] = None; } else { self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy)); } @@ -131,8 +130,6 @@ impl, const D: usize> CircuitBuilder { ); assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same"); - let chunk_size = a[0].len(); - let n = a.len(); let even = n % 2 == 0; @@ -197,9 +194,15 @@ fn route( assert_eq!(a_values.len(), b_values.len()); let n = a_values.len(); let even = n % 2 == 0; - // Bimap: maps indices of values in a to indices of the same values in b + + // We use a bimap to match indices of values in a to indices of the same values in b. + // This means that given a wire on one side, we can easily find the matching wire on the other side. let ab_map = bimap_from_lists(a_values, b_values); + let switches = [a_switches, b_switches]; + + // We keep track of the new wires we've routed (after routing some wires, we need to check `witness` + // and `newly_set` instead of just `witness`. let mut newly_set = [vec![false; n], vec![false; n]]; // Given a side and an index, returns the index in the other side that corresponds to the same value. @@ -352,12 +355,7 @@ struct PermutationGenerator { impl SimpleGenerator for PermutationGenerator { fn dependencies(&self) -> Vec { - self.a - .clone() - .into_iter() - .flatten() - .chain(self.b.clone().into_iter().flatten()) - .collect() + self.a.iter().chain(&self.b).flatten().cloned().collect() } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { @@ -405,7 +403,7 @@ mod tests { let lst: Vec = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect(); let a: Vec> = lst[..] - .windows(2) + .chunks(2) .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) .collect(); let mut b = a.clone(); @@ -428,29 +426,29 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let lst1: Vec = (0..size * 2).map(|_| F::rand()).collect(); - let lst2: Vec = (0..size * 2).map(|_| F::rand()).collect(); + let lst1: Vec = F::rand_vec(size * 2); + let lst2: Vec = F::rand_vec(size * 2); let a: Vec> = lst1[..] - .windows(2) + .chunks(2) .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) .collect(); let b: Vec> = lst2[..] - .windows(2) + .chunks(2) .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) .collect(); builder.assert_permutation(a, b); let data = builder.build(); - let proof = data.prove(pw).unwrap(); + data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + Ok(()) } #[test] fn test_permutations_good() -> Result<()> { for n in 2..9 { - test_permutation_good(n).unwrap() + test_permutation_good(n)?; } Ok(()) @@ -458,9 +456,25 @@ mod tests { #[test] #[should_panic] - fn test_permutations_bad() -> () { - for n in 2..9 { - test_permutation_bad(n).unwrap() - } + fn test_permutation_bad_small() { + let size = 2; + + test_permutation_bad(size).unwrap() + } + + #[test] + #[should_panic] + fn test_permutation_bad_medium() { + let size = 6; + + test_permutation_bad(size).unwrap() + } + + #[test] + #[should_panic] + fn test_permutation_bad_large() { + let size = 10; + + test_permutation_bad(size).unwrap() } }