diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 0761c0ab..df7fb895 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -3,7 +3,9 @@ use std::borrow::Borrow; use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::Extendable; use crate::gates::exponentiation::ExponentiationGate; +use crate::plonk_common::reduce_with_powers_recursive; use crate::target::Target; +use crate::util::log2_ceil; impl, const D: usize> CircuitBuilder { /// Computes `-x`. @@ -168,7 +170,6 @@ impl, const D: usize> CircuitBuilder { base } - // TODO: Optimize this, maybe with a new gate. // TODO: Test /// Exponentiate `base` to the power of `exponent`, given by its little-endian bits. pub fn exp_from_bits( @@ -176,61 +177,36 @@ impl, const D: usize> CircuitBuilder { base: Target, exponent_bits: impl Iterator>, ) -> Target { - let mut current = base; - let one = self.one(); - let mut product = one; + let exp_bits_vec: Vec = exponent_bits.map(|b| *b.borrow()).collect::>(); + let gate = ExponentiationGate::new(exp_bits_vec.len()); + let gate_index = self.add_gate(gate.clone(), vec![]); - for bit in exponent_bits { - let multiplicand = self.select(*bit.borrow(), current, one); - product = self.mul(product, multiplicand); - current = self.mul(current, current); - } + let two = self.constant(F::TWO); + let exponent = reduce_with_powers_recursive(self, &exp_bits_vec[..], two); - product + self.route(exponent, Target::wire(gate_index, gate.wire_base())); + self.route(exponent, Target::wire(gate_index, gate.wire_power())); + exp_bits_vec.iter().enumerate().for_each(|(i, bit)| { + self.route(*bit, Target::wire(gate_index, gate.wire_power_bit(i))); + }); + + Target::wire(gate_index, gate.wire_output()) } - // TODO: Optimize this, maybe with a new gate. // TODO: Test /// Exponentiate `base` to the power of `exponent`, where `exponent < 2^num_bits`. pub fn exp(&mut self, base: Target, exponent: Target, num_bits: usize) -> Target { let exponent_bits = self.split_le(exponent, num_bits); - let gate = ExponentiationGate::new(exponent_bits.len()); - let gate_index = self.add_gate(gate.clone(), vec![]); - - self.route(exponent, Target::wire(gate_index, gate.wire_power())); - exponent_bits.iter().enumerate().for_each(|(i, &bit)| { - self.route(bit, Target::wire(gate_index, gate.wire_power_bit(i))); - }); - - Target::wire(gate_index, gate.wire_output()) + self.exp_from_bits(base, exponent_bits.iter()) } /// Exponentiate `base` to the power of a known `exponent`. // TODO: Test pub fn exp_u64(&mut self, base: Target, exponent: u64) -> Target { - let mut exp_bits = Vec::new(); - let mut cur_exp = exponent; - while cur_exp > 0 { - exp_bits.push(cur_exp % 2); - cur_exp /= 2; - } - let exp_target = self.constant(F::from_canonical_u64(exponent)); - let exp_bits_targets: Vec<_> = exp_bits - .iter() - .map(|b| self.constant(F::from_canonical_u64(*b))) - .collect(); - - let gate = ExponentiationGate::new(exp_bits.len()); - let gate_index = self.add_gate(gate.clone(), vec![]); - - self.route(exp_target, Target::wire(gate_index, gate.wire_power())); - exp_bits_targets.iter().enumerate().for_each(|(i, &bit)| { - self.route(bit, Target::wire(gate_index, gate.wire_power_bit(i))); - }); - - Target::wire(gate_index, gate.wire_output()) + let num_bits = log2_ceil(exponent as usize + 1); + self.exp(base, exp_target, num_bits) } /// Computes `x / y`. Results in an unsatisfiable instance if `y = 0`. diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index cd1694e4..3ec017f3 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -6,7 +6,7 @@ use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use crate::plonk_common::{reduce_with_powers, reduce_with_powers_recursive}; +use crate::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::witness::PartialWitness; @@ -80,9 +80,9 @@ impl, const D: usize, const B: usize> Gate for BaseSumGat let sum = vars.local_wires[Self::WIRE_SUM]; let reversed_sum = vars.local_wires[Self::WIRE_REVERSED_SUM]; let mut limbs = vars.local_wires[self.limbs()].to_vec(); - let computed_sum = reduce_with_powers_recursive(builder, &limbs, base); + let computed_sum = reduce_with_powers_ext_recursive(builder, &limbs, base); limbs.reverse(); - let computed_reversed_sum = reduce_with_powers_recursive(builder, &limbs, base); + let computed_reversed_sum = reduce_with_powers_ext_recursive(builder, &limbs, base); let mut constraints = vec![ builder.sub_extension(computed_sum, sum), builder.sub_extension(computed_reversed_sum, reversed_sum), diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index baaabc2b..900559e6 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -6,7 +6,7 @@ use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::Gate; use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use crate::plonk_common::{reduce_with_powers, reduce_with_powers_recursive}; +use crate::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::wire::Wire; @@ -160,7 +160,7 @@ impl, const D: usize> Gate for ExponentiationGate { let mut constraints = Vec::new(); let two = builder.constant(F::TWO); - let computed_power = reduce_with_powers_recursive(builder, &power_bits, two); + let computed_power = reduce_with_powers_ext_recursive(builder, &power_bits, two); let power_diff = builder.sub_extension(power, computed_power); constraints.push(power_diff); diff --git a/src/plonk_common.rs b/src/plonk_common.rs index 62552c0a..595fed91 100644 --- a/src/plonk_common.rs +++ b/src/plonk_common.rs @@ -158,6 +158,18 @@ pub(crate) fn reduce_with_powers(terms: &[F], alpha: F) -> F { } pub(crate) fn reduce_with_powers_recursive, const D: usize>( + builder: &mut CircuitBuilder, + terms: &[Target], + alpha: Target, +) -> Target { + let mut sum = builder.zero(); + for &term in terms.iter().rev() { + sum = builder.mul_add(sum, alpha, term); + } + sum +} + +pub(crate) fn reduce_with_powers_ext_recursive, const D: usize>( builder: &mut CircuitBuilder, terms: &[ExtensionTarget], alpha: Target,