diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 6a5ab2be..e7dead21 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -112,263 +112,6 @@ impl, const D: usize> CircuitBuilder { self.constant_ext_algebra(ExtensionAlgebra::ZERO) } - pub fn double_arithmetic_extension( - &mut self, - const_0: F, - const_1: F, - fixed_multiplicand: ExtensionTarget, - multiplicand_0: ExtensionTarget, - addend_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - addend_1: ExtensionTarget, - ) -> (ExtensionTarget, ExtensionTarget) { - let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]); - - let wire_fixed_multiplicand = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_fixed_multiplicand(), - ); - let wire_multiplicand_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); - let wire_addend_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_0()); - let wire_multiplicand_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); - let wire_addend_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_1()); - let wire_output_0 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); - let wire_output_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_1()); - - self.route_extension(fixed_multiplicand, wire_fixed_multiplicand); - self.route_extension(multiplicand_0, wire_multiplicand_0); - self.route_extension(addend_0, wire_addend_0); - self.route_extension(multiplicand_1, wire_multiplicand_1); - self.route_extension(addend_1, wire_addend_1); - (wire_output_0, wire_output_1) - } - - pub fn arithmetic_extension( - &mut self, - const_0: F, - const_1: F, - multiplicand_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - addend: ExtensionTarget, - ) -> ExtensionTarget { - let zero = self.zero_extension(); - self.double_arithmetic_extension( - const_0, - const_1, - multiplicand_0, - multiplicand_1, - addend, - zero, - zero, - ) - .0 - } - - pub fn add_extension( - &mut self, - a: ExtensionTarget, - b: ExtensionTarget, - ) -> ExtensionTarget { - let one = self.one_extension(); - self.arithmetic_extension(F::ONE, F::ONE, one, a, b) - } - - pub fn add_two_extension( - &mut self, - a0: ExtensionTarget, - b0: ExtensionTarget, - a1: ExtensionTarget, - b1: ExtensionTarget, - ) -> (ExtensionTarget, ExtensionTarget) { - let one = self.one_extension(); - self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, a1, b1) - } - - pub fn add_ext_algebra( - &mut self, - a: ExtensionAlgebraTarget, - b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - let mut res = Vec::with_capacity(D); - let d_even = D & (D ^ 1); // = 2 * (D/2) - for mut chunk in &(0..d_even).chunks(2) { - let i = chunk.next().unwrap(); - let j = chunk.next().unwrap(); - let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); - res.extend([o0, o1]); - } - if D % 2 == 1 { - res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); - } - ExtensionAlgebraTarget(res.try_into().unwrap()) - } - - pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let zero = self.zero_extension(); - let mut terms = terms.to_vec(); - if terms.len() % 2 == 1 { - terms.push(zero); - } - let mut acc0 = zero; - let mut acc1 = zero; - for chunk in terms.chunks_exact(2) { - (acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]); - } - self.add_extension(acc0, acc1) - } - - pub fn sub_extension( - &mut self, - a: ExtensionTarget, - b: ExtensionTarget, - ) -> ExtensionTarget { - let one = self.one_extension(); - self.arithmetic_extension(F::ONE, F::NEG_ONE, one, a, b) - } - - pub fn sub_two_extension( - &mut self, - a0: ExtensionTarget, - b0: ExtensionTarget, - a1: ExtensionTarget, - b1: ExtensionTarget, - ) -> (ExtensionTarget, ExtensionTarget) { - let one = self.one_extension(); - self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, a1, b1) - } - - pub fn sub_ext_algebra( - &mut self, - a: ExtensionAlgebraTarget, - b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - let mut res = Vec::with_capacity(D); - let d_even = D & (D ^ 1); // = 2 * (D/2) - for mut chunk in &(0..d_even).chunks(2) { - let i = chunk.next().unwrap(); - let j = chunk.next().unwrap(); - let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); - res.extend([o0, o1]); - } - if D % 2 == 1 { - res.push(self.sub_extension(a.0[D - 1], b.0[D - 1])); - } - ExtensionAlgebraTarget(res.try_into().unwrap()) - } - - pub fn mul_extension_with_const( - &mut self, - const_0: F, - multiplicand_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - ) -> ExtensionTarget { - let zero = self.zero_extension(); - self.double_arithmetic_extension( - const_0, - F::ZERO, - multiplicand_0, - multiplicand_1, - zero, - zero, - zero, - ) - .0 - } - - pub fn mul_extension( - &mut self, - multiplicand_0: ExtensionTarget, - multiplicand_1: ExtensionTarget, - ) -> ExtensionTarget { - self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1) - } - - pub fn mul_ext_algebra( - &mut self, - a: ExtensionAlgebraTarget, - b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - let mut res = [self.zero_extension(); D]; - let w = self.constant(F::Extension::W); - for i in 0..D { - for j in 0..D { - res[(i + j) % D] = if i + j < D { - self.mul_add_extension(a.0[i], b.0[j], res[(i + j) % D]) - } else { - let ai_bi = self.mul_extension(a.0[i], b.0[j]); - self.scalar_mul_add_extension(w, ai_bi, res[(i + j) % D]) - } - } - } - ExtensionAlgebraTarget(res) - } - - pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let mut product = self.one_extension(); - for term in terms { - product = self.mul_extension(product, *term); - } - product - } - - /// Like `mul_add`, but for `ExtensionTarget`s. Note that, unlike `mul_add`, this has no - /// performance benefit over separate muls and adds. - pub fn mul_add_extension( - &mut self, - a: ExtensionTarget, - b: ExtensionTarget, - c: ExtensionTarget, - ) -> ExtensionTarget { - self.arithmetic_extension(F::ONE, F::ONE, a, b, c) - } - - /// Like `mul_add`, but for `ExtensionTarget`s. - pub fn scalar_mul_add_extension( - &mut self, - a: Target, - b: ExtensionTarget, - c: ExtensionTarget, - ) -> ExtensionTarget { - let a_ext = self.convert_to_ext(a); - self.arithmetic_extension(F::ONE, F::ONE, a_ext, b, c) - } - - /// Like `mul_sub`, but for `ExtensionTarget`s. - pub fn scalar_mul_sub_extension( - &mut self, - a: Target, - b: ExtensionTarget, - c: ExtensionTarget, - ) -> ExtensionTarget { - let a_ext = self.convert_to_ext(a); - self.arithmetic_extension(F::ONE, F::NEG_ONE, a_ext, b, c) - } - - /// Returns `a * b`, where `b` is in the extension field and `a` is in the base field. - pub fn scalar_mul_ext(&mut self, a: Target, b: ExtensionTarget) -> ExtensionTarget { - let a_ext = self.convert_to_ext(a); - self.mul_extension(a_ext, b) - } - - /// Returns `a * b`, where `b` is in the extension of the extension field, and `a` is in the - /// extension field. - pub fn scalar_mul_ext_algebra( - &mut self, - a: ExtensionTarget, - mut b: ExtensionAlgebraTarget, - ) -> ExtensionAlgebraTarget { - for i in 0..D { - b.0[i] = self.mul_extension(a, b.0[i]); - } - b - } - pub fn convert_to_ext(&mut self, t: Target) -> ExtensionTarget { let zero = self.zero(); let mut arr = [zero; D]; diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 99e260c7..b478f621 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -23,11 +23,6 @@ impl, const D: usize> CircuitBuilder { self.mul(x, x) } - /// Computes `x^2`. - pub fn square_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { - self.mul_extension(x, x) - } - /// Computes `x^3`. pub fn cube(&mut self, x: Target) -> Target { self.mul_many(&[x, x, x]) @@ -137,6 +132,7 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, one, F::ONE, y) } + // TODO: Can be made `2*D` times more efficient by using all wires of an `ArithmeticExtensionGate`. pub fn add_many(&mut self, terms: &[Target]) -> Target { let mut sum = self.zero(); for term in terms { @@ -200,25 +196,6 @@ impl, const D: usize> CircuitBuilder { product } - /// Exponentiate `base` to the power of a known `exponent`. - // TODO: Test - pub fn exp_u64_extension( - &mut self, - base: ExtensionTarget, - exponent: u64, - ) -> ExtensionTarget { - let mut current = base; - let mut product = self.one_extension(); - - for j in 0..bits_u64(exponent as u64) { - if (exponent >> j & 1) != 0 { - product = self.mul_extension(product, current); - } - current = self.square_extension(current); - } - product - } - /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in /// some cases, as it allows `0 / 0 = `. pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target { @@ -241,170 +218,4 @@ impl, const D: usize> CircuitBuilder { let y_ext = self.convert_to_ext(y); self.div_unsafe_extension(x_ext, y_ext).0[0] } - - /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in - /// some cases, as it allows `0 / 0 = `. - pub fn div_unsafe_extension( - &mut self, - x: ExtensionTarget, - y: ExtensionTarget, - ) -> ExtensionTarget { - // Add an `ArithmeticExtensionGate` to compute `q * y`. - let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]); - - let multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_fixed_multiplicand(), - ); - let multiplicand_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); - let output = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); - - self.add_generator(QuotientGeneratorExtension { - numerator: x, - denominator: y, - quotient: multiplicand_0, - }); - self.add_generator(ZeroOutGenerator { - gate_index: gate, - ranges: vec![ - ArithmeticExtensionGate::::wires_addend_0(), - ArithmeticExtensionGate::::wires_multiplicand_1(), - ArithmeticExtensionGate::::wires_addend_1(), - ], - }); - - self.route_extension(y, multiplicand_1); - self.assert_equal_extension(output, x); - - multiplicand_0 - } -} - -struct QuotientGeneratorExtension { - numerator: ExtensionTarget, - denominator: ExtensionTarget, - quotient: ExtensionTarget, -} - -impl, const D: usize> SimpleGenerator for QuotientGeneratorExtension { - fn dependencies(&self) -> Vec { - let mut deps = self.numerator.to_target_array().to_vec(); - deps.extend(&self.denominator.to_target_array()); - deps - } - - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let num = witness.get_extension_target(self.numerator); - let dem = witness.get_extension_target(self.denominator); - let quotient = num / dem; - let mut pw = PartialWitness::new(); - for i in 0..D { - pw.set_target( - self.quotient.to_target_array()[i], - quotient.to_basefield_array()[i], - ); - } - - pw - } -} - -/// Generator used to zero out wires at a given gate index and ranges. -pub struct ZeroOutGenerator { - gate_index: usize, - ranges: Vec>, -} - -impl SimpleGenerator for ZeroOutGenerator { - fn dependencies(&self) -> Vec { - Vec::new() - } - - fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { - let mut pw = PartialWitness::new(); - for t in self - .ranges - .iter() - .flat_map(|r| Target::wires_from_range(self.gate_index, r.clone())) - { - pw.set_target(t, F::ZERO); - } - - pw - } -} - -/// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. -#[derive(Clone)] -pub struct PowersTarget { - base: ExtensionTarget, - current: ExtensionTarget, -} - -impl PowersTarget { - pub fn next>( - &mut self, - builder: &mut CircuitBuilder, - ) -> ExtensionTarget { - let result = self.current; - self.current = builder.mul_extension(self.base, self.current); - result - } - - pub fn repeated_frobenius>( - self, - k: usize, - builder: &mut CircuitBuilder, - ) -> Self { - let Self { base, current } = self; - Self { - base: base.repeated_frobenius(k, builder), - current: current.repeated_frobenius(k, builder), - } - } -} - -impl, const D: usize> CircuitBuilder { - pub fn powers(&mut self, base: ExtensionTarget) -> PowersTarget { - PowersTarget { - base, - current: self.one_extension(), - } - } -} - -#[cfg(test)] -mod tests { - use crate::circuit_builder::CircuitBuilder; - use crate::circuit_data::CircuitConfig; - use crate::field::crandall_field::CrandallField; - use crate::field::extension_field::quartic::QuarticCrandallField; - use crate::field::field::Field; - use crate::fri::FriConfig; - use crate::witness::PartialWitness; - - #[test] - fn test_div_extension() { - type F = CrandallField; - type FF = QuarticCrandallField; - const D: usize = 4; - - let config = CircuitConfig::large_config(); - - let mut builder = CircuitBuilder::::new(config); - - let x = FF::rand(); - let y = FF::rand(); - let z = x / y; - let xt = builder.constant_extension(x); - let yt = builder.constant_extension(y); - let zt = builder.constant_extension(z); - let comp_zt = builder.div_unsafe_extension(xt, yt); - builder.assert_equal_extension(zt, comp_zt); - - let data = builder.build(); - let proof = data.prove(PartialWitness::new()); - } } diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs new file mode 100644 index 00000000..d29e5079 --- /dev/null +++ b/src/gadgets/arithmetic_extension.rs @@ -0,0 +1,464 @@ +use std::convert::TryInto; +use std::ops::Range; + +use itertools::Itertools; + +use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::algebra::ExtensionAlgebra; +use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; +use crate::field::extension_field::{Extendable, FieldExtension, OEF}; +use crate::field::field::Field; +use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::generator::SimpleGenerator; +use crate::target::Target; +use crate::util::bits_u64; +use crate::witness::PartialWitness; + +impl, const D: usize> CircuitBuilder { + pub fn double_arithmetic_extension( + &mut self, + const_0: F, + const_1: F, + fixed_multiplicand: ExtensionTarget, + multiplicand_0: ExtensionTarget, + addend_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend_1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]); + + let wire_fixed_multiplicand = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ); + let wire_multiplicand_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); + let wire_addend_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_0()); + let wire_multiplicand_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_1()); + let wire_addend_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_addend_1()); + let wire_output_0 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + let wire_output_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_1()); + + self.route_extension(fixed_multiplicand, wire_fixed_multiplicand); + self.route_extension(multiplicand_0, wire_multiplicand_0); + self.route_extension(addend_0, wire_addend_0); + self.route_extension(multiplicand_1, wire_multiplicand_1); + self.route_extension(addend_1, wire_addend_1); + (wire_output_0, wire_output_1) + } + + pub fn arithmetic_extension( + &mut self, + const_0: F, + const_1: F, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + addend: ExtensionTarget, + ) -> ExtensionTarget { + let zero = self.zero_extension(); + self.double_arithmetic_extension( + const_0, + const_1, + multiplicand_0, + multiplicand_1, + addend, + zero, + zero, + ) + .0 + } + + pub fn add_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + ) -> ExtensionTarget { + let one = self.one_extension(); + self.arithmetic_extension(F::ONE, F::ONE, one, a, b) + } + + pub fn add_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let one = self.one_extension(); + self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, a1, b1) + } + + pub fn add_ext_algebra( + &mut self, + a: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + let mut res = Vec::with_capacity(D); + let d_even = D & (D ^ 1); // = 2 * (D/2) + for mut chunk in &(0..d_even).chunks(2) { + let i = chunk.next().unwrap(); + let j = chunk.next().unwrap(); + let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); + res.extend([o0, o1]); + } + if D % 2 == 1 { + res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); + } + ExtensionAlgebraTarget(res.try_into().unwrap()) + } + + pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { + let zero = self.zero_extension(); + let mut terms = terms.to_vec(); + if terms.len() % 2 == 1 { + terms.push(zero); + } + let mut acc0 = zero; + let mut acc1 = zero; + for chunk in terms.chunks_exact(2) { + (acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]); + } + self.add_extension(acc0, acc1) + } + + pub fn sub_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + ) -> ExtensionTarget { + let one = self.one_extension(); + self.arithmetic_extension(F::ONE, F::NEG_ONE, one, a, b) + } + + pub fn sub_two_extension( + &mut self, + a0: ExtensionTarget, + b0: ExtensionTarget, + a1: ExtensionTarget, + b1: ExtensionTarget, + ) -> (ExtensionTarget, ExtensionTarget) { + let one = self.one_extension(); + self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, a1, b1) + } + + pub fn sub_ext_algebra( + &mut self, + a: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + let mut res = Vec::with_capacity(D); + let d_even = D & (D ^ 1); // = 2 * (D/2) + for mut chunk in &(0..d_even).chunks(2) { + let i = chunk.next().unwrap(); + let j = chunk.next().unwrap(); + let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); + res.extend([o0, o1]); + } + if D % 2 == 1 { + res.push(self.sub_extension(a.0[D - 1], b.0[D - 1])); + } + ExtensionAlgebraTarget(res.try_into().unwrap()) + } + + pub fn mul_extension_with_const( + &mut self, + const_0: F, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + ) -> ExtensionTarget { + let zero = self.zero_extension(); + self.double_arithmetic_extension( + const_0, + F::ZERO, + multiplicand_0, + multiplicand_1, + zero, + zero, + zero, + ) + .0 + } + + pub fn mul_extension( + &mut self, + multiplicand_0: ExtensionTarget, + multiplicand_1: ExtensionTarget, + ) -> ExtensionTarget { + self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1) + } + + /// Computes `x^2`. + pub fn square_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { + self.mul_extension(x, x) + } + + pub fn mul_ext_algebra( + &mut self, + a: ExtensionAlgebraTarget, + b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + let mut res = [self.zero_extension(); D]; + let w = self.constant(F::Extension::W); + for i in 0..D { + for j in 0..D { + res[(i + j) % D] = if i + j < D { + self.mul_add_extension(a.0[i], b.0[j], res[(i + j) % D]) + } else { + let ai_bi = self.mul_extension(a.0[i], b.0[j]); + self.scalar_mul_add_extension(w, ai_bi, res[(i + j) % D]) + } + } + } + ExtensionAlgebraTarget(res) + } + + pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { + let mut product = self.one_extension(); + for term in terms { + product = self.mul_extension(product, *term); + } + product + } + + /// Like `mul_add`, but for `ExtensionTarget`s. Note that, unlike `mul_add`, this has no + /// performance benefit over separate muls and adds. + pub fn mul_add_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + self.arithmetic_extension(F::ONE, F::ONE, a, b, c) + } + + /// Like `mul_add`, but for `ExtensionTarget`s. + pub fn scalar_mul_add_extension( + &mut self, + a: Target, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let a_ext = self.convert_to_ext(a); + self.arithmetic_extension(F::ONE, F::ONE, a_ext, b, c) + } + + /// Like `mul_sub`, but for `ExtensionTarget`s. + pub fn scalar_mul_sub_extension( + &mut self, + a: Target, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let a_ext = self.convert_to_ext(a); + self.arithmetic_extension(F::ONE, F::NEG_ONE, a_ext, b, c) + } + + /// Returns `a * b`, where `b` is in the extension field and `a` is in the base field. + pub fn scalar_mul_ext(&mut self, a: Target, b: ExtensionTarget) -> ExtensionTarget { + let a_ext = self.convert_to_ext(a); + self.mul_extension(a_ext, b) + } + + /// Returns `a * b`, where `b` is in the extension of the extension field, and `a` is in the + /// extension field. + pub fn scalar_mul_ext_algebra( + &mut self, + a: ExtensionTarget, + mut b: ExtensionAlgebraTarget, + ) -> ExtensionAlgebraTarget { + for i in 0..D { + b.0[i] = self.mul_extension(a, b.0[i]); + } + b + } + + /// Exponentiate `base` to the power of a known `exponent`. + // TODO: Test + pub fn exp_u64_extension( + &mut self, + base: ExtensionTarget, + exponent: u64, + ) -> ExtensionTarget { + let mut current = base; + let mut product = self.one_extension(); + + for j in 0..bits_u64(exponent as u64) { + if (exponent >> j & 1) != 0 { + product = self.mul_extension(product, current); + } + current = self.square_extension(current); + } + product + } + + /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in + /// some cases, as it allows `0 / 0 = `. + pub fn div_unsafe_extension( + &mut self, + x: ExtensionTarget, + y: ExtensionTarget, + ) -> ExtensionTarget { + // Add an `ArithmeticExtensionGate` to compute `q * y`. + let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]); + + let multiplicand_0 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_fixed_multiplicand(), + ); + let multiplicand_1 = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); + let output = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); + + self.add_generator(QuotientGeneratorExtension { + numerator: x, + denominator: y, + quotient: multiplicand_0, + }); + self.add_generator(ZeroOutGenerator { + gate_index: gate, + ranges: vec![ + ArithmeticExtensionGate::::wires_addend_0(), + ArithmeticExtensionGate::::wires_multiplicand_1(), + ArithmeticExtensionGate::::wires_addend_1(), + ], + }); + + self.route_extension(y, multiplicand_1); + self.assert_equal_extension(output, x); + + multiplicand_0 + } +} + +/// Generator used to zero out wires at a given gate index and ranges. +pub struct ZeroOutGenerator { + gate_index: usize, + ranges: Vec>, +} + +impl SimpleGenerator for ZeroOutGenerator { + fn dependencies(&self) -> Vec { + Vec::new() + } + + fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { + let mut pw = PartialWitness::new(); + for t in self + .ranges + .iter() + .flat_map(|r| Target::wires_from_range(self.gate_index, r.clone())) + { + pw.set_target(t, F::ZERO); + } + + pw + } +} + +struct QuotientGeneratorExtension { + numerator: ExtensionTarget, + denominator: ExtensionTarget, + quotient: ExtensionTarget, +} + +impl, const D: usize> SimpleGenerator for QuotientGeneratorExtension { + fn dependencies(&self) -> Vec { + let mut deps = self.numerator.to_target_array().to_vec(); + deps.extend(&self.denominator.to_target_array()); + deps + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let num = witness.get_extension_target(self.numerator); + let dem = witness.get_extension_target(self.denominator); + let quotient = num / dem; + let mut pw = PartialWitness::new(); + for i in 0..D { + pw.set_target( + self.quotient.to_target_array()[i], + quotient.to_basefield_array()[i], + ); + } + + pw + } +} + +/// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. +#[derive(Clone)] +pub struct PowersTarget { + base: ExtensionTarget, + current: ExtensionTarget, +} + +impl PowersTarget { + pub fn next>( + &mut self, + builder: &mut CircuitBuilder, + ) -> ExtensionTarget { + let result = self.current; + self.current = builder.mul_extension(self.base, self.current); + result + } + + pub fn repeated_frobenius>( + self, + k: usize, + builder: &mut CircuitBuilder, + ) -> Self { + let Self { base, current } = self; + Self { + base: base.repeated_frobenius(k, builder), + current: current.repeated_frobenius(k, builder), + } + } +} + +impl, const D: usize> CircuitBuilder { + pub fn powers(&mut self, base: ExtensionTarget) -> PowersTarget { + PowersTarget { + base, + current: self.one_extension(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::circuit_builder::CircuitBuilder; + use crate::circuit_data::CircuitConfig; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::field::field::Field; + use crate::fri::FriConfig; + use crate::witness::PartialWitness; + + #[test] + fn test_div_extension() { + type F = CrandallField; + type FF = QuarticCrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand(); + let y = FF::rand(); + let z = x / y; + let xt = builder.constant_extension(x); + let yt = builder.constant_extension(y); + let zt = builder.constant_extension(z); + let comp_zt = builder.div_unsafe_extension(xt, yt); + builder.assert_equal_extension(zt, comp_zt); + + let data = builder.build(); + let proof = data.prove(PartialWitness::new()); + } +} diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index a1e041fc..2f216870 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -1,4 +1,5 @@ pub mod arithmetic; +pub mod arithmetic_extension; pub mod hash; pub mod insert; pub mod interpolation;