From d6211b8ab8811bc0617b2101b6db8356561bab4c Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 2 Aug 2021 10:55:10 -0700 Subject: [PATCH] Reuse a buffer of generated values (#142) * Reuse a buffer of generated values To avoid allocating `GeneratedValues` all the time. Saves ~60ms or so. * PR feedback --- src/gadgets/arithmetic_extension.rs | 4 +-- src/gadgets/range_check.rs | 9 +++---- src/gadgets/split_base.rs | 6 ++--- src/gadgets/split_join.rs | 14 +++------- src/gates/arithmetic.rs | 8 +++--- src/gates/base_sum.rs | 9 +++---- src/gates/constant.rs | 4 +-- src/gates/exponentiation.rs | 9 +++---- src/gates/gmimc.rs | 10 +++---- src/gates/insertion.rs | 10 +++---- src/gates/interpolation.rs | 8 +++--- src/gates/reducing.rs | 8 +++--- src/iop/generator.rs | 42 ++++++++++++++++------------- src/iop/witness.rs | 5 ++-- 14 files changed, 63 insertions(+), 83 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 812e7317..d76dab9d 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -536,11 +536,11 @@ impl, const D: usize> SimpleGenerator for QuotientGeneratorE deps } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let num = witness.get_extension_target(self.numerator); let dem = witness.get_extension_target(self.denominator); let quotient = num / dem; - GeneratedValues::singleton_extension_target(self.quotient, quotient) + out_buffer.set_extension_target(self.quotient, quotient) } } diff --git a/src/gadgets/range_check.rs b/src/gadgets/range_check.rs index 0fc38afe..53bbf55c 100644 --- a/src/gadgets/range_check.rs +++ b/src/gadgets/range_check.rs @@ -56,15 +56,12 @@ impl SimpleGenerator for LowHighGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let integer_value = witness.get_target(self.integer).to_canonical_u64(); let low = integer_value & ((1 << self.n_log) - 1); let high = integer_value >> self.n_log; - let mut result = GeneratedValues::with_capacity(2); - result.set_target(self.low, F::from_canonical_u64(low)); - result.set_target(self.high, F::from_canonical_u64(high)); - - result + out_buffer.set_target(self.low, F::from_canonical_u64(low)); + out_buffer.set_target(self.high, F::from_canonical_u64(high)); } } diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index 901b4728..54fd8cf2 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -76,7 +76,7 @@ impl SimpleGenerator for BaseSumGenerator { self.limbs.clone() } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let sum = self .limbs .iter() @@ -84,10 +84,10 @@ impl SimpleGenerator for BaseSumGenerator { .rev() .fold(F::ZERO, |acc, limb| acc * F::from_canonical_usize(B) + limb); - GeneratedValues::singleton_target( + out_buffer.set_target( Target::wire(self.gate_index, BaseSumGate::::WIRE_SUM), sum, - ) + ); } } diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 78518245..45422d5d 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -110,13 +110,12 @@ impl SimpleGenerator for SplitGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let mut integer_value = witness.get_target(self.integer).to_canonical_u64(); - let mut result = GeneratedValues::with_capacity(self.bits.len()); for &b in &self.bits { let b_value = integer_value & 1; - result.set_target(b, F::from_canonical_u64(b_value)); + out_buffer.set_target(b, F::from_canonical_u64(b_value)); integer_value >>= 1; } @@ -124,8 +123,6 @@ impl SimpleGenerator for SplitGenerator { integer_value, 0, "Integer too large to fit in given number of bits" ); - - result } } @@ -141,13 +138,12 @@ impl SimpleGenerator for WireSplitGenerator { vec![self.integer] } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let mut integer_value = witness.get_target(self.integer).to_canonical_u64(); - let mut result = GeneratedValues::with_capacity(self.gates.len()); for &gate in &self.gates { let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); - result.set_target( + out_buffer.set_target( sum, F::from_canonical_u64(integer_value & ((1 << self.num_limbs) - 1)), ); @@ -160,7 +156,5 @@ impl SimpleGenerator for WireSplitGenerator { "Integer too large to fit in {} many `BaseSumGate`s", self.gates.len() ); - - result } } diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 73cf12bc..499bd4a7 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -191,7 +191,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio .collect() } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.gate_index, range); witness.get_extension_target(t) @@ -211,7 +211,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio let computed_output = multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into(); - GeneratedValues::singleton_extension_target(output_target, computed_output) + out_buffer.set_extension_target(output_target, computed_output) } } @@ -224,7 +224,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio .collect() } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.gate_index, range); witness.get_extension_target(t) @@ -244,7 +244,7 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio let computed_output = multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into(); - GeneratedValues::singleton_extension_target(output_target, computed_output) + out_buffer.set_extension_target(output_target, computed_output) } } diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 7da26095..f580bf16 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -144,7 +144,7 @@ impl SimpleGenerator for BaseSplitGenerator { vec![Target::wire(self.gate_index, BaseSumGate::::WIRE_SUM)] } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let sum_value = witness .get_target(Target::wire(self.gate_index, BaseSumGate::::WIRE_SUM)) .to_canonical_u64() as usize; @@ -169,16 +169,13 @@ impl SimpleGenerator for BaseSplitGenerator { .iter() .fold(F::ZERO, |acc, &x| acc * b_field + x); - let mut result = GeneratedValues::with_capacity(self.num_limbs + 1); - result.set_target( + out_buffer.set_target( Target::wire(self.gate_index, BaseSumGate::::WIRE_REVERSED_SUM), reversed_sum, ); for (b, b_value) in limbs.zip(limbs_value) { - result.set_target(b, b_value); + out_buffer.set_target(b, b_value); } - - result } } diff --git a/src/gates/constant.rs b/src/gates/constant.rs index 5fb98eb2..0cc22b22 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -85,12 +85,12 @@ impl SimpleGenerator for ConstantGenerator { Vec::new() } - fn run_once(&self, _witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, _witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let wire = Wire { gate: self.gate_index, input: ConstantGate::WIRE_OUTPUT, }; - GeneratedValues::singleton_target(Target::Wire(wire), self.constant) + out_buffer.set_target(Target::Wire(wire), self.constant); } } diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index f2f2d050..7bbb6eeb 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -220,7 +220,7 @@ impl, const D: usize> SimpleGenerator for ExponentiationGene deps } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let local_wire = |input| Wire { gate: self.gate_index, input, @@ -245,16 +245,13 @@ impl, const D: usize> SimpleGenerator for ExponentiationGene current_intermediate_value *= current_intermediate_value; } - let mut result = GeneratedValues::with_capacity(num_power_bits + 1); for i in 0..num_power_bits { let intermediate_value_wire = local_wire(self.gate.wire_intermediate_value(i)); - result.set_wire(intermediate_value_wire, intermediate_values[i]); + out_buffer.set_wire(intermediate_value_wire, intermediate_values[i]); } let output_wire = local_wire(self.gate.wire_output()); - result.set_wire(output_wire, intermediate_values[num_power_bits - 1]); - - result + out_buffer.set_wire(output_wire, intermediate_values[num_power_bits - 1]); } } diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 73de7956..146f7b6f 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -265,9 +265,7 @@ impl, const D: usize, const R: usize> SimpleGenerator .collect() } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { - let mut result = GeneratedValues::with_capacity(R + W); - + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let mut state = (0..W) .map(|i| { witness.get_wire(Wire { @@ -295,7 +293,7 @@ impl, const D: usize, const R: usize> SimpleGenerator for r in 0..R { let active = r % W; let cubing_input = state[active] + addition_buffer + self.constants[r]; - result.set_wire( + out_buffer.set_wire( Wire { gate: self.gate_index, input: GMiMCGate::::wire_cubing_input(r), @@ -309,7 +307,7 @@ impl, const D: usize, const R: usize> SimpleGenerator for i in 0..W { state[i] += addition_buffer; - result.set_wire( + out_buffer.set_wire( Wire { gate: self.gate_index, input: GMiMCGate::::wire_output(i), @@ -317,8 +315,6 @@ impl, const D: usize, const R: usize> SimpleGenerator state[i], ); } - - result } } diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index c540b5c0..a2e03580 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -255,7 +255,7 @@ impl, const D: usize> SimpleGenerator for InsertionGenerator deps } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let local_wire = |input| Wire { gate: self.gate_index, input, @@ -304,14 +304,12 @@ impl, const D: usize> SimpleGenerator for InsertionGenerator let mut result = GeneratedValues::::with_capacity((vec_size + 1) * (D + 2)); for i in 0..=vec_size { let output_wires = self.gate.wires_output_list_item(i).map(local_wire); - result.set_ext_wires(output_wires, new_vec[i]); + out_buffer.set_ext_wires(output_wires, new_vec[i]); let equality_dummy_wire = local_wire(self.gate.wires_equality_dummy_for_round_r(i)); - result.set_wire(equality_dummy_wire, equality_dummy_vals[i]); + out_buffer.set_wire(equality_dummy_wire, equality_dummy_vals[i]); let insert_here_wire = local_wire(self.gate.wires_insert_here_for_round_r(i)); - result.set_wire(insert_here_wire, insert_here_vals[i]); + out_buffer.set_wire(insert_here_wire, insert_here_vals[i]); } - - result } } diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index cd3e74a2..0cb6ea98 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -239,7 +239,7 @@ impl, const D: usize> SimpleGenerator for InterpolationGener deps } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let n = self.gate.num_points; let local_wire = |input| Wire { @@ -270,15 +270,13 @@ impl, const D: usize> SimpleGenerator for InterpolationGener let mut result = GeneratedValues::::with_capacity(D * (self.gate.num_points + 1)); for (i, &coeff) in interpolant.coeffs.iter().enumerate() { let wires = self.gate.wires_coeff(i).map(local_wire); - result.set_ext_wires(wires, coeff); + out_buffer.set_ext_wires(wires, coeff); } let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); let evaluation_value = interpolant.eval(evaluation_point); let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); - result.set_ext_wires(evaluation_value_wires, evaluation_value); - - result + out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value); } } diff --git a/src/gates/reducing.rs b/src/gates/reducing.rs index a87d8af9..92f00455 100644 --- a/src/gates/reducing.rs +++ b/src/gates/reducing.rs @@ -177,7 +177,7 @@ impl, const D: usize> SimpleGenerator for ReducingGenerator< .collect() } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.gate_index, range); witness.get_extension_target(t) @@ -202,12 +202,10 @@ impl, const D: usize> SimpleGenerator for ReducingGenerator< let mut acc = old_acc; for i in 0..self.gate.num_coeffs { let computed_acc = acc * alpha + coeffs[i].into(); - result.set_extension_target(accs[i], computed_acc); + out_buffer.set_extension_target(accs[i], computed_acc); acc = computed_acc; } - result.set_extension_target(output, acc); - - result + out_buffer.set_extension_target(output, acc); } } diff --git a/src/iop/generator.rs b/src/iop/generator.rs index 248646c1..9f421b14 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -34,18 +34,20 @@ pub(crate) fn generate_partial_witness( // We also track a list of "expired" generators which have already returned false. let mut generator_is_expired = vec![false; generators.len()]; + let mut buffer = GeneratedValues::empty(); + // Keep running generators until no generators are queued. while !pending_generator_indices.is_empty() { let mut next_pending_generator_indices = HashSet::new(); for &generator_idx in &pending_generator_indices { - let (result, finished) = generators[generator_idx].run(&witness); + let finished = generators[generator_idx].run(&witness, &mut buffer); if finished { generator_is_expired[generator_idx] = true; } // Enqueue unfinished generators that were watching one of the newly populated targets. - for (watch, _) in &result.target_values { + for (watch, _) in &buffer.target_values { if let Some(watching_generator_indices) = generator_indices_by_watches.get(watch) { for &watching_generator_idx in watching_generator_indices { if !generator_is_expired[watching_generator_idx] { @@ -55,7 +57,7 @@ pub(crate) fn generate_partial_witness( } } - witness.extend(result); + witness.extend(buffer.target_values.drain(..)); } pending_generator_indices = next_pending_generator_indices; @@ -73,11 +75,10 @@ pub trait WitnessGenerator: 'static + Send + Sync { /// the generator will be queued to run. fn watch_list(&self) -> Vec; - /// Run this generator, returning a `PartialWitness` containing any new witness elements, and a - /// flag indicating whether the generator is finished. If the flag is true, the generator will - /// never be run again, otherwise it will be queued for another run next time a target in its - /// watch list is populated. - fn run(&self, witness: &PartialWitness) -> (GeneratedValues, bool); + /// Run this generator, returning a flag indicating whether the generator is finished. If the + /// flag is true, the generator will never be run again, otherwise it will be queued for another + /// run next time a target in its watch list is populated. + fn run(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) -> bool; } /// Values generated by a generator invocation. @@ -108,6 +109,10 @@ impl GeneratedValues { vec![(target, value)].into() } + pub fn clear(&mut self) { + self.target_values.clear(); + } + pub fn singleton_extension_target( et: ExtensionTarget, value: F::Extension, @@ -171,7 +176,7 @@ impl GeneratedValues { pub trait SimpleGenerator: 'static + Send + Sync { fn dependencies(&self) -> Vec; - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues; + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues); } impl> WitnessGenerator for SG { @@ -179,11 +184,12 @@ impl> WitnessGenerator for SG { self.dependencies() } - fn run(&self, witness: &PartialWitness) -> (GeneratedValues, bool) { + fn run(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) -> bool { if witness.contains_all(&self.dependencies()) { - (self.run_once(witness), true) + self.run_once(witness, out_buffer); + true } else { - (GeneratedValues::empty(), false) + false } } } @@ -200,9 +206,9 @@ impl SimpleGenerator for CopyGenerator { vec![self.src] } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let value = witness.get_target(self.src); - GeneratedValues::singleton_target(self.dst, value) + out_buffer.set_target(self.dst, value); } } @@ -216,10 +222,10 @@ impl SimpleGenerator for RandomValueGenerator { Vec::new() } - fn run_once(&self, _witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, _witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let random_value = F::rand(); - GeneratedValues::singleton_target(self.target, random_value) + out_buffer.set_target(self.target, random_value); } } @@ -234,7 +240,7 @@ impl SimpleGenerator for NonzeroTestGenerator { vec![self.to_test] } - fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let to_test_value = witness.get_target(self.to_test); let dummy_value = if to_test_value == F::ZERO { @@ -243,6 +249,6 @@ impl SimpleGenerator for NonzeroTestGenerator { to_test_value.inverse() }; - GeneratedValues::singleton_target(self.dummy, dummy_value) + out_buffer.set_target(self.dummy, dummy_value); } } diff --git a/src/iop/witness.rs b/src/iop/witness.rs index df8010d9..339f8cfb 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -9,7 +9,6 @@ use crate::field::field_types::Field; use crate::gates::gate::GateInstance; use crate::hash::hash_types::HashOut; use crate::hash::hash_types::HashOutTarget; -use crate::iop::generator::GeneratedValues; use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::plonk::copy_constraint::CopyConstraint; @@ -150,8 +149,8 @@ impl PartialWitness { self.set_wires(wires, &value.to_basefield_array()); } - pub fn extend(&mut self, other: GeneratedValues) { - for (target, value) in other.target_values { + pub fn extend>(&mut self, pairs: I) { + for (target, value) in pairs { self.set_target(target, value); } }