diff --git a/insertion/src/insert_gadget.rs b/insertion/src/insert_gadget.rs index 9f8aa4bf..76cb4410 100644 --- a/insertion/src/insert_gadget.rs +++ b/insertion/src/insert_gadget.rs @@ -27,7 +27,7 @@ impl, const D: usize> CircuitBuilderInsert v: Vec>, ) -> Vec> { let gate = InsertionGate::new(v.len()); - let gate_index = self.add_gate(gate.clone(), vec![]); + let gate_index = self.add_gate(gate.clone(), vec![], vec![]); v.iter().enumerate().for_each(|(i, &val)| { self.connect_extension( diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index c00220c3..734e2705 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -79,7 +79,8 @@ impl, const D: usize> CircuitBuilder { fn add_base_arithmetic_operation(&mut self, operation: BaseArithmeticOperation) -> Target { let gate = ArithmeticGate::new_from_config(&self.config); - let (gate, i) = self.find_slot(gate, vec![operation.const_0, operation.const_1]); + let constants = vec![operation.const_0, operation.const_1]; + let (gate, i) = self.find_slot(gate, &constants, &constants); let wires_multiplicand_0 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_0(i)); let wires_multiplicand_1 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_1(i)); let wires_addend = Target::wire(gate, ArithmeticGate::wire_ith_addend(i)); @@ -240,7 +241,7 @@ impl, const D: usize> CircuitBuilder { while exp_bits_vec.len() < num_power_bits { exp_bits_vec.push(_false); } - let gate_index = self.add_gate(gate.clone(), vec![]); + let gate_index = self.add_gate(gate.clone(), vec![], vec![]); self.connect(base, Target::wire(gate_index, gate.wire_base())); exp_bits_vec.iter().enumerate().for_each(|(i, bit)| { diff --git a/plonky2/src/gadgets/arithmetic_extension.rs b/plonky2/src/gadgets/arithmetic_extension.rs index e1a17988..d29f8f8f 100644 --- a/plonky2/src/gadgets/arithmetic_extension.rs +++ b/plonky2/src/gadgets/arithmetic_extension.rs @@ -61,7 +61,8 @@ impl, const D: usize> CircuitBuilder { operation: ExtensionArithmeticOperation, ) -> ExtensionTarget { let gate = ArithmeticExtensionGate::new_from_config(&self.config); - let (gate, i) = self.find_slot(gate, vec![operation.const_0, operation.const_1]); + let constants = vec![operation.const_0, operation.const_1]; + let (gate, i) = self.find_slot(gate, &constants, &constants); let wires_multiplicand_0 = ExtensionTarget::from_range( gate, ArithmeticExtensionGate::::wires_ith_multiplicand_0(i), @@ -85,7 +86,8 @@ impl, const D: usize> CircuitBuilder { operation: ExtensionArithmeticOperation, ) -> ExtensionTarget { let gate = MulExtensionGate::new_from_config(&self.config); - let (gate, i) = self.find_slot(gate, vec![operation.const_0]); + let constants = vec![operation.const_0]; + let (gate, i) = self.find_slot(gate, &constants, &constants); let wires_multiplicand_0 = ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_multiplicand_0(i)); let wires_multiplicand_1 = diff --git a/plonky2/src/gadgets/arithmetic_u32.rs b/plonky2/src/gadgets/arithmetic_u32.rs index 1c8df3b6..dfdbb5fb 100644 --- a/plonky2/src/gadgets/arithmetic_u32.rs +++ b/plonky2/src/gadgets/arithmetic_u32.rs @@ -78,7 +78,7 @@ impl, const D: usize> CircuitBuilder { } let gate = U32ArithmeticGate::::new_from_config(&self.config); - let (gate_index, copy) = self.find_slot(gate, vec![]); + let (gate_index, copy) = self.find_slot(gate, &[], &[]); self.connect( Target::wire(gate_index, gate.wire_ith_multiplicand_0(copy)), @@ -138,7 +138,7 @@ impl, const D: usize> CircuitBuilder { borrow: U32Target, ) -> (U32Target, U32Target) { let gate = U32SubtractionGate::::new_from_config(&self.config); - let (gate_index, copy) = self.find_slot(gate, vec![]); + let (gate_index, copy) = self.find_slot(gate, &[], &[]); self.connect(Target::wire(gate_index, gate.wire_ith_input_x(copy)), x.0); self.connect(Target::wire(gate_index, gate.wire_ith_input_y(copy)), y.0); diff --git a/plonky2/src/gadgets/interpolation.rs b/plonky2/src/gadgets/interpolation.rs index 2d2c2273..473361aa 100644 --- a/plonky2/src/gadgets/interpolation.rs +++ b/plonky2/src/gadgets/interpolation.rs @@ -88,7 +88,7 @@ impl, const D: usize> CircuitBuilder { evaluation_point: ExtensionTarget, ) -> ExtensionTarget { let gate = G::new(subgroup_bits); - let gate_index = self.add_gate(gate, vec![]); + let gate_index = self.add_gate(gate, vec![], vec![]); self.connect(coset_shift, Target::wire(gate_index, gate.wire_shift())); for (i, &v) in values.iter().enumerate() { self.connect_extension( diff --git a/plonky2/src/gadgets/multiple_comparison.rs b/plonky2/src/gadgets/multiple_comparison.rs index 88b94f3f..7d637d87 100644 --- a/plonky2/src/gadgets/multiple_comparison.rs +++ b/plonky2/src/gadgets/multiple_comparison.rs @@ -25,7 +25,7 @@ impl, const D: usize> CircuitBuilder { let mut result = one; for i in 0..n { let a_le_b_gate = ComparisonGate::new(num_bits, num_chunks); - let a_le_b_gate_index = self.add_gate(a_le_b_gate.clone(), vec![]); + let a_le_b_gate_index = self.add_gate(a_le_b_gate.clone(), vec![], vec![]); self.connect( Target::wire(a_le_b_gate_index, a_le_b_gate.wire_first_input()), a[i], @@ -37,7 +37,7 @@ impl, const D: usize> CircuitBuilder { let a_le_b_result = Target::wire(a_le_b_gate_index, a_le_b_gate.wire_result_bool()); let b_le_a_gate = ComparisonGate::new(num_bits, num_chunks); - let b_le_a_gate_index = self.add_gate(b_le_a_gate.clone(), vec![]); + let b_le_a_gate_index = self.add_gate(b_le_a_gate.clone(), vec![], vec![]); self.connect( Target::wire(b_le_a_gate_index, b_le_a_gate.wire_first_input()), b[i], diff --git a/plonky2/src/gadgets/random_access.rs b/plonky2/src/gadgets/random_access.rs index 1a142844..9518e9fa 100644 --- a/plonky2/src/gadgets/random_access.rs +++ b/plonky2/src/gadgets/random_access.rs @@ -18,7 +18,7 @@ impl, const D: usize> CircuitBuilder { return self.connect(claimed_element, v[0]); } let dummy_gate = RandomAccessGate::::new_from_config(&self.config, bits); - let (gate_index, copy) = self.find_slot(dummy_gate, vec![]); + let (gate_index, copy) = self.find_slot(dummy_gate, &[], &[]); v.iter().enumerate().for_each(|(i, &val)| { self.connect( diff --git a/plonky2/src/gadgets/split_base.rs b/plonky2/src/gadgets/split_base.rs index d4589476..db6fe69d 100644 --- a/plonky2/src/gadgets/split_base.rs +++ b/plonky2/src/gadgets/split_base.rs @@ -16,7 +16,7 @@ impl, const D: usize> CircuitBuilder { /// base-B limb of the element, with little-endian ordering. pub fn split_le_base(&mut self, x: Target, num_limbs: usize) -> Vec { let gate_type = BaseSumGate::::new(num_limbs); - let gate = self.add_gate(gate_type, vec![]); + let gate = self.add_gate(gate_type, vec![], vec![]); let sum = Target::wire(gate, BaseSumGate::::WIRE_SUM); self.connect(x, sum); @@ -54,7 +54,7 @@ impl, const D: usize> CircuitBuilder { "Not enough routed wires." ); let gate_type = BaseSumGate::<2>::new_from_config::(&self.config); - let gate_index = self.add_gate(gate_type, vec![]); + let gate_index = self.add_gate(gate_type, vec![], vec![]); for (limb, wire) in bits .iter() .zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_bits) diff --git a/plonky2/src/gadgets/split_join.rs b/plonky2/src/gadgets/split_join.rs index 8c7c9c3f..dd7a0a39 100644 --- a/plonky2/src/gadgets/split_join.rs +++ b/plonky2/src/gadgets/split_join.rs @@ -20,7 +20,7 @@ impl, const D: usize> CircuitBuilder { let gate_type = BaseSumGate::<2>::new_from_config::(&self.config); let k = ceil_div_usize(num_bits, gate_type.num_limbs); let gates = (0..k) - .map(|_| self.add_gate(gate_type, vec![])) + .map(|_| self.add_gate(gate_type, vec![], vec![])) .collect::>(); let mut bits = Vec::with_capacity(num_bits); diff --git a/plonky2/src/gates/batchable.rs b/plonky2/src/gates/batchable.rs index d40fd73c..c37cd2c8 100644 --- a/plonky2/src/gates/batchable.rs +++ b/plonky2/src/gates/batchable.rs @@ -22,7 +22,7 @@ pub trait BatchableGate, const D: usize>: Gate, const D: usize> { pub current_slot: HashMap, (usize, usize)>, } @@ -88,6 +88,7 @@ impl, G: MultiOpsGate, const D: usize> Batcha current_slot: &CurrentSlot, builder: &mut CircuitBuilder, ) { + dbg!(self.id(), ¤t_slot, params); if let Some(&(gate_index, op)) = current_slot.current_slot.get(params) { let zero = builder.zero(); for i in op..self.num_ops() { diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index 34684194..e5d3943f 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -175,6 +175,7 @@ pub trait Gate, const D: usize>: 'static + Send + S pub struct GateInstance, const D: usize> { pub gate_ref: GateRef, pub constants: Vec, + pub params: Vec, } /// Map each gate to a boolean prefix used to construct the gate's selector polynomial. diff --git a/plonky2/src/gates/gmimc.rs b/plonky2/src/gates/gmimc.rs index 47a82e62..b3ff8969 100644 --- a/plonky2/src/gates/gmimc.rs +++ b/plonky2/src/gates/gmimc.rs @@ -389,7 +389,7 @@ mod tests { let mut builder = CircuitBuilder::new(config); type Gate = GMiMCGate; let gate = Gate::new(); - let gate_index = builder.add_gate(gate, vec![]); + let gate_index = builder.add_gate(gate, vec![], vec![]); let circuit = builder.build_prover::(); let permutation_inputs = (0..WIDTH).map(F::from_canonical_usize).collect::>(); diff --git a/plonky2/src/gates/poseidon.rs b/plonky2/src/gates/poseidon.rs index 4a2ffee7..82fedcbb 100644 --- a/plonky2/src/gates/poseidon.rs +++ b/plonky2/src/gates/poseidon.rs @@ -569,7 +569,7 @@ mod tests { let mut builder = CircuitBuilder::new(config); type Gate = PoseidonGate; let gate = Gate::new(); - let gate_index = builder.add_gate(gate, vec![]); + let gate_index = builder.add_gate(gate, vec![], vec![]); let circuit = builder.build_prover::(); let permutation_inputs = (0..SPONGE_WIDTH) diff --git a/plonky2/src/gates/switch.rs b/plonky2/src/gates/switch.rs index e9048edd..62209720 100644 --- a/plonky2/src/gates/switch.rs +++ b/plonky2/src/gates/switch.rs @@ -23,7 +23,7 @@ use crate::plonk::vars::{ }; /// A gate for conditionally swapping input values based on a boolean. -#[derive(Clone, Debug)] +#[derive(Copy, Clone, Debug)] pub struct SwitchGate, const D: usize> { pub(crate) chunk_size: usize, pub(crate) num_copies: usize, @@ -165,7 +165,7 @@ impl, const D: usize> Gate for SwitchGate> = Box::new(SwitchGenerator:: { gate_index, - gate: self.clone(), + gate: *self, copy: c, }); g @@ -195,8 +195,13 @@ impl, const D: usize> MultiOpsGate for Switch self.num_copies } - fn dependencies_ith_op(&self, _gate_index: usize, _i: usize) -> Vec { - todo!() + fn dependencies_ith_op(&self, gate_index: usize, i: usize) -> Vec { + SwitchGenerator:: { + gate_index, + gate: *self, + copy: i, + } + .watch_list() } } diff --git a/plonky2/src/hash/gmimc.rs b/plonky2/src/hash/gmimc.rs index 3492e08f..13f7807f 100644 --- a/plonky2/src/hash/gmimc.rs +++ b/plonky2/src/hash/gmimc.rs @@ -126,7 +126,7 @@ impl AlgebraicHasher for GMiMCHash { F: RichField + Extendable, { let gate_type = GMiMCGate::::new(); - let gate = builder.add_gate(gate_type, vec![]); + let gate = builder.add_gate(gate_type, vec![], vec![]); let swap_wire = GMiMCGate::::WIRE_SWAP; let swap_wire = Target::wire(gate, swap_wire); diff --git a/plonky2/src/hash/poseidon.rs b/plonky2/src/hash/poseidon.rs index 606dfd13..8e8d6ba5 100644 --- a/plonky2/src/hash/poseidon.rs +++ b/plonky2/src/hash/poseidon.rs @@ -269,7 +269,7 @@ pub trait Poseidon: PrimeField { // If we have enough routed wires, we will use PoseidonMdsGate. let mds_gate = PoseidonMdsGate::::new(); if builder.config.num_routed_wires >= mds_gate.num_wires() { - let index = builder.add_gate(mds_gate, vec![]); + let index = builder.add_gate(mds_gate, vec![], vec![]); for i in 0..WIDTH { let input_wire = PoseidonMdsGate::::wires_input(i); builder.connect_extension(state[i], ExtensionTarget::from_range(index, input_wire)); @@ -652,7 +652,7 @@ impl AlgebraicHasher for PoseidonHash { F: RichField + Extendable, { let gate_type = PoseidonGate::::new(); - let gate = builder.add_gate(gate_type, vec![]); + let gate = builder.add_gate(gate_type, vec![], vec![]); let swap_wire = PoseidonGate::::WIRE_SWAP; let swap_wire = Target::wire(gate, swap_wire); diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 368232fd..1c7779c6 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -91,6 +91,11 @@ pub(crate) fn generate_partial_witness< pending_generator_indices = next_pending_generator_indices; } + for i in 0..generator_is_expired.len() { + if !generator_is_expired[i] { + dbg!(i); + } + } assert_eq!( remaining_generators, 0, "{} generators weren't run", diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 80a951c7..b658e339 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -45,7 +45,7 @@ use crate::util::timing::TimingTree; use crate::util::{transpose, transpose_poly_values}; pub struct CircuitBuilder, const D: usize> { - pub(crate) config: CircuitConfig, + pub config: CircuitConfig, /// The types of gates used in this circuit. gates: HashSet>, @@ -183,7 +183,12 @@ impl, const D: usize> CircuitBuilder { } /// Adds a gate to the circuit, and returns its index. - pub fn add_gate>(&mut self, gate_type: G, constants: Vec) -> usize { + pub fn add_gate>( + &mut self, + gate_type: G, + constants: Vec, + params: Vec, + ) -> usize { // println!("{} {}", self.num_gates(), gate_type.id()); self.check_gate_compatibility(&gate_type); assert_eq!( @@ -205,6 +210,7 @@ impl, const D: usize> CircuitBuilder { self.gate_instances.push(GateInstance { gate_ref, constants, + params, }); index @@ -303,7 +309,7 @@ impl, const D: usize> CircuitBuilder { // We will fill this `ConstantGate` with zero constants initially. // These will be overwritten by `constant` as the gate instances are filled. let gate = ConstantGate { num_consts }; - let (gate, instance) = self.find_slot(gate, vec![F::ZERO; num_consts]); + let (gate, instance) = self.find_slot(gate, &[], &vec![F::ZERO; num_consts]); let target = Target::wire(gate, instance); self.gate_instances[gate].constants[instance] = c; @@ -372,7 +378,8 @@ impl, const D: usize> CircuitBuilder { pub fn find_slot + Clone>( &mut self, gate: G, - params: Vec, + params: &[F], + constants: &[F], ) -> (usize, usize) { let num_gates = self.num_gates(); let num_ops = gate.num_ops(); @@ -383,11 +390,11 @@ impl, const D: usize> CircuitBuilder { .or_insert(CurrentSlot { current_slot: HashMap::new(), }); - let slot = gate_slot.current_slot.get(¶ms); + let slot = gate_slot.current_slot.get(params); let res = if let Some(&s) = slot { s } else { - self.add_gate(gate, params.clone()); + self.add_gate(gate, constants.to_vec(), params.to_vec()); (num_gates, 0) }; if res.1 == num_ops - 1 { @@ -395,13 +402,13 @@ impl, const D: usize> CircuitBuilder { .get_mut(&gate_ref) .unwrap() .current_slot - .remove(¶ms); + .remove(params); } else { self.current_slots .get_mut(&gate_ref) .unwrap() .current_slot - .insert(params, (res.0, res.1 + 1)); + .insert(params.to_vec(), (res.0, res.1 + 1)); } res @@ -490,7 +497,7 @@ impl, const D: usize> CircuitBuilder { } while !self.gate_instances.len().is_power_of_two() { - self.add_gate(NoopGate, vec![]); + self.add_gate(NoopGate, vec![], vec![]); } } @@ -507,7 +514,7 @@ impl, const D: usize> CircuitBuilder { // For each "regular" blinding factor, we simply add a no-op gate, and insert a random value // for each wire. for _ in 0..regular_poly_openings { - let gate = self.add_gate(NoopGate, vec![]); + let gate = self.add_gate(NoopGate, vec![], vec![]); for w in 0..num_wires { self.add_simple_generator(RandomValueGenerator { target: Target::Wire(Wire { gate, input: w }), @@ -519,8 +526,8 @@ impl, const D: usize> CircuitBuilder { // enforce a copy constraint between them. // See https://mirprotocol.org/blog/Adding-zero-knowledge-to-Plonk-Halo for _ in 0..z_openings { - let gate_1 = self.add_gate(NoopGate, vec![]); - let gate_2 = self.add_gate(NoopGate, vec![]); + let gate_1 = self.add_gate(NoopGate, vec![], vec![]); + let gate_2 = self.add_gate(NoopGate, vec![], vec![]); for w in 0..num_routed_wires { self.add_simple_generator(RandomValueGenerator { @@ -634,7 +641,7 @@ impl, const D: usize> CircuitBuilder { // those hash wires match the claimed public inputs. let public_inputs_hash = self.hash_n_to_hash::(self.public_inputs.clone(), true); - let pi_gate = self.add_gate(PublicInputGate, vec![]); + let pi_gate = self.add_gate(PublicInputGate, vec![], vec![]); for (&hash_part, wire) in public_inputs_hash .elements .iter() @@ -694,6 +701,11 @@ impl, const D: usize> CircuitBuilder { constants_sigmas_cap: constants_sigmas_cap.clone(), }; + let mut gens = self.generators.len(); + for (i, g) in self.gate_instances.iter().enumerate() { + gens += g.gate_ref.0.generators(i, &g.constants).len(); + dbg!(g.gate_ref.0.id(), gens); + } // Add gate generators. self.add_generators( self.gate_instances @@ -1130,11 +1142,12 @@ impl, const D: usize> CircuitBuilder { // } // fn fill_batched_gates(&mut self) { + dbg!(&self.current_slots); let instances = self.gate_instances.clone(); for gate in instances { if let Some(slot) = self.current_slots.get(&gate.gate_ref) { let cloned = slot.clone(); - gate.gate_ref.0.fill_gate(&gate.constants, &cloned, self); + gate.gate_ref.0.fill_gate(&gate.params, &cloned, self); } } } diff --git a/plonky2/src/plonk/proof.rs b/plonky2/src/plonk/proof.rs index 91db2b25..b494f324 100644 --- a/plonky2/src/plonk/proof.rs +++ b/plonky2/src/plonk/proof.rs @@ -353,7 +353,7 @@ mod tests { let comp_zt = builder.mul(xt, yt); builder.connect(zt, comp_zt); for _ in 0..100 { - builder.add_gate(NoopGate, vec![]); + builder.add_gate(NoopGate, vec![], vec![]); } let data = builder.build::(); let proof = data.prove(pw)?; diff --git a/plonky2/src/plonk/recursive_verifier.rs b/plonky2/src/plonk/recursive_verifier.rs index dc6f5039..a9af478c 100644 --- a/plonky2/src/plonk/recursive_verifier.rs +++ b/plonky2/src/plonk/recursive_verifier.rs @@ -546,7 +546,7 @@ mod tests { )> { let mut builder = CircuitBuilder::::new(config.clone()); for _ in 0..num_dummy_gates { - builder.add_gate(NoopGate, vec![]); + builder.add_gate(NoopGate, vec![], vec![]); } let data = builder.build::(); @@ -601,7 +601,7 @@ mod tests { // builder will pad to the next power of two, 2^min_degree_bits. let min_gates = (1 << (min_degree_bits - 1)) + 1; for _ in builder.num_gates()..min_gates { - builder.add_gate(NoopGate, vec![]); + builder.add_gate(NoopGate, vec![], vec![]); } } diff --git a/plonky2/src/util/reducing.rs b/plonky2/src/util/reducing.rs index 8bfe45d1..a2d4e4cf 100644 --- a/plonky2/src/util/reducing.rs +++ b/plonky2/src/util/reducing.rs @@ -132,7 +132,7 @@ impl ReducingFactorTarget { reversed_terms.reverse(); for chunk in reversed_terms.chunks_exact(max_coeffs_len) { let gate = ReducingGate::new(max_coeffs_len); - let gate_index = builder.add_gate(gate.clone(), Vec::new()); + let gate_index = builder.add_gate(gate.clone(), vec![], vec![]); builder.connect_extension( self.base, @@ -182,7 +182,7 @@ impl ReducingFactorTarget { reversed_terms.reverse(); for chunk in reversed_terms.chunks_exact(max_coeffs_len) { let gate = ReducingExtensionGate::new(max_coeffs_len); - let gate_index = builder.add_gate(gate.clone(), Vec::new()); + let gate_index = builder.add_gate(gate.clone(), vec![], vec![]); builder.connect_extension( self.base, diff --git a/waksman/src/permutation.rs b/waksman/src/permutation.rs index b0e725d2..b25c2980 100644 --- a/waksman/src/permutation.rs +++ b/waksman/src/permutation.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; use std::marker::PhantomData; use plonky2::field::{extension_field::Extendable, field_types::Field}; +use plonky2::gates::switch::SwitchGate; use plonky2::hash::hash_types::RichField; use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; use plonky2::iop::target::Target; @@ -71,41 +72,42 @@ fn assert_permutation_2x2, const D: usize>( /// Given two input wire chunks, add a new switch to the circuit (by adding one copy to a switch /// gate). Returns the wire for the switch boolean, and the two output wire chunks. fn create_switch, const D: usize>( - _builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, a1: Vec, a2: Vec, ) -> (Target, Vec, Vec) { assert_eq!(a1.len(), a2.len(), "Chunk size must be the same"); - let _chunk_size = a1.len(); + let chunk_size = a1.len(); - todo!() - // let (gate, gate_index, next_copy) = builder.find_switch_gate(chunk_size); - // - // let mut c = Vec::new(); - // let mut d = Vec::new(); - // for e in 0..chunk_size { - // builder.connect( - // a1[e], - // Target::wire(gate_index, gate.wire_first_input(next_copy, e)), - // ); - // builder.connect( - // a2[e], - // Target::wire(gate_index, gate.wire_second_input(next_copy, e)), - // ); - // c.push(Target::wire( - // gate_index, - // gate.wire_first_output(next_copy, e), - // )); - // d.push(Target::wire( - // gate_index, - // gate.wire_second_output(next_copy, e), - // )); - // } + let gate = SwitchGate::new_from_config(&builder.config, chunk_size); + let params = vec![F::from_canonical_usize(chunk_size)]; + let (gate_index, next_copy) = builder.find_slot(gate, ¶ms, &[]); - // let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy)); - // - // (switch, c, d) + let mut c = Vec::new(); + let mut d = Vec::new(); + for e in 0..chunk_size { + builder.connect( + a1[e], + Target::wire(gate_index, gate.wire_first_input(next_copy, e)), + ); + builder.connect( + a2[e], + Target::wire(gate_index, gate.wire_second_input(next_copy, e)), + ); + c.push(Target::wire( + gate_index, + gate.wire_first_output(next_copy, e), + )); + d.push(Target::wire( + gate_index, + gate.wire_second_output(next_copy, e), + )); + } + + let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy)); + + (switch, c, d) } fn assert_permutation_recursive, const D: usize>( diff --git a/waksman/src/sorting.rs b/waksman/src/sorting.rs index b3e616d5..4270ebc7 100644 --- a/waksman/src/sorting.rs +++ b/waksman/src/sorting.rs @@ -54,7 +54,7 @@ pub fn assert_le, const D: usize>( num_chunks: usize, ) { let gate = AssertLessThanGate::new(bits, num_chunks); - let gate_index = builder.add_gate(gate.clone(), vec![]); + let gate_index = builder.add_gate(gate.clone(), vec![], vec![]); builder.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); builder.connect(Target::wire(gate_index, gate.wire_second_input()), rhs);