diff --git a/Cargo.toml b/Cargo.toml index d8b84356..ed5acab6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,8 @@ edition = "2018" default-run = "bench_recursion" [dependencies] +array_tool = "1.0.3" +bimap = "0.4.0" env_logger = "0.9.0" log = "0.4.14" itertools = "0.10.0" diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 1ac6c6f2..b8f3e1f5 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -108,7 +108,7 @@ impl, const D: usize> CircuitBuilder { exponent_bits: impl IntoIterator>, ) -> Target { let _false = self._false(); - let gate = ExponentiationGate::new(self.config.clone()); + let gate = ExponentiationGate::new_from_config(self.config.clone()); let num_power_bits = gate.num_power_bits; let mut exp_bits_vec: Vec = exponent_bits.into_iter().map(|b| *b.borrow()).collect(); diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 9cf24c1d..a5c3d47d 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -413,7 +413,7 @@ impl, const D: usize> CircuitBuilder { ) -> ExtensionTarget { let inv = self.add_virtual_extension_target(); let one = self.one_extension(); - self.add_generator(QuotientGeneratorExtension { + self.add_simple_generator(QuotientGeneratorExtension { numerator: one, denominator: y, quotient: inv, diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 0eb42e27..4b3371ef 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -3,6 +3,7 @@ pub mod arithmetic_extension; pub mod hash; pub mod insert; pub mod interpolation; +pub mod permutation; pub mod polynomial; pub mod random_access; pub mod range_check; diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs new file mode 100644 index 00000000..c47f5f23 --- /dev/null +++ b/src/gadgets/permutation.rs @@ -0,0 +1,483 @@ +use std::collections::BTreeMap; +use std::marker::PhantomData; + +use crate::field::{ + extension_field::Extendable, + field_types::{Field, PrimeField}, +}; +use crate::gates::switch::SwitchGate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::bimap::bimap_from_lists; + +impl, const D: usize> CircuitBuilder { + /// Assert that two lists of expressions evaluate to permutations of one another. + pub fn assert_permutation(&mut self, a: Vec>, b: Vec>) { + assert_eq!( + a.len(), + b.len(), + "Permutation must have same number of inputs and outputs" + ); + assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same"); + + let chunk_size = a[0].len(); + + match a.len() { + // Two empty lists are permutations of one another, trivially. + 0 => (), + // Two singleton lists are permutations of one another as long as their items are equal. + 1 => { + for e in 0..chunk_size { + self.connect(a[0][e], b[0][e]) + } + } + 2 => { + self.assert_permutation_2x2(a[0].clone(), a[1].clone(), b[0].clone(), b[1].clone()) + } + // For larger lists, we recursively use two smaller permutation networks. + //_ => self.assert_permutation_recursive(a, b) + _ => self.assert_permutation_recursive(a, b), + } + } + + /// Assert that [a1, a2] is a permutation of [b1, b2]. + fn assert_permutation_2x2( + &mut self, + a1: Vec, + a2: Vec, + b1: Vec, + b2: Vec, + ) { + assert!( + a1.len() == a2.len() && a2.len() == b1.len() && b1.len() == b2.len(), + "Chunk size must be the same" + ); + + let chunk_size = a1.len(); + + let (_switch, gate_out1, gate_out2) = self.create_switch(a1, a2); + for e in 0..chunk_size { + self.connect(b1[e], gate_out1[e]); + self.connect(b2[e], gate_out2[e]); + } + } + + /// Given two input wire chunks, add a new switch to the circuit (by adding one copy to a switch + /// gate). Returns the wire for the switch boolean, and the two output wire chunks. + fn create_switch( + &mut self, + a1: Vec, + a2: Vec, + ) -> (Target, Vec, Vec) { + assert_eq!(a1.len(), a2.len(), "Chunk size must be the same"); + + let chunk_size = a1.len(); + + if self.current_switch_gates.len() < chunk_size { + self.current_switch_gates + .extend(vec![None; chunk_size - self.current_switch_gates.len()]); + } + + let (gate, gate_index, mut next_copy) = + match self.current_switch_gates[chunk_size - 1].clone() { + None => { + let gate = SwitchGate::::new_from_config(self.config.clone(), chunk_size); + let gate_index = self.add_gate(gate.clone(), vec![]); + (gate, gate_index, 0) + } + Some((gate, idx, next_copy)) => (gate, idx, next_copy), + }; + + let num_copies = gate.num_copies; + + let mut c = Vec::new(); + let mut d = Vec::new(); + for e in 0..chunk_size { + self.connect( + a1[e], + Target::wire(gate_index, gate.wire_first_input(next_copy, e)), + ); + self.connect( + a2[e], + Target::wire(gate_index, gate.wire_second_input(next_copy, e)), + ); + c.push(Target::wire( + gate_index, + gate.wire_first_output(next_copy, e), + )); + d.push(Target::wire( + gate_index, + gate.wire_second_output(next_copy, e), + )); + } + + let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy)); + + next_copy += 1; + if next_copy == num_copies { + self.current_switch_gates[chunk_size - 1] = None; + } else { + self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy)); + } + + (switch, c, d) + } + + fn assert_permutation_recursive(&mut self, a: Vec>, b: Vec>) { + assert_eq!( + a.len(), + b.len(), + "Permutation must have same number of inputs and outputs" + ); + assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same"); + + let n = a.len(); + let even = n % 2 == 0; + + let mut child_1_a = Vec::new(); + let mut child_1_b = Vec::new(); + let mut child_2_a = Vec::new(); + let mut child_2_b = Vec::new(); + + // See Figure 8 in the AS-Waksman paper. + let a_num_switches = n / 2; + let b_num_switches = if even { + a_num_switches - 1 + } else { + a_num_switches + }; + + let mut a_switches = Vec::new(); + let mut b_switches = Vec::new(); + for i in 0..a_num_switches { + let (switch, out_1, out_2) = self.create_switch(a[i * 2].clone(), a[i * 2 + 1].clone()); + a_switches.push(switch); + child_1_a.push(out_1); + child_2_a.push(out_2); + } + for i in 0..b_num_switches { + let (switch, out_1, out_2) = self.create_switch(b[i * 2].clone(), b[i * 2 + 1].clone()); + b_switches.push(switch); + child_1_b.push(out_1); + child_2_b.push(out_2); + } + + // See Figure 8 in the AS-Waksman paper. + if even { + child_1_b.push(b[n - 2].clone()); + child_2_b.push(b[n - 1].clone()); + } else { + child_2_a.push(a[n - 1].clone()); + child_2_b.push(b[n - 1].clone()); + } + + self.assert_permutation(child_1_a, child_1_b); + self.assert_permutation(child_2_a, child_2_b); + + self.add_simple_generator(PermutationGenerator:: { + a, + b, + a_switches, + b_switches, + _phantom: PhantomData, + }); + } +} + +fn route( + a_values: Vec>, + b_values: Vec>, + a_switches: Vec, + b_switches: Vec, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, +) { + assert_eq!(a_values.len(), b_values.len()); + let n = a_values.len(); + let even = n % 2 == 0; + + // We use a bimap to match indices of values in a to indices of the same values in b. + // This means that given a wire on one side, we can easily find the matching wire on the other side. + let ab_map = bimap_from_lists(a_values, b_values); + + let switches = [a_switches, b_switches]; + + // We keep track of the new wires we've routed (after routing some wires, we need to check `witness` + // and `newly_set` instead of just `witness`. + let mut newly_set = [vec![false; n], vec![false; n]]; + + // Given a side and an index, returns the index in the other side that corresponds to the same value. + let ab_map_by_side = |side: usize, index: usize| -> usize { + *match side { + 0 => ab_map.get_by_left(&index), + 1 => ab_map.get_by_right(&index), + _ => panic!("Expected side to be 0 or 1"), + } + .unwrap() + }; + + // We maintain two maps for wires which have been routed to a particular subnetwork on one side + // of the network (left or right) but not the other. The keys are wire indices, and the values + // are subnetwork indices. + let mut partial_routes = [BTreeMap::new(), BTreeMap::new()]; + + // After we route a wire on one side, we find the corresponding wire on the other side and check + // if it still needs to be routed. If so, we add it to partial_routes. + let enqueue_other_side = |partial_routes: &mut [BTreeMap], + witness: &PartitionWitness, + newly_set: &mut [Vec], + side: usize, + this_i: usize, + subnet: bool| { + let other_side = 1 - side; + let other_i = ab_map_by_side(side, this_i); + let other_switch_i = other_i / 2; + + if other_switch_i >= switches[other_side].len() { + // The other wire doesn't go through a switch, so there's no routing to be done. + // This happens in the case of the very last wire. + return; + } + + if witness.contains(switches[other_side][other_switch_i]) + || newly_set[other_side][other_switch_i] + { + // The other switch has already been routed. + return; + } + + let other_i_sibling = 4 * other_switch_i + 1 - other_i; + if let Some(&sibling_subnet) = partial_routes[other_side].get(&other_i_sibling) { + // The other switch's sibling is already pending routing. + assert_ne!(subnet, sibling_subnet); + } else { + let opt_old_subnet = partial_routes[other_side].insert(other_i, subnet); + if let Some(old_subnet) = opt_old_subnet { + assert_eq!(subnet, old_subnet, "Routing conflict (should never happen)"); + } + } + }; + + // See Figure 8 in the AS-Waksman paper. + if even { + enqueue_other_side( + &mut partial_routes, + witness, + &mut newly_set, + 1, + n - 2, + false, + ); + enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 1, n - 1, true); + } else { + enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 0, n - 1, true); + enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 1, n - 1, true); + } + + let route_switch = |partial_routes: &mut [BTreeMap], + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + newly_set: &mut [Vec], + side: usize, + switch_index: usize, + swap: bool| { + // First, we actually set the switch configuration. + out_buffer.set_target(switches[side][switch_index], F::from_bool(swap)); + newly_set[side][switch_index] = true; + + // Then, we enqueue the two corresponding wires on the other side of the network, to ensure + // that they get routed in the next step. + let this_i_1 = switch_index * 2; + let this_i_2 = this_i_1 + 1; + enqueue_other_side(partial_routes, witness, newly_set, side, this_i_1, swap); + enqueue_other_side(partial_routes, witness, newly_set, side, this_i_2, !swap); + }; + + // If {a,b}_only_routes is empty, then we can route any switch next. For efficiency, we will + // simply do top-down scans (one on the left side, one on the right side) for switches which + // have not yet been routed. These variables represent the positions of those two scans. + let mut scan_index = [0, 0]; + + // Until both scans complete, we alternate back and worth between the left and right switch + // layers. We process any partially routed wires for that side, or if there aren't any, we route + // the next switch in our scan. + while scan_index[0] < switches[0].len() || scan_index[1] < switches[1].len() { + for side in 0..=1 { + if !partial_routes[side].is_empty() { + for (this_i, subnet) in partial_routes[side].clone().into_iter() { + let this_first_switch_input = this_i % 2 == 0; + let swap = this_first_switch_input == subnet; + let this_switch_i = this_i / 2; + route_switch( + &mut partial_routes, + witness, + out_buffer, + &mut newly_set, + side, + this_switch_i, + swap, + ); + } + partial_routes[side].clear(); + } else { + // We can route any switch next. Continue our scan for pending switches. + while scan_index[side] < switches[side].len() + && (witness.contains(switches[side][scan_index[side]]) + || newly_set[side][scan_index[side]]) + { + scan_index[side] += 1; + } + if scan_index[side] < switches[side].len() { + // Either switch configuration would work; we arbitrarily choose to not swap. + route_switch( + &mut partial_routes, + witness, + out_buffer, + &mut newly_set, + side, + scan_index[side], + false, + ); + scan_index[side] += 1; + } + } + } + } +} + +#[derive(Debug)] +struct PermutationGenerator { + a: Vec>, + b: Vec>, + a_switches: Vec, + b_switches: Vec, + _phantom: PhantomData, +} + +impl SimpleGenerator for PermutationGenerator { + fn dependencies(&self) -> Vec { + self.a.iter().chain(&self.b).flatten().cloned().collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a_values = self + .a + .iter() + .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) + .collect(); + let b_values = self + .b + .iter() + .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) + .collect(); + route( + a_values, + b_values, + self.a_switches.clone(), + self.b_switches.clone(), + witness, + out_buffer, + ); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use rand::{seq::SliceRandom, thread_rng}; + + use super::*; + use crate::field::crandall_field::CrandallField; + use crate::field::field_types::Field; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + fn test_permutation_good(size: usize) -> Result<()> { + type F = CrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_zk_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let lst: Vec = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect(); + let a: Vec> = lst[..] + .chunks(2) + .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) + .collect(); + let mut b = a.clone(); + b.shuffle(&mut thread_rng()); + + builder.assert_permutation(a, b); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + fn test_permutation_bad(size: usize) -> Result<()> { + type F = CrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_zk_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let lst1: Vec = F::rand_vec(size * 2); + let lst2: Vec = F::rand_vec(size * 2); + let a: Vec> = lst1[..] + .chunks(2) + .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) + .collect(); + let b: Vec> = lst2[..] + .chunks(2) + .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) + .collect(); + + builder.assert_permutation(a, b); + + let data = builder.build(); + data.prove(pw).unwrap(); + + Ok(()) + } + + #[test] + fn test_permutations_good() -> Result<()> { + for n in 2..9 { + test_permutation_good(n)?; + } + + Ok(()) + } + + #[test] + #[should_panic] + fn test_permutation_bad_small() { + let size = 2; + + test_permutation_bad(size).unwrap() + } + + #[test] + #[should_panic] + fn test_permutation_bad_medium() { + let size = 6; + + test_permutation_bad(size).unwrap() + } + + #[test] + #[should_panic] + fn test_permutation_bad_large() { + let size = 10; + + test_permutation_bad(size).unwrap() + } +} diff --git a/src/gadgets/range_check.rs b/src/gadgets/range_check.rs index e669e2da..f53b0a7b 100644 --- a/src/gadgets/range_check.rs +++ b/src/gadgets/range_check.rs @@ -28,7 +28,7 @@ impl, const D: usize> CircuitBuilder { let high_gate = self.add_gate(BaseSumGate::<2>::new(num_bits - n_log), vec![]); let low = Target::wire(low_gate, BaseSumGate::<2>::WIRE_SUM); let high = Target::wire(high_gate, BaseSumGate::<2>::WIRE_SUM); - self.add_generator(LowHighGenerator { + self.add_simple_generator(LowHighGenerator { integer: x, n_log, low, diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index 5afbb882..4589bbc0 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -44,7 +44,7 @@ impl, const D: usize> CircuitBuilder { self.connect(limb.borrow().target, Target::wire(gate_index, wire)); } - self.add_generator(BaseSumGenerator::<2> { + self.add_simple_generator(BaseSumGenerator::<2> { gate_index, limbs: bits.map(|l| *l.borrow()).collect(), }); diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 3b4c3b6e..0b091721 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -47,7 +47,7 @@ impl, const D: usize> CircuitBuilder { } self.connect(acc, integer); - self.add_generator(WireSplitGenerator { + self.add_simple_generator(WireSplitGenerator { integer, gates, num_limbs: bits_per_gate, diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index da82961f..23765a55 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -111,12 +111,15 @@ impl, const D: usize> Gate for ArithmeticExt ) -> Vec>> { (0..NUM_ARITHMETIC_OPS) .map(|i| { - let g: Box> = Box::new(ArithmeticExtensionGenerator { - gate_index, - const_0: local_constants[0], - const_1: local_constants[1], - i, - }); + let g: Box> = Box::new( + ArithmeticExtensionGenerator { + gate_index, + const_0: local_constants[0], + const_1: local_constants[1], + i, + } + .adapter(), + ); g }) .collect::>() diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index ef7dbbda..801f9e26 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -105,7 +105,7 @@ impl, const D: usize, const B: usize> Gate f gate_index, num_limbs: self.num_limbs, }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } // 1 for the sum then `num_limbs` for the limbs. diff --git a/src/gates/constant.rs b/src/gates/constant.rs index 3bf34536..1135e8a3 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -54,7 +54,7 @@ impl, const D: usize> Gate for ConstantGate gate_index, constant: local_constants[0], }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index eeeff2c5..c9a410c7 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -20,14 +20,18 @@ pub(crate) struct ExponentiationGate, const D: usi } impl, const D: usize> ExponentiationGate { - pub fn new(config: CircuitConfig) -> Self { - let num_power_bits = Self::max_power_bits(config.num_wires, config.num_routed_wires); + pub fn new(num_power_bits: usize) -> Self { Self { num_power_bits, _phantom: PhantomData, } } + pub fn new_from_config(config: CircuitConfig) -> Self { + let num_power_bits = Self::max_power_bits(config.num_wires, config.num_routed_wires); + Self::new(num_power_bits) + } + fn max_power_bits(num_wires: usize, num_routed_wires: usize) -> usize { // 2 wires are reserved for the base and output. let max_for_routed_wires = num_routed_wires - 2; @@ -180,7 +184,7 @@ impl, const D: usize> Gate for Exponentiatio gate_index, gate: self.clone(), }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { @@ -298,12 +302,14 @@ mod tests { ..CircuitConfig::large_config() }; - test_low_degree::(ExponentiationGate::new(config)); + test_low_degree::(ExponentiationGate::new_from_config(config)); } #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(ExponentiationGate::new(CircuitConfig::large_config())) + test_eval_fns::(ExponentiationGate::new_from_config( + CircuitConfig::large_config(), + )) } #[test] diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 43113c44..07f5e46c 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -219,7 +219,7 @@ impl, const D: usize, const R: usize> Gate gate_index, constants: self.constants.clone(), }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index 383ff69c..7c013499 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -220,7 +220,7 @@ impl, const D: usize> Gate for InsertionGate gate_index, gate: self.clone(), }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 52bacdac..37e94d89 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -190,7 +190,7 @@ impl, const D: usize> Gate for Interpolation gate: self.clone(), _phantom: PhantomData, }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { diff --git a/src/gates/random_access.rs b/src/gates/random_access.rs index f463e487..5ea604d7 100644 --- a/src/gates/random_access.rs +++ b/src/gates/random_access.rs @@ -167,7 +167,7 @@ impl, const D: usize> Gate for RandomAccessG gate_index, gate: self.clone(), }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { diff --git a/src/gates/reducing.rs b/src/gates/reducing.rs index 70f290c9..60a9f747 100644 --- a/src/gates/reducing.rs +++ b/src/gates/reducing.rs @@ -137,10 +137,13 @@ impl, const D: usize> Gate for ReducingGate< gate_index: usize, _local_constants: &[F], ) -> Vec>> { - vec![Box::new(ReducingGenerator { - gate_index, - gate: self.clone(), - })] + vec![Box::new( + ReducingGenerator { + gate_index, + gate: self.clone(), + } + .adapter(), + )] } fn num_wires(&self) -> usize { diff --git a/src/gates/switch.rs b/src/gates/switch.rs index 939c11e0..b17ce1ab 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -1,10 +1,12 @@ use std::marker::PhantomData; +use array_tool::vec::Union; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, PrimeField}; use crate::gates::gate::Gate; -use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::generator::{GeneratedValues, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; @@ -14,62 +16,59 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// A gate for conditionally swapping input values based on a boolean. #[derive(Clone, Debug)] -pub(crate) struct SwitchGate, const D: usize, const CHUNK_SIZE: usize> -{ - num_copies: usize, +pub(crate) struct SwitchGate, const D: usize> { + pub(crate) chunk_size: usize, + pub(crate) num_copies: usize, _phantom: PhantomData, } -impl, const D: usize, const CHUNK_SIZE: usize> - SwitchGate -{ - pub fn new(config: CircuitConfig) -> Self { - let num_copies = Self::max_num_copies(config.num_routed_wires); +impl, const D: usize> SwitchGate { + pub fn new(num_copies: usize, chunk_size: usize) -> Self { Self { + chunk_size, num_copies, _phantom: PhantomData, } } - fn max_num_copies(num_routed_wires: usize) -> usize { - num_routed_wires / (4 * CHUNK_SIZE + 1) + pub fn new_from_config(config: CircuitConfig, chunk_size: usize) -> Self { + let num_copies = Self::max_num_copies(config.num_routed_wires, chunk_size); + Self::new(num_copies, chunk_size) + } + + pub fn max_num_copies(num_routed_wires: usize, chunk_size: usize) -> usize { + num_routed_wires / (4 * chunk_size + 1) + } + + pub fn wire_first_input(&self, copy: usize, element: usize) -> usize { + debug_assert!(element < self.chunk_size); + copy * (4 * self.chunk_size + 1) + element + } + + pub fn wire_second_input(&self, copy: usize, element: usize) -> usize { + debug_assert!(element < self.chunk_size); + copy * (4 * self.chunk_size + 1) + self.chunk_size + element + } + + pub fn wire_first_output(&self, copy: usize, element: usize) -> usize { + debug_assert!(element < self.chunk_size); + copy * (4 * self.chunk_size + 1) + 2 * self.chunk_size + element + } + + pub fn wire_second_output(&self, copy: usize, element: usize) -> usize { + debug_assert!(element < self.chunk_size); + copy * (4 * self.chunk_size + 1) + 3 * self.chunk_size + element } pub fn wire_switch_bool(&self, copy: usize) -> usize { debug_assert!(copy < self.num_copies); - copy * (4 * CHUNK_SIZE + 1) - } - - pub fn wire_first_input(&self, copy: usize, element: usize) -> usize { - debug_assert!(copy < self.num_copies); - debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE + 1) + 1 + element - } - - pub fn wire_second_input(&self, copy: usize, element: usize) -> usize { - debug_assert!(copy < self.num_copies); - debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE + 1) + 1 + CHUNK_SIZE + element - } - - pub fn wire_first_output(&self, copy: usize, element: usize) -> usize { - debug_assert!(copy < self.num_copies); - debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE + 1) + 1 + 2 * CHUNK_SIZE + element - } - - pub fn wire_second_output(&self, copy: usize, element: usize) -> usize { - debug_assert!(copy < self.num_copies); - debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE + 1) + 1 + 3 * CHUNK_SIZE + element + copy * (4 * self.chunk_size + 1) + 4 * self.chunk_size } } -impl, const D: usize, const CHUNK_SIZE: usize> Gate - for SwitchGate -{ +impl, const D: usize> Gate for SwitchGate { fn id(&self) -> String { - format!("{:?}", self, D, CHUNK_SIZE) + format!("{:?}", self, D) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { @@ -79,7 +78,7 @@ impl, const D: usize, const CHUNK_SIZE: usize> Gat let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; let not_switch = F::Extension::ONE - switch_bool; - for e in 0..CHUNK_SIZE { + for e in 0..self.chunk_size { let first_input = vars.local_wires[self.wire_first_input(c, e)]; let second_input = vars.local_wires[self.wire_second_input(c, e)]; let first_output = vars.local_wires[self.wire_first_output(c, e)]; @@ -102,7 +101,7 @@ impl, const D: usize, const CHUNK_SIZE: usize> Gat let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; let not_switch = F::ONE - switch_bool; - for e in 0..CHUNK_SIZE { + for e in 0..self.chunk_size { let first_input = vars.local_wires[self.wire_first_input(c, e)]; let second_input = vars.local_wires[self.wire_second_input(c, e)]; let first_output = vars.local_wires[self.wire_first_output(c, e)]; @@ -130,7 +129,7 @@ impl, const D: usize, const CHUNK_SIZE: usize> Gat let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; let not_switch = builder.sub_extension(one, switch_bool); - for e in 0..CHUNK_SIZE { + for e in 0..self.chunk_size { let first_input = vars.local_wires[self.wire_first_input(c, e)]; let second_input = vars.local_wires[self.wire_second_input(c, e)]; let first_output = vars.local_wires[self.wire_first_output(c, e)]; @@ -165,15 +164,20 @@ impl, const D: usize, const CHUNK_SIZE: usize> Gat gate_index: usize, _local_constants: &[F], ) -> Vec>> { - let gen = SwitchGenerator:: { - gate_index, - gate: self.clone(), - }; - vec![Box::new(gen)] + (0..self.num_copies) + .map(|c| { + let g: Box> = Box::new(SwitchGenerator:: { + gate_index, + gate: self.clone(), + copy: c, + }); + g + }) + .collect() } fn num_wires(&self) -> usize { - self.wire_second_output(self.num_copies - 1, CHUNK_SIZE - 1) + 1 + self.wire_switch_bool(self.num_copies - 1) + 1 } fn num_constants(&self) -> usize { @@ -185,35 +189,46 @@ impl, const D: usize, const CHUNK_SIZE: usize> Gat } fn num_constraints(&self) -> usize { - 4 * self.num_copies * CHUNK_SIZE + 4 * self.num_copies * self.chunk_size } } #[derive(Debug)] -struct SwitchGenerator, const D: usize, const CHUNK_SIZE: usize> { +struct SwitchGenerator, const D: usize> { gate_index: usize, - gate: SwitchGate, + gate: SwitchGate, + copy: usize, } -impl, const D: usize, const CHUNK_SIZE: usize> SimpleGenerator - for SwitchGenerator -{ - fn dependencies(&self) -> Vec { +impl, const D: usize> SwitchGenerator { + fn in_out_dependencies(&self) -> Vec { let local_target = |input| Target::wire(self.gate_index, input); let mut deps = Vec::new(); - for c in 0..self.gate.num_copies { - deps.push(local_target(self.gate.wire_switch_bool(c))); - for e in 0..CHUNK_SIZE { - deps.push(local_target(self.gate.wire_first_input(c, e))); - deps.push(local_target(self.gate.wire_second_input(c, e))); - } + for e in 0..self.gate.chunk_size { + deps.push(local_target(self.gate.wire_first_input(self.copy, e))); + deps.push(local_target(self.gate.wire_second_input(self.copy, e))); + deps.push(local_target(self.gate.wire_first_output(self.copy, e))); + deps.push(local_target(self.gate.wire_second_output(self.copy, e))); } deps } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn in_switch_dependencies(&self) -> Vec { + let local_target = |input| Target::wire(self.gate_index, input); + + let mut deps = Vec::new(); + for e in 0..self.gate.chunk_size { + deps.push(local_target(self.gate.wire_first_input(self.copy, e))); + deps.push(local_target(self.gate.wire_second_input(self.copy, e))); + deps.push(local_target(self.gate.wire_switch_bool(self.copy))); + } + + deps + } + + fn run_in_out(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { let local_wire = |input| Wire { gate: self.gate_index, input, @@ -221,24 +236,69 @@ impl, const D: usize, const CHUNK_SIZE: usize> Sim let get_local_wire = |input| witness.get_wire(local_wire(input)); - for c in 0..self.gate.num_copies { - let switch_bool = get_local_wire(self.gate.wire_switch_bool(c)); - for e in 0..CHUNK_SIZE { - let first_input = get_local_wire(self.gate.wire_first_input(c, e)); - let second_input = get_local_wire(self.gate.wire_second_input(c, e)); - let first_output_wire = local_wire(self.gate.wire_first_output(c, e)); - let second_output_wire = local_wire(self.gate.wire_second_output(c, e)); + for e in 0..self.gate.chunk_size { + let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy)); + let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e)); + let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e)); + let first_output = get_local_wire(self.gate.wire_first_output(self.copy, e)); + let second_output = get_local_wire(self.gate.wire_second_output(self.copy, e)); - if switch_bool == F::ONE { - out_buffer.set_wire(first_output_wire, second_input); - out_buffer.set_wire(second_output_wire, first_input); - } else { - out_buffer.set_wire(first_output_wire, first_input); - out_buffer.set_wire(second_output_wire, second_input); - } + if first_output == first_input && second_output == second_input { + out_buffer.set_wire(switch_bool_wire, F::ZERO); + } else if first_output == second_input && second_output == first_input { + out_buffer.set_wire(switch_bool_wire, F::ONE); + } else { + panic!("No permutation from given inputs to given outputs"); } } } + + fn run_in_switch(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + for e in 0..self.gate.chunk_size { + let first_output_wire = local_wire(self.gate.wire_first_output(self.copy, e)); + let second_output_wire = local_wire(self.gate.wire_second_output(self.copy, e)); + let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e)); + let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e)); + let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy)); + + let (first_output, second_output) = if switch_bool == F::ZERO { + (first_input, second_input) + } else if switch_bool == F::ONE { + (second_input, first_input) + } else { + panic!("Invalid switch bool value"); + }; + + out_buffer.set_wire(first_output_wire, first_output); + out_buffer.set_wire(second_output_wire, second_output); + } + } +} + +impl, const D: usize> WitnessGenerator for SwitchGenerator { + fn watch_list(&self) -> Vec { + self.in_out_dependencies() + .union(self.in_switch_dependencies()) + } + + fn run(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) -> bool { + if witness.contains_all(&self.in_out_dependencies()) { + self.run_in_out(witness, out_buffer); + true + } else if witness.contains_all(&self.in_switch_dependencies()) { + self.run_in_switch(witness, out_buffer); + true + } else { + false + } + } } #[cfg(test)] @@ -259,37 +319,44 @@ mod tests { #[test] fn wire_indices() { - let gate = SwitchGate:: { - num_copies: 3, + type SG = SwitchGate; + let num_copies = 3; + let chunk_size = 3; + + let gate = SG { + chunk_size, + num_copies, _phantom: PhantomData, }; - assert_eq!(gate.wire_switch_bool(0), 0); - assert_eq!(gate.wire_first_input(0, 0), 1); - assert_eq!(gate.wire_first_input(0, 2), 3); - assert_eq!(gate.wire_second_input(0, 0), 4); - assert_eq!(gate.wire_second_input(0, 2), 6); - assert_eq!(gate.wire_first_output(0, 0), 7); - assert_eq!(gate.wire_second_output(0, 2), 12); - assert_eq!(gate.wire_switch_bool(1), 13); - assert_eq!(gate.wire_first_input(1, 0), 14); - assert_eq!(gate.wire_second_output(1, 2), 25); - assert_eq!(gate.wire_switch_bool(2), 26); - assert_eq!(gate.wire_first_input(2, 0), 27); - assert_eq!(gate.wire_second_output(2, 2), 38); + assert_eq!(gate.wire_first_input(0, 0), 0); + assert_eq!(gate.wire_first_input(0, 2), 2); + assert_eq!(gate.wire_second_input(0, 0), 3); + assert_eq!(gate.wire_second_input(0, 2), 5); + assert_eq!(gate.wire_first_output(0, 0), 6); + assert_eq!(gate.wire_second_output(0, 2), 11); + assert_eq!(gate.wire_switch_bool(0), 12); + assert_eq!(gate.wire_first_input(1, 0), 13); + assert_eq!(gate.wire_second_output(1, 2), 24); + assert_eq!(gate.wire_switch_bool(1), 25); + assert_eq!(gate.wire_first_input(2, 0), 26); + assert_eq!(gate.wire_second_output(2, 2), 37); + assert_eq!(gate.wire_switch_bool(2), 38); } #[test] fn low_degree() { - test_low_degree::(SwitchGate::<_, 4, 3>::new( + test_low_degree::(SwitchGate::<_, 4>::new_from_config( CircuitConfig::large_config(), + 3, )); } #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(SwitchGate::<_, 4, 3>::new( + test_eval_fns::(SwitchGate::<_, 4>::new_from_config( CircuitConfig::large_config(), + 3, )) } @@ -312,7 +379,7 @@ mod tests { let mut v = Vec::new(); for c in 0..num_copies { let switch = switch_bools[c]; - v.push(F::from_bool(switch)); + let mut first_input_chunk = Vec::with_capacity(CHUNK_SIZE); let mut second_input_chunk = Vec::with_capacity(CHUNK_SIZE); let mut first_output_chunk = Vec::with_capacity(CHUNK_SIZE); @@ -331,6 +398,8 @@ mod tests { v.append(&mut second_input_chunk); v.append(&mut first_output_chunk); v.append(&mut second_output_chunk); + + v.push(F::from_bool(switch)); } v.iter().map(|&x| x.into()).collect::>() @@ -340,7 +409,8 @@ mod tests { let second_inputs: Vec> = (0..num_copies).map(|_| F::rand_vec(CHUNK_SIZE)).collect(); let switch_bools = vec![true, false, true]; - let gate = SwitchGate:: { + let gate = SwitchGate:: { + chunk_size: CHUNK_SIZE, num_copies, _phantom: PhantomData, }; diff --git a/src/iop/generator.rs b/src/iop/generator.rs index 8174981c..483a7a62 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -1,4 +1,5 @@ use std::fmt::Debug; +use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; @@ -186,16 +187,32 @@ pub trait SimpleGenerator: 'static + Send + Sync + Debug { fn dependencies(&self) -> Vec; fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues); + + fn adapter(self) -> SimpleGeneratorAdapter + where + Self: Sized, + { + SimpleGeneratorAdapter { + inner: self, + _phantom: PhantomData, + } + } } -impl> WitnessGenerator for SG { +#[derive(Debug)] +pub struct SimpleGeneratorAdapter + ?Sized> { + _phantom: PhantomData, + inner: SG, +} + +impl> WitnessGenerator for SimpleGeneratorAdapter { fn watch_list(&self) -> Vec { - self.dependencies() + self.inner.dependencies() } fn run(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) -> bool { - if witness.contains_all(&self.dependencies()) { - self.run_once(witness, out_buffer); + if witness.contains_all(&self.inner.dependencies()) { + self.inner.run_once(witness, out_buffer); true } else { false diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index bf70db64..529a85b0 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -15,9 +15,12 @@ use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; +use crate::gates::switch::SwitchGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; use crate::hash::hashing::hash_n_to_hash; -use crate::iop::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator}; +use crate::iop::generator::{ + CopyGenerator, RandomValueGenerator, SimpleGenerator, WitnessGenerator, +}; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::iop::witness::PartitionWitness; @@ -41,7 +44,7 @@ pub struct CircuitBuilder, const D: usize> { gates: HashSet>, /// The concrete placement of each gate. - gate_instances: Vec>, + pub(crate) gate_instances: Vec>, /// Targets to be made public. public_inputs: Vec, @@ -66,6 +69,11 @@ pub struct CircuitBuilder, const D: usize> { /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using /// these constants with gate index `g` and already using `i` arithmetic operations. pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, + + // `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value + // chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies + // of switches + pub(crate) current_switch_gates: Vec, usize, usize)>>, } impl, const D: usize> CircuitBuilder { @@ -83,6 +91,7 @@ impl, const D: usize> CircuitBuilder { constants_to_targets: HashMap::new(), targets_to_constants: HashMap::new(), free_arithmetic: HashMap::new(), + current_switch_gates: Vec::new(), } } @@ -182,7 +191,7 @@ impl, const D: usize> CircuitBuilder { /// Adds a generator which will copy `src` to `dst`. pub fn generate_copy(&mut self, src: Target, dst: Target) { - self.add_generator(CopyGenerator { src, dst }); + self.add_simple_generator(CopyGenerator { src, dst }); } /// Uses Plonk's permutation argument to require that two elements be equal. @@ -209,8 +218,8 @@ impl, const D: usize> CircuitBuilder { self.generators.extend(generators); } - pub fn add_generator>(&mut self, generator: G) { - self.generators.push(Box::new(generator)); + pub fn add_simple_generator>(&mut self, generator: G) { + self.generators.push(Box::new(generator.adapter())); } /// Returns a routable target with a value of 0. @@ -383,7 +392,7 @@ impl, const D: usize> CircuitBuilder { for _ in 0..regular_poly_openings { let gate = self.add_gate(NoopGate, vec![]); for w in 0..num_wires { - self.add_generator(RandomValueGenerator { + self.add_simple_generator(RandomValueGenerator { target: Target::Wire(Wire { gate, input: w }), }); } @@ -397,7 +406,7 @@ impl, const D: usize> CircuitBuilder { let gate_2 = self.add_gate(NoopGate, vec![]); for w in 0..num_routed_wires { - self.add_generator(RandomValueGenerator { + self.add_simple_generator(RandomValueGenerator { target: Target::Wire(Wire { gate: gate_1, input: w, @@ -507,6 +516,33 @@ impl, const D: usize> CircuitBuilder { } } + /// Fill the remaining unused switch gates with dummy values, so that all + /// `SwitchGenerator` are run. + fn fill_switch_gates(&mut self) { + let zero = self.zero(); + + for chunk_size in 1..=self.current_switch_gates.len() { + if let Some((gate, gate_index, mut copy)) = + self.current_switch_gates[chunk_size - 1].clone() + { + while copy < gate.num_copies { + for element in 0..chunk_size { + let wire_first_input = + Target::wire(gate_index, gate.wire_first_input(copy, element)); + let wire_second_input = + Target::wire(gate_index, gate.wire_second_input(copy, element)); + let wire_switch_bool = + Target::wire(gate_index, gate.wire_switch_bool(copy)); + self.connect(zero, wire_first_input); + self.connect(zero, wire_second_input); + self.connect(zero, wire_switch_bool); + } + copy += 1; + } + } + } + } + pub fn print_gate_counts(&self, min_delta: usize) { self.context_log .filter(self.num_gates(), min_delta) @@ -519,6 +555,7 @@ impl, const D: usize> CircuitBuilder { let start = Instant::now(); self.fill_arithmetic_gates(); + self.fill_switch_gates(); // Hash the public inputs, and route them to a `PublicInputGate` which will enforce that // those hash wires match the claimed public inputs. diff --git a/src/util/bimap.rs b/src/util/bimap.rs new file mode 100644 index 00000000..6fb46db7 --- /dev/null +++ b/src/util/bimap.rs @@ -0,0 +1,76 @@ +use std::collections::HashMap; +use std::hash::Hash; + +use bimap::BiMap; +use itertools::enumerate; + +/// Given two lists which are permutations of one another, creates a BiMap which maps an index in +/// one list to an index in the other list with the same associated value. +/// +/// If the lists contain duplicates, then multiple permutations with this property exist, and an +/// arbitrary one of them will be returned. +pub fn bimap_from_lists(a: Vec, b: Vec) -> BiMap { + assert_eq!(a.len(), b.len(), "Vectors differ in length"); + + let mut b_values_to_indices = HashMap::new(); + for (i, value) in enumerate(b) { + b_values_to_indices + .entry(value) + .or_insert_with(Vec::new) + .push(i); + } + + let mut bimap = BiMap::new(); + for (i, value) in enumerate(a) { + if let Some(j) = b_values_to_indices.get_mut(&value).and_then(Vec::pop) { + bimap.insert(i, j); + } else { + panic!("Value in first list not found in second list"); + } + } + + bimap +} + +#[cfg(test)] +mod tests { + use crate::util::bimap::bimap_from_lists; + + #[test] + fn empty_lists() { + let empty: Vec = Vec::new(); + let bimap = bimap_from_lists(empty.clone(), empty); + assert!(bimap.is_empty()); + } + + #[test] + fn without_duplicates() { + let bimap = bimap_from_lists(vec!['a', 'b', 'c'], vec!['b', 'c', 'a']); + assert_eq!(bimap.get_by_left(&0), Some(&2)); + assert_eq!(bimap.get_by_left(&1), Some(&0)); + assert_eq!(bimap.get_by_left(&2), Some(&1)); + } + + #[test] + fn with_duplicates() { + let first = vec!['a', 'a', 'b']; + let second = vec!['a', 'b', 'a']; + let bimap = bimap_from_lists(first.clone(), second.clone()); + for i in 0..3 { + let j = *bimap.get_by_left(&i).unwrap(); + assert_eq!(first[i], second[j]); + } + } + + #[test] + #[should_panic] + fn lengths_differ() { + bimap_from_lists(vec!['a', 'a', 'b'], vec!['a', 'b']); + } + + #[test] + #[should_panic] + fn not_a_permutation() { + bimap_from_lists(vec!['a', 'a', 'b'], vec!['a', 'b', 'b']); + } +} diff --git a/src/util/mod.rs b/src/util/mod.rs index cdd26ef8..daa6716b 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,6 +1,7 @@ use crate::field::field_types::Field; use crate::polynomial::polynomial::PolynomialValues; +pub(crate) mod bimap; pub(crate) mod context_tree; pub(crate) mod marking; pub(crate) mod partial_products;