diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index e9fc25a4..fc1ad37e 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -26,6 +26,7 @@ fn bench_prove, const D: usize>() -> Result<()> { num_wires: 126, num_routed_wires: 33, constant_gate_size: 6, + use_base_arithmetic_gate: false, security_bits: 128, rate_bits: 3, num_challenges: 3, diff --git a/src/fri/commitment.rs b/src/fri/commitment.rs index 8233a293..c8a13cac 100644 --- a/src/fri/commitment.rs +++ b/src/fri/commitment.rs @@ -16,8 +16,8 @@ use crate::util::reducing::ReducingFactor; use crate::util::timing::TimingTree; use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transpose}; -/// Two (~64 bit) field elements gives ~128 bit security. -pub const SALT_SIZE: usize = 2; +/// Four (~64 bit) field elements gives ~128 bit security. +pub const SALT_SIZE: usize = 4; /// Represents a batch FRI based commitment to a list of polynomials. pub struct PolynomialBatchCommitment { diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 103857c5..3fe90019 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,8 +1,8 @@ use std::borrow::Borrow; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; -use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::field::field_types::{PrimeField, RichField}; +use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::exponentiation::ExponentiationGate; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -33,18 +33,117 @@ impl, const D: usize> CircuitBuilder { multiplicand_1: Target, addend: Target, ) -> Target { - let multiplicand_0_ext = self.convert_to_ext(multiplicand_0); - let multiplicand_1_ext = self.convert_to_ext(multiplicand_1); - let addend_ext = self.convert_to_ext(addend); + // If we're not configured to use the base arithmetic gate, just call arithmetic_extension. + if !self.config.use_base_arithmetic_gate { + let multiplicand_0_ext = self.convert_to_ext(multiplicand_0); + let multiplicand_1_ext = self.convert_to_ext(multiplicand_1); + let addend_ext = self.convert_to_ext(addend); - self.arithmetic_extension( + return self + .arithmetic_extension( + const_0, + const_1, + multiplicand_0_ext, + multiplicand_1_ext, + addend_ext, + ) + .0[0]; + } + + // See if we can determine the result without adding an `ArithmeticGate`. + if let Some(result) = + self.arithmetic_special_cases(const_0, const_1, multiplicand_0, multiplicand_1, addend) + { + return result; + } + + // See if we've already computed the same operation. + let operation = BaseArithmeticOperation { const_0, const_1, - multiplicand_0_ext, - multiplicand_1_ext, - addend_ext, - ) - .0[0] + multiplicand_0, + multiplicand_1, + addend, + }; + if let Some(&result) = self.base_arithmetic_results.get(&operation) { + return result; + } + + // Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot. + let result = self.add_base_arithmetic_operation(operation); + self.base_arithmetic_results.insert(operation, result); + result + } + + fn add_base_arithmetic_operation(&mut self, operation: BaseArithmeticOperation) -> Target { + let (gate, i) = self.find_base_arithmetic_gate(operation.const_0, operation.const_1); + let wires_multiplicand_0 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_0(i)); + let wires_multiplicand_1 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_1(i)); + let wires_addend = Target::wire(gate, ArithmeticGate::wire_ith_addend(i)); + + self.connect(operation.multiplicand_0, wires_multiplicand_0); + self.connect(operation.multiplicand_1, wires_multiplicand_1); + self.connect(operation.addend, wires_addend); + + Target::wire(gate, ArithmeticGate::wire_ith_output(i)) + } + + /// Checks for special cases where the value of + /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` + /// can be determined without adding an `ArithmeticGate`. + fn arithmetic_special_cases( + &mut self, + const_0: F, + const_1: F, + multiplicand_0: Target, + multiplicand_1: Target, + addend: Target, + ) -> Option { + let zero = self.zero(); + + let mul_0_const = self.target_as_constant(multiplicand_0); + let mul_1_const = self.target_as_constant(multiplicand_1); + let addend_const = self.target_as_constant(addend); + + let first_term_zero = + const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero; + let second_term_zero = const_1 == F::ZERO || addend == zero; + + // If both terms are constant, return their (constant) sum. + let first_term_const = if first_term_zero { + Some(F::ZERO) + } else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) { + Some(x * y * const_0) + } else { + None + }; + let second_term_const = if second_term_zero { + Some(F::ZERO) + } else { + addend_const.map(|x| x * const_1) + }; + if let (Some(x), Some(y)) = (first_term_const, second_term_const) { + return Some(self.constant(x + y)); + } + + if first_term_zero && const_1.is_one() { + return Some(addend); + } + + if second_term_zero { + if let Some(x) = mul_0_const { + if (x * const_0).is_one() { + return Some(multiplicand_1); + } + } + if let Some(x) = mul_1_const { + if (x * const_0).is_one() { + return Some(multiplicand_0); + } + } + } + + None } /// Computes `x * y + z`. @@ -116,7 +215,7 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of `2^power_log`. pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target { - if power_log > ArithmeticExtensionGate::::new_from_config(&self.config).num_ops { + if power_log > ArithmeticGate::new_from_config(&self.config).num_ops { // Cheaper to just use `ExponentiateGate`. return self.exp_u64(base, 1 << power_log); } @@ -170,8 +269,7 @@ impl, const D: usize> CircuitBuilder { let base_t = self.constant(base); let exponent_bits: Vec<_> = exponent_bits.into_iter().map(|b| *b.borrow()).collect(); - if exponent_bits.len() > ArithmeticExtensionGate::::new_from_config(&self.config).num_ops - { + if exponent_bits.len() > ArithmeticGate::new_from_config(&self.config).num_ops { // Cheaper to just use `ExponentiateGate`. return self.exp_from_bits(base_t, exponent_bits); } @@ -221,3 +319,13 @@ impl, const D: usize> CircuitBuilder { self.inverse_extension(x_ext).0[0] } } + +/// Represents a base arithmetic operation in the circuit. Used to memoize results. +#[derive(Copy, Clone, Eq, PartialEq, Hash)] +pub(crate) struct BaseArithmeticOperation { + const_0: F, + const_1: F, + multiplicand_0: Target, + multiplicand_1: Target, + addend: Target, +} diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index e2654dcc..9fbffad3 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -4,7 +4,7 @@ use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTar use crate::field::extension_field::FieldExtension; use crate::field::extension_field::{Extendable, OEF}; use crate::field::field_types::{Field, PrimeField, RichField}; -use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -32,7 +32,7 @@ impl, const D: usize> CircuitBuilder { } // See if we've already computed the same operation. - let operation = ArithmeticOperation { + let operation = ExtensionArithmeticOperation { const_0, const_1, multiplicand_0, @@ -51,7 +51,7 @@ impl, const D: usize> CircuitBuilder { fn add_arithmetic_extension_operation( &mut self, - operation: ArithmeticOperation, + operation: ExtensionArithmeticOperation, ) -> ExtensionTarget { let (gate, i) = self.find_arithmetic_gate(operation.const_0, operation.const_1); let wires_multiplicand_0 = ExtensionTarget::from_range( @@ -519,9 +519,9 @@ impl, const D: usize> CircuitBuilder { } } -/// Represents an arithmetic operation in the circuit. Used to memoize results. +/// Represents an extension arithmetic operation in the circuit. Used to memoize results. #[derive(Copy, Clone, Eq, PartialEq, Hash)] -pub(crate) struct ArithmeticOperation, const D: usize> { +pub(crate) struct ExtensionArithmeticOperation, const D: usize> { const_0: F, const_1: F, multiplicand_0: ExtensionTarget, diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index 77e660e6..3a5f2421 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -127,8 +127,8 @@ mod tests { #[test] fn test_multiple_comparison() -> Result<()> { - for size in [1, 3, 6, 10] { - for num_bits in [20, 32, 40, 50] { + for size in [1, 3, 6] { + for num_bits in [20, 32, 40, 44] { test_list_le(size, num_bits).unwrap(); } } diff --git a/src/gates/arithmetic_base.rs b/src/gates/arithmetic_base.rs new file mode 100644 index 00000000..d5c131a5 --- /dev/null +++ b/src/gates/arithmetic_base.rs @@ -0,0 +1,212 @@ +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config +/// supports enough routed wires, it can support several such operations in one gate. +#[derive(Debug)] +pub struct ArithmeticGate { + /// Number of arithmetic operations performed by an arithmetic gate. + pub num_ops: usize, +} + +impl ArithmeticGate { + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + num_ops: Self::num_ops(config), + } + } + + /// Determine the maximum number of operations that can fit in one gate for the given config. + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 4; + config.num_routed_wires / wires_per_op + } + + pub fn wire_ith_multiplicand_0(i: usize) -> usize { + 4 * i + } + pub fn wire_ith_multiplicand_1(i: usize) -> usize { + 4 * i + 1 + } + pub fn wire_ith_addend(i: usize) -> usize { + 4 * i + 2 + } + pub fn wire_ith_output(i: usize) -> usize { + 4 * i + 3 + } +} + +impl, const D: usize> Gate for ArithmeticGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[Self::wire_ith_addend(i)]; + let output = vars.local_wires[Self::wire_ith_output(i)]; + let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1; + + constraints.push(output - computed_output); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[Self::wire_ith_addend(i)]; + let output = vars.local_wires[Self::wire_ith_output(i)]; + let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1; + + constraints.push(output - computed_output); + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let mut constraints = Vec::new(); + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[Self::wire_ith_addend(i)]; + let output = vars.local_wires[Self::wire_ith_output(i)]; + let computed_output = { + let scaled_mul = + builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); + builder.mul_add_extension(const_1, addend, scaled_mul) + }; + + let diff = builder.sub_extension(output, computed_output); + constraints.push(diff); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + local_constants: &[F], + ) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + ArithmeticBaseGenerator { + gate_index, + const_0: local_constants[0], + const_1: local_constants[1], + i, + } + .adapter(), + ); + g + }) + .collect::>() + } + + fn num_wires(&self) -> usize { + self.num_ops * 4 + } + + fn num_constants(&self) -> usize { + 2 + } + + fn degree(&self) -> usize { + 3 + } + + fn num_constraints(&self) -> usize { + self.num_ops + } +} + +#[derive(Clone, Debug)] +struct ArithmeticBaseGenerator, const D: usize> { + gate_index: usize, + const_0: F, + const_1: F, + i: usize, +} + +impl, const D: usize> SimpleGenerator + for ArithmeticBaseGenerator +{ + fn dependencies(&self) -> Vec { + [ + ArithmeticGate::wire_ith_multiplicand_0(self.i), + ArithmeticGate::wire_ith_multiplicand_1(self.i), + ArithmeticGate::wire_ith_addend(self.i), + ] + .iter() + .map(|&i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let get_wire = + |wire: usize| -> F { witness.get_target(Target::wire(self.gate_index, wire)) }; + + let multiplicand_0 = get_wire(ArithmeticGate::wire_ith_multiplicand_0(self.i)); + let multiplicand_1 = get_wire(ArithmeticGate::wire_ith_multiplicand_1(self.i)); + let addend = get_wire(ArithmeticGate::wire_ith_addend(self.i)); + + let output_target = Target::wire(self.gate_index, ArithmeticGate::wire_ith_output(self.i)); + + let computed_output = + multiplicand_0 * multiplicand_1 * self.const_0 + addend * self.const_1; + + out_buffer.set_target(output_target, computed_output) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::arithmetic_base::ArithmeticGate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::plonk::circuit_data::CircuitConfig; + + #[test] + fn low_degree() { + let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_low_degree::(gate); + } + + #[test] + fn eval_fns() -> Result<()> { + let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config()); + test_eval_fns::(gate) + } +} diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic_extension.rs similarity index 96% rename from src/gates/arithmetic.rs rename to src/gates/arithmetic_extension.rs index 95b48e2f..dbde7535 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic_extension.rs @@ -12,7 +12,8 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. +/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config +/// supports enough routed wires, it can support several such operations in one gate. #[derive(Debug)] pub struct ArithmeticExtensionGate { /// Number of arithmetic operations performed by an arithmetic gate. @@ -206,7 +207,7 @@ mod tests { use anyhow::Result; use crate::field::goldilocks_field::GoldilocksField; - use crate::gates::arithmetic::ArithmeticExtensionGate; + use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::plonk::circuit_data::CircuitConfig; diff --git a/src/gates/assert_le.rs b/src/gates/assert_le.rs index 4d33a867..ffbc043a 100644 --- a/src/gates/assert_le.rs +++ b/src/gates/assert_le.rs @@ -470,8 +470,8 @@ mod tests { #[test] fn low_degree() { - let num_bits = 40; - let num_chunks = 5; + let num_bits = 20; + let num_chunks = 4; test_low_degree::(AssertLessThanGate::<_, 4>::new( num_bits, num_chunks, @@ -480,8 +480,8 @@ mod tests { #[test] fn eval_fns() -> Result<()> { - let num_bits = 40; - let num_chunks = 5; + let num_bits = 20; + let num_chunks = 4; test_eval_fns::(AssertLessThanGate::<_, 4>::new( num_bits, num_chunks, diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index ed9a73ac..aaba41c7 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -223,7 +223,7 @@ impl, const D: usize> Tree> { mod tests { use super::*; use crate::field::goldilocks_field::GoldilocksField; - use crate::gates::arithmetic::ArithmeticExtensionGate; + use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::base_sum::BaseSumGate; use crate::gates::constant::ConstantGate; use crate::gates::gmimc::GMiMCGate; diff --git a/src/gates/mod.rs b/src/gates/mod.rs index da301c62..93de5e97 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -1,7 +1,8 @@ // Gates have `new` methods that return `GateRef`s. #![allow(clippy::new_ret_no_self)] -pub mod arithmetic; +pub mod arithmetic_base; +pub mod arithmetic_extension; pub mod arithmetic_u32; pub mod assert_le; pub mod base_sum; diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 6e1eb69a..59c23b44 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -56,35 +56,49 @@ where /// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0. pub const WIRE_SWAP: usize = 2 * WIDTH; + const START_DELTA: usize = 2 * WIDTH + 1; + + /// A wire which stores `swap * (input[i + 4] - input[i])`; used to compute the swapped inputs. + fn wire_delta(i: usize) -> usize { + assert!(i < 4); + Self::START_DELTA + i + } + + const START_FULL_0: usize = Self::START_DELTA + 4; + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set /// of full rounds. fn wire_full_sbox_0(round: usize, i: usize) -> usize { + debug_assert!( + round != 0, + "First round S-box inputs are not stored as wires" + ); debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); debug_assert!(i < WIDTH); - 2 * WIDTH + 1 + WIDTH * round + i + Self::START_FULL_0 + WIDTH * (round - 1) + i } + const START_PARTIAL: usize = Self::START_FULL_0 + WIDTH * (poseidon::HALF_N_FULL_ROUNDS - 1); + /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. fn wire_partial_sbox(round: usize) -> usize { debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - 2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round + Self::START_PARTIAL + round } + const START_FULL_1: usize = Self::START_PARTIAL + poseidon::N_PARTIAL_ROUNDS; + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set /// of full rounds. fn wire_full_sbox_1(round: usize, i: usize) -> usize { debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); debug_assert!(i < WIDTH); - 2 * WIDTH - + 1 - + WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) - + poseidon::N_PARTIAL_ROUNDS - + i + Self::START_FULL_1 + WIDTH * round + i } /// End of wire indices, exclusive. fn end() -> usize { - 2 * WIDTH + 1 + WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + Self::START_FULL_1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS } } @@ -104,31 +118,38 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::Extension::ONE)); - let mut state = Vec::with_capacity(WIDTH); + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - state.push(a + swap * (b - a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - state.push(a + swap * (b - a)); - } - for i in 8..WIDTH { - state.push(vars.local_wires[i]); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + constraints.push(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::Extension::ZERO; WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [F::Extension; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer_field(&mut state, round_ctr); - for i in 0..WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(state[i] - sbox_in); - state[i] = sbox_in; + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } } >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); @@ -183,31 +204,38 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * swap.sub_one()); - let mut state = Vec::with_capacity(WIDTH); + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - state.push(a + swap * (b - a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - state.push(a + swap * (b - a)); - } - for i in 8..WIDTH { - state.push(vars.local_wires[i]); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + constraints.push(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::ZERO; WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [F; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer(&mut state, round_ctr); - for i in 0..WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(state[i] - sbox_in); - state[i] = sbox_in; + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } } >::sbox_layer(&mut state); state = >::mds_layer(&state); @@ -267,38 +295,39 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(builder.mul_sub_extension(swap, swap, swap)); - let mut state = Vec::with_capacity(WIDTH); - // We need to compute both `if swap {b} else {a}` and `if swap {a} else {b}`. - // We will arithmetize them as - // swap (b - a) + a - // -swap (b - a) + b - // so that `b - a` can be used for both. - let mut state_first_4 = vec![]; - let mut state_next_4 = vec![]; + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - let delta = builder.sub_extension(b, a); - state_first_4.push(builder.mul_add_extension(swap, delta, a)); - state_next_4.push(builder.arithmetic_extension(F::NEG_ONE, F::ONE, swap, delta, b)); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let diff = builder.sub_extension(input_rhs, input_lhs); + constraints.push(builder.mul_sub_extension(swap, diff, delta_i)); } - state.extend(state_first_4); - state.extend(state_next_4); + // Compute the possibly-swapped input layer. + let mut state = [builder.zero_extension(); WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + state[i] = builder.add_extension(input_lhs, delta_i); + state[i + 4] = builder.sub_extension(input_rhs, delta_i); + } for i in 8..WIDTH { - state.push(vars.local_wires[i]); + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [ExtensionTarget; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer_recursive(builder, &mut state, round_ctr); - for i in 0..WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(builder.sub_extension(state[i], sbox_in)); - state[i] = sbox_in; + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } } >::sbox_layer_recursive(builder, &mut state); state = >::mds_layer_recursive(builder, &state); @@ -386,7 +415,7 @@ where } fn num_constraints(&self) -> usize { - WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + WIDTH * (poseidon::N_FULL_ROUNDS_TOTAL - 1) + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + 4 } } @@ -422,19 +451,20 @@ where }; let mut state = (0..WIDTH) - .map(|i| { - witness.get_wire(Wire { - gate: self.gate_index, - input: PoseidonGate::::wire_input(i), - }) - }) + .map(|i| witness.get_wire(local_wire(PoseidonGate::::wire_input(i)))) .collect::>(); - let swap_value = witness.get_wire(Wire { - gate: self.gate_index, - input: PoseidonGate::::WIRE_SWAP, - }); + let swap_value = witness.get_wire(local_wire(PoseidonGate::::WIRE_SWAP)); debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); + + for i in 0..4 { + let delta_i = swap_value * (state[i + 4] - state[i]); + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_delta(i)), + delta_i, + ); + } + if swap_value == F::ONE { for i in 0..4 { state.swap(i, 4 + i); @@ -446,11 +476,13 @@ where for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer_field(&mut state, round_ctr); - for i in 0..WIDTH { - out_buffer.set_wire( - local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), - state[i], - ); + if r != 0 { + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), + state[i], + ); + } } >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); @@ -522,6 +554,29 @@ mod tests { use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + #[test] + fn wire_indices() { + type F = GoldilocksField; + const WIDTH: usize = 12; + type Gate = PoseidonGate; + + assert_eq!(Gate::wire_input(0), 0); + assert_eq!(Gate::wire_input(11), 11); + assert_eq!(Gate::wire_output(0), 12); + assert_eq!(Gate::wire_output(11), 23); + assert_eq!(Gate::WIRE_SWAP, 24); + assert_eq!(Gate::wire_delta(0), 25); + assert_eq!(Gate::wire_delta(3), 28); + assert_eq!(Gate::wire_full_sbox_0(1, 0), 29); + assert_eq!(Gate::wire_full_sbox_0(3, 0), 53); + assert_eq!(Gate::wire_full_sbox_0(3, 11), 64); + assert_eq!(Gate::wire_partial_sbox(0), 65); + assert_eq!(Gate::wire_partial_sbox(21), 86); + assert_eq!(Gate::wire_full_sbox_1(0, 0), 87); + assert_eq!(Gate::wire_full_sbox_1(3, 0), 123); + assert_eq!(Gate::wire_full_sbox_1(3, 11), 134); + } + #[test] fn generated_output() { type F = GoldilocksField; diff --git a/src/iop/generator.rs b/src/iop/generator.rs index 8c6cb294..ae973d7c 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -87,7 +87,7 @@ pub(crate) fn generate_partial_witness<'a, F: RichField + Extendable, const D assert_eq!( remaining_generators, 0, "{} generators weren't run", - remaining_generators + remaining_generators, ); witness diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index c6c49826..aac9d42e 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -12,9 +12,11 @@ use crate::field::fft::fft_root_table; use crate::field::field_types::{Field, RichField}; use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::{FriConfig, FriParams}; -use crate::gadgets::arithmetic_extension::ArithmeticOperation; +use crate::gadgets::arithmetic::BaseArithmeticOperation; +use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation; use crate::gadgets::arithmetic_u32::U32Target; -use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::gates::arithmetic_base::ArithmeticGate; +use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; @@ -74,8 +76,11 @@ pub struct CircuitBuilder, const D: usize> { constants_to_targets: HashMap, targets_to_constants: HashMap, + /// Memoized results of `arithmetic` calls. + pub(crate) base_arithmetic_results: HashMap, Target>, + /// Memoized results of `arithmetic_extension` calls. - pub(crate) arithmetic_results: HashMap, ExtensionTarget>, + pub(crate) arithmetic_results: HashMap, ExtensionTarget>, batched_gates: BatchedGates, } @@ -93,6 +98,7 @@ impl, const D: usize> CircuitBuilder { marked_targets: Vec::new(), generators: Vec::new(), constants_to_targets: HashMap::new(), + base_arithmetic_results: HashMap::new(), arithmetic_results: HashMap::new(), targets_to_constants: HashMap::new(), batched_gates: BatchedGates::new(), @@ -742,11 +748,13 @@ impl, const D: usize> CircuitBuilder { } } -/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a CircuitBuilder track such gates that are currently being "filled up." +/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a +/// CircuitBuilder track such gates that are currently being "filled up." pub struct BatchedGates, const D: usize> { /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using /// these constants with gate index `g` and already using `i` arithmetic operations. pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, + pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>, /// A map `(c0, c1) -> (g, i)` from constants `vec_size` to an available arithmetic gate using /// these constants with gate index `g` and already using `i` random accesses. @@ -771,6 +779,7 @@ impl, const D: usize> BatchedGates { pub fn new() -> Self { Self { free_arithmetic: HashMap::new(), + free_base_arithmetic: HashMap::new(), free_random_access: HashMap::new(), current_switch_gates: Vec::new(), current_u32_arithmetic_gate: None, @@ -781,6 +790,37 @@ impl, const D: usize> BatchedGates { } impl, const D: usize> CircuitBuilder { + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_base_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_base_arithmetic + .get(&(const_0, const_1)) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + ArithmeticGate::new_from_config(&self.config), + vec![const_0, const_1], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < ArithmeticGate::num_ops(&self.config) - 1 { + self.batched_gates + .free_base_arithmetic + .insert((const_0, const_1), (gate, i + 1)); + } else { + self.batched_gates + .free_base_arithmetic + .remove(&(const_0, const_1)); + } + + (gate, i) + } + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index /// `g` and the gate's `i`-th operation is available. @@ -941,36 +981,36 @@ impl, const D: usize> CircuitBuilder { (gate, instance) } + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticGate` are run. + fn fill_base_arithmetic_gates(&mut self) { + let zero = self.zero(); + for ((c0, c1), (_gate, i)) in self.batched_gates.free_base_arithmetic.clone() { + for _ in i..ArithmeticGate::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_target(); + self.arithmetic(c0, c1, dummy, dummy, dummy); + self.connect(dummy, zero); + } + } + assert!(self.batched_gates.free_base_arithmetic.is_empty()); + } + /// Fill the remaining unused arithmetic operations with zeros, so that all /// `ArithmeticExtensionGenerator`s are run. fn fill_arithmetic_gates(&mut self) { let zero = self.zero_extension(); - let remaining_arithmetic_gates = self - .batched_gates - .free_arithmetic - .values() - .copied() - .collect::>(); - for (gate, i) in remaining_arithmetic_gates { - for j in i..ArithmeticExtensionGate::::num_ops(&self.config) { - let wires_multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_multiplicand_0(j), - ); - let wires_multiplicand_1 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_multiplicand_1(j), - ); - let wires_addend = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_addend(j), - ); - - self.connect_extension(zero, wires_multiplicand_0); - self.connect_extension(zero, wires_multiplicand_1); - self.connect_extension(zero, wires_addend); + for ((c0, c1), (_gate, i)) in self.batched_gates.free_arithmetic.clone() { + for _ in i..ArithmeticExtensionGate::::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_extension_target(); + self.arithmetic_extension(c0, c1, dummy, dummy, dummy); + self.connect_extension(dummy, zero); } } + assert!(self.batched_gates.free_arithmetic.is_empty()); } /// Fill the remaining unused random access operations with zeros, so that all @@ -1064,6 +1104,7 @@ impl, const D: usize> CircuitBuilder { fn fill_batched_gates(&mut self) { self.fill_arithmetic_gates(); + self.fill_base_arithmetic_gates(); self.fill_random_access_gates(); self.fill_switch_gates(); self.fill_u32_arithmetic_gates(); diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 869543af..564d558d 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -26,6 +26,9 @@ pub struct CircuitConfig { pub num_wires: usize, pub num_routed_wires: usize, pub constant_gate_size: usize, + /// Whether to use a dedicated gate for base field arithmetic, rather than using a single gate + /// for both base field and extension field arithmetic. + pub use_base_arithmetic_gate: bool, pub security_bits: usize, pub rate_bits: usize, /// The number of challenge points to generate, for IOPs that have soundness errors of (roughly) @@ -52,9 +55,10 @@ impl CircuitConfig { /// A typical recursion config, without zero-knowledge, targeting ~100 bit security. pub(crate) fn standard_recursion_config() -> Self { Self { - num_wires: 143, + num_wires: 135, num_routed_wires: 25, constant_gate_size: 6, + use_base_arithmetic_gate: true, security_bits: 100, rate_bits: 3, num_challenges: 2, diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 2be91b40..ef322c9f 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -28,7 +28,7 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( alphas: &[F], ) -> Vec { let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let constraint_terms = evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars); @@ -37,8 +37,6 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::new(); let l1_x = plonk_common::eval_l_1(common_data.degree(), x); @@ -71,24 +69,15 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( &denominator_values, current_partial_products, z_x, + z_gz, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); - - let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); - let final_deno_product = denominator_values[final_num_prod..] - .iter() - .copied() - .product(); - let last_partial = *current_partial_products.last().unwrap(); - let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; - vanishing_v_shift_terms.push(v_shift_term); } let vanishing_terms = [ vanishing_z_1_terms, vanishing_partial_products_terms, - vanishing_v_shift_terms, constraint_terms, ] .concat(); @@ -121,7 +110,7 @@ pub(crate) fn eval_vanishing_poly_base_batch, const assert_eq!(s_sigmas_batch.len(), n); let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let num_gate_constraints = common_data.num_gate_constraints; @@ -139,8 +128,6 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::with_capacity(num_challenges); let mut res_batch: Vec> = Vec::with_capacity(n); for k in 0..n { @@ -181,19 +168,11 @@ pub(crate) fn eval_vanishing_poly_base_batch, const &denominator_values, current_partial_products, z_x, + z_gz, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); - let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); - let final_deno_product = denominator_values[final_num_prod..] - .iter() - .copied() - .product(); - let last_partial = *current_partial_products.last().unwrap(); - let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; - vanishing_v_shift_terms.push(v_shift_term); - numerator_values.clear(); denominator_values.clear(); } @@ -201,14 +180,12 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let vanishing_terms = vanishing_z_1_terms .iter() .chain(vanishing_partial_products_terms.iter()) - .chain(vanishing_v_shift_terms.iter()) .chain(constraint_terms); let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas); res_batch.push(res); vanishing_z_1_terms.clear(); vanishing_partial_products_terms.clear(); - vanishing_v_shift_terms.clear(); } res_batch } @@ -314,7 +291,7 @@ pub(crate) fn eval_vanishing_poly_recursively, cons alphas: &[Target], ) -> Vec> { let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let constraint_terms = with_context!( builder, @@ -331,8 +308,6 @@ pub(crate) fn eval_vanishing_poly_recursively, cons let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::new(); let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg); @@ -377,23 +352,15 @@ pub(crate) fn eval_vanishing_poly_recursively, cons &denominator_values, current_partial_products, z_x, + z_gz, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); - - let final_nume_product = builder.mul_many_extension(&numerator_values[final_num_prod..]); - let final_deno_product = builder.mul_many_extension(&denominator_values[final_num_prod..]); - let z_gz_denominators = builder.mul_extension(z_gz, final_deno_product); - let last_partial = *current_partial_products.last().unwrap(); - let v_shift_term = - builder.mul_sub_extension(last_partial, final_nume_product, z_gz_denominators); - vanishing_v_shift_terms.push(v_shift_term); } let vanishing_terms = [ vanishing_z_1_terms, vanishing_partial_products_terms, - vanishing_v_shift_terms, constraint_terms, ] .concat(); diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index c4133b4d..0f3c9bfa 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -1,9 +1,12 @@ +use std::iter; + use itertools::Itertools; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::ceil_div_usize; pub(crate) fn quotient_chunk_products( quotient_values: &[F], @@ -33,70 +36,74 @@ pub(crate) fn partial_products_and_z_gx(z_x: F, quotient_chunk_product /// Returns a tuple `(a,b)`, where `a` is the length of the output of `partial_products()` on a /// vector of length `n`, and `b` is the number of original elements consumed in `partial_products()`. -pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { +pub(crate) fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { debug_assert!(max_degree > 1); let chunk_size = max_degree; - let num_chunks = n / chunk_size; - + // We'll split the product into `ceil_div_usize(n, chunk_size)` chunks, but the last chunk will + // be associated with Z(gx) itself. Thus we subtract one to get the chunks associated with + // partial products. + let num_chunks = ceil_div_usize(n, chunk_size) - 1; (num_chunks, num_chunks * chunk_size) } -/// Checks that the partial products of `numerators/denominators` are coherent with those in `partials` by only computing -/// products of size `max_degree` or less. +/// Checks the relationship between each pair of partial product accumulators. In particular, this +/// sequence of accumulators starts with `Z(x)`, then contains each partial product polynomials +/// `p_i(x)`, and finally `Z(g x)`. See the partial products section of the Plonky2 paper. pub(crate) fn check_partial_products( numerators: &[F], denominators: &[F], partials: &[F], z_x: F, + z_gx: F, max_degree: usize, ) -> Vec { debug_assert!(max_degree > 1); - let mut acc = z_x; - let mut partials = partials.iter(); - let mut res = Vec::new(); + let product_accs = iter::once(&z_x) + .chain(partials.iter()) + .chain(iter::once(&z_gx)); let chunk_size = max_degree; - for (nume_chunk, deno_chunk) in numerators - .chunks_exact(chunk_size) - .zip_eq(denominators.chunks_exact(chunk_size)) - { - let num_chunk_product = nume_chunk.iter().copied().product(); - let den_chunk_product = deno_chunk.iter().copied().product(); - let new_acc = *partials.next().unwrap(); - res.push(acc * num_chunk_product - new_acc * den_chunk_product); - acc = new_acc; - } - debug_assert!(partials.next().is_none()); - - res + numerators + .chunks(chunk_size) + .zip_eq(denominators.chunks(chunk_size)) + .zip_eq(product_accs.tuple_windows()) + .map(|((nume_chunk, deno_chunk), (&prev_acc, &next_acc))| { + let num_chunk_product = nume_chunk.iter().copied().product(); + let den_chunk_product = deno_chunk.iter().copied().product(); + // Assert that next_acc * deno_product = prev_acc * nume_product. + prev_acc * num_chunk_product - next_acc * den_chunk_product + }) + .collect() } +/// Checks the relationship between each pair of partial product accumulators. In particular, this +/// sequence of accumulators starts with `Z(x)`, then contains each partial product polynomials +/// `p_i(x)`, and finally `Z(g x)`. See the partial products section of the Plonky2 paper. pub(crate) fn check_partial_products_recursively, const D: usize>( builder: &mut CircuitBuilder, numerators: &[ExtensionTarget], denominators: &[ExtensionTarget], partials: &[ExtensionTarget], - mut acc: ExtensionTarget, + z_x: ExtensionTarget, + z_gx: ExtensionTarget, max_degree: usize, ) -> Vec> { debug_assert!(max_degree > 1); - let mut partials = partials.iter(); - let mut res = Vec::new(); + let product_accs = iter::once(&z_x) + .chain(partials.iter()) + .chain(iter::once(&z_gx)); let chunk_size = max_degree; - for (nume_chunk, deno_chunk) in numerators - .chunks_exact(chunk_size) - .zip(denominators.chunks_exact(chunk_size)) - { - let nume_product = builder.mul_many_extension(nume_chunk); - let deno_product = builder.mul_many_extension(deno_chunk); - let new_acc = *partials.next().unwrap(); - let new_acc_deno = builder.mul_extension(new_acc, deno_product); - // Assert that new_acc*deno_product = acc * nume_product. - res.push(builder.mul_sub_extension(acc, nume_product, new_acc_deno)); - acc = new_acc; - } - debug_assert!(partials.next().is_none()); - - res + numerators + .chunks(chunk_size) + .zip_eq(denominators.chunks(chunk_size)) + .zip_eq(product_accs.tuple_windows()) + .map(|((nume_chunk, deno_chunk), (&prev_acc, &next_acc))| { + let nume_product = builder.mul_many_extension(nume_chunk); + let deno_product = builder.mul_many_extension(deno_chunk); + let next_acc_deno = builder.mul_extension(next_acc, deno_product); + // Assert that next_acc * deno_product = prev_acc * nume_product. + builder.mul_sub_extension(prev_acc, nume_product, next_acc_deno) + }) + .collect() } #[cfg(test)] @@ -108,36 +115,31 @@ mod tests { fn test_partial_products() { type F = GoldilocksField; let denominators = vec![F::ONE; 6]; + let z_x = F::ONE; let v = field_vec(&[1, 2, 3, 4, 5, 6]); + let z_gx = F::from_canonical_u64(720); let quotient_chunks_prods = quotient_chunk_products(&v, 2); assert_eq!(quotient_chunks_prods, field_vec(&[2, 12, 30])); - let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); - assert_eq!(p, field_vec(&[2, 24, 720])); + let pps_and_z_gx = partial_products_and_z_gx(z_x, "ient_chunks_prods); + let pps = &pps_and_z_gx[..pps_and_z_gx.len() - 1]; + assert_eq!(pps_and_z_gx, field_vec(&[2, 24, 720])); let nums = num_partial_products(v.len(), 2); - assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &denominators, &p, F::ONE, 2) + assert_eq!(pps.len(), nums.0); + assert!(check_partial_products(&v, &denominators, pps, z_x, z_gx, 2) .iter() .all(|x| x.is_zero())); - assert_eq!( - *p.last().unwrap() * v[nums.1..].iter().copied().product::(), - v.into_iter().product::(), - ); - let v = field_vec(&[1, 2, 3, 4, 5, 6]); let quotient_chunks_prods = quotient_chunk_products(&v, 3); assert_eq!(quotient_chunks_prods, field_vec(&[6, 120])); - let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); - assert_eq!(p, field_vec(&[6, 720])); + let pps_and_z_gx = partial_products_and_z_gx(z_x, "ient_chunks_prods); + let pps = &pps_and_z_gx[..pps_and_z_gx.len() - 1]; + assert_eq!(pps_and_z_gx, field_vec(&[6, 720])); let nums = num_partial_products(v.len(), 3); - assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &denominators, &p, F::ONE, 3) + assert_eq!(pps.len(), nums.0); + assert!(check_partial_products(&v, &denominators, pps, z_x, z_gx, 3) .iter() .all(|x| x.is_zero())); - assert_eq!( - *p.last().unwrap() * v[nums.1..].iter().copied().product::(), - v.into_iter().product::(), - ); } fn field_vec(xs: &[usize]) -> Vec {