diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index 740d73dd..06e29376 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -40,25 +40,34 @@ impl, const D: usize> CircuitBuilder { } } - /// Assert that [a, b] is a permutation of [c, d]. + /// Assert that [a1, a2] is a permutation of [b1, b2]. fn assert_permutation_2x2( &mut self, - a: [Target; CHUNK_SIZE], - b: [Target; CHUNK_SIZE], - c: [Target; CHUNK_SIZE], - d: [Target; CHUNK_SIZE], + a1: [Target; CHUNK_SIZE], + a2: [Target; CHUNK_SIZE], + b1: [Target; CHUNK_SIZE], + b2: [Target; CHUNK_SIZE], ) { - let (_, gate_c, gate_d) = self.create_switch(a, b); + let (switch, gate_c, gate_d) = self.create_switch(a1, a2); for e in 0..CHUNK_SIZE { - self.route(c[e], gate_c[e]); - self.route(d[e], gate_d[e]); + self.route(b1[e], gate_c[e]); + self.route(b2[e], gate_d[e]); } + + self.add_generator(TwoByTwoPermutationGenerator:: { + a1, + a2, + b1, + b2, + switch, + _phantom: PhantomData, + }); } fn create_switch( &mut self, - a: [Target; CHUNK_SIZE], - b: [Target; CHUNK_SIZE], + a1: [Target; CHUNK_SIZE], + a2: [Target; CHUNK_SIZE], ) -> (Target, [Target; CHUNK_SIZE], [Target; CHUNK_SIZE]) { if self.current_switch_gates.len() < CHUNK_SIZE { self.current_switch_gates @@ -81,14 +90,14 @@ impl, const D: usize> CircuitBuilder { let mut d = Vec::new(); for e in 0..CHUNK_SIZE { self.route( - a[e], + a1[e], Target::wire( gate_index, SwitchGate::::wire_first_input(next_copy, e), ), ); self.route( - b[e], + a2[e], Target::wire( gate_index, SwitchGate::::wire_second_input(next_copy, e), @@ -321,6 +330,68 @@ fn route( } } +struct TwoByTwoPermutationGenerator { + a1: [Target; CHUNK_SIZE], + a2: [Target; CHUNK_SIZE], + b1: [Target; CHUNK_SIZE], + b2: [Target; CHUNK_SIZE], + switch: Target, + _phantom: PhantomData, +} + +impl SimpleGenerator + for TwoByTwoPermutationGenerator +{ + fn dependencies(&self) -> Vec { + [self.a1, self.a2, self.b1, self.b2] + .to_vec() + .iter() + .map(|arr| arr.to_vec()) + .flatten() + .collect() + } + + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { + let a1_values: Vec<_> = self + .a1 + .to_vec() + .iter() + .map(|x| witness.get_target(*x)) + .collect(); + let a2_values: Vec<_> = self + .a2 + .to_vec() + .iter() + .map(|x| witness.get_target(*x)) + .collect(); + let b1_values: Vec<_> = self + .b1 + .to_vec() + .iter() + .map(|x| witness.get_target(*x)) + .collect(); + let b2_values: Vec<_> = self + .b2 + .to_vec() + .iter() + .map(|x| witness.get_target(*x)) + .collect(); + + let no_switch = a1_values.iter().zip(b1_values.iter()).all(|(a, b)| a == b) + && a2_values.iter().zip(b2_values.iter()).all(|(a, b)| a == b); + let switch = a1_values.iter().zip(b2_values.iter()).all(|(a, b)| a == b) + && a2_values.iter().zip(b1_values.iter()).all(|(a, b)| a == b); + + if no_switch { + out_buffer.set_target(self.switch, F::ZERO); + } else if switch { + out_buffer.set_target(self.switch, F::ONE); + } else { + panic!("No permutation"); + } + } +} + struct PermutationGenerator { a: Vec<[Target; CHUNK_SIZE]>, b: Vec<[Target; CHUNK_SIZE]>, @@ -331,9 +402,12 @@ struct PermutationGenerator { impl SimpleGenerator for PermutationGenerator { fn dependencies(&self) -> Vec { - let mut deps = self.a_switches.clone(); - deps.extend(self.b_switches.clone()); - deps + self.a + .iter() + .map(|arr| arr.to_vec()) + .flatten() + .chain(self.b.iter().map(|arr| arr.to_vec()).flatten()) + .collect() } fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) {