plonky2/src/gadgets/arithmetic.rs

411 lines
13 KiB
Rust
Raw Normal View History

2021-06-25 15:11:49 +02:00
use std::ops::Range;
use crate::circuit_builder::CircuitBuilder;
2021-06-04 17:36:48 +02:00
use crate::field::extension_field::target::ExtensionTarget;
2021-06-07 17:09:53 +02:00
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field::Field;
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::generator::SimpleGenerator;
2021-04-21 22:31:45 +02:00
use crate::target::Target;
2021-06-25 16:27:20 +02:00
use crate::util::bits_u64;
2021-04-02 15:29:21 -07:00
use crate::wire::Wire;
2021-04-21 11:47:18 -07:00
use crate::witness::PartialWitness;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Computes `-x`.
2021-04-02 15:29:21 -07:00
pub fn neg(&mut self, x: Target) -> Target {
let neg_one = self.neg_one();
self.mul(x, neg_one)
}
/// Computes `x^2`.
pub fn square(&mut self, x: Target) -> Target {
self.mul(x, x)
}
2021-06-25 16:27:20 +02:00
/// Computes `x^2`.
pub fn square_extension(&mut self, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
self.mul_extension(x, x)
}
/// Computes `x^3`.
pub fn cube(&mut self, x: Target) -> Target {
self.mul_many(&[x, x, x])
}
2021-04-21 11:47:18 -07:00
/// Computes `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`.
pub fn arithmetic(
&mut self,
const_0: F,
multiplicand_0: Target,
multiplicand_1: Target,
const_1: F,
addend: Target,
) -> Target {
// See if we can determine the result without adding an `ArithmeticGate`.
if let Some(result) =
self.arithmetic_special_cases(const_0, multiplicand_0, multiplicand_1, const_1, addend)
{
2021-04-21 11:47:18 -07:00
return result;
2021-04-02 15:29:21 -07:00
}
2021-06-25 13:53:14 +02:00
let multiplicand_0_ext = self.convert_to_ext(multiplicand_0);
let multiplicand_1_ext = self.convert_to_ext(multiplicand_1);
let addend_ext = self.convert_to_ext(addend);
self.arithmetic_extension(
const_0,
const_1,
multiplicand_0_ext,
multiplicand_1_ext,
addend_ext,
)
.0[0]
2021-04-02 15:29:21 -07:00
}
2021-04-21 11:47:18 -07:00
/// Checks for special cases where the value of
/// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`
/// can be determined without adding an `ArithmeticGate`.
fn arithmetic_special_cases(
&mut self,
const_0: F,
multiplicand_0: Target,
multiplicand_1: Target,
const_1: F,
addend: Target,
) -> Option<Target> {
let zero = self.zero();
let mul_0_const = self.target_as_constant(multiplicand_0);
let mul_1_const = self.target_as_constant(multiplicand_1);
let addend_const = self.target_as_constant(addend);
let first_term_zero =
const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero;
2021-04-21 11:47:18 -07:00
let second_term_zero = const_1 == F::ZERO || addend == zero;
// If both terms are constant, return their (constant) sum.
let first_term_const = if first_term_zero {
Some(F::ZERO)
} else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) {
Some(const_0 * x * y)
} else {
None
};
let second_term_const = if second_term_zero {
Some(F::ZERO)
} else {
addend_const.map(|x| const_1 * x)
};
if let (Some(x), Some(y)) = (first_term_const, second_term_const) {
return Some(self.constant(x + y));
}
2021-04-23 12:35:19 -07:00
if first_term_zero && const_1.is_one() {
return Some(addend);
2021-04-21 11:47:18 -07:00
}
if second_term_zero {
if let Some(x) = mul_0_const {
if (const_0 * x).is_one() {
return Some(multiplicand_1);
}
}
if let Some(x) = mul_1_const {
if (const_1 * x).is_one() {
return Some(multiplicand_0);
}
}
}
None
}
/// Computes `x * y + z`.
pub fn mul_add(&mut self, x: Target, y: Target, z: Target) -> Target {
self.arithmetic(F::ONE, x, y, F::ONE, z)
}
/// Computes `x * y - z`.
pub fn mul_sub(&mut self, x: Target, y: Target, z: Target) -> Target {
self.arithmetic(F::ONE, x, y, F::NEG_ONE, z)
}
/// Computes `x + y`.
2021-04-21 11:47:18 -07:00
pub fn add(&mut self, x: Target, y: Target) -> Target {
let one = self.one();
// x + y = 1 * x * 1 + 1 * y
self.arithmetic(F::ONE, x, one, F::ONE, y)
}
2021-04-02 15:29:21 -07:00
pub fn add_many(&mut self, terms: &[Target]) -> Target {
let mut sum = self.zero();
for term in terms {
sum = self.add(sum, *term);
}
sum
}
/// Computes `x - y`.
pub fn sub(&mut self, x: Target, y: Target) -> Target {
2021-04-21 11:47:18 -07:00
let one = self.one();
// x - y = 1 * x * 1 + (-1) * y
self.arithmetic(F::ONE, x, one, F::NEG_ONE, y)
}
/// Computes `x * y`.
pub fn mul(&mut self, x: Target, y: Target) -> Target {
2021-04-21 11:47:18 -07:00
// x * y = 1 * x * y + 0 * x
self.arithmetic(F::ONE, x, y, F::ZERO, x)
}
2021-04-02 15:29:21 -07:00
pub fn mul_many(&mut self, terms: &[Target]) -> Target {
let mut product = self.one();
for term in terms {
product = self.mul(product, *term);
}
product
}
// TODO: Optimize this, maybe with a new gate.
2021-06-25 16:27:20 +02:00
// TODO: Test
2021-06-16 08:56:58 +02:00
/// Exponentiate `base` to the power of `exponent`, where `exponent < 2^num_bits`.
pub fn exp(&mut self, base: Target, exponent: Target, num_bits: usize) -> Target {
let mut current = base;
2021-06-25 16:27:20 +02:00
let one_ext = self.one_extension();
let mut product = self.one();
2021-06-16 08:56:58 +02:00
let exponent_bits = self.split_le(exponent, num_bits);
for bit in exponent_bits.into_iter() {
2021-06-25 16:27:20 +02:00
let current_ext = self.convert_to_ext(current);
let multiplicand = self.select(bit, current_ext, one_ext);
product = self.mul(product, multiplicand.0[0]);
current = self.mul(current, current);
}
product
}
2021-06-25 16:27:20 +02:00
/// Exponentiate `base` to the power of a known `exponent`.
// TODO: Test
pub fn exp_u64(&mut self, base: Target, exponent: u64) -> Target {
let mut current = base;
let mut product = self.one();
for j in 0..bits_u64(exponent as u64) {
if (exponent >> j & 1) != 0 {
product = self.mul(product, current);
}
current = self.square(current);
}
product
}
/// Exponentiate `base` to the power of a known `exponent`.
// TODO: Test
pub fn exp_u64_extension(
&mut self,
base: ExtensionTarget<D>,
exponent: u64,
) -> ExtensionTarget<D> {
let mut current = base;
let mut product = self.one_extension();
for j in 0..bits_u64(exponent as u64) {
if (exponent >> j & 1) != 0 {
product = self.mul_extension(product, current);
}
current = self.square_extension(current);
}
product
}
2021-04-21 11:47:18 -07:00
/// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in
/// some cases, as it allows `0 / 0 = <anything>`.
pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target {
// Check for special cases where we can determine the result without an `ArithmeticGate`.
let zero = self.zero();
let one = self.one();
if x == zero {
return zero;
}
if y == one {
return x;
}
if let (Some(x_const), Some(y_const)) =
(self.target_as_constant(x), self.target_as_constant(y))
{
2021-04-21 11:47:18 -07:00
return self.constant(x_const / y_const);
}
2021-06-25 13:53:14 +02:00
let x_ext = self.convert_to_ext(x);
let y_ext = self.convert_to_ext(y);
self.div_unsafe_extension(x_ext, y_ext).0[0]
2021-04-21 11:47:18 -07:00
}
2021-06-07 11:19:54 +02:00
/// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in
/// some cases, as it allows `0 / 0 = <anything>`.
pub fn div_unsafe_extension(
&mut self,
x: ExtensionTarget<D>,
y: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
2021-06-25 13:53:14 +02:00
// Add an `ArithmeticExtensionGate` to compute `q * y`.
let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ZERO]);
2021-06-07 17:55:27 +02:00
2021-06-25 15:11:49 +02:00
let multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand(),
);
let multiplicand_1 =
2021-06-25 15:11:49 +02:00
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_multiplicand_0());
let output =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_output_0());
2021-06-07 17:55:27 +02:00
self.add_generator(QuotientGeneratorExtension {
numerator: x,
denominator: y,
quotient: multiplicand_0,
});
2021-06-25 15:11:49 +02:00
self.add_generator(ZeroOutGenerator {
gate_index: gate,
ranges: vec![
ArithmeticExtensionGate::<D>::wires_addend_0(),
ArithmeticExtensionGate::<D>::wires_multiplicand_1(),
ArithmeticExtensionGate::<D>::wires_addend_1(),
],
});
2021-06-04 17:36:48 +02:00
2021-06-07 17:55:27 +02:00
self.route_extension(y, multiplicand_1);
self.assert_equal_extension(output, x);
multiplicand_0
2021-06-04 17:36:48 +02:00
}
}
2021-06-07 11:19:54 +02:00
struct QuotientGeneratorExtension<const D: usize> {
numerator: ExtensionTarget<D>,
denominator: ExtensionTarget<D>,
quotient: ExtensionTarget<D>,
}
2021-06-07 17:09:53 +02:00
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for QuotientGeneratorExtension<D> {
2021-06-07 11:19:54 +02:00
fn dependencies(&self) -> Vec<Target> {
let mut deps = self.numerator.to_target_array().to_vec();
deps.extend(&self.denominator.to_target_array());
deps
}
fn run_once(&self, witness: &PartialWitness<F>) -> PartialWitness<F> {
let num = witness.get_extension_target(self.numerator);
let dem = witness.get_extension_target(self.denominator);
let quotient = num / dem;
let mut pw = PartialWitness::new();
2021-06-07 17:09:53 +02:00
for i in 0..D {
pw.set_target(
self.quotient.to_target_array()[i],
quotient.to_basefield_array()[i],
);
}
2021-06-25 15:11:49 +02:00
pw
}
}
/// Generator used to zero out wires at a given gate index and ranges.
pub struct ZeroOutGenerator {
gate_index: usize,
ranges: Vec<Range<usize>>,
}
impl<F: Field> SimpleGenerator<F> for ZeroOutGenerator {
fn dependencies(&self) -> Vec<Target> {
Vec::new()
}
fn run_once(&self, _witness: &PartialWitness<F>) -> PartialWitness<F> {
let mut pw = PartialWitness::new();
for t in self
.ranges
.iter()
.flat_map(|r| Target::wires_from_range(self.gate_index, r.clone()))
{
pw.set_target(t, F::ZERO);
}
2021-06-07 11:19:54 +02:00
pw
}
}
/// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`.
#[derive(Clone)]
pub struct PowersTarget<const D: usize> {
base: ExtensionTarget<D>,
current: ExtensionTarget<D>,
}
impl<const D: usize> PowersTarget<D> {
pub fn next<F: Extendable<D>>(
&mut self,
builder: &mut CircuitBuilder<F, D>,
) -> ExtensionTarget<D> {
let result = self.current;
2021-06-09 21:12:15 +02:00
self.current = builder.mul_extension(self.base, self.current);
2021-06-07 11:19:54 +02:00
result
}
pub fn repeated_frobenius<F: Extendable<D>>(
self,
k: usize,
builder: &mut CircuitBuilder<F, D>,
) -> Self {
let Self { base, current } = self;
Self {
base: base.repeated_frobenius(k, builder),
current: current.repeated_frobenius(k, builder),
}
}
2021-06-07 11:19:54 +02:00
}
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn powers(&mut self, base: ExtensionTarget<D>) -> PowersTarget<D> {
PowersTarget {
base,
current: self.one_extension(),
}
}
}
2021-06-07 17:55:27 +02:00
#[cfg(test)]
mod tests {
use crate::circuit_builder::CircuitBuilder;
use crate::circuit_data::CircuitConfig;
use crate::field::crandall_field::CrandallField;
use crate::field::extension_field::quartic::QuarticCrandallField;
use crate::field::field::Field;
use crate::fri::FriConfig;
use crate::witness::PartialWitness;
#[test]
fn test_div_extension() {
type F = CrandallField;
type FF = QuarticCrandallField;
const D: usize = 4;
let config = CircuitConfig::large_config();
2021-06-07 17:55:27 +02:00
let mut builder = CircuitBuilder::<F, D>::new(config);
let x = FF::rand();
let y = FF::rand();
let z = x / y;
let xt = builder.constant_extension(x);
let yt = builder.constant_extension(y);
let zt = builder.constant_extension(z);
let comp_zt = builder.div_unsafe_extension(xt, yt);
builder.assert_equal_extension(zt, comp_zt);
let data = builder.build();
let proof = data.prove(PartialWitness::new());
}
}