diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index f0a2338e..6fe38b45 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -157,16 +157,20 @@ impl, const D: usize> CircuitBuilder { let evals = [0, 1, 4] .iter() .flat_map(|&i| proof.unsalted_evals(i, config)) - .map(|&e| F::Extension::from_basefield(e)); + .map(|&e| self.convert_to_ext(e)); let openings = os .constants .iter() .chain(&os.plonk_sigmas) .chain(&os.quotient_polys); - let numerator = izip!(evals, openings, &mut alpha_powers) - .map(|(e, &o, a)| a * (e - o)) - .sum::(); - let denominator = subgroup_x - zeta; + let mut numerator = self.zero_extension(); + for (e, &o) in izip!(evals, openings) { + let a = alpha_powers.next(self); + let diff = self.sub_extension(e, o); + numerator = self.mul_add_extension(a, diff, numerator); + } + let denominator = self.sub_extension(subgroup_x, zeta); + // let quotient = self.div_unsafe() sum += numerator / denominator; let ev: F::Extension = proof diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index f152f77d..a84caf03 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -225,29 +225,40 @@ impl, const D: usize> CircuitBuilder { q } -} -/// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. -#[derive(Clone)] -pub struct PowersTarget { - base: ExtensionTarget, - current: ExtensionTarget, -} + /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in + /// some cases, as it allows `0 / 0 = `. + pub fn div_unsafe_extension( + &mut self, + x: ExtensionTarget, + y: ExtensionTarget, + ) -> ExtensionTarget { + // Add an `ArithmeticGate` to compute `q * y`. + let gate = self.add_gate(ArithmeticGate::new(), vec![F::ONE, F::ZERO]); -impl, const D: usize> PowersTarget { - fn next(&mut self, builder: &mut CircuitBuilder) -> Option> { - let result = self.current; - self.current = builder.mul_extension(self.base, self.current); - Some(result) - } -} + let wire_multiplicand_0 = Wire { + gate, + input: ArithmeticGate::WIRE_MULTIPLICAND_0, + }; + let wire_multiplicand_1 = Wire { + gate, + input: ArithmeticGate::WIRE_MULTIPLICAND_1, + }; + let wire_addend = Wire { + gate, + input: ArithmeticGate::WIRE_ADDEND, + }; + let wire_output = Wire { + gate, + input: ArithmeticGate::WIRE_OUTPUT, + }; -impl, const D: usize> CircuitBuilder { - pub fn powers(&mut self, base: ExtensionTarget) -> PowersTarget { - PowersTarget { - base, - current: self.one_extension(), - } + let q = Target::Wire(wire_multiplicand_0); + self.add_generator(QuotientGeneratorExtension { + numerator: x, + denominator: ExtensionTarget(), + quotient: ExtensionTarget(), + }) } } @@ -268,3 +279,53 @@ impl SimpleGenerator for QuotientGenerator { PartialWitness::singleton_target(self.quotient, num / den) } } + +struct QuotientGeneratorExtension { + numerator: ExtensionTarget, + denominator: ExtensionTarget, + quotient: ExtensionTarget, +} + +impl SimpleGenerator for QuotientGeneratorExtension { + fn dependencies(&self) -> Vec { + let mut deps = self.numerator.to_target_array().to_vec(); + deps.extend(&self.denominator.to_target_array()); + deps + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let num = witness.get_extension_target(self.numerator); + let dem = witness.get_extension_target(self.denominator); + let quotient = num / dem; + let mut pw = PartialWitness::new(); + pw.set_ext_wires(self.quotient.to_target_array(), quotient); + pw + } +} + +/// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. +#[derive(Clone)] +pub struct PowersTarget { + base: ExtensionTarget, + current: ExtensionTarget, +} + +impl PowersTarget { + pub fn next>( + &mut self, + builder: &mut CircuitBuilder, + ) -> ExtensionTarget { + let result = self.current; + self.current = builder.mul_extension(self.base, self.current); + result + } +} + +impl, const D: usize> CircuitBuilder { + pub fn powers(&mut self, base: ExtensionTarget) -> PowersTarget { + PowersTarget { + base, + current: self.one_extension(), + } + } +} diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 49c37bae..f131b754 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -8,3 +8,4 @@ pub(crate) mod noop; #[cfg(test)] mod gate_testing; +mod mul_extension; diff --git a/src/gates/mul_extension.rs b/src/gates/mul_extension.rs new file mode 100644 index 00000000..ec5ba671 --- /dev/null +++ b/src/gates/mul_extension.rs @@ -0,0 +1,160 @@ +use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field::Field; +use crate::gates::gate::{Gate, GateRef}; +use crate::generator::{SimpleGenerator, WitnessGenerator}; +use crate::target::Target; +use crate::vars::{EvaluationTargets, EvaluationVars}; +use crate::wire::Wire; +use crate::witness::PartialWitness; +use std::ops::Range; + +/// A gate which can multiply to field extension elements. +/// TODO: Add an addend if `NUM_ROUTED_WIRES` is large enough. +#[derive(Debug)] +pub struct MulExtensionGate; + +impl MulExtensionGate { + pub fn new>() -> GateRef { + GateRef::new(MulExtensionGate) + } + + pub fn wires_multiplicand_0() -> Range { + 0..D + } + pub fn wires_multiplicand_1() -> Range { + D..2 * D + } + pub fn wires_output() -> Range { + 2 * D..3 * D + } +} + +impl, const D: usize> Gate for MulExtensionGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; + let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1]; + let addend = vars.local_wires[Self::WIRE_ADDEND]; + let output = vars.local_wires[Self::WIRE_OUTPUT]; + let computed_output = const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend; + vec![computed_output - output] + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; + let multiplicand_1 = vars.local_wires[Self::WIRE_MULTIPLICAND_1]; + let addend = vars.local_wires[Self::WIRE_ADDEND]; + let output = vars.local_wires[Self::WIRE_OUTPUT]; + + let product_term = builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); + let addend_term = builder.mul_extension(const_1, addend); + let computed_output = builder.add_many_extension(&[product_term, addend_term]); + vec![builder.sub_extension(computed_output, output)] + } + + fn generators( + &self, + gate_index: usize, + local_constants: &[F], + ) -> Vec>> { + let gen = ArithmeticGenerator { + gate_index, + const_0: local_constants[0], + const_1: local_constants[1], + }; + vec![Box::new(gen)] + } + + fn num_wires(&self) -> usize { + 4 + } + + fn num_constants(&self) -> usize { + 2 + } + + fn degree(&self) -> usize { + 3 + } + + fn num_constraints(&self) -> usize { + 1 + } +} + +struct ArithmeticGenerator { + gate_index: usize, + const_0: F, + const_1: F, +} + +impl SimpleGenerator for ArithmeticGenerator { + fn dependencies(&self) -> Vec { + vec![ + Target::Wire(Wire { + gate: self.gate_index, + input: ArithmeticGate::WIRE_MULTIPLICAND_0, + }), + Target::Wire(Wire { + gate: self.gate_index, + input: ArithmeticGate::WIRE_MULTIPLICAND_1, + }), + Target::Wire(Wire { + gate: self.gate_index, + input: ArithmeticGate::WIRE_ADDEND, + }), + ] + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let multiplicand_0_target = Wire { + gate: self.gate_index, + input: ArithmeticGate::WIRE_MULTIPLICAND_0, + }; + let multiplicand_1_target = Wire { + gate: self.gate_index, + input: ArithmeticGate::WIRE_MULTIPLICAND_1, + }; + let addend_target = Wire { + gate: self.gate_index, + input: ArithmeticGate::WIRE_ADDEND, + }; + let output_target = Wire { + gate: self.gate_index, + input: ArithmeticGate::WIRE_OUTPUT, + }; + + let multiplicand_0 = witness.get_wire(multiplicand_0_target); + let multiplicand_1 = witness.get_wire(multiplicand_1_target); + let addend = witness.get_wire(addend_target); + + let output = self.const_0 * multiplicand_0 * multiplicand_1 + self.const_1 * addend; + + PartialWitness::singleton_wire(output_target, output) + } +} + +#[cfg(test)] +mod tests { + use crate::field::crandall_field::CrandallField; + use crate::gates::arithmetic::ArithmeticGate; + use crate::gates::gate_testing::test_low_degree; + + #[test] + fn low_degree() { + test_low_degree(ArithmeticGate::new::()) + } +} diff --git a/src/witness.rs b/src/witness.rs index a0b4b2a4..ebd85d9e 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -1,9 +1,11 @@ use std::collections::HashMap; +use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::target::Target; use crate::wire::Wire; +use std::convert::TryInto; #[derive(Clone, Debug)] pub struct PartialWitness { @@ -39,6 +41,15 @@ impl PartialWitness { targets.iter().map(|&t| self.get_target(t)).collect() } + pub fn get_extension_target(&self, et: ExtensionTarget) -> F::Extension + where + F: Extendable, + { + F::Extension::from_basefield_array( + self.get_targets(&et.to_target_array()).try_into().unwrap(), + ) + } + pub fn try_get_target(&self, target: Target) -> Option { self.target_values.get(&target).cloned() }