diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index e94315b0..6d8e7d6d 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; use std::convert::TryInto; +use std::marker::PhantomData; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gates::switch::SwitchGate; @@ -64,13 +65,13 @@ impl, const D: usize> CircuitBuilder { .extend(vec![None; CHUNK_SIZE - self.current_switch_gates.len()]); } - let (gate, gate_index, mut next_copy) = match self.current_switch_gates[CHUNK_SIZE - 1] { + let (gate_index, mut next_copy) = match self.current_switch_gates[CHUNK_SIZE - 1] { None => { let gate = SwitchGate::::new_from_config(self.config.clone()); let gate_index = self.add_gate(gate.clone(), vec![]); - (gate, gate_index, 0) + (gate_index, 0) } - Some((idx, next_copy)) => (self.gate_instances[idx], idx, next_copy), + Some((idx, next_copy)) => (idx, next_copy), }; let num_copies = @@ -105,7 +106,7 @@ impl, const D: usize> CircuitBuilder { let switch = Target::wire( gate_index, - SwitchGate::::wire_switch_bool(gate.num_copies, next_copy), + SwitchGate::::wire_switch_bool(num_copies, next_copy), ); let c_arr: [Target; CHUNK_SIZE] = c.try_into().unwrap(); @@ -144,13 +145,17 @@ impl, const D: usize> CircuitBuilder { a_num_switches }; + let mut a_switches = Vec::new(); + let mut b_switches = Vec::new(); for i in 0..a_num_switches { - let (a_switch, out_1, out_2) = self.create_switch(a[i * 2], a[i * 2 + 1]); + let (switch, out_1, out_2) = self.create_switch(a[i * 2], a[i * 2 + 1]); + a_switches.push(switch); child_1_a.push(out_1); child_2_a.push(out_2); } for i in 0..b_num_switches { - let (b_switch, out_1, out_2) = self.create_switch(b[i * 2], b[i * 2 + 1]); + let (switch, out_1, out_2) = self.create_switch(b[i * 2], b[i * 2 + 1]); + b_switches.push(switch); child_1_b.push(out_1); child_2_b.push(out_2); } @@ -167,15 +172,21 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation(child_1_a, child_1_b); self.assert_permutation(child_2_a, child_2_b); - self.add_generator(PermutationGenerator {}); + self.add_generator(PermutationGenerator:: { + a, + b, + a_switches, + b_switches, + _phantom: PhantomData, + }); } } fn route( a_values: Vec<[F; CHUNK_SIZE]>, b_values: Vec<[F; CHUNK_SIZE]>, - a_switches: Vec<[Target; CHUNK_SIZE]>, - b_switches: Vec<[Target; CHUNK_SIZE]>, + a_switches: Vec, + b_switches: Vec, witness: &PartialWitness, out_buffer: &mut GeneratedValues, ) { @@ -219,7 +230,7 @@ fn route( return; } - if witness.contains_all(&switches[other_side][other_switch_i]) { + if witness.contains(switches[other_side][other_switch_i]) { // The other switch has already been routed. return; } @@ -252,9 +263,7 @@ fn route( switch_index: usize, swap: bool| { // First, we actually set the switch configuration. - for e in 0..CHUNK_SIZE { - out_buffer.set_target(switches[side][switch_index][e], F::from_bool(swap)); - } + out_buffer.set_target(switches[side][switch_index], F::from_bool(swap)); // Then, we enqueue the two corresponding wires on the other side of the network, to ensure // that they get routed in the next step. @@ -292,7 +301,7 @@ fn route( } else { // We can route any switch next. Continue our scan for pending switches. while scan_index[side] < switches[side].len() - && witness.contains_all(&switches[side][scan_index[side]]) + && witness.contains(switches[side][scan_index[side]]) { scan_index[side] += 1; } @@ -313,21 +322,22 @@ fn route( } struct PermutationGenerator { - a_wires: Vec<[Target; CHUNK_SIZE]>, - b_wires: Vec<[Target; CHUNK_SIZE]>, + a: Vec<[Target; CHUNK_SIZE]>, + b: Vec<[Target; CHUNK_SIZE]>, + a_switches: Vec, + b_switches: Vec, + _phantom: PhantomData, } impl SimpleGenerator for PermutationGenerator { fn dependencies(&self) -> Vec { - self.a_wires - .iter() - .map(|arr| arr.to_vec()) - .flatten() - .collect() + let mut deps = self.a_switches.clone(); + deps.extend(self.b_switches.clone()); + deps } fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { - let wire_chunk_to_vals = |wire| { + let wire_chunk_to_vals = |wire: [Target; CHUNK_SIZE]| { let mut vals = [F::ZERO; CHUNK_SIZE]; for e in 0..CHUNK_SIZE { vals[e] = witness.get_target(wire[e]); @@ -335,13 +345,13 @@ impl SimpleGenerator for PermutationGenera vals }; - let a_values = self.a_wires.iter().map(wire_chunk_to_vals).collect(); - let b_values = self.b_wires.iter().map(wire_chunk_to_vals).collect(); + let a_values = self.a.iter().map(|chunk| wire_chunk_to_vals(*chunk)).collect(); + let b_values = self.b.iter().map(|chunk| wire_chunk_to_vals(*chunk)).collect(); route( - a_values.clone(), - b_values.clone(), - self.a_wires.clone(), - self.b_wires.clone(), + a_values, + b_values, + self.a_switches.clone(), + self.b_switches.clone(), witness, out_buffer, );