diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 69eaea48..9cdeded4 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -7,6 +7,7 @@ use crate::field::field::Field; use crate::gates::mul_extension::ArithmeticExtensionGate; use crate::generator::SimpleGenerator; use crate::target::Target; +use crate::util::bits_u64; use crate::wire::Wire; use crate::witness::PartialWitness; @@ -22,6 +23,11 @@ impl, const D: usize> CircuitBuilder { self.mul(x, x) } + /// Computes `x^2`. + pub fn square_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { + self.mul_extension(x, x) + } + /// Computes `x^3`. pub fn cube(&mut self, x: Target) -> Target { self.mul_many(&[x, x, x]) @@ -161,21 +167,58 @@ impl, const D: usize> CircuitBuilder { } // TODO: Optimize this, maybe with a new gate. + // TODO: Test /// Exponentiate `base` to the power of `exponent`, where `exponent < 2^num_bits`. pub fn exp(&mut self, base: Target, exponent: Target, num_bits: usize) -> Target { let mut current = base; - let one = self.one(); - let mut product = one; + let one_ext = self.one_extension(); + let mut product = self.one(); let exponent_bits = self.split_le(exponent, num_bits); for bit in exponent_bits.into_iter() { - product = self.mul_many(&[bit, current, product]); + let current_ext = self.convert_to_ext(current); + let multiplicand = self.select(bit, current_ext, one_ext); + product = self.mul(product, multiplicand.0[0]); current = self.mul(current, current); } product } + /// Exponentiate `base` to the power of a known `exponent`. + // TODO: Test + pub fn exp_u64(&mut self, base: Target, exponent: u64) -> Target { + let mut current = base; + let mut product = self.one(); + + for j in 0..bits_u64(exponent as u64) { + if (exponent >> j & 1) != 0 { + product = self.mul(product, current); + } + current = self.square(current); + } + product + } + + /// Exponentiate `base` to the power of a known `exponent`. + // TODO: Test + pub fn exp_u64_extension( + &mut self, + base: ExtensionTarget, + exponent: u64, + ) -> ExtensionTarget { + let mut current = base; + let mut product = self.one_extension(); + + for j in 0..bits_u64(exponent as u64) { + if (exponent >> j & 1) != 0 { + product = self.mul_extension(product, current); + } + current = self.square_extension(current); + } + product + } + /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in /// some cases, as it allows `0 / 0 = `. pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target { diff --git a/src/util/scaling.rs b/src/util/scaling.rs index 1d422e8a..057e8467 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -80,112 +80,252 @@ impl ReducingFactor { } } -// #[derive(Debug, Copy, Clone)] -// pub struct ReducingFactorTarget { -// base: ExtensionTarget, -// count: u64, -// } -// -// impl, const D: usize> ReducingFactorTarget { -// pub fn new(base: ExtensionTarget) -> Self { -// Self { base, count: 0 } -// } -// -// fn mul( -// &mut self, -// x: ExtensionTarget, -// builder: &mut CircuitBuilder, -// ) -> ExtensionTarget { -// self.count += 1; -// builder.mul_extension(self.base, x) -// } -// -// pub fn reduce( -// &mut self, -// iter: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. -// builder: &mut CircuitBuilder, -// ) -> ExtensionTarget { -// let l = iter.len(); -// let padded_iter = if l % 2 == 0 { -// iter.to_vec() -// } else { -// [iter, &[builder.zero_extension()]].concat() -// }; -// let half_length = padded_iter.len() / 2; -// let gates = (0..half_length) -// .map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE])) -// .collect::>(); -// -// struct ParallelReductionGenerator<'a, const D: usize> { -// base: ExtensionTarget, -// padded_iter: &'a [ExtensionTarget], -// gates: &'a [usize], -// half_length: usize, -// } -// -// impl<'a, F: Extendable, const D: usize> SimpleGenerator -// for ParallelReductionGenerator<'a, D> -// { -// fn dependencies(&self) -> Vec { -// self.padded_iter -// .iter() -// .flat_map(|ext| ext.to_target_array()) -// .chain(self.base.to_target_array()) -// .collect() -// } -// -// fn run_once(&self, witness: &PartialWitness) -> PartialWitness { -// let mut pw = PartialWitness::new(); -// let base = witness.get_extension_target(self.base); -// let vs = self -// .padded_iter -// .iter() -// .map(|&ext| witness.get_extension_target(ext)) -// .collect::>(); -// let first_half = &vs[..self.half_length]; -// let intermediate_acc = base.reduce(first_half); -// } -// } -// } -// -// pub fn reduce_parallel( -// &mut self, -// iter0: impl DoubleEndedIterator>>, -// iter1: impl DoubleEndedIterator>>, -// builder: &mut CircuitBuilder, -// ) -> (ExtensionTarget, ExtensionTarget) { -// iter.rev().fold(builder.zero_extension(), |acc, x| { -// builder.arithmetic_extension(F::ONE, F::ONE, self.base, acc, x) -// }) -// } -// -// pub fn shift( -// &mut self, -// x: ExtensionTarget, -// builder: &mut CircuitBuilder, -// ) -> ExtensionTarget { -// let tmp = self.base.exp(self.count) * x; -// self.count = 0; -// tmp -// } -// -// pub fn shift_poly( -// &mut self, -// p: &mut PolynomialCoeffs>, -// builder: &mut CircuitBuilder, -// ) { -// *p *= self.base.exp(self.count); -// self.count = 0; -// } -// -// pub fn reset(&mut self) { -// self.count = 0; -// } -// -// pub fn repeated_frobenius(&self, count: usize, builder: &mut CircuitBuilder) -> Self { -// Self { -// base: self.base.repeated_frobenius(count), -// count: self.count, -// } -// } -// } +#[derive(Debug, Copy, Clone)] +pub struct ReducingFactorTarget { + base: ExtensionTarget, + count: u64, +} + +impl ReducingFactorTarget { + pub fn new(base: ExtensionTarget) -> Self { + Self { base, count: 0 } + } + + pub fn reduce( + &mut self, + iter: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: Extendable, + { + let l = iter.len(); + self.count += l as u64; + let padded_iter = if l % 2 == 0 { + iter.to_vec() + } else { + [iter, &[builder.zero_extension()]].concat() + }; + let half_length = padded_iter.len() / 2; + let gates = (0..half_length) + .map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE])) + .collect::>(); + + builder.add_generator(ParallelReductionGenerator { + base: self.base, + padded_iter: padded_iter.clone(), + gates: gates.clone(), + half_length, + }); + + for i in 0..half_length { + builder.route_extension( + ExtensionTarget::from_range( + gates[i], + ArithmeticExtensionGate::::wires_addend_0(), + ), + padded_iter[2 * half_length - i - 1], + ); + } + for i in 0..half_length { + builder.route_extension( + ExtensionTarget::from_range( + gates[i], + ArithmeticExtensionGate::::wires_addend_1(), + ), + padded_iter[half_length - i - 1], + ); + } + for gate_pair in gates[..half_length].windows(2) { + builder.assert_equal_extension( + ExtensionTarget::from_range( + gate_pair[0], + ArithmeticExtensionGate::::wires_output_0(), + ), + ExtensionTarget::from_range( + gate_pair[1], + ArithmeticExtensionGate::::wires_multiplicand_0(), + ), + ); + } + for gate_pair in gates[half_length..].windows(2) { + builder.assert_equal_extension( + ExtensionTarget::from_range( + gate_pair[0], + ArithmeticExtensionGate::::wires_output_1(), + ), + ExtensionTarget::from_range( + gate_pair[1], + ArithmeticExtensionGate::::wires_multiplicand_1(), + ), + ); + } + builder.assert_equal_extension( + ExtensionTarget::from_range( + gates[half_length - 1], + ArithmeticExtensionGate::::wires_output_0(), + ), + ExtensionTarget::from_range( + gates[0], + ArithmeticExtensionGate::::wires_multiplicand_1(), + ), + ); + + ExtensionTarget::from_range( + gates[half_length - 1], + ArithmeticExtensionGate::::wires_output_1(), + ) + } + + pub fn shift( + &mut self, + x: ExtensionTarget, + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: Extendable, + { + let exp = builder.exp_u64_extension(self.base, self.count); + let tmp = builder.mul_extension(exp, x); + self.count = 0; + tmp + } + + pub fn reset(&mut self) { + self.count = 0; + } + + pub fn repeated_frobenius(&self, count: usize, builder: &mut CircuitBuilder) -> Self + where + F: Extendable, + { + Self { + base: self.base.repeated_frobenius(count, builder), + count: self.count, + } + } +} + +struct ParallelReductionGenerator { + base: ExtensionTarget, + padded_iter: Vec>, + gates: Vec, + half_length: usize, +} + +impl, const D: usize> SimpleGenerator for ParallelReductionGenerator { + fn dependencies(&self) -> Vec { + self.padded_iter + .iter() + .flat_map(|ext| ext.to_target_array()) + .chain(self.base.to_target_array()) + .collect() + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let mut pw = PartialWitness::new(); + let base = witness.get_extension_target(self.base); + let vs = self + .padded_iter + .iter() + .map(|&ext| witness.get_extension_target(ext)) + .collect::>(); + let intermediate_accs = vs + .iter() + .rev() + .scan(F::Extension::ZERO, |acc, &x| { + let tmp = *acc; + *acc = *acc * base + x; + Some(tmp) + }) + .collect::>(); + for i in 0..self.half_length { + pw.set_extension_target( + ExtensionTarget::from_range( + self.gates[i], + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ), + base, + ); + pw.set_extension_target( + ExtensionTarget::from_range( + self.gates[i], + ArithmeticExtensionGate::::wires_multiplicand_0(), + ), + intermediate_accs[i], + ); + pw.set_extension_target( + ExtensionTarget::from_range( + self.gates[i], + ArithmeticExtensionGate::::wires_addend_0(), + ), + vs[2 * self.half_length - i - 1], + ); + pw.set_extension_target( + ExtensionTarget::from_range( + self.gates[i], + ArithmeticExtensionGate::::wires_multiplicand_1(), + ), + intermediate_accs[self.half_length + i], + ); + pw.set_extension_target( + ExtensionTarget::from_range( + self.gates[i], + ArithmeticExtensionGate::::wires_addend_1(), + ), + vs[self.half_length - i - 1], + ); + } + + pw + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit_data::CircuitConfig; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + + fn test_reduce_gadget(n: usize) { + type F = CrandallField; + type FF = QuarticCrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let mut builder = CircuitBuilder::::new(config); + + let alpha = FF::rand(); + let alpha = FF::ONE; + let vs = (0..n) + .map(|i| FF::from_canonical_usize(i)) + .collect::>(); + + let manual_reduce = ReducingFactor::new(alpha).reduce(vs.iter()); + let manual_reduce = builder.constant_extension(manual_reduce); + + let mut alpha_t = ReducingFactorTarget::new(builder.constant_extension(alpha)); + let vs_t = vs + .iter() + .map(|&v| builder.constant_extension(v)) + .collect::>(); + let circuit_reduce = alpha_t.reduce(&vs_t, &mut builder); + + builder.assert_equal_extension(manual_reduce, circuit_reduce); + + let data = builder.build(); + let proof = data.prove(PartialWitness::new()); + } + + #[test] + fn test_reduce_gadget_even() { + test_reduce_gadget(10); + } + + #[test] + fn test_reduce_gadget_odd() { + test_reduce_gadget(11); + } +}