From 8642a10fde28af930c80f174d724499034c6cc94 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 21 Jul 2021 15:58:15 +0200 Subject: [PATCH 1/5] Start of optimization --- src/gadgets/arithmetic_extension.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 10b60dcd..0cf57ed6 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -80,6 +80,7 @@ impl, const D: usize> CircuitBuilder { self.arithmetic_extension(F::ONE, F::ONE, one, a, b) } + /// Returns `(a0+b0, a1+b1)`. pub fn add_two_extension( &mut self, a0: ExtensionTarget, @@ -196,6 +197,17 @@ impl, const D: usize> CircuitBuilder { self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1) } + /// Returns `(a0*b0, a1*b1)`. + pub fn mul_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + todo!() + } + /// Computes `x^2`. pub fn square_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { self.mul_extension(x, x) @@ -222,6 +234,19 @@ impl, const D: usize> CircuitBuilder { } pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { + let one = self.one_extension(); + let mut terms = terms.to_vec(); + if terms.len().is_odd() { + terms.push(one); + } + // We maintain two accumulators, one for the sum of even elements, and one for odd elements. + let mut acc0 = one; + let mut acc1 = one; + for chunk in terms.chunks_exact(2) { + (acc0, acc1) = self.mul_two_extension(acc0, chunk[0], acc1, chunk[1]); + } + // We sum both accumulators to get the final result. + self.add_extension(acc0, acc1) let mut product = self.one_extension(); for term in terms { product = self.mul_extension(product, *term); From b59d4979641d078c68e3e17998abbff8410761d7 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 21 Jul 2021 17:20:08 +0200 Subject: [PATCH 2/5] Modify ArithmeticExtensionGate to support 32 wires --- src/gadgets/arithmetic_extension.rs | 115 +++++++++++++++-------- src/gates/arithmetic.rs | 136 +++++++++++++++------------- src/util/scaling.rs | 7 +- 3 files changed, 158 insertions(+), 100 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 0cf57ed6..33165ce2 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -17,37 +17,47 @@ impl, const D: usize> CircuitBuilder { &mut self, const_0: F, const_1: F, - fixed_multiplicand: ExtensionTarget, - multiplicand_0: ExtensionTarget, - addend_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - addend_1: ExtensionTarget, + first_multiplicand_0: ExtensionTarget, + first_multiplicand_1: ExtensionTarget, + first_addend: ExtensionTarget, + second_multiplicand_0: ExtensionTarget, + second_multiplicand_1: ExtensionTarget, + second_addend: ExtensionTarget, ) -> (ExtensionTarget, ExtensionTarget) { let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]); - let wire_fixed_multiplicand = ExtensionTarget::from_range( + let wire_first_multiplicand_0 = ExtensionTarget::from_range( gate, - ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ArithmeticExtensionGate::::wires_first_multiplicand_0(), ); - let wire_multiplicand_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); - let wire_addend_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_0()); - let wire_multiplicand_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); - let wire_addend_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_1()); - let wire_output_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); - let wire_output_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_1()); + let wire_first_multiplicand_1 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_first_multiplicand_1(), + ); + let wire_first_addend = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_addend()); + let wire_second_multiplicand_0 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_second_multiplicand_0(), + ); + let wire_second_multiplicand_1 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_second_multiplicand_1(), + ); + let wire_second_addend = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_second_addend()); + let wire_first_output = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); + let wire_second_output = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_second_output()); - self.route_extension(fixed_multiplicand, wire_fixed_multiplicand); - self.route_extension(multiplicand_0, wire_multiplicand_0); - self.route_extension(addend_0, wire_addend_0); - self.route_extension(multiplicand_1, wire_multiplicand_1); - self.route_extension(addend_1, wire_addend_1); - (wire_output_0, wire_output_1) + self.route_extension(first_multiplicand_0, wire_first_multiplicand_0); + self.route_extension(first_multiplicand_1, wire_first_multiplicand_1); + self.route_extension(first_addend, wire_first_addend); + self.route_extension(second_multiplicand_0, wire_second_multiplicand_0); + self.route_extension(second_multiplicand_1, wire_second_multiplicand_1); + self.route_extension(second_addend, wire_second_addend); + (wire_first_output, wire_second_output) } pub fn arithmetic_extension( @@ -67,6 +77,7 @@ impl, const D: usize> CircuitBuilder { addend, zero, zero, + zero, ) .0 } @@ -89,7 +100,7 @@ impl, const D: usize> CircuitBuilder { b1: ExtensionTarget, ) -> (ExtensionTarget, ExtensionTarget) { let one = self.one_extension(); - self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, a1, b1) + self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, one, a1, b1) } pub fn add_ext_algebra( @@ -147,7 +158,7 @@ impl, const D: usize> CircuitBuilder { b1: ExtensionTarget, ) -> (ExtensionTarget, ExtensionTarget) { let one = self.one_extension(); - self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, a1, b1) + self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, one, a1, b1) } pub fn sub_ext_algebra( @@ -185,6 +196,7 @@ impl, const D: usize> CircuitBuilder { zero, zero, zero, + zero, ) .0 } @@ -205,7 +217,8 @@ impl, const D: usize> CircuitBuilder { a1: ExtensionTarget, b1: ExtensionTarget, ) -> (ExtensionTarget, ExtensionTarget) { - todo!() + let zero = self.zero_extension(); + self.double_arithmetic_extension(F::ONE, F::ZERO, a0, b0, zero, a1, b1, zero) } /// Computes `x^2`. @@ -239,19 +252,14 @@ impl, const D: usize> CircuitBuilder { if terms.len().is_odd() { terms.push(one); } - // We maintain two accumulators, one for the sum of even elements, and one for odd elements. + // We maintain two accumulators, one for the product of even elements, and one for odd elements. let mut acc0 = one; let mut acc1 = one; for chunk in terms.chunks_exact(2) { (acc0, acc1) = self.mul_two_extension(acc0, chunk[0], acc1, chunk[1]); } - // We sum both accumulators to get the final result. - self.add_extension(acc0, acc1) - let mut product = self.one_extension(); - for term in terms { - product = self.mul_extension(product, *term); - } - product + // We multiply both accumulators to get the final result. + self.mul_extension(acc0, acc1) } /// Like `mul_add`, but for `ExtensionTarget`s. @@ -468,6 +476,41 @@ mod tests { use crate::verifier::verify; use crate::witness::PartialWitness; + #[test] + fn test_mul_many() -> Result<()> { + type F = CrandallField; + type FF = QuarticCrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::new(); + + let vs = FF::rand_vec(20); + let ts = builder.add_virtual_extension_targets(20); + for (&v, &t) in vs.iter().zip(&ts) { + pw.set_extension_target(t, v); + } + let mul0 = builder.mul_many_extension(&ts); + let mul1 = { + let mut acc = builder.one_extension(); + for &t in &ts { + acc = builder.mul_extension(acc, t); + } + acc + }; + let mul2 = builder.constant_extension(vs.into_iter().product()); + + builder.assert_equal_extension(mul0, mul1); + builder.assert_equal_extension(mul1, mul2); + + let data = builder.build(); + let proof = data.prove(pw)?; + + verify(proof, &data.verifier_only, &data.common) + } + #[test] fn test_div_extension() -> Result<()> { type F = CrandallField; diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 39baa226..2548c498 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -18,27 +18,30 @@ impl ArithmeticExtensionGate { GateRef::new(ArithmeticExtensionGate) } - pub fn wires_fixed_multiplicand() -> Range { + pub fn wires_first_multiplicand_0() -> Range { 0..D } - pub fn wires_multiplicand_0() -> Range { + pub fn wires_first_multiplicand_1() -> Range { D..2 * D } - pub fn wires_addend_0() -> Range { + pub fn wires_first_addend() -> Range { 2 * D..3 * D } - pub fn wires_multiplicand_1() -> Range { + pub fn wires_second_multiplicand_0() -> Range { 3 * D..4 * D } - pub fn wires_addend_1() -> Range { + pub fn wires_second_multiplicand_1() -> Range { 4 * D..5 * D } - pub fn wires_output_0() -> Range { + pub fn wires_second_addend() -> Range { 5 * D..6 * D } - pub fn wires_output_1() -> Range { + pub fn wires_first_output() -> Range { 6 * D..7 * D } + pub fn wires_second_output() -> Range { + 7 * D..8 * D + } } impl, const D: usize> Gate for ArithmeticExtensionGate { @@ -50,21 +53,24 @@ impl, const D: usize> Gate for ArithmeticExtensionGate let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); - let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); - let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); - let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); - let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); - let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0()); + let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1()); + let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend()); + let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0()); + let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1()); + let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend()); + let first_output = vars.get_local_ext_algebra(Self::wires_first_output()); + let second_output = vars.get_local_ext_algebra(Self::wires_second_output()); - let computed_output_0 = - fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into(); - let computed_output_1 = - fixed_multiplicand * multiplicand_1 * const_0.into() + addend_1 * const_1.into(); + let first_computed_output = first_multiplicand_0 * first_multiplicand_1 * const_0.into() + + first_addend * const_1.into(); + let second_computed_output = second_multiplicand_0 * second_multiplicand_1 * const_0.into() + + second_addend * const_1.into(); - let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec(); - constraints.extend((output_1 - computed_output_1).to_basefield_array()); + let mut constraints = (first_output - first_computed_output) + .to_basefield_array() + .to_vec(); + constraints.extend((second_output - second_computed_output).to_basefield_array()); constraints } @@ -76,26 +82,32 @@ impl, const D: usize> Gate for ArithmeticExtensionGate let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); - let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); - let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); - let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); - let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); - let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0()); + let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1()); + let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend()); + let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0()); + let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1()); + let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend()); + let first_output = vars.get_local_ext_algebra(Self::wires_first_output()); + let second_output = vars.get_local_ext_algebra(Self::wires_second_output()); - let computed_output_0 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_0); - let computed_output_0 = builder.scalar_mul_ext_algebra(const_0, computed_output_0); - let scaled_addend_0 = builder.scalar_mul_ext_algebra(const_1, addend_0); - let computed_output_0 = builder.add_ext_algebra(computed_output_0, scaled_addend_0); + let first_computed_output = + builder.mul_ext_algebra(first_multiplicand_0, first_multiplicand_1); + let first_computed_output = builder.scalar_mul_ext_algebra(const_0, first_computed_output); + let first_scaled_addend = builder.scalar_mul_ext_algebra(const_1, first_addend); + let first_computed_output = + builder.add_ext_algebra(first_computed_output, first_scaled_addend); - let computed_output_1 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_1); - let computed_output_1 = builder.scalar_mul_ext_algebra(const_0, computed_output_1); - let scaled_addend_1 = builder.scalar_mul_ext_algebra(const_1, addend_1); - let computed_output_1 = builder.add_ext_algebra(computed_output_1, scaled_addend_1); + let second_computed_output = + builder.mul_ext_algebra(second_multiplicand_0, second_multiplicand_1); + let second_computed_output = + builder.scalar_mul_ext_algebra(const_0, second_computed_output); + let second_scaled_addend = builder.scalar_mul_ext_algebra(const_1, second_addend); + let second_computed_output = + builder.add_ext_algebra(second_computed_output, second_scaled_addend); - let diff_0 = builder.sub_ext_algebra(output_0, computed_output_0); - let diff_1 = builder.sub_ext_algebra(output_1, computed_output_1); + let diff_0 = builder.sub_ext_algebra(first_output, first_computed_output); + let diff_1 = builder.sub_ext_algebra(second_output, second_computed_output); let mut constraints = diff_0.to_ext_target_array().to_vec(); constraints.extend(diff_1.to_ext_target_array()); constraints @@ -120,7 +132,7 @@ impl, const D: usize> Gate for ArithmeticExtensionGate } fn num_wires(&self) -> usize { - 7 * D + 8 * D } fn num_constants(&self) -> usize { @@ -150,9 +162,9 @@ struct ArithmeticExtensionGenerator1, const D: usize> { impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator0 { fn dependencies(&self) -> Vec { - ArithmeticExtensionGate::::wires_fixed_multiplicand() - .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) - .chain(ArithmeticExtensionGate::::wires_addend_0()) + ArithmeticExtensionGate::::wires_first_multiplicand_0() + .chain(ArithmeticExtensionGate::::wires_first_multiplicand_1()) + .chain(ArithmeticExtensionGate::::wires_first_addend()) .map(|i| Target::wire(self.gate_index, i)) .collect() } @@ -163,29 +175,29 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio witness.get_extension_target(t) }; - let fixed_multiplicand = - extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); let multiplicand_0 = - extract_extension(ArithmeticExtensionGate::::wires_multiplicand_0()); - let addend_0 = extract_extension(ArithmeticExtensionGate::::wires_addend_0()); + extract_extension(ArithmeticExtensionGate::::wires_first_multiplicand_0()); + let multiplicand_1 = + extract_extension(ArithmeticExtensionGate::::wires_first_multiplicand_1()); + let addend = extract_extension(ArithmeticExtensionGate::::wires_first_addend()); - let output_target_0 = ExtensionTarget::from_range( + let output_target = ExtensionTarget::from_range( self.gate_index, - ArithmeticExtensionGate::::wires_output_0(), + ArithmeticExtensionGate::::wires_first_output(), ); - let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() - + addend_0 * self.const_1.into(); + let computed_output = + multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into(); - PartialWitness::singleton_extension_target(output_target_0, computed_output_0) + PartialWitness::singleton_extension_target(output_target, computed_output) } } impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator1 { fn dependencies(&self) -> Vec { - ArithmeticExtensionGate::::wires_fixed_multiplicand() - .chain(ArithmeticExtensionGate::::wires_multiplicand_1()) - .chain(ArithmeticExtensionGate::::wires_addend_1()) + ArithmeticExtensionGate::::wires_second_multiplicand_0() + .chain(ArithmeticExtensionGate::::wires_second_multiplicand_1()) + .chain(ArithmeticExtensionGate::::wires_second_addend()) .map(|i| Target::wire(self.gate_index, i)) .collect() } @@ -196,21 +208,21 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio witness.get_extension_target(t) }; - let fixed_multiplicand = - extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); + let multiplicand_0 = + extract_extension(ArithmeticExtensionGate::::wires_second_multiplicand_0()); let multiplicand_1 = - extract_extension(ArithmeticExtensionGate::::wires_multiplicand_1()); - let addend_1 = extract_extension(ArithmeticExtensionGate::::wires_addend_1()); + extract_extension(ArithmeticExtensionGate::::wires_second_multiplicand_1()); + let addend = extract_extension(ArithmeticExtensionGate::::wires_second_addend()); - let output_target_1 = ExtensionTarget::from_range( + let output_target = ExtensionTarget::from_range( self.gate_index, - ArithmeticExtensionGate::::wires_output_1(), + ArithmeticExtensionGate::::wires_second_output(), ); - let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into() - + addend_1 * self.const_1.into(); + let computed_output = + multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into(); - PartialWitness::singleton_extension_target(output_target_1, computed_output_1) + PartialWitness::singleton_extension_target(output_target, computed_output) } } diff --git a/src/util/scaling.rs b/src/util/scaling.rs index be339784..af32201d 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -122,8 +122,10 @@ impl ReducingFactorTarget { // out_0 = alpha acc + pair[0] // acc' = out_1 = alpha out_0 + pair[1] let gate = builder.num_gates(); - let out_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + let out_0 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_first_output(), + ); acc = builder .double_arithmetic_extension( F::ONE, @@ -131,6 +133,7 @@ impl ReducingFactorTarget { self.base, acc, pair[0], + self.base, out_0, pair[1], ) From d870a36deeb2f3bb56b7e90b3d9a140299a86e7a Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 21 Jul 2021 17:29:05 +0200 Subject: [PATCH 3/5] {add|mul}_three_extension --- src/gadgets/arithmetic_extension.rs | 40 ++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 33165ce2..fa6df6a9 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -125,6 +125,22 @@ impl, const D: usize> CircuitBuilder { ExtensionAlgebraTarget(res.try_into().unwrap()) } + /// Add 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. + pub fn add_three_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let one = self.one_extension(); + let gate = self.num_gates(); + let first_out = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); + self.double_arithmetic_extension(F::ONE, F::ONE, one, a, b, one, c, first_out) + .1 + } + + /// Add `n` `ExtensionTarget`s with `n/2 + 1` `ArithmeticExtensionGate`s. pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let zero = self.zero_extension(); let mut terms = terms.to_vec(); @@ -246,6 +262,22 @@ impl, const D: usize> CircuitBuilder { ExtensionAlgebraTarget(res) } + /// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. + pub fn mul_three_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let zero = self.zero_extension(); + let gate = self.num_gates(); + let first_out = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); + self.double_arithmetic_extension(F::ONE, F::ZERO, a, b, zero, c, first_out, zero) + .1 + } + + /// Multiply `n` `ExtensionTarget`s with `n/2 + 1` `ArithmeticExtensionGate`s. pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let one = self.one_extension(); let mut terms = terms.to_vec(); @@ -487,8 +519,8 @@ mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::new(); - let vs = FF::rand_vec(20); - let ts = builder.add_virtual_extension_targets(20); + let vs = FF::rand_vec(3); + let ts = builder.add_virtual_extension_targets(3); for (&v, &t) in vs.iter().zip(&ts) { pw.set_extension_target(t, v); } @@ -500,10 +532,12 @@ mod tests { } acc }; - let mul2 = builder.constant_extension(vs.into_iter().product()); + let mul2 = builder.mul_three_extension(ts[0], ts[1], ts[2]); + let mul3 = builder.constant_extension(vs.into_iter().product()); builder.assert_equal_extension(mul0, mul1); builder.assert_equal_extension(mul1, mul2); + builder.assert_equal_extension(mul2, mul3); let data = builder.build(); let proof = data.prove(pw)?; From 6e305f0a3ef45d091e2518ba3c83967376ca43c4 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 21 Jul 2021 17:41:22 +0200 Subject: [PATCH 4/5] Change `{add|mul}_many` and `cube` --- src/gadgets/arithmetic.rs | 25 ++++++++++++++----------- src/gadgets/arithmetic_extension.rs | 4 ++-- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 6f85cdcf..f20b2f01 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -16,7 +16,8 @@ impl, const D: usize> CircuitBuilder { /// Computes `x^3`. pub fn cube(&mut self, x: Target) -> Target { - self.mul_many(&[x, x, x]) + let xe = self.convert_to_ext(x); + self.mul_three_extension(xe, xe, xe).to_target_array()[0] } /// Computes `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`. @@ -123,13 +124,14 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, one, F::ONE, y) } + /// Add `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. // TODO: Can be made `2*D` times more efficient by using all wires of an `ArithmeticExtensionGate`. pub fn add_many(&mut self, terms: &[Target]) -> Target { - let mut sum = self.zero(); - for term in terms { - sum = self.add(sum, *term); - } - sum + let terms_ext = terms + .iter() + .map(|&t| self.convert_to_ext(t)) + .collect::>(); + self.add_many_extension(&terms_ext).to_target_array()[0] } /// Computes `x - y`. @@ -145,12 +147,13 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, y, F::ZERO, x) } + /// Multiply `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. pub fn mul_many(&mut self, terms: &[Target]) -> Target { - let mut product = self.one(); - for term in terms { - product = self.mul(product, *term); - } - product + let terms_ext = terms + .iter() + .map(|&t| self.convert_to_ext(t)) + .collect::>(); + self.mul_many_extension(&terms_ext).to_target_array()[0] } // TODO: Optimize this, maybe with a new gate. diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index fa6df6a9..22f70884 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -140,7 +140,7 @@ impl, const D: usize> CircuitBuilder { .1 } - /// Add `n` `ExtensionTarget`s with `n/2 + 1` `ArithmeticExtensionGate`s. + /// Add `n` `ExtensionTarget`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let zero = self.zero_extension(); let mut terms = terms.to_vec(); @@ -277,7 +277,7 @@ impl, const D: usize> CircuitBuilder { .1 } - /// Multiply `n` `ExtensionTarget`s with `n/2 + 1` `ArithmeticExtensionGate`s. + /// Multiply `n` `ExtensionTarget`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let one = self.one_extension(); let mut terms = terms.to_vec(); From 6cc871d30c2c5688462840693ee1129d9687f15b Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 21 Jul 2021 19:23:26 +0200 Subject: [PATCH 5/5] PR feedback --- src/gadgets/arithmetic_extension.rs | 37 ++++++++++++++++------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 9aedc2fe..e6efb451 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -140,21 +140,24 @@ impl, const D: usize> CircuitBuilder { .1 } - /// Add `n` `ExtensionTarget`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. + /// Add `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let zero = self.zero_extension(); let mut terms = terms.to_vec(); - if terms.len().is_odd() { + if terms.is_empty() { + return zero; + } else if terms.len() < 3 { + terms.resize(3, zero); + } else if terms.len().is_even() { terms.push(zero); } - // We maintain two accumulators, one for the sum of even elements, and one for odd elements. - let mut acc0 = zero; - let mut acc1 = zero; + + let mut acc = self.add_three_extension(terms[0], terms[1], terms[2]); + terms.drain(0..3); for chunk in terms.chunks_exact(2) { - (acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]); + acc = self.add_three_extension(acc, chunk[0], chunk[1]); } - // We sum both accumulators to get the final result. - self.add_extension(acc0, acc1) + acc } pub fn sub_extension( @@ -277,21 +280,23 @@ impl, const D: usize> CircuitBuilder { .1 } - /// Multiply `n` `ExtensionTarget`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. + /// Multiply `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let one = self.one_extension(); let mut terms = terms.to_vec(); - if terms.len().is_odd() { + if terms.is_empty() { + return one; + } else if terms.len() < 3 { + terms.resize(3, one); + } else if terms.len().is_even() { terms.push(one); } - // We maintain two accumulators, one for the product of even elements, and one for odd elements. - let mut acc0 = one; - let mut acc1 = one; + let mut acc = self.mul_three_extension(terms[0], terms[1], terms[2]); + terms.drain(0..3); for chunk in terms.chunks_exact(2) { - (acc0, acc1) = self.mul_two_extension(acc0, chunk[0], acc1, chunk[1]); + acc = self.mul_three_extension(acc, chunk[0], chunk[1]); } - // We multiply both accumulators to get the final result. - self.mul_extension(acc0, acc1) + acc } /// Like `mul_add`, but for `ExtensionTarget`s.