From f0f8320b932723a879ec98e0cab767ddc32dff3f Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 9 Aug 2021 10:46:29 +0200 Subject: [PATCH] First pass --- src/gadgets/arithmetic_extension.rs | 56 ++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index d76dab9d..e5702ecf 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -12,6 +12,7 @@ use crate::iop::target::Target; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bits_u64; +use crate::with_context; impl, const D: usize> CircuitBuilder { pub fn double_arithmetic_extension( @@ -152,6 +153,29 @@ impl, const D: usize> CircuitBuilder { None } + pub fn arithmetic_many_extension( + &mut self, + const_0: F, + const_1: F, + operands: Vec<[ExtensionTarget; 3]>, + ) -> Vec> { + let mut res = Vec::new(); + for chunk in operands.chunks_exact(2) { + let [fm0, fm1, fa] = chunk[0]; + let [sm0, sm1, sa] = chunk[1]; + let arithm = + self.double_arithmetic_extension(const_0, const_1, fm0, fm1, fa, sm0, sm1, sa); + res.push(arithm.0); + res.push(arithm.1); + } + if operands.len().is_odd() { + let [m0, m1, a] = operands[operands.len() - 1]; + res.push(self.arithmetic_extension(const_0, const_1, m0, m1, a)); + } + + res + } + pub fn add_extension( &mut self, a: ExtensionTarget, @@ -325,17 +349,31 @@ impl, const D: usize> CircuitBuilder { a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - let mut res = [self.zero_extension(); D]; - let w = self.constant(F::Extension::W); + let zero = self.zero_extension(); + let mut ops = Vec::new(); + let mut opsw = Vec::new(); for i in 0..D { - for j in 0..D { - res[(i + j) % D] = if i + j < D { - self.mul_add_extension(a.0[i], b.0[j], res[(i + j) % D]) - } else { - let ai_bi = self.mul_extension(a.0[i], b.0[j]); - self.scalar_mul_add_extension(w, ai_bi, res[(i + j) % D]) - } + for j in 0..D - i { + ops.push([a.0[i], b.0[j], zero]); } + for j in D - i..D { + opsw.push([a.0[i], b.0[j], zero]); + } + } + let mut muls = self.arithmetic_many_extension(F::ONE, F::ONE, ops); + let mut mulsw = self.arithmetic_many_extension(F::Extension::W, F::ONE, opsw); + let mut toadd = vec![vec![]; D]; + for i in 0..D { + for j in 0..D - i { + toadd[(i + j) % D].push(muls.remove(0)); + } + for j in D - i..D { + toadd[(i + j) % D].push(mulsw.remove(0)); + } + } + let mut res = [zero; D]; + for i in 0..D { + res[i] = self.add_many_extension(&toadd[i]); } ExtensionAlgebraTarget(res) }