From 013c8bb612354d5e9f1872d123902826c6644d09 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 26 Aug 2021 14:23:30 -0700 Subject: [PATCH] progress --- Cargo.toml | 1 + src/gadgets/permutation.rs | 192 ++++++++++++++++++++++++++++++++++--- src/gates/switch.rs | 6 +- src/util/bimap.rs | 76 +++++++++++++++ src/util/mod.rs | 1 + 5 files changed, 260 insertions(+), 16 deletions(-) create mode 100644 src/util/bimap.rs diff --git a/Cargo.toml b/Cargo.toml index d8b84356..790ff6b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ edition = "2018" default-run = "bench_recursion" [dependencies] +bimap = "0.4.0" env_logger = "0.9.0" log = "0.4.14" itertools = "0.10.0" diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index 4b301c04..64e18f93 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::convert::TryInto; use crate::field::{extension_field::Extendable, field_types::Field}; @@ -6,6 +7,7 @@ use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::bimap::bimap_from_lists; impl, const D: usize> CircuitBuilder { /// Assert that two lists of expressions evaluate to permutations of one another. @@ -162,18 +164,150 @@ impl, const D: usize> CircuitBuilder { } } -fn route_one_layer( - a_values: Vec, - b_values: Vec, - a_wires: Vec<[Target; CHUNK_SIZE]>, - b_wires: Vec<[Target; CHUNK_SIZE]>, +fn route( + a_values: Vec<[F; CHUNK_SIZE]>, + b_values: Vec<[F; CHUNK_SIZE]>, + a_switches: Vec<[Target; CHUNK_SIZE]>, + b_switches: Vec<[Target; CHUNK_SIZE]>, + witness: &PartialWitness, + out_buffer: &mut GeneratedValues, ) { - todo!() + assert_eq!(a_values.len(), b_values.len()); + let n = a_values.len(); + let even = n % 2 == 0; + // Bimap: maps indices of values in a to indices of the same values in b + let ab_map = bimap_from_lists(a_values, b_values); + let switches = [a_switches, b_switches]; + + // Given a side and an index, returns the index in the other side that corresponds to the same value. + let ab_map_by_side = |side: usize, index: usize| -> usize { + *match side { + 0 => ab_map.get_by_left(&index), + 1 => ab_map.get_by_right(&index), + _ => panic!("Expected side to be 0 or 1"), + } + .unwrap() + }; + + // We maintain two maps for wires which have been routed to a particular subnetwork on one side + // of the network (left or right) but not the other. The keys are wire indices, and the values + // are subnetwork indices. + let mut partial_routes = [BTreeMap::new(), BTreeMap::new()]; + + // After we route a wire on one side, we find the corresponding wire on the other side and check + // if it still needs to be routed. If so, we add it to partial_routes. + let enqueue_other_side = |partial_routes: &mut [BTreeMap], + witness: &PartialWitness, + out_buffer: &mut GeneratedValues, + side: usize, + this_i: usize, + subnet: bool| { + let other_side = 1 - side; + let other_i = ab_map_by_side(side, this_i); + let other_switch_i = other_i / 2; + + if other_switch_i >= switches[other_side].len() { + // The other wire doesn't go through a switch, so there's no routing to be done. + // This happens in the case of the very last wire. + return; + } + + if witness.contains_all(&switches[other_side][other_switch_i]) { + // The other switch has already been routed. + return; + } + + let other_i_sibling = 4 * other_switch_i + 1 - other_i; + if let Some(&sibling_subnet) = partial_routes[other_side].get(&other_i_sibling) { + // The other switch's sibling is already pending routing. + assert_ne!(subnet, sibling_subnet); + } else { + let opt_old_subnet = partial_routes[other_side].insert(other_i, subnet); + if let Some(old_subnet) = opt_old_subnet { + assert_eq!(subnet, old_subnet, "Routing conflict (should never happen)"); + } + } + }; + + // See Figure 8 in the AS-Waksman paper. + if even { + enqueue_other_side(&mut partial_routes, witness, out_buffer, 1, n - 2, false); + enqueue_other_side(&mut partial_routes, witness, out_buffer, 1, n - 1, true); + } else { + enqueue_other_side(&mut partial_routes, witness, out_buffer, 0, n - 1, true); + enqueue_other_side(&mut partial_routes, witness, out_buffer, 1, n - 1, true); + } + + let route_switch = |partial_routes: &mut [BTreeMap], + witness: &PartialWitness, + out_buffer: &mut GeneratedValues, + side: usize, + switch_index: usize, + swap: bool| { + // First, we actually set the switch configuration. + for e in 0..CHUNK_SIZE { + out_buffer.set_target(switches[side][switch_index][e], F::from_bool(swap)); + } + + // Then, we enqueue the two corresponding wires on the other side of the network, to ensure + // that they get routed in the next step. + let this_i_1 = switch_index * 2; + let this_i_2 = this_i_1 + 1; + enqueue_other_side(partial_routes, witness, out_buffer, side, this_i_1, swap); + enqueue_other_side(partial_routes, witness, out_buffer, side, this_i_2, !swap); + }; + + // If {a,b}_only_routes is empty, then we can route any switch next. For efficiency, we will + // simply do top-down scans (one on the left side, one on the right side) for switches which + // have not yet been routed. These variables represent the positions of those two scans. + let mut scan_index = [0, 0]; + + // Until both scans complete, we alternate back and worth between the left and right switch + // layers. We process any partially routed wires for that side, or if there aren't any, we route + // the next switch in our scan. + while scan_index[0] < switches[0].len() || scan_index[1] < switches[1].len() { + for side in 0..=1 { + if !partial_routes[side].is_empty() { + for (this_i, subnet) in partial_routes[side].clone().into_iter() { + let this_first_switch_input = this_i % 2 == 0; + let swap = this_first_switch_input == subnet; + let this_switch_i = this_i / 2; + route_switch( + &mut partial_routes, + witness, + out_buffer, + side, + this_switch_i, + swap, + ); + } + partial_routes[side].clear(); + } else { + // We can route any switch next. Continue our scan for pending switches. + while scan_index[side] < switches[side].len() + && witness.contains_all(&switches[side][scan_index[side]]) + { + scan_index[side] += 1; + } + if scan_index[side] < switches[side].len() { + // Either switch configuration would work; we arbitrarily choose to not swap. + route_switch( + &mut partial_routes, + witness, + out_buffer, + side, + scan_index[side], + false, + ); + } + } + } + } } struct PermutationGenerator { - a_values: Vec, - b_values: Vec, + a_values: Vec<[F; CHUNK_SIZE]>, + b_values: Vec<[F; CHUNK_SIZE]>, a_wires: Vec<[Target; CHUNK_SIZE]>, b_wires: Vec<[Target; CHUNK_SIZE]>, } @@ -185,11 +319,17 @@ impl SimpleGenerator for PermutationGenera .map(|arr| arr.to_vec()) .flatten() .collect() - //.chain(self.b_wires.iter()).collect() } fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { - todo!() + route( + self.a_values.clone(), + self.b_values.clone(), + self.a_wires.clone(), + self.b_wires.clone(), + witness, + out_buffer, + ); } } @@ -205,10 +345,36 @@ mod tests { use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; - fn test_permutation(size: usize) -> Result<()> { + #[test] + fn route_2x2() -> Result<()> { + type F = CrandallField; + type FF = QuarticCrandallField; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(config.num_wires); + let mut builder = CircuitBuilder::::new(config); + + let one = F::ONE; + let two = F::from_canonical_usize(2); + let seven = F::from_canonical_usize(7); + let eight = F::from_canonical_usize(8); + + let one_two = [builder.constant(one), builder.constant(two)]; + let seven_eight = [builder.constant(seven), builder.constant(eight)]; + + let a = vec![one_two, seven_eight]; + let b = vec![seven_eight, one_two]; + + builder.assert_permutation(a, b); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + /*fn test_permutation(size: usize) -> Result<()> { type F = CrandallField; type FF = QuarticCrandallField; - let len = 1 << len_log; let config = CircuitConfig::large_config(); let pw = PartialWitness::new(config.num_wires); let mut builder = CircuitBuilder::::new(config); @@ -225,5 +391,5 @@ mod tests { let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) - } + }*/ } diff --git a/src/gates/switch.rs b/src/gates/switch.rs index 70bfd45d..5ffc6525 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -169,7 +169,7 @@ impl, const D: usize, const CHUNK_SIZE: usize> Gate } fn num_wires(&self) -> usize { - Self::wire_second_output(self.num_copies - 1, CHUNK_SIZE - 1) + 1 + Self::wire_switch_bool(self.num_copies, self.num_copies - 1) + 1 } fn num_constants(&self) -> usize { @@ -294,14 +294,14 @@ mod tests { #[test] fn low_degree() { - test_low_degree::(SwitchGate::<_, 4, 3>::new( + test_low_degree::(SwitchGate::<_, 4, 3>::new_from_config( CircuitConfig::large_config(), )); } #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(SwitchGate::<_, 4, 3>::new( + test_eval_fns::(SwitchGate::<_, 4, 3>::new_from_config( CircuitConfig::large_config(), )) } diff --git a/src/util/bimap.rs b/src/util/bimap.rs new file mode 100644 index 00000000..6fb46db7 --- /dev/null +++ b/src/util/bimap.rs @@ -0,0 +1,76 @@ +use std::collections::HashMap; +use std::hash::Hash; + +use bimap::BiMap; +use itertools::enumerate; + +/// Given two lists which are permutations of one another, creates a BiMap which maps an index in +/// one list to an index in the other list with the same associated value. +/// +/// If the lists contain duplicates, then multiple permutations with this property exist, and an +/// arbitrary one of them will be returned. +pub fn bimap_from_lists(a: Vec, b: Vec) -> BiMap { + assert_eq!(a.len(), b.len(), "Vectors differ in length"); + + let mut b_values_to_indices = HashMap::new(); + for (i, value) in enumerate(b) { + b_values_to_indices + .entry(value) + .or_insert_with(Vec::new) + .push(i); + } + + let mut bimap = BiMap::new(); + for (i, value) in enumerate(a) { + if let Some(j) = b_values_to_indices.get_mut(&value).and_then(Vec::pop) { + bimap.insert(i, j); + } else { + panic!("Value in first list not found in second list"); + } + } + + bimap +} + +#[cfg(test)] +mod tests { + use crate::util::bimap::bimap_from_lists; + + #[test] + fn empty_lists() { + let empty: Vec = Vec::new(); + let bimap = bimap_from_lists(empty.clone(), empty); + assert!(bimap.is_empty()); + } + + #[test] + fn without_duplicates() { + let bimap = bimap_from_lists(vec!['a', 'b', 'c'], vec!['b', 'c', 'a']); + assert_eq!(bimap.get_by_left(&0), Some(&2)); + assert_eq!(bimap.get_by_left(&1), Some(&0)); + assert_eq!(bimap.get_by_left(&2), Some(&1)); + } + + #[test] + fn with_duplicates() { + let first = vec!['a', 'a', 'b']; + let second = vec!['a', 'b', 'a']; + let bimap = bimap_from_lists(first.clone(), second.clone()); + for i in 0..3 { + let j = *bimap.get_by_left(&i).unwrap(); + assert_eq!(first[i], second[j]); + } + } + + #[test] + #[should_panic] + fn lengths_differ() { + bimap_from_lists(vec!['a', 'a', 'b'], vec!['a', 'b']); + } + + #[test] + #[should_panic] + fn not_a_permutation() { + bimap_from_lists(vec!['a', 'a', 'b'], vec!['a', 'b', 'b']); + } +} diff --git a/src/util/mod.rs b/src/util/mod.rs index cdd26ef8..daa6716b 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,6 +1,7 @@ use crate::field::field_types::Field; use crate::polynomial::polynomial::PolynomialValues; +pub(crate) mod bimap; pub(crate) mod context_tree; pub(crate) mod marking; pub(crate) mod partial_products;