diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 3a26497e..ec6f60be 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -95,6 +95,11 @@ impl, const D: usize> CircuitBuilder { /// Adds a gate to the circuit, and returns its index. pub fn add_gate(&mut self, gate_type: GateRef, constants: Vec) -> usize { + assert_eq!( + gate_type.0.num_constants(), + constants.len(), + "Number of constants doesn't match." + ); // If we haven't seen a gate of this type before, check that it's compatible with our // circuit configuration, then register it. if !self.gates.contains(&gate_type) { diff --git a/src/circuit_data.rs b/src/circuit_data.rs index c8477464..c6c27d35 100644 --- a/src/circuit_data.rs +++ b/src/circuit_data.rs @@ -51,7 +51,7 @@ impl CircuitConfig { pub(crate) fn large_config() -> Self { Self { num_wires: 134, - num_routed_wires: 12, + num_routed_wires: 28, security_bits: 128, rate_bits: 3, num_challenges: 3, diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 0316b04f..f6cec119 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -1,11 +1,13 @@ use std::convert::{TryFrom, TryInto}; use std::ops::Range; +use itertools::Itertools; + use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::{Extendable, FieldExtension, OEF}; use crate::field::field::Field; -use crate::gates::mul_extension::MulExtensionGate; +use crate::gates::mul_extension::ArithmeticExtensionGate; use crate::target::Target; /// `Target`s representing an element of an extension field. @@ -110,57 +112,155 @@ impl, const D: usize> CircuitBuilder { self.constant_ext_algebra(ExtensionAlgebra::ZERO) } + pub fn double_arithmetic_extension( + &mut self, + const_0: F, + const_1: F, + fixed_multiplicand: ExtensionTarget, + multiplicand_0: ExtensionTarget, + addend_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend_1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]); + + let wire_fixed_multiplicand = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ); + let wire_multiplicand_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); + let wire_addend_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_0()); + let wire_multiplicand_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); + let wire_addend_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_1()); + let wire_output_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + let wire_output_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_1()); + + self.route_extension(fixed_multiplicand, wire_fixed_multiplicand); + self.route_extension(multiplicand_0, wire_multiplicand_0); + self.route_extension(addend_0, wire_addend_0); + self.route_extension(multiplicand_1, wire_multiplicand_1); + self.route_extension(addend_1, wire_addend_1); + self.route_extension(multiplicand_1, wire_multiplicand_1); + (wire_output_0, wire_output_1) + } + + pub fn arithmetic_extension( + &mut self, + const_0: F, + const_1: F, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend: ExtensionTarget, + ) -> ExtensionTarget { + let zero = self.zero_extension(); + self.double_arithmetic_extension( + const_0, + const_1, + multiplicand_0, + multiplicand_1, + addend, + zero, + zero, + ) + .0 + } + pub fn add_extension( &mut self, - mut a: ExtensionTarget, + a: ExtensionTarget, b: ExtensionTarget, ) -> ExtensionTarget { - for i in 0..D { - a.0[i] = self.add(a.0[i], b.0[i]); - } - a + let one = self.one_extension(); + self.arithmetic_extension(F::ONE, F::ONE, one, a, b) + } + + pub fn add_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let one = self.one_extension(); + self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, a1, b1) } pub fn add_ext_algebra( &mut self, - mut a: ExtensionAlgebraTarget, + a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - for i in 0..D { - a.0[i] = self.add_extension(a.0[i], b.0[i]); + let mut res = vec![]; + let d_even = D & (D ^ 1); // = 2 * (D/2) + for mut chunk in &(0..d_even).chunks(2) { + let i = chunk.next().unwrap(); + let j = chunk.next().unwrap(); + let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); + res.extend([o0, o1]); } - a + if D % 2 == 1 { + res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); + } + ExtensionAlgebraTarget(res.try_into().unwrap()) } pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let mut sum = self.zero_extension(); - for term in terms { - sum = self.add_extension(sum, *term); + let zero = self.zero_extension(); + let mut terms = terms.to_vec(); + if terms.len() % 2 == 1 { + terms.push(zero); } - sum + let mut acc0 = zero; + let mut acc1 = zero; + for chunk in terms.chunks_exact(2) { + (acc0, acc1) = self.add_two_extension(acc0, acc1, chunk[0], chunk[1]); + } + self.add_extension(acc0, acc1) } - /// TODO: Change this to using an `arithmetic_extension` function once `MulExtensionGate` supports addend. pub fn sub_extension( &mut self, - mut a: ExtensionTarget, + a: ExtensionTarget, b: ExtensionTarget, ) -> ExtensionTarget { - for i in 0..D { - a.0[i] = self.sub(a.0[i], b.0[i]); - } - a + let one = self.one_extension(); + self.arithmetic_extension(F::ONE, F::NEG_ONE, one, a, b) + } + + pub fn sub_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let one = self.one_extension(); + self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, a1, b1) } pub fn sub_ext_algebra( &mut self, - mut a: ExtensionAlgebraTarget, + a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - for i in 0..D { - a.0[i] = self.sub_extension(a.0[i], b.0[i]); + let mut res = vec![]; + let d_even = D & (D ^ 1); // = 2 * (D/2) + for mut chunk in &(0..d_even).chunks(2) { + let i = chunk.next().unwrap(); + let j = chunk.next().unwrap(); + let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); + res.extend([o0, o1]); } - a + if D % 2 == 1 { + res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); + } + ExtensionAlgebraTarget(res.try_into().unwrap()) } pub fn mul_extension_with_const( @@ -169,17 +269,17 @@ impl, const D: usize> CircuitBuilder { multiplicand_0: ExtensionTarget, multiplicand_1: ExtensionTarget, ) -> ExtensionTarget { - let gate = self.add_gate(MulExtensionGate::new(), vec![const_0]); - - let wire_multiplicand_0 = - ExtensionTarget::from_range(gate, MulExtensionGate::::wires_multiplicand_0()); - let wire_multiplicand_1 = - ExtensionTarget::from_range(gate, MulExtensionGate::::wires_multiplicand_1()); - let wire_output = ExtensionTarget::from_range(gate, MulExtensionGate::::wires_output()); - - self.route_extension(multiplicand_0, wire_multiplicand_0); - self.route_extension(multiplicand_1, wire_multiplicand_1); - wire_output + let zero = self.zero_extension(); + self.double_arithmetic_extension( + const_0, + F::ZERO, + multiplicand_0, + multiplicand_1, + zero, + zero, + zero, + ) + .0 } pub fn mul_extension( @@ -199,12 +299,11 @@ impl, const D: usize> CircuitBuilder { let w = self.constant(F::Extension::W); for i in 0..D { for j in 0..D { - let ai_bi = self.mul_extension(a.0[i], b.0[j]); res[(i + j) % D] = if i + j < D { - self.add_extension(ai_bi, res[(i + j) % D]) + self.mul_add_extension(a.0[i], b.0[j], res[(i + j) % D]) } else { - let w_ai_bi = self.scalar_mul_ext(w, ai_bi); - self.add_extension(w_ai_bi, res[(i + j) % D]) + let ai_bi = self.mul_extension(a.0[i], b.0[j]); + self.scalar_mul_add_extension(w, ai_bi, res[(i + j) % D]) } } } @@ -221,28 +320,35 @@ impl, const D: usize> CircuitBuilder { /// Like `mul_add`, but for `ExtensionTarget`s. Note that, unlike `mul_add`, this has no /// performance benefit over separate muls and adds. - /// TODO: Change this to using an `arithmetic_extension` function once `MulExtensionGate` supports addend. pub fn mul_add_extension( &mut self, a: ExtensionTarget, b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { - let product = self.mul_extension(a, b); - self.add_extension(product, c) + self.arithmetic_extension(F::ONE, F::ONE, a, b, c) } - /// Like `mul_sub`, but for `ExtensionTarget`s. Note that, unlike `mul_sub`, this has no - /// performance benefit over separate muls and subs. - /// TODO: Change this to using an `arithmetic_extension` function once `MulExtensionGate` supports addend. + /// Like `mul_add`, but for `ExtensionTarget`s. + pub fn scalar_mul_add_extension( + &mut self, + a: Target, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let a_ext = self.convert_to_ext(a); + self.arithmetic_extension(F::ONE, F::ONE, a_ext, b, c) + } + + /// Like `mul_sub`, but for `ExtensionTarget`s. pub fn scalar_mul_sub_extension( &mut self, a: Target, b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { - let product = self.scalar_mul_ext(a, b); - self.sub_extension(product, c) + let a_ext = self.convert_to_ext(a); + self.arithmetic_extension(F::ONE, F::NEG_ONE, a_ext, b, c) } /// Returns `a * b`, where `b` is in the extension field and `a` is in the base field. diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 520f8cd9..7fc7e08d 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,11 +1,9 @@ -use std::convert::TryInto; - use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::gates::arithmetic::ArithmeticGate; -use crate::gates::mul_extension::MulExtensionGate; +use crate::gates::mul_extension::ArithmeticExtensionGate; use crate::generator::SimpleGenerator; use crate::target::Target; use crate::wire::Wire; @@ -253,16 +251,14 @@ impl, const D: usize> CircuitBuilder { y: ExtensionTarget, ) -> ExtensionTarget { // Add an `ArithmeticGate` to compute `q * y`. - let gate = self.add_gate(MulExtensionGate::new(), vec![F::ONE]); + let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]); let multiplicand_0 = - Target::wires_from_range(gate, MulExtensionGate::::wires_multiplicand_0()); - let multiplicand_0 = ExtensionTarget(multiplicand_0.try_into().unwrap()); + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); let multiplicand_1 = - Target::wires_from_range(gate, MulExtensionGate::::wires_multiplicand_1()); - let multiplicand_1 = ExtensionTarget(multiplicand_1.try_into().unwrap()); - let output = Target::wires_from_range(gate, MulExtensionGate::::wires_output()); - let output = ExtensionTarget(output.try_into().unwrap()); + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); + let output = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); self.add_generator(QuotientGeneratorExtension { numerator: x, diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 40b91c1e..37746685 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -72,13 +72,10 @@ mod tests { fn test_interpolate() { type F = CrandallField; type FF = QuarticCrandallField; - let config = CircuitConfig { - num_routed_wires: 18, - ..CircuitConfig::large_config() - }; + let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config); - let len = 2; + let len = 4; let points = (0..len) .map(|_| (F::rand(), FF::rand())) .collect::>(); diff --git a/src/gates/mul_extension.rs b/src/gates/mul_extension.rs index e378e2b1..a3006837 100644 --- a/src/gates/mul_extension.rs +++ b/src/gates/mul_extension.rs @@ -10,39 +10,63 @@ use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; use crate::witness::PartialWitness; -/// A gate which can multiply two field extension elements. -/// TODO: Add an addend if `NUM_ROUTED_WIRES` is large enough. +/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. #[derive(Debug)] -pub struct MulExtensionGate; +pub struct ArithmeticExtensionGate; -impl MulExtensionGate { +impl ArithmeticExtensionGate { pub fn new>() -> GateRef { - GateRef::new(MulExtensionGate) + GateRef::new(ArithmeticExtensionGate) } - pub fn wires_multiplicand_0() -> Range { + pub fn wires_fixed_multiplicand() -> Range { 0..D } - pub fn wires_multiplicand_1() -> Range { + pub fn wires_multiplicand_0() -> Range { D..2 * D } - pub fn wires_output() -> Range { + pub fn wires_addend_0() -> Range { 2 * D..3 * D } + pub fn wires_multiplicand_1() -> Range { + 3 * D..4 * D + } + pub fn wires_addend_1() -> Range { + 4 * D..5 * D + } + pub fn wires_output_0() -> Range { + 5 * D..6 * D + } + pub fn wires_output_1() -> Range { + 6 * D..7 * D + } } -impl, const D: usize> Gate for MulExtensionGate { +impl, const D: usize> Gate for ArithmeticExtensionGate { fn id(&self) -> String { format!("{:?}", self) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let output = vars.get_local_ext_algebra(Self::wires_output()); - let computed_output = multiplicand_0 * multiplicand_1 * const_0.into(); - (output - computed_output).to_basefield_array().to_vec() + let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); + let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); + let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + + let computed_output_0 = + fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into(); + let computed_output_1 = + fixed_multiplicand * multiplicand_1 * const_1.into() + addend_1 * const_1.into(); + + let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec(); + constraints.extend((output_1 - computed_output_1).to_basefield_array()); + constraints } fn eval_unfiltered_recursively( @@ -51,13 +75,31 @@ impl, const D: usize> Gate for MulExtensionGate { vars: EvaluationTargets, ) -> Vec> { let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let output = vars.get_local_ext_algebra(Self::wires_output()); - let computed_output = builder.mul_ext_algebra(multiplicand_0, multiplicand_1); - let computed_output = builder.scalar_mul_ext_algebra(const_0, computed_output); - let diff = builder.sub_ext_algebra(output, computed_output); - diff.to_ext_target_array().to_vec() + let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); + let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); + let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + + let computed_output_0 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_0); + let computed_output_0 = builder.scalar_mul_ext_algebra(const_0, computed_output_0); + let scaled_addend_0 = builder.scalar_mul_ext_algebra(const_1, addend_0); + let computed_output_0 = builder.add_ext_algebra(computed_output_0, scaled_addend_0); + + let computed_output_1 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_1); + let computed_output_1 = builder.scalar_mul_ext_algebra(const_0, computed_output_1); + let scaled_addend_1 = builder.scalar_mul_ext_algebra(const_1, addend_1); + let computed_output_1 = builder.add_ext_algebra(computed_output_1, scaled_addend_1); + + let diff_0 = builder.sub_ext_algebra(output_0, computed_output_0); + let diff_1 = builder.sub_ext_algebra(output_1, computed_output_1); + let mut constraints = diff_0.to_ext_target_array().to_vec(); + constraints.extend(diff_1.to_ext_target_array()); + constraints } fn generators( @@ -68,16 +110,17 @@ impl, const D: usize> Gate for MulExtensionGate { let gen = MulExtensionGenerator { gate_index, const_0: local_constants[0], + const_1: local_constants[1], }; vec![Box::new(gen)] } fn num_wires(&self) -> usize { - 12 + 28 } fn num_constants(&self) -> usize { - 1 + 2 } fn degree(&self) -> usize { @@ -85,49 +128,59 @@ impl, const D: usize> Gate for MulExtensionGate { } fn num_constraints(&self) -> usize { - D + 2 * D } } struct MulExtensionGenerator, const D: usize> { gate_index: usize, const_0: F, + const_1: F, } impl, const D: usize> SimpleGenerator for MulExtensionGenerator { fn dependencies(&self) -> Vec { - MulExtensionGate::::wires_multiplicand_0() - .chain(MulExtensionGate::::wires_multiplicand_1()) - .map(|i| { - Target::Wire(Wire { - gate: self.gate_index, - input: i, - }) - }) + ArithmeticExtensionGate::::wires_fixed_multiplicand() + .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) + .chain(ArithmeticExtensionGate::::wires_addend_0()) + .chain(ArithmeticExtensionGate::::wires_multiplicand_1()) + .chain(ArithmeticExtensionGate::::wires_addend_1()) + .map(|i| Target::wire(self.gate_index, i)) .collect() } fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let multiplicand_0_target = ExtensionTarget::from_range( + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let fixed_multiplicand = + extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); + let multiplicand_0 = + extract_extension(ArithmeticExtensionGate::::wires_multiplicand_0()); + let addend_0 = extract_extension(ArithmeticExtensionGate::::wires_addend_0()); + let multiplicand_1 = + extract_extension(ArithmeticExtensionGate::::wires_multiplicand_1()); + let addend_1 = extract_extension(ArithmeticExtensionGate::::wires_addend_1()); + + let output_target_0 = ExtensionTarget::from_range( self.gate_index, - MulExtensionGate::::wires_multiplicand_0(), + ArithmeticExtensionGate::::wires_output_0(), ); - let multiplicand_0 = witness.get_extension_target(multiplicand_0_target); - - let multiplicand_1_target = ExtensionTarget::from_range( + let output_target_1 = ExtensionTarget::from_range( self.gate_index, - MulExtensionGate::::wires_multiplicand_1(), + ArithmeticExtensionGate::::wires_output_1(), ); - let multiplicand_1 = witness.get_extension_target(multiplicand_1_target); - let output_target = - ExtensionTarget::from_range(self.gate_index, MulExtensionGate::::wires_output()); - - let computed_output = - F::Extension::from_basefield(self.const_0) * multiplicand_0 * multiplicand_1; + let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() + + addend_0 * self.const_1.into(); + let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into() + + addend_1 * self.const_1.into(); let mut pw = PartialWitness::new(); - pw.set_extension_target(output_target, computed_output); + pw.set_extension_target(output_target_0, computed_output_0); + pw.set_extension_target(output_target_1, computed_output_1); pw } } @@ -136,10 +189,10 @@ impl, const D: usize> SimpleGenerator for MulExtensionGenera mod tests { use crate::field::crandall_field::CrandallField; use crate::gates::gate_testing::test_low_degree; - use crate::gates::mul_extension::MulExtensionGate; + use crate::gates::mul_extension::ArithmeticExtensionGate; #[test] fn low_degree() { - test_low_degree(MulExtensionGate::<4>::new::()) + test_low_degree(ArithmeticExtensionGate::<4>::new::()) } } diff --git a/src/lib.rs b/src/lib.rs index 09f90f81..adfdf2cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(destructuring_assignment)] + pub mod circuit_builder; pub mod circuit_data; pub mod field; diff --git a/src/recursive_verifier.rs b/src/recursive_verifier.rs index c1005a67..bb9d724e 100644 --- a/src/recursive_verifier.rs +++ b/src/recursive_verifier.rs @@ -5,7 +5,7 @@ use crate::gates::gate::GateRef; use crate::proof::ProofTarget; const MIN_WIRES: usize = 120; // TODO: Double check. -const MIN_ROUTED_WIRES: usize = 8; // TODO: Double check. +const MIN_ROUTED_WIRES: usize = 28; // TODO: Double check. /// Recursively verifies an inner proof. pub fn add_recursive_verifier, const D: usize>(