diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index af97a866..7dcd88cd 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -97,7 +97,10 @@ impl, const D: usize> CircuitBuilder { a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { + // We run two additions in parallel. So `[a0,a1,a2,a3] + [b0,b1,b2,b3]` is computed with two + // `add_two_extension`, first `[a0,a1]+[b0,b1]` then `[a2,a3]+[b2,b3]`. let mut res = Vec::with_capacity(D); + // We need some extra logic if D is odd. let d_even = D & (D ^ 1); // = 2 * (D/2) for mut chunk in &(0..d_even).chunks(2) { let i = chunk.next().unwrap(); @@ -117,11 +120,13 @@ impl, const D: usize> CircuitBuilder { if terms.len() % 2 == 1 { terms.push(zero); } + // We maintain two accumulators, one for the sum of even elements, and one for odd elements. let mut acc0 = zero; let mut acc1 = zero; for chunk in terms.chunks_exact(2) { (acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]); } + // We sum both accumulators to get the final result. self.add_extension(acc0, acc1) } @@ -150,6 +155,7 @@ impl, const D: usize> CircuitBuilder { a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { + // See `add_ext_algebra`. let mut res = Vec::with_capacity(D); let d_even = D & (D ^ 1); // = 2 * (D/2) for mut chunk in &(0..d_even).chunks(2) { @@ -319,6 +325,7 @@ impl, const D: usize> CircuitBuilder { denominator: y, quotient: multiplicand_0, }); + // We need to zero out the other wires for the `ArithmeticExtensionGenerator` to hit. self.add_generator(ZeroOutGenerator { gate_index: gate, ranges: vec![ diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 6751889e..31ae5caa 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -106,7 +106,7 @@ impl, const D: usize> Gate for ArithmeticExtensionGate gate_index: usize, local_constants: &[F], ) -> Vec>> { - let gen = MulExtensionGenerator { + let gen = ArithmeticExtensionGenerator { gate_index, const_0: local_constants[0], const_1: local_constants[1], @@ -131,13 +131,13 @@ impl, const D: usize> Gate for ArithmeticExtensionGate } } -struct MulExtensionGenerator, const D: usize> { +struct ArithmeticExtensionGenerator, const D: usize> { gate_index: usize, const_0: F, const_1: F, } -impl, const D: usize> SimpleGenerator for MulExtensionGenerator { +impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator { fn dependencies(&self) -> Vec { ArithmeticExtensionGate::::wires_fixed_multiplicand() .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) diff --git a/src/util/scaling.rs b/src/util/scaling.rs index 0bc61840..87158649 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -91,6 +91,11 @@ impl ReducingFactorTarget { Self { base, count: 0 } } + /// Reduces a length `n` vector of `ExtensionTarget`s using `n/2` `ArithmeticExtensionGate`s. + /// It does this by running two accumulators in parallel. Here's an example with `n=4, alpha=2, D=1`: + /// 1st gate: 2 0 4 11 2 4 24 <- 2*0+4= 4, 2*11+2=24 + /// 2nd gate: 2 4 3 24 1 11 49 <- 2*4+3=11, 2*24+1=49 + /// which verifies that `2.reduce([1,2,3,4]) = 49`. pub fn reduce( &mut self, iter: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. @@ -99,18 +104,22 @@ impl ReducingFactorTarget { where F: Extendable, { + let zero = builder.zero_extension(); let l = iter.len(); self.count += l as u64; + // If needed we pad the original vector so that it has even length. let padded_iter = if l % 2 == 0 { iter.to_vec() } else { - [iter, &[builder.zero_extension()]].concat() + [iter, &[zero]].concat() }; let half_length = padded_iter.len() / 2; + // Add `n/2` `ArithmeticExtensionGate`s that will perform the accumulation. let gates = (0..half_length) .map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE])) .collect::>(); + // Add a generator that will fill the accumulation wires. builder.add_generator(ParallelReductionGenerator { base: self.base, padded_iter: padded_iter.clone(), @@ -119,24 +128,33 @@ impl ReducingFactorTarget { }); for i in 0..half_length { + // The fixed multiplicand is always `base`. builder.route_extension( + self.base, + ExtensionTarget::from_range( + gates[i], + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ), + ); + // Set the addends for the first half of the accumulation. + builder.route_extension( + padded_iter[2 * half_length - i - 1], ExtensionTarget::from_range( gates[i], ArithmeticExtensionGate::::wires_addend_0(), ), - padded_iter[2 * half_length - i - 1], ); - } - for i in 0..half_length { + // Set the addends for the second half of the accumulation. builder.route_extension( + padded_iter[half_length - i - 1], ExtensionTarget::from_range( gates[i], ArithmeticExtensionGate::::wires_addend_1(), ), - padded_iter[half_length - i - 1], ); } for gate_pair in gates[..half_length].windows(2) { + // Verifies that the accumulator is passed between gates for the first half of the accumulation. builder.assert_equal_extension( ExtensionTarget::from_range( gate_pair[0], @@ -149,6 +167,7 @@ impl ReducingFactorTarget { ); } for gate_pair in gates[half_length..].windows(2) { + // Verifies that the accumulator is passed between gates for the second half of the accumulation. builder.assert_equal_extension( ExtensionTarget::from_range( gate_pair[0], @@ -160,6 +179,16 @@ impl ReducingFactorTarget { ), ); } + // Verifies that the starting accumulator for the first half is zero. + builder.assert_equal_extension( + ExtensionTarget::from_range( + gates[0], + ArithmeticExtensionGate::::wires_multiplicand_0(), + ), + zero, + ); + // Verifies that the final accumulator for the first half is passed as a starting + // accumulator for the second half. builder.assert_equal_extension( ExtensionTarget::from_range( gates[half_length - 1], @@ -171,6 +200,7 @@ impl ReducingFactorTarget { ), ); + // Return the final accumulator for the second half. ExtensionTarget::from_range( gates[half_length - 1], ArithmeticExtensionGate::::wires_output_1(), @@ -206,6 +236,7 @@ impl ReducingFactorTarget { } } +/// Fills the intermediate accumulator in `ReducingFactorTarget::reduce`. struct ParallelReductionGenerator { base: ExtensionTarget, padded_iter: Vec>, @@ -215,6 +246,7 @@ struct ParallelReductionGenerator { impl, const D: usize> SimpleGenerator for ParallelReductionGenerator { fn dependencies(&self) -> Vec { + // Need only the values and the base. self.padded_iter .iter() .flat_map(|ext| ext.to_target_array()) @@ -230,6 +262,7 @@ impl, const D: usize> SimpleGenerator for ParallelReductionG .iter() .map(|&ext| witness.get_extension_target(ext)) .collect::>(); + // Computed the intermediate accumulators. let intermediate_accs = vs .iter() .rev() @@ -240,13 +273,7 @@ impl, const D: usize> SimpleGenerator for ParallelReductionG }) .collect::>(); for i in 0..self.half_length { - pw.set_extension_target( - ExtensionTarget::from_range( - self.gates[i], - ArithmeticExtensionGate::::wires_fixed_multiplicand(), - ), - base, - ); + // Fill the accumulators for the first half. pw.set_extension_target( ExtensionTarget::from_range( self.gates[i], @@ -254,13 +281,7 @@ impl, const D: usize> SimpleGenerator for ParallelReductionG ), intermediate_accs[i], ); - pw.set_extension_target( - ExtensionTarget::from_range( - self.gates[i], - ArithmeticExtensionGate::::wires_addend_0(), - ), - vs[2 * self.half_length - i - 1], - ); + // Fill the accumulators for the second half. pw.set_extension_target( ExtensionTarget::from_range( self.gates[i], @@ -268,13 +289,6 @@ impl, const D: usize> SimpleGenerator for ParallelReductionG ), intermediate_accs[self.half_length + i], ); - pw.set_extension_target( - ExtensionTarget::from_range( - self.gates[i], - ArithmeticExtensionGate::::wires_addend_1(), - ), - vs[self.half_length - i - 1], - ); } pw