diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 9f5d7141..bc9aad6f 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -413,7 +413,7 @@ impl, const D: usize> CircuitBuilder { ) -> ExtensionTarget { let inv = self.add_virtual_extension_target(); let one = self.one_extension(); - self.add_generator(QuotientGeneratorExtension { + self.add_simple_generator(QuotientGeneratorExtension { numerator: one, denominator: y, quotient: inv, diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index 3016c2a1..0c51057a 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -6,7 +6,7 @@ use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gates::switch::SwitchGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; -use crate::iop::witness::PartialWitness; +use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bimap::bimap_from_lists; @@ -28,7 +28,7 @@ impl, const D: usize> CircuitBuilder { // Two singleton lists are permutations of one another as long as their items are equal. 1 => { for e in 0..chunk_size { - self.assert_equal(a[0][e], b[0][e]) + self.connect(a[0][e], b[0][e]) } } 2 => { @@ -57,8 +57,8 @@ impl, const D: usize> CircuitBuilder { let (switch, gate_out1, gate_out2) = self.create_switch(a1, a2); for e in 0..chunk_size { - self.route(b1[e], gate_out1[e]); - self.route(b2[e], gate_out2[e]); + self.connect(b1[e], gate_out1[e]); + self.connect(b2[e], gate_out2[e]); } } @@ -91,11 +91,11 @@ impl, const D: usize> CircuitBuilder { let mut c = Vec::new(); let mut d = Vec::new(); for e in 0..chunk_size { - self.route( + self.connect( a1[e], Target::wire(gate_index, gate.wire_first_input(next_copy, e)), ); - self.route( + self.connect( a2[e], Target::wire(gate_index, gate.wire_second_input(next_copy, e)), ); @@ -176,7 +176,7 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation(child_1_a, child_1_b); self.assert_permutation(child_2_a, child_2_b); - self.add_generator(PermutationGenerator:: { + self.add_simple_generator(PermutationGenerator:: { chunk_size, a, b, @@ -192,7 +192,7 @@ fn route( b_values: Vec>, a_switches: Vec, b_switches: Vec, - witness: &PartialWitness, + witness: &PartitionWitness, out_buffer: &mut GeneratedValues, ) { assert_eq!(a_values.len(), b_values.len()); @@ -221,7 +221,7 @@ fn route( // After we route a wire on one side, we find the corresponding wire on the other side and check // if it still needs to be routed. If so, we add it to partial_routes. let enqueue_other_side = |partial_routes: &mut [BTreeMap], - witness: &PartialWitness, + witness: &PartitionWitness, newly_set: &mut [Vec], side: usize, this_i: usize, @@ -272,7 +272,7 @@ fn route( } let mut route_switch = |partial_routes: &mut [BTreeMap], - witness: &PartialWitness, + witness: &PartitionWitness, out_buffer: &mut GeneratedValues, side: usize, switch_index: usize, @@ -351,6 +351,8 @@ fn route( } } } + +#[derive(Debug)] struct PermutationGenerator { chunk_size: usize, a: Vec>, @@ -370,7 +372,7 @@ impl SimpleGenerator for PermutationGenerator { .collect() } - fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { let a_values = self .a .iter() @@ -406,9 +408,10 @@ mod tests { #[test] fn test_permutation_2x2() -> Result<()> { type F = CrandallField; - let config = CircuitConfig::large_config(); - let pw = PartialWitness::new(config.num_wires); - let mut builder = CircuitBuilder::::new(config); + let config = CircuitConfig::large_zk_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); let one = F::ONE; let two = F::from_canonical_usize(2); @@ -432,9 +435,10 @@ mod tests { #[test] fn test_permutation_4x4() -> Result<()> { type F = CrandallField; - let config = CircuitConfig::large_config(); - let pw = PartialWitness::new(config.num_wires); - let mut builder = CircuitBuilder::::new(config); + let config = CircuitConfig::large_zk_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); let one = F::ONE; let two = F::from_canonical_usize(2); diff --git a/src/gadgets/range_check.rs b/src/gadgets/range_check.rs index 4f3b7d0d..32a3608d 100644 --- a/src/gadgets/range_check.rs +++ b/src/gadgets/range_check.rs @@ -28,7 +28,7 @@ impl, const D: usize> CircuitBuilder { let high_gate = self.add_gate(BaseSumGate::<2>::new(num_bits - n_log), vec![]); let low = Target::wire(low_gate, BaseSumGate::<2>::WIRE_SUM); let high = Target::wire(high_gate, BaseSumGate::<2>::WIRE_SUM); - self.add_generator(LowHighGenerator { + self.add_simple_generator(LowHighGenerator { integer: x, n_log, low, diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index 914269d5..02d24833 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -44,7 +44,7 @@ impl, const D: usize> CircuitBuilder { self.connect(limb.borrow().target, Target::wire(gate_index, wire)); } - self.add_generator(BaseSumGenerator::<2> { + self.add_simple_generator(BaseSumGenerator::<2> { gate_index, limbs: bits.map(|l| *l.borrow()).collect(), }); diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 4b6527ca..143e63e3 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -47,7 +47,7 @@ impl, const D: usize> CircuitBuilder { } self.connect(acc, integer); - self.add_generator(WireSplitGenerator { + self.add_simple_generator(WireSplitGenerator { integer, gates, num_limbs: bits_per_gate, diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 3c9dcb18..41dee743 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -110,12 +110,15 @@ impl, const D: usize> Gate for ArithmeticExtensionGate ) -> Vec>> { (0..NUM_ARITHMETIC_OPS) .map(|i| { - let g: Box> = Box::new(ArithmeticExtensionGenerator { - gate_index, - const_0: local_constants[0], - const_1: local_constants[1], - i, - }); + let g: Box> = Box::new( + ArithmeticExtensionGenerator { + gate_index, + const_0: local_constants[0], + const_1: local_constants[1], + i, + } + .adapter(), + ); g }) .collect::>() diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 341d0fc1..9102bf07 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -105,7 +105,7 @@ impl, const D: usize, const B: usize> Gate for BaseSumGat gate_index, num_limbs: self.num_limbs, }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } // 1 for the sum then `num_limbs` for the limbs. diff --git a/src/gates/constant.rs b/src/gates/constant.rs index 894016d6..40493e81 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -54,7 +54,7 @@ impl, const D: usize> Gate for ConstantGate { gate_index, constant: local_constants[0], }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index 1a3a6ea3..2c1f70bc 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -184,7 +184,7 @@ impl, const D: usize> Gate for ExponentiationGate { gate_index, gate: self.clone(), }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 682bd1fa..9e645fcc 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -217,7 +217,7 @@ impl, const D: usize, const R: usize> Gate for GMiMCGate< gate_index, constants: self.constants.clone(), }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index 4c1e6fdf..030dfbff 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -220,7 +220,7 @@ impl, const D: usize> Gate for InsertionGate { gate_index, gate: self.clone(), }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 56924742..53571b71 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -189,7 +189,7 @@ impl, const D: usize> Gate for InterpolationGate { gate: self.clone(), _phantom: PhantomData, }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { diff --git a/src/gates/random_access.rs b/src/gates/random_access.rs index 3bba5ad2..1e87fc10 100644 --- a/src/gates/random_access.rs +++ b/src/gates/random_access.rs @@ -167,7 +167,7 @@ impl, const D: usize> Gate for RandomAccessGate { gate_index, gate: self.clone(), }; - vec![Box::new(gen)] + vec![Box::new(gen.adapter())] } fn num_wires(&self) -> usize { diff --git a/src/gates/reducing.rs b/src/gates/reducing.rs index ee3232b2..7545e549 100644 --- a/src/gates/reducing.rs +++ b/src/gates/reducing.rs @@ -136,10 +136,13 @@ impl, const D: usize> Gate for ReducingGate { gate_index: usize, _local_constants: &[F], ) -> Vec>> { - vec![Box::new(ReducingGenerator { - gate_index, - gate: self.clone(), - })] + vec![Box::new( + ReducingGenerator { + gate_index, + gate: self.clone(), + } + .adapter(), + )] } fn num_wires(&self) -> usize { diff --git a/src/gates/switch.rs b/src/gates/switch.rs index b520ef50..954a4997 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -228,7 +228,7 @@ impl, const D: usize> SwitchGenerator { deps } - fn run_in_out(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { + fn run_in_out(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { let local_wire = |input| Wire { gate: self.gate_index, input, @@ -253,7 +253,7 @@ impl, const D: usize> SwitchGenerator { } } - fn run_in_switch(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { + fn run_in_switch(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { let local_wire = |input| Wire { gate: self.gate_index, input, @@ -288,7 +288,7 @@ impl, const D: usize> WitnessGenerator for SwitchGenerator, out_buffer: &mut GeneratedValues) -> bool { + fn run(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) -> bool { if witness.contains_all(&self.in_out_dependencies()) { self.run_in_out(witness, out_buffer); true diff --git a/src/iop/generator.rs b/src/iop/generator.rs index 8174981c..483a7a62 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -1,4 +1,5 @@ use std::fmt::Debug; +use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; @@ -186,16 +187,32 @@ pub trait SimpleGenerator: 'static + Send + Sync + Debug { fn dependencies(&self) -> Vec; fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues); + + fn adapter(self) -> SimpleGeneratorAdapter + where + Self: Sized, + { + SimpleGeneratorAdapter { + inner: self, + _phantom: PhantomData, + } + } } -impl> WitnessGenerator for SG { +#[derive(Debug)] +pub struct SimpleGeneratorAdapter + ?Sized> { + _phantom: PhantomData, + inner: SG, +} + +impl> WitnessGenerator for SimpleGeneratorAdapter { fn watch_list(&self) -> Vec { - self.dependencies() + self.inner.dependencies() } fn run(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) -> bool { - if witness.contains_all(&self.dependencies()) { - self.run_once(witness, out_buffer); + if witness.contains_all(&self.inner.dependencies()) { + self.inner.run_once(witness, out_buffer); true } else { false diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index ed19a89f..7252cebf 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -17,7 +17,9 @@ use crate::gates::public_input::PublicInputGate; use crate::gates::switch::SwitchGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; use crate::hash::hashing::hash_n_to_hash; -use crate::iop::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator}; +use crate::iop::generator::{ + CopyGenerator, RandomValueGenerator, SimpleGenerator, WitnessGenerator, +}; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::iop::witness::PartitionWitness; @@ -188,7 +190,7 @@ impl, const D: usize> CircuitBuilder { /// Adds a generator which will copy `src` to `dst`. pub fn generate_copy(&mut self, src: Target, dst: Target) { - self.add_generator(CopyGenerator { src, dst }); + self.add_simple_generator(CopyGenerator { src, dst }); } /// Uses Plonk's permutation argument to require that two elements be equal. @@ -215,8 +217,8 @@ impl, const D: usize> CircuitBuilder { self.generators.extend(generators); } - pub fn add_generator>(&mut self, generator: G) { - self.generators.push(Box::new(generator)); + pub fn add_simple_generator>(&mut self, generator: G) { + self.generators.push(Box::new(generator.adapter())); } /// Returns a routable target with a value of 0. @@ -389,7 +391,7 @@ impl, const D: usize> CircuitBuilder { for _ in 0..regular_poly_openings { let gate = self.add_gate(NoopGate, vec![]); for w in 0..num_wires { - self.add_generator(RandomValueGenerator { + self.add_simple_generator(RandomValueGenerator { target: Target::Wire(Wire { gate, input: w }), }); } @@ -403,7 +405,7 @@ impl, const D: usize> CircuitBuilder { let gate_2 = self.add_gate(NoopGate, vec![]); for w in 0..num_routed_wires { - self.add_generator(RandomValueGenerator { + self.add_simple_generator(RandomValueGenerator { target: Target::Wire(Wire { gate: gate_1, input: w, @@ -528,9 +530,9 @@ impl, const D: usize> CircuitBuilder { let wire_second_input = Target::wire(gate_index, gate.wire_second_input(copy, element)); let wire_switch_bool = Target::wire(gate_index, gate.wire_switch_bool(copy)); - self.route(zero, wire_first_input); - self.route(zero, wire_second_input); - self.route(zero, wire_switch_bool); + self.connect(zero, wire_first_input); + self.connect(zero, wire_second_input); + self.connect(zero, wire_switch_bool); } } }