addressed comments

This commit is contained in:
Nicholas Ward 2021-09-06 08:38:47 -07:00
parent 0e24719908
commit 1818e69ce3

View File

@ -1,11 +1,10 @@
use std::collections::BTreeMap;
use std::convert::TryInto;
use std::marker::PhantomData;
use crate::field::{extension_field::Extendable, field_types::Field};
use crate::gates::switch::SwitchGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::{BoolTarget, Target};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::bimap::bimap_from_lists;
@ -62,12 +61,14 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
/// Given two input wire chunks, add a new switch to the circuit (by adding one copy to a switch
/// gate). Returns the wire for the switch boolean, and the two output wire chunks.
fn create_switch(
&mut self,
a1: Vec<Target>,
a2: Vec<Target>,
) -> (Target, Vec<Target>, Vec<Target>) {
assert!(a1.len() == a2.len(), "Chunk size must be the same");
assert_eq!(a1.len(), a2.len(), "Chunk size must be the same");
let chunk_size = a1.len();
@ -113,9 +114,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
next_copy += 1;
if next_copy == num_copies {
let new_gate = SwitchGate::<F, D>::new_from_config(self.config.clone(), chunk_size);
let new_gate_index = self.add_gate(new_gate.clone(), vec![]);
self.current_switch_gates[chunk_size - 1] = Some((new_gate, new_gate_index, 0));
self.current_switch_gates[chunk_size - 1] = None;
} else {
self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy));
}
@ -131,8 +130,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
);
assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same");
let chunk_size = a[0].len();
let n = a.len();
let even = n % 2 == 0;
@ -197,9 +194,15 @@ fn route<F: Field>(
assert_eq!(a_values.len(), b_values.len());
let n = a_values.len();
let even = n % 2 == 0;
// Bimap: maps indices of values in a to indices of the same values in b
// We use a bimap to match indices of values in a to indices of the same values in b.
// This means that given a wire on one side, we can easily find the matching wire on the other side.
let ab_map = bimap_from_lists(a_values, b_values);
let switches = [a_switches, b_switches];
// We keep track of the new wires we've routed (after routing some wires, we need to check `witness`
// and `newly_set` instead of just `witness`.
let mut newly_set = [vec![false; n], vec![false; n]];
// Given a side and an index, returns the index in the other side that corresponds to the same value.
@ -352,12 +355,7 @@ struct PermutationGenerator<F: Field> {
impl<F: Field> SimpleGenerator<F> for PermutationGenerator<F> {
fn dependencies(&self) -> Vec<Target> {
self.a
.clone()
.into_iter()
.flatten()
.chain(self.b.clone().into_iter().flatten())
.collect()
self.a.iter().chain(&self.b).flatten().cloned().collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
@ -405,7 +403,7 @@ mod tests {
let lst: Vec<F> = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect();
let a: Vec<Vec<Target>> = lst[..]
.windows(2)
.chunks(2)
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
.collect();
let mut b = a.clone();
@ -428,29 +426,29 @@ mod tests {
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let lst1: Vec<F> = (0..size * 2).map(|_| F::rand()).collect();
let lst2: Vec<F> = (0..size * 2).map(|_| F::rand()).collect();
let lst1: Vec<F> = F::rand_vec(size * 2);
let lst2: Vec<F> = F::rand_vec(size * 2);
let a: Vec<Vec<Target>> = lst1[..]
.windows(2)
.chunks(2)
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
.collect();
let b: Vec<Vec<Target>> = lst2[..]
.windows(2)
.chunks(2)
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
.collect();
builder.assert_permutation(a, b);
let data = builder.build();
let proof = data.prove(pw).unwrap();
data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
Ok(())
}
#[test]
fn test_permutations_good() -> Result<()> {
for n in 2..9 {
test_permutation_good(n).unwrap()
test_permutation_good(n)?;
}
Ok(())
@ -458,9 +456,25 @@ mod tests {
#[test]
#[should_panic]
fn test_permutations_bad() -> () {
for n in 2..9 {
test_permutation_bad(n).unwrap()
}
fn test_permutation_bad_small() {
let size = 2;
test_permutation_bad(size).unwrap()
}
#[test]
#[should_panic]
fn test_permutation_bad_medium() {
let size = 6;
test_permutation_bad(size).unwrap()
}
#[test]
#[should_panic]
fn test_permutation_bad_large() {
let size = 10;
test_permutation_bad(size).unwrap()
}
}