diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 7b6535f9..7489ed4c 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -3,6 +3,7 @@ use crate::field::extension_field::FieldExtension; use crate::field::extension_field::{Extendable, OEF}; use crate::field::field_types::{Field, PrimeField, RichField}; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; +use crate::gates::multiplication_extension::MulExtensionGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -41,8 +42,12 @@ impl, const D: usize> CircuitBuilder { return result; } + let result = if self.target_as_constant_ext(addend) == Some(F::Extension::ZERO) { + self.add_mul_extension_operation(operation) + } else { + self.add_arithmetic_extension_operation(operation) + }; // Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot. - let result = self.add_arithmetic_extension_operation(operation); self.arithmetic_results.insert(operation, result); result } @@ -70,6 +75,22 @@ impl, const D: usize> CircuitBuilder { ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_ith_output(i)) } + fn add_mul_extension_operation( + &mut self, + operation: ExtensionArithmeticOperation, + ) -> ExtensionTarget { + let (gate, i) = self.find_mul_gate(operation.const_0); + let wires_multiplicand_0 = + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_multiplicand_0(i)); + let wires_multiplicand_1 = + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_multiplicand_1(i)); + + self.connect_extension(operation.multiplicand_0, wires_multiplicand_0); + self.connect_extension(operation.multiplicand_1, wires_multiplicand_1); + + ExtensionTarget::from_range(gate, MulExtensionGate::::wires_ith_output(i)) + } + /// Checks for special cases where the value of /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` /// can be determined without adding an `ArithmeticGate`. @@ -556,7 +577,7 @@ mod tests { for (&v, &t) in vs.iter().zip(&ts) { pw.set_extension_target(t, v); } - let mul0 = builder.mul_many_extension(&ts); + // let mul0 = builder.mul_many_extension(&ts); let mul1 = { let mut acc = builder.one_extension(); for &t in &ts { @@ -566,7 +587,7 @@ mod tests { }; let mul2 = builder.constant_extension(vs.into_iter().product()); - builder.connect_extension(mul0, mul1); + // builder.connect_extension(mul0, mul1); builder.connect_extension(mul1, mul2); let data = builder.build(); @@ -633,4 +654,31 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_mul_ext() -> Result<()> { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand(); + let y = FF::rand(); + let z = x * y; + + let xt = builder.constant_extension(x); + let yt = builder.constant_extension(y); + let zt = builder.constant_extension(z); + let comp_zt = builder.mul_extension(xt, yt); + builder.connect_extension(zt, comp_zt); + + let data = builder.build(); + let proof = data.prove(pw)?; + + verify(proof, &data.verifier_only, &data.common) + } } diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 93de5e97..369c9ea5 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -14,6 +14,7 @@ pub mod gate_tree; pub mod gmimc; pub mod insertion; pub mod interpolation; +pub mod multiplication_extension; pub mod noop; pub mod poseidon; pub(crate) mod poseidon_mds; diff --git a/src/gates/multiplication_extension.rs b/src/gates/multiplication_extension.rs new file mode 100644 index 00000000..16cc6315 --- /dev/null +++ b/src/gates/multiplication_extension.rs @@ -0,0 +1,204 @@ +use std::ops::Range; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::extension_field::FieldExtension; +use crate::field::field_types::RichField; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config +/// supports enough routed wires, it can support several such operations in one gate. +#[derive(Debug)] +pub struct MulExtensionGate { + /// Number of arithmetic operations performed by an arithmetic gate. + pub num_ops: usize, +} + +impl MulExtensionGate { + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + num_ops: Self::num_ops(config), + } + } + + /// Determine the maximum number of operations that can fit in one gate for the given config. + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 3 * D; + config.num_routed_wires / wires_per_op + } + + pub fn wires_ith_multiplicand_0(i: usize) -> Range { + 3 * D * i..3 * D * i + D + } + pub fn wires_ith_multiplicand_1(i: usize) -> Range { + 3 * D * i + D..3 * D * i + 2 * D + } + pub fn wires_ith_output(i: usize) -> Range { + 3 * D * i + 2 * D..3 * D * i + 3 * D + } +} + +impl, const D: usize> Gate for MulExtensionGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext_algebra(Self::wires_ith_output(i)); + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0); + + constraints.extend((output - computed_output).to_basefield_array()); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext(Self::wires_ith_output(i)); + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0); + + constraints.extend((output - computed_output).to_basefield_array()); + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let const_0 = vars.local_constants[0]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i)); + let output = vars.get_local_ext_algebra(Self::wires_ith_output(i)); + let computed_output = { + let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1); + builder.scalar_mul_ext_algebra(const_0, mul) + }; + + let diff = builder.sub_ext_algebra(output, computed_output); + constraints.extend(diff.to_ext_target_array()); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + local_constants: &[F], + ) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + MulExtensionGenerator { + gate_index, + const_0: local_constants[0], + i, + } + .adapter(), + ); + g + }) + .collect::>() + } + + fn num_wires(&self) -> usize { + self.num_ops * 3 * D + } + + fn num_constants(&self) -> usize { + 1 + } + + fn degree(&self) -> usize { + 3 + } + + fn num_constraints(&self) -> usize { + self.num_ops * D + } +} + +#[derive(Clone, Debug)] +struct MulExtensionGenerator, const D: usize> { + gate_index: usize, + const_0: F, + i: usize, +} + +impl, const D: usize> SimpleGenerator + for MulExtensionGenerator +{ + fn dependencies(&self) -> Vec { + MulExtensionGate::::wires_ith_multiplicand_0(self.i) + .chain(MulExtensionGate::::wires_ith_multiplicand_1(self.i)) + .map(|i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let multiplicand_0 = + extract_extension(MulExtensionGate::::wires_ith_multiplicand_0(self.i)); + let multiplicand_1 = + extract_extension(MulExtensionGate::::wires_ith_multiplicand_1(self.i)); + + let output_target = ExtensionTarget::from_range( + self.gate_index, + MulExtensionGate::::wires_ith_output(self.i), + ); + + let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(self.const_0); + + out_buffer.set_extension_target(output_target, computed_output) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::multiplication_extension::MulExtensionGate; + use crate::plonk::circuit_data::CircuitConfig; + + #[test] + fn low_degree() { + let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_low_degree::(gate); + } + + #[test] + fn eval_fns() -> Result<()> { + let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_eval_fns::(gate) + } +} diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 32ea59b6..eda31ece 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -20,6 +20,7 @@ use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; +use crate::gates::multiplication_extension::MulExtensionGate; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; use crate::gates::random_access::RandomAccessGate; @@ -70,7 +71,7 @@ pub struct CircuitBuilder, const D: usize> { marked_targets: Vec>, /// Generators used to generate the witness. - generators: Vec>>, + pub generators: Vec>>, constants_to_targets: HashMap, targets_to_constants: HashMap, @@ -769,6 +770,8 @@ pub struct BatchedGates, const D: usize> { pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>, + pub(crate) free_mul: HashMap, + /// A map `b -> (g, i)` from `b` bits to an available random access gate of that size with gate /// index `g` and already using `i` random accesses. pub(crate) free_random_access: HashMap, @@ -793,6 +796,7 @@ impl, const D: usize> BatchedGates { Self { free_arithmetic: HashMap::new(), free_base_arithmetic: HashMap::new(), + free_mul: HashMap::new(), free_random_access: HashMap::new(), current_switch_gates: Vec::new(), current_u32_arithmetic_gate: None, @@ -865,6 +869,33 @@ impl, const D: usize> CircuitBuilder { (gate, i) } + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_mul_gate(&mut self, const_0: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_mul + .get(&const_0) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + MulExtensionGate::new_from_config(&self.config), + vec![const_0], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < MulExtensionGate::::num_ops(&self.config) - 1 { + self.batched_gates.free_mul.insert(const_0, (gate, i + 1)); + } else { + self.batched_gates.free_mul.remove(&const_0); + } + + (gate, i) + } + /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index /// `g` and the gate's `i`-th random access is available. @@ -1021,6 +1052,22 @@ impl, const D: usize> CircuitBuilder { assert!(self.batched_gates.free_arithmetic.is_empty()); } + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticExtensionGenerator`s are run. + fn fill_mul_gates(&mut self) { + let zero = self.zero_extension(); + for (c0, (_gate, i)) in self.batched_gates.free_mul.clone() { + for _ in i..MulExtensionGate::::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_extension_target(); + self.arithmetic_extension(c0, F::ZERO, dummy, dummy, zero); + self.connect_extension(dummy, zero); + } + } + assert!(self.batched_gates.free_mul.is_empty()); + } + /// Fill the remaining unused random access operations with zeros, so that all /// `RandomAccessGenerator`s are run. fn fill_random_access_gates(&mut self) { @@ -1110,6 +1157,7 @@ impl, const D: usize> CircuitBuilder { fn fill_batched_gates(&mut self) { self.fill_arithmetic_gates(); self.fill_base_arithmetic_gates(); + self.fill_mul_gates(); self.fill_random_access_gates(); self.fill_switch_gates(); self.fill_u32_arithmetic_gates();