diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index ccca4e2b..4a86ef48 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -8,6 +8,7 @@ use crate::circuit_data::{ VerifierCircuitData, VerifierOnlyCircuitData, }; use crate::field::cosets::get_unique_coset_shifts; +use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::constant::ConstantGate; @@ -130,6 +131,12 @@ impl, const D: usize> CircuitBuilder { self.assert_equal(src, dst); } + pub fn route_extension(&mut self, src: ExtensionTarget, dst: ExtensionTarget) { + for i in 0..D { + self.route(src.0[i], dst.0[i]); + } + } + /// Adds a generator which will copy `src` to `dst`. pub fn generate_copy(&mut self, src: Target, dst: Target) { self.add_generator(CopyGenerator { src, dst }); @@ -154,6 +161,12 @@ impl, const D: usize> CircuitBuilder { self.assert_equal(x, zero); } + pub fn assert_equal_extension(&mut self, x: ExtensionTarget, y: ExtensionTarget) { + for i in 0..D { + self.assert_equal(x.0[i], y.0[i]); + } + } + pub fn add_generators(&mut self, generators: Vec>>) { self.generators.extend(generators); } diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index fd32c83d..4a2e4bd3 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -3,10 +3,12 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::gates::arithmetic::ArithmeticGate; +use crate::gates::mul_extension::MulExtensionGate; use crate::generator::SimpleGenerator; use crate::target::Target; use crate::wire::Wire; use crate::witness::PartialWitness; +use std::convert::TryInto; impl, const D: usize> CircuitBuilder { /// Computes `-x`. @@ -234,32 +236,32 @@ impl, const D: usize> CircuitBuilder { y: ExtensionTarget, ) -> ExtensionTarget { // Add an `ArithmeticGate` to compute `q * y`. - let gate = self.add_gate(ArithmeticGate::new(), vec![F::ONE, F::ZERO]); + let gate = self.add_gate(MulExtensionGate::new(), vec![F::ONE]); - let wire_multiplicand_0 = Wire { - gate, - input: ArithmeticGate::WIRE_MULTIPLICAND_0, - }; - let wire_multiplicand_1 = Wire { - gate, - input: ArithmeticGate::WIRE_MULTIPLICAND_1, - }; - let wire_addend = Wire { - gate, - input: ArithmeticGate::WIRE_ADDEND, - }; - let wire_output = Wire { - gate, - input: ArithmeticGate::WIRE_OUTPUT, - }; + let multiplicand_0 = MulExtensionGate::::wires_multiplicand_0() + .map(|i| Target::Wire(Wire { gate, input: i })) + .collect::>(); + let multiplicand_0 = ExtensionTarget(multiplicand_0.try_into().unwrap()); + let multiplicand_1 = MulExtensionGate::::wires_multiplicand_1() + .map(|i| Target::Wire(Wire { gate, input: i })) + .collect::>(); + let multiplicand_1 = ExtensionTarget(multiplicand_1.try_into().unwrap()); + let output = MulExtensionGate::::wires_output() + .map(|i| Target::Wire(Wire { gate, input: i })) + .collect::>(); + let output = ExtensionTarget(output.try_into().unwrap()); - let q = Target::Wire(wire_multiplicand_0); - todo!() - // self.add_generator(QuotientGeneratorExtension { - // numerator: x, - // denominator: ExtensionTarget(), - // quotient: ExtensionTarget(), - // }) + self.add_generator(QuotientGeneratorExtension { + numerator: x, + denominator: y, + quotient: multiplicand_0, + }); + + self.route_extension(y, multiplicand_1); + + self.assert_equal_extension(output, x); + + multiplicand_0 } } @@ -335,3 +337,53 @@ impl, const D: usize> CircuitBuilder { } } } + +#[cfg(test)] +mod tests { + use crate::circuit_builder::CircuitBuilder; + use crate::circuit_data::CircuitConfig; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::field::field::Field; + use crate::fri::FriConfig; + use crate::prover::PLONK_BLINDING; + use crate::witness::PartialWitness; + + #[test] + fn test_div_extension() { + type F = CrandallField; + type FF = QuarticCrandallField; + const D: usize = 4; + + let config = CircuitConfig { + num_wires: 134, + num_routed_wires: 12, + security_bits: 128, + rate_bits: 0, + num_challenges: 3, + fri_config: FriConfig { + proof_of_work_bits: 1, + rate_bits: 0, + reduction_arity_bits: vec![1], + num_query_rounds: 1, + blinding: PLONK_BLINDING.to_vec(), + }, + }; + + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand(); + let y = FF::rand(); + let x = FF::TWO; + let y = FF::ONE; + let z = x / y; + let xt = builder.constant_extension(x); + let yt = builder.constant_extension(y); + let zt = builder.constant_extension(z); + let comp_zt = builder.div_unsafe_extension(xt, yt); + builder.assert_equal_extension(zt, comp_zt); + + let data = builder.build(); + let proof = data.prove(PartialWitness::new()); + } +} diff --git a/src/gates/mod.rs b/src/gates/mod.rs index f131b754..3ac7bd74 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -4,8 +4,8 @@ pub mod constant; pub(crate) mod gate; pub mod gmimc; mod interpolation; +pub mod mul_extension; pub(crate) mod noop; #[cfg(test)] mod gate_testing; -mod mul_extension; diff --git a/src/witness.rs b/src/witness.rs index ebd85d9e..c5748c3f 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -81,6 +81,26 @@ impl PartialWitness { } } + pub fn set_extension_target( + &mut self, + et: ExtensionTarget, + value: F::Extension, + ) where + F: Extendable, + { + let limbs = value.to_basefield_array(); + for i in 0..D { + let opt_old_value = self.target_values.insert(et.0[i], limbs[i]); + if let Some(old_value) = opt_old_value { + assert_eq!( + old_value, limbs[i], + "Target was set twice with different values: {:?}", + et.0[i] + ); + } + } + } + pub fn set_wire(&mut self, wire: Wire, value: F) { self.set_target(Target::Wire(wire), value) }