From 8796c73362a32283157fd43ca36793f6ba19122f Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 23 Jun 2021 18:04:43 +0200 Subject: [PATCH 01/22] Change `MulExtensionGate` to `ArithmeticExtensionGate` and change gadgets to use the new wires in this gate. --- src/circuit_builder.rs | 5 + src/circuit_data.rs | 2 +- src/field/extension_field/target.rs | 200 +++++++++++++++++++++------- src/gadgets/arithmetic.rs | 16 +-- src/gadgets/interpolation.rs | 7 +- src/gates/mul_extension.rs | 139 +++++++++++++------ src/lib.rs | 2 + src/recursive_verifier.rs | 2 +- 8 files changed, 266 insertions(+), 107 deletions(-) diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 3a26497e..ec6f60be 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -95,6 +95,11 @@ impl, const D: usize> CircuitBuilder { /// Adds a gate to the circuit, and returns its index. pub fn add_gate(&mut self, gate_type: GateRef, constants: Vec) -> usize { + assert_eq!( + gate_type.0.num_constants(), + constants.len(), + "Number of constants doesn't match." + ); // If we haven't seen a gate of this type before, check that it's compatible with our // circuit configuration, then register it. if !self.gates.contains(&gate_type) { diff --git a/src/circuit_data.rs b/src/circuit_data.rs index c8477464..c6c27d35 100644 --- a/src/circuit_data.rs +++ b/src/circuit_data.rs @@ -51,7 +51,7 @@ impl CircuitConfig { pub(crate) fn large_config() -> Self { Self { num_wires: 134, - num_routed_wires: 12, + num_routed_wires: 28, security_bits: 128, rate_bits: 3, num_challenges: 3, diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 0316b04f..f6cec119 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -1,11 +1,13 @@ use std::convert::{TryFrom, TryInto}; use std::ops::Range; +use itertools::Itertools; + use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::{Extendable, FieldExtension, OEF}; use crate::field::field::Field; -use crate::gates::mul_extension::MulExtensionGate; +use crate::gates::mul_extension::ArithmeticExtensionGate; use crate::target::Target; /// `Target`s representing an element of an extension field. @@ -110,57 +112,155 @@ impl, const D: usize> CircuitBuilder { self.constant_ext_algebra(ExtensionAlgebra::ZERO) } + pub fn double_arithmetic_extension( + &mut self, + const_0: F, + const_1: F, + fixed_multiplicand: ExtensionTarget, + multiplicand_0: ExtensionTarget, + addend_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend_1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]); + + let wire_fixed_multiplicand = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ); + let wire_multiplicand_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); + let wire_addend_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_0()); + let wire_multiplicand_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); + let wire_addend_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_1()); + let wire_output_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + let wire_output_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_1()); + + self.route_extension(fixed_multiplicand, wire_fixed_multiplicand); + self.route_extension(multiplicand_0, wire_multiplicand_0); + self.route_extension(addend_0, wire_addend_0); + self.route_extension(multiplicand_1, wire_multiplicand_1); + self.route_extension(addend_1, wire_addend_1); + self.route_extension(multiplicand_1, wire_multiplicand_1); + (wire_output_0, wire_output_1) + } + + pub fn arithmetic_extension( + &mut self, + const_0: F, + const_1: F, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend: ExtensionTarget, + ) -> ExtensionTarget { + let zero = self.zero_extension(); + self.double_arithmetic_extension( + const_0, + const_1, + multiplicand_0, + multiplicand_1, + addend, + zero, + zero, + ) + .0 + } + pub fn add_extension( &mut self, - mut a: ExtensionTarget, + a: ExtensionTarget, b: ExtensionTarget, ) -> ExtensionTarget { - for i in 0..D { - a.0[i] = self.add(a.0[i], b.0[i]); - } - a + let one = self.one_extension(); + self.arithmetic_extension(F::ONE, F::ONE, one, a, b) + } + + pub fn add_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let one = self.one_extension(); + self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, a1, b1) } pub fn add_ext_algebra( &mut self, - mut a: ExtensionAlgebraTarget, + a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - for i in 0..D { - a.0[i] = self.add_extension(a.0[i], b.0[i]); + let mut res = vec![]; + let d_even = D & (D ^ 1); // = 2 * (D/2) + for mut chunk in &(0..d_even).chunks(2) { + let i = chunk.next().unwrap(); + let j = chunk.next().unwrap(); + let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); + res.extend([o0, o1]); } - a + if D % 2 == 1 { + res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); + } + ExtensionAlgebraTarget(res.try_into().unwrap()) } pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let mut sum = self.zero_extension(); - for term in terms { - sum = self.add_extension(sum, *term); + let zero = self.zero_extension(); + let mut terms = terms.to_vec(); + if terms.len() % 2 == 1 { + terms.push(zero); } - sum + let mut acc0 = zero; + let mut acc1 = zero; + for chunk in terms.chunks_exact(2) { + (acc0, acc1) = self.add_two_extension(acc0, acc1, chunk[0], chunk[1]); + } + self.add_extension(acc0, acc1) } - /// TODO: Change this to using an `arithmetic_extension` function once `MulExtensionGate` supports addend. pub fn sub_extension( &mut self, - mut a: ExtensionTarget, + a: ExtensionTarget, b: ExtensionTarget, ) -> ExtensionTarget { - for i in 0..D { - a.0[i] = self.sub(a.0[i], b.0[i]); - } - a + let one = self.one_extension(); + self.arithmetic_extension(F::ONE, F::NEG_ONE, one, a, b) + } + + pub fn sub_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let one = self.one_extension(); + self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, a1, b1) } pub fn sub_ext_algebra( &mut self, - mut a: ExtensionAlgebraTarget, + a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - for i in 0..D { - a.0[i] = self.sub_extension(a.0[i], b.0[i]); + let mut res = vec![]; + let d_even = D & (D ^ 1); // = 2 * (D/2) + for mut chunk in &(0..d_even).chunks(2) { + let i = chunk.next().unwrap(); + let j = chunk.next().unwrap(); + let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); + res.extend([o0, o1]); } - a + if D % 2 == 1 { + res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); + } + ExtensionAlgebraTarget(res.try_into().unwrap()) } pub fn mul_extension_with_const( @@ -169,17 +269,17 @@ impl, const D: usize> CircuitBuilder { multiplicand_0: ExtensionTarget, multiplicand_1: ExtensionTarget, ) -> ExtensionTarget { - let gate = self.add_gate(MulExtensionGate::new(), vec![const_0]); - - let wire_multiplicand_0 = - ExtensionTarget::from_range(gate, MulExtensionGate::::wires_multiplicand_0()); - let wire_multiplicand_1 = - ExtensionTarget::from_range(gate, MulExtensionGate::::wires_multiplicand_1()); - let wire_output = ExtensionTarget::from_range(gate, MulExtensionGate::::wires_output()); - - self.route_extension(multiplicand_0, wire_multiplicand_0); - self.route_extension(multiplicand_1, wire_multiplicand_1); - wire_output + let zero = self.zero_extension(); + self.double_arithmetic_extension( + const_0, + F::ZERO, + multiplicand_0, + multiplicand_1, + zero, + zero, + zero, + ) + .0 } pub fn mul_extension( @@ -199,12 +299,11 @@ impl, const D: usize> CircuitBuilder { let w = self.constant(F::Extension::W); for i in 0..D { for j in 0..D { - let ai_bi = self.mul_extension(a.0[i], b.0[j]); res[(i + j) % D] = if i + j < D { - self.add_extension(ai_bi, res[(i + j) % D]) + self.mul_add_extension(a.0[i], b.0[j], res[(i + j) % D]) } else { - let w_ai_bi = self.scalar_mul_ext(w, ai_bi); - self.add_extension(w_ai_bi, res[(i + j) % D]) + let ai_bi = self.mul_extension(a.0[i], b.0[j]); + self.scalar_mul_add_extension(w, ai_bi, res[(i + j) % D]) } } } @@ -221,28 +320,35 @@ impl, const D: usize> CircuitBuilder { /// Like `mul_add`, but for `ExtensionTarget`s. Note that, unlike `mul_add`, this has no /// performance benefit over separate muls and adds. - /// TODO: Change this to using an `arithmetic_extension` function once `MulExtensionGate` supports addend. pub fn mul_add_extension( &mut self, a: ExtensionTarget, b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { - let product = self.mul_extension(a, b); - self.add_extension(product, c) + self.arithmetic_extension(F::ONE, F::ONE, a, b, c) } - /// Like `mul_sub`, but for `ExtensionTarget`s. Note that, unlike `mul_sub`, this has no - /// performance benefit over separate muls and subs. - /// TODO: Change this to using an `arithmetic_extension` function once `MulExtensionGate` supports addend. + /// Like `mul_add`, but for `ExtensionTarget`s. + pub fn scalar_mul_add_extension( + &mut self, + a: Target, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let a_ext = self.convert_to_ext(a); + self.arithmetic_extension(F::ONE, F::ONE, a_ext, b, c) + } + + /// Like `mul_sub`, but for `ExtensionTarget`s. pub fn scalar_mul_sub_extension( &mut self, a: Target, b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { - let product = self.scalar_mul_ext(a, b); - self.sub_extension(product, c) + let a_ext = self.convert_to_ext(a); + self.arithmetic_extension(F::ONE, F::NEG_ONE, a_ext, b, c) } /// Returns `a * b`, where `b` is in the extension field and `a` is in the base field. diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 520f8cd9..7fc7e08d 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,11 +1,9 @@ -use std::convert::TryInto; - use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::gates::arithmetic::ArithmeticGate; -use crate::gates::mul_extension::MulExtensionGate; +use crate::gates::mul_extension::ArithmeticExtensionGate; use crate::generator::SimpleGenerator; use crate::target::Target; use crate::wire::Wire; @@ -253,16 +251,14 @@ impl, const D: usize> CircuitBuilder { y: ExtensionTarget, ) -> ExtensionTarget { // Add an `ArithmeticGate` to compute `q * y`. - let gate = self.add_gate(MulExtensionGate::new(), vec![F::ONE]); + let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]); let multiplicand_0 = - Target::wires_from_range(gate, MulExtensionGate::::wires_multiplicand_0()); - let multiplicand_0 = ExtensionTarget(multiplicand_0.try_into().unwrap()); + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); let multiplicand_1 = - Target::wires_from_range(gate, MulExtensionGate::::wires_multiplicand_1()); - let multiplicand_1 = ExtensionTarget(multiplicand_1.try_into().unwrap()); - let output = Target::wires_from_range(gate, MulExtensionGate::::wires_output()); - let output = ExtensionTarget(output.try_into().unwrap()); + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); + let output = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); self.add_generator(QuotientGeneratorExtension { numerator: x, diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 40b91c1e..37746685 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -72,13 +72,10 @@ mod tests { fn test_interpolate() { type F = CrandallField; type FF = QuarticCrandallField; - let config = CircuitConfig { - num_routed_wires: 18, - ..CircuitConfig::large_config() - }; + let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config); - let len = 2; + let len = 4; let points = (0..len) .map(|_| (F::rand(), FF::rand())) .collect::>(); diff --git a/src/gates/mul_extension.rs b/src/gates/mul_extension.rs index e378e2b1..a3006837 100644 --- a/src/gates/mul_extension.rs +++ b/src/gates/mul_extension.rs @@ -10,39 +10,63 @@ use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; use crate::witness::PartialWitness; -/// A gate which can multiply two field extension elements. -/// TODO: Add an addend if `NUM_ROUTED_WIRES` is large enough. +/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. #[derive(Debug)] -pub struct MulExtensionGate; +pub struct ArithmeticExtensionGate; -impl MulExtensionGate { +impl ArithmeticExtensionGate { pub fn new>() -> GateRef { - GateRef::new(MulExtensionGate) + GateRef::new(ArithmeticExtensionGate) } - pub fn wires_multiplicand_0() -> Range { + pub fn wires_fixed_multiplicand() -> Range { 0..D } - pub fn wires_multiplicand_1() -> Range { + pub fn wires_multiplicand_0() -> Range { D..2 * D } - pub fn wires_output() -> Range { + pub fn wires_addend_0() -> Range { 2 * D..3 * D } + pub fn wires_multiplicand_1() -> Range { + 3 * D..4 * D + } + pub fn wires_addend_1() -> Range { + 4 * D..5 * D + } + pub fn wires_output_0() -> Range { + 5 * D..6 * D + } + pub fn wires_output_1() -> Range { + 6 * D..7 * D + } } -impl, const D: usize> Gate for MulExtensionGate { +impl, const D: usize> Gate for ArithmeticExtensionGate { fn id(&self) -> String { format!("{:?}", self) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let output = vars.get_local_ext_algebra(Self::wires_output()); - let computed_output = multiplicand_0 * multiplicand_1 * const_0.into(); - (output - computed_output).to_basefield_array().to_vec() + let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); + let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); + let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + + let computed_output_0 = + fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into(); + let computed_output_1 = + fixed_multiplicand * multiplicand_1 * const_1.into() + addend_1 * const_1.into(); + + let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec(); + constraints.extend((output_1 - computed_output_1).to_basefield_array()); + constraints } fn eval_unfiltered_recursively( @@ -51,13 +75,31 @@ impl, const D: usize> Gate for MulExtensionGate { vars: EvaluationTargets, ) -> Vec> { let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let output = vars.get_local_ext_algebra(Self::wires_output()); - let computed_output = builder.mul_ext_algebra(multiplicand_0, multiplicand_1); - let computed_output = builder.scalar_mul_ext_algebra(const_0, computed_output); - let diff = builder.sub_ext_algebra(output, computed_output); - diff.to_ext_target_array().to_vec() + let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); + let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); + let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + + let computed_output_0 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_0); + let computed_output_0 = builder.scalar_mul_ext_algebra(const_0, computed_output_0); + let scaled_addend_0 = builder.scalar_mul_ext_algebra(const_1, addend_0); + let computed_output_0 = builder.add_ext_algebra(computed_output_0, scaled_addend_0); + + let computed_output_1 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_1); + let computed_output_1 = builder.scalar_mul_ext_algebra(const_0, computed_output_1); + let scaled_addend_1 = builder.scalar_mul_ext_algebra(const_1, addend_1); + let computed_output_1 = builder.add_ext_algebra(computed_output_1, scaled_addend_1); + + let diff_0 = builder.sub_ext_algebra(output_0, computed_output_0); + let diff_1 = builder.sub_ext_algebra(output_1, computed_output_1); + let mut constraints = diff_0.to_ext_target_array().to_vec(); + constraints.extend(diff_1.to_ext_target_array()); + constraints } fn generators( @@ -68,16 +110,17 @@ impl, const D: usize> Gate for MulExtensionGate { let gen = MulExtensionGenerator { gate_index, const_0: local_constants[0], + const_1: local_constants[1], }; vec![Box::new(gen)] } fn num_wires(&self) -> usize { - 12 + 28 } fn num_constants(&self) -> usize { - 1 + 2 } fn degree(&self) -> usize { @@ -85,49 +128,59 @@ impl, const D: usize> Gate for MulExtensionGate { } fn num_constraints(&self) -> usize { - D + 2 * D } } struct MulExtensionGenerator, const D: usize> { gate_index: usize, const_0: F, + const_1: F, } impl, const D: usize> SimpleGenerator for MulExtensionGenerator { fn dependencies(&self) -> Vec { - MulExtensionGate::::wires_multiplicand_0() - .chain(MulExtensionGate::::wires_multiplicand_1()) - .map(|i| { - Target::Wire(Wire { - gate: self.gate_index, - input: i, - }) - }) + ArithmeticExtensionGate::::wires_fixed_multiplicand() + .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) + .chain(ArithmeticExtensionGate::::wires_addend_0()) + .chain(ArithmeticExtensionGate::::wires_multiplicand_1()) + .chain(ArithmeticExtensionGate::::wires_addend_1()) + .map(|i| Target::wire(self.gate_index, i)) .collect() } fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let multiplicand_0_target = ExtensionTarget::from_range( + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let fixed_multiplicand = + extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); + let multiplicand_0 = + extract_extension(ArithmeticExtensionGate::::wires_multiplicand_0()); + let addend_0 = extract_extension(ArithmeticExtensionGate::::wires_addend_0()); + let multiplicand_1 = + extract_extension(ArithmeticExtensionGate::::wires_multiplicand_1()); + let addend_1 = extract_extension(ArithmeticExtensionGate::::wires_addend_1()); + + let output_target_0 = ExtensionTarget::from_range( self.gate_index, - MulExtensionGate::::wires_multiplicand_0(), + ArithmeticExtensionGate::::wires_output_0(), ); - let multiplicand_0 = witness.get_extension_target(multiplicand_0_target); - - let multiplicand_1_target = ExtensionTarget::from_range( + let output_target_1 = ExtensionTarget::from_range( self.gate_index, - MulExtensionGate::::wires_multiplicand_1(), + ArithmeticExtensionGate::::wires_output_1(), ); - let multiplicand_1 = witness.get_extension_target(multiplicand_1_target); - let output_target = - ExtensionTarget::from_range(self.gate_index, MulExtensionGate::::wires_output()); - - let computed_output = - F::Extension::from_basefield(self.const_0) * multiplicand_0 * multiplicand_1; + let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() + + addend_0 * self.const_1.into(); + let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into() + + addend_1 * self.const_1.into(); let mut pw = PartialWitness::new(); - pw.set_extension_target(output_target, computed_output); + pw.set_extension_target(output_target_0, computed_output_0); + pw.set_extension_target(output_target_1, computed_output_1); pw } } @@ -136,10 +189,10 @@ impl, const D: usize> SimpleGenerator for MulExtensionGenera mod tests { use crate::field::crandall_field::CrandallField; use crate::gates::gate_testing::test_low_degree; - use crate::gates::mul_extension::MulExtensionGate; + use crate::gates::mul_extension::ArithmeticExtensionGate; #[test] fn low_degree() { - test_low_degree(MulExtensionGate::<4>::new::()) + test_low_degree(ArithmeticExtensionGate::<4>::new::()) } } diff --git a/src/lib.rs b/src/lib.rs index 09f90f81..adfdf2cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(destructuring_assignment)] + pub mod circuit_builder; pub mod circuit_data; pub mod field; diff --git a/src/recursive_verifier.rs b/src/recursive_verifier.rs index c1005a67..bb9d724e 100644 --- a/src/recursive_verifier.rs +++ b/src/recursive_verifier.rs @@ -5,7 +5,7 @@ use crate::gates::gate::GateRef; use crate::proof::ProofTarget; const MIN_WIRES: usize = 120; // TODO: Double check. -const MIN_ROUTED_WIRES: usize = 8; // TODO: Double check. +const MIN_ROUTED_WIRES: usize = 28; // TODO: Double check. /// Recursively verifies an inner proof. pub fn add_recursive_verifier, const D: usize>( From ff74887ab981fd07861b20f78360ce00f8b70150 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 23 Jun 2021 18:06:53 +0200 Subject: [PATCH 02/22] Use `with_capacity` when length is known --- src/field/extension_field/target.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index f6cec119..14a5c4c7 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -196,7 +196,7 @@ impl, const D: usize> CircuitBuilder { a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - let mut res = vec![]; + let mut res = Vec::with_capacity(D); let d_even = D & (D ^ 1); // = 2 * (D/2) for mut chunk in &(0..d_even).chunks(2) { let i = chunk.next().unwrap(); @@ -249,7 +249,7 @@ impl, const D: usize> CircuitBuilder { a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - let mut res = vec![]; + let mut res = Vec::with_capacity(D); let d_even = D & (D ^ 1); // = 2 * (D/2) for mut chunk in &(0..d_even).chunks(2) { let i = chunk.next().unwrap(); From 6652b38b998cc0372882e96b1f12c4483bdb58de Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Jun 2021 13:53:14 +0200 Subject: [PATCH 03/22] Remove `ArithmeticGate` --- src/gadgets/arithmetic.rs | 78 ++++++++------------------------------- 1 file changed, 15 insertions(+), 63 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 7fc7e08d..0b2645e9 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -2,7 +2,6 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; -use crate::gates::arithmetic::ArithmeticGate; use crate::gates::mul_extension::ArithmeticExtensionGate; use crate::generator::SimpleGenerator; use crate::target::Target; @@ -41,30 +40,18 @@ impl, const D: usize> CircuitBuilder { { return result; } + let multiplicand_0_ext = self.convert_to_ext(multiplicand_0); + let multiplicand_1_ext = self.convert_to_ext(multiplicand_1); + let addend_ext = self.convert_to_ext(addend); - let gate = self.add_gate(ArithmeticGate::new(), vec![const_0, const_1]); - - let wire_multiplicand_0 = Wire { - gate, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }; - let wire_multiplicand_1 = Wire { - gate, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }; - let wire_addend = Wire { - gate, - input: ArithmeticGate::WIRE_ADDEND, - }; - let wire_output = Wire { - gate, - input: ArithmeticGate::WIRE_OUTPUT, - }; - - self.route(multiplicand_0, Target::Wire(wire_multiplicand_0)); - self.route(multiplicand_1, Target::Wire(wire_multiplicand_1)); - self.route(addend, Target::Wire(wire_addend)); - Target::Wire(wire_output) + self.arithmetic_extension( + const_0, + const_1, + multiplicand_0_ext, + multiplicand_1_ext, + addend_ext, + ) + .0[0] } /// Checks for special cases where the value of @@ -205,42 +192,9 @@ impl, const D: usize> CircuitBuilder { return self.constant(x_const / y_const); } - // Add an `ArithmeticGate` to compute `q * y`. - let gate = self.add_gate(ArithmeticGate::new(), vec![F::ONE, F::ZERO]); - - let wire_multiplicand_0 = Wire { - gate, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }; - let wire_multiplicand_1 = Wire { - gate, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }; - let wire_addend = Wire { - gate, - input: ArithmeticGate::WIRE_ADDEND, - }; - let wire_output = Wire { - gate, - input: ArithmeticGate::WIRE_OUTPUT, - }; - - let q = Target::Wire(wire_multiplicand_0); - self.add_generator(QuotientGenerator { - numerator: x, - denominator: y, - quotient: q, - }); - - self.route(y, Target::Wire(wire_multiplicand_1)); - - // This can be anything, since the whole second term has a weight of zero. - self.route(zero, Target::Wire(wire_addend)); - - let q_y = Target::Wire(wire_output); - self.assert_equal(q_y, x); - - q + let x_ext = self.convert_to_ext(x); + let y_ext = self.convert_to_ext(y); + self.div_unsafe_extension(x_ext, y_ext).0[0] } /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in @@ -250,7 +204,7 @@ impl, const D: usize> CircuitBuilder { x: ExtensionTarget, y: ExtensionTarget, ) -> ExtensionTarget { - // Add an `ArithmeticGate` to compute `q * y`. + // Add an `ArithmeticExtensionGate` to compute `q * y`. let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]); let multiplicand_0 = @@ -381,8 +335,6 @@ mod tests { let x = FF::rand(); let y = FF::rand(); - let x = FF::TWO; - let y = FF::ONE; let z = x / y; let xt = builder.constant_extension(x); let yt = builder.constant_extension(y); From fd3fa739a6d23159e301c531cc67d3a4ce99f67e Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Jun 2021 13:56:43 +0200 Subject: [PATCH 04/22] Fix test relying on `ArithmeticGate` --- src/gates/gate_tree.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index 1423de76..9dc188a0 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -186,12 +186,11 @@ impl, const D: usize> Tree> { mod tests { use super::*; use crate::field::crandall_field::CrandallField; - use crate::gates::arithmetic::ArithmeticGate; use crate::gates::base_sum::BaseSumGate; use crate::gates::constant::ConstantGate; use crate::gates::gmimc::GMiMCGate; use crate::gates::interpolation::InterpolationGate; - use crate::gates::mul_extension::MulExtensionGate; + use crate::gates::mul_extension::ArithmeticExtensionGate; use crate::gates::noop::NoopGate; use crate::hash::GMIMC_ROUNDS; @@ -204,11 +203,10 @@ mod tests { let gates = vec![ NoopGate::get::(), ConstantGate::get(), - ArithmeticGate::new(), + ArithmeticExtensionGate::new(), BaseSumGate::<4>::new(4), GMiMCGate::::with_automatic_constants(), InterpolationGate::new(4), - MulExtensionGate::new(), ]; let len = gates.len(); From beadce72fcae1e4cf990fa90b541705e54080238 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Jun 2021 15:11:49 +0200 Subject: [PATCH 05/22] Add `ZeroOutGenerator` --- src/field/extension_field/target.rs | 1 - src/gadgets/arithmetic.rs | 63 +++++++++------ src/util/scaling.rs | 118 +++++++++++++++++++++++++++- 3 files changed, 158 insertions(+), 24 deletions(-) diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 14a5c4c7..e6f5c38b 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -146,7 +146,6 @@ impl, const D: usize> CircuitBuilder { self.route_extension(addend_0, wire_addend_0); self.route_extension(multiplicand_1, wire_multiplicand_1); self.route_extension(addend_1, wire_addend_1); - self.route_extension(multiplicand_1, wire_multiplicand_1); (wire_output_0, wire_output_1) } diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 0b2645e9..69eaea48 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,3 +1,5 @@ +use std::ops::Range; + use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; @@ -207,10 +209,12 @@ impl, const D: usize> CircuitBuilder { // Add an `ArithmeticExtensionGate` to compute `q * y`. let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]); - let multiplicand_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); + let multiplicand_0 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ); let multiplicand_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); let output = ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); @@ -219,33 +223,22 @@ impl, const D: usize> CircuitBuilder { denominator: y, quotient: multiplicand_0, }); + self.add_generator(ZeroOutGenerator { + gate_index: gate, + ranges: vec![ + ArithmeticExtensionGate::::wires_addend_0(), + ArithmeticExtensionGate::::wires_multiplicand_1(), + ArithmeticExtensionGate::::wires_addend_1(), + ], + }); self.route_extension(y, multiplicand_1); - self.assert_equal_extension(output, x); multiplicand_0 } } -struct QuotientGenerator { - numerator: Target, - denominator: Target, - quotient: Target, -} - -impl SimpleGenerator for QuotientGenerator { - fn dependencies(&self) -> Vec { - vec![self.numerator, self.denominator] - } - - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let num = witness.get_target(self.numerator); - let den = witness.get_target(self.denominator); - PartialWitness::singleton_target(self.quotient, num / den) - } -} - struct QuotientGeneratorExtension { numerator: ExtensionTarget, denominator: ExtensionTarget, @@ -270,6 +263,32 @@ impl, const D: usize> SimpleGenerator for QuotientGeneratorE quotient.to_basefield_array()[i], ); } + + pw + } +} + +/// Generator used to zero out wires at a given gate index and ranges. +pub struct ZeroOutGenerator { + gate_index: usize, + ranges: Vec>, +} + +impl SimpleGenerator for ZeroOutGenerator { + fn dependencies(&self) -> Vec { + Vec::new() + } + + fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { + let mut pw = PartialWitness::new(); + for t in self + .ranges + .iter() + .flat_map(|r| Target::wires_from_range(self.gate_index, r.clone())) + { + pw.set_target(t, F::ZERO); + } + pw } } diff --git a/src/util/scaling.rs b/src/util/scaling.rs index cea86195..1d422e8a 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -1,8 +1,14 @@ use std::borrow::Borrow; -use crate::field::extension_field::Frobenius; +use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::{Extendable, Frobenius}; use crate::field::field::Field; +use crate::gates::mul_extension::ArithmeticExtensionGate; +use crate::generator::SimpleGenerator; use crate::polynomial::polynomial::PolynomialCoeffs; +use crate::target::Target; +use crate::witness::PartialWitness; /// When verifying the composition polynomial in FRI we have to compute sums of the form /// `(sum_0^k a^i * x_i)/d_0 + (sum_k^r a^i * y_i)/d_1` @@ -73,3 +79,113 @@ impl ReducingFactor { } } } + +// #[derive(Debug, Copy, Clone)] +// pub struct ReducingFactorTarget { +// base: ExtensionTarget, +// count: u64, +// } +// +// impl, const D: usize> ReducingFactorTarget { +// pub fn new(base: ExtensionTarget) -> Self { +// Self { base, count: 0 } +// } +// +// fn mul( +// &mut self, +// x: ExtensionTarget, +// builder: &mut CircuitBuilder, +// ) -> ExtensionTarget { +// self.count += 1; +// builder.mul_extension(self.base, x) +// } +// +// pub fn reduce( +// &mut self, +// iter: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. +// builder: &mut CircuitBuilder, +// ) -> ExtensionTarget { +// let l = iter.len(); +// let padded_iter = if l % 2 == 0 { +// iter.to_vec() +// } else { +// [iter, &[builder.zero_extension()]].concat() +// }; +// let half_length = padded_iter.len() / 2; +// let gates = (0..half_length) +// .map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE])) +// .collect::>(); +// +// struct ParallelReductionGenerator<'a, const D: usize> { +// base: ExtensionTarget, +// padded_iter: &'a [ExtensionTarget], +// gates: &'a [usize], +// half_length: usize, +// } +// +// impl<'a, F: Extendable, const D: usize> SimpleGenerator +// for ParallelReductionGenerator<'a, D> +// { +// fn dependencies(&self) -> Vec { +// self.padded_iter +// .iter() +// .flat_map(|ext| ext.to_target_array()) +// .chain(self.base.to_target_array()) +// .collect() +// } +// +// fn run_once(&self, witness: &PartialWitness) -> PartialWitness { +// let mut pw = PartialWitness::new(); +// let base = witness.get_extension_target(self.base); +// let vs = self +// .padded_iter +// .iter() +// .map(|&ext| witness.get_extension_target(ext)) +// .collect::>(); +// let first_half = &vs[..self.half_length]; +// let intermediate_acc = base.reduce(first_half); +// } +// } +// } +// +// pub fn reduce_parallel( +// &mut self, +// iter0: impl DoubleEndedIterator>>, +// iter1: impl DoubleEndedIterator>>, +// builder: &mut CircuitBuilder, +// ) -> (ExtensionTarget, ExtensionTarget) { +// iter.rev().fold(builder.zero_extension(), |acc, x| { +// builder.arithmetic_extension(F::ONE, F::ONE, self.base, acc, x) +// }) +// } +// +// pub fn shift( +// &mut self, +// x: ExtensionTarget, +// builder: &mut CircuitBuilder, +// ) -> ExtensionTarget { +// let tmp = self.base.exp(self.count) * x; +// self.count = 0; +// tmp +// } +// +// pub fn shift_poly( +// &mut self, +// p: &mut PolynomialCoeffs>, +// builder: &mut CircuitBuilder, +// ) { +// *p *= self.base.exp(self.count); +// self.count = 0; +// } +// +// pub fn reset(&mut self) { +// self.count = 0; +// } +// +// pub fn repeated_frobenius(&self, count: usize, builder: &mut CircuitBuilder) -> Self { +// Self { +// base: self.base.repeated_frobenius(count), +// count: self.count, +// } +// } +// } From 8a119f035d2e93070302adeba62645d29d78068a Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Jun 2021 16:27:20 +0200 Subject: [PATCH 06/22] Working ReducingFactorTarget --- src/gadgets/arithmetic.rs | 49 +++++- src/util/scaling.rs | 358 ++++++++++++++++++++++++++------------ 2 files changed, 295 insertions(+), 112 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 69eaea48..9cdeded4 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -7,6 +7,7 @@ use crate::field::field::Field; use crate::gates::mul_extension::ArithmeticExtensionGate; use crate::generator::SimpleGenerator; use crate::target::Target; +use crate::util::bits_u64; use crate::wire::Wire; use crate::witness::PartialWitness; @@ -22,6 +23,11 @@ impl, const D: usize> CircuitBuilder { self.mul(x, x) } + /// Computes `x^2`. + pub fn square_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { + self.mul_extension(x, x) + } + /// Computes `x^3`. pub fn cube(&mut self, x: Target) -> Target { self.mul_many(&[x, x, x]) @@ -161,21 +167,58 @@ impl, const D: usize> CircuitBuilder { } // TODO: Optimize this, maybe with a new gate. + // TODO: Test /// Exponentiate `base` to the power of `exponent`, where `exponent < 2^num_bits`. pub fn exp(&mut self, base: Target, exponent: Target, num_bits: usize) -> Target { let mut current = base; - let one = self.one(); - let mut product = one; + let one_ext = self.one_extension(); + let mut product = self.one(); let exponent_bits = self.split_le(exponent, num_bits); for bit in exponent_bits.into_iter() { - product = self.mul_many(&[bit, current, product]); + let current_ext = self.convert_to_ext(current); + let multiplicand = self.select(bit, current_ext, one_ext); + product = self.mul(product, multiplicand.0[0]); current = self.mul(current, current); } product } + /// Exponentiate `base` to the power of a known `exponent`. + // TODO: Test + pub fn exp_u64(&mut self, base: Target, exponent: u64) -> Target { + let mut current = base; + let mut product = self.one(); + + for j in 0..bits_u64(exponent as u64) { + if (exponent >> j & 1) != 0 { + product = self.mul(product, current); + } + current = self.square(current); + } + product + } + + /// Exponentiate `base` to the power of a known `exponent`. + // TODO: Test + pub fn exp_u64_extension( + &mut self, + base: ExtensionTarget, + exponent: u64, + ) -> ExtensionTarget { + let mut current = base; + let mut product = self.one_extension(); + + for j in 0..bits_u64(exponent as u64) { + if (exponent >> j & 1) != 0 { + product = self.mul_extension(product, current); + } + current = self.square_extension(current); + } + product + } + /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in /// some cases, as it allows `0 / 0 = `. pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target { diff --git a/src/util/scaling.rs b/src/util/scaling.rs index 1d422e8a..057e8467 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -80,112 +80,252 @@ impl ReducingFactor { } } -// #[derive(Debug, Copy, Clone)] -// pub struct ReducingFactorTarget { -// base: ExtensionTarget, -// count: u64, -// } -// -// impl, const D: usize> ReducingFactorTarget { -// pub fn new(base: ExtensionTarget) -> Self { -// Self { base, count: 0 } -// } -// -// fn mul( -// &mut self, -// x: ExtensionTarget, -// builder: &mut CircuitBuilder, -// ) -> ExtensionTarget { -// self.count += 1; -// builder.mul_extension(self.base, x) -// } -// -// pub fn reduce( -// &mut self, -// iter: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. -// builder: &mut CircuitBuilder, -// ) -> ExtensionTarget { -// let l = iter.len(); -// let padded_iter = if l % 2 == 0 { -// iter.to_vec() -// } else { -// [iter, &[builder.zero_extension()]].concat() -// }; -// let half_length = padded_iter.len() / 2; -// let gates = (0..half_length) -// .map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE])) -// .collect::>(); -// -// struct ParallelReductionGenerator<'a, const D: usize> { -// base: ExtensionTarget, -// padded_iter: &'a [ExtensionTarget], -// gates: &'a [usize], -// half_length: usize, -// } -// -// impl<'a, F: Extendable, const D: usize> SimpleGenerator -// for ParallelReductionGenerator<'a, D> -// { -// fn dependencies(&self) -> Vec { -// self.padded_iter -// .iter() -// .flat_map(|ext| ext.to_target_array()) -// .chain(self.base.to_target_array()) -// .collect() -// } -// -// fn run_once(&self, witness: &PartialWitness) -> PartialWitness { -// let mut pw = PartialWitness::new(); -// let base = witness.get_extension_target(self.base); -// let vs = self -// .padded_iter -// .iter() -// .map(|&ext| witness.get_extension_target(ext)) -// .collect::>(); -// let first_half = &vs[..self.half_length]; -// let intermediate_acc = base.reduce(first_half); -// } -// } -// } -// -// pub fn reduce_parallel( -// &mut self, -// iter0: impl DoubleEndedIterator>>, -// iter1: impl DoubleEndedIterator>>, -// builder: &mut CircuitBuilder, -// ) -> (ExtensionTarget, ExtensionTarget) { -// iter.rev().fold(builder.zero_extension(), |acc, x| { -// builder.arithmetic_extension(F::ONE, F::ONE, self.base, acc, x) -// }) -// } -// -// pub fn shift( -// &mut self, -// x: ExtensionTarget, -// builder: &mut CircuitBuilder, -// ) -> ExtensionTarget { -// let tmp = self.base.exp(self.count) * x; -// self.count = 0; -// tmp -// } -// -// pub fn shift_poly( -// &mut self, -// p: &mut PolynomialCoeffs>, -// builder: &mut CircuitBuilder, -// ) { -// *p *= self.base.exp(self.count); -// self.count = 0; -// } -// -// pub fn reset(&mut self) { -// self.count = 0; -// } -// -// pub fn repeated_frobenius(&self, count: usize, builder: &mut CircuitBuilder) -> Self { -// Self { -// base: self.base.repeated_frobenius(count), -// count: self.count, -// } -// } -// } +#[derive(Debug, Copy, Clone)] +pub struct ReducingFactorTarget { + base: ExtensionTarget, + count: u64, +} + +impl ReducingFactorTarget { + pub fn new(base: ExtensionTarget) -> Self { + Self { base, count: 0 } + } + + pub fn reduce( + &mut self, + iter: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: Extendable, + { + let l = iter.len(); + self.count += l as u64; + let padded_iter = if l % 2 == 0 { + iter.to_vec() + } else { + [iter, &[builder.zero_extension()]].concat() + }; + let half_length = padded_iter.len() / 2; + let gates = (0..half_length) + .map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE])) + .collect::>(); + + builder.add_generator(ParallelReductionGenerator { + base: self.base, + padded_iter: padded_iter.clone(), + gates: gates.clone(), + half_length, + }); + + for i in 0..half_length { + builder.route_extension( + ExtensionTarget::from_range( + gates[i], + ArithmeticExtensionGate::::wires_addend_0(), + ), + padded_iter[2 * half_length - i - 1], + ); + } + for i in 0..half_length { + builder.route_extension( + ExtensionTarget::from_range( + gates[i], + ArithmeticExtensionGate::::wires_addend_1(), + ), + padded_iter[half_length - i - 1], + ); + } + for gate_pair in gates[..half_length].windows(2) { + builder.assert_equal_extension( + ExtensionTarget::from_range( + gate_pair[0], + ArithmeticExtensionGate::::wires_output_0(), + ), + ExtensionTarget::from_range( + gate_pair[1], + ArithmeticExtensionGate::::wires_multiplicand_0(), + ), + ); + } + for gate_pair in gates[half_length..].windows(2) { + builder.assert_equal_extension( + ExtensionTarget::from_range( + gate_pair[0], + ArithmeticExtensionGate::::wires_output_1(), + ), + ExtensionTarget::from_range( + gate_pair[1], + ArithmeticExtensionGate::::wires_multiplicand_1(), + ), + ); + } + builder.assert_equal_extension( + ExtensionTarget::from_range( + gates[half_length - 1], + ArithmeticExtensionGate::::wires_output_0(), + ), + ExtensionTarget::from_range( + gates[0], + ArithmeticExtensionGate::::wires_multiplicand_1(), + ), + ); + + ExtensionTarget::from_range( + gates[half_length - 1], + ArithmeticExtensionGate::::wires_output_1(), + ) + } + + pub fn shift( + &mut self, + x: ExtensionTarget, + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: Extendable, + { + let exp = builder.exp_u64_extension(self.base, self.count); + let tmp = builder.mul_extension(exp, x); + self.count = 0; + tmp + } + + pub fn reset(&mut self) { + self.count = 0; + } + + pub fn repeated_frobenius(&self, count: usize, builder: &mut CircuitBuilder) -> Self + where + F: Extendable, + { + Self { + base: self.base.repeated_frobenius(count, builder), + count: self.count, + } + } +} + +struct ParallelReductionGenerator { + base: ExtensionTarget, + padded_iter: Vec>, + gates: Vec, + half_length: usize, +} + +impl, const D: usize> SimpleGenerator for ParallelReductionGenerator { + fn dependencies(&self) -> Vec { + self.padded_iter + .iter() + .flat_map(|ext| ext.to_target_array()) + .chain(self.base.to_target_array()) + .collect() + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let mut pw = PartialWitness::new(); + let base = witness.get_extension_target(self.base); + let vs = self + .padded_iter + .iter() + .map(|&ext| witness.get_extension_target(ext)) + .collect::>(); + let intermediate_accs = vs + .iter() + .rev() + .scan(F::Extension::ZERO, |acc, &x| { + let tmp = *acc; + *acc = *acc * base + x; + Some(tmp) + }) + .collect::>(); + for i in 0..self.half_length { + pw.set_extension_target( + ExtensionTarget::from_range( + self.gates[i], + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ), + base, + ); + pw.set_extension_target( + ExtensionTarget::from_range( + self.gates[i], + ArithmeticExtensionGate::::wires_multiplicand_0(), + ), + intermediate_accs[i], + ); + pw.set_extension_target( + ExtensionTarget::from_range( + self.gates[i], + ArithmeticExtensionGate::::wires_addend_0(), + ), + vs[2 * self.half_length - i - 1], + ); + pw.set_extension_target( + ExtensionTarget::from_range( + self.gates[i], + ArithmeticExtensionGate::::wires_multiplicand_1(), + ), + intermediate_accs[self.half_length + i], + ); + pw.set_extension_target( + ExtensionTarget::from_range( + self.gates[i], + ArithmeticExtensionGate::::wires_addend_1(), + ), + vs[self.half_length - i - 1], + ); + } + + pw + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit_data::CircuitConfig; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + + fn test_reduce_gadget(n: usize) { + type F = CrandallField; + type FF = QuarticCrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let mut builder = CircuitBuilder::::new(config); + + let alpha = FF::rand(); + let alpha = FF::ONE; + let vs = (0..n) + .map(|i| FF::from_canonical_usize(i)) + .collect::>(); + + let manual_reduce = ReducingFactor::new(alpha).reduce(vs.iter()); + let manual_reduce = builder.constant_extension(manual_reduce); + + let mut alpha_t = ReducingFactorTarget::new(builder.constant_extension(alpha)); + let vs_t = vs + .iter() + .map(|&v| builder.constant_extension(v)) + .collect::>(); + let circuit_reduce = alpha_t.reduce(&vs_t, &mut builder); + + builder.assert_equal_extension(manual_reduce, circuit_reduce); + + let data = builder.build(); + let proof = data.prove(PartialWitness::new()); + } + + #[test] + fn test_reduce_gadget_even() { + test_reduce_gadget(10); + } + + #[test] + fn test_reduce_gadget_odd() { + test_reduce_gadget(11); + } +} From b62c2e699034e94cc9ea39991420667500f8080c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Jun 2021 16:31:10 +0200 Subject: [PATCH 07/22] Supplant ArithmeticGate with ArithmeticExtensionGate --- src/field/extension_field/target.rs | 2 +- src/gadgets/arithmetic.rs | 2 +- src/gates/arithmetic.rs | 183 +++++++++++++++---------- src/gates/gate_tree.rs | 2 +- src/gates/mod.rs | 3 +- src/gates/mul_extension.rs | 198 ---------------------------- src/util/scaling.rs | 2 +- 7 files changed, 117 insertions(+), 275 deletions(-) delete mode 100644 src/gates/mul_extension.rs diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index e6f5c38b..8eb86043 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -7,7 +7,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::{Extendable, FieldExtension, OEF}; use crate::field::field::Field; -use crate::gates::mul_extension::ArithmeticExtensionGate; +use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::target::Target; /// `Target`s representing an element of an extension field. diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 9cdeded4..99e260c7 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -4,7 +4,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; -use crate::gates::mul_extension::ArithmeticExtensionGate; +use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::generator::SimpleGenerator; use crate::target::Target; use crate::util::bits_u64; diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 0d0fdd7c..586a0cc2 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -1,7 +1,8 @@ +use std::ops::Range; + use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::Extendable; -use crate::field::field::Field; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator}; use crate::target::Target; @@ -9,26 +10,39 @@ use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; use crate::witness::PartialWitness; -/// A gate which can be configured to perform various arithmetic. In particular, it computes -/// -/// ```text -/// output := const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend -/// ``` +/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. #[derive(Debug)] -pub struct ArithmeticGate; +pub struct ArithmeticExtensionGate; -impl ArithmeticGate { - pub fn new, const D: usize>() -> GateRef { - GateRef::new(ArithmeticGate) +impl ArithmeticExtensionGate { + pub fn new>() -> GateRef { + GateRef::new(ArithmeticExtensionGate) } - pub const WIRE_MULTIPLICAND_0: usize = 0; - pub const WIRE_MULTIPLICAND_1: usize = 1; - pub const WIRE_ADDEND: usize = 2; - pub const WIRE_OUTPUT: usize = 3; + pub fn wires_fixed_multiplicand() -> Range { + 0..D + } + pub fn wires_multiplicand_0() -> Range { + D..2 * D + } + pub fn wires_addend_0() -> Range { + 2 * D..3 * D + } + pub fn wires_multiplicand_1() -> Range { + 3 * D..4 * D + } + pub fn wires_addend_1() -> Range { + 4 * D..5 * D + } + pub fn wires_output_0() -> Range { + 5 * D..6 * D + } + pub fn wires_output_1() -> Range { + 6 * D..7 * D + } } -impl, const D: usize> Gate for ArithmeticGate { +impl, const D: usize> Gate for ArithmeticExtensionGate { fn id(&self) -> String { format!("{:?}", self) } @@ -36,12 +50,23 @@ impl, const D: usize> Gate for ArithmeticGate { fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; - let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1]; - let addend = vars.local_wires[Self::WIRE_ADDEND]; - let output = vars.local_wires[Self::WIRE_OUTPUT]; - let computed_output = const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend; - vec![computed_output - output] + + let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); + let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); + let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); + let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + + let computed_output_0 = + fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into(); + let computed_output_1 = + fixed_multiplicand * multiplicand_1 * const_1.into() + addend_1 * const_1.into(); + + let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec(); + constraints.extend((output_1 - computed_output_1).to_basefield_array()); + constraints } fn eval_unfiltered_recursively( @@ -51,15 +76,30 @@ impl, const D: usize> Gate for ArithmeticGate { ) -> Vec> { let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; - let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1]; - let addend = vars.local_wires[Self::WIRE_ADDEND]; - let output = vars.local_wires[Self::WIRE_OUTPUT]; - let product_term = builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); - let addend_term = builder.mul_extension(const_1, addend); - let computed_output = builder.add_many_extension(&[product_term, addend_term]); - vec![builder.sub_extension(computed_output, output)] + let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); + let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); + let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); + let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + + let computed_output_0 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_0); + let computed_output_0 = builder.scalar_mul_ext_algebra(const_0, computed_output_0); + let scaled_addend_0 = builder.scalar_mul_ext_algebra(const_1, addend_0); + let computed_output_0 = builder.add_ext_algebra(computed_output_0, scaled_addend_0); + + let computed_output_1 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_1); + let computed_output_1 = builder.scalar_mul_ext_algebra(const_0, computed_output_1); + let scaled_addend_1 = builder.scalar_mul_ext_algebra(const_1, addend_1); + let computed_output_1 = builder.add_ext_algebra(computed_output_1, scaled_addend_1); + + let diff_0 = builder.sub_ext_algebra(output_0, computed_output_0); + let diff_1 = builder.sub_ext_algebra(output_1, computed_output_1); + let mut constraints = diff_0.to_ext_target_array().to_vec(); + constraints.extend(diff_1.to_ext_target_array()); + constraints } fn generators( @@ -67,7 +107,7 @@ impl, const D: usize> Gate for ArithmeticGate { gate_index: usize, local_constants: &[F], ) -> Vec>> { - let gen = ArithmeticGenerator { + let gen = MulExtensionGenerator { gate_index, const_0: local_constants[0], const_1: local_constants[1], @@ -76,7 +116,7 @@ impl, const D: usize> Gate for ArithmeticGate { } fn num_wires(&self) -> usize { - 4 + 28 } fn num_constants(&self) -> usize { @@ -88,70 +128,71 @@ impl, const D: usize> Gate for ArithmeticGate { } fn num_constraints(&self) -> usize { - 1 + 2 * D } } -struct ArithmeticGenerator { +struct MulExtensionGenerator, const D: usize> { gate_index: usize, const_0: F, const_1: F, } -impl SimpleGenerator for ArithmeticGenerator { +impl, const D: usize> SimpleGenerator for MulExtensionGenerator { fn dependencies(&self) -> Vec { - vec![ - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }), - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }), - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_ADDEND, - }), - ] + ArithmeticExtensionGate::::wires_fixed_multiplicand() + .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) + .chain(ArithmeticExtensionGate::::wires_addend_0()) + .chain(ArithmeticExtensionGate::::wires_multiplicand_1()) + .chain(ArithmeticExtensionGate::::wires_addend_1()) + .map(|i| Target::wire(self.gate_index, i)) + .collect() } fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let multiplicand_0_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }; - let multiplicand_1_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }; - let addend_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_ADDEND, - }; - let output_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_OUTPUT, + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) }; - let multiplicand_0 = witness.get_wire(multiplicand_0_target); - let multiplicand_1 = witness.get_wire(multiplicand_1_target); - let addend = witness.get_wire(addend_target); + let fixed_multiplicand = + extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); + let multiplicand_0 = + extract_extension(ArithmeticExtensionGate::::wires_multiplicand_0()); + let addend_0 = extract_extension(ArithmeticExtensionGate::::wires_addend_0()); + let multiplicand_1 = + extract_extension(ArithmeticExtensionGate::::wires_multiplicand_1()); + let addend_1 = extract_extension(ArithmeticExtensionGate::::wires_addend_1()); - let output = self.const_0 * multiplicand_0 * multiplicand_1 + self.const_1 * addend; + let output_target_0 = ExtensionTarget::from_range( + self.gate_index, + ArithmeticExtensionGate::::wires_output_0(), + ); + let output_target_1 = ExtensionTarget::from_range( + self.gate_index, + ArithmeticExtensionGate::::wires_output_1(), + ); - PartialWitness::singleton_wire(output_target, output) + let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() + + addend_0 * self.const_1.into(); + let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into() + + addend_1 * self.const_1.into(); + + let mut pw = PartialWitness::new(); + pw.set_extension_target(output_target_0, computed_output_0); + pw.set_extension_target(output_target_1, computed_output_1); + pw } } #[cfg(test)] mod tests { use crate::field::crandall_field::CrandallField; - use crate::gates::arithmetic::ArithmeticGate; + use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::gate_testing::test_low_degree; #[test] fn low_degree() { - test_low_degree(ArithmeticGate::new::()) + test_low_degree(ArithmeticExtensionGate::<4>::new::()) } } diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index 9dc188a0..2471747c 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -186,11 +186,11 @@ impl, const D: usize> Tree> { mod tests { use super::*; use crate::field::crandall_field::CrandallField; + use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::base_sum::BaseSumGate; use crate::gates::constant::ConstantGate; use crate::gates::gmimc::GMiMCGate; use crate::gates::interpolation::InterpolationGate; - use crate::gates::mul_extension::ArithmeticExtensionGate; use crate::gates::noop::NoopGate; use crate::hash::GMIMC_ROUNDS; diff --git a/src/gates/mod.rs b/src/gates/mod.rs index bb8b178b..fa23b273 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -1,11 +1,10 @@ -pub(crate) mod arithmetic; +pub mod arithmetic; pub mod base_sum; pub mod constant; pub(crate) mod gate; pub mod gate_tree; pub mod gmimc; pub mod interpolation; -pub mod mul_extension; pub(crate) mod noop; #[cfg(test)] diff --git a/src/gates/mul_extension.rs b/src/gates/mul_extension.rs deleted file mode 100644 index a3006837..00000000 --- a/src/gates/mul_extension.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::ops::Range; - -use crate::circuit_builder::CircuitBuilder; -use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; -use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; -use crate::wire::Wire; -use crate::witness::PartialWitness; - -/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. -#[derive(Debug)] -pub struct ArithmeticExtensionGate; - -impl ArithmeticExtensionGate { - pub fn new>() -> GateRef { - GateRef::new(ArithmeticExtensionGate) - } - - pub fn wires_fixed_multiplicand() -> Range { - 0..D - } - pub fn wires_multiplicand_0() -> Range { - D..2 * D - } - pub fn wires_addend_0() -> Range { - 2 * D..3 * D - } - pub fn wires_multiplicand_1() -> Range { - 3 * D..4 * D - } - pub fn wires_addend_1() -> Range { - 4 * D..5 * D - } - pub fn wires_output_0() -> Range { - 5 * D..6 * D - } - pub fn wires_output_1() -> Range { - 6 * D..7 * D - } -} - -impl, const D: usize> Gate for ArithmeticExtensionGate { - fn id(&self) -> String { - format!("{:?}", self) - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let const_0 = vars.local_constants[0]; - let const_1 = vars.local_constants[1]; - - let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); - let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); - let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); - let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); - let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); - let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); - - let computed_output_0 = - fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into(); - let computed_output_1 = - fixed_multiplicand * multiplicand_1 * const_1.into() + addend_1 * const_1.into(); - - let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec(); - constraints.extend((output_1 - computed_output_1).to_basefield_array()); - constraints - } - - fn eval_unfiltered_recursively( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let const_0 = vars.local_constants[0]; - let const_1 = vars.local_constants[1]; - - let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); - let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); - let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); - let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); - let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); - let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); - - let computed_output_0 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_0); - let computed_output_0 = builder.scalar_mul_ext_algebra(const_0, computed_output_0); - let scaled_addend_0 = builder.scalar_mul_ext_algebra(const_1, addend_0); - let computed_output_0 = builder.add_ext_algebra(computed_output_0, scaled_addend_0); - - let computed_output_1 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_1); - let computed_output_1 = builder.scalar_mul_ext_algebra(const_0, computed_output_1); - let scaled_addend_1 = builder.scalar_mul_ext_algebra(const_1, addend_1); - let computed_output_1 = builder.add_ext_algebra(computed_output_1, scaled_addend_1); - - let diff_0 = builder.sub_ext_algebra(output_0, computed_output_0); - let diff_1 = builder.sub_ext_algebra(output_1, computed_output_1); - let mut constraints = diff_0.to_ext_target_array().to_vec(); - constraints.extend(diff_1.to_ext_target_array()); - constraints - } - - fn generators( - &self, - gate_index: usize, - local_constants: &[F], - ) -> Vec>> { - let gen = MulExtensionGenerator { - gate_index, - const_0: local_constants[0], - const_1: local_constants[1], - }; - vec![Box::new(gen)] - } - - fn num_wires(&self) -> usize { - 28 - } - - fn num_constants(&self) -> usize { - 2 - } - - fn degree(&self) -> usize { - 3 - } - - fn num_constraints(&self) -> usize { - 2 * D - } -} - -struct MulExtensionGenerator, const D: usize> { - gate_index: usize, - const_0: F, - const_1: F, -} - -impl, const D: usize> SimpleGenerator for MulExtensionGenerator { - fn dependencies(&self) -> Vec { - ArithmeticExtensionGate::::wires_fixed_multiplicand() - .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) - .chain(ArithmeticExtensionGate::::wires_addend_0()) - .chain(ArithmeticExtensionGate::::wires_multiplicand_1()) - .chain(ArithmeticExtensionGate::::wires_addend_1()) - .map(|i| Target::wire(self.gate_index, i)) - .collect() - } - - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let extract_extension = |range: Range| -> F::Extension { - let t = ExtensionTarget::from_range(self.gate_index, range); - witness.get_extension_target(t) - }; - - let fixed_multiplicand = - extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); - let multiplicand_0 = - extract_extension(ArithmeticExtensionGate::::wires_multiplicand_0()); - let addend_0 = extract_extension(ArithmeticExtensionGate::::wires_addend_0()); - let multiplicand_1 = - extract_extension(ArithmeticExtensionGate::::wires_multiplicand_1()); - let addend_1 = extract_extension(ArithmeticExtensionGate::::wires_addend_1()); - - let output_target_0 = ExtensionTarget::from_range( - self.gate_index, - ArithmeticExtensionGate::::wires_output_0(), - ); - let output_target_1 = ExtensionTarget::from_range( - self.gate_index, - ArithmeticExtensionGate::::wires_output_1(), - ); - - let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() - + addend_0 * self.const_1.into(); - let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into() - + addend_1 * self.const_1.into(); - - let mut pw = PartialWitness::new(); - pw.set_extension_target(output_target_0, computed_output_0); - pw.set_extension_target(output_target_1, computed_output_1); - pw - } -} - -#[cfg(test)] -mod tests { - use crate::field::crandall_field::CrandallField; - use crate::gates::gate_testing::test_low_degree; - use crate::gates::mul_extension::ArithmeticExtensionGate; - - #[test] - fn low_degree() { - test_low_degree(ArithmeticExtensionGate::<4>::new::()) - } -} diff --git a/src/util/scaling.rs b/src/util/scaling.rs index 057e8467..0bc61840 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -4,7 +4,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, Frobenius}; use crate::field::field::Field; -use crate::gates::mul_extension::ArithmeticExtensionGate; +use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::generator::SimpleGenerator; use crate::polynomial::polynomial::PolynomialCoeffs; use crate::target::Target; From 8602ae154936cecc49cd8eda0a863ea0e8be40d0 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Jun 2021 16:35:58 +0200 Subject: [PATCH 08/22] Typo --- src/field/extension_field/target.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 8eb86043..6a5ab2be 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -218,7 +218,7 @@ impl, const D: usize> CircuitBuilder { let mut acc0 = zero; let mut acc1 = zero; for chunk in terms.chunks_exact(2) { - (acc0, acc1) = self.add_two_extension(acc0, acc1, chunk[0], chunk[1]); + (acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]); } self.add_extension(acc0, acc1) } @@ -257,7 +257,7 @@ impl, const D: usize> CircuitBuilder { res.extend([o0, o1]); } if D % 2 == 1 { - res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); + res.push(self.sub_extension(a.0[D - 1], b.0[D - 1])); } ExtensionAlgebraTarget(res.try_into().unwrap()) } From fc4738869d3fed5de883256f33cdfcacb82ae9b8 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Jun 2021 16:45:02 +0200 Subject: [PATCH 09/22] Rearrange files --- src/field/extension_field/target.rs | 257 --------------- src/gadgets/arithmetic.rs | 191 +----------- src/gadgets/arithmetic_extension.rs | 464 ++++++++++++++++++++++++++++ src/gadgets/mod.rs | 1 + 4 files changed, 466 insertions(+), 447 deletions(-) create mode 100644 src/gadgets/arithmetic_extension.rs diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 6a5ab2be..e7dead21 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -112,263 +112,6 @@ impl, const D: usize> CircuitBuilder { self.constant_ext_algebra(ExtensionAlgebra::ZERO) } - pub fn double_arithmetic_extension( - &mut self, - const_0: F, - const_1: F, - fixed_multiplicand: ExtensionTarget, - multiplicand_0: ExtensionTarget, - addend_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - addend_1: ExtensionTarget, - ) -> (ExtensionTarget, ExtensionTarget) { - let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]); - - let wire_fixed_multiplicand = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_fixed_multiplicand(), - ); - let wire_multiplicand_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); - let wire_addend_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_0()); - let wire_multiplicand_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); - let wire_addend_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_1()); - let wire_output_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); - let wire_output_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_1()); - - self.route_extension(fixed_multiplicand, wire_fixed_multiplicand); - self.route_extension(multiplicand_0, wire_multiplicand_0); - self.route_extension(addend_0, wire_addend_0); - self.route_extension(multiplicand_1, wire_multiplicand_1); - self.route_extension(addend_1, wire_addend_1); - (wire_output_0, wire_output_1) - } - - pub fn arithmetic_extension( - &mut self, - const_0: F, - const_1: F, - multiplicand_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - addend: ExtensionTarget, - ) -> ExtensionTarget { - let zero = self.zero_extension(); - self.double_arithmetic_extension( - const_0, - const_1, - multiplicand_0, - multiplicand_1, - addend, - zero, - zero, - ) - .0 - } - - pub fn add_extension( - &mut self, - a: ExtensionTarget, - b: ExtensionTarget, - ) -> ExtensionTarget { - let one = self.one_extension(); - self.arithmetic_extension(F::ONE, F::ONE, one, a, b) - } - - pub fn add_two_extension( - &mut self, - a0: ExtensionTarget, - b0: ExtensionTarget, - a1: ExtensionTarget, - b1: ExtensionTarget, - ) -> (ExtensionTarget, ExtensionTarget) { - let one = self.one_extension(); - self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, a1, b1) - } - - pub fn add_ext_algebra( - &mut self, - a: ExtensionAlgebraTarget, - b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - let mut res = Vec::with_capacity(D); - let d_even = D & (D ^ 1); // = 2 * (D/2) - for mut chunk in &(0..d_even).chunks(2) { - let i = chunk.next().unwrap(); - let j = chunk.next().unwrap(); - let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); - res.extend([o0, o1]); - } - if D % 2 == 1 { - res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); - } - ExtensionAlgebraTarget(res.try_into().unwrap()) - } - - pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let zero = self.zero_extension(); - let mut terms = terms.to_vec(); - if terms.len() % 2 == 1 { - terms.push(zero); - } - let mut acc0 = zero; - let mut acc1 = zero; - for chunk in terms.chunks_exact(2) { - (acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]); - } - self.add_extension(acc0, acc1) - } - - pub fn sub_extension( - &mut self, - a: ExtensionTarget, - b: ExtensionTarget, - ) -> ExtensionTarget { - let one = self.one_extension(); - self.arithmetic_extension(F::ONE, F::NEG_ONE, one, a, b) - } - - pub fn sub_two_extension( - &mut self, - a0: ExtensionTarget, - b0: ExtensionTarget, - a1: ExtensionTarget, - b1: ExtensionTarget, - ) -> (ExtensionTarget, ExtensionTarget) { - let one = self.one_extension(); - self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, a1, b1) - } - - pub fn sub_ext_algebra( - &mut self, - a: ExtensionAlgebraTarget, - b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - let mut res = Vec::with_capacity(D); - let d_even = D & (D ^ 1); // = 2 * (D/2) - for mut chunk in &(0..d_even).chunks(2) { - let i = chunk.next().unwrap(); - let j = chunk.next().unwrap(); - let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); - res.extend([o0, o1]); - } - if D % 2 == 1 { - res.push(self.sub_extension(a.0[D - 1], b.0[D - 1])); - } - ExtensionAlgebraTarget(res.try_into().unwrap()) - } - - pub fn mul_extension_with_const( - &mut self, - const_0: F, - multiplicand_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - ) -> ExtensionTarget { - let zero = self.zero_extension(); - self.double_arithmetic_extension( - const_0, - F::ZERO, - multiplicand_0, - multiplicand_1, - zero, - zero, - zero, - ) - .0 - } - - pub fn mul_extension( - &mut self, - multiplicand_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - ) -> ExtensionTarget { - self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1) - } - - pub fn mul_ext_algebra( - &mut self, - a: ExtensionAlgebraTarget, - b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - let mut res = [self.zero_extension(); D]; - let w = self.constant(F::Extension::W); - for i in 0..D { - for j in 0..D { - res[(i + j) % D] = if i + j < D { - self.mul_add_extension(a.0[i], b.0[j], res[(i + j) % D]) - } else { - let ai_bi = self.mul_extension(a.0[i], b.0[j]); - self.scalar_mul_add_extension(w, ai_bi, res[(i + j) % D]) - } - } - } - ExtensionAlgebraTarget(res) - } - - pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let mut product = self.one_extension(); - for term in terms { - product = self.mul_extension(product, *term); - } - product - } - - /// Like `mul_add`, but for `ExtensionTarget`s. Note that, unlike `mul_add`, this has no - /// performance benefit over separate muls and adds. - pub fn mul_add_extension( - &mut self, - a: ExtensionTarget, - b: ExtensionTarget, - c: ExtensionTarget, - ) -> ExtensionTarget { - self.arithmetic_extension(F::ONE, F::ONE, a, b, c) - } - - /// Like `mul_add`, but for `ExtensionTarget`s. - pub fn scalar_mul_add_extension( - &mut self, - a: Target, - b: ExtensionTarget, - c: ExtensionTarget, - ) -> ExtensionTarget { - let a_ext = self.convert_to_ext(a); - self.arithmetic_extension(F::ONE, F::ONE, a_ext, b, c) - } - - /// Like `mul_sub`, but for `ExtensionTarget`s. - pub fn scalar_mul_sub_extension( - &mut self, - a: Target, - b: ExtensionTarget, - c: ExtensionTarget, - ) -> ExtensionTarget { - let a_ext = self.convert_to_ext(a); - self.arithmetic_extension(F::ONE, F::NEG_ONE, a_ext, b, c) - } - - /// Returns `a * b`, where `b` is in the extension field and `a` is in the base field. - pub fn scalar_mul_ext(&mut self, a: Target, b: ExtensionTarget) -> ExtensionTarget { - let a_ext = self.convert_to_ext(a); - self.mul_extension(a_ext, b) - } - - /// Returns `a * b`, where `b` is in the extension of the extension field, and `a` is in the - /// extension field. - pub fn scalar_mul_ext_algebra( - &mut self, - a: ExtensionTarget, - mut b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - for i in 0..D { - b.0[i] = self.mul_extension(a, b.0[i]); - } - b - } - pub fn convert_to_ext(&mut self, t: Target) -> ExtensionTarget { let zero = self.zero(); let mut arr = [zero; D]; diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 99e260c7..b478f621 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -23,11 +23,6 @@ impl, const D: usize> CircuitBuilder { self.mul(x, x) } - /// Computes `x^2`. - pub fn square_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { - self.mul_extension(x, x) - } - /// Computes `x^3`. pub fn cube(&mut self, x: Target) -> Target { self.mul_many(&[x, x, x]) @@ -137,6 +132,7 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, one, F::ONE, y) } + // TODO: Can be made `2*D` times more efficient by using all wires of an `ArithmeticExtensionGate`. pub fn add_many(&mut self, terms: &[Target]) -> Target { let mut sum = self.zero(); for term in terms { @@ -200,25 +196,6 @@ impl, const D: usize> CircuitBuilder { product } - /// Exponentiate `base` to the power of a known `exponent`. - // TODO: Test - pub fn exp_u64_extension( - &mut self, - base: ExtensionTarget, - exponent: u64, - ) -> ExtensionTarget { - let mut current = base; - let mut product = self.one_extension(); - - for j in 0..bits_u64(exponent as u64) { - if (exponent >> j & 1) != 0 { - product = self.mul_extension(product, current); - } - current = self.square_extension(current); - } - product - } - /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in /// some cases, as it allows `0 / 0 = `. pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target { @@ -241,170 +218,4 @@ impl, const D: usize> CircuitBuilder { let y_ext = self.convert_to_ext(y); self.div_unsafe_extension(x_ext, y_ext).0[0] } - - /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in - /// some cases, as it allows `0 / 0 = `. - pub fn div_unsafe_extension( - &mut self, - x: ExtensionTarget, - y: ExtensionTarget, - ) -> ExtensionTarget { - // Add an `ArithmeticExtensionGate` to compute `q * y`. - let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]); - - let multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_fixed_multiplicand(), - ); - let multiplicand_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); - let output = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); - - self.add_generator(QuotientGeneratorExtension { - numerator: x, - denominator: y, - quotient: multiplicand_0, - }); - self.add_generator(ZeroOutGenerator { - gate_index: gate, - ranges: vec![ - ArithmeticExtensionGate::::wires_addend_0(), - ArithmeticExtensionGate::::wires_multiplicand_1(), - ArithmeticExtensionGate::::wires_addend_1(), - ], - }); - - self.route_extension(y, multiplicand_1); - self.assert_equal_extension(output, x); - - multiplicand_0 - } -} - -struct QuotientGeneratorExtension { - numerator: ExtensionTarget, - denominator: ExtensionTarget, - quotient: ExtensionTarget, -} - -impl, const D: usize> SimpleGenerator for QuotientGeneratorExtension { - fn dependencies(&self) -> Vec { - let mut deps = self.numerator.to_target_array().to_vec(); - deps.extend(&self.denominator.to_target_array()); - deps - } - - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let num = witness.get_extension_target(self.numerator); - let dem = witness.get_extension_target(self.denominator); - let quotient = num / dem; - let mut pw = PartialWitness::new(); - for i in 0..D { - pw.set_target( - self.quotient.to_target_array()[i], - quotient.to_basefield_array()[i], - ); - } - - pw - } -} - -/// Generator used to zero out wires at a given gate index and ranges. -pub struct ZeroOutGenerator { - gate_index: usize, - ranges: Vec>, -} - -impl SimpleGenerator for ZeroOutGenerator { - fn dependencies(&self) -> Vec { - Vec::new() - } - - fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { - let mut pw = PartialWitness::new(); - for t in self - .ranges - .iter() - .flat_map(|r| Target::wires_from_range(self.gate_index, r.clone())) - { - pw.set_target(t, F::ZERO); - } - - pw - } -} - -/// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. -#[derive(Clone)] -pub struct PowersTarget { - base: ExtensionTarget, - current: ExtensionTarget, -} - -impl PowersTarget { - pub fn next>( - &mut self, - builder: &mut CircuitBuilder, - ) -> ExtensionTarget { - let result = self.current; - self.current = builder.mul_extension(self.base, self.current); - result - } - - pub fn repeated_frobenius>( - self, - k: usize, - builder: &mut CircuitBuilder, - ) -> Self { - let Self { base, current } = self; - Self { - base: base.repeated_frobenius(k, builder), - current: current.repeated_frobenius(k, builder), - } - } -} - -impl, const D: usize> CircuitBuilder { - pub fn powers(&mut self, base: ExtensionTarget) -> PowersTarget { - PowersTarget { - base, - current: self.one_extension(), - } - } -} - -#[cfg(test)] -mod tests { - use crate::circuit_builder::CircuitBuilder; - use crate::circuit_data::CircuitConfig; - use crate::field::crandall_field::CrandallField; - use crate::field::extension_field::quartic::QuarticCrandallField; - use crate::field::field::Field; - use crate::fri::FriConfig; - use crate::witness::PartialWitness; - - #[test] - fn test_div_extension() { - type F = CrandallField; - type FF = QuarticCrandallField; - const D: usize = 4; - - let config = CircuitConfig::large_config(); - - let mut builder = CircuitBuilder::::new(config); - - let x = FF::rand(); - let y = FF::rand(); - let z = x / y; - let xt = builder.constant_extension(x); - let yt = builder.constant_extension(y); - let zt = builder.constant_extension(z); - let comp_zt = builder.div_unsafe_extension(xt, yt); - builder.assert_equal_extension(zt, comp_zt); - - let data = builder.build(); - let proof = data.prove(PartialWitness::new()); - } } diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs new file mode 100644 index 00000000..d29e5079 --- /dev/null +++ b/src/gadgets/arithmetic_extension.rs @@ -0,0 +1,464 @@ +use std::convert::TryInto; +use std::ops::Range; + +use itertools::Itertools; + +use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::algebra::ExtensionAlgebra; +use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; +use crate::field::extension_field::{Extendable, FieldExtension, OEF}; +use crate::field::field::Field; +use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::generator::SimpleGenerator; +use crate::target::Target; +use crate::util::bits_u64; +use crate::witness::PartialWitness; + +impl, const D: usize> CircuitBuilder { + pub fn double_arithmetic_extension( + &mut self, + const_0: F, + const_1: F, + fixed_multiplicand: ExtensionTarget, + multiplicand_0: ExtensionTarget, + addend_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend_1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]); + + let wire_fixed_multiplicand = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ); + let wire_multiplicand_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); + let wire_addend_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_0()); + let wire_multiplicand_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); + let wire_addend_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_1()); + let wire_output_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + let wire_output_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_1()); + + self.route_extension(fixed_multiplicand, wire_fixed_multiplicand); + self.route_extension(multiplicand_0, wire_multiplicand_0); + self.route_extension(addend_0, wire_addend_0); + self.route_extension(multiplicand_1, wire_multiplicand_1); + self.route_extension(addend_1, wire_addend_1); + (wire_output_0, wire_output_1) + } + + pub fn arithmetic_extension( + &mut self, + const_0: F, + const_1: F, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend: ExtensionTarget, + ) -> ExtensionTarget { + let zero = self.zero_extension(); + self.double_arithmetic_extension( + const_0, + const_1, + multiplicand_0, + multiplicand_1, + addend, + zero, + zero, + ) + .0 + } + + pub fn add_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + ) -> ExtensionTarget { + let one = self.one_extension(); + self.arithmetic_extension(F::ONE, F::ONE, one, a, b) + } + + pub fn add_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let one = self.one_extension(); + self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, a1, b1) + } + + pub fn add_ext_algebra( + &mut self, + a: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + let mut res = Vec::with_capacity(D); + let d_even = D & (D ^ 1); // = 2 * (D/2) + for mut chunk in &(0..d_even).chunks(2) { + let i = chunk.next().unwrap(); + let j = chunk.next().unwrap(); + let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); + res.extend([o0, o1]); + } + if D % 2 == 1 { + res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); + } + ExtensionAlgebraTarget(res.try_into().unwrap()) + } + + pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { + let zero = self.zero_extension(); + let mut terms = terms.to_vec(); + if terms.len() % 2 == 1 { + terms.push(zero); + } + let mut acc0 = zero; + let mut acc1 = zero; + for chunk in terms.chunks_exact(2) { + (acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]); + } + self.add_extension(acc0, acc1) + } + + pub fn sub_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + ) -> ExtensionTarget { + let one = self.one_extension(); + self.arithmetic_extension(F::ONE, F::NEG_ONE, one, a, b) + } + + pub fn sub_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let one = self.one_extension(); + self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, a1, b1) + } + + pub fn sub_ext_algebra( + &mut self, + a: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + let mut res = Vec::with_capacity(D); + let d_even = D & (D ^ 1); // = 2 * (D/2) + for mut chunk in &(0..d_even).chunks(2) { + let i = chunk.next().unwrap(); + let j = chunk.next().unwrap(); + let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); + res.extend([o0, o1]); + } + if D % 2 == 1 { + res.push(self.sub_extension(a.0[D - 1], b.0[D - 1])); + } + ExtensionAlgebraTarget(res.try_into().unwrap()) + } + + pub fn mul_extension_with_const( + &mut self, + const_0: F, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + ) -> ExtensionTarget { + let zero = self.zero_extension(); + self.double_arithmetic_extension( + const_0, + F::ZERO, + multiplicand_0, + multiplicand_1, + zero, + zero, + zero, + ) + .0 + } + + pub fn mul_extension( + &mut self, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + ) -> ExtensionTarget { + self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1) + } + + /// Computes `x^2`. + pub fn square_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { + self.mul_extension(x, x) + } + + pub fn mul_ext_algebra( + &mut self, + a: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + let mut res = [self.zero_extension(); D]; + let w = self.constant(F::Extension::W); + for i in 0..D { + for j in 0..D { + res[(i + j) % D] = if i + j < D { + self.mul_add_extension(a.0[i], b.0[j], res[(i + j) % D]) + } else { + let ai_bi = self.mul_extension(a.0[i], b.0[j]); + self.scalar_mul_add_extension(w, ai_bi, res[(i + j) % D]) + } + } + } + ExtensionAlgebraTarget(res) + } + + pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { + let mut product = self.one_extension(); + for term in terms { + product = self.mul_extension(product, *term); + } + product + } + + /// Like `mul_add`, but for `ExtensionTarget`s. Note that, unlike `mul_add`, this has no + /// performance benefit over separate muls and adds. + pub fn mul_add_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + self.arithmetic_extension(F::ONE, F::ONE, a, b, c) + } + + /// Like `mul_add`, but for `ExtensionTarget`s. + pub fn scalar_mul_add_extension( + &mut self, + a: Target, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let a_ext = self.convert_to_ext(a); + self.arithmetic_extension(F::ONE, F::ONE, a_ext, b, c) + } + + /// Like `mul_sub`, but for `ExtensionTarget`s. + pub fn scalar_mul_sub_extension( + &mut self, + a: Target, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let a_ext = self.convert_to_ext(a); + self.arithmetic_extension(F::ONE, F::NEG_ONE, a_ext, b, c) + } + + /// Returns `a * b`, where `b` is in the extension field and `a` is in the base field. + pub fn scalar_mul_ext(&mut self, a: Target, b: ExtensionTarget) -> ExtensionTarget { + let a_ext = self.convert_to_ext(a); + self.mul_extension(a_ext, b) + } + + /// Returns `a * b`, where `b` is in the extension of the extension field, and `a` is in the + /// extension field. + pub fn scalar_mul_ext_algebra( + &mut self, + a: ExtensionTarget, + mut b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + for i in 0..D { + b.0[i] = self.mul_extension(a, b.0[i]); + } + b + } + + /// Exponentiate `base` to the power of a known `exponent`. + // TODO: Test + pub fn exp_u64_extension( + &mut self, + base: ExtensionTarget, + exponent: u64, + ) -> ExtensionTarget { + let mut current = base; + let mut product = self.one_extension(); + + for j in 0..bits_u64(exponent as u64) { + if (exponent >> j & 1) != 0 { + product = self.mul_extension(product, current); + } + current = self.square_extension(current); + } + product + } + + /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in + /// some cases, as it allows `0 / 0 = `. + pub fn div_unsafe_extension( + &mut self, + x: ExtensionTarget, + y: ExtensionTarget, + ) -> ExtensionTarget { + // Add an `ArithmeticExtensionGate` to compute `q * y`. + let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]); + + let multiplicand_0 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ); + let multiplicand_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); + let output = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + + self.add_generator(QuotientGeneratorExtension { + numerator: x, + denominator: y, + quotient: multiplicand_0, + }); + self.add_generator(ZeroOutGenerator { + gate_index: gate, + ranges: vec![ + ArithmeticExtensionGate::::wires_addend_0(), + ArithmeticExtensionGate::::wires_multiplicand_1(), + ArithmeticExtensionGate::::wires_addend_1(), + ], + }); + + self.route_extension(y, multiplicand_1); + self.assert_equal_extension(output, x); + + multiplicand_0 + } +} + +/// Generator used to zero out wires at a given gate index and ranges. +pub struct ZeroOutGenerator { + gate_index: usize, + ranges: Vec>, +} + +impl SimpleGenerator for ZeroOutGenerator { + fn dependencies(&self) -> Vec { + Vec::new() + } + + fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { + let mut pw = PartialWitness::new(); + for t in self + .ranges + .iter() + .flat_map(|r| Target::wires_from_range(self.gate_index, r.clone())) + { + pw.set_target(t, F::ZERO); + } + + pw + } +} + +struct QuotientGeneratorExtension { + numerator: ExtensionTarget, + denominator: ExtensionTarget, + quotient: ExtensionTarget, +} + +impl, const D: usize> SimpleGenerator for QuotientGeneratorExtension { + fn dependencies(&self) -> Vec { + let mut deps = self.numerator.to_target_array().to_vec(); + deps.extend(&self.denominator.to_target_array()); + deps + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let num = witness.get_extension_target(self.numerator); + let dem = witness.get_extension_target(self.denominator); + let quotient = num / dem; + let mut pw = PartialWitness::new(); + for i in 0..D { + pw.set_target( + self.quotient.to_target_array()[i], + quotient.to_basefield_array()[i], + ); + } + + pw + } +} + +/// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. +#[derive(Clone)] +pub struct PowersTarget { + base: ExtensionTarget, + current: ExtensionTarget, +} + +impl PowersTarget { + pub fn next>( + &mut self, + builder: &mut CircuitBuilder, + ) -> ExtensionTarget { + let result = self.current; + self.current = builder.mul_extension(self.base, self.current); + result + } + + pub fn repeated_frobenius>( + self, + k: usize, + builder: &mut CircuitBuilder, + ) -> Self { + let Self { base, current } = self; + Self { + base: base.repeated_frobenius(k, builder), + current: current.repeated_frobenius(k, builder), + } + } +} + +impl, const D: usize> CircuitBuilder { + pub fn powers(&mut self, base: ExtensionTarget) -> PowersTarget { + PowersTarget { + base, + current: self.one_extension(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::circuit_builder::CircuitBuilder; + use crate::circuit_data::CircuitConfig; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::field::field::Field; + use crate::fri::FriConfig; + use crate::witness::PartialWitness; + + #[test] + fn test_div_extension() { + type F = CrandallField; + type FF = QuarticCrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand(); + let y = FF::rand(); + let z = x / y; + let xt = builder.constant_extension(x); + let yt = builder.constant_extension(y); + let zt = builder.constant_extension(z); + let comp_zt = builder.div_unsafe_extension(xt, yt); + builder.assert_equal_extension(zt, comp_zt); + + let data = builder.build(); + let proof = data.prove(PartialWitness::new()); + } +} diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index a1e041fc..2f216870 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -1,4 +1,5 @@ pub mod arithmetic; +pub mod arithmetic_extension; pub mod hash; pub mod insert; pub mod interpolation; From 42db0a31c12bcf93aa8008957e60ce5b0bf64d38 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Jun 2021 16:49:29 +0200 Subject: [PATCH 10/22] Clippy --- src/field/extension_field/target.rs | 3 --- src/gadgets/arithmetic.rs | 10 +--------- src/gadgets/arithmetic_extension.rs | 1 - src/gates/arithmetic.rs | 3 +-- src/generator.rs | 1 - 5 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index e7dead21..9d60847e 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -1,13 +1,10 @@ use std::convert::{TryFrom, TryInto}; use std::ops::Range; -use itertools::Itertools; - use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::{Extendable, FieldExtension, OEF}; use crate::field::field::Field; -use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::target::Target; /// `Target`s representing an element of an extension field. diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index b478f621..73455797 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,15 +1,7 @@ -use std::ops::Range; - use crate::circuit_builder::CircuitBuilder; -use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field::Field; -use crate::gates::arithmetic::ArithmeticExtensionGate; -use crate::generator::SimpleGenerator; +use crate::field::extension_field::Extendable; use crate::target::Target; use crate::util::bits_u64; -use crate::wire::Wire; -use crate::witness::PartialWitness; impl, const D: usize> CircuitBuilder { /// Computes `-x`. diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index d29e5079..af97a866 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -4,7 +4,6 @@ use std::ops::Range; use itertools::Itertools; use crate::circuit_builder::CircuitBuilder; -use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::{Extendable, FieldExtension, OEF}; use crate::field::field::Field; diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 586a0cc2..6751889e 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -2,12 +2,11 @@ use std::ops::Range; use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::{Extendable, FieldExtension}; +use crate::field::extension_field::Extendable; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; -use crate::wire::Wire; use crate::witness::PartialWitness; /// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. diff --git a/src/generator.rs b/src/generator.rs index db81172f..d760df05 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -2,7 +2,6 @@ use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use crate::field::field::Field; -use crate::permutation_argument::TargetPartitions; use crate::target::Target; use crate::witness::PartialWitness; From 2f06a78cb14c6767a134de3c1a6374d7f3f5346c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Jun 2021 16:53:11 +0200 Subject: [PATCH 11/22] Simplify exp_u64 --- src/gadgets/arithmetic.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 73455797..bda70624 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -176,16 +176,8 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of a known `exponent`. // TODO: Test pub fn exp_u64(&mut self, base: Target, exponent: u64) -> Target { - let mut current = base; - let mut product = self.one(); - - for j in 0..bits_u64(exponent as u64) { - if (exponent >> j & 1) != 0 { - product = self.mul(product, current); - } - current = self.square(current); - } - product + let base_ext = self.convert_to_ext(base); + self.exp_u64_extension(base_ext, exponent).0[0] } /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in From 636d8bef0776d493e51a5e8c9795c531f7138c70 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Jun 2021 17:24:22 +0200 Subject: [PATCH 12/22] Comments --- src/gadgets/arithmetic_extension.rs | 7 +++ src/gates/arithmetic.rs | 6 +-- src/util/scaling.rs | 66 +++++++++++++++++------------ 3 files changed, 50 insertions(+), 29 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index af97a866..7dcd88cd 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -97,7 +97,10 @@ impl, const D: usize> CircuitBuilder { a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { + // We run two additions in parallel. So `[a0,a1,a2,a3] + [b0,b1,b2,b3]` is computed with two + // `add_two_extension`, first `[a0,a1]+[b0,b1]` then `[a2,a3]+[b2,b3]`. let mut res = Vec::with_capacity(D); + // We need some extra logic if D is odd. let d_even = D & (D ^ 1); // = 2 * (D/2) for mut chunk in &(0..d_even).chunks(2) { let i = chunk.next().unwrap(); @@ -117,11 +120,13 @@ impl, const D: usize> CircuitBuilder { if terms.len() % 2 == 1 { terms.push(zero); } + // We maintain two accumulators, one for the sum of even elements, and one for odd elements. let mut acc0 = zero; let mut acc1 = zero; for chunk in terms.chunks_exact(2) { (acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]); } + // We sum both accumulators to get the final result. self.add_extension(acc0, acc1) } @@ -150,6 +155,7 @@ impl, const D: usize> CircuitBuilder { a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { + // See `add_ext_algebra`. let mut res = Vec::with_capacity(D); let d_even = D & (D ^ 1); // = 2 * (D/2) for mut chunk in &(0..d_even).chunks(2) { @@ -319,6 +325,7 @@ impl, const D: usize> CircuitBuilder { denominator: y, quotient: multiplicand_0, }); + // We need to zero out the other wires for the `ArithmeticExtensionGenerator` to hit. self.add_generator(ZeroOutGenerator { gate_index: gate, ranges: vec![ diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 6751889e..31ae5caa 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -106,7 +106,7 @@ impl, const D: usize> Gate for ArithmeticExtensionGate gate_index: usize, local_constants: &[F], ) -> Vec>> { - let gen = MulExtensionGenerator { + let gen = ArithmeticExtensionGenerator { gate_index, const_0: local_constants[0], const_1: local_constants[1], @@ -131,13 +131,13 @@ impl, const D: usize> Gate for ArithmeticExtensionGate } } -struct MulExtensionGenerator, const D: usize> { +struct ArithmeticExtensionGenerator, const D: usize> { gate_index: usize, const_0: F, const_1: F, } -impl, const D: usize> SimpleGenerator for MulExtensionGenerator { +impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator { fn dependencies(&self) -> Vec { ArithmeticExtensionGate::::wires_fixed_multiplicand() .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) diff --git a/src/util/scaling.rs b/src/util/scaling.rs index 0bc61840..87158649 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -91,6 +91,11 @@ impl ReducingFactorTarget { Self { base, count: 0 } } + /// Reduces a length `n` vector of `ExtensionTarget`s using `n/2` `ArithmeticExtensionGate`s. + /// It does this by running two accumulators in parallel. Here's an example with `n=4, alpha=2, D=1`: + /// 1st gate: 2 0 4 11 2 4 24 <- 2*0+4= 4, 2*11+2=24 + /// 2nd gate: 2 4 3 24 1 11 49 <- 2*4+3=11, 2*24+1=49 + /// which verifies that `2.reduce([1,2,3,4]) = 49`. pub fn reduce( &mut self, iter: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. @@ -99,18 +104,22 @@ impl ReducingFactorTarget { where F: Extendable, { + let zero = builder.zero_extension(); let l = iter.len(); self.count += l as u64; + // If needed we pad the original vector so that it has even length. let padded_iter = if l % 2 == 0 { iter.to_vec() } else { - [iter, &[builder.zero_extension()]].concat() + [iter, &[zero]].concat() }; let half_length = padded_iter.len() / 2; + // Add `n/2` `ArithmeticExtensionGate`s that will perform the accumulation. let gates = (0..half_length) .map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE])) .collect::>(); + // Add a generator that will fill the accumulation wires. builder.add_generator(ParallelReductionGenerator { base: self.base, padded_iter: padded_iter.clone(), @@ -119,24 +128,33 @@ impl ReducingFactorTarget { }); for i in 0..half_length { + // The fixed multiplicand is always `base`. builder.route_extension( + self.base, + ExtensionTarget::from_range( + gates[i], + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ), + ); + // Set the addends for the first half of the accumulation. + builder.route_extension( + padded_iter[2 * half_length - i - 1], ExtensionTarget::from_range( gates[i], ArithmeticExtensionGate::::wires_addend_0(), ), - padded_iter[2 * half_length - i - 1], ); - } - for i in 0..half_length { + // Set the addends for the second half of the accumulation. builder.route_extension( + padded_iter[half_length - i - 1], ExtensionTarget::from_range( gates[i], ArithmeticExtensionGate::::wires_addend_1(), ), - padded_iter[half_length - i - 1], ); } for gate_pair in gates[..half_length].windows(2) { + // Verifies that the accumulator is passed between gates for the first half of the accumulation. builder.assert_equal_extension( ExtensionTarget::from_range( gate_pair[0], @@ -149,6 +167,7 @@ impl ReducingFactorTarget { ); } for gate_pair in gates[half_length..].windows(2) { + // Verifies that the accumulator is passed between gates for the second half of the accumulation. builder.assert_equal_extension( ExtensionTarget::from_range( gate_pair[0], @@ -160,6 +179,16 @@ impl ReducingFactorTarget { ), ); } + // Verifies that the starting accumulator for the first half is zero. + builder.assert_equal_extension( + ExtensionTarget::from_range( + gates[0], + ArithmeticExtensionGate::::wires_multiplicand_0(), + ), + zero, + ); + // Verifies that the final accumulator for the first half is passed as a starting + // accumulator for the second half. builder.assert_equal_extension( ExtensionTarget::from_range( gates[half_length - 1], @@ -171,6 +200,7 @@ impl ReducingFactorTarget { ), ); + // Return the final accumulator for the second half. ExtensionTarget::from_range( gates[half_length - 1], ArithmeticExtensionGate::::wires_output_1(), @@ -206,6 +236,7 @@ impl ReducingFactorTarget { } } +/// Fills the intermediate accumulator in `ReducingFactorTarget::reduce`. struct ParallelReductionGenerator { base: ExtensionTarget, padded_iter: Vec>, @@ -215,6 +246,7 @@ struct ParallelReductionGenerator { impl, const D: usize> SimpleGenerator for ParallelReductionGenerator { fn dependencies(&self) -> Vec { + // Need only the values and the base. self.padded_iter .iter() .flat_map(|ext| ext.to_target_array()) @@ -230,6 +262,7 @@ impl, const D: usize> SimpleGenerator for ParallelReductionG .iter() .map(|&ext| witness.get_extension_target(ext)) .collect::>(); + // Computed the intermediate accumulators. let intermediate_accs = vs .iter() .rev() @@ -240,13 +273,7 @@ impl, const D: usize> SimpleGenerator for ParallelReductionG }) .collect::>(); for i in 0..self.half_length { - pw.set_extension_target( - ExtensionTarget::from_range( - self.gates[i], - ArithmeticExtensionGate::::wires_fixed_multiplicand(), - ), - base, - ); + // Fill the accumulators for the first half. pw.set_extension_target( ExtensionTarget::from_range( self.gates[i], @@ -254,13 +281,7 @@ impl, const D: usize> SimpleGenerator for ParallelReductionG ), intermediate_accs[i], ); - pw.set_extension_target( - ExtensionTarget::from_range( - self.gates[i], - ArithmeticExtensionGate::::wires_addend_0(), - ), - vs[2 * self.half_length - i - 1], - ); + // Fill the accumulators for the second half. pw.set_extension_target( ExtensionTarget::from_range( self.gates[i], @@ -268,13 +289,6 @@ impl, const D: usize> SimpleGenerator for ParallelReductionG ), intermediate_accs[self.half_length + i], ); - pw.set_extension_target( - ExtensionTarget::from_range( - self.gates[i], - ArithmeticExtensionGate::::wires_addend_1(), - ), - vs[self.half_length - i - 1], - ); } pw From 12e81acccfb30d080e32c612806f3f6167f5aa47 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 28 Jun 2021 11:27:43 +0200 Subject: [PATCH 13/22] Optimize the degree of the tree returned by `Tree::from_gates` to allow non-power of 2 degree. --- src/circuit_builder.rs | 15 ++++++------ src/gates/gate_tree.rs | 53 +++++++++++++++++++++++++++++++++++------- 2 files changed, 52 insertions(+), 16 deletions(-) diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 1fe5a5e1..27e58336 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -230,12 +230,11 @@ impl, const D: usize> CircuitBuilder { } } - fn constant_polys(&self, gates: &[PrefixedGate]) -> Vec> { - let num_constants = gates - .iter() - .map(|gate| gate.gate.0.num_constants() + gate.prefix.len()) - .max() - .unwrap(); + fn constant_polys( + &self, + gates: &[PrefixedGate], + num_constants: usize, + ) -> Vec> { let constants_per_gate = self .gate_instances .iter() @@ -294,10 +293,10 @@ impl, const D: usize> CircuitBuilder { info!("degree after blinding & padding: {}", degree); let gates = self.gates.iter().cloned().collect(); - let gate_tree = Tree::from_gates(gates); + let (gate_tree, max_filtered_constraint_degree, num_constants) = Tree::from_gates(gates); let prefixed_gates = PrefixedGate::from_tree(gate_tree); - let constant_vecs = self.constant_polys(&prefixed_gates); + let constant_vecs = self.constant_polys(&prefixed_gates, num_constants); let constants_commitment = ListPolynomialCommitment::new( constant_vecs.into_iter().map(|v| v.ifft()).collect(), self.config.fri_config.rate_bits, diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index 1423de76..8610f648 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -57,8 +57,9 @@ impl, const D: usize> Tree> { /// For this construction, we use the greedy algorithm in `Self::find_tree`. /// This latter function greedily adds gates at the depth where /// `filtered_deg(gate)=D, constant_wires(gate)=C` to ensure no space is wasted. - /// We return the first tree found in this manner. - pub fn from_gates(mut gates: Vec>) -> Self { + /// We return the first tree found in this manner, along with it's maximum filtered degree + /// and the number of constant wires needed when using this tree. + pub fn from_gates(mut gates: Vec>) -> (Self, usize, usize) { let timer = std::time::Instant::now(); gates.sort_unstable_by_key(|g| (-(g.0.degree() as isize), -(g.0.num_constants() as isize))); @@ -67,14 +68,32 @@ impl, const D: usize> Tree> { // So we can restrict our search space by setting `max_degree` to a power of 2. let max_degree = 1 << max_degree_bits; for max_constants in 1..100 { - if let Some(mut tree) = Self::find_tree(&gates, max_degree, max_constants) { - tree.shorten(); + if let Some(mut best_tree) = Self::find_tree(&gates, max_degree, max_constants) { + best_tree.shorten(); + let mut best_num_constants = best_tree.num_constants(); + let mut best_degree = max_degree; + // Iterate backwards from `max_degree` to try to find a tree with a lower degree + // but the same number of constants. + 'optdegree: for degree in (0..max_degree).rev() { + if let Some(mut tree) = Self::find_tree(&gates, degree, max_constants) { + tree.shorten(); + let num_constants = tree.num_constants(); + if num_constants > best_num_constants { + break 'optdegree; + } else { + best_degree = degree; + best_num_constants = num_constants; + best_tree = tree; + } + } + } info!( - "Found tree with max degree {} in {}s.", - max_degree, + "Found tree with max degree {} and {} constants in {}s.", + best_degree, + best_num_constants, timer.elapsed().as_secs_f32() ); - return tree; + return (best_tree, best_degree, best_num_constants); } } } @@ -180,6 +199,24 @@ impl, const D: usize> Tree> { } } } + + /// Returns the tree's maximum filtered constraint degree. + fn max_filtered_degree(&self) -> usize { + self.traversal() + .into_iter() + .map(|(g, p)| g.0.degree() + p.len()) + .max() + .expect("Empty tree.") + } + + /// Returns the number of constant wires needed to fit all prefixes and gate constants. + fn num_constants(&self) -> usize { + self.traversal() + .into_iter() + .map(|(g, p)| g.0.num_constants() + p.len()) + .max() + .expect("Empty tree.") + } } #[cfg(test)] @@ -212,7 +249,7 @@ mod tests { ]; let len = gates.len(); - let tree = Tree::from_gates(gates.clone()); + let (tree, _, _) = Tree::from_gates(gates.clone()); let mut gates_with_prefix = tree.traversal(); for (g, p) in &gates_with_prefix { info!( From 9a352193ed2e679ff51633829f6e908a833aa78f Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 29 Jun 2021 09:49:05 +0200 Subject: [PATCH 14/22] PR feedback --- src/gadgets/arithmetic_extension.rs | 10 ++-------- src/gates/arithmetic.rs | 4 ++-- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 7dcd88cd..77e0a2c0 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -230,8 +230,7 @@ impl, const D: usize> CircuitBuilder { product } - /// Like `mul_add`, but for `ExtensionTarget`s. Note that, unlike `mul_add`, this has no - /// performance benefit over separate muls and adds. + /// Like `mul_add`, but for `ExtensionTarget`s. pub fn mul_add_extension( &mut self, a: ExtensionTarget, @@ -385,12 +384,7 @@ impl, const D: usize> SimpleGenerator for QuotientGeneratorE let dem = witness.get_extension_target(self.denominator); let quotient = num / dem; let mut pw = PartialWitness::new(); - for i in 0..D { - pw.set_target( - self.quotient.to_target_array()[i], - quotient.to_basefield_array()[i], - ); - } + pw.set_extension_target(self.quotient, quotient); pw } diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 31ae5caa..5be18085 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -61,7 +61,7 @@ impl, const D: usize> Gate for ArithmeticExtensionGate let computed_output_0 = fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into(); let computed_output_1 = - fixed_multiplicand * multiplicand_1 * const_1.into() + addend_1 * const_1.into(); + fixed_multiplicand * multiplicand_1 * const_0.into() + addend_1 * const_1.into(); let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec(); constraints.extend((output_1 - computed_output_1).to_basefield_array()); @@ -115,7 +115,7 @@ impl, const D: usize> Gate for ArithmeticExtensionGate } fn num_wires(&self) -> usize { - 28 + 7 * D } fn num_constants(&self) -> usize { From bae3777bcdf9f34f95f43417daf9f3ffb4fd8bc6 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 29 Jun 2021 14:00:34 +0200 Subject: [PATCH 15/22] Use max filtered degree found with the tree method in `CircuitBuilder::build` --- src/circuit_builder.rs | 9 +-------- src/circuit_data.rs | 4 ++-- src/gates/gate_tree.rs | 2 +- src/polynomial/commitment.rs | 2 +- src/prover.rs | 21 +++++++++------------ 5 files changed, 14 insertions(+), 24 deletions(-) diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 7ecf625f..9a5626b9 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -300,14 +300,7 @@ impl, const D: usize> CircuitBuilder { let degree_bits = log2_strict(degree); let subgroup = F::two_adic_subgroup(degree_bits); - let constant_vecs = self.constant_polys(&prefixed_gates); - let num_constants = constant_vecs.len(); let constant_vecs = self.constant_polys(&prefixed_gates, num_constants); - let constants_commitment = ListPolynomialCommitment::new( - constant_vecs.into_iter().map(|v| v.ifft()).collect(), - self.config.fri_config.rate_bits, - false, - ); let k_is = get_unique_coset_shifts(degree, self.config.num_routed_wires); let sigma_vecs = self.sigma_vecs(&k_is, &subgroup); @@ -355,7 +348,7 @@ impl, const D: usize> CircuitBuilder { config: self.config, degree_bits, gates: prefixed_gates, - max_filtered_constraint_degree_bits: 3, // TODO: compute this correctly once filters land. + max_filtered_constraint_degree, num_gate_constraints, num_constants, k_is, diff --git a/src/circuit_data.rs b/src/circuit_data.rs index afb37628..fa575dcc 100644 --- a/src/circuit_data.rs +++ b/src/circuit_data.rs @@ -146,7 +146,7 @@ pub struct CommonCircuitData, const D: usize> { pub(crate) gates: Vec>, /// The maximum degree of a filter times a constraint by any gate. - pub(crate) max_filtered_constraint_degree_bits: usize, + pub(crate) max_filtered_constraint_degree: usize, /// The largest number of constraints imposed by any gate. pub(crate) num_gate_constraints: usize, @@ -184,7 +184,7 @@ impl, const D: usize> CommonCircuitData { } pub fn quotient_degree(&self) -> usize { - ((1 << self.max_filtered_constraint_degree_bits) - 1) * self.degree() + (self.max_filtered_constraint_degree - 1) * self.degree() } pub fn total_constraints(&self) -> usize { diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index 8610f648..8a858fd2 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -88,7 +88,7 @@ impl, const D: usize> Tree> { } } info!( - "Found tree with max degree {} and {} constants in {}s.", + "Found tree with max degree {} and {} constants wires in {}s.", best_degree, best_num_constants, timer.elapsed().as_secs_f32() diff --git a/src/polynomial/commitment.rs b/src/polynomial/commitment.rs index da647ffa..9bd3905b 100644 --- a/src/polynomial/commitment.rs +++ b/src/polynomial/commitment.rs @@ -327,7 +327,7 @@ mod tests { }, degree_bits: 0, gates: vec![], - max_filtered_constraint_degree_bits: 0, + max_filtered_constraint_degree: 0, num_gate_constraints: 0, num_constants: 4, k_is: vec![F::ONE; 6], diff --git a/src/prover.rs b/src/prover.rs index e8cd7f46..67b5a185 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -12,7 +12,7 @@ use crate::polynomial::commitment::ListPolynomialCommitment; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::proof::Proof; use crate::timed; -use crate::util::transpose; +use crate::util::{log2_ceil, transpose}; use crate::vars::EvaluationVarsBase; use crate::witness::{PartialWitness, Witness}; @@ -219,23 +219,22 @@ fn compute_quotient_polys<'a, F: Extendable, const D: usize>( alphas: &[F], ) -> Vec> { let num_challenges = common_data.config.num_challenges; + let max_filtered_constraint_degree_bits = log2_ceil(common_data.max_filtered_constraint_degree); assert!( - common_data.max_filtered_constraint_degree_bits <= common_data.config.rate_bits, + max_filtered_constraint_degree_bits <= common_data.config.rate_bits, "Having constraints of degree higher than the rate is not supported yet. \ If we need this in the future, we can precompute the larger LDE before computing the `ListPolynomialCommitment`s." ); // We reuse the LDE computed in `ListPolynomialCommitment` and extract every `step` points to get // an LDE matching `max_filtered_constraint_degree`. - let step = - 1 << (common_data.config.rate_bits - common_data.max_filtered_constraint_degree_bits); + let step = 1 << (common_data.config.rate_bits - max_filtered_constraint_degree_bits); // When opening the `Z`s polys at the "next" point in Plonk, need to look at the point `next_step` // steps away since we work on an LDE of degree `max_filtered_constraint_degree`. - let next_step = 1 << common_data.max_filtered_constraint_degree_bits; + let next_step = 1 << max_filtered_constraint_degree_bits; - let points = F::two_adic_subgroup( - common_data.degree_bits + common_data.max_filtered_constraint_degree_bits, - ); + let points = + F::two_adic_subgroup(common_data.degree_bits + max_filtered_constraint_degree_bits); let lde_size = points.len(); // Retrieve the LDE values at index `i`. @@ -243,10 +242,8 @@ fn compute_quotient_polys<'a, F: Extendable, const D: usize>( comm.get_lde_values(i * step) }; - let z_h_on_coset = ZeroPolyOnCoset::new( - common_data.degree_bits, - common_data.max_filtered_constraint_degree_bits, - ); + let z_h_on_coset = + ZeroPolyOnCoset::new(common_data.degree_bits, max_filtered_constraint_degree_bits); let quotient_values: Vec> = points .into_par_iter() From f1e3474fcb02f95f0788fdc349af040e9b4f311a Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 29 Jun 2021 12:33:11 -0700 Subject: [PATCH 16/22] Simple reduce (#78) * Simple reduce * Fix bug causing test failure --- src/circuit_builder.rs | 4 + src/gadgets/arithmetic_extension.rs | 7 +- src/gates/arithmetic.rs | 60 ++++++--- src/util/scaling.rs | 191 +++++----------------------- src/witness.rs | 12 ++ 5 files changed, 97 insertions(+), 177 deletions(-) diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 1c42c4c1..6f110c18 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -62,6 +62,10 @@ impl, const D: usize> CircuitBuilder { } } + pub fn num_gates(&self) -> usize { + self.gate_instances.len() + } + pub fn add_public_input(&mut self) -> Target { let index = self.public_input_index; self.public_input_index += 1; diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 77e0a2c0..ce53c59f 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -2,6 +2,7 @@ use std::convert::TryInto; use std::ops::Range; use itertools::Itertools; +use num::Integer; use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; @@ -108,7 +109,7 @@ impl, const D: usize> CircuitBuilder { let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); res.extend([o0, o1]); } - if D % 2 == 1 { + if D.is_odd() { res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); } ExtensionAlgebraTarget(res.try_into().unwrap()) @@ -117,7 +118,7 @@ impl, const D: usize> CircuitBuilder { pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let zero = self.zero_extension(); let mut terms = terms.to_vec(); - if terms.len() % 2 == 1 { + if terms.len().is_odd() { terms.push(zero); } // We maintain two accumulators, one for the sum of even elements, and one for odd elements. @@ -164,7 +165,7 @@ impl, const D: usize> CircuitBuilder { let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); res.extend([o0, o1]); } - if D % 2 == 1 { + if D.is_odd() { res.push(self.sub_extension(a.0[D - 1], b.0[D - 1])); } ExtensionAlgebraTarget(res.try_into().unwrap()) diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 5be18085..39baa226 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -106,12 +106,17 @@ impl, const D: usize> Gate for ArithmeticExtensionGate gate_index: usize, local_constants: &[F], ) -> Vec>> { - let gen = ArithmeticExtensionGenerator { + let gen0 = ArithmeticExtensionGenerator0 { gate_index, const_0: local_constants[0], const_1: local_constants[1], }; - vec![Box::new(gen)] + let gen1 = ArithmeticExtensionGenerator1 { + gate_index, + const_0: local_constants[0], + const_1: local_constants[1], + }; + vec![Box::new(gen0), Box::new(gen1)] } fn num_wires(&self) -> usize { @@ -131,19 +136,23 @@ impl, const D: usize> Gate for ArithmeticExtensionGate } } -struct ArithmeticExtensionGenerator, const D: usize> { +struct ArithmeticExtensionGenerator0, const D: usize> { gate_index: usize, const_0: F, const_1: F, } -impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator { +struct ArithmeticExtensionGenerator1, const D: usize> { + gate_index: usize, + const_0: F, + const_1: F, +} + +impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator0 { fn dependencies(&self) -> Vec { ArithmeticExtensionGate::::wires_fixed_multiplicand() .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) .chain(ArithmeticExtensionGate::::wires_addend_0()) - .chain(ArithmeticExtensionGate::::wires_multiplicand_1()) - .chain(ArithmeticExtensionGate::::wires_addend_1()) .map(|i| Target::wire(self.gate_index, i)) .collect() } @@ -159,28 +168,49 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio let multiplicand_0 = extract_extension(ArithmeticExtensionGate::::wires_multiplicand_0()); let addend_0 = extract_extension(ArithmeticExtensionGate::::wires_addend_0()); - let multiplicand_1 = - extract_extension(ArithmeticExtensionGate::::wires_multiplicand_1()); - let addend_1 = extract_extension(ArithmeticExtensionGate::::wires_addend_1()); let output_target_0 = ExtensionTarget::from_range( self.gate_index, ArithmeticExtensionGate::::wires_output_0(), ); + + let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() + + addend_0 * self.const_1.into(); + + PartialWitness::singleton_extension_target(output_target_0, computed_output_0) + } +} + +impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator1 { + fn dependencies(&self) -> Vec { + ArithmeticExtensionGate::::wires_fixed_multiplicand() + .chain(ArithmeticExtensionGate::::wires_multiplicand_1()) + .chain(ArithmeticExtensionGate::::wires_addend_1()) + .map(|i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let fixed_multiplicand = + extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); + let multiplicand_1 = + extract_extension(ArithmeticExtensionGate::::wires_multiplicand_1()); + let addend_1 = extract_extension(ArithmeticExtensionGate::::wires_addend_1()); + let output_target_1 = ExtensionTarget::from_range( self.gate_index, ArithmeticExtensionGate::::wires_output_1(), ); - let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() - + addend_0 * self.const_1.into(); let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into() + addend_1 * self.const_1.into(); - let mut pw = PartialWitness::new(); - pw.set_extension_target(output_target_0, computed_output_0); - pw.set_extension_target(output_target_1, computed_output_1); - pw + PartialWitness::singleton_extension_target(output_target_1, computed_output_1) } } diff --git a/src/util/scaling.rs b/src/util/scaling.rs index 87158649..4941d03f 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -1,14 +1,13 @@ use std::borrow::Borrow; +use num::Integer; + use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, Frobenius}; use crate::field::field::Field; use crate::gates::arithmetic::ArithmeticExtensionGate; -use crate::generator::SimpleGenerator; use crate::polynomial::polynomial::PolynomialCoeffs; -use crate::target::Target; -use crate::witness::PartialWitness; /// When verifying the composition polynomial in FRI we have to compute sums of the form /// `(sum_0^k a^i * x_i)/d_0 + (sum_k^r a^i * y_i)/d_1` @@ -98,113 +97,45 @@ impl ReducingFactorTarget { /// which verifies that `2.reduce([1,2,3,4]) = 49`. pub fn reduce( &mut self, - iter: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. + terms: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. builder: &mut CircuitBuilder, ) -> ExtensionTarget where F: Extendable, { let zero = builder.zero_extension(); - let l = iter.len(); + let l = terms.len(); self.count += l as u64; - // If needed we pad the original vector so that it has even length. - let padded_iter = if l % 2 == 0 { - iter.to_vec() - } else { - [iter, &[zero]].concat() - }; - let half_length = padded_iter.len() / 2; - // Add `n/2` `ArithmeticExtensionGate`s that will perform the accumulation. - let gates = (0..half_length) - .map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE])) - .collect::>(); - // Add a generator that will fill the accumulation wires. - builder.add_generator(ParallelReductionGenerator { - base: self.base, - padded_iter: padded_iter.clone(), - gates: gates.clone(), - half_length, - }); + let mut terms_vec = terms.to_vec(); + // If needed, we pad the original vector so that it has even length. + if terms_vec.len().is_odd() { + terms_vec.push(zero); + } + terms_vec.reverse(); - for i in 0..half_length { - // The fixed multiplicand is always `base`. - builder.route_extension( - self.base, - ExtensionTarget::from_range( - gates[i], - ArithmeticExtensionGate::::wires_fixed_multiplicand(), - ), - ); - // Set the addends for the first half of the accumulation. - builder.route_extension( - padded_iter[2 * half_length - i - 1], - ExtensionTarget::from_range( - gates[i], - ArithmeticExtensionGate::::wires_addend_0(), - ), - ); - // Set the addends for the second half of the accumulation. - builder.route_extension( - padded_iter[half_length - i - 1], - ExtensionTarget::from_range( - gates[i], - ArithmeticExtensionGate::::wires_addend_1(), - ), - ); + let mut acc = zero; + for pair in terms_vec.chunks(2) { + // We will route the output of the first arithmetic operation to the multiplicand of the + // second, i.e. we compute the following: + // out_0 = alpha acc + pair[0] + // acc' = out_1 = alpha out_0 + pair[1] + let gate = builder.num_gates(); + let out_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + acc = builder + .double_arithmetic_extension( + F::ONE, + F::ONE, + self.base, + acc, + pair[0], + out_0, + pair[1], + ) + .1; } - for gate_pair in gates[..half_length].windows(2) { - // Verifies that the accumulator is passed between gates for the first half of the accumulation. - builder.assert_equal_extension( - ExtensionTarget::from_range( - gate_pair[0], - ArithmeticExtensionGate::::wires_output_0(), - ), - ExtensionTarget::from_range( - gate_pair[1], - ArithmeticExtensionGate::::wires_multiplicand_0(), - ), - ); - } - for gate_pair in gates[half_length..].windows(2) { - // Verifies that the accumulator is passed between gates for the second half of the accumulation. - builder.assert_equal_extension( - ExtensionTarget::from_range( - gate_pair[0], - ArithmeticExtensionGate::::wires_output_1(), - ), - ExtensionTarget::from_range( - gate_pair[1], - ArithmeticExtensionGate::::wires_multiplicand_1(), - ), - ); - } - // Verifies that the starting accumulator for the first half is zero. - builder.assert_equal_extension( - ExtensionTarget::from_range( - gates[0], - ArithmeticExtensionGate::::wires_multiplicand_0(), - ), - zero, - ); - // Verifies that the final accumulator for the first half is passed as a starting - // accumulator for the second half. - builder.assert_equal_extension( - ExtensionTarget::from_range( - gates[half_length - 1], - ArithmeticExtensionGate::::wires_output_0(), - ), - ExtensionTarget::from_range( - gates[0], - ArithmeticExtensionGate::::wires_multiplicand_1(), - ), - ); - - // Return the final accumulator for the second half. - ExtensionTarget::from_range( - gates[half_length - 1], - ArithmeticExtensionGate::::wires_output_1(), - ) + acc } pub fn shift( @@ -236,71 +167,13 @@ impl ReducingFactorTarget { } } -/// Fills the intermediate accumulator in `ReducingFactorTarget::reduce`. -struct ParallelReductionGenerator { - base: ExtensionTarget, - padded_iter: Vec>, - gates: Vec, - half_length: usize, -} - -impl, const D: usize> SimpleGenerator for ParallelReductionGenerator { - fn dependencies(&self) -> Vec { - // Need only the values and the base. - self.padded_iter - .iter() - .flat_map(|ext| ext.to_target_array()) - .chain(self.base.to_target_array()) - .collect() - } - - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let mut pw = PartialWitness::new(); - let base = witness.get_extension_target(self.base); - let vs = self - .padded_iter - .iter() - .map(|&ext| witness.get_extension_target(ext)) - .collect::>(); - // Computed the intermediate accumulators. - let intermediate_accs = vs - .iter() - .rev() - .scan(F::Extension::ZERO, |acc, &x| { - let tmp = *acc; - *acc = *acc * base + x; - Some(tmp) - }) - .collect::>(); - for i in 0..self.half_length { - // Fill the accumulators for the first half. - pw.set_extension_target( - ExtensionTarget::from_range( - self.gates[i], - ArithmeticExtensionGate::::wires_multiplicand_0(), - ), - intermediate_accs[i], - ); - // Fill the accumulators for the second half. - pw.set_extension_target( - ExtensionTarget::from_range( - self.gates[i], - ArithmeticExtensionGate::::wires_multiplicand_1(), - ), - intermediate_accs[self.half_length + i], - ); - } - - pw - } -} - #[cfg(test)] mod tests { use super::*; use crate::circuit_data::CircuitConfig; use crate::field::crandall_field::CrandallField; use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::witness::PartialWitness; fn test_reduce_gadget(n: usize) { type F = CrandallField; diff --git a/src/witness.rs b/src/witness.rs index 989ec2bc..e2426aaa 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -32,6 +32,18 @@ impl PartialWitness { witness } + pub fn singleton_extension_target( + et: ExtensionTarget, + value: F::Extension, + ) -> Self + where + F: Extendable, + { + let mut witness = PartialWitness::new(); + witness.set_extension_target(et, value); + witness + } + pub fn is_empty(&self) -> bool { self.target_values.is_empty() } From eee3026eeee36009fb38f736c58a68bdd497eb0e Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 30 Jun 2021 08:15:56 +0200 Subject: [PATCH 17/22] Move `shorten` in `find_tree` --- src/gates/gate_tree.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index 8a858fd2..d1b58aa3 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -69,14 +69,12 @@ impl, const D: usize> Tree> { let max_degree = 1 << max_degree_bits; for max_constants in 1..100 { if let Some(mut best_tree) = Self::find_tree(&gates, max_degree, max_constants) { - best_tree.shorten(); let mut best_num_constants = best_tree.num_constants(); let mut best_degree = max_degree; // Iterate backwards from `max_degree` to try to find a tree with a lower degree // but the same number of constants. 'optdegree: for degree in (0..max_degree).rev() { if let Some(mut tree) = Self::find_tree(&gates, degree, max_constants) { - tree.shorten(); let num_constants = tree.num_constants(); if num_constants > best_num_constants { break 'optdegree; @@ -108,6 +106,7 @@ impl, const D: usize> Tree> { for g in gates { tree.try_add_gate(g, max_degree, max_constants)?; } + tree.shorten(); Some(tree) } From b7f0352cd8224a62afa6496d5c41f77fe9c3a8c7 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 30 Jun 2021 08:25:36 +0200 Subject: [PATCH 18/22] Update comment on `reduce` --- src/util/scaling.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/util/scaling.rs b/src/util/scaling.rs index 4941d03f..3449e0b9 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -91,9 +91,10 @@ impl ReducingFactorTarget { } /// Reduces a length `n` vector of `ExtensionTarget`s using `n/2` `ArithmeticExtensionGate`s. - /// It does this by running two accumulators in parallel. Here's an example with `n=4, alpha=2, D=1`: - /// 1st gate: 2 0 4 11 2 4 24 <- 2*0+4= 4, 2*11+2=24 - /// 2nd gate: 2 4 3 24 1 11 49 <- 2*4+3=11, 2*24+1=49 + /// It does this by batching two steps of Horner's method in each gate. + /// Here's an example with `n=4, alpha=2, D=1`: + /// 1st gate: 2 0 4 4 3 4 11 <- 2*0+4=4, 2*4+3=11 + /// 2nd gate: 2 11 2 24 1 24 49 <- 2*11+2=24, 2*24+1=49 /// which verifies that `2.reduce([1,2,3,4]) = 49`. pub fn reduce( &mut self, From 03179e5674590dad2102854dd380e6bed035c850 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 30 Jun 2021 12:54:45 -0700 Subject: [PATCH 19/22] Couple fixes related to blinding - `self.gates` -> `self.gate_instances` - Some tests were using a single binary FRI reduction, which doesn't provide enough succinctness for our blinding scheme to work. This caused `blinding_counts` to continue until it overflowed. --- src/bin/bench_recursion.rs | 2 +- src/circuit_builder.rs | 2 +- src/circuit_data.rs | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index 59b65e51..0f1b9783 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -32,7 +32,7 @@ fn bench_prove, const D: usize>() { fri_config: FriConfig { proof_of_work_bits: 1, rate_bits: 3, - reduction_arity_bits: vec![1], + reduction_arity_bits: vec![1, 1, 1, 1], num_query_rounds: 1, }, }; diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 1bf938a0..7ce4adb3 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -259,7 +259,7 @@ impl, const D: usize> CircuitBuilder { /// polynomials (which are opened at only one location) and for the Z polynomials (which are /// opened at two). fn blinding_counts(&self) -> (usize, usize) { - let num_gates = self.gates.len(); + let num_gates = self.gate_instances.len(); let mut degree_estimate = 1 << log2_ceil(num_gates); loop { diff --git a/src/circuit_data.rs b/src/circuit_data.rs index 6f352832..d72a0be3 100644 --- a/src/circuit_data.rs +++ b/src/circuit_data.rs @@ -39,7 +39,7 @@ impl Default for CircuitConfig { fri_config: FriConfig { proof_of_work_bits: 1, rate_bits: 1, - reduction_arity_bits: vec![1], + reduction_arity_bits: vec![1, 1, 1, 1], num_query_rounds: 1, }, } @@ -61,7 +61,7 @@ impl CircuitConfig { fri_config: FriConfig { proof_of_work_bits: 1, rate_bits: 3, - reduction_arity_bits: vec![1], + reduction_arity_bits: vec![1, 1, 1, 1], num_query_rounds: 1, }, } From 574a3d4847fc3ca8506849c611bde55e639f9adf Mon Sep 17 00:00:00 2001 From: Hamish Ivey-Law <426294+unzvfu@users.noreply.github.com> Date: Thu, 1 Jul 2021 14:55:41 +1000 Subject: [PATCH 20/22] FFT improvements (#81) * Use built-in `reverse_bits`; remove duplicate `reverse_index_bits`. * Reduce precomputation time/space complexity from quadratic to linear. * Several working cache-friendly FFTs. * Fix to allow FFT of constant polynomial. * Simplify FFT strategy choice. * Add PrimeField and CHARACTERISTIC properties to Fields. * Add faster method for inverse of 2^m. * Pre-compute some of the roots; tidy up loop iteration. * Precomputation for both FFT variants. * Refactor precomputation; add optional parameters; rename some things. * Unrolled version with zero tail. * Iterative version of Unrolled precomputation. * Test zero tail algo. * Restore default degree. * Address comments from @dlubarov and @wborgeaud. --- src/field/crandall_field.rs | 3 + src/field/extension_field/quadratic.rs | 3 + src/field/extension_field/quartic.rs | 3 + src/field/fft.rs | 358 ++++++++++++++++++------- src/field/field.rs | 28 ++ src/field/field_testing.rs | 14 + src/polynomial/polynomial.rs | 6 +- src/util/mod.rs | 15 +- 8 files changed, 325 insertions(+), 105 deletions(-) diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index dbd29cb2..7a1d18d6 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -136,6 +136,8 @@ impl Debug for CrandallField { } impl Field for CrandallField { + type PrimeField = Self; + const ZERO: Self = Self(0); const ONE: Self = Self(1); const TWO: Self = Self(2); @@ -143,6 +145,7 @@ impl Field for CrandallField { const ORDER: u64 = 18446744071293632513; const TWO_ADICITY: usize = 28; + const CHARACTERISTIC: u64 = Self::ORDER; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(5); const POWER_OF_TWO_GENERATOR: Self = Self(10281950781551402419); diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index af21ad60..ede2ef26 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -43,11 +43,14 @@ impl From<>::BaseField> for QuadraticCrandallField { } impl Field for QuadraticCrandallField { + type PrimeField = CrandallField; + const ZERO: Self = Self([CrandallField::ZERO; 2]); const ONE: Self = Self([CrandallField::ONE, CrandallField::ZERO]); const TWO: Self = Self([CrandallField::TWO, CrandallField::ZERO]); const NEG_ONE: Self = Self([CrandallField::NEG_ONE, CrandallField::ZERO]); + const CHARACTERISTIC: u64 = CrandallField::ORDER; // Does not fit in 64-bits. const ORDER: u64 = 0; const TWO_ADICITY: usize = 29; diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index b93cbb56..f609eeb7 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -50,6 +50,8 @@ impl From<>::BaseField> for QuarticCrandallField { } impl Field for QuarticCrandallField { + type PrimeField = CrandallField; + const ZERO: Self = Self([CrandallField::ZERO; 4]); const ONE: Self = Self([ CrandallField::ONE, @@ -70,6 +72,7 @@ impl Field for QuarticCrandallField { CrandallField::ZERO, ]); + const CHARACTERISTIC: u64 = CrandallField::ORDER; // Does not fit in 64-bits. const ORDER: u64 = 0; const TWO_ADICITY: usize = 30; diff --git a/src/field/fft.rs b/src/field/fft.rs index 56764b47..af5c05a7 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -1,142 +1,304 @@ +use std::option::Option; + use crate::field::field::Field; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; -use crate::util::{log2_ceil, log2_strict}; +use crate::util::{log2_strict, reverse_index_bits}; -/// Permutes `arr` such that each index is mapped to its reverse in binary. -fn reverse_index_bits(arr: Vec) -> Vec { - let n = arr.len(); - let n_power = log2_strict(n); +// TODO: Should really do some "dynamic" dispatch to handle the +// different FFT algos rather than C-style enum dispatch. +enum FftStrategy { Classic, Unrolled } - let mut result = Vec::with_capacity(n); - for i in 0..n { - result.push(arr[reverse_bits(i, n_power)]); +const FFT_STRATEGY: FftStrategy = FftStrategy::Classic; + +type FftRootTable = Vec>; + +fn fft_classic_root_table(n: usize) -> FftRootTable { + let lg_n = log2_strict(n); + // bases[i] = g^2^i, for i = 0, ..., lg_n - 1 + let mut bases = Vec::with_capacity(lg_n); + let mut base = F::primitive_root_of_unity(lg_n); + bases.push(base); + for _ in 1..lg_n { + base = base.square(); // base = g^2^_ + bases.push(base); } - result -} -fn reverse_bits(n: usize, num_bits: usize) -> usize { - let mut result = 0; - for i in 0..num_bits { - let i_rev = num_bits - i - 1; - result |= (n >> i & 1) << i_rev; + let mut root_table = Vec::with_capacity(lg_n); + for lg_m in 1..=lg_n { + let half_m = 1 << (lg_m - 1); + let base = bases[lg_n - lg_m]; + let root_row = base.powers().take(half_m.max(2)).collect(); + root_table.push(root_row); } - result + root_table } -pub(crate) struct FftPrecomputation { - /// For each layer index i, stores the cyclic subgroup corresponding to the evaluation domain of - /// layer i. The indices within these subgroup vectors are bit-reversed. - subgroups_rev: Vec>, + +fn fft_unrolled_root_table(n: usize) -> FftRootTable { + // Precompute a table of the roots of unity used in the main + // loops. + + // Suppose n is the size of the outer vector and g is a primitive nth + // root of unity. Then the [lg(m) - 1][j] element of the table is + // g^{ n/2m * j } for j = 0..m-1 + + let lg_n = log2_strict(n); + // bases[i] = g^2^i, for i = 0, ..., lg_n - 2 + let mut bases = Vec::with_capacity(lg_n); + let mut base = F::primitive_root_of_unity(lg_n); + bases.push(base); + // NB: If n = 1, then lg_n is zero, so we can't do 1..(lg_n-1) here + for _ in 2..lg_n { + base = base.square(); // base = g^2^(_-1) + bases.push(base); + } + + let mut root_table = Vec::with_capacity(lg_n); + for lg_m in 1..lg_n { + let m = 1 << lg_m; + let base = bases[lg_n - lg_m - 1]; + let root_row = base.powers().take(m.max(2)).collect(); + root_table.push(root_row); + } + root_table } -impl FftPrecomputation { - pub fn size(&self) -> usize { - self.subgroups_rev.last().unwrap().len() +#[inline] +fn fft_dispatch( + input: Vec, + zero_factor: Option, + root_table: Option> +) -> Vec { + let n = input.len(); + match FFT_STRATEGY { + FftStrategy::Classic + => fft_classic(input, + zero_factor.unwrap_or(0), + root_table.unwrap_or_else(|| fft_classic_root_table(n))), + FftStrategy::Unrolled + => fft_unrolled(input, + zero_factor.unwrap_or(0), + root_table.unwrap_or_else(|| fft_unrolled_root_table(n))) } } +#[inline] pub fn fft(poly: PolynomialCoeffs) -> PolynomialValues { - let precomputation = fft_precompute(poly.len()); - fft_with_precomputation_power_of_2(poly, &precomputation) + fft_with_options(poly, None, None) } -pub(crate) fn fft_precompute(degree: usize) -> FftPrecomputation { - let degree_log = log2_ceil(degree); - - let mut subgroups_rev = Vec::new(); - let mut subgroup = F::two_adic_subgroup(degree_log); - for _i in 0..=degree_log { - let subsubgroup = subgroup.iter().step_by(2).copied().collect(); - let subgroup_rev = reverse_index_bits(subgroup); - subgroups_rev.push(subgroup_rev); - subgroup = subsubgroup; - } - subgroups_rev.reverse(); - - FftPrecomputation { subgroups_rev } +#[inline] +pub fn fft_with_options( + poly: PolynomialCoeffs, + zero_factor: Option, + root_table: Option> +) -> PolynomialValues { + let PolynomialCoeffs { coeffs } = poly; + PolynomialValues { values: fft_dispatch(coeffs, zero_factor, root_table) } } -pub(crate) fn ifft_with_precomputation_power_of_2( +#[inline] +pub fn ifft(poly: PolynomialValues) -> PolynomialCoeffs { + ifft_with_options(poly, None, None) +} + +pub fn ifft_with_options( poly: PolynomialValues, - precomputation: &FftPrecomputation, + zero_factor: Option, + root_table: Option> ) -> PolynomialCoeffs { let n = poly.len(); - let n_inv = F::from_canonical_usize(n).try_inverse().unwrap(); + let lg_n = log2_strict(n); + let n_inv = F::inverse_2exp(lg_n); let PolynomialValues { values } = poly; - let PolynomialValues { values: mut result } = - fft_with_precomputation_power_of_2(PolynomialCoeffs { coeffs: values }, precomputation); + let mut coeffs = fft_dispatch(values, zero_factor, root_table); // We reverse all values except the first, and divide each by n. - result[0] *= n_inv; - result[n / 2] *= n_inv; + coeffs[0] *= n_inv; + coeffs[n / 2] *= n_inv; for i in 1..(n / 2) { let j = n - i; - let result_i = result[j] * n_inv; - let result_j = result[i] * n_inv; - result[i] = result_i; - result[j] = result_j; + let coeffs_i = coeffs[j] * n_inv; + let coeffs_j = coeffs[i] * n_inv; + coeffs[i] = coeffs_i; + coeffs[j] = coeffs_j; } - PolynomialCoeffs { coeffs: result } + PolynomialCoeffs { coeffs } } -pub(crate) fn fft_with_precomputation_power_of_2( - poly: PolynomialCoeffs, - precomputation: &FftPrecomputation, -) -> PolynomialValues { - debug_assert_eq!( - poly.len(), - precomputation.subgroups_rev.last().unwrap().len(), - "Number of coefficients does not match size of subgroup in precomputation" - ); +/// FFT implementation based on Section 32.3 of "Introduction to +/// Algorithms" by Cormen et al. +/// +/// The parameter r signifies that the first 1/2^r of the entries of +/// input may be non-zero, but the last 1 - 1/2^r entries are +/// definitely zero. +pub(crate) fn fft_classic( + input: Vec, + r: usize, + root_table: FftRootTable +) -> Vec { + let mut values = reverse_index_bits(input); - let half_degree = poly.len() >> 1; - let degree_log = poly.log_len(); + let n = values.len(); + let lg_n = log2_strict(n); - // In the base layer, we're just evaluating "degree 0 polynomials", i.e. the coefficients - // themselves. - let PolynomialCoeffs { coeffs } = poly; - let mut evaluations = reverse_index_bits(coeffs); + if root_table.len() != lg_n { + panic!("Expected root table of length {}, but it was {}.", lg_n, root_table.len()); + } - for i in 1..=degree_log { - // In layer i, we're evaluating a series of polynomials, each at 2^i points. In practice - // we evaluate a pair of points together, so we have 2^(i - 1) pairs. - let points_per_poly = 1 << i; - let pairs_per_poly = 1 << (i - 1); - - let mut new_evaluations = Vec::new(); - for pair_index in 0..half_degree { - let poly_index = pair_index / pairs_per_poly; - let pair_index_within_poly = pair_index % pairs_per_poly; - - let child_index_0 = poly_index * points_per_poly + pair_index_within_poly; - let child_index_1 = child_index_0 + pairs_per_poly; - - let even = evaluations[child_index_0]; - let odd = evaluations[child_index_1]; - - let point_0 = precomputation.subgroups_rev[i][pair_index_within_poly * 2]; - let product = point_0 * odd; - new_evaluations.push(even + product); - new_evaluations.push(even - product); + // After reverse_index_bits, the only non-zero elements of values + // are at indices i*2^r for i = 0..n/2^r. The loop below copies + // the value at i*2^r to the positions [i*2^r + 1, i*2^r + 2, ..., + // (i+1)*2^r - 1]; i.e. it replaces the 2^r - 1 zeros following + // element i*2^r with the value at i*2^r. This corresponds to the + // first r rounds of the FFT when there are 2^r zeros at the end + // of the original input. + if r > 0 { // if r == 0 then this loop is a noop. + let mask = !((1 << r) - 1); + for i in 0..n { + values[i] = values[i & mask]; } - evaluations = new_evaluations; } - // Reorder so that evaluations' indices correspond to (g_0, g_1, g_2, ...) - let values = reverse_index_bits(evaluations); - PolynomialValues { values } + let mut m = 1 << (r + 1); + for lg_m in (r+1)..=lg_n { + let half_m = m / 2; + for k in (0..n).step_by(m) { + for j in 0..half_m { + let omega = root_table[lg_m - 1][j]; + let t = omega * values[k + half_m + j]; + let u = values[k + j]; + values[k + j] = u + t; + values[k + half_m + j] = u - t; + } + } + m *= 2; + } + values } -pub(crate) fn ifft(poly: PolynomialValues) -> PolynomialCoeffs { - let precomputation = fft_precompute(poly.len()); - ifft_with_precomputation_power_of_2(poly, &precomputation) +/// FFT implementation inspired by Barretenberg's (but with extra unrolling): +/// https://github.com/AztecProtocol/barretenberg/blob/master/barretenberg/src/aztec/polynomials/polynomial_arithmetic.cpp#L58 +/// https://github.com/AztecProtocol/barretenberg/blob/master/barretenberg/src/aztec/polynomials/evaluation_domain.cpp#L30 +/// +/// The parameter r signifies that the first 1/2^r of the entries of +/// input may be non-zero, but the last 1 - 1/2^r entries are +/// definitely zero. +fn fft_unrolled( + input: Vec, + r_orig: usize, + root_table: FftRootTable +) -> Vec { + let n = input.len(); + let lg_n = log2_strict(input.len()); + + let mut values = reverse_index_bits(input); + + // FFT of a constant polynomial (including zero) is itself. + if n < 2 { + return values + } + + // The 'm' corresponds to the specialisation from the 'm' in the + // main loop (m >= 4) below. + + // (See comment in fft_classic near same code.) + let mut r = r_orig; + let mut m = 1 << r; + if r > 0 { // if r == 0 then this loop is a noop. + let mask = !((1 << r) - 1); + for i in 0..n { + values[i] = values[i & mask]; + } + } + + // m = 1 + if m == 1 { + for k in (0..n).step_by(2) { + let t = values[k + 1]; + values[k + 1] = values[k] - t; + values[k] += t; + } + r += 1; + m *= 2; + } + + if n == 2 { + return values + } + + if root_table.len() != (lg_n - 1) { + panic!("Expected root table of length {}, but it was {}.", lg_n, root_table.len()); + } + + // m = 2 + if m <= 2 { + for k in (0..n).step_by(4) { + // NB: Grouping statements as is done in the main loop below + // does not seem to help here (worse by a few millis). + let omega_0 = root_table[0][0]; + let tmp_0 = omega_0 * values[k + 2 + 0]; + values[k + 2 + 0] = values[k + 0] - tmp_0; + values[k + 0] += tmp_0; + + let omega_1 = root_table[0][1]; + let tmp_1 = omega_1 * values[k + 2 + 1]; + values[k + 2 + 1] = values[k + 1] - tmp_1; + values[k + 1] += tmp_1; + } + r += 1; + m *= 2; + } + + // m >= 4 + for lg_m in r..lg_n { + for k in (0..n).step_by(2*m) { + // Unrolled the commented loop by groups of 4 and + // rearranged the lines. Improves runtime by about + // 10%. + /* + for j in (0..m) { + let omega = root_table[lg_m - 1][j]; + let tmp = omega * values[k + m + j]; + values[k + m + j] = values[k + j] - tmp; + values[k + j] += tmp; + } + */ + for j in (0..m).step_by(4) { + let off1 = k + j; + let off2 = k + m + j; + + let omega_0 = root_table[lg_m - 1][j + 0]; + let omega_1 = root_table[lg_m - 1][j + 1]; + let omega_2 = root_table[lg_m - 1][j + 2]; + let omega_3 = root_table[lg_m - 1][j + 3]; + + let tmp_0 = omega_0 * values[off2 + 0]; + let tmp_1 = omega_1 * values[off2 + 1]; + let tmp_2 = omega_2 * values[off2 + 2]; + let tmp_3 = omega_3 * values[off2 + 3]; + + values[off2 + 0] = values[off1 + 0] - tmp_0; + values[off2 + 1] = values[off1 + 1] - tmp_1; + values[off2 + 2] = values[off1 + 2] - tmp_2; + values[off2 + 3] = values[off1 + 3] - tmp_3; + values[off1 + 0] += tmp_0; + values[off1 + 1] += tmp_1; + values[off1 + 2] += tmp_2; + values[off1 + 3] += tmp_3; + } + } + m *= 2; + } + values } + #[cfg(test)] mod tests { use crate::field::crandall_field::CrandallField; - use crate::field::fft::{fft, ifft}; + use crate::field::fft::{fft, ifft, fft_with_options}; use crate::field::field::Field; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_ceil, log2_strict}; @@ -162,6 +324,12 @@ mod tests { for i in degree..degree_padded { assert_eq!(interpolated_coefficients.coeffs[i], F::ZERO); } + + for r in 0..4 { + // expand ceofficients by factor 2^r by filling with zeros + let zero_tail = coefficients.clone().lde(r); + assert_eq!(fft(zero_tail.clone()), fft_with_options(zero_tail, Some(r), None)); + } } fn evaluate_naive(coefficients: &PolynomialCoeffs) -> PolynomialValues { diff --git a/src/field/field.rs b/src/field/field.rs index 516012d2..f4ef5990 100644 --- a/src/field/field.rs +++ b/src/field/field.rs @@ -32,11 +32,14 @@ pub trait Field: + Send + Sync { + type PrimeField: Field; + const ZERO: Self; const ONE: Self; const TWO: Self; const NEG_ONE: Self; + const CHARACTERISTIC: u64; const ORDER: u64; const TWO_ADICITY: usize; @@ -101,6 +104,31 @@ pub trait Field: x_inv } + /// Compute the inverse of 2^exp in this field. + #[inline] + fn inverse_2exp(exp: usize) -> Self { + let p = Self::CHARACTERISTIC; + + if exp <= Self::PrimeField::TWO_ADICITY { + // The inverse of 2^exp is p-(p-1)/2^exp when char(F) = p and exp is + // at most the TWO_ADICITY of the prime field. + // + // NB: PrimeFields fit in 64 bits => TWO_ADICITY < 64 => + // exp < 64 => this shift amount is legal. + Self::from_canonical_u64(p - ((p - 1) >> exp)) + } else { + // In the general case we compute 1/2 = (p+1)/2 and then exponentiate + // by exp to get 1/2^exp. Costs about log_2(exp) operations. + let half = Self::from_canonical_u64((p + 1) >> 1); + half.exp(exp as u64) + + // TODO: Faster to combine several high powers of 1/2 using multiple + // applications of the trick above. E.g. if the 2-adicity is v, then + // compute 1/2^(v^2 + v + 13) with 1/2^((v + 1) * v + 13), etc. + // (using the v-adic expansion of m). Costs about log_v(exp) operations. + } + } + fn primitive_root_of_unity(n_log: usize) -> Self { assert!(n_log <= Self::TWO_ADICITY); let mut base = Self::POWER_OF_TWO_GENERATOR; diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index 53e9c63c..7190684f 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -315,6 +315,20 @@ macro_rules! test_arithmetic { assert_eq!(x, F::ONE); assert_eq!(F::ZERO - x, F::NEG_ONE); } + + #[test] + fn inverse_2exp() { + // Just check consistency with try_inverse() + type F = $field; + + let v = ::PrimeField::TWO_ADICITY; + + for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123*v] { + let x = F::TWO.exp(e as u64).inverse(); + let y = F::inverse_2exp(e); + assert_eq!(x, y); + } + } } }; } diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 888d7af0..5f295030 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -1,3 +1,5 @@ +use std::time::Instant; + use std::cmp::max; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; @@ -5,7 +7,7 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; use anyhow::{ensure, Result}; use crate::field::extension_field::Extendable; -use crate::field::fft::{fft, ifft}; +use crate::field::fft::{fft, ifft, fft_with_options}; use crate::field::field::Field; use crate::util::log2_strict; @@ -55,7 +57,7 @@ impl PolynomialValues { pub fn lde(self, rate_bits: usize) -> Self { let coeffs = ifft(self).lde(rate_bits); - fft(coeffs) + fft_with_options(coeffs, Some(rate_bits), None) } pub fn degree(&self) -> usize { diff --git a/src/util/mod.rs b/src/util/mod.rs index f901b0af..8fd60d53 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -49,13 +49,13 @@ pub(crate) fn transpose(matrix: &[Vec]) -> Vec> { } /// Permutes `arr` such that each index is mapped to its reverse in binary. -pub(crate) fn reverse_index_bits(arr: Vec) -> Vec { +pub(crate) fn reverse_index_bits(arr: Vec) -> Vec { let n = arr.len(); let n_power = log2_strict(n); let mut result = Vec::with_capacity(n); for i in 0..n { - result.push(arr[reverse_bits(i, n_power)].clone()); + result.push(arr[reverse_bits(i, n_power)]); } result } @@ -73,12 +73,11 @@ pub(crate) fn reverse_index_bits_in_place(arr: &mut Vec) { } pub(crate) fn reverse_bits(n: usize, num_bits: usize) -> usize { - let mut result = 0; - for i in 0..num_bits { - let i_rev = num_bits - i - 1; - result |= (n >> i & 1) << i_rev; - } - result + // NB: The only reason we need overflowing_shr() here as opposed + // to plain '>>' is to accommodate the case n == num_bits == 0, + // which would become `0 >> 64`. Rust thinks that any shift of 64 + // bits causes overflow, even when the argument is zero. + n.reverse_bits().overflowing_shr(usize::BITS - num_bits as u32).0 } #[cfg(test)] From 95a875e28d213741d9b263ac7c4efe8cb6c8340b Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Thu, 1 Jul 2021 08:12:12 -0700 Subject: [PATCH 21/22] Allow virtual targets to be routed (#84) As in plonky1. The semantics of virtual targets in plonky1 were rather weird, but I think it's somewhat better here, since we already separate `generate_copy` and `assert_equal` methods. Users now make more of an explicit choice -- they can use a `VirtualTarget` for the witness generation only using `generate_copy`, or they can involve it in copy constraints. --- src/circuit_builder.rs | 28 ++++++++++++++-------------- src/gadgets/split_join.rs | 6 +++--- src/permutation_argument.rs | 2 +- src/target.rs | 7 +++++-- src/witness.rs | 2 +- 5 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 7ce4adb3..06c554e3 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -36,7 +36,7 @@ pub struct CircuitBuilder, const D: usize> { /// The next available index for a public input. public_input_index: usize, - /// The next available index for a VirtualAdviceTarget. + /// The next available index for a `VirtualTarget`. virtual_target_index: usize, copy_constraints: Vec<(Target, Target)>, @@ -77,22 +77,18 @@ impl, const D: usize> CircuitBuilder { (0..n).map(|_i| self.add_public_input()).collect() } - /// Adds a new "virtual" advice target. This is not an actual wire in the witness, but just a - /// target that help facilitate witness generation. In particular, a generator can assign a - /// values to a virtual target, which can then be copied to other (virtual or concrete) targets - /// via `generate_copy`. When we generate the final witness (a grid of wire values), these - /// virtual targets will go away. - /// - /// Since virtual targets are not part of the actual permutation argument, they cannot be used - /// with `assert_equal`. - pub fn add_virtual_advice_target(&mut self) -> Target { + /// Adds a new "virtual" target. This is not an actual wire in the witness, but just a target + /// that help facilitate witness generation. In particular, a generator can assign a values to a + /// virtual target, which can then be copied to other (virtual or concrete) targets. When we + /// generate the final witness (a grid of wire values), these virtual targets will go away. + pub fn add_virtual_target(&mut self) -> Target { let index = self.virtual_target_index; self.virtual_target_index += 1; - Target::VirtualAdviceTarget { index } + Target::VirtualTarget { index } } - pub fn add_virtual_advice_targets(&mut self, n: usize) -> Vec { - (0..n).map(|_i| self.add_virtual_advice_target()).collect() + pub fn add_virtual_targets(&mut self, n: usize) -> Vec { + (0..n).map(|_i| self.add_virtual_target()).collect() } pub fn add_gate_no_constants(&mut self, gate_type: GateRef) -> usize { @@ -368,7 +364,11 @@ impl, const D: usize> CircuitBuilder { } for index in 0..self.public_input_index { - target_partitions.add_partition(Target::PublicInput { index }) + target_partitions.add_partition(Target::PublicInput { index }); + } + + for index in 0..self.virtual_target_index { + target_partitions.add_partition(Target::VirtualTarget { index }); } for &(a, b) in &self.copy_constraints { diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 5eb60148..3a2c27f4 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -9,14 +9,14 @@ use crate::wire::Wire; use crate::witness::PartialWitness; impl, const D: usize> CircuitBuilder { - /// Split the given integer into a list of virtual advice targets, where each one represents a - /// bit of the integer, with little-endian ordering. + /// Split the given integer into a list of virtual targets, where each one represents a bit of + /// the integer, with little-endian ordering. /// /// Note that this only handles witness generation; it does not enforce that the decomposition /// is correct. The output should be treated as a "purported" decomposition which must be /// enforced elsewhere. pub(crate) fn split_le_virtual(&mut self, integer: Target, num_bits: usize) -> Vec { - let bit_targets = self.add_virtual_advice_targets(num_bits); + let bit_targets = self.add_virtual_targets(num_bits); self.add_generator(SplitGenerator { integer, bits: bit_targets.clone(), diff --git a/src/permutation_argument.rs b/src/permutation_argument.rs index 54436ecb..cc202a95 100644 --- a/src/permutation_argument.rs +++ b/src/permutation_argument.rs @@ -57,7 +57,7 @@ impl TargetPartitions { } pub fn to_wire_partitions(&self) -> WirePartitions { - // Here we just drop all CircuitInputs, leaving all GateInputs. + // Here we keep just the Wire targets, filtering out everything else. let mut partitions = Vec::new(); let mut indices = HashMap::new(); diff --git a/src/target.rs b/src/target.rs index 8aec0b5a..e765f7eb 100644 --- a/src/target.rs +++ b/src/target.rs @@ -8,7 +8,10 @@ use crate::wire::Wire; pub enum Target { Wire(Wire), PublicInput { index: usize }, - VirtualAdviceTarget { index: usize }, + /// A target that doesn't have any inherent location in the witness (but it can be copied to + /// another target that does). This is useful for representing intermediate values in witness + /// generation. + VirtualTarget { index: usize }, } impl Target { @@ -20,7 +23,7 @@ impl Target { match self { Target::Wire(wire) => wire.is_routable(config), Target::PublicInput { .. } => true, - Target::VirtualAdviceTarget { .. } => false, + Target::VirtualTarget { .. } => true, } } diff --git a/src/witness.rs b/src/witness.rs index 681049d0..863ee7dd 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -30,7 +30,7 @@ impl Witness { F: Extendable, { for &(a, b) in copy_constraints { - // TODO: Take care of public inputs once they land. + // TODO: Take care of public inputs once they land, and virtual targets. if let ( Target::Wire(Wire { gate: a_gate, From 519533d4b775b8f1a7cf4ccb1e4fa6a39b26aadf Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Thu, 1 Jul 2021 10:53:42 -0700 Subject: [PATCH 22/22] Benchmark tweaks (#83) - Configure FRI with a list of arities that's more appropriate for a 2^14 instance. The previous config resulted in a huge final polynomial. - Log the blinding factors, and other logging tweaks. --- src/bin/bench_recursion.rs | 4 +--- src/circuit_builder.rs | 8 ++++++-- src/gadgets/arithmetic.rs | 1 - src/gadgets/arithmetic_extension.rs | 2 +- src/gates/gate_tree.rs | 2 +- src/polynomial/polynomial.rs | 4 +--- src/util/timing.rs | 2 +- 7 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index 0f1b9783..b55ee841 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -21,8 +21,6 @@ fn main() { } fn bench_prove, const D: usize>() { - let gmimc_gate = GMiMCGate::::with_automatic_constants(); - let config = CircuitConfig { num_wires: 134, num_routed_wires: 27, @@ -32,7 +30,7 @@ fn bench_prove, const D: usize>() { fri_config: FriConfig { proof_of_work_bits: 1, rate_bits: 3, - reduction_arity_bits: vec![1, 1, 1, 1], + reduction_arity_bits: vec![2, 2, 2, 2, 2], num_query_rounds: 1, }, }; diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 06c554e3..2d00b403 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -277,6 +277,10 @@ impl, const D: usize> CircuitBuilder { fn blind_and_pad(&mut self) { let (regular_poly_openings, z_openings) = self.blinding_counts(); + info!( + "Adding {} blinding terms for witness polynomials, and {}*2 for Z polynomials", + regular_poly_openings, z_openings + ); let num_routed_wires = self.config.num_routed_wires; let num_wires = self.config.num_wires; @@ -383,12 +387,12 @@ impl, const D: usize> CircuitBuilder { pub fn build(mut self) -> CircuitData { let start = Instant::now(); info!( - "degree before blinding & padding: {}", + "Degree before blinding & padding: {}", self.gate_instances.len() ); self.blind_and_pad(); let degree = self.gate_instances.len(); - info!("degree after blinding & padding: {}", degree); + info!("Degree after blinding & padding: {}", degree); let gates = self.gates.iter().cloned().collect(); let (gate_tree, max_filtered_constraint_degree, num_constants) = Tree::from_gates(gates); diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index bda70624..5c328362 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,7 +1,6 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::Extendable; use crate::target::Target; -use crate::util::bits_u64; impl, const D: usize> CircuitBuilder { /// Computes `-x`. diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index ce53c59f..4f7b1e14 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -6,7 +6,7 @@ use num::Integer; use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; -use crate::field::extension_field::{Extendable, FieldExtension, OEF}; +use crate::field::extension_field::{Extendable, OEF}; use crate::field::field::Field; use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::generator::SimpleGenerator; diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index bf56a690..5c2e084e 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -86,7 +86,7 @@ impl, const D: usize> Tree> { } } info!( - "Found tree with max degree {} and {} constants wires in {}s.", + "Found tree with max degree {} and {} constants wires in {:.4}s.", best_degree, best_num_constants, timer.elapsed().as_secs_f32() diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 5f295030..81d07b8f 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -1,5 +1,3 @@ -use std::time::Instant; - use std::cmp::max; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; @@ -7,7 +5,7 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; use anyhow::{ensure, Result}; use crate::field::extension_field::Extendable; -use crate::field::fft::{fft, ifft, fft_with_options}; +use crate::field::fft::{fft, fft_with_options, ifft}; use crate::field::field::Field; use crate::util::log2_strict; diff --git a/src/util/timing.rs b/src/util/timing.rs index 17136f36..6f27a9d1 100644 --- a/src/util/timing.rs +++ b/src/util/timing.rs @@ -7,7 +7,7 @@ macro_rules! timed { let timer = Instant::now(); let res = $a; - info!("{:.3}s {}", timer.elapsed().as_secs_f32(), $msg); + info!("{:.4}s {}", timer.elapsed().as_secs_f32(), $msg); res }}; }