use std::convert::TryInto; use itertools::Itertools; use num::Integer; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::{Extendable, OEF}; use crate::field::field_types::Field; use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bits_u64; impl, const D: usize> CircuitBuilder { pub fn double_arithmetic_extension( &mut self, const_0: F, const_1: F, first_multiplicand_0: ExtensionTarget, first_multiplicand_1: ExtensionTarget, first_addend: ExtensionTarget, second_multiplicand_0: ExtensionTarget, second_multiplicand_1: ExtensionTarget, second_addend: ExtensionTarget, ) -> (ExtensionTarget, ExtensionTarget) { let gate = self.add_gate(ArithmeticExtensionGate, vec![const_0, const_1]); let wire_first_multiplicand_0 = ExtensionTarget::from_range( gate, ArithmeticExtensionGate::::wires_first_multiplicand_0(), ); let wire_first_multiplicand_1 = ExtensionTarget::from_range( gate, ArithmeticExtensionGate::::wires_first_multiplicand_1(), ); let wire_first_addend = ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_addend()); let wire_second_multiplicand_0 = ExtensionTarget::from_range( gate, ArithmeticExtensionGate::::wires_second_multiplicand_0(), ); let wire_second_multiplicand_1 = ExtensionTarget::from_range( gate, ArithmeticExtensionGate::::wires_second_multiplicand_1(), ); let wire_second_addend = ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_second_addend()); let wire_first_output = ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); let wire_second_output = ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_second_output()); self.route_extension(first_multiplicand_0, wire_first_multiplicand_0); self.route_extension(first_multiplicand_1, wire_first_multiplicand_1); self.route_extension(first_addend, wire_first_addend); self.route_extension(second_multiplicand_0, wire_second_multiplicand_0); self.route_extension(second_multiplicand_1, wire_second_multiplicand_1); self.route_extension(second_addend, wire_second_addend); (wire_first_output, wire_second_output) } pub fn arithmetic_extension( &mut self, const_0: F, const_1: F, multiplicand_0: ExtensionTarget, multiplicand_1: ExtensionTarget, addend: ExtensionTarget, ) -> ExtensionTarget { // See if we can determine the result without adding an `ArithmeticGate`. if let Some(result) = self.arithmetic_extension_special_cases( const_0, const_1, multiplicand_0, multiplicand_1, addend, ) { return result; } let zero = self.zero_extension(); self.double_arithmetic_extension( const_0, const_1, multiplicand_0, multiplicand_1, addend, zero, zero, zero, ) .0 } /// Checks for special cases where the value of /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` /// can be determined without adding an `ArithmeticGate`. fn arithmetic_extension_special_cases( &mut self, const_0: F, const_1: F, multiplicand_0: ExtensionTarget, multiplicand_1: ExtensionTarget, addend: ExtensionTarget, ) -> Option> { let zero = self.zero_extension(); let mul_0_const = self.target_as_constant_ext(multiplicand_0); let mul_1_const = self.target_as_constant_ext(multiplicand_1); let addend_const = self.target_as_constant_ext(addend); let first_term_zero = const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero; let second_term_zero = const_1 == F::ZERO || addend == zero; // If both terms are constant, return their (constant) sum. let first_term_const = if first_term_zero { Some(F::Extension::ZERO) } else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) { Some(x * y * const_0.into()) } else { None }; let second_term_const = if second_term_zero { Some(F::Extension::ZERO) } else { addend_const.map(|x| x * const_1.into()) }; if let (Some(x), Some(y)) = (first_term_const, second_term_const) { return Some(self.constant_extension(x + y)); } if first_term_zero && const_1.is_one() { return Some(addend); } if second_term_zero { if let Some(x) = mul_0_const { if (x * const_0.into()).is_one() { return Some(multiplicand_1); } } if let Some(x) = mul_1_const { if (x * const_0.into()).is_one() { return Some(multiplicand_0); } } } None } pub fn add_extension( &mut self, a: ExtensionTarget, b: ExtensionTarget, ) -> ExtensionTarget { let one = self.one_extension(); self.arithmetic_extension(F::ONE, F::ONE, one, a, b) } /// Returns `(a0+b0, a1+b1)`. pub fn add_two_extension( &mut self, a0: ExtensionTarget, b0: ExtensionTarget, a1: ExtensionTarget, b1: ExtensionTarget, ) -> (ExtensionTarget, ExtensionTarget) { let one = self.one_extension(); self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, one, a1, b1) } pub fn add_ext_algebra( &mut self, a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { // We run two additions in parallel. So `[a0,a1,a2,a3] + [b0,b1,b2,b3]` is computed with two // `add_two_extension`, first `[a0,a1]+[b0,b1]` then `[a2,a3]+[b2,b3]`. let mut res = Vec::with_capacity(D); // We need some extra logic if D is odd. let d_even = D & (D ^ 1); // = 2 * (D/2) for mut chunk in &(0..d_even).chunks(2) { let i = chunk.next().unwrap(); let j = chunk.next().unwrap(); let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); res.extend([o0, o1]); } if D.is_odd() { res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); } ExtensionAlgebraTarget(res.try_into().unwrap()) } /// Add 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. pub fn add_three_extension( &mut self, a: ExtensionTarget, b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { let one = self.one_extension(); let gate = self.num_gates(); let first_out = ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); self.double_arithmetic_extension(F::ONE, F::ONE, one, a, b, one, c, first_out) .1 } /// Add `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let zero = self.zero_extension(); let mut terms = terms.to_vec(); if terms.is_empty() { return zero; } else if terms.len() < 3 { terms.resize(3, zero); } else if terms.len().is_even() { terms.push(zero); } let mut acc = self.add_three_extension(terms[0], terms[1], terms[2]); terms.drain(0..3); for chunk in terms.chunks_exact(2) { acc = self.add_three_extension(acc, chunk[0], chunk[1]); } acc } pub fn sub_extension( &mut self, a: ExtensionTarget, b: ExtensionTarget, ) -> ExtensionTarget { let one = self.one_extension(); self.arithmetic_extension(F::ONE, F::NEG_ONE, one, a, b) } pub fn sub_two_extension( &mut self, a0: ExtensionTarget, b0: ExtensionTarget, a1: ExtensionTarget, b1: ExtensionTarget, ) -> (ExtensionTarget, ExtensionTarget) { let one = self.one_extension(); self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, one, a1, b1) } pub fn sub_ext_algebra( &mut self, a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { // See `add_ext_algebra`. let mut res = Vec::with_capacity(D); let d_even = D & (D ^ 1); // = 2 * (D/2) for mut chunk in &(0..d_even).chunks(2) { let i = chunk.next().unwrap(); let j = chunk.next().unwrap(); let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); res.extend([o0, o1]); } if D.is_odd() { res.push(self.sub_extension(a.0[D - 1], b.0[D - 1])); } ExtensionAlgebraTarget(res.try_into().unwrap()) } pub fn mul_extension_with_const( &mut self, const_0: F, multiplicand_0: ExtensionTarget, multiplicand_1: ExtensionTarget, ) -> ExtensionTarget { let zero = self.zero_extension(); self.double_arithmetic_extension( const_0, F::ZERO, multiplicand_0, multiplicand_1, zero, zero, zero, zero, ) .0 } pub fn mul_extension( &mut self, multiplicand_0: ExtensionTarget, multiplicand_1: ExtensionTarget, ) -> ExtensionTarget { self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1) } /// Returns `(a0*b0, a1*b1)`. pub fn mul_two_extension( &mut self, a0: ExtensionTarget, b0: ExtensionTarget, a1: ExtensionTarget, b1: ExtensionTarget, ) -> (ExtensionTarget, ExtensionTarget) { let zero = self.zero_extension(); self.double_arithmetic_extension(F::ONE, F::ZERO, a0, b0, zero, a1, b1, zero) } /// Computes `x^2`. pub fn square_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { self.mul_extension(x, x) } /// Computes `x^3`. pub fn cube_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { self.mul_three_extension(x, x, x) } pub fn mul_ext_algebra( &mut self, a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { let mut res = [self.zero_extension(); D]; let w = self.constant(F::Extension::W); for i in 0..D { for j in 0..D { res[(i + j) % D] = if i + j < D { self.mul_add_extension(a.0[i], b.0[j], res[(i + j) % D]) } else { let ai_bi = self.mul_extension(a.0[i], b.0[j]); self.scalar_mul_add_extension(w, ai_bi, res[(i + j) % D]) } } } ExtensionAlgebraTarget(res) } /// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. pub fn mul_three_extension( &mut self, a: ExtensionTarget, b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { let zero = self.zero_extension(); let gate = self.num_gates(); let first_out = ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); self.double_arithmetic_extension(F::ONE, F::ZERO, a, b, zero, c, first_out, zero) .1 } /// Multiply `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let one = self.one_extension(); let mut terms = terms.to_vec(); if terms.is_empty() { return one; } else if terms.len() < 3 { terms.resize(3, one); } else if terms.len().is_even() { terms.push(one); } let mut acc = self.mul_three_extension(terms[0], terms[1], terms[2]); terms.drain(0..3); for chunk in terms.chunks_exact(2) { acc = self.mul_three_extension(acc, chunk[0], chunk[1]); } acc } /// Like `mul_add`, but for `ExtensionTarget`s. pub fn mul_add_extension( &mut self, a: ExtensionTarget, b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { self.arithmetic_extension(F::ONE, F::ONE, a, b, c) } /// Like `mul_add`, but for `ExtensionTarget`s. pub fn scalar_mul_add_extension( &mut self, a: Target, b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { let a_ext = self.convert_to_ext(a); self.arithmetic_extension(F::ONE, F::ONE, a_ext, b, c) } /// Like `mul_sub`, but for `ExtensionTarget`s. pub fn mul_sub_extension( &mut self, a: ExtensionTarget, b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { self.arithmetic_extension(F::ONE, F::NEG_ONE, a, b, c) } /// Like `mul_sub`, but for `ExtensionTarget`s. pub fn scalar_mul_sub_extension( &mut self, a: Target, b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { let a_ext = self.convert_to_ext(a); self.arithmetic_extension(F::ONE, F::NEG_ONE, a_ext, b, c) } /// Returns `a * b`, where `b` is in the extension field and `a` is in the base field. pub fn scalar_mul_ext(&mut self, a: Target, b: ExtensionTarget) -> ExtensionTarget { let a_ext = self.convert_to_ext(a); self.mul_extension(a_ext, b) } /// Returns `a * b`, where `b` is in the extension of the extension field, and `a` is in the /// extension field. pub fn scalar_mul_ext_algebra( &mut self, a: ExtensionTarget, mut b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { for i in 0..D { b.0[i] = self.mul_extension(a, b.0[i]); } b } /// Exponentiate `base` to the power of `2^power_log`. // TODO: Test pub fn exp_power_of_2_extension( &mut self, mut base: ExtensionTarget, power_log: usize, ) -> ExtensionTarget { for _ in 0..power_log { base = self.square_extension(base); } base } /// Exponentiate `base` to the power of a known `exponent`. // TODO: Test pub fn exp_u64_extension( &mut self, base: ExtensionTarget, exponent: u64, ) -> ExtensionTarget { match exponent { 0 => return self.one_extension(), 1 => return base, 2 => return self.square_extension(base), 3 => return self.cube_extension(base), _ => (), } let mut current = base; let mut product = self.one_extension(); for j in 0..bits_u64(exponent as u64) { if (exponent >> j & 1) != 0 { product = self.mul_extension(product, current); } current = self.square_extension(current); } product } /// Computes `x / y`. Results in an unsatisfiable instance if `y = 0`. pub fn div_extension( &mut self, x: ExtensionTarget, y: ExtensionTarget, ) -> ExtensionTarget { let y_inv = self.inverse_extension(y); self.mul_extension(x, y_inv) } /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in /// some cases, as it allows `0 / 0 = `. pub fn div_unsafe_extension( &mut self, x: ExtensionTarget, y: ExtensionTarget, ) -> ExtensionTarget { let quotient = self.add_virtual_extension_target(); self.add_generator(QuotientGeneratorExtension { numerator: x, denominator: y, quotient, }); // Enforce that q y = x. let q_y = self.mul_extension(quotient, y); self.assert_equal_extension(q_y, x); quotient } /// Computes `1 / x`. Results in an unsatisfiable instance if `x = 0`. pub fn inverse_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { let inv = self.add_virtual_extension_target(); let one = self.one_extension(); self.add_generator(QuotientGeneratorExtension { numerator: one, denominator: x, quotient: inv, }); // Enforce that x times its purported inverse equals 1. let x_inv = self.mul_extension(x, inv); self.assert_equal_extension(x_inv, one); inv } } struct QuotientGeneratorExtension { numerator: ExtensionTarget, denominator: ExtensionTarget, quotient: ExtensionTarget, } impl, const D: usize> SimpleGenerator for QuotientGeneratorExtension { fn dependencies(&self) -> Vec { let mut deps = self.numerator.to_target_array().to_vec(); deps.extend(&self.denominator.to_target_array()); deps } fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let num = witness.get_extension_target(self.numerator); let dem = witness.get_extension_target(self.denominator); let quotient = num / dem; out_buffer.set_extension_target(self.quotient, quotient) } } /// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. #[derive(Clone)] pub struct PowersTarget { base: ExtensionTarget, current: ExtensionTarget, } impl PowersTarget { pub fn next>( &mut self, builder: &mut CircuitBuilder, ) -> ExtensionTarget { let result = self.current; self.current = builder.mul_extension(self.base, self.current); result } pub fn repeated_frobenius>( self, k: usize, builder: &mut CircuitBuilder, ) -> Self { let Self { base, current } = self; Self { base: base.repeated_frobenius(k, builder), current: current.repeated_frobenius(k, builder), } } } impl, const D: usize> CircuitBuilder { pub fn powers(&mut self, base: ExtensionTarget) -> PowersTarget { PowersTarget { base, current: self.one_extension(), } } } #[cfg(test)] mod tests { use anyhow::Result; use crate::field::crandall_field::CrandallField; use crate::field::extension_field::quartic::QuarticCrandallField; use crate::field::field_types::Field; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; #[test] fn test_mul_many() -> Result<()> { type F = CrandallField; type FF = QuarticCrandallField; const D: usize = 4; let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::new(); let vs = FF::rand_vec(3); let ts = builder.add_virtual_extension_targets(3); for (&v, &t) in vs.iter().zip(&ts) { pw.set_extension_target(t, v); } let mul0 = builder.mul_many_extension(&ts); let mul1 = { let mut acc = builder.one_extension(); for &t in &ts { acc = builder.mul_extension(acc, t); } acc }; let mul2 = builder.mul_three_extension(ts[0], ts[1], ts[2]); let mul3 = builder.constant_extension(vs.into_iter().product()); builder.assert_equal_extension(mul0, mul1); builder.assert_equal_extension(mul1, mul2); builder.assert_equal_extension(mul2, mul3); let data = builder.build(); let proof = data.prove(pw)?; verify(proof, &data.verifier_only, &data.common) } #[test] fn test_div_extension() -> Result<()> { type F = CrandallField; type FF = QuarticCrandallField; const D: usize = 4; let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config); let x = FF::rand(); let y = FF::rand(); let z = x / y; let xt = builder.constant_extension(x); let yt = builder.constant_extension(y); let zt = builder.constant_extension(z); let comp_zt = builder.div_extension(xt, yt); let comp_zt_unsafe = builder.div_unsafe_extension(xt, yt); builder.assert_equal_extension(zt, comp_zt); builder.assert_equal_extension(zt, comp_zt_unsafe); let data = builder.build(); let proof = data.prove(PartialWitness::new())?; verify(proof, &data.verifier_only, &data.common) } }