diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 65c42246..1bf938a0 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -63,6 +63,10 @@ impl, const D: usize> CircuitBuilder { } } + pub fn num_gates(&self) -> usize { + self.gate_instances.len() + } + pub fn add_public_input(&mut self) -> Target { let index = self.public_input_index; self.public_input_index += 1; @@ -97,6 +101,11 @@ impl, const D: usize> CircuitBuilder { /// Adds a gate to the circuit, and returns its index. pub fn add_gate(&mut self, gate_type: GateRef, constants: Vec) -> usize { + assert_eq!( + gate_type.0.num_constants(), + constants.len(), + "Number of constants doesn't match." + ); // If we haven't seen a gate of this type before, check that it's compatible with our // circuit configuration, then register it. if !self.gates.contains(&gate_type) { diff --git a/src/circuit_data.rs b/src/circuit_data.rs index fa575dcc..6f352832 100644 --- a/src/circuit_data.rs +++ b/src/circuit_data.rs @@ -54,7 +54,7 @@ impl CircuitConfig { pub(crate) fn large_config() -> Self { Self { num_wires: 134, - num_routed_wires: 12, + num_routed_wires: 28, security_bits: 128, rate_bits: 3, num_challenges: 3, diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 0316b04f..9d60847e 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -5,7 +5,6 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::{Extendable, FieldExtension, OEF}; use crate::field::field::Field; -use crate::gates::mul_extension::MulExtensionGate; use crate::target::Target; /// `Target`s representing an element of an extension field. @@ -110,160 +109,6 @@ impl, const D: usize> CircuitBuilder { self.constant_ext_algebra(ExtensionAlgebra::ZERO) } - pub fn add_extension( - &mut self, - mut a: ExtensionTarget, - b: ExtensionTarget, - ) -> ExtensionTarget { - for i in 0..D { - a.0[i] = self.add(a.0[i], b.0[i]); - } - a - } - - pub fn add_ext_algebra( - &mut self, - mut a: ExtensionAlgebraTarget, - b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - for i in 0..D { - a.0[i] = self.add_extension(a.0[i], b.0[i]); - } - a - } - - pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let mut sum = self.zero_extension(); - for term in terms { - sum = self.add_extension(sum, *term); - } - sum - } - - /// TODO: Change this to using an `arithmetic_extension` function once `MulExtensionGate` supports addend. - pub fn sub_extension( - &mut self, - mut a: ExtensionTarget, - b: ExtensionTarget, - ) -> ExtensionTarget { - for i in 0..D { - a.0[i] = self.sub(a.0[i], b.0[i]); - } - a - } - - pub fn sub_ext_algebra( - &mut self, - mut a: ExtensionAlgebraTarget, - b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - for i in 0..D { - a.0[i] = self.sub_extension(a.0[i], b.0[i]); - } - a - } - - pub fn mul_extension_with_const( - &mut self, - const_0: F, - multiplicand_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - ) -> ExtensionTarget { - let gate = self.add_gate(MulExtensionGate::new(), vec![const_0]); - - let wire_multiplicand_0 = - ExtensionTarget::from_range(gate, MulExtensionGate::::wires_multiplicand_0()); - let wire_multiplicand_1 = - ExtensionTarget::from_range(gate, MulExtensionGate::::wires_multiplicand_1()); - let wire_output = ExtensionTarget::from_range(gate, MulExtensionGate::::wires_output()); - - self.route_extension(multiplicand_0, wire_multiplicand_0); - self.route_extension(multiplicand_1, wire_multiplicand_1); - wire_output - } - - pub fn mul_extension( - &mut self, - multiplicand_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - ) -> ExtensionTarget { - self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1) - } - - pub fn mul_ext_algebra( - &mut self, - a: ExtensionAlgebraTarget, - b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - let mut res = [self.zero_extension(); D]; - let w = self.constant(F::Extension::W); - for i in 0..D { - for j in 0..D { - let ai_bi = self.mul_extension(a.0[i], b.0[j]); - res[(i + j) % D] = if i + j < D { - self.add_extension(ai_bi, res[(i + j) % D]) - } else { - let w_ai_bi = self.scalar_mul_ext(w, ai_bi); - self.add_extension(w_ai_bi, res[(i + j) % D]) - } - } - } - ExtensionAlgebraTarget(res) - } - - pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let mut product = self.one_extension(); - for term in terms { - product = self.mul_extension(product, *term); - } - product - } - - /// Like `mul_add`, but for `ExtensionTarget`s. Note that, unlike `mul_add`, this has no - /// performance benefit over separate muls and adds. - /// TODO: Change this to using an `arithmetic_extension` function once `MulExtensionGate` supports addend. - pub fn mul_add_extension( - &mut self, - a: ExtensionTarget, - b: ExtensionTarget, - c: ExtensionTarget, - ) -> ExtensionTarget { - let product = self.mul_extension(a, b); - self.add_extension(product, c) - } - - /// Like `mul_sub`, but for `ExtensionTarget`s. Note that, unlike `mul_sub`, this has no - /// performance benefit over separate muls and subs. - /// TODO: Change this to using an `arithmetic_extension` function once `MulExtensionGate` supports addend. - pub fn scalar_mul_sub_extension( - &mut self, - a: Target, - b: ExtensionTarget, - c: ExtensionTarget, - ) -> ExtensionTarget { - let product = self.scalar_mul_ext(a, b); - self.sub_extension(product, c) - } - - /// Returns `a * b`, where `b` is in the extension field and `a` is in the base field. - pub fn scalar_mul_ext(&mut self, a: Target, b: ExtensionTarget) -> ExtensionTarget { - let a_ext = self.convert_to_ext(a); - self.mul_extension(a_ext, b) - } - - /// Returns `a * b`, where `b` is in the extension of the extension field, and `a` is in the - /// extension field. - pub fn scalar_mul_ext_algebra( - &mut self, - a: ExtensionTarget, - mut b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - for i in 0..D { - b.0[i] = self.mul_extension(a, b.0[i]); - } - b - } - pub fn convert_to_ext(&mut self, t: Target) -> ExtensionTarget { let zero = self.zero(); let mut arr = [zero; D]; diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 2f8b1559..bda70624 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,15 +1,7 @@ -use std::convert::TryInto; - use crate::circuit_builder::CircuitBuilder; -use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field::Field; -use crate::gates::arithmetic::ArithmeticGate; -use crate::gates::mul_extension::MulExtensionGate; -use crate::generator::SimpleGenerator; +use crate::field::extension_field::Extendable; use crate::target::Target; -use crate::wire::Wire; -use crate::witness::PartialWitness; +use crate::util::bits_u64; impl, const D: usize> CircuitBuilder { /// Computes `-x`. @@ -43,30 +35,18 @@ impl, const D: usize> CircuitBuilder { { return result; } + let multiplicand_0_ext = self.convert_to_ext(multiplicand_0); + let multiplicand_1_ext = self.convert_to_ext(multiplicand_1); + let addend_ext = self.convert_to_ext(addend); - let gate = self.add_gate(ArithmeticGate::new(), vec![const_0, const_1]); - - let wire_multiplicand_0 = Wire { - gate, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }; - let wire_multiplicand_1 = Wire { - gate, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }; - let wire_addend = Wire { - gate, - input: ArithmeticGate::WIRE_ADDEND, - }; - let wire_output = Wire { - gate, - input: ArithmeticGate::WIRE_OUTPUT, - }; - - self.route(multiplicand_0, Target::Wire(wire_multiplicand_0)); - self.route(multiplicand_1, Target::Wire(wire_multiplicand_1)); - self.route(addend, Target::Wire(wire_addend)); - Target::Wire(wire_output) + self.arithmetic_extension( + const_0, + const_1, + multiplicand_0_ext, + multiplicand_1_ext, + addend_ext, + ) + .0[0] } /// Checks for special cases where the value of @@ -144,6 +124,7 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, one, F::ONE, y) } + // TODO: Can be made `2*D` times more efficient by using all wires of an `ArithmeticExtensionGate`. pub fn add_many(&mut self, terms: &[Target]) -> Target { let mut sum = self.zero(); for term in terms { @@ -174,21 +155,31 @@ impl, const D: usize> CircuitBuilder { } // TODO: Optimize this, maybe with a new gate. + // TODO: Test /// Exponentiate `base` to the power of `exponent`, where `exponent < 2^num_bits`. pub fn exp(&mut self, base: Target, exponent: Target, num_bits: usize) -> Target { let mut current = base; - let one = self.one(); - let mut product = one; + let one_ext = self.one_extension(); + let mut product = self.one(); let exponent_bits = self.split_le(exponent, num_bits); for bit in exponent_bits.into_iter() { - product = self.mul_many(&[bit, current, product]); + let current_ext = self.convert_to_ext(current); + let multiplicand = self.select(bit, current_ext, one_ext); + product = self.mul(product, multiplicand.0[0]); current = self.mul(current, current); } product } + /// Exponentiate `base` to the power of a known `exponent`. + // TODO: Test + pub fn exp_u64(&mut self, base: Target, exponent: u64) -> Target { + let base_ext = self.convert_to_ext(base); + self.exp_u64_extension(base_ext, exponent).0[0] + } + /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in /// some cases, as it allows `0 / 0 = `. pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target { @@ -207,224 +198,8 @@ impl, const D: usize> CircuitBuilder { return self.constant(x_const / y_const); } - // Add an `ArithmeticGate` to compute `q * y`. - let gate = self.add_gate(ArithmeticGate::new(), vec![F::ONE, F::ZERO]); - - let wire_multiplicand_0 = Wire { - gate, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }; - let wire_multiplicand_1 = Wire { - gate, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }; - let wire_addend = Wire { - gate, - input: ArithmeticGate::WIRE_ADDEND, - }; - let wire_output = Wire { - gate, - input: ArithmeticGate::WIRE_OUTPUT, - }; - - let q = Target::Wire(wire_multiplicand_0); - self.add_generator(QuotientGenerator { - numerator: x, - denominator: y, - quotient: q, - }); - - self.route(y, Target::Wire(wire_multiplicand_1)); - - // This can be anything, since the whole second term has a weight of zero. - self.route(zero, Target::Wire(wire_addend)); - - let q_y = Target::Wire(wire_output); - self.assert_equal(q_y, x); - - q - } - - /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in - /// some cases, as it allows `0 / 0 = `. - pub fn div_unsafe_extension( - &mut self, - x: ExtensionTarget, - y: ExtensionTarget, - ) -> ExtensionTarget { - // Add an `ArithmeticGate` to compute `q * y`. - let gate = self.add_gate(MulExtensionGate::new(), vec![F::ONE]); - - let multiplicand_0 = - Target::wires_from_range(gate, MulExtensionGate::::wires_multiplicand_0()); - let multiplicand_0 = ExtensionTarget(multiplicand_0.try_into().unwrap()); - let multiplicand_1 = - Target::wires_from_range(gate, MulExtensionGate::::wires_multiplicand_1()); - let multiplicand_1 = ExtensionTarget(multiplicand_1.try_into().unwrap()); - let output = Target::wires_from_range(gate, MulExtensionGate::::wires_output()); - let output = ExtensionTarget(output.try_into().unwrap()); - - self.add_generator(QuotientGeneratorExtension { - numerator: x, - denominator: y, - quotient: multiplicand_0, - }); - - self.route_extension(y, multiplicand_1); - - self.assert_equal_extension(output, x); - - multiplicand_0 - } -} - -struct QuotientGenerator { - numerator: Target, - denominator: Target, - quotient: Target, -} - -impl SimpleGenerator for QuotientGenerator { - fn dependencies(&self) -> Vec { - vec![self.numerator, self.denominator] - } - - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let num = witness.get_target(self.numerator); - let den = witness.get_target(self.denominator); - PartialWitness::singleton_target(self.quotient, num / den) - } -} - -struct QuotientGeneratorExtension { - numerator: ExtensionTarget, - denominator: ExtensionTarget, - quotient: ExtensionTarget, -} - -impl, const D: usize> SimpleGenerator for QuotientGeneratorExtension { - fn dependencies(&self) -> Vec { - let mut deps = self.numerator.to_target_array().to_vec(); - deps.extend(&self.denominator.to_target_array()); - deps - } - - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let num = witness.get_extension_target(self.numerator); - let dem = witness.get_extension_target(self.denominator); - let quotient = num / dem; - let mut pw = PartialWitness::new(); - for i in 0..D { - pw.set_target( - self.quotient.to_target_array()[i], - quotient.to_basefield_array()[i], - ); - } - pw - } -} - -/// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. -#[derive(Clone)] -pub struct PowersTarget { - base: ExtensionTarget, - current: ExtensionTarget, -} - -impl PowersTarget { - pub fn next>( - &mut self, - builder: &mut CircuitBuilder, - ) -> ExtensionTarget { - let result = self.current; - self.current = builder.mul_extension(self.base, self.current); - result - } - - pub fn repeated_frobenius>( - self, - k: usize, - builder: &mut CircuitBuilder, - ) -> Self { - let Self { base, current } = self; - Self { - base: base.repeated_frobenius(k, builder), - current: current.repeated_frobenius(k, builder), - } - } -} - -impl, const D: usize> CircuitBuilder { - pub fn powers(&mut self, base: ExtensionTarget) -> PowersTarget { - PowersTarget { - base, - current: self.one_extension(), - } - } -} - -#[cfg(test)] -mod tests { - use crate::circuit_builder::CircuitBuilder; - use crate::circuit_data::CircuitConfig; - use crate::field::crandall_field::CrandallField; - use crate::field::extension_field::quartic::QuarticCrandallField; - use crate::field::field::Field; - use crate::fri::FriConfig; - use crate::gates::arithmetic::ArithmeticGate; - use crate::target::Target; - use crate::witness::PartialWitness; - - #[test] - fn test_div() { - type F = CrandallField; - type FF = QuarticCrandallField; - const D: usize = 4; - - let config = CircuitConfig::large_config(); - - let mut builder = CircuitBuilder::::new(config); - - let x = F::rand(); - let y = F::rand(); - let mut pw = PartialWitness::new(); - /// Computes x*x + 0*y = x^2. - let square_gate = builder.add_gate(ArithmeticGate::new(), vec![F::ONE, F::ZERO]); - pw.set_target(Target::wire(square_gate, 0), x); - pw.set_target(Target::wire(square_gate, 1), x); - let x2t = Target::wire(square_gate, ArithmeticGate::WIRE_OUTPUT); - let yt = Target::wire(square_gate, ArithmeticGate::WIRE_ADDEND); - pw.set_target(yt, y); - // Constant for x*x/y. - let zt = builder.constant(x * x / y); - // Computed division for x*x/y using the division gadget. - let comp_zt = builder.div_unsafe(x2t, yt); - builder.assert_equal(zt, comp_zt); - - let data = builder.build(); - let proof = data.prove(pw); - } - - #[test] - fn test_div_extension() { - type F = CrandallField; - type FF = QuarticCrandallField; - const D: usize = 4; - - let config = CircuitConfig::large_config(); - - let mut builder = CircuitBuilder::::new(config); - - let x = FF::rand(); - let y = FF::rand(); - let z = x / y; - let xt = builder.constant_extension(x); - let yt = builder.constant_extension(y); - let zt = builder.constant_extension(z); - let comp_zt = builder.div_unsafe_extension(xt, yt); - builder.assert_equal_extension(zt, comp_zt); - - let data = builder.build(); - let proof = data.prove(PartialWitness::new()); + let x_ext = self.convert_to_ext(x); + let y_ext = self.convert_to_ext(y); + self.div_unsafe_extension(x_ext, y_ext).0[0] } } diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs new file mode 100644 index 00000000..ce53c59f --- /dev/null +++ b/src/gadgets/arithmetic_extension.rs @@ -0,0 +1,465 @@ +use std::convert::TryInto; +use std::ops::Range; + +use itertools::Itertools; +use num::Integer; + +use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; +use crate::field::extension_field::{Extendable, FieldExtension, OEF}; +use crate::field::field::Field; +use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::generator::SimpleGenerator; +use crate::target::Target; +use crate::util::bits_u64; +use crate::witness::PartialWitness; + +impl, const D: usize> CircuitBuilder { + pub fn double_arithmetic_extension( + &mut self, + const_0: F, + const_1: F, + fixed_multiplicand: ExtensionTarget, + multiplicand_0: ExtensionTarget, + addend_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend_1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]); + + let wire_fixed_multiplicand = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ); + let wire_multiplicand_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); + let wire_addend_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_0()); + let wire_multiplicand_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); + let wire_addend_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_1()); + let wire_output_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + let wire_output_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_1()); + + self.route_extension(fixed_multiplicand, wire_fixed_multiplicand); + self.route_extension(multiplicand_0, wire_multiplicand_0); + self.route_extension(addend_0, wire_addend_0); + self.route_extension(multiplicand_1, wire_multiplicand_1); + self.route_extension(addend_1, wire_addend_1); + (wire_output_0, wire_output_1) + } + + pub fn arithmetic_extension( + &mut self, + const_0: F, + const_1: F, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend: ExtensionTarget, + ) -> ExtensionTarget { + let zero = self.zero_extension(); + self.double_arithmetic_extension( + const_0, + const_1, + multiplicand_0, + multiplicand_1, + addend, + zero, + zero, + ) + .0 + } + + pub fn add_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + ) -> ExtensionTarget { + let one = self.one_extension(); + self.arithmetic_extension(F::ONE, F::ONE, one, a, b) + } + + pub fn add_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let one = self.one_extension(); + self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, a1, b1) + } + + pub fn add_ext_algebra( + &mut self, + a: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + // We run two additions in parallel. So `[a0,a1,a2,a3] + [b0,b1,b2,b3]` is computed with two + // `add_two_extension`, first `[a0,a1]+[b0,b1]` then `[a2,a3]+[b2,b3]`. + let mut res = Vec::with_capacity(D); + // We need some extra logic if D is odd. + let d_even = D & (D ^ 1); // = 2 * (D/2) + for mut chunk in &(0..d_even).chunks(2) { + let i = chunk.next().unwrap(); + let j = chunk.next().unwrap(); + let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); + res.extend([o0, o1]); + } + if D.is_odd() { + res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); + } + ExtensionAlgebraTarget(res.try_into().unwrap()) + } + + pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { + let zero = self.zero_extension(); + let mut terms = terms.to_vec(); + if terms.len().is_odd() { + terms.push(zero); + } + // We maintain two accumulators, one for the sum of even elements, and one for odd elements. + let mut acc0 = zero; + let mut acc1 = zero; + for chunk in terms.chunks_exact(2) { + (acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]); + } + // We sum both accumulators to get the final result. + self.add_extension(acc0, acc1) + } + + pub fn sub_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + ) -> ExtensionTarget { + let one = self.one_extension(); + self.arithmetic_extension(F::ONE, F::NEG_ONE, one, a, b) + } + + pub fn sub_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let one = self.one_extension(); + self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, a1, b1) + } + + pub fn sub_ext_algebra( + &mut self, + a: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + // See `add_ext_algebra`. + let mut res = Vec::with_capacity(D); + let d_even = D & (D ^ 1); // = 2 * (D/2) + for mut chunk in &(0..d_even).chunks(2) { + let i = chunk.next().unwrap(); + let j = chunk.next().unwrap(); + let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); + res.extend([o0, o1]); + } + if D.is_odd() { + res.push(self.sub_extension(a.0[D - 1], b.0[D - 1])); + } + ExtensionAlgebraTarget(res.try_into().unwrap()) + } + + pub fn mul_extension_with_const( + &mut self, + const_0: F, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + ) -> ExtensionTarget { + let zero = self.zero_extension(); + self.double_arithmetic_extension( + const_0, + F::ZERO, + multiplicand_0, + multiplicand_1, + zero, + zero, + zero, + ) + .0 + } + + pub fn mul_extension( + &mut self, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + ) -> ExtensionTarget { + self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1) + } + + /// Computes `x^2`. + pub fn square_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { + self.mul_extension(x, x) + } + + pub fn mul_ext_algebra( + &mut self, + a: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + let mut res = [self.zero_extension(); D]; + let w = self.constant(F::Extension::W); + for i in 0..D { + for j in 0..D { + res[(i + j) % D] = if i + j < D { + self.mul_add_extension(a.0[i], b.0[j], res[(i + j) % D]) + } else { + let ai_bi = self.mul_extension(a.0[i], b.0[j]); + self.scalar_mul_add_extension(w, ai_bi, res[(i + j) % D]) + } + } + } + ExtensionAlgebraTarget(res) + } + + pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { + let mut product = self.one_extension(); + for term in terms { + product = self.mul_extension(product, *term); + } + product + } + + /// Like `mul_add`, but for `ExtensionTarget`s. + pub fn mul_add_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + self.arithmetic_extension(F::ONE, F::ONE, a, b, c) + } + + /// Like `mul_add`, but for `ExtensionTarget`s. + pub fn scalar_mul_add_extension( + &mut self, + a: Target, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let a_ext = self.convert_to_ext(a); + self.arithmetic_extension(F::ONE, F::ONE, a_ext, b, c) + } + + /// Like `mul_sub`, but for `ExtensionTarget`s. + pub fn scalar_mul_sub_extension( + &mut self, + a: Target, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let a_ext = self.convert_to_ext(a); + self.arithmetic_extension(F::ONE, F::NEG_ONE, a_ext, b, c) + } + + /// Returns `a * b`, where `b` is in the extension field and `a` is in the base field. + pub fn scalar_mul_ext(&mut self, a: Target, b: ExtensionTarget) -> ExtensionTarget { + let a_ext = self.convert_to_ext(a); + self.mul_extension(a_ext, b) + } + + /// Returns `a * b`, where `b` is in the extension of the extension field, and `a` is in the + /// extension field. + pub fn scalar_mul_ext_algebra( + &mut self, + a: ExtensionTarget, + mut b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + for i in 0..D { + b.0[i] = self.mul_extension(a, b.0[i]); + } + b + } + + /// Exponentiate `base` to the power of a known `exponent`. + // TODO: Test + pub fn exp_u64_extension( + &mut self, + base: ExtensionTarget, + exponent: u64, + ) -> ExtensionTarget { + let mut current = base; + let mut product = self.one_extension(); + + for j in 0..bits_u64(exponent as u64) { + if (exponent >> j & 1) != 0 { + product = self.mul_extension(product, current); + } + current = self.square_extension(current); + } + product + } + + /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in + /// some cases, as it allows `0 / 0 = `. + pub fn div_unsafe_extension( + &mut self, + x: ExtensionTarget, + y: ExtensionTarget, + ) -> ExtensionTarget { + // Add an `ArithmeticExtensionGate` to compute `q * y`. + let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]); + + let multiplicand_0 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ); + let multiplicand_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); + let output = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + + self.add_generator(QuotientGeneratorExtension { + numerator: x, + denominator: y, + quotient: multiplicand_0, + }); + // We need to zero out the other wires for the `ArithmeticExtensionGenerator` to hit. + self.add_generator(ZeroOutGenerator { + gate_index: gate, + ranges: vec![ + ArithmeticExtensionGate::::wires_addend_0(), + ArithmeticExtensionGate::::wires_multiplicand_1(), + ArithmeticExtensionGate::::wires_addend_1(), + ], + }); + + self.route_extension(y, multiplicand_1); + self.assert_equal_extension(output, x); + + multiplicand_0 + } +} + +/// Generator used to zero out wires at a given gate index and ranges. +pub struct ZeroOutGenerator { + gate_index: usize, + ranges: Vec>, +} + +impl SimpleGenerator for ZeroOutGenerator { + fn dependencies(&self) -> Vec { + Vec::new() + } + + fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { + let mut pw = PartialWitness::new(); + for t in self + .ranges + .iter() + .flat_map(|r| Target::wires_from_range(self.gate_index, r.clone())) + { + pw.set_target(t, F::ZERO); + } + + pw + } +} + +struct QuotientGeneratorExtension { + numerator: ExtensionTarget, + denominator: ExtensionTarget, + quotient: ExtensionTarget, +} + +impl, const D: usize> SimpleGenerator for QuotientGeneratorExtension { + fn dependencies(&self) -> Vec { + let mut deps = self.numerator.to_target_array().to_vec(); + deps.extend(&self.denominator.to_target_array()); + deps + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let num = witness.get_extension_target(self.numerator); + let dem = witness.get_extension_target(self.denominator); + let quotient = num / dem; + let mut pw = PartialWitness::new(); + pw.set_extension_target(self.quotient, quotient); + + pw + } +} + +/// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. +#[derive(Clone)] +pub struct PowersTarget { + base: ExtensionTarget, + current: ExtensionTarget, +} + +impl PowersTarget { + pub fn next>( + &mut self, + builder: &mut CircuitBuilder, + ) -> ExtensionTarget { + let result = self.current; + self.current = builder.mul_extension(self.base, self.current); + result + } + + pub fn repeated_frobenius>( + self, + k: usize, + builder: &mut CircuitBuilder, + ) -> Self { + let Self { base, current } = self; + Self { + base: base.repeated_frobenius(k, builder), + current: current.repeated_frobenius(k, builder), + } + } +} + +impl, const D: usize> CircuitBuilder { + pub fn powers(&mut self, base: ExtensionTarget) -> PowersTarget { + PowersTarget { + base, + current: self.one_extension(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::circuit_builder::CircuitBuilder; + use crate::circuit_data::CircuitConfig; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::field::field::Field; + use crate::fri::FriConfig; + use crate::witness::PartialWitness; + + #[test] + fn test_div_extension() { + type F = CrandallField; + type FF = QuarticCrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand(); + let y = FF::rand(); + let z = x / y; + let xt = builder.constant_extension(x); + let yt = builder.constant_extension(y); + let zt = builder.constant_extension(z); + let comp_zt = builder.div_unsafe_extension(xt, yt); + builder.assert_equal_extension(zt, comp_zt); + + let data = builder.build(); + let proof = data.prove(PartialWitness::new()); + } +} diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 40b91c1e..37746685 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -72,13 +72,10 @@ mod tests { fn test_interpolate() { type F = CrandallField; type FF = QuarticCrandallField; - let config = CircuitConfig { - num_routed_wires: 18, - ..CircuitConfig::large_config() - }; + let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config); - let len = 2; + let len = 4; let points = (0..len) .map(|_| (F::rand(), FF::rand())) .collect::>(); diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index a1e041fc..2f216870 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -1,4 +1,5 @@ pub mod arithmetic; +pub mod arithmetic_extension; pub mod hash; pub mod insert; pub mod interpolation; diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 0d0fdd7c..39baa226 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -1,34 +1,47 @@ +use std::ops::Range; + use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; -use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; -use crate::wire::Wire; use crate::witness::PartialWitness; -/// A gate which can be configured to perform various arithmetic. In particular, it computes -/// -/// ```text -/// output := const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend -/// ``` +/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. #[derive(Debug)] -pub struct ArithmeticGate; +pub struct ArithmeticExtensionGate; -impl ArithmeticGate { - pub fn new, const D: usize>() -> GateRef { - GateRef::new(ArithmeticGate) +impl ArithmeticExtensionGate { + pub fn new>() -> GateRef { + GateRef::new(ArithmeticExtensionGate) } - pub const WIRE_MULTIPLICAND_0: usize = 0; - pub const WIRE_MULTIPLICAND_1: usize = 1; - pub const WIRE_ADDEND: usize = 2; - pub const WIRE_OUTPUT: usize = 3; + pub fn wires_fixed_multiplicand() -> Range { + 0..D + } + pub fn wires_multiplicand_0() -> Range { + D..2 * D + } + pub fn wires_addend_0() -> Range { + 2 * D..3 * D + } + pub fn wires_multiplicand_1() -> Range { + 3 * D..4 * D + } + pub fn wires_addend_1() -> Range { + 4 * D..5 * D + } + pub fn wires_output_0() -> Range { + 5 * D..6 * D + } + pub fn wires_output_1() -> Range { + 6 * D..7 * D + } } -impl, const D: usize> Gate for ArithmeticGate { +impl, const D: usize> Gate for ArithmeticExtensionGate { fn id(&self) -> String { format!("{:?}", self) } @@ -36,12 +49,23 @@ impl, const D: usize> Gate for ArithmeticGate { fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; - let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1]; - let addend = vars.local_wires[Self::WIRE_ADDEND]; - let output = vars.local_wires[Self::WIRE_OUTPUT]; - let computed_output = const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend; - vec![computed_output - output] + + let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); + let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); + let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); + let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + + let computed_output_0 = + fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into(); + let computed_output_1 = + fixed_multiplicand * multiplicand_1 * const_0.into() + addend_1 * const_1.into(); + + let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec(); + constraints.extend((output_1 - computed_output_1).to_basefield_array()); + constraints } fn eval_unfiltered_recursively( @@ -51,15 +75,30 @@ impl, const D: usize> Gate for ArithmeticGate { ) -> Vec> { let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; - let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1]; - let addend = vars.local_wires[Self::WIRE_ADDEND]; - let output = vars.local_wires[Self::WIRE_OUTPUT]; - let product_term = builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); - let addend_term = builder.mul_extension(const_1, addend); - let computed_output = builder.add_many_extension(&[product_term, addend_term]); - vec![builder.sub_extension(computed_output, output)] + let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); + let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); + let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); + let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + + let computed_output_0 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_0); + let computed_output_0 = builder.scalar_mul_ext_algebra(const_0, computed_output_0); + let scaled_addend_0 = builder.scalar_mul_ext_algebra(const_1, addend_0); + let computed_output_0 = builder.add_ext_algebra(computed_output_0, scaled_addend_0); + + let computed_output_1 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_1); + let computed_output_1 = builder.scalar_mul_ext_algebra(const_0, computed_output_1); + let scaled_addend_1 = builder.scalar_mul_ext_algebra(const_1, addend_1); + let computed_output_1 = builder.add_ext_algebra(computed_output_1, scaled_addend_1); + + let diff_0 = builder.sub_ext_algebra(output_0, computed_output_0); + let diff_1 = builder.sub_ext_algebra(output_1, computed_output_1); + let mut constraints = diff_0.to_ext_target_array().to_vec(); + constraints.extend(diff_1.to_ext_target_array()); + constraints } fn generators( @@ -67,16 +106,21 @@ impl, const D: usize> Gate for ArithmeticGate { gate_index: usize, local_constants: &[F], ) -> Vec>> { - let gen = ArithmeticGenerator { + let gen0 = ArithmeticExtensionGenerator0 { gate_index, const_0: local_constants[0], const_1: local_constants[1], }; - vec![Box::new(gen)] + let gen1 = ArithmeticExtensionGenerator1 { + gate_index, + const_0: local_constants[0], + const_1: local_constants[1], + }; + vec![Box::new(gen0), Box::new(gen1)] } fn num_wires(&self) -> usize { - 4 + 7 * D } fn num_constants(&self) -> usize { @@ -88,70 +132,96 @@ impl, const D: usize> Gate for ArithmeticGate { } fn num_constraints(&self) -> usize { - 1 + 2 * D } } -struct ArithmeticGenerator { +struct ArithmeticExtensionGenerator0, const D: usize> { gate_index: usize, const_0: F, const_1: F, } -impl SimpleGenerator for ArithmeticGenerator { +struct ArithmeticExtensionGenerator1, const D: usize> { + gate_index: usize, + const_0: F, + const_1: F, +} + +impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator0 { fn dependencies(&self) -> Vec { - vec![ - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }), - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }), - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_ADDEND, - }), - ] + ArithmeticExtensionGate::::wires_fixed_multiplicand() + .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) + .chain(ArithmeticExtensionGate::::wires_addend_0()) + .map(|i| Target::wire(self.gate_index, i)) + .collect() } fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let multiplicand_0_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }; - let multiplicand_1_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }; - let addend_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_ADDEND, - }; - let output_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_OUTPUT, + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) }; - let multiplicand_0 = witness.get_wire(multiplicand_0_target); - let multiplicand_1 = witness.get_wire(multiplicand_1_target); - let addend = witness.get_wire(addend_target); + let fixed_multiplicand = + extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); + let multiplicand_0 = + extract_extension(ArithmeticExtensionGate::::wires_multiplicand_0()); + let addend_0 = extract_extension(ArithmeticExtensionGate::::wires_addend_0()); - let output = self.const_0 * multiplicand_0 * multiplicand_1 + self.const_1 * addend; + let output_target_0 = ExtensionTarget::from_range( + self.gate_index, + ArithmeticExtensionGate::::wires_output_0(), + ); - PartialWitness::singleton_wire(output_target, output) + let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() + + addend_0 * self.const_1.into(); + + PartialWitness::singleton_extension_target(output_target_0, computed_output_0) + } +} + +impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator1 { + fn dependencies(&self) -> Vec { + ArithmeticExtensionGate::::wires_fixed_multiplicand() + .chain(ArithmeticExtensionGate::::wires_multiplicand_1()) + .chain(ArithmeticExtensionGate::::wires_addend_1()) + .map(|i| Target::wire(self.gate_index, i)) + .collect() + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) + }; + + let fixed_multiplicand = + extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); + let multiplicand_1 = + extract_extension(ArithmeticExtensionGate::::wires_multiplicand_1()); + let addend_1 = extract_extension(ArithmeticExtensionGate::::wires_addend_1()); + + let output_target_1 = ExtensionTarget::from_range( + self.gate_index, + ArithmeticExtensionGate::::wires_output_1(), + ); + + let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into() + + addend_1 * self.const_1.into(); + + PartialWitness::singleton_extension_target(output_target_1, computed_output_1) } } #[cfg(test)] mod tests { use crate::field::crandall_field::CrandallField; - use crate::gates::arithmetic::ArithmeticGate; + use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::gate_testing::test_low_degree; #[test] fn low_degree() { - test_low_degree(ArithmeticGate::new::()) + test_low_degree(ArithmeticExtensionGate::<4>::new::()) } } diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index d1b58aa3..bf56a690 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -222,12 +222,11 @@ impl, const D: usize> Tree> { mod tests { use super::*; use crate::field::crandall_field::CrandallField; - use crate::gates::arithmetic::ArithmeticGate; + use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::base_sum::BaseSumGate; use crate::gates::constant::ConstantGate; use crate::gates::gmimc::GMiMCGate; use crate::gates::interpolation::InterpolationGate; - use crate::gates::mul_extension::MulExtensionGate; use crate::gates::noop::NoopGate; use crate::hash::GMIMC_ROUNDS; @@ -240,11 +239,10 @@ mod tests { let gates = vec![ NoopGate::get::(), ConstantGate::get(), - ArithmeticGate::new(), + ArithmeticExtensionGate::new(), BaseSumGate::<4>::new(4), GMiMCGate::::with_automatic_constants(), InterpolationGate::new(4), - MulExtensionGate::new(), ]; let len = gates.len(); diff --git a/src/gates/mod.rs b/src/gates/mod.rs index bb8b178b..fa23b273 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -1,11 +1,10 @@ -pub(crate) mod arithmetic; +pub mod arithmetic; pub mod base_sum; pub mod constant; pub(crate) mod gate; pub mod gate_tree; pub mod gmimc; pub mod interpolation; -pub mod mul_extension; pub(crate) mod noop; #[cfg(test)] diff --git a/src/gates/mul_extension.rs b/src/gates/mul_extension.rs deleted file mode 100644 index e378e2b1..00000000 --- a/src/gates/mul_extension.rs +++ /dev/null @@ -1,145 +0,0 @@ -use std::ops::Range; - -use crate::circuit_builder::CircuitBuilder; -use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; -use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; -use crate::wire::Wire; -use crate::witness::PartialWitness; - -/// A gate which can multiply two field extension elements. -/// TODO: Add an addend if `NUM_ROUTED_WIRES` is large enough. -#[derive(Debug)] -pub struct MulExtensionGate; - -impl MulExtensionGate { - pub fn new>() -> GateRef { - GateRef::new(MulExtensionGate) - } - - pub fn wires_multiplicand_0() -> Range { - 0..D - } - pub fn wires_multiplicand_1() -> Range { - D..2 * D - } - pub fn wires_output() -> Range { - 2 * D..3 * D - } -} - -impl, const D: usize> Gate for MulExtensionGate { - fn id(&self) -> String { - format!("{:?}", self) - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let const_0 = vars.local_constants[0]; - let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); - let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let output = vars.get_local_ext_algebra(Self::wires_output()); - let computed_output = multiplicand_0 * multiplicand_1 * const_0.into(); - (output - computed_output).to_basefield_array().to_vec() - } - - fn eval_unfiltered_recursively( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let const_0 = vars.local_constants[0]; - let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); - let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let output = vars.get_local_ext_algebra(Self::wires_output()); - let computed_output = builder.mul_ext_algebra(multiplicand_0, multiplicand_1); - let computed_output = builder.scalar_mul_ext_algebra(const_0, computed_output); - let diff = builder.sub_ext_algebra(output, computed_output); - diff.to_ext_target_array().to_vec() - } - - fn generators( - &self, - gate_index: usize, - local_constants: &[F], - ) -> Vec>> { - let gen = MulExtensionGenerator { - gate_index, - const_0: local_constants[0], - }; - vec![Box::new(gen)] - } - - fn num_wires(&self) -> usize { - 12 - } - - fn num_constants(&self) -> usize { - 1 - } - - fn degree(&self) -> usize { - 3 - } - - fn num_constraints(&self) -> usize { - D - } -} - -struct MulExtensionGenerator, const D: usize> { - gate_index: usize, - const_0: F, -} - -impl, const D: usize> SimpleGenerator for MulExtensionGenerator { - fn dependencies(&self) -> Vec { - MulExtensionGate::::wires_multiplicand_0() - .chain(MulExtensionGate::::wires_multiplicand_1()) - .map(|i| { - Target::Wire(Wire { - gate: self.gate_index, - input: i, - }) - }) - .collect() - } - - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let multiplicand_0_target = ExtensionTarget::from_range( - self.gate_index, - MulExtensionGate::::wires_multiplicand_0(), - ); - let multiplicand_0 = witness.get_extension_target(multiplicand_0_target); - - let multiplicand_1_target = ExtensionTarget::from_range( - self.gate_index, - MulExtensionGate::::wires_multiplicand_1(), - ); - let multiplicand_1 = witness.get_extension_target(multiplicand_1_target); - - let output_target = - ExtensionTarget::from_range(self.gate_index, MulExtensionGate::::wires_output()); - - let computed_output = - F::Extension::from_basefield(self.const_0) * multiplicand_0 * multiplicand_1; - - let mut pw = PartialWitness::new(); - pw.set_extension_target(output_target, computed_output); - pw - } -} - -#[cfg(test)] -mod tests { - use crate::field::crandall_field::CrandallField; - use crate::gates::gate_testing::test_low_degree; - use crate::gates::mul_extension::MulExtensionGate; - - #[test] - fn low_degree() { - test_low_degree(MulExtensionGate::<4>::new::()) - } -} diff --git a/src/lib.rs b/src/lib.rs index 09f90f81..adfdf2cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(destructuring_assignment)] + pub mod circuit_builder; pub mod circuit_data; pub mod field; diff --git a/src/recursive_verifier.rs b/src/recursive_verifier.rs index c1005a67..bb9d724e 100644 --- a/src/recursive_verifier.rs +++ b/src/recursive_verifier.rs @@ -5,7 +5,7 @@ use crate::gates::gate::GateRef; use crate::proof::ProofTarget; const MIN_WIRES: usize = 120; // TODO: Double check. -const MIN_ROUTED_WIRES: usize = 8; // TODO: Double check. +const MIN_ROUTED_WIRES: usize = 28; // TODO: Double check. /// Recursively verifies an inner proof. pub fn add_recursive_verifier, const D: usize>( diff --git a/src/util/scaling.rs b/src/util/scaling.rs index cea86195..3449e0b9 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -1,7 +1,12 @@ use std::borrow::Borrow; -use crate::field::extension_field::Frobenius; +use num::Integer; + +use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::{Extendable, Frobenius}; use crate::field::field::Field; +use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::polynomial::polynomial::PolynomialCoeffs; /// When verifying the composition polynomial in FRI we have to compute sums of the form @@ -73,3 +78,142 @@ impl ReducingFactor { } } } + +#[derive(Debug, Copy, Clone)] +pub struct ReducingFactorTarget { + base: ExtensionTarget, + count: u64, +} + +impl ReducingFactorTarget { + pub fn new(base: ExtensionTarget) -> Self { + Self { base, count: 0 } + } + + /// Reduces a length `n` vector of `ExtensionTarget`s using `n/2` `ArithmeticExtensionGate`s. + /// It does this by batching two steps of Horner's method in each gate. + /// Here's an example with `n=4, alpha=2, D=1`: + /// 1st gate: 2 0 4 4 3 4 11 <- 2*0+4=4, 2*4+3=11 + /// 2nd gate: 2 11 2 24 1 24 49 <- 2*11+2=24, 2*24+1=49 + /// which verifies that `2.reduce([1,2,3,4]) = 49`. + pub fn reduce( + &mut self, + terms: &[ExtensionTarget], // Could probably work with a `DoubleEndedIterator` too. + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: Extendable, + { + let zero = builder.zero_extension(); + let l = terms.len(); + self.count += l as u64; + + let mut terms_vec = terms.to_vec(); + // If needed, we pad the original vector so that it has even length. + if terms_vec.len().is_odd() { + terms_vec.push(zero); + } + terms_vec.reverse(); + + let mut acc = zero; + for pair in terms_vec.chunks(2) { + // We will route the output of the first arithmetic operation to the multiplicand of the + // second, i.e. we compute the following: + // out_0 = alpha acc + pair[0] + // acc' = out_1 = alpha out_0 + pair[1] + let gate = builder.num_gates(); + let out_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + acc = builder + .double_arithmetic_extension( + F::ONE, + F::ONE, + self.base, + acc, + pair[0], + out_0, + pair[1], + ) + .1; + } + acc + } + + pub fn shift( + &mut self, + x: ExtensionTarget, + builder: &mut CircuitBuilder, + ) -> ExtensionTarget + where + F: Extendable, + { + let exp = builder.exp_u64_extension(self.base, self.count); + let tmp = builder.mul_extension(exp, x); + self.count = 0; + tmp + } + + pub fn reset(&mut self) { + self.count = 0; + } + + pub fn repeated_frobenius(&self, count: usize, builder: &mut CircuitBuilder) -> Self + where + F: Extendable, + { + Self { + base: self.base.repeated_frobenius(count, builder), + count: self.count, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit_data::CircuitConfig; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::witness::PartialWitness; + + fn test_reduce_gadget(n: usize) { + type F = CrandallField; + type FF = QuarticCrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let mut builder = CircuitBuilder::::new(config); + + let alpha = FF::rand(); + let alpha = FF::ONE; + let vs = (0..n) + .map(|i| FF::from_canonical_usize(i)) + .collect::>(); + + let manual_reduce = ReducingFactor::new(alpha).reduce(vs.iter()); + let manual_reduce = builder.constant_extension(manual_reduce); + + let mut alpha_t = ReducingFactorTarget::new(builder.constant_extension(alpha)); + let vs_t = vs + .iter() + .map(|&v| builder.constant_extension(v)) + .collect::>(); + let circuit_reduce = alpha_t.reduce(&vs_t, &mut builder); + + builder.assert_equal_extension(manual_reduce, circuit_reduce); + + let data = builder.build(); + let proof = data.prove(PartialWitness::new()); + } + + #[test] + fn test_reduce_gadget_even() { + test_reduce_gadget(10); + } + + #[test] + fn test_reduce_gadget_odd() { + test_reduce_gadget(11); + } +} diff --git a/src/witness.rs b/src/witness.rs index 7294f6e4..681049d0 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -78,6 +78,18 @@ impl PartialWitness { witness } + pub fn singleton_extension_target( + et: ExtensionTarget, + value: F::Extension, + ) -> Self + where + F: Extendable, + { + let mut witness = PartialWitness::new(); + witness.set_extension_target(et, value); + witness + } + pub fn is_empty(&self) -> bool { self.target_values.is_empty() }