diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 3fe90019..a053b761 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -215,7 +215,7 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of `2^power_log`. pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target { - if power_log > ArithmeticGate::new_from_config(&self.config).num_ops { + if power_log > self.num_base_arithmetic_ops_per_gate() { // Cheaper to just use `ExponentiateGate`. return self.exp_u64(base, 1 << power_log); } @@ -269,7 +269,7 @@ impl, const D: usize> CircuitBuilder { let base_t = self.constant(base); let exponent_bits: Vec<_> = exponent_bits.into_iter().map(|b| *b.borrow()).collect(); - if exponent_bits.len() > ArithmeticGate::new_from_config(&self.config).num_ops { + if exponent_bits.len() > self.num_base_arithmetic_ops_per_gate() { // Cheaper to just use `ExponentiateGate`. return self.exp_from_bits(base_t, exponent_bits); } diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index 30bdea6a..d60324ce 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -1,5 +1,7 @@ use std::borrow::Borrow; +use itertools::Itertools; + use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::gates::base_sum::BaseSumGate; @@ -29,21 +31,26 @@ impl, const D: usize> CircuitBuilder { /// the number with little-endian bit representation given by `bits`. pub(crate) fn le_sum( &mut self, - bits: impl ExactSizeIterator> + Clone, + mut bits: impl Iterator>, ) -> Target { + let bits = bits.map(|b| *b.borrow()).collect_vec(); let num_bits = bits.len(); if num_bits == 0 { return self.zero(); - } else if num_bits == 1 { - let mut bits = bits; - return bits.next().unwrap().borrow().target; - } else if num_bits == 2 { - let two = self.two(); - let mut bits = bits; - let b0 = bits.next().unwrap().borrow().target; - let b1 = bits.next().unwrap().borrow().target; - return self.mul_add(two, b1, b0); } + + // Check if it's cheaper to just do this with arithmetic operations. + let arithmetic_ops = num_bits - 1; + if arithmetic_ops <= self.num_base_arithmetic_ops_per_gate() { + let two = self.two(); + let mut rev_bits = bits.iter().rev(); + let mut sum = rev_bits.next().unwrap().target; + for &bit in rev_bits { + sum = self.mul_add(two, sum, bit.target); + } + return sum; + } + debug_assert!( BaseSumGate::<2>::START_LIMBS + num_bits <= self.config.num_routed_wires, "Not enough routed wires." @@ -51,10 +58,10 @@ impl, const D: usize> CircuitBuilder { let gate_type = BaseSumGate::<2>::new_from_config::(&self.config); let gate_index = self.add_gate(gate_type, vec![]); for (limb, wire) in bits - .clone() + .iter() .zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_bits) { - self.connect(limb.borrow().target, Target::wire(gate_index, wire)); + self.connect(limb.target, Target::wire(gate_index, wire)); } for l in gate_type.limbs().skip(num_bits) { self.assert_zero(Target::wire(gate_index, l)); @@ -62,7 +69,7 @@ impl, const D: usize> CircuitBuilder { self.add_simple_generator(BaseSumGenerator::<2> { gate_index, - limbs: bits.map(|l| *l.borrow()).collect(), + limbs: bits, }); Target::wire(gate_index, BaseSumGate::<2>::WIRE_SUM) diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index aac9d42e..4425f193 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -379,6 +379,20 @@ impl, const D: usize> CircuitBuilder { } } + /// The number of (base field) `arithmetic` operations that can be performed in a single gate. + pub(crate) fn num_base_arithmetic_ops_per_gate(&self) -> usize { + if self.config.use_base_arithmetic_gate { + ArithmeticGate::new_from_config(&self.config).num_ops + } else { + self.num_ext_arithmetic_ops_per_gate() + } + } + + /// The number of `arithmetic_extension` operations that can be performed in a single gate. + pub(crate) fn num_ext_arithmetic_ops_per_gate(&self) -> usize { + ArithmeticExtensionGate::::new_from_config(&self.config).num_ops + } + /// The number of polynomial values that will be revealed per opening, both for the "regular" /// polynomials and for the Z polynomials. Because calculating these values involves a recursive /// dependence (the amount of blinding depends on the degree, which depends on the blinding),