diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index e6f5c38b..8eb86043 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -7,7 +7,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::{Extendable, FieldExtension, OEF}; use crate::field::field::Field; -use crate::gates::mul_extension::ArithmeticExtensionGate; +use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::target::Target; /// `Target`s representing an element of an extension field. diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 9cdeded4..99e260c7 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -4,7 +4,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; -use crate::gates::mul_extension::ArithmeticExtensionGate; +use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::generator::SimpleGenerator; use crate::target::Target; use crate::util::bits_u64; diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 0d0fdd7c..586a0cc2 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -1,7 +1,8 @@ +use std::ops::Range; + use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::Extendable; -use crate::field::field::Field; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator}; use crate::target::Target; @@ -9,26 +10,39 @@ use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; use crate::witness::PartialWitness; -/// A gate which can be configured to perform various arithmetic. In particular, it computes -/// -/// ```text -/// output := const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend -/// ``` +/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. #[derive(Debug)] -pub struct ArithmeticGate; +pub struct ArithmeticExtensionGate; -impl ArithmeticGate { - pub fn new, const D: usize>() -> GateRef { - GateRef::new(ArithmeticGate) +impl ArithmeticExtensionGate { + pub fn new>() -> GateRef { + GateRef::new(ArithmeticExtensionGate) } - pub const WIRE_MULTIPLICAND_0: usize = 0; - pub const WIRE_MULTIPLICAND_1: usize = 1; - pub const WIRE_ADDEND: usize = 2; - pub const WIRE_OUTPUT: usize = 3; + pub fn wires_fixed_multiplicand() -> Range { + 0..D + } + pub fn wires_multiplicand_0() -> Range { + D..2 * D + } + pub fn wires_addend_0() -> Range { + 2 * D..3 * D + } + pub fn wires_multiplicand_1() -> Range { + 3 * D..4 * D + } + pub fn wires_addend_1() -> Range { + 4 * D..5 * D + } + pub fn wires_output_0() -> Range { + 5 * D..6 * D + } + pub fn wires_output_1() -> Range { + 6 * D..7 * D + } } -impl, const D: usize> Gate for ArithmeticGate { +impl, const D: usize> Gate for ArithmeticExtensionGate { fn id(&self) -> String { format!("{:?}", self) } @@ -36,12 +50,23 @@ impl, const D: usize> Gate for ArithmeticGate { fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; - let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1]; - let addend = vars.local_wires[Self::WIRE_ADDEND]; - let output = vars.local_wires[Self::WIRE_OUTPUT]; - let computed_output = const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend; - vec![computed_output - output] + + let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); + let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); + let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); + let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + + let computed_output_0 = + fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into(); + let computed_output_1 = + fixed_multiplicand * multiplicand_1 * const_1.into() + addend_1 * const_1.into(); + + let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec(); + constraints.extend((output_1 - computed_output_1).to_basefield_array()); + constraints } fn eval_unfiltered_recursively( @@ -51,15 +76,30 @@ impl, const D: usize> Gate for ArithmeticGate { ) -> Vec> { let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; - let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1]; - let addend = vars.local_wires[Self::WIRE_ADDEND]; - let output = vars.local_wires[Self::WIRE_OUTPUT]; - let product_term = builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); - let addend_term = builder.mul_extension(const_1, addend); - let computed_output = builder.add_many_extension(&[product_term, addend_term]); - vec![builder.sub_extension(computed_output, output)] + let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); + let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); + let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); + let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); + let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); + + let computed_output_0 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_0); + let computed_output_0 = builder.scalar_mul_ext_algebra(const_0, computed_output_0); + let scaled_addend_0 = builder.scalar_mul_ext_algebra(const_1, addend_0); + let computed_output_0 = builder.add_ext_algebra(computed_output_0, scaled_addend_0); + + let computed_output_1 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_1); + let computed_output_1 = builder.scalar_mul_ext_algebra(const_0, computed_output_1); + let scaled_addend_1 = builder.scalar_mul_ext_algebra(const_1, addend_1); + let computed_output_1 = builder.add_ext_algebra(computed_output_1, scaled_addend_1); + + let diff_0 = builder.sub_ext_algebra(output_0, computed_output_0); + let diff_1 = builder.sub_ext_algebra(output_1, computed_output_1); + let mut constraints = diff_0.to_ext_target_array().to_vec(); + constraints.extend(diff_1.to_ext_target_array()); + constraints } fn generators( @@ -67,7 +107,7 @@ impl, const D: usize> Gate for ArithmeticGate { gate_index: usize, local_constants: &[F], ) -> Vec>> { - let gen = ArithmeticGenerator { + let gen = MulExtensionGenerator { gate_index, const_0: local_constants[0], const_1: local_constants[1], @@ -76,7 +116,7 @@ impl, const D: usize> Gate for ArithmeticGate { } fn num_wires(&self) -> usize { - 4 + 28 } fn num_constants(&self) -> usize { @@ -88,70 +128,71 @@ impl, const D: usize> Gate for ArithmeticGate { } fn num_constraints(&self) -> usize { - 1 + 2 * D } } -struct ArithmeticGenerator { +struct MulExtensionGenerator, const D: usize> { gate_index: usize, const_0: F, const_1: F, } -impl SimpleGenerator for ArithmeticGenerator { +impl, const D: usize> SimpleGenerator for MulExtensionGenerator { fn dependencies(&self) -> Vec { - vec![ - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }), - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }), - Target::Wire(Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_ADDEND, - }), - ] + ArithmeticExtensionGate::::wires_fixed_multiplicand() + .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) + .chain(ArithmeticExtensionGate::::wires_addend_0()) + .chain(ArithmeticExtensionGate::::wires_multiplicand_1()) + .chain(ArithmeticExtensionGate::::wires_addend_1()) + .map(|i| Target::wire(self.gate_index, i)) + .collect() } fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let multiplicand_0_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }; - let multiplicand_1_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }; - let addend_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_ADDEND, - }; - let output_target = Wire { - gate: self.gate_index, - input: ArithmeticGate::WIRE_OUTPUT, + let extract_extension = |range: Range| -> F::Extension { + let t = ExtensionTarget::from_range(self.gate_index, range); + witness.get_extension_target(t) }; - let multiplicand_0 = witness.get_wire(multiplicand_0_target); - let multiplicand_1 = witness.get_wire(multiplicand_1_target); - let addend = witness.get_wire(addend_target); + let fixed_multiplicand = + extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); + let multiplicand_0 = + extract_extension(ArithmeticExtensionGate::::wires_multiplicand_0()); + let addend_0 = extract_extension(ArithmeticExtensionGate::::wires_addend_0()); + let multiplicand_1 = + extract_extension(ArithmeticExtensionGate::::wires_multiplicand_1()); + let addend_1 = extract_extension(ArithmeticExtensionGate::::wires_addend_1()); - let output = self.const_0 * multiplicand_0 * multiplicand_1 + self.const_1 * addend; + let output_target_0 = ExtensionTarget::from_range( + self.gate_index, + ArithmeticExtensionGate::::wires_output_0(), + ); + let output_target_1 = ExtensionTarget::from_range( + self.gate_index, + ArithmeticExtensionGate::::wires_output_1(), + ); - PartialWitness::singleton_wire(output_target, output) + let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() + + addend_0 * self.const_1.into(); + let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into() + + addend_1 * self.const_1.into(); + + let mut pw = PartialWitness::new(); + pw.set_extension_target(output_target_0, computed_output_0); + pw.set_extension_target(output_target_1, computed_output_1); + pw } } #[cfg(test)] mod tests { use crate::field::crandall_field::CrandallField; - use crate::gates::arithmetic::ArithmeticGate; + use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::gate_testing::test_low_degree; #[test] fn low_degree() { - test_low_degree(ArithmeticGate::new::()) + test_low_degree(ArithmeticExtensionGate::<4>::new::()) } } diff --git a/src/gates/gate_tree.rs b/src/gates/gate_tree.rs index 9dc188a0..2471747c 100644 --- a/src/gates/gate_tree.rs +++ b/src/gates/gate_tree.rs @@ -186,11 +186,11 @@ impl, const D: usize> Tree> { mod tests { use super::*; use crate::field::crandall_field::CrandallField; + use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::base_sum::BaseSumGate; use crate::gates::constant::ConstantGate; use crate::gates::gmimc::GMiMCGate; use crate::gates::interpolation::InterpolationGate; - use crate::gates::mul_extension::ArithmeticExtensionGate; use crate::gates::noop::NoopGate; use crate::hash::GMIMC_ROUNDS; diff --git a/src/gates/mod.rs b/src/gates/mod.rs index bb8b178b..fa23b273 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -1,11 +1,10 @@ -pub(crate) mod arithmetic; +pub mod arithmetic; pub mod base_sum; pub mod constant; pub(crate) mod gate; pub mod gate_tree; pub mod gmimc; pub mod interpolation; -pub mod mul_extension; pub(crate) mod noop; #[cfg(test)] diff --git a/src/gates/mul_extension.rs b/src/gates/mul_extension.rs deleted file mode 100644 index a3006837..00000000 --- a/src/gates/mul_extension.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::ops::Range; - -use crate::circuit_builder::CircuitBuilder; -use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; -use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; -use crate::wire::Wire; -use crate::witness::PartialWitness; - -/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. -#[derive(Debug)] -pub struct ArithmeticExtensionGate; - -impl ArithmeticExtensionGate { - pub fn new>() -> GateRef { - GateRef::new(ArithmeticExtensionGate) - } - - pub fn wires_fixed_multiplicand() -> Range { - 0..D - } - pub fn wires_multiplicand_0() -> Range { - D..2 * D - } - pub fn wires_addend_0() -> Range { - 2 * D..3 * D - } - pub fn wires_multiplicand_1() -> Range { - 3 * D..4 * D - } - pub fn wires_addend_1() -> Range { - 4 * D..5 * D - } - pub fn wires_output_0() -> Range { - 5 * D..6 * D - } - pub fn wires_output_1() -> Range { - 6 * D..7 * D - } -} - -impl, const D: usize> Gate for ArithmeticExtensionGate { - fn id(&self) -> String { - format!("{:?}", self) - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let const_0 = vars.local_constants[0]; - let const_1 = vars.local_constants[1]; - - let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); - let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); - let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); - let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); - let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); - let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); - - let computed_output_0 = - fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into(); - let computed_output_1 = - fixed_multiplicand * multiplicand_1 * const_1.into() + addend_1 * const_1.into(); - - let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec(); - constraints.extend((output_1 - computed_output_1).to_basefield_array()); - constraints - } - - fn eval_unfiltered_recursively( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let const_0 = vars.local_constants[0]; - let const_1 = vars.local_constants[1]; - - let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand()); - let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0()); - let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0()); - let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1()); - let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1()); - let output_0 = vars.get_local_ext_algebra(Self::wires_output_0()); - let output_1 = vars.get_local_ext_algebra(Self::wires_output_1()); - - let computed_output_0 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_0); - let computed_output_0 = builder.scalar_mul_ext_algebra(const_0, computed_output_0); - let scaled_addend_0 = builder.scalar_mul_ext_algebra(const_1, addend_0); - let computed_output_0 = builder.add_ext_algebra(computed_output_0, scaled_addend_0); - - let computed_output_1 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_1); - let computed_output_1 = builder.scalar_mul_ext_algebra(const_0, computed_output_1); - let scaled_addend_1 = builder.scalar_mul_ext_algebra(const_1, addend_1); - let computed_output_1 = builder.add_ext_algebra(computed_output_1, scaled_addend_1); - - let diff_0 = builder.sub_ext_algebra(output_0, computed_output_0); - let diff_1 = builder.sub_ext_algebra(output_1, computed_output_1); - let mut constraints = diff_0.to_ext_target_array().to_vec(); - constraints.extend(diff_1.to_ext_target_array()); - constraints - } - - fn generators( - &self, - gate_index: usize, - local_constants: &[F], - ) -> Vec>> { - let gen = MulExtensionGenerator { - gate_index, - const_0: local_constants[0], - const_1: local_constants[1], - }; - vec![Box::new(gen)] - } - - fn num_wires(&self) -> usize { - 28 - } - - fn num_constants(&self) -> usize { - 2 - } - - fn degree(&self) -> usize { - 3 - } - - fn num_constraints(&self) -> usize { - 2 * D - } -} - -struct MulExtensionGenerator, const D: usize> { - gate_index: usize, - const_0: F, - const_1: F, -} - -impl, const D: usize> SimpleGenerator for MulExtensionGenerator { - fn dependencies(&self) -> Vec { - ArithmeticExtensionGate::::wires_fixed_multiplicand() - .chain(ArithmeticExtensionGate::::wires_multiplicand_0()) - .chain(ArithmeticExtensionGate::::wires_addend_0()) - .chain(ArithmeticExtensionGate::::wires_multiplicand_1()) - .chain(ArithmeticExtensionGate::::wires_addend_1()) - .map(|i| Target::wire(self.gate_index, i)) - .collect() - } - - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let extract_extension = |range: Range| -> F::Extension { - let t = ExtensionTarget::from_range(self.gate_index, range); - witness.get_extension_target(t) - }; - - let fixed_multiplicand = - extract_extension(ArithmeticExtensionGate::::wires_fixed_multiplicand()); - let multiplicand_0 = - extract_extension(ArithmeticExtensionGate::::wires_multiplicand_0()); - let addend_0 = extract_extension(ArithmeticExtensionGate::::wires_addend_0()); - let multiplicand_1 = - extract_extension(ArithmeticExtensionGate::::wires_multiplicand_1()); - let addend_1 = extract_extension(ArithmeticExtensionGate::::wires_addend_1()); - - let output_target_0 = ExtensionTarget::from_range( - self.gate_index, - ArithmeticExtensionGate::::wires_output_0(), - ); - let output_target_1 = ExtensionTarget::from_range( - self.gate_index, - ArithmeticExtensionGate::::wires_output_1(), - ); - - let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into() - + addend_0 * self.const_1.into(); - let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into() - + addend_1 * self.const_1.into(); - - let mut pw = PartialWitness::new(); - pw.set_extension_target(output_target_0, computed_output_0); - pw.set_extension_target(output_target_1, computed_output_1); - pw - } -} - -#[cfg(test)] -mod tests { - use crate::field::crandall_field::CrandallField; - use crate::gates::gate_testing::test_low_degree; - use crate::gates::mul_extension::ArithmeticExtensionGate; - - #[test] - fn low_degree() { - test_low_degree(ArithmeticExtensionGate::<4>::new::()) - } -} diff --git a/src/util/scaling.rs b/src/util/scaling.rs index 057e8467..0bc61840 100644 --- a/src/util/scaling.rs +++ b/src/util/scaling.rs @@ -4,7 +4,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, Frobenius}; use crate::field::field::Field; -use crate::gates::mul_extension::ArithmeticExtensionGate; +use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::generator::SimpleGenerator; use crate::polynomial::polynomial::PolynomialCoeffs; use crate::target::Target;