Reuse a buffer of generated values (#142)

* Reuse a buffer of generated values

To avoid allocating `GeneratedValues` all the time. Saves ~60ms or so.

* PR feedback
This commit is contained in:
Daniel Lubarov 2021-08-02 10:55:10 -07:00 committed by GitHub
parent 730962ceac
commit d6211b8ab8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 63 additions and 83 deletions

View File

@ -536,11 +536,11 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for QuotientGeneratorE
deps
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let num = witness.get_extension_target(self.numerator);
let dem = witness.get_extension_target(self.denominator);
let quotient = num / dem;
GeneratedValues::singleton_extension_target(self.quotient, quotient)
out_buffer.set_extension_target(self.quotient, quotient)
}
}

View File

@ -56,15 +56,12 @@ impl<F: Field> SimpleGenerator<F> for LowHighGenerator {
vec![self.integer]
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let integer_value = witness.get_target(self.integer).to_canonical_u64();
let low = integer_value & ((1 << self.n_log) - 1);
let high = integer_value >> self.n_log;
let mut result = GeneratedValues::with_capacity(2);
result.set_target(self.low, F::from_canonical_u64(low));
result.set_target(self.high, F::from_canonical_u64(high));
result
out_buffer.set_target(self.low, F::from_canonical_u64(low));
out_buffer.set_target(self.high, F::from_canonical_u64(high));
}
}

View File

@ -76,7 +76,7 @@ impl<F: Field, const B: usize> SimpleGenerator<F> for BaseSumGenerator<B> {
self.limbs.clone()
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let sum = self
.limbs
.iter()
@ -84,10 +84,10 @@ impl<F: Field, const B: usize> SimpleGenerator<F> for BaseSumGenerator<B> {
.rev()
.fold(F::ZERO, |acc, limb| acc * F::from_canonical_usize(B) + limb);
GeneratedValues::singleton_target(
out_buffer.set_target(
Target::wire(self.gate_index, BaseSumGate::<B>::WIRE_SUM),
sum,
)
);
}
}

View File

@ -110,13 +110,12 @@ impl<F: Field> SimpleGenerator<F> for SplitGenerator {
vec![self.integer]
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let mut integer_value = witness.get_target(self.integer).to_canonical_u64();
let mut result = GeneratedValues::with_capacity(self.bits.len());
for &b in &self.bits {
let b_value = integer_value & 1;
result.set_target(b, F::from_canonical_u64(b_value));
out_buffer.set_target(b, F::from_canonical_u64(b_value));
integer_value >>= 1;
}
@ -124,8 +123,6 @@ impl<F: Field> SimpleGenerator<F> for SplitGenerator {
integer_value, 0,
"Integer too large to fit in given number of bits"
);
result
}
}
@ -141,13 +138,12 @@ impl<F: Field> SimpleGenerator<F> for WireSplitGenerator {
vec![self.integer]
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let mut integer_value = witness.get_target(self.integer).to_canonical_u64();
let mut result = GeneratedValues::with_capacity(self.gates.len());
for &gate in &self.gates {
let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM);
result.set_target(
out_buffer.set_target(
sum,
F::from_canonical_u64(integer_value & ((1 << self.num_limbs) - 1)),
);
@ -160,7 +156,5 @@ impl<F: Field> SimpleGenerator<F> for WireSplitGenerator {
"Integer too large to fit in {} many `BaseSumGate`s",
self.gates.len()
);
result
}
}

View File

@ -191,7 +191,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensio
.collect()
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let extract_extension = |range: Range<usize>| -> F::Extension {
let t = ExtensionTarget::from_range(self.gate_index, range);
witness.get_extension_target(t)
@ -211,7 +211,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensio
let computed_output =
multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into();
GeneratedValues::singleton_extension_target(output_target, computed_output)
out_buffer.set_extension_target(output_target, computed_output)
}
}
@ -224,7 +224,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensio
.collect()
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let extract_extension = |range: Range<usize>| -> F::Extension {
let t = ExtensionTarget::from_range(self.gate_index, range);
witness.get_extension_target(t)
@ -244,7 +244,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensio
let computed_output =
multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into();
GeneratedValues::singleton_extension_target(output_target, computed_output)
out_buffer.set_extension_target(output_target, computed_output)
}
}

View File

@ -144,7 +144,7 @@ impl<F: Field, const B: usize> SimpleGenerator<F> for BaseSplitGenerator<B> {
vec![Target::wire(self.gate_index, BaseSumGate::<B>::WIRE_SUM)]
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let sum_value = witness
.get_target(Target::wire(self.gate_index, BaseSumGate::<B>::WIRE_SUM))
.to_canonical_u64() as usize;
@ -169,16 +169,13 @@ impl<F: Field, const B: usize> SimpleGenerator<F> for BaseSplitGenerator<B> {
.iter()
.fold(F::ZERO, |acc, &x| acc * b_field + x);
let mut result = GeneratedValues::with_capacity(self.num_limbs + 1);
result.set_target(
out_buffer.set_target(
Target::wire(self.gate_index, BaseSumGate::<B>::WIRE_REVERSED_SUM),
reversed_sum,
);
for (b, b_value) in limbs.zip(limbs_value) {
result.set_target(b, b_value);
out_buffer.set_target(b, b_value);
}
result
}
}

View File

@ -85,12 +85,12 @@ impl<F: Field> SimpleGenerator<F> for ConstantGenerator<F> {
Vec::new()
}
fn run_once(&self, _witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, _witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let wire = Wire {
gate: self.gate_index,
input: ConstantGate::WIRE_OUTPUT,
};
GeneratedValues::singleton_target(Target::Wire(wire), self.constant)
out_buffer.set_target(Target::Wire(wire), self.constant);
}
}

View File

@ -220,7 +220,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ExponentiationGene
deps
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
@ -245,16 +245,13 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ExponentiationGene
current_intermediate_value *= current_intermediate_value;
}
let mut result = GeneratedValues::with_capacity(num_power_bits + 1);
for i in 0..num_power_bits {
let intermediate_value_wire = local_wire(self.gate.wire_intermediate_value(i));
result.set_wire(intermediate_value_wire, intermediate_values[i]);
out_buffer.set_wire(intermediate_value_wire, intermediate_values[i]);
}
let output_wire = local_wire(self.gate.wire_output());
result.set_wire(output_wire, intermediate_values[num_power_bits - 1]);
result
out_buffer.set_wire(output_wire, intermediate_values[num_power_bits - 1]);
}
}

View File

@ -265,9 +265,7 @@ impl<F: Extendable<D>, const D: usize, const R: usize> SimpleGenerator<F>
.collect()
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
let mut result = GeneratedValues::with_capacity(R + W);
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let mut state = (0..W)
.map(|i| {
witness.get_wire(Wire {
@ -295,7 +293,7 @@ impl<F: Extendable<D>, const D: usize, const R: usize> SimpleGenerator<F>
for r in 0..R {
let active = r % W;
let cubing_input = state[active] + addition_buffer + self.constants[r];
result.set_wire(
out_buffer.set_wire(
Wire {
gate: self.gate_index,
input: GMiMCGate::<F, D, R>::wire_cubing_input(r),
@ -309,7 +307,7 @@ impl<F: Extendable<D>, const D: usize, const R: usize> SimpleGenerator<F>
for i in 0..W {
state[i] += addition_buffer;
result.set_wire(
out_buffer.set_wire(
Wire {
gate: self.gate_index,
input: GMiMCGate::<F, D, R>::wire_output(i),
@ -317,8 +315,6 @@ impl<F: Extendable<D>, const D: usize, const R: usize> SimpleGenerator<F>
state[i],
);
}
result
}
}

View File

@ -255,7 +255,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for InsertionGenerator
deps
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
@ -304,14 +304,12 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for InsertionGenerator
let mut result = GeneratedValues::<F>::with_capacity((vec_size + 1) * (D + 2));
for i in 0..=vec_size {
let output_wires = self.gate.wires_output_list_item(i).map(local_wire);
result.set_ext_wires(output_wires, new_vec[i]);
out_buffer.set_ext_wires(output_wires, new_vec[i]);
let equality_dummy_wire = local_wire(self.gate.wires_equality_dummy_for_round_r(i));
result.set_wire(equality_dummy_wire, equality_dummy_vals[i]);
out_buffer.set_wire(equality_dummy_wire, equality_dummy_vals[i]);
let insert_here_wire = local_wire(self.gate.wires_insert_here_for_round_r(i));
result.set_wire(insert_here_wire, insert_here_vals[i]);
out_buffer.set_wire(insert_here_wire, insert_here_vals[i]);
}
result
}
}

View File

@ -239,7 +239,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for InterpolationGener
deps
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let n = self.gate.num_points;
let local_wire = |input| Wire {
@ -270,15 +270,13 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for InterpolationGener
let mut result = GeneratedValues::<F>::with_capacity(D * (self.gate.num_points + 1));
for (i, &coeff) in interpolant.coeffs.iter().enumerate() {
let wires = self.gate.wires_coeff(i).map(local_wire);
result.set_ext_wires(wires, coeff);
out_buffer.set_ext_wires(wires, coeff);
}
let evaluation_point = get_local_ext(self.gate.wires_evaluation_point());
let evaluation_value = interpolant.eval(evaluation_point);
let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire);
result.set_ext_wires(evaluation_value_wires, evaluation_value);
result
out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value);
}
}

View File

@ -177,7 +177,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ReducingGenerator<
.collect()
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let extract_extension = |range: Range<usize>| -> F::Extension {
let t = ExtensionTarget::from_range(self.gate_index, range);
witness.get_extension_target(t)
@ -202,12 +202,10 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ReducingGenerator<
let mut acc = old_acc;
for i in 0..self.gate.num_coeffs {
let computed_acc = acc * alpha + coeffs[i].into();
result.set_extension_target(accs[i], computed_acc);
out_buffer.set_extension_target(accs[i], computed_acc);
acc = computed_acc;
}
result.set_extension_target(output, acc);
result
out_buffer.set_extension_target(output, acc);
}
}

View File

@ -34,18 +34,20 @@ pub(crate) fn generate_partial_witness<F: Field>(
// We also track a list of "expired" generators which have already returned false.
let mut generator_is_expired = vec![false; generators.len()];
let mut buffer = GeneratedValues::empty();
// Keep running generators until no generators are queued.
while !pending_generator_indices.is_empty() {
let mut next_pending_generator_indices = HashSet::new();
for &generator_idx in &pending_generator_indices {
let (result, finished) = generators[generator_idx].run(&witness);
let finished = generators[generator_idx].run(&witness, &mut buffer);
if finished {
generator_is_expired[generator_idx] = true;
}
// Enqueue unfinished generators that were watching one of the newly populated targets.
for (watch, _) in &result.target_values {
for (watch, _) in &buffer.target_values {
if let Some(watching_generator_indices) = generator_indices_by_watches.get(watch) {
for &watching_generator_idx in watching_generator_indices {
if !generator_is_expired[watching_generator_idx] {
@ -55,7 +57,7 @@ pub(crate) fn generate_partial_witness<F: Field>(
}
}
witness.extend(result);
witness.extend(buffer.target_values.drain(..));
}
pending_generator_indices = next_pending_generator_indices;
@ -73,11 +75,10 @@ pub trait WitnessGenerator<F: Field>: 'static + Send + Sync {
/// the generator will be queued to run.
fn watch_list(&self) -> Vec<Target>;
/// Run this generator, returning a `PartialWitness` containing any new witness elements, and a
/// flag indicating whether the generator is finished. If the flag is true, the generator will
/// never be run again, otherwise it will be queued for another run next time a target in its
/// watch list is populated.
fn run(&self, witness: &PartialWitness<F>) -> (GeneratedValues<F>, bool);
/// Run this generator, returning a flag indicating whether the generator is finished. If the
/// flag is true, the generator will never be run again, otherwise it will be queued for another
/// run next time a target in its watch list is populated.
fn run(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) -> bool;
}
/// Values generated by a generator invocation.
@ -108,6 +109,10 @@ impl<F: Field> GeneratedValues<F> {
vec![(target, value)].into()
}
pub fn clear(&mut self) {
self.target_values.clear();
}
pub fn singleton_extension_target<const D: usize>(
et: ExtensionTarget<D>,
value: F::Extension,
@ -171,7 +176,7 @@ impl<F: Field> GeneratedValues<F> {
pub trait SimpleGenerator<F: Field>: 'static + Send + Sync {
fn dependencies(&self) -> Vec<Target>;
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F>;
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>);
}
impl<F: Field, SG: SimpleGenerator<F>> WitnessGenerator<F> for SG {
@ -179,11 +184,12 @@ impl<F: Field, SG: SimpleGenerator<F>> WitnessGenerator<F> for SG {
self.dependencies()
}
fn run(&self, witness: &PartialWitness<F>) -> (GeneratedValues<F>, bool) {
fn run(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) -> bool {
if witness.contains_all(&self.dependencies()) {
(self.run_once(witness), true)
self.run_once(witness, out_buffer);
true
} else {
(GeneratedValues::empty(), false)
false
}
}
}
@ -200,9 +206,9 @@ impl<F: Field> SimpleGenerator<F> for CopyGenerator {
vec![self.src]
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let value = witness.get_target(self.src);
GeneratedValues::singleton_target(self.dst, value)
out_buffer.set_target(self.dst, value);
}
}
@ -216,10 +222,10 @@ impl<F: Field> SimpleGenerator<F> for RandomValueGenerator {
Vec::new()
}
fn run_once(&self, _witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, _witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let random_value = F::rand();
GeneratedValues::singleton_target(self.target, random_value)
out_buffer.set_target(self.target, random_value);
}
}
@ -234,7 +240,7 @@ impl<F: Field> SimpleGenerator<F> for NonzeroTestGenerator {
vec![self.to_test]
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let to_test_value = witness.get_target(self.to_test);
let dummy_value = if to_test_value == F::ZERO {
@ -243,6 +249,6 @@ impl<F: Field> SimpleGenerator<F> for NonzeroTestGenerator {
to_test_value.inverse()
};
GeneratedValues::singleton_target(self.dummy, dummy_value)
out_buffer.set_target(self.dummy, dummy_value);
}
}

View File

@ -9,7 +9,6 @@ use crate::field::field_types::Field;
use crate::gates::gate::GateInstance;
use crate::hash::hash_types::HashOut;
use crate::hash::hash_types::HashOutTarget;
use crate::iop::generator::GeneratedValues;
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::plonk::copy_constraint::CopyConstraint;
@ -150,8 +149,8 @@ impl<F: Field> PartialWitness<F> {
self.set_wires(wires, &value.to_basefield_array());
}
pub fn extend(&mut self, other: GeneratedValues<F>) {
for (target, value) in other.target_values {
pub fn extend<I: Iterator<Item = (Target, F)>>(&mut self, pairs: I) {
for (target, value) in pairs {
self.set_target(target, value);
}
}