From 10d016a92c2fb30650a559417e5462a89f2c39f6 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 1 Sep 2021 16:38:10 -0700 Subject: [PATCH] chunk size as field --- src/gadgets/permutation.rs | 224 ++++++++++++++--------------------- src/gates/switch.rs | 207 +++++++++++++++----------------- src/plonk/circuit_builder.rs | 29 ++++- 3 files changed, 212 insertions(+), 248 deletions(-) diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index 4e685d2f..3016c2a1 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -12,27 +12,28 @@ use crate::util::bimap::bimap_from_lists; impl, const D: usize> CircuitBuilder { /// Assert that two lists of expressions evaluate to permutations of one another. - pub fn assert_permutation( - &mut self, - a: Vec<[Target; CHUNK_SIZE]>, - b: Vec<[Target; CHUNK_SIZE]>, - ) { + pub fn assert_permutation(&mut self, a: Vec>, b: Vec>) { assert_eq!( a.len(), b.len(), "Permutation must have same number of inputs and outputs" ); + assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same"); + + let chunk_size = a[0].len(); match a.len() { // Two empty lists are permutations of one another, trivially. 0 => (), // Two singleton lists are permutations of one another as long as their items are equal. 1 => { - for e in 0..CHUNK_SIZE { + for e in 0..chunk_size { self.assert_equal(a[0][e], b[0][e]) } } - 2 => self.assert_permutation_2x2(a[0], a[1], b[0], b[1]), + 2 => { + self.assert_permutation_2x2(a[0].clone(), a[1].clone(), b[0].clone(), b[1].clone()) + } // For larger lists, we recursively use two smaller permutation networks. //_ => self.assert_permutation_recursive(a, b) _ => self.assert_permutation_recursive(a, b), @@ -40,103 +41,98 @@ impl, const D: usize> CircuitBuilder { } /// Assert that [a1, a2] is a permutation of [b1, b2]. - fn assert_permutation_2x2( + fn assert_permutation_2x2( &mut self, - a1: [Target; CHUNK_SIZE], - a2: [Target; CHUNK_SIZE], - b1: [Target; CHUNK_SIZE], - b2: [Target; CHUNK_SIZE], + a1: Vec, + a2: Vec, + b1: Vec, + b2: Vec, ) { + assert!( + a1.len() == a2.len() && a2.len() == b1.len() && b1.len() == b2.len(), + "Chunk size must be the same" + ); + + let chunk_size = a1.len(); + let (switch, gate_out1, gate_out2) = self.create_switch(a1, a2); - for e in 0..CHUNK_SIZE { + for e in 0..chunk_size { self.route(b1[e], gate_out1[e]); self.route(b2[e], gate_out2[e]); } - - self.add_generator(TwoByTwoPermutationGenerator:: { - a1, - a2, - b1, - b2, - switch, - _phantom: PhantomData, - }); } - fn create_switch( + fn create_switch( &mut self, - a1: [Target; CHUNK_SIZE], - a2: [Target; CHUNK_SIZE], - ) -> (Target, [Target; CHUNK_SIZE], [Target; CHUNK_SIZE]) { - if self.current_switch_gates.len() < CHUNK_SIZE { + a1: Vec, + a2: Vec, + ) -> (Target, Vec, Vec) { + assert!(a1.len() == a2.len(), "Chunk size must be the same"); + + let chunk_size = a1.len(); + + if self.current_switch_gates.len() < chunk_size { self.current_switch_gates - .extend(vec![None; CHUNK_SIZE - self.current_switch_gates.len()]); + .extend(vec![None; chunk_size - self.current_switch_gates.len()]); } - let (gate_index, mut next_copy) = match self.current_switch_gates[CHUNK_SIZE - 1] { - None => { - let gate = SwitchGate::::new_from_config(self.config.clone()); - let gate_index = self.add_gate(gate.clone(), vec![]); - (gate_index, 0) - } - Some((idx, next_copy)) => (idx, next_copy), - }; + let (gate, gate_index, mut next_copy) = + match self.current_switch_gates[chunk_size - 1].clone() { + None => { + let gate = SwitchGate::::new_from_config(self.config.clone(), chunk_size); + let gate_index = self.add_gate(gate.clone(), vec![]); + (gate, gate_index, 0) + } + Some((gate, idx, next_copy)) => (gate, idx, next_copy), + }; - let num_copies = - SwitchGate::::max_num_copies(self.config.num_routed_wires); + let num_copies = gate.num_copies; let mut c = Vec::new(); let mut d = Vec::new(); - for e in 0..CHUNK_SIZE { + for e in 0..chunk_size { self.route( a1[e], - Target::wire( - gate_index, - SwitchGate::::wire_first_input(next_copy, e), - ), + Target::wire(gate_index, gate.wire_first_input(next_copy, e)), ); self.route( a2[e], - Target::wire( - gate_index, - SwitchGate::::wire_second_input(next_copy, e), - ), + Target::wire(gate_index, gate.wire_second_input(next_copy, e)), ); c.push(Target::wire( gate_index, - SwitchGate::::wire_first_output(next_copy, e), + gate.wire_first_output(next_copy, e), )); d.push(Target::wire( gate_index, - SwitchGate::::wire_second_output(next_copy, e), + gate.wire_second_output(next_copy, e), )); } - let switch = Target::wire( - gate_index, - SwitchGate::::wire_switch_bool(num_copies, next_copy), - ); - - let c_arr: [Target; CHUNK_SIZE] = c.try_into().unwrap(); - let d_arr: [Target; CHUNK_SIZE] = d.try_into().unwrap(); + let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy)); next_copy += 1; if next_copy == num_copies { - let new_gate = SwitchGate::::new_from_config(self.config.clone()); + let new_gate = SwitchGate::::new_from_config(self.config.clone(), chunk_size); let new_gate_index = self.add_gate(new_gate.clone(), vec![]); - self.current_switch_gates[CHUNK_SIZE - 1] = Some((new_gate_index, 0)); + self.current_switch_gates[chunk_size - 1] = Some((new_gate, new_gate_index, 0)); } else { - self.current_switch_gates[CHUNK_SIZE - 1] = Some((gate_index, next_copy)); + self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy)); } - (switch, c_arr, d_arr) + (switch, c, d) } - fn assert_permutation_recursive( - &mut self, - a: Vec<[Target; CHUNK_SIZE]>, - b: Vec<[Target; CHUNK_SIZE]>, - ) { + fn assert_permutation_recursive(&mut self, a: Vec>, b: Vec>) { + assert_eq!( + a.len(), + b.len(), + "Permutation must have same number of inputs and outputs" + ); + assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same"); + + let chunk_size = a[0].len(); + let n = a.len(); let even = n % 2 == 0; @@ -156,13 +152,13 @@ impl, const D: usize> CircuitBuilder { let mut a_switches = Vec::new(); let mut b_switches = Vec::new(); for i in 0..a_num_switches { - let (switch, out_1, out_2) = self.create_switch(a[i * 2], a[i * 2 + 1]); + let (switch, out_1, out_2) = self.create_switch(a[i * 2].clone(), a[i * 2 + 1].clone()); a_switches.push(switch); child_1_a.push(out_1); child_2_a.push(out_2); } for i in 0..b_num_switches { - let (switch, out_1, out_2) = self.create_switch(b[i * 2], b[i * 2 + 1]); + let (switch, out_1, out_2) = self.create_switch(b[i * 2].clone(), b[i * 2 + 1].clone()); b_switches.push(switch); child_1_b.push(out_1); child_2_b.push(out_2); @@ -180,7 +176,8 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation(child_1_a, child_1_b); self.assert_permutation(child_2_a, child_2_b); - self.add_generator(PermutationGenerator:: { + self.add_generator(PermutationGenerator:: { + chunk_size, a, b, a_switches, @@ -190,9 +187,9 @@ impl, const D: usize> CircuitBuilder { } } -fn route( - a_values: Vec<[F; CHUNK_SIZE]>, - b_values: Vec<[F; CHUNK_SIZE]>, +fn route( + a_values: Vec>, + b_values: Vec>, a_switches: Vec, b_switches: Vec, witness: &PartialWitness, @@ -354,57 +351,16 @@ fn route( } } } - -struct TwoByTwoPermutationGenerator { - a1: [Target; CHUNK_SIZE], - a2: [Target; CHUNK_SIZE], - b1: [Target; CHUNK_SIZE], - b2: [Target; CHUNK_SIZE], - switch: Target, - _phantom: PhantomData, -} - -impl SimpleGenerator - for TwoByTwoPermutationGenerator -{ - fn dependencies(&self) -> Vec { - [self.a1, self.a2, self.b1, self.b2] - .iter() - .map(|arr| arr.to_vec()) - .flatten() - .collect() - } - - fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { - let a1_values: Vec<_> = self.a1.iter().map(|x| witness.get_target(*x)).collect(); - let a2_values: Vec<_> = self.a2.iter().map(|x| witness.get_target(*x)).collect(); - let b1_values: Vec<_> = self.b1.iter().map(|x| witness.get_target(*x)).collect(); - let b2_values: Vec<_> = self.b2.iter().map(|x| witness.get_target(*x)).collect(); - - let no_switch = a1_values.iter().zip(b1_values.iter()).all(|(a, b)| a == b) - && a2_values.iter().zip(b2_values.iter()).all(|(a, b)| a == b); - let switch = a1_values.iter().zip(b2_values.iter()).all(|(a, b)| a == b) - && a2_values.iter().zip(b1_values.iter()).all(|(a, b)| a == b); - - if no_switch { - out_buffer.set_target(self.switch, F::ZERO); - } else if switch { - out_buffer.set_target(self.switch, F::ONE); - } else { - panic!("No permutation"); - } - } -} - -struct PermutationGenerator { - a: Vec<[Target; CHUNK_SIZE]>, - b: Vec<[Target; CHUNK_SIZE]>, +struct PermutationGenerator { + chunk_size: usize, + a: Vec>, + b: Vec>, a_switches: Vec, b_switches: Vec, _phantom: PhantomData, } -impl SimpleGenerator for PermutationGenerator { +impl SimpleGenerator for PermutationGenerator { fn dependencies(&self) -> Vec { self.a .clone() @@ -415,23 +371,15 @@ impl SimpleGenerator for PermutationGenera } fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { - let wire_chunk_to_vals = |wire: [Target; CHUNK_SIZE]| { - let mut vals = [F::ZERO; CHUNK_SIZE]; - for e in 0..CHUNK_SIZE { - vals[e] = witness.get_target(wire[e]); - } - vals - }; - let a_values = self .a .iter() - .map(|chunk| wire_chunk_to_vals(*chunk)) + .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) .collect(); let b_values = self .b .iter() - .map(|chunk| wire_chunk_to_vals(*chunk)) + .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) .collect(); route( a_values, @@ -450,7 +398,6 @@ mod tests { use super::*; use crate::field::crandall_field::CrandallField; - use crate::field::extension_field::quartic::QuarticCrandallField; use crate::field::field_types::Field; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_data::CircuitConfig; @@ -468,10 +415,10 @@ mod tests { let seven = F::from_canonical_usize(7); let eight = F::from_canonical_usize(8); - let one_two = [builder.constant(one), builder.constant(two)]; - let seven_eight = [builder.constant(seven), builder.constant(eight)]; + let one_two = vec![builder.constant(one), builder.constant(two)]; + let seven_eight = vec![builder.constant(seven), builder.constant(eight)]; - let a = vec![one_two, seven_eight]; + let a = vec![one_two.clone(), seven_eight.clone()]; let b = vec![seven_eight, one_two]; builder.assert_permutation(a, b); @@ -498,12 +445,17 @@ mod tests { let seven = F::from_canonical_usize(7); let eight = F::from_canonical_usize(8); - let one_two = [builder.constant(one), builder.constant(two)]; - let three_four = [builder.constant(three), builder.constant(four)]; - let five_six = [builder.constant(five), builder.constant(six)]; - let seven_eight = [builder.constant(seven), builder.constant(eight)]; + let one_two = vec![builder.constant(one), builder.constant(two)]; + let three_four = vec![builder.constant(three), builder.constant(four)]; + let five_six = vec![builder.constant(five), builder.constant(six)]; + let seven_eight = vec![builder.constant(seven), builder.constant(eight)]; - let a = vec![one_two, three_four, five_six, seven_eight]; + let a = vec![ + one_two.clone(), + three_four.clone(), + five_six.clone(), + seven_eight.clone(), + ]; let b = vec![seven_eight, one_two, five_six, three_four]; builder.assert_permutation(a, b); diff --git a/src/gates/switch.rs b/src/gates/switch.rs index 0939fc08..c854e49a 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -14,73 +14,73 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// A gate for conditionally swapping input values based on a boolean. #[derive(Clone, Debug)] -pub(crate) struct SwitchGate, const D: usize, const CHUNK_SIZE: usize> { - num_copies: usize, +pub(crate) struct SwitchGate, const D: usize> { + pub(crate) chunk_size: usize, + pub(crate) num_copies: usize, _phantom: PhantomData, } -impl, const D: usize, const CHUNK_SIZE: usize> SwitchGate { - pub fn new(num_copies: usize) -> Self { +impl, const D: usize> SwitchGate { + pub fn new(num_copies: usize, chunk_size: usize) -> Self { Self { + chunk_size, num_copies, _phantom: PhantomData, } } - pub fn new_from_config(config: CircuitConfig) -> Self { - let num_copies = Self::max_num_copies(config.num_routed_wires); - Self::new(num_copies) + pub fn new_from_config(config: CircuitConfig, chunk_size: usize) -> Self { + let num_copies = Self::max_num_copies(config.num_routed_wires, chunk_size); + Self::new(num_copies, chunk_size) } - pub fn max_num_copies(num_routed_wires: usize) -> usize { - num_routed_wires / (4 * CHUNK_SIZE) + pub fn max_num_copies(num_routed_wires: usize, chunk_size: usize) -> usize { + num_routed_wires / (4 * chunk_size) } - pub fn wire_first_input(copy: usize, element: usize) -> usize { - debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE) + element + pub fn wire_first_input(&self, copy: usize, element: usize) -> usize { + debug_assert!(element < self.chunk_size); + copy * (4 * self.chunk_size) + element } - pub fn wire_second_input(copy: usize, element: usize) -> usize { - debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE) + CHUNK_SIZE + element + pub fn wire_second_input(&self, copy: usize, element: usize) -> usize { + debug_assert!(element < self.chunk_size); + copy * (4 * self.chunk_size) + self.chunk_size + element } - pub fn wire_first_output(copy: usize, element: usize) -> usize { - debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE) + 2 * CHUNK_SIZE + element + pub fn wire_first_output(&self, copy: usize, element: usize) -> usize { + debug_assert!(element < self.chunk_size); + copy * (4 * self.chunk_size) + 2 * self.chunk_size + element } - pub fn wire_second_output(copy: usize, element: usize) -> usize { - debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE) + 3 * CHUNK_SIZE + element + pub fn wire_second_output(&self, copy: usize, element: usize) -> usize { + debug_assert!(element < self.chunk_size); + copy * (4 * self.chunk_size) + 3 * self.chunk_size + element } - pub fn wire_switch_bool(num_copies: usize, copy: usize) -> usize { - debug_assert!(copy < num_copies); - num_copies * (4 * CHUNK_SIZE) + copy + pub fn wire_switch_bool(&self, copy: usize) -> usize { + debug_assert!(copy < self.num_copies); + self.num_copies * (4 * self.chunk_size) + copy } } -impl, const D: usize, const CHUNK_SIZE: usize> Gate - for SwitchGate -{ +impl, const D: usize> Gate for SwitchGate { fn id(&self) -> String { - format!("{:?}", self, D, CHUNK_SIZE) + format!("{:?}", self, D) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); for c in 0..self.num_copies { - let switch_bool = vars.local_wires[Self::wire_switch_bool(self.num_copies, c)]; + let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; let not_switch = F::Extension::ONE - switch_bool; - for e in 0..CHUNK_SIZE { - let first_input = vars.local_wires[Self::wire_first_input(c, e)]; - let second_input = vars.local_wires[Self::wire_second_input(c, e)]; - let first_output = vars.local_wires[Self::wire_first_output(c, e)]; - let second_output = vars.local_wires[Self::wire_second_output(c, e)]; + for e in 0..self.chunk_size { + let first_input = vars.local_wires[self.wire_first_input(c, e)]; + let second_input = vars.local_wires[self.wire_second_input(c, e)]; + let first_output = vars.local_wires[self.wire_first_output(c, e)]; + let second_output = vars.local_wires[self.wire_second_output(c, e)]; constraints.push(switch_bool * (first_input - second_output)); constraints.push(switch_bool * (second_input - first_output)); @@ -96,14 +96,14 @@ impl, const D: usize, const CHUNK_SIZE: usize> Gate let mut constraints = Vec::with_capacity(self.num_constraints()); for c in 0..self.num_copies { - let switch_bool = vars.local_wires[Self::wire_switch_bool(self.num_copies, c)]; + let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; let not_switch = F::ONE - switch_bool; - for e in 0..CHUNK_SIZE { - let first_input = vars.local_wires[Self::wire_first_input(c, e)]; - let second_input = vars.local_wires[Self::wire_second_input(c, e)]; - let first_output = vars.local_wires[Self::wire_first_output(c, e)]; - let second_output = vars.local_wires[Self::wire_second_output(c, e)]; + for e in 0..self.chunk_size { + let first_input = vars.local_wires[self.wire_first_input(c, e)]; + let second_input = vars.local_wires[self.wire_second_input(c, e)]; + let first_output = vars.local_wires[self.wire_first_output(c, e)]; + let second_output = vars.local_wires[self.wire_second_output(c, e)]; constraints.push(switch_bool * (first_input - second_output)); constraints.push(switch_bool * (second_input - first_output)); @@ -124,14 +124,14 @@ impl, const D: usize, const CHUNK_SIZE: usize> Gate let one = builder.one_extension(); for c in 0..self.num_copies { - let switch_bool = vars.local_wires[Self::wire_switch_bool(self.num_copies, c)]; + let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; let not_switch = builder.sub_extension(one, switch_bool); - for e in 0..CHUNK_SIZE { - let first_input = vars.local_wires[Self::wire_first_input(c, e)]; - let second_input = vars.local_wires[Self::wire_second_input(c, e)]; - let first_output = vars.local_wires[Self::wire_first_output(c, e)]; - let second_output = vars.local_wires[Self::wire_second_output(c, e)]; + for e in 0..self.chunk_size { + let first_input = vars.local_wires[self.wire_first_input(c, e)]; + let second_input = vars.local_wires[self.wire_second_input(c, e)]; + let first_output = vars.local_wires[self.wire_first_output(c, e)]; + let second_output = vars.local_wires[self.wire_second_output(c, e)]; let first_switched = builder.sub_extension(first_input, second_output); let first_switched_constraint = builder.mul_extension(switch_bool, first_switched); @@ -164,19 +164,18 @@ impl, const D: usize, const CHUNK_SIZE: usize> Gate ) -> Vec>> { (0..self.num_copies) .map(|c| { - let g: Box> = - Box::new(SwitchGenerator:: { - gate_index, - gate: self.clone(), - copy: c, - }); + let g: Box> = Box::new(SwitchGenerator:: { + gate_index, + gate: self.clone(), + copy: c, + }); g }) .collect() } fn num_wires(&self) -> usize { - Self::wire_switch_bool(self.num_copies, self.num_copies - 1) + 1 + self.wire_switch_bool(self.num_copies - 1) + 1 } fn num_constants(&self) -> usize { @@ -188,34 +187,26 @@ impl, const D: usize, const CHUNK_SIZE: usize> Gate } fn num_constraints(&self) -> usize { - 4 * self.num_copies * CHUNK_SIZE + 4 * self.num_copies * self.chunk_size } } #[derive(Debug)] -struct SwitchGenerator, const D: usize, const CHUNK_SIZE: usize> { +struct SwitchGenerator, const D: usize> { gate_index: usize, - gate: SwitchGate, + gate: SwitchGate, copy: usize, } -impl, const D: usize, const CHUNK_SIZE: usize> SimpleGenerator - for SwitchGenerator -{ +impl, const D: usize> SimpleGenerator for SwitchGenerator { fn dependencies(&self) -> Vec { let local_target = |input| Target::wire(self.gate_index, input); let mut deps = Vec::new(); - for e in 0..CHUNK_SIZE { - deps.push(local_target( - SwitchGate::::wire_first_input(self.copy, e), - )); - deps.push(local_target( - SwitchGate::::wire_second_input(self.copy, e), - )); - deps.push(local_target( - SwitchGate::::wire_switch_bool(self.gate.num_copies, self.copy), - )); + for e in 0..self.gate.chunk_size { + deps.push(local_target(self.gate.wire_first_input(self.copy, e))); + deps.push(local_target(self.gate.wire_second_input(self.copy, e))); + deps.push(local_target(self.gate.wire_switch_bool(self.copy))); } deps @@ -229,23 +220,12 @@ impl, const D: usize, const CHUNK_SIZE: usize> SimpleGenerator< let get_local_wire = |input| witness.get_wire(local_wire(input)); - for e in 0..CHUNK_SIZE { - let first_output_wire = local_wire(SwitchGate::::wire_first_output( - self.copy, e, - )); - let second_output_wire = local_wire( - SwitchGate::::wire_second_output(self.copy, e), - ); - let first_input = get_local_wire(SwitchGate::::wire_first_input( - self.copy, e, - )); - let second_input = get_local_wire(SwitchGate::::wire_second_input( - self.copy, e, - )); - let switch_bool = get_local_wire(SwitchGate::::wire_switch_bool( - self.gate.num_copies, - self.copy, - )); + for e in 0..self.gate.chunk_size { + let first_output_wire = local_wire(self.gate.wire_first_output(self.copy, e)); + let second_output_wire = local_wire(self.gate.wire_second_output(self.copy, e)); + let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e)); + let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e)); + let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy)); let (first_output, second_output) = if switch_bool == F::ZERO { (first_input, second_input) @@ -277,40 +257,44 @@ mod tests { #[test] fn wire_indices() { - type SG = SwitchGate; + type SG = SwitchGate; let num_copies = 3; + let chunk_size = 3; let gate = SG { + chunk_size, num_copies, _phantom: PhantomData, }; - assert_eq!(SG::wire_first_input(0, 0), 0); - assert_eq!(SG::wire_first_input(0, 2), 2); - assert_eq!(SG::wire_second_input(0, 0), 3); - assert_eq!(SG::wire_second_input(0, 2), 5); - assert_eq!(SG::wire_first_output(0, 0), 6); - assert_eq!(SG::wire_second_output(0, 2), 11); - assert_eq!(SG::wire_first_input(1, 0), 12); - assert_eq!(SG::wire_second_output(1, 2), 23); - assert_eq!(SG::wire_first_input(2, 0), 24); - assert_eq!(SG::wire_second_output(2, 2), 35); - assert_eq!(SG::wire_switch_bool(num_copies, 0), 36); - assert_eq!(SG::wire_switch_bool(num_copies, 1), 37); - assert_eq!(SG::wire_switch_bool(num_copies, 2), 38); + assert_eq!(gate.wire_first_input(0, 0), 0); + assert_eq!(gate.wire_first_input(0, 2), 2); + assert_eq!(gate.wire_second_input(0, 0), 3); + assert_eq!(gate.wire_second_input(0, 2), 5); + assert_eq!(gate.wire_first_output(0, 0), 6); + assert_eq!(gate.wire_second_output(0, 2), 11); + assert_eq!(gate.wire_first_input(1, 0), 12); + assert_eq!(gate.wire_second_output(1, 2), 23); + assert_eq!(gate.wire_first_input(2, 0), 24); + assert_eq!(gate.wire_second_output(2, 2), 35); + assert_eq!(gate.wire_switch_bool(0), 36); + assert_eq!(gate.wire_switch_bool(1), 37); + assert_eq!(gate.wire_switch_bool(2), 38); } #[test] fn low_degree() { - test_low_degree::(SwitchGate::<_, 4, 3>::new_from_config( + test_low_degree::(SwitchGate::<_, 4>::new_from_config( CircuitConfig::large_config(), + 3, )); } #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(SwitchGate::<_, 4, 3>::new_from_config( + test_eval_fns::(SwitchGate::<_, 4>::new_from_config( CircuitConfig::large_config(), + 3, )) } @@ -319,7 +303,7 @@ mod tests { type F = CrandallField; type FF = QuarticCrandallField; const D: usize = 4; - const CHUNK_SIZE: usize = 4; + const chunk_size: usize = 4; let num_copies = 3; /// Returns the local wires for a switch gate given the inputs and the switch booleans. @@ -336,11 +320,11 @@ mod tests { let switch = switch_bools[c]; switches.push(F::from_bool(switch)); - let mut first_input_chunk = Vec::with_capacity(CHUNK_SIZE); - let mut second_input_chunk = Vec::with_capacity(CHUNK_SIZE); - let mut first_output_chunk = Vec::with_capacity(CHUNK_SIZE); - let mut second_output_chunk = Vec::with_capacity(CHUNK_SIZE); - for e in 0..CHUNK_SIZE { + let mut first_input_chunk = Vec::with_capacity(chunk_size); + let mut second_input_chunk = Vec::with_capacity(chunk_size); + let mut first_output_chunk = Vec::with_capacity(chunk_size); + let mut second_output_chunk = Vec::with_capacity(chunk_size); + for e in 0..chunk_size { let first_input = first_inputs[c][e]; let second_input = second_inputs[c][e]; let first_output = if switch { second_input } else { first_input }; @@ -360,11 +344,12 @@ mod tests { v.iter().map(|&x| x.into()).collect::>() } - let first_inputs: Vec> = (0..num_copies).map(|_| F::rand_vec(CHUNK_SIZE)).collect(); - let second_inputs: Vec> = (0..num_copies).map(|_| F::rand_vec(CHUNK_SIZE)).collect(); + let first_inputs: Vec> = (0..num_copies).map(|_| F::rand_vec(chunk_size)).collect(); + let second_inputs: Vec> = (0..num_copies).map(|_| F::rand_vec(chunk_size)).collect(); let switch_bools = vec![true, false, true]; - let gate = SwitchGate:: { + let gate = SwitchGate:: { + chunk_size, num_copies, _phantom: PhantomData, }; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 0daa4943..ed19a89f 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -14,6 +14,7 @@ use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; +use crate::gates::switch::SwitchGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; use crate::hash::hashing::hash_n_to_hash; use crate::iop::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator}; @@ -66,7 +67,10 @@ pub struct CircuitBuilder, const D: usize> { /// these constants with gate index `g` and already using `i` arithmetic operations. pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, - pub(crate) current_switch_gates: Vec>, + // `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value + // chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies + // of switches + pub(crate) current_switch_gates: Vec, usize, usize)>>, } impl, const D: usize> CircuitBuilder { @@ -509,6 +513,29 @@ impl, const D: usize> CircuitBuilder { } } + /// Fill the remaining unused switch gates with dummy values, so that all + /// `SwitchGenerator` are run. + fn fill_switch_gates(&mut self) { + let zero = self.zero(); + + for chunk_size in 1..=self.current_switch_gates.len() { + if let Some((gate, gate_index, copy)) = + self.current_switch_gates[chunk_size - 1].clone() + { + for element in 0..chunk_size { + let wire_first_input = + Target::wire(gate_index, gate.wire_first_input(copy, element)); + let wire_second_input = + Target::wire(gate_index, gate.wire_second_input(copy, element)); + let wire_switch_bool = Target::wire(gate_index, gate.wire_switch_bool(copy)); + self.route(zero, wire_first_input); + self.route(zero, wire_second_input); + self.route(zero, wire_switch_bool); + } + } + } + } + pub fn print_gate_counts(&self, min_delta: usize) { self.context_log .filter(self.num_gates(), min_delta)