diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 00dbbe21..0931cc88 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,5 +1,7 @@ use std::borrow::Borrow; +use itertools::Itertools; + use crate::field::extension_field::Extendable; use crate::field::field_types::{PrimeField, RichField}; use crate::gates::arithmetic_base::ArithmeticGate; @@ -206,11 +208,11 @@ impl, const D: usize> CircuitBuilder { /// Multiply `n` `Target`s. pub fn mul_many(&mut self, terms: &[Target]) -> Target { - let terms_ext = terms + terms .iter() - .map(|&t| self.convert_to_ext(t)) - .collect::>(); - self.mul_many_extension(&terms_ext).to_target_array()[0] + .copied() + .fold1(|acc, t| self.mul(acc, t)) + .unwrap_or_else(|| self.one()) } /// Exponentiate `base` to the power of `2^power_log`. diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 7b6535f9..d81943ab 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -1,8 +1,11 @@ +use itertools::Itertools; + use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::FieldExtension; use crate::field::extension_field::{Extendable, OEF}; use crate::field::field_types::{Field, PrimeField, RichField}; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; +use crate::gates::multiplication_extension::MulExtensionGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -41,13 +44,19 @@ impl, const D: usize> CircuitBuilder { return result; } + let result = if self.target_as_constant_ext(addend) == Some(F::Extension::ZERO) { + // If the addend is zero, we use a multiplication gate. + self.compute_mul_extension_operation(operation) + } else { + // Otherwise, we use an arithmetic gate. + self.compute_arithmetic_extension_operation(operation) + }; // Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot. - let result = self.add_arithmetic_extension_operation(operation); self.arithmetic_results.insert(operation, result); result } - fn add_arithmetic_extension_operation( + fn compute_arithmetic_extension_operation( &mut self, operation: ExtensionArithmeticOperation, ) -> ExtensionTarget { @@ -70,6 +79,22 @@ impl, const D: usize> CircuitBuilder { ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_ith_output(i)) } + fn compute_mul_extension_operation( + &mut self, + operation: ExtensionArithmeticOperation, + ) -> ExtensionTarget { + let (gate, i) = self.find_mul_gate(operation.const_0); + let wires_multiplicand_0 = + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_multiplicand_0(i)); + let wires_multiplicand_1 = + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_multiplicand_1(i)); + + self.connect_extension(operation.multiplicand_0, wires_multiplicand_0); + self.connect_extension(operation.multiplicand_1, wires_multiplicand_1); + + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_output(i)) + } + /// Checks for special cases where the value of /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` /// can be determined without adding an `ArithmeticGate`. @@ -273,11 +298,11 @@ impl, const D: usize> CircuitBuilder { /// Multiply `n` `ExtensionTarget`s. pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let mut product = self.one_extension(); - for &term in terms { - product = self.mul_extension(product, term); - } - product + terms + .iter() + .copied() + .fold1(|acc, t| self.mul_extension(acc, t)) + .unwrap_or_else(|| self.one_extension()) } /// Like `mul_add`, but for `ExtensionTarget`s. diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 93de5e97..369c9ea5 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -14,6 +14,7 @@ pub mod gate_tree; pub mod gmimc; pub mod insertion; pub mod interpolation; +pub mod multiplication_extension; pub mod noop; pub mod poseidon; pub(crate) mod poseidon_mds; diff --git a/src/gates/multiplication_extension.rs b/src/gates/multiplication_extension.rs new file mode 100644 index 00000000..4c385b79 --- /dev/null +++ b/src/gates/multiplication_extension.rs @@ -0,0 +1,204 @@ +use std::ops::Range; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::extension_field::FieldExtension; +use crate::field::field_types::RichField; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// A gate which can perform a weighted multiplication, i.e. `result = c0 x y`. If the config +/// supports enough routed wires, it can support several such operations in one gate. +#[derive(Debug)] +pub struct MulExtensionGate { + /// Number of multiplications performed by the gate. + pub num_ops: usize, +} + +impl MulExtensionGate { + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + num_ops: Self::num_ops(config), + } + } + + /// Determine the maximum number of operations that can fit in one gate for the given config. + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 3 * D; + config.num_routed_wires / wires_per_op + } + + pub fn wires_ith_multiplicand_0(i: usize) -> Range { + 3 * D * i..3 * D * i + D + } + pub fn wires_ith_multiplicand_1(i: usize) -> Range { + 3 * D * i + D..3 * D * i + 2 * D + } + pub fn wires_ith_output(i: usize) -> Range { + 3 * D * i + 2 * D..3 * D * i + 3 * D + } +} + +impl, const D: usize> Gate for MulExtensionGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext_algebra(Self::wires_ith_output(i)); + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0); + + constraints.extend((output - computed_output).to_basefield_array()); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext(Self::wires_ith_output(i)); + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0); + + constraints.extend((output - computed_output).to_basefield_array()); + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext_algebra(Self::wires_ith_output(i)); + let computed_output = { + let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1); + builder.scalar_mul_ext_algebra(const_0, mul) + }; + + let diff = builder.sub_ext_algebra(output, computed_output); + constraints.extend(diff.to_ext_target_array()); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + local_constants: &[F], + ) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + MulExtensionGenerator { + gate_index, + const_0: local_constants[0], + i, + } + .adapter(), + ); + g + }) + .collect::>() + } + + fn num_wires(&self) -> usize { + self.num_ops * 3 * D + } + + fn num_constants(&self) -> usize { + 1 + } + + fn degree(&self) -> usize { + 3 + } + + fn num_constraints(&self) -> usize { + self.num_ops * D + } +} + +#[derive(Clone, Debug)] +struct MulExtensionGenerator, const D: usize> { + gate_index: usize, + const_0: F, + i: usize, +} + +impl, const D: usize> SimpleGenerator + for MulExtensionGenerator +{ + fn dependencies(&self) -> Vec { + MulExtensionGate::::wires_ith_multiplicand_0(self.i) + .chain(MulExtensionGate::::wires_ith_multiplicand_1(self.i)) + .map(|i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let multiplicand_0 = + extract_extension(MulExtensionGate::::wires_ith_multiplicand_0(self.i)); + let multiplicand_1 = + extract_extension(MulExtensionGate::::wires_ith_multiplicand_1(self.i)); + + let output_target = ExtensionTarget::from_range( + self.gate_index, + MulExtensionGate::::wires_ith_output(self.i), + ); + + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(self.const_0); + + out_buffer.set_extension_target(output_target, computed_output) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::multiplication_extension::MulExtensionGate; + use crate::plonk::circuit_data::CircuitConfig; + + #[test] + fn low_degree() { + let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_low_degree::(gate); + } + + #[test] + fn eval_fns() -> Result<()> { + let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_eval_fns::(gate) + } +} diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 32ea59b6..730699b9 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -20,6 +20,7 @@ use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; +use crate::gates::multiplication_extension::MulExtensionGate; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; use crate::gates::random_access::RandomAccessGate; @@ -769,6 +770,8 @@ pub struct BatchedGates, const D: usize> { pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>, + pub(crate) free_mul: HashMap, + /// A map `b -> (g, i)` from `b` bits to an available random access gate of that size with gate /// index `g` and already using `i` random accesses. pub(crate) free_random_access: HashMap, @@ -793,6 +796,7 @@ impl, const D: usize> BatchedGates { Self { free_arithmetic: HashMap::new(), free_base_arithmetic: HashMap::new(), + free_mul: HashMap::new(), free_random_access: HashMap::new(), current_switch_gates: Vec::new(), current_u32_arithmetic_gate: None, @@ -865,6 +869,33 @@ impl, const D: usize> CircuitBuilder { (gate, i) } + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_mul_gate(&mut self, const_0: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_mul + .get(&const_0) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + MulExtensionGate::new_from_config(&self.config), + vec![const_0], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < MulExtensionGate::::num_ops(&self.config) - 1 { + self.batched_gates.free_mul.insert(const_0, (gate, i + 1)); + } else { + self.batched_gates.free_mul.remove(&const_0); + } + + (gate, i) + } + /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index /// `g` and the gate's `i`-th random access is available. @@ -1021,6 +1052,22 @@ impl, const D: usize> CircuitBuilder { assert!(self.batched_gates.free_arithmetic.is_empty()); } + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticExtensionGenerator`s are run. + fn fill_mul_gates(&mut self) { + let zero = self.zero_extension(); + for (c0, (_gate, i)) in self.batched_gates.free_mul.clone() { + for _ in i..MulExtensionGate::::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_extension_target(); + self.arithmetic_extension(c0, F::ZERO, dummy, dummy, zero); + self.connect_extension(dummy, zero); + } + } + assert!(self.batched_gates.free_mul.is_empty()); + } + /// Fill the remaining unused random access operations with zeros, so that all /// `RandomAccessGenerator`s are run. fn fill_random_access_gates(&mut self) { @@ -1110,6 +1157,7 @@ impl, const D: usize> CircuitBuilder { fn fill_batched_gates(&mut self) { self.fill_arithmetic_gates(); self.fill_base_arithmetic_gates(); + self.fill_mul_gates(); self.fill_random_access_gates(); self.fill_switch_gates(); self.fill_u32_arithmetic_gates();