From 9c17a00c008fb0c0f182f63b797680c780721061 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 18 Jul 2021 23:05:57 -0700 Subject: [PATCH] Division related changes (#99) * Division related changes - Simplify `div_unsafe_extension` using virtual targets - Add methods for inversion and safe division As a followup I'll switch some calls to safe division. * Test safe division also * add_virtual_extension_target --- src/gadgets/arithmetic.rs | 12 ++++ src/gadgets/arithmetic_extension.rs | 88 ++++++++++++----------------- 2 files changed, 47 insertions(+), 53 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 5c328362..6f85cdcf 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -179,6 +179,12 @@ impl, const D: usize> CircuitBuilder { self.exp_u64_extension(base_ext, exponent).0[0] } + /// Computes `x / y`. Results in an unsatisfiable instance if `y = 0`. + pub fn div(&mut self, x: Target, y: Target) -> Target { + let y_inv = self.inverse(y); + self.mul(x, y_inv) + } + /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in /// some cases, as it allows `0 / 0 = `. pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target { @@ -201,4 +207,10 @@ impl, const D: usize> CircuitBuilder { let y_ext = self.convert_to_ext(y); self.div_unsafe_extension(x_ext, y_ext).0[0] } + + /// Computes `1 / x`. Results in an unsatisfiable instance if `x = 0`. + pub fn inverse(&mut self, x: Target) -> Target { + let x_ext = self.convert_to_ext(x); + self.inverse_extension(x_ext).0[0] + } } diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 71af783c..d8a38ea9 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -1,4 +1,4 @@ -use std::convert::TryInto; +use std::convert::{TryFrom, TryInto}; use std::ops::Range; use itertools::Itertools; @@ -324,6 +324,16 @@ impl, const D: usize> CircuitBuilder { product } + /// Computes `x / y`. Results in an unsatisfiable instance if `y = 0`. + pub fn div_extension( + &mut self, + x: ExtensionTarget, + y: ExtensionTarget, + ) -> ExtensionTarget { + let y_inv = self.inverse_extension(y); + self.mul_extension(x, y_inv) + } + /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in /// some cases, as it allows `0 / 0 = `. pub fn div_unsafe_extension( @@ -331,62 +341,35 @@ impl, const D: usize> CircuitBuilder { x: ExtensionTarget, y: ExtensionTarget, ) -> ExtensionTarget { - // Add an `ArithmeticExtensionGate` to compute `q * y`. - let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]); - - let multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_fixed_multiplicand(), - ); - let multiplicand_1 = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_multiplicand_0()); - let output = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_output_0()); - + let quotient = self.add_virtual_extension_target(); self.add_generator(QuotientGeneratorExtension { numerator: x, denominator: y, - quotient: multiplicand_0, - }); - // We need to zero out the other wires for the `ArithmeticExtensionGenerator` to hit. - self.add_generator(ZeroOutGenerator { - gate_index: gate, - ranges: vec![ - ArithmeticExtensionGate::::wires_addend_0(), - ArithmeticExtensionGate::::wires_multiplicand_1(), - ArithmeticExtensionGate::::wires_addend_1(), - ], + quotient, }); - self.route_extension(y, multiplicand_1); - self.assert_equal_extension(output, x); + // Enforce that q y = x. + let q_y = self.mul_extension(quotient, y); + self.assert_equal_extension(q_y, x); - multiplicand_0 - } -} - -/// Generator used to zero out wires at a given gate index and ranges. -pub struct ZeroOutGenerator { - gate_index: usize, - ranges: Vec>, -} - -impl SimpleGenerator for ZeroOutGenerator { - fn dependencies(&self) -> Vec { - Vec::new() + quotient } - fn run_once(&self, _witness: &PartialWitness) -> PartialWitness { - let mut pw = PartialWitness::new(); - for t in self - .ranges - .iter() - .flat_map(|r| Target::wires_from_range(self.gate_index, r.clone())) - { - pw.set_target(t, F::ZERO); - } + /// Computes `1 / x`. Results in an unsatisfiable instance if `x = 0`. + pub fn inverse_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { + let inv = self.add_virtual_extension_target(); + let one = self.one_extension(); + self.add_generator(QuotientGeneratorExtension { + numerator: one, + denominator: x, + quotient: inv, + }); - pw + // Enforce that x times its purported inverse equals 1. + let x_inv = self.mul_extension(x, inv); + self.assert_equal_extension(x_inv, one); + + inv } } @@ -407,10 +390,7 @@ impl, const D: usize> SimpleGenerator for QuotientGeneratorE let num = witness.get_extension_target(self.numerator); let dem = witness.get_extension_target(self.denominator); let quotient = num / dem; - let mut pw = PartialWitness::new(); - pw.set_extension_target(self.quotient, quotient); - - pw + PartialWitness::singleton_extension_target(self.quotient, quotient) } } @@ -479,8 +459,10 @@ mod tests { let xt = builder.constant_extension(x); let yt = builder.constant_extension(y); let zt = builder.constant_extension(z); - let comp_zt = builder.div_unsafe_extension(xt, yt); + let comp_zt = builder.div_extension(xt, yt); + let comp_zt_unsafe = builder.div_unsafe_extension(xt, yt); builder.assert_equal_extension(zt, comp_zt); + builder.assert_equal_extension(zt, comp_zt_unsafe); let data = builder.build(); let proof = data.prove(PartialWitness::new());