Division related changes (#99)

* Division related changes

- Simplify `div_unsafe_extension` using virtual targets
- Add methods for inversion and safe division

As a followup I'll switch some calls to safe division.

* Test safe division also

* add_virtual_extension_target
This commit is contained in:
Daniel Lubarov 2021-07-18 23:05:57 -07:00 committed by GitHub
parent b937679292
commit 9c17a00c00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 53 deletions

View File

@ -179,6 +179,12 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.exp_u64_extension(base_ext, exponent).0[0] self.exp_u64_extension(base_ext, exponent).0[0]
} }
/// Computes `x / y`. Results in an unsatisfiable instance if `y = 0`.
pub fn div(&mut self, x: Target, y: Target) -> Target {
let y_inv = self.inverse(y);
self.mul(x, y_inv)
}
/// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in
/// some cases, as it allows `0 / 0 = <anything>`. /// some cases, as it allows `0 / 0 = <anything>`.
pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target { pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target {
@ -201,4 +207,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let y_ext = self.convert_to_ext(y); let y_ext = self.convert_to_ext(y);
self.div_unsafe_extension(x_ext, y_ext).0[0] self.div_unsafe_extension(x_ext, y_ext).0[0]
} }
/// Computes `1 / x`. Results in an unsatisfiable instance if `x = 0`.
pub fn inverse(&mut self, x: Target) -> Target {
let x_ext = self.convert_to_ext(x);
self.inverse_extension(x_ext).0[0]
}
} }

View File

@ -1,4 +1,4 @@
use std::convert::TryInto; use std::convert::{TryFrom, TryInto};
use std::ops::Range; use std::ops::Range;
use itertools::Itertools; use itertools::Itertools;
@ -324,6 +324,16 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
product product
} }
/// Computes `x / y`. Results in an unsatisfiable instance if `y = 0`.
pub fn div_extension(
&mut self,
x: ExtensionTarget<D>,
y: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let y_inv = self.inverse_extension(y);
self.mul_extension(x, y_inv)
}
/// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in
/// some cases, as it allows `0 / 0 = <anything>`. /// some cases, as it allows `0 / 0 = <anything>`.
pub fn div_unsafe_extension( pub fn div_unsafe_extension(
@ -331,62 +341,35 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
x: ExtensionTarget<D>, x: ExtensionTarget<D>,
y: ExtensionTarget<D>, y: ExtensionTarget<D>,
) -> ExtensionTarget<D> { ) -> ExtensionTarget<D> {
// Add an `ArithmeticExtensionGate` to compute `q * y`. let quotient = self.add_virtual_extension_target();
let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]);
let multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand(),
);
let multiplicand_1 =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_multiplicand_0());
let output =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_output_0());
self.add_generator(QuotientGeneratorExtension { self.add_generator(QuotientGeneratorExtension {
numerator: x, numerator: x,
denominator: y, denominator: y,
quotient: multiplicand_0, quotient,
});
// We need to zero out the other wires for the `ArithmeticExtensionGenerator` to hit.
self.add_generator(ZeroOutGenerator {
gate_index: gate,
ranges: vec![
ArithmeticExtensionGate::<D>::wires_addend_0(),
ArithmeticExtensionGate::<D>::wires_multiplicand_1(),
ArithmeticExtensionGate::<D>::wires_addend_1(),
],
}); });
self.route_extension(y, multiplicand_1); // Enforce that q y = x.
self.assert_equal_extension(output, x); let q_y = self.mul_extension(quotient, y);
self.assert_equal_extension(q_y, x);
multiplicand_0 quotient
}
}
/// Generator used to zero out wires at a given gate index and ranges.
pub struct ZeroOutGenerator {
gate_index: usize,
ranges: Vec<Range<usize>>,
}
impl<F: Field> SimpleGenerator<F> for ZeroOutGenerator {
fn dependencies(&self) -> Vec<Target> {
Vec::new()
} }
fn run_once(&self, _witness: &PartialWitness<F>) -> PartialWitness<F> { /// Computes `1 / x`. Results in an unsatisfiable instance if `x = 0`.
let mut pw = PartialWitness::new(); pub fn inverse_extension(&mut self, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
for t in self let inv = self.add_virtual_extension_target();
.ranges let one = self.one_extension();
.iter() self.add_generator(QuotientGeneratorExtension {
.flat_map(|r| Target::wires_from_range(self.gate_index, r.clone())) numerator: one,
{ denominator: x,
pw.set_target(t, F::ZERO); quotient: inv,
} });
pw // Enforce that x times its purported inverse equals 1.
let x_inv = self.mul_extension(x, inv);
self.assert_equal_extension(x_inv, one);
inv
} }
} }
@ -407,10 +390,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for QuotientGeneratorE
let num = witness.get_extension_target(self.numerator); let num = witness.get_extension_target(self.numerator);
let dem = witness.get_extension_target(self.denominator); let dem = witness.get_extension_target(self.denominator);
let quotient = num / dem; let quotient = num / dem;
let mut pw = PartialWitness::new(); PartialWitness::singleton_extension_target(self.quotient, quotient)
pw.set_extension_target(self.quotient, quotient);
pw
} }
} }
@ -479,8 +459,10 @@ mod tests {
let xt = builder.constant_extension(x); let xt = builder.constant_extension(x);
let yt = builder.constant_extension(y); let yt = builder.constant_extension(y);
let zt = builder.constant_extension(z); let zt = builder.constant_extension(z);
let comp_zt = builder.div_unsafe_extension(xt, yt); let comp_zt = builder.div_extension(xt, yt);
let comp_zt_unsafe = builder.div_unsafe_extension(xt, yt);
builder.assert_equal_extension(zt, comp_zt); builder.assert_equal_extension(zt, comp_zt);
builder.assert_equal_extension(zt, comp_zt_unsafe);
let data = builder.build(); let data = builder.build();
let proof = data.prove(PartialWitness::new()); let proof = data.prove(PartialWitness::new());