diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index ebad5025..e2794330 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -3,6 +3,7 @@ use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use num::bigint::BigUint; +use num::Integer; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -99,6 +100,15 @@ impl> Field for QuadraticExtension { )) } + fn from_biguint(n: BigUint) -> Self { + let (high, low) = n.div_rem(&F::order()); + Self([F::from_biguint(low), F::from_biguint(high)]) + } + + fn to_biguint(&self) -> BigUint { + self.0[0].to_biguint() + F::order() * self.0[1].to_biguint() + } + fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index 001da821..01918ff3 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -4,6 +4,7 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi use num::bigint::BigUint; use num::traits::Pow; +use num::Integer; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -104,6 +105,26 @@ impl> Field for QuarticExtension { )) } + fn from_biguint(n: BigUint) -> Self { + let (rest, first) = n.div_rem(&F::order()); + let (rest, second) = rest.div_rem(&F::order()); + let (rest, third) = rest.div_rem(&F::order()); + Self([ + F::from_biguint(first), + F::from_biguint(second), + F::from_biguint(third), + F::from_biguint(rest), + ]) + } + + fn to_biguint(&self) -> BigUint { + let mut result = self.0[3].to_biguint(); + result = result * F::order() + self.0[2].to_biguint(); + result = result * F::order() + self.0[1].to_biguint(); + result = result * F::order() + self.0[0].to_biguint(); + result + } + fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 4fe10b17..f5d06fdb 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -206,6 +206,11 @@ pub trait Field: subgroup.into_iter().map(|x| x * shift).collect() } + // TODO: move these to a new `PrimeField` trait (for all prime fields, not just 64-bit ones) + fn from_biguint(n: BigUint) -> Self; + + fn to_biguint(&self) -> BigUint; + fn from_canonical_u64(n: u64) -> Self; fn from_canonical_u32(n: u32) -> Self { diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index 45164506..cb85d56d 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -4,7 +4,7 @@ use std::hash::{Hash, Hasher}; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use num::BigUint; +use num::{BigUint, Integer}; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -91,6 +91,14 @@ impl Field for GoldilocksField { try_inverse_u64(self) } + fn from_biguint(n: BigUint) -> Self { + Self(n.mod_floor(&Self::order()).to_u64_digits()[0]) + } + + fn to_biguint(&self) -> BigUint { + self.to_canonical_u64().into() + } + #[inline] fn from_canonical_u64(n: u64) -> Self { debug_assert!(n < Self::ORDER); diff --git a/src/field/secp256k1.rs b/src/field/secp256k1.rs index 56d506d6..5f8e1b4e 100644 --- a/src/field/secp256k1.rs +++ b/src/field/secp256k1.rs @@ -36,27 +36,6 @@ fn biguint_from_array(arr: [u64; 4]) -> BigUint { ]) } -impl Secp256K1Base { - fn to_canonical_biguint(&self) -> BigUint { - let mut result = biguint_from_array(self.0); - if result >= Self::order() { - result -= Self::order(); - } - result - } - - fn from_biguint(val: BigUint) -> Self { - Self( - val.to_u64_digits() - .into_iter() - .pad_using(4, |_| 0) - .collect::>()[..] - .try_into() - .expect("error converting to u64 array"), - ) - } -} - impl Default for Secp256K1Base { fn default() -> Self { Self::ZERO @@ -65,7 +44,7 @@ impl Default for Secp256K1Base { impl PartialEq for Secp256K1Base { fn eq(&self, other: &Self) -> bool { - self.to_canonical_biguint() == other.to_canonical_biguint() + self.to_biguint() == other.to_biguint() } } @@ -73,19 +52,19 @@ impl Eq for Secp256K1Base {} impl Hash for Secp256K1Base { fn hash(&self, state: &mut H) { - self.to_canonical_biguint().hash(state) + self.to_biguint().hash(state) } } impl Display for Secp256K1Base { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Display::fmt(&self.to_canonical_biguint(), f) + Display::fmt(&self.to_biguint(), f) } } impl Debug for Secp256K1Base { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Debug::fmt(&self.to_canonical_biguint(), f) + Debug::fmt(&self.to_biguint(), f) } } @@ -129,6 +108,25 @@ impl Field for Secp256K1Base { Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) } + fn to_biguint(&self) -> BigUint { + let mut result = biguint_from_array(self.0); + if result >= Self::order() { + result -= Self::order(); + } + result + } + + fn from_biguint(val: BigUint) -> Self { + Self( + val.to_u64_digits() + .into_iter() + .pad_using(4, |_| 0) + .collect::>()[..] + .try_into() + .expect("error converting to u64 array"), + ) + } + #[inline] fn from_canonical_u64(n: u64) -> Self { Self([n, 0, 0, 0]) @@ -157,7 +155,7 @@ impl Neg for Secp256K1Base { if self.is_zero() { Self::ZERO } else { - Self::from_biguint(Self::order() - self.to_canonical_biguint()) + Self::from_biguint(Self::order() - self.to_biguint()) } } } @@ -167,7 +165,7 @@ impl Add for Secp256K1Base { #[inline] fn add(self, rhs: Self) -> Self { - let mut result = self.to_canonical_biguint() + rhs.to_canonical_biguint(); + let mut result = self.to_biguint() + rhs.to_biguint(); if result >= Self::order() { result -= Self::order(); } @@ -210,9 +208,7 @@ impl Mul for Secp256K1Base { #[inline] fn mul(self, rhs: Self) -> Self { - Self::from_biguint( - (self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()), - ) + Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order())) } } diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 24499760..e2654dcc 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -12,33 +12,6 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bits_u64; impl, const D: usize> CircuitBuilder { - /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. - /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index - /// `g` and the gate's `i`-th operation is available. - fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { - let (gate, i) = self - .free_arithmetic - .get(&(const_0, const_1)) - .copied() - .unwrap_or_else(|| { - let gate = self.add_gate( - ArithmeticExtensionGate::new_from_config(&self.config), - vec![const_0, const_1], - ); - (gate, 0) - }); - - // Update `free_arithmetic` with new values. - if i < ArithmeticExtensionGate::::num_ops(&self.config) - 1 { - self.free_arithmetic - .insert((const_0, const_1), (gate, i + 1)); - } else { - self.free_arithmetic.remove(&(const_0, const_1)); - } - - (gate, i) - } - pub fn arithmetic_extension( &mut self, const_0: F, diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs new file mode 100644 index 00000000..ce7aa121 --- /dev/null +++ b/src/gadgets/arithmetic_u32.rs @@ -0,0 +1,182 @@ +use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; +use crate::gates::arithmetic_u32::U32ArithmeticGate; +use crate::gates::subtraction_u32::U32SubtractionGate; +use crate::iop::target::Target; +use crate::plonk::circuit_builder::CircuitBuilder; + +#[derive(Clone, Copy, Debug)] +pub struct U32Target(pub Target); + +impl, const D: usize> CircuitBuilder { + pub fn add_virtual_u32_target(&mut self) -> U32Target { + U32Target(self.add_virtual_target()) + } + + pub fn add_virtual_u32_targets(&mut self, n: usize) -> Vec { + self.add_virtual_targets(n) + .into_iter() + .map(U32Target) + .collect() + } + + pub fn zero_u32(&mut self) -> U32Target { + U32Target(self.zero()) + } + + pub fn one_u32(&mut self) -> U32Target { + U32Target(self.one()) + } + + pub fn connect_u32(&mut self, x: U32Target, y: U32Target) { + self.connect(x.0, y.0) + } + + pub fn assert_zero_u32(&mut self, x: U32Target) { + self.assert_zero(x.0) + } + + /// Checks for special cases where the value of + /// `x * y + z` + /// can be determined without adding a `U32ArithmeticGate`. + pub fn arithmetic_u32_special_cases( + &mut self, + x: U32Target, + y: U32Target, + z: U32Target, + ) -> Option<(U32Target, U32Target)> { + let x_const = self.target_as_constant(x.0); + let y_const = self.target_as_constant(y.0); + let z_const = self.target_as_constant(z.0); + + // If both terms are constant, return their (constant) sum. + let first_term_const = if let (Some(xx), Some(yy)) = (x_const, y_const) { + Some(xx * yy) + } else { + None + }; + + if let (Some(a), Some(b)) = (first_term_const, z_const) { + let sum = (a + b).to_canonical_u64(); + let (low, high) = (sum as u32, (sum >> 32) as u32); + return Some((self.constant_u32(low), self.constant_u32(high))); + } + + None + } + + // Returns x * y + z. + pub fn mul_add_u32( + &mut self, + x: U32Target, + y: U32Target, + z: U32Target, + ) -> (U32Target, U32Target) { + if let Some(result) = self.arithmetic_u32_special_cases(x, y, z) { + return result; + } + + let (gate_index, copy) = self.find_u32_arithmetic_gate(); + + self.connect( + Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_multiplicand_0(copy), + ), + x.0, + ); + self.connect( + Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_multiplicand_1(copy), + ), + y.0, + ); + self.connect( + Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(copy)), + z.0, + ); + + let output_low = U32Target(Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_output_low_half(copy), + )); + let output_high = U32Target(Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_output_high_half(copy), + )); + + (output_low, output_high) + } + + pub fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { + let one = self.one_u32(); + self.mul_add_u32(a, one, b) + } + + pub fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target) { + match to_add.len() { + 0 => (self.zero_u32(), self.zero_u32()), + 1 => (to_add[0], self.zero_u32()), + 2 => self.add_u32(to_add[0], to_add[1]), + _ => { + let (mut low, mut carry) = self.add_u32(to_add[0], to_add[1]); + for i in 2..to_add.len() { + let (new_low, new_carry) = self.add_u32(to_add[i], low); + let (combined_carry, _zero) = self.add_u32(carry, new_carry); + low = new_low; + carry = combined_carry; + } + (low, carry) + } + } + } + + pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { + let zero = self.zero_u32(); + self.mul_add_u32(a, b, zero) + } + + // Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x). + pub fn sub_u32( + &mut self, + x: U32Target, + y: U32Target, + borrow: U32Target, + ) -> (U32Target, U32Target) { + let (gate_index, copy) = self.find_u32_subtraction_gate(); + + self.connect( + Target::wire( + gate_index, + U32SubtractionGate::::wire_ith_input_x(copy), + ), + x.0, + ); + self.connect( + Target::wire( + gate_index, + U32SubtractionGate::::wire_ith_input_y(copy), + ), + y.0, + ); + self.connect( + Target::wire( + gate_index, + U32SubtractionGate::::wire_ith_input_borrow(copy), + ), + borrow.0, + ); + + let output_result = U32Target(Target::wire( + gate_index, + U32SubtractionGate::::wire_ith_output_result(copy), + )); + let output_borrow = U32Target(Target::wire( + gate_index, + U32SubtractionGate::::wire_ith_output_borrow(copy), + )); + + (output_result, output_borrow) + } +} diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs new file mode 100644 index 00000000..fff97a6e --- /dev/null +++ b/src/gadgets/biguint.rs @@ -0,0 +1,371 @@ +use std::marker::PhantomData; + +use num::{BigUint, Integer}; + +use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::{BoolTarget, Target}; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; + +#[derive(Clone, Debug)] +pub struct BigUintTarget { + pub limbs: Vec, +} + +impl BigUintTarget { + pub fn num_limbs(&self) -> usize { + self.limbs.len() + } + + pub fn get_limb(&self, i: usize) -> U32Target { + self.limbs[i] + } +} + +impl, const D: usize> CircuitBuilder { + pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { + let limb_values = value.to_u32_digits(); + let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect(); + + BigUintTarget { limbs } + } + + pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { + let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); + for i in 0..min_limbs { + self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); + } + + for i in min_limbs..lhs.num_limbs() { + self.assert_zero_u32(lhs.get_limb(i)); + } + for i in min_limbs..rhs.num_limbs() { + self.assert_zero_u32(rhs.get_limb(i)); + } + } + + pub fn pad_biguints( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget) { + if a.num_limbs() > b.num_limbs() { + let mut padded_b = b.clone(); + for _ in b.num_limbs()..a.num_limbs() { + padded_b.limbs.push(self.zero_u32()); + } + + (a.clone(), padded_b) + } else { + let mut padded_a = a.clone(); + for _ in a.num_limbs()..b.num_limbs() { + padded_a.limbs.push(self.zero_u32()); + } + + (padded_a, b.clone()) + } + } + + pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { + let (a, b) = self.pad_biguints(a, b); + + self.list_le_u32(a.limbs, b.limbs) + } + + pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { + let limbs = (0..num_limbs) + .map(|_| self.add_virtual_u32_target()) + .collect(); + + BigUintTarget { limbs } + } + + // Add two `BigUintTarget`s. + pub fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let num_limbs = a.num_limbs().max(b.num_limbs()); + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for i in 0..num_limbs { + let a_limb = (i < a.num_limbs()) + .then(|| a.limbs[i]) + .unwrap_or_else(|| self.zero_u32()); + let b_limb = (i < b.num_limbs()) + .then(|| b.limbs[i]) + .unwrap_or_else(|| self.zero_u32()); + + let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]); + carry = new_carry; + combined_limbs.push(new_limb); + } + combined_limbs.push(carry); + + BigUintTarget { + limbs: combined_limbs, + } + } + + // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. + pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let num_limbs = a.limbs.len(); + let (a, b) = self.pad_biguints(a, b); + + let mut result_limbs = vec![]; + + let mut borrow = self.zero_u32(); + for i in 0..num_limbs { + let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); + result_limbs.push(result); + borrow = new_borrow; + } + // Borrow should be zero here. + + BigUintTarget { + limbs: result_limbs, + } + } + + pub fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let total_limbs = a.limbs.len() + b.limbs.len(); + + let mut to_add = vec![vec![]; total_limbs]; + for i in 0..a.limbs.len() { + for j in 0..b.limbs.len() { + let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); + to_add[i + j].push(product); + to_add[i + j + 1].push(carry); + } + } + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for i in 0..total_limbs { + to_add[i].push(carry); + let (new_result, new_carry) = self.add_many_u32(&to_add[i].clone()); + combined_limbs.push(new_result); + carry = new_carry; + } + combined_limbs.push(carry); + + BigUintTarget { + limbs: combined_limbs, + } + } + + pub fn div_rem_biguint( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget) { + let num_limbs = a.limbs.len(); + let div = self.add_virtual_biguint_target(num_limbs); + let rem = self.add_virtual_biguint_target(num_limbs); + + self.add_simple_generator(BigUintDivRemGenerator:: { + a: a.clone(), + b: b.clone(), + div: div.clone(), + rem: rem.clone(), + _phantom: PhantomData, + }); + + let div_b = self.mul_biguint(&div, &b); + let div_b_plus_rem = self.add_biguint(&div_b, &rem); + self.connect_biguint(&a, &div_b_plus_rem); + + let cmp_rem_b = self.cmp_biguint(&rem, b); + self.assert_one(cmp_rem_b.target); + + (div, rem) + } + + pub fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let (div, _rem) = self.div_rem_biguint(a, b); + div + } + + pub fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let (_div, rem) = self.div_rem_biguint(a, b); + rem + } +} + +#[derive(Debug)] +struct BigUintDivRemGenerator, const D: usize> { + a: BigUintTarget, + b: BigUintTarget, + div: BigUintTarget, + rem: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for BigUintDivRemGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .limbs + .iter() + .chain(&self.b.limbs) + .map(|&l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_biguint_target(self.a.clone()); + let b = witness.get_biguint_target(self.b.clone()); + let (div, rem) = a.div_rem(&b); + + out_buffer.set_biguint_target(self.div.clone(), div); + out_buffer.set_biguint_target(self.rem.clone(), rem); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use num::{BigUint, FromPrimitive, Integer}; + use rand::Rng; + + use crate::{ + field::goldilocks_field::GoldilocksField, + iop::witness::PartialWitness, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify}, + }; + + #[test] + fn test_biguint_add() -> Result<()> { + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + let expected_z_value = &x_value + &y_value; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.add_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); + + builder.connect_biguint(&z, &expected_z); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_biguint_sub() -> Result<()> { + let mut rng = rand::thread_rng(); + + let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); + if y_value > x_value { + (x_value, y_value) = (y_value, x_value); + } + let expected_z_value = &x_value - &y_value; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.sub_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); + + builder.connect_biguint(&z, &expected_z); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_biguint_mul() -> Result<()> { + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + let expected_z_value = &x_value * &y_value; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.mul_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); + + builder.connect_biguint(&z, &expected_z); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_biguint_cmp() -> Result<()> { + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let cmp = builder.cmp_biguint(&x, &y); + let expected_cmp = builder.constant_bool(x_value <= y_value); + + builder.connect(cmp.target, expected_cmp.target); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_biguint_div_rem() -> Result<()> { + let mut rng = rand::thread_rng(); + + let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); + if y_value > x_value { + (x_value, y_value) = (y_value, x_value); + } + let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value); + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let (div, rem) = builder.div_rem_biguint(&x, &y); + + let expected_div = builder.constant_biguint(&expected_div_value); + let expected_rem = builder.constant_biguint(&expected_rem_value); + + builder.connect_biguint(&div, &expected_div); + builder.connect_biguint(&rem, &expected_rem); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index aa18fbeb..8b6e60f6 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -1,8 +1,12 @@ pub mod arithmetic; pub mod arithmetic_extension; +pub mod arithmetic_u32; +pub mod biguint; pub mod hash; pub mod insert; pub mod interpolation; +pub mod multiple_comparison; +pub mod nonnative; pub mod permutation; pub mod polynomial; pub mod random_access; diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs new file mode 100644 index 00000000..77e660e6 --- /dev/null +++ b/src/gadgets/multiple_comparison.rs @@ -0,0 +1,138 @@ +use super::arithmetic_u32::U32Target; +use crate::field::extension_field::Extendable; +use crate::field::field_types::RichField; +use crate::gates::comparison::ComparisonGate; +use crate::iop::target::{BoolTarget, Target}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::ceil_div_usize; + +impl, const D: usize> CircuitBuilder { + /// Returns true if a is less than or equal to b, considered as base-`2^num_bits` limbs of a large value. + /// This range-checks its inputs. + pub fn list_le(&mut self, a: Vec, b: Vec, num_bits: usize) -> BoolTarget { + assert_eq!( + a.len(), + b.len(), + "Comparison must be between same number of inputs and outputs" + ); + let n = a.len(); + + let chunk_bits = 2; + let num_chunks = ceil_div_usize(num_bits, chunk_bits); + + let one = self.one(); + let mut result = one; + for i in 0..n { + let a_le_b_gate = ComparisonGate::new(num_bits, num_chunks); + let a_le_b_gate_index = self.add_gate(a_le_b_gate.clone(), vec![]); + self.connect( + Target::wire(a_le_b_gate_index, a_le_b_gate.wire_first_input()), + a[i], + ); + self.connect( + Target::wire(a_le_b_gate_index, a_le_b_gate.wire_second_input()), + b[i], + ); + let a_le_b_result = Target::wire(a_le_b_gate_index, a_le_b_gate.wire_result_bool()); + + let b_le_a_gate = ComparisonGate::new(num_bits, num_chunks); + let b_le_a_gate_index = self.add_gate(b_le_a_gate.clone(), vec![]); + self.connect( + Target::wire(b_le_a_gate_index, b_le_a_gate.wire_first_input()), + b[i], + ); + self.connect( + Target::wire(b_le_a_gate_index, b_le_a_gate.wire_second_input()), + a[i], + ); + let b_le_a_result = Target::wire(b_le_a_gate_index, b_le_a_gate.wire_result_bool()); + + let these_limbs_equal = self.mul(a_le_b_result, b_le_a_result); + let these_limbs_less_than = self.sub(one, b_le_a_result); + result = self.mul_add(these_limbs_equal, result, these_limbs_less_than); + } + + // `result` being boolean is an invariant, maintained because its new value is always + // `x * result + y`, where `x` and `y` are booleans that are not simultaneously true. + BoolTarget::new_unsafe(result) + } + + /// Helper function for comparing, specifically, lists of `U32Target`s. + pub fn list_le_u32(&mut self, a: Vec, b: Vec) -> BoolTarget { + let a_targets = a.iter().map(|&t| t.0).collect(); + let b_targets = b.iter().map(|&t| t.0).collect(); + self.list_le(a_targets, b_targets, 32) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use num::BigUint; + use rand::Rng; + + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + fn test_list_le(size: usize, num_bits: usize) -> Result<()> { + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let mut rng = rand::thread_rng(); + + let lst1: Vec = (0..size) + .map(|_| rng.gen_range(0..(1 << num_bits))) + .collect(); + let lst2: Vec = (0..size) + .map(|_| rng.gen_range(0..(1 << num_bits))) + .collect(); + + let a_biguint = BigUint::from_slice( + &lst1 + .iter() + .flat_map(|&x| [x as u32, (x >> 32) as u32]) + .collect::>(), + ); + let b_biguint = BigUint::from_slice( + &lst2 + .iter() + .flat_map(|&x| [x as u32, (x >> 32) as u32]) + .collect::>(), + ); + + let a = lst1 + .iter() + .map(|&x| builder.constant(F::from_canonical_u64(x))) + .collect(); + let b = lst2 + .iter() + .map(|&x| builder.constant(F::from_canonical_u64(x))) + .collect(); + + let result = builder.list_le(a, b, num_bits); + + let expected_result = builder.constant_bool(a_biguint <= b_biguint); + builder.connect(result.target, expected_result.target); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_multiple_comparison() -> Result<()> { + for size in [1, 3, 6, 10] { + for num_bits in [20, 32, 40, 50] { + test_list_le(size, num_bits).unwrap(); + } + } + + Ok(()) + } +} diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs new file mode 100644 index 00000000..fd883e5d --- /dev/null +++ b/src/gadgets/nonnative.rs @@ -0,0 +1,216 @@ +use std::marker::PhantomData; + +use num::{BigUint, One}; + +use crate::field::field_types::RichField; +use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::gadgets::biguint::BigUintTarget; +use crate::plonk::circuit_builder::CircuitBuilder; + +pub struct ForeignFieldTarget { + value: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize> CircuitBuilder { + pub fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { + ForeignFieldTarget { + value: x.clone(), + _phantom: PhantomData, + } + } + + pub fn nonnative_to_biguint(&mut self, x: &ForeignFieldTarget) -> BigUintTarget { + x.value.clone() + } + + pub fn constant_nonnative(&mut self, x: FF) -> ForeignFieldTarget { + let x_biguint = self.constant_biguint(&x.to_biguint()); + self.biguint_to_nonnative(&x_biguint) + } + + // Assert that two ForeignFieldTarget's, both assumed to be in reduced form, are equal. + pub fn connect_nonnative( + &mut self, + lhs: &ForeignFieldTarget, + rhs: &ForeignFieldTarget, + ) { + self.connect_biguint(&lhs.value, &rhs.value); + } + + // Add two `ForeignFieldTarget`s. + pub fn add_nonnative( + &mut self, + a: &ForeignFieldTarget, + b: &ForeignFieldTarget, + ) -> ForeignFieldTarget { + let result = self.add_biguint(&a.value, &b.value); + + self.reduce(&result) + } + + // Subtract two `ForeignFieldTarget`s. + pub fn sub_nonnative( + &mut self, + a: &ForeignFieldTarget, + b: &ForeignFieldTarget, + ) -> ForeignFieldTarget { + let order = self.constant_biguint(&FF::order()); + let a_plus_order = self.add_biguint(&order, &a.value); + let result = self.sub_biguint(&a_plus_order, &b.value); + + // TODO: reduce sub result with only one conditional addition? + self.reduce(&result) + } + + pub fn mul_nonnative( + &mut self, + a: &ForeignFieldTarget, + b: &ForeignFieldTarget, + ) -> ForeignFieldTarget { + let result = self.mul_biguint(&a.value, &b.value); + + self.reduce(&result) + } + + pub fn neg_nonnative( + &mut self, + x: &ForeignFieldTarget, + ) -> ForeignFieldTarget { + let neg_one = FF::order() - BigUint::one(); + let neg_one_target = self.constant_biguint(&neg_one); + let neg_one_ff = self.biguint_to_nonnative(&neg_one_target); + + self.mul_nonnative(&neg_one_ff, x) + } + + /// Returns `x % |FF|` as a `ForeignFieldTarget`. + fn reduce(&mut self, x: &BigUintTarget) -> ForeignFieldTarget { + let modulus = FF::order(); + let order_target = self.constant_biguint(&modulus); + let value = self.rem_biguint(x, &order_target); + + ForeignFieldTarget { + value, + _phantom: PhantomData, + } + } + + fn reduce_nonnative( + &mut self, + x: &ForeignFieldTarget, + ) -> ForeignFieldTarget { + let x_biguint = self.nonnative_to_biguint(x); + self.reduce(&x_biguint) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::field::secp256k1::Secp256K1Base; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + #[test] + fn test_nonnative_add() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let y_ff = FF::rand(); + let sum_ff = x_ff + y_ff; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let sum = builder.add_nonnative(&x, &y); + + let sum_expected = builder.constant_nonnative(sum_ff); + builder.connect_nonnative(&sum, &sum_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_nonnative_sub() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let mut y_ff = FF::rand(); + while y_ff.to_biguint() > x_ff.to_biguint() { + y_ff = FF::rand(); + } + let diff_ff = x_ff - y_ff; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let diff = builder.sub_nonnative(&x, &y); + + let diff_expected = builder.constant_nonnative(diff_ff); + builder.connect_nonnative(&diff, &diff_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_nonnative_mul() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let y_ff = FF::rand(); + let product_ff = x_ff * y_ff; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let product = builder.mul_nonnative(&x, &y); + + let product_expected = builder.constant_nonnative(product_ff); + builder.connect_nonnative(&product, &product_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_nonnative_neg() -> Result<()> { + type FF = Secp256K1Base; + let x_ff = FF::rand(); + let neg_x_ff = -x_ff; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let neg_x = builder.neg_nonnative(&x); + + let neg_x_expected = builder.constant_nonnative(neg_x_ff); + builder.connect_nonnative(&neg_x, &neg_x_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index fd4a897f..c60eda7d 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -35,7 +35,6 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation_2x2(a[0].clone(), a[1].clone(), b[0].clone(), b[1].clone()) } // For larger lists, we recursively use two smaller permutation networks. - //_ => self.assert_permutation_recursive(a, b) _ => self.assert_permutation_recursive(a, b), } } @@ -73,22 +72,7 @@ impl, const D: usize> CircuitBuilder { let chunk_size = a1.len(); - if self.current_switch_gates.len() < chunk_size { - self.current_switch_gates - .extend(vec![None; chunk_size - self.current_switch_gates.len()]); - } - - let (gate, gate_index, mut next_copy) = - match self.current_switch_gates[chunk_size - 1].clone() { - None => { - let gate = SwitchGate::::new_from_config(&self.config, chunk_size); - let gate_index = self.add_gate(gate.clone(), vec![]); - (gate, gate_index, 0) - } - Some((gate, idx, next_copy)) => (gate, idx, next_copy), - }; - - let num_copies = gate.num_copies; + let (gate, gate_index, next_copy) = self.find_switch_gate(chunk_size); let mut c = Vec::new(); let mut d = Vec::new(); @@ -113,13 +97,6 @@ impl, const D: usize> CircuitBuilder { let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy)); - next_copy += 1; - if next_copy == num_copies { - self.current_switch_gates[chunk_size - 1] = None; - } else { - self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy)); - } - (switch, c, d) } diff --git a/src/gadgets/random_access.rs b/src/gadgets/random_access.rs index 398c516f..58c827c1 100644 --- a/src/gadgets/random_access.rs +++ b/src/gadgets/random_access.rs @@ -6,37 +6,6 @@ use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { - /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. - /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index - /// `g` and the gate's `i`-th random access is available. - fn find_random_access_gate(&mut self, vec_size: usize) -> (usize, usize) { - let (gate, i) = self - .free_random_access - .get(&vec_size) - .copied() - .unwrap_or_else(|| { - let gate = self.add_gate( - RandomAccessGate::new_from_config(&self.config, vec_size), - vec![], - ); - (gate, 0) - }); - - // Update `free_random_access` with new values. - if i < RandomAccessGate::::max_num_copies( - self.config.num_routed_wires, - self.config.num_wires, - vec_size, - ) - 1 - { - self.free_random_access.insert(vec_size, (gate, i + 1)); - } else { - self.free_random_access.remove(&vec_size); - } - - (gate, i) - } - /// Checks that a `Target` matches a vector at a non-deterministic index. /// Note: `access_index` is not range-checked. pub fn random_access(&mut self, access_index: Target, claimed_element: Target, v: Vec) { diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 72dcf273..2c52db23 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -4,7 +4,7 @@ use itertools::izip; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; -use crate::gates::comparison::ComparisonGate; +use crate::gates::assert_le::AssertLessThanGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; @@ -40,9 +40,9 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation(a_chunks, b_chunks); } - /// Add a ComparisonGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits. + /// Add an AssertLessThanGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits. pub fn assert_le(&mut self, lhs: Target, rhs: Target, bits: usize, num_chunks: usize) { - let gate = ComparisonGate::new(bits, num_chunks); + let gate = AssertLessThanGate::new(bits, num_chunks); let gate_index = self.add_gate(gate.clone(), vec![]); self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index 6564a876..2bbbda6e 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -17,12 +17,18 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; pub const NUM_U32_ARITHMETIC_OPS: usize = 3; /// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand). -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct U32ArithmeticGate, const D: usize> { _phantom: PhantomData, } impl, const D: usize> U32ArithmeticGate { + pub fn new() -> Self { + Self { + _phantom: PhantomData, + } + } + pub fn wire_ith_multiplicand_0(i: usize) -> usize { debug_assert!(i < NUM_U32_ARITHMETIC_OPS); 5 * i @@ -309,8 +315,7 @@ impl, const D: usize> SimpleGenerator .take(num_limbs) .collect(); let output_limbs_f: Vec<_> = output_limbs_u64 - .iter() - .cloned() + .into_iter() .map(F::from_canonical_u64) .collect(); @@ -385,8 +390,7 @@ mod tests { output /= limb_base; } let mut output_limbs_f: Vec<_> = output_limbs - .iter() - .cloned() + .into_iter() .map(F::from_canonical_u64) .collect(); diff --git a/src/gates/assert_le.rs b/src/gates/assert_le.rs new file mode 100644 index 00000000..4d33a867 --- /dev/null +++ b/src/gates/assert_le.rs @@ -0,0 +1,607 @@ +use std::marker::PhantomData; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::util::{bits_u64, ceil_div_usize}; + +// TODO: replace/merge this gate with `ComparisonGate`. + +/// A gate for checking that one value is less than or equal to another. +#[derive(Clone, Debug)] +pub struct AssertLessThanGate, const D: usize> { + pub(crate) num_bits: usize, + pub(crate) num_chunks: usize, + _phantom: PhantomData, +} + +impl, const D: usize> AssertLessThanGate { + pub fn new(num_bits: usize, num_chunks: usize) -> Self { + debug_assert!(num_bits < bits_u64(F::ORDER)); + Self { + num_bits, + num_chunks, + _phantom: PhantomData, + } + } + + pub fn chunk_bits(&self) -> usize { + ceil_div_usize(self.num_bits, self.num_chunks) + } + + pub fn wire_first_input(&self) -> usize { + 0 + } + + pub fn wire_second_input(&self) -> usize { + 1 + } + + pub fn wire_most_significant_diff(&self) -> usize { + 2 + } + + pub fn wire_first_chunk_val(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + chunk + } + + pub fn wire_second_chunk_val(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + self.num_chunks + chunk + } + + pub fn wire_equality_dummy(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + 2 * self.num_chunks + chunk + } + + pub fn wire_chunks_equal(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + 3 * self.num_chunks + chunk + } + + pub fn wire_intermediate_value(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + 4 * self.num_chunks + chunk + } +} + +impl, const D: usize> Gate for AssertLessThanGate { + fn id(&self) -> String { + format!("{:?}", self, D) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let first_chunks_combined = reduce_with_powers( + &first_chunks, + F::Extension::from_canonical_usize(1 << self.chunk_bits()), + ); + let second_chunks_combined = reduce_with_powers( + &second_chunks, + F::Extension::from_canonical_usize(1 << self.chunk_bits()), + ); + + constraints.push(first_chunks_combined - first_input); + constraints.push(second_chunks_combined - second_input); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = F::Extension::ZERO; + + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let first_product = (0..chunk_size) + .map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x)) + .product(); + let second_product = (0..chunk_size) + .map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(first_product); + constraints.push(second_product); + + let difference = second_chunks[i] - first_chunks[i]; + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal)); + constraints.push(chunks_equal * difference); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); + most_significant_diff_so_far = + intermediate_value + (F::Extension::ONE - chunks_equal) * difference; + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints.push(most_significant_diff - most_significant_diff_so_far); + + // Range check `most_significant_diff` to be less than `chunk_size`. + let product = (0..chunk_size) + .map(|x| most_significant_diff - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(product); + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let first_chunks_combined = reduce_with_powers( + &first_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + let second_chunks_combined = reduce_with_powers( + &second_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + + constraints.push(first_chunks_combined - first_input); + constraints.push(second_chunks_combined - second_input); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = F::ZERO; + + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let first_product = (0..chunk_size) + .map(|x| first_chunks[i] - F::from_canonical_usize(x)) + .product(); + let second_product = (0..chunk_size) + .map(|x| second_chunks[i] - F::from_canonical_usize(x)) + .product(); + constraints.push(first_product); + constraints.push(second_product); + + let difference = second_chunks[i] - first_chunks[i]; + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + constraints.push(difference * equality_dummy - (F::ONE - chunks_equal)); + constraints.push(chunks_equal * difference); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); + most_significant_diff_so_far = + intermediate_value + (F::ONE - chunks_equal) * difference; + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints.push(most_significant_diff - most_significant_diff_so_far); + + // Range check `most_significant_diff` to be less than `chunk_size`. + let product = (0..chunk_size) + .map(|x| most_significant_diff - F::from_canonical_usize(x)) + .product(); + constraints.push(product); + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits())); + let first_chunks_combined = + reduce_with_powers_ext_recursive(builder, &first_chunks, chunk_base); + let second_chunks_combined = + reduce_with_powers_ext_recursive(builder, &second_chunks, chunk_base); + + constraints.push(builder.sub_extension(first_chunks_combined, first_input)); + constraints.push(builder.sub_extension(second_chunks_combined, second_input)); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = builder.zero_extension(); + + let one = builder.one_extension(); + // Find the chosen chunk. + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let mut first_product = one; + let mut second_product = one; + for x in 0..chunk_size { + let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); + let first_diff = builder.sub_extension(first_chunks[i], x_f); + let second_diff = builder.sub_extension(second_chunks[i], x_f); + first_product = builder.mul_extension(first_product, first_diff); + second_product = builder.mul_extension(second_product, second_diff); + } + constraints.push(first_product); + constraints.push(second_product); + + let difference = builder.sub_extension(second_chunks[i], first_chunks[i]); + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + let diff_times_equal = builder.mul_extension(difference, equality_dummy); + let not_equal = builder.sub_extension(one, chunks_equal); + constraints.push(builder.sub_extension(diff_times_equal, not_equal)); + constraints.push(builder.mul_extension(chunks_equal, difference)); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far); + constraints.push(builder.sub_extension(intermediate_value, old_diff)); + + let not_equal = builder.sub_extension(one, chunks_equal); + let new_diff = builder.mul_extension(not_equal, difference); + most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff); + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints + .push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far)); + + // Range check `most_significant_diff` to be less than `chunk_size`. + let mut product = builder.one_extension(); + for x in 0..chunk_size { + let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(most_significant_diff, x_f); + product = builder.mul_extension(product, diff); + } + constraints.push(product); + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = AssertLessThanGenerator:: { + gate_index, + gate: self.clone(), + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.wire_intermediate_value(self.num_chunks - 1) + 1 + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << self.chunk_bits() + } + + fn num_constraints(&self) -> usize { + 4 + 5 * self.num_chunks + } +} + +#[derive(Debug)] +struct AssertLessThanGenerator, const D: usize> { + gate_index: usize, + gate: AssertLessThanGate, +} + +impl, const D: usize> SimpleGenerator + for AssertLessThanGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |input| Target::wire(self.gate_index, input); + + let mut deps = Vec::new(); + deps.push(local_target(self.gate.wire_first_input())); + deps.push(local_target(self.gate.wire_second_input())); + deps + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + let first_input = get_local_wire(self.gate.wire_first_input()); + let second_input = get_local_wire(self.gate.wire_second_input()); + + let first_input_u64 = first_input.to_canonical_u64(); + let second_input_u64 = second_input.to_canonical_u64(); + + debug_assert!(first_input_u64 < second_input_u64); + + let chunk_size = 1 << self.gate.chunk_bits(); + let first_input_chunks: Vec = (0..self.gate.num_chunks) + .scan(first_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + let second_input_chunks: Vec = (0..self.gate.num_chunks) + .scan(second_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + let chunks_equal: Vec = (0..self.gate.num_chunks) + .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) + .collect(); + let equality_dummies: Vec = first_input_chunks + .iter() + .zip(second_input_chunks.iter()) + .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) + .collect(); + + let mut most_significant_diff_so_far = F::ZERO; + let mut intermediate_values = Vec::new(); + for i in 0..self.gate.num_chunks { + if first_input_chunks[i] != second_input_chunks[i] { + most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; + intermediate_values.push(F::ZERO); + } else { + intermediate_values.push(most_significant_diff_so_far); + } + } + let most_significant_diff = most_significant_diff_so_far; + + out_buffer.set_wire( + local_wire(self.gate.wire_most_significant_diff()), + most_significant_diff, + ); + for i in 0..self.gate.num_chunks { + out_buffer.set_wire( + local_wire(self.gate.wire_first_chunk_val(i)), + first_input_chunks[i], + ); + out_buffer.set_wire( + local_wire(self.gate.wire_second_chunk_val(i)), + second_input_chunks[i], + ); + out_buffer.set_wire( + local_wire(self.gate.wire_equality_dummy(i)), + equality_dummies[i], + ); + out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]); + out_buffer.set_wire( + local_wire(self.gate.wire_intermediate_value(i)), + intermediate_values[i], + ); + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use rand::Rng; + + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::{Field, PrimeField}; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::assert_le::AssertLessThanGate; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::hash::hash_types::HashOut; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn wire_indices() { + type AG = AssertLessThanGate; + let num_bits = 40; + let num_chunks = 5; + + let gate = AG { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + + assert_eq!(gate.wire_first_input(), 0); + assert_eq!(gate.wire_second_input(), 1); + assert_eq!(gate.wire_most_significant_diff(), 2); + assert_eq!(gate.wire_first_chunk_val(0), 3); + assert_eq!(gate.wire_first_chunk_val(4), 7); + assert_eq!(gate.wire_second_chunk_val(0), 8); + assert_eq!(gate.wire_second_chunk_val(4), 12); + assert_eq!(gate.wire_equality_dummy(0), 13); + assert_eq!(gate.wire_equality_dummy(4), 17); + assert_eq!(gate.wire_chunks_equal(0), 18); + assert_eq!(gate.wire_chunks_equal(4), 22); + assert_eq!(gate.wire_intermediate_value(0), 23); + assert_eq!(gate.wire_intermediate_value(4), 27); + } + + #[test] + fn low_degree() { + let num_bits = 40; + let num_chunks = 5; + + test_low_degree::(AssertLessThanGate::<_, 4>::new( + num_bits, num_chunks, + )) + } + + #[test] + fn eval_fns() -> Result<()> { + let num_bits = 40; + let num_chunks = 5; + + test_eval_fns::(AssertLessThanGate::<_, 4>::new( + num_bits, num_chunks, + )) + } + + #[test] + fn test_gate_constraint() { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + + let num_bits = 40; + let num_chunks = 5; + let chunk_bits = num_bits / num_chunks; + + // Returns the local wires for an AssertLessThanGate given the two inputs. + let get_wires = |first_input: F, second_input: F| -> Vec { + let mut v = Vec::new(); + + let first_input_u64 = first_input.to_canonical_u64(); + let second_input_u64 = second_input.to_canonical_u64(); + + let chunk_size = 1 << chunk_bits; + let mut first_input_chunks: Vec = (0..num_chunks) + .scan(first_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + let mut second_input_chunks: Vec = (0..num_chunks) + .scan(second_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + let mut chunks_equal: Vec = (0..num_chunks) + .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) + .collect(); + let mut equality_dummies: Vec = first_input_chunks + .iter() + .zip(second_input_chunks.iter()) + .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) + .collect(); + + let mut most_significant_diff_so_far = F::ZERO; + let mut intermediate_values = Vec::new(); + for i in 0..num_chunks { + if first_input_chunks[i] != second_input_chunks[i] { + most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; + intermediate_values.push(F::ZERO); + } else { + intermediate_values.push(most_significant_diff_so_far); + } + } + let most_significant_diff = most_significant_diff_so_far; + + v.push(first_input); + v.push(second_input); + v.push(most_significant_diff); + v.append(&mut first_input_chunks); + v.append(&mut second_input_chunks); + v.append(&mut equality_dummies); + v.append(&mut chunks_equal); + v.append(&mut intermediate_values); + + v.iter().map(|&x| x.into()).collect::>() + }; + + let mut rng = rand::thread_rng(); + let max: u64 = 1 << num_bits - 1; + let first_input_u64 = rng.gen_range(0..max); + let second_input_u64 = { + let mut val = rng.gen_range(0..max); + while val < first_input_u64 { + val = rng.gen_range(0..max); + } + val + }; + + let first_input = F::from_canonical_u64(first_input_u64); + let second_input = F::from_canonical_u64(second_input_u64); + + let less_than_gate = AssertLessThanGate:: { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + let less_than_vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(first_input, second_input), + public_inputs_hash: &HashOut::rand(), + }; + assert!( + less_than_gate + .eval_unfiltered(less_than_vars) + .iter() + .all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + + let equal_gate = AssertLessThanGate:: { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + let equal_vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(first_input, first_input), + public_inputs_hash: &HashOut::rand(), + }; + assert!( + equal_gate + .eval_unfiltered(equal_vars) + .iter() + .all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 988086d0..a610c5e2 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -43,33 +43,42 @@ impl, const D: usize> ComparisonGate { 1 } - pub fn wire_most_significant_diff(&self) -> usize { + pub fn wire_result_bool(&self) -> usize { 2 } + pub fn wire_most_significant_diff(&self) -> usize { + 3 + } + pub fn wire_first_chunk_val(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + chunk + 4 + chunk } pub fn wire_second_chunk_val(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + self.num_chunks + chunk + 4 + self.num_chunks + chunk } pub fn wire_equality_dummy(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + 2 * self.num_chunks + chunk + 4 + 2 * self.num_chunks + chunk } pub fn wire_chunks_equal(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + 3 * self.num_chunks + chunk + 4 + 3 * self.num_chunks + chunk } pub fn wire_intermediate_value(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + 4 * self.num_chunks + chunk + 4 + 4 * self.num_chunks + chunk + } + + /// The `bit_index`th bit of 2^n - 1 + most_significant_diff. + pub fn wire_most_significant_diff_bit(&self, bit_index: usize) -> usize { + 4 + 5 * self.num_chunks + bit_index } } @@ -110,10 +119,10 @@ impl, const D: usize> Gate for ComparisonGate for i in 0..self.num_chunks { // Range-check the chunks to be less than `chunk_size`. - let first_product = (0..chunk_size) + let first_product: F::Extension = (0..chunk_size) .map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x)) .product(); - let second_product = (0..chunk_size) + let second_product: F::Extension = (0..chunk_size) .map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x)) .product(); constraints.push(first_product); @@ -137,11 +146,24 @@ impl, const D: usize> Gate for ComparisonGate let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; constraints.push(most_significant_diff - most_significant_diff_so_far); - // Range check `most_significant_diff` to be less than `chunk_size`. - let product = (0..chunk_size) - .map(|x| most_significant_diff - F::Extension::from_canonical_usize(x)) - .product(); - constraints.push(product); + let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + + // Range-check the bits. + for i in 0..most_significant_diff_bits.len() { + constraints.push( + most_significant_diff_bits[i] * (F::Extension::ONE - most_significant_diff_bits[i]), + ); + } + + let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO); + let two_n = F::Extension::from_canonical_u64(1 << self.chunk_bits()); + constraints.push((two_n + most_significant_diff) - bits_combined); + + // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); constraints } @@ -178,10 +200,10 @@ impl, const D: usize> Gate for ComparisonGate for i in 0..self.num_chunks { // Range-check the chunks to be less than `chunk_size`. - let first_product = (0..chunk_size) + let first_product: F = (0..chunk_size) .map(|x| first_chunks[i] - F::from_canonical_usize(x)) .product(); - let second_product = (0..chunk_size) + let second_product: F = (0..chunk_size) .map(|x| second_chunks[i] - F::from_canonical_usize(x)) .product(); constraints.push(first_product); @@ -205,11 +227,23 @@ impl, const D: usize> Gate for ComparisonGate let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; constraints.push(most_significant_diff - most_significant_diff_so_far); - // Range check `most_significant_diff` to be less than `chunk_size`. - let product = (0..chunk_size) - .map(|x| most_significant_diff - F::from_canonical_usize(x)) - .product(); - constraints.push(product); + let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + + // Range-check the bits. + for i in 0..most_significant_diff_bits.len() { + constraints + .push(most_significant_diff_bits[i] * (F::ONE - most_significant_diff_bits[i])); + } + + let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); + let two_n = F::from_canonical_u64(1 << self.chunk_bits()); + constraints.push((two_n + most_significant_diff) - bits_combined); + + // Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); constraints } @@ -285,14 +319,30 @@ impl, const D: usize> Gate for ComparisonGate constraints .push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far)); - // Range check `most_significant_diff` to be less than `chunk_size`. - let mut product = builder.one_extension(); - for x in 0..chunk_size { - let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); - let diff = builder.sub_extension(most_significant_diff, x_f); - product = builder.mul_extension(product, diff); + let most_significant_diff_bits: Vec> = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + + // Range-check the bits. + for i in 0..most_significant_diff_bits.len() { + let this_bit = most_significant_diff_bits[i]; + let inverse = builder.sub_extension(one, this_bit); + constraints.push(builder.mul_extension(this_bit, inverse)); } - constraints.push(product); + + let two = builder.two(); + let bits_combined = + reduce_with_powers_ext_recursive(builder, &most_significant_diff_bits, two); + let two_n = + builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits())); + let sum = builder.add_extension(two_n, most_significant_diff); + constraints.push(builder.sub_extension(sum, bits_combined)); + + // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + constraints.push( + builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]), + ); constraints } @@ -310,7 +360,7 @@ impl, const D: usize> Gate for ComparisonGate } fn num_wires(&self) -> usize { - self.wire_intermediate_value(self.num_chunks - 1) + 1 + 4 + 5 * self.num_chunks + (self.chunk_bits() + 1) } fn num_constants(&self) -> usize { @@ -322,7 +372,7 @@ impl, const D: usize> Gate for ComparisonGate } fn num_constraints(&self) -> usize { - 4 + 5 * self.num_chunks + 6 + 5 * self.num_chunks + self.chunk_bits() } } @@ -358,7 +408,7 @@ impl, const D: usize> SimpleGenerator let first_input_u64 = first_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64(); - debug_assert!(first_input_u64 < second_input_u64); + let result = F::from_canonical_usize((first_input_u64 <= second_input_u64) as usize); let chunk_size = 1 << self.gate.chunk_bits(); let first_input_chunks: Vec = (0..self.gate.num_chunks) @@ -397,6 +447,22 @@ impl, const D: usize> SimpleGenerator } let most_significant_diff = most_significant_diff_so_far; + let two_n = F::from_canonical_usize(1 << self.gate.chunk_bits()); + let two_n_plus_msd = (two_n + most_significant_diff).to_canonical_u64(); + + let msd_bits_u64: Vec = (0..self.gate.chunk_bits() + 1) + .scan(two_n_plus_msd, |acc, _| { + let tmp = *acc % 2; + *acc /= 2; + Some(tmp) + }) + .collect(); + let msd_bits: Vec = msd_bits_u64 + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect(); + + out_buffer.set_wire(local_wire(self.gate.wire_result_bool()), result); out_buffer.set_wire( local_wire(self.gate.wire_most_significant_diff()), most_significant_diff, @@ -420,6 +486,12 @@ impl, const D: usize> SimpleGenerator intermediate_values[i], ); } + for i in 0..self.gate.chunk_bits() + 1 { + out_buffer.set_wire( + local_wire(self.gate.wire_most_significant_diff_bit(i)), + msd_bits[i], + ); + } } } @@ -453,17 +525,20 @@ mod tests { assert_eq!(gate.wire_first_input(), 0); assert_eq!(gate.wire_second_input(), 1); - assert_eq!(gate.wire_most_significant_diff(), 2); - assert_eq!(gate.wire_first_chunk_val(0), 3); - assert_eq!(gate.wire_first_chunk_val(4), 7); - assert_eq!(gate.wire_second_chunk_val(0), 8); - assert_eq!(gate.wire_second_chunk_val(4), 12); - assert_eq!(gate.wire_equality_dummy(0), 13); - assert_eq!(gate.wire_equality_dummy(4), 17); - assert_eq!(gate.wire_chunks_equal(0), 18); - assert_eq!(gate.wire_chunks_equal(4), 22); - assert_eq!(gate.wire_intermediate_value(0), 23); - assert_eq!(gate.wire_intermediate_value(4), 27); + assert_eq!(gate.wire_result_bool(), 2); + assert_eq!(gate.wire_most_significant_diff(), 3); + assert_eq!(gate.wire_first_chunk_val(0), 4); + assert_eq!(gate.wire_first_chunk_val(4), 8); + assert_eq!(gate.wire_second_chunk_val(0), 9); + assert_eq!(gate.wire_second_chunk_val(4), 13); + assert_eq!(gate.wire_equality_dummy(0), 14); + assert_eq!(gate.wire_equality_dummy(4), 18); + assert_eq!(gate.wire_chunks_equal(0), 19); + assert_eq!(gate.wire_chunks_equal(4), 23); + assert_eq!(gate.wire_intermediate_value(0), 24); + assert_eq!(gate.wire_intermediate_value(4), 28); + assert_eq!(gate.wire_most_significant_diff_bit(0), 29); + assert_eq!(gate.wire_most_significant_diff_bit(8), 37); } #[test] @@ -499,6 +574,8 @@ mod tests { let first_input_u64 = first_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64(); + let result_bool = F::from_bool(first_input_u64 <= second_input_u64); + let chunk_size = 1 << chunk_bits; let mut first_input_chunks: Vec = (0..num_chunks) .scan(first_input_u64, |acc, _| { @@ -536,14 +613,26 @@ mod tests { } let most_significant_diff = most_significant_diff_so_far; + let two_n_plus_msd = + (1 << chunk_bits) as u64 + most_significant_diff.to_canonical_u64(); + let mut msd_bits: Vec = (0..chunk_bits + 1) + .scan(two_n_plus_msd, |acc, _| { + let tmp = *acc % 2; + *acc /= 2; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + v.push(first_input); v.push(second_input); + v.push(result_bool); v.push(most_significant_diff); v.append(&mut first_input_chunks); v.append(&mut second_input_chunks); v.append(&mut equality_dummies); v.append(&mut chunks_equal); v.append(&mut intermediate_values); + v.append(&mut msd_bits); v.iter().map(|&x| x.into()).collect::>() }; diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 76066285..b1a6028e 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -3,6 +3,7 @@ pub mod arithmetic; pub mod arithmetic_u32; +pub mod assert_le; pub mod base_sum; pub mod comparison; pub mod constant; @@ -18,6 +19,7 @@ pub(crate) mod poseidon_mds; pub(crate) mod public_input; pub mod random_access; pub mod reducing; +pub mod subtraction_u32; pub mod switch; #[cfg(test)] diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs new file mode 100644 index 00000000..225c09e4 --- /dev/null +++ b/src/gates/subtraction_u32.rs @@ -0,0 +1,422 @@ +use std::marker::PhantomData; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// Maximum number of subtractions operations performed by a single gate. +pub const NUM_U32_SUBTRACTION_OPS: usize = 3; + +/// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns +/// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked. +#[derive(Clone, Debug)] +pub struct U32SubtractionGate, const D: usize> { + _phantom: PhantomData, +} + +impl, const D: usize> U32SubtractionGate { + pub fn new() -> Self { + Self { + _phantom: PhantomData, + } + } + + pub fn wire_ith_input_x(i: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + 5 * i + } + pub fn wire_ith_input_y(i: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + 5 * i + 1 + } + pub fn wire_ith_input_borrow(i: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + 5 * i + 2 + } + + pub fn wire_ith_output_result(i: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + 5 * i + 3 + } + pub fn wire_ith_output_borrow(i: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + 5 * i + 4 + } + + pub fn limb_bits() -> usize { + 2 + } + // We have limbs for the 32 bits of `output_result`. + pub fn num_limbs() -> usize { + 32 / Self::limb_bits() + } + + pub fn wire_ith_output_jth_limb(i: usize, j: usize) -> usize { + debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + debug_assert!(j < Self::num_limbs()); + 5 * NUM_U32_SUBTRACTION_OPS + Self::num_limbs() * i + j + } +} + +impl, const D: usize> Gate for U32SubtractionGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..NUM_U32_SUBTRACTION_OPS { + let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; + let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + + let result_initial = input_x - input_y - input_borrow; + let base = F::Extension::from_canonical_u64(1 << 32u64); + + let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + + constraints.push(output_result - (result_initial + base * output_borrow)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = F::Extension::ZERO; + let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(product); + + combined_limbs = limb_base * combined_limbs + this_limb; + } + constraints.push(combined_limbs - output_result); + + // Range-check output_borrow to be one bit. + constraints.push(output_borrow * (F::Extension::ONE - output_borrow)); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..NUM_U32_SUBTRACTION_OPS { + let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; + let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + + let result_initial = input_x - input_y - input_borrow; + let base = F::from_canonical_u64(1 << 32u64); + + let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + + constraints.push(output_result - (result_initial + base * output_borrow)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = F::ZERO; + let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::from_canonical_usize(x)) + .product(); + constraints.push(product); + + combined_limbs = limb_base * combined_limbs + this_limb; + } + constraints.push(combined_limbs - output_result); + + // Range-check output_borrow to be one bit. + constraints.push(output_borrow * (F::ONE - output_borrow)); + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..NUM_U32_SUBTRACTION_OPS { + let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; + let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + + let diff = builder.sub_extension(input_x, input_y); + let result_initial = builder.sub_extension(diff, input_borrow); + let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64)); + + let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + + let computed_output = builder.mul_add_extension(base, output_borrow, result_initial); + constraints.push(builder.sub_extension(output_result, computed_output)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = builder.zero_extension(); + let limb_base = builder + .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let mut product = builder.one_extension(); + for x in 0..max_limb { + let x_target = + builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(this_limb, x_target); + product = builder.mul_extension(product, diff); + } + constraints.push(product); + + combined_limbs = builder.mul_add_extension(limb_base, combined_limbs, this_limb); + } + constraints.push(builder.sub_extension(combined_limbs, output_result)); + + // Range-check output_borrow to be one bit. + let one = builder.one_extension(); + let not_borrow = builder.sub_extension(one, output_borrow); + constraints.push(builder.mul_extension(output_borrow, not_borrow)); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + (0..NUM_U32_SUBTRACTION_OPS) + .map(|i| { + let g: Box> = Box::new( + U32SubtractionGenerator { + gate_index, + i, + _phantom: PhantomData, + } + .adapter(), + ); + g + }) + .collect() + } + + fn num_wires(&self) -> usize { + NUM_U32_SUBTRACTION_OPS * (5 + Self::num_limbs()) + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << Self::limb_bits() + } + + fn num_constraints(&self) -> usize { + NUM_U32_SUBTRACTION_OPS * (3 + Self::num_limbs()) + } +} + +#[derive(Clone, Debug)] +struct U32SubtractionGenerator, const D: usize> { + gate_index: usize, + i: usize, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for U32SubtractionGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |input| Target::wire(self.gate_index, input); + + let mut deps = Vec::with_capacity(3); + deps.push(local_target(U32SubtractionGate::::wire_ith_input_x( + self.i, + ))); + deps.push(local_target(U32SubtractionGate::::wire_ith_input_y( + self.i, + ))); + deps.push(local_target( + U32SubtractionGate::::wire_ith_input_borrow(self.i), + )); + deps + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + let input_x = get_local_wire(U32SubtractionGate::::wire_ith_input_x(self.i)); + let input_y = get_local_wire(U32SubtractionGate::::wire_ith_input_y(self.i)); + let input_borrow = + get_local_wire(U32SubtractionGate::::wire_ith_input_borrow(self.i)); + + let result_initial = input_x - input_y - input_borrow; + let result_initial_u64 = result_initial.to_canonical_u64(); + let output_borrow = if result_initial_u64 > 1 << 32u64 { + F::ONE + } else { + F::ZERO + }; + + let base = F::from_canonical_u64(1 << 32u64); + let output_result = result_initial + base * output_borrow; + + let output_result_wire = + local_wire(U32SubtractionGate::::wire_ith_output_result(self.i)); + let output_borrow_wire = + local_wire(U32SubtractionGate::::wire_ith_output_borrow(self.i)); + + out_buffer.set_wire(output_result_wire, output_result); + out_buffer.set_wire(output_borrow_wire, output_borrow); + + let output_result_u64 = output_result.to_canonical_u64(); + + let num_limbs = U32SubtractionGate::::num_limbs(); + let limb_base = 1 << U32SubtractionGate::::limb_bits(); + let output_limbs: Vec<_> = (0..num_limbs) + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + for j in 0..num_limbs { + let wire = local_wire(U32SubtractionGate::::wire_ith_output_jth_limb( + self.i, j, + )); + out_buffer.set_wire(wire, output_limbs[j]); + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use rand::Rng; + + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::{Field, PrimeField}; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; + use crate::hash::hash_types::HashOut; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn low_degree() { + test_low_degree::(U32SubtractionGate:: { + _phantom: PhantomData, + }) + } + + #[test] + fn eval_fns() -> Result<()> { + test_eval_fns::(U32SubtractionGate:: { + _phantom: PhantomData, + }) + } + + #[test] + fn test_gate_constraint() { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + + fn get_wires(inputs_x: Vec, inputs_y: Vec, borrows: Vec) -> Vec { + let mut v0 = Vec::new(); + let mut v1 = Vec::new(); + + let limb_bits = U32SubtractionGate::::limb_bits(); + let num_limbs = U32SubtractionGate::::num_limbs(); + let limb_base = 1 << limb_bits; + for c in 0..NUM_U32_SUBTRACTION_OPS { + let input_x = F::from_canonical_u64(inputs_x[c]); + let input_y = F::from_canonical_u64(inputs_y[c]); + let input_borrow = F::from_canonical_u64(borrows[c]); + + let result_initial = input_x - input_y - input_borrow; + let result_initial_u64 = result_initial.to_canonical_u64(); + let output_borrow = if result_initial_u64 > 1 << 32u64 { + F::ONE + } else { + F::ZERO + }; + + let base = F::from_canonical_u64(1 << 32u64); + let output_result = result_initial + base * output_borrow; + + let output_result_u64 = output_result.to_canonical_u64(); + + let mut output_limbs: Vec<_> = (0..num_limbs) + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + v0.push(input_x); + v0.push(input_y); + v0.push(input_borrow); + v0.push(output_result); + v0.push(output_borrow); + v1.append(&mut output_limbs); + } + + v0.iter() + .chain(v1.iter()) + .map(|&x| x.into()) + .collect::>() + } + + let mut rng = rand::thread_rng(); + let inputs_x = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + let inputs_y = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + let borrows = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| (rng.gen::() % 2) as u64) + .collect(); + + let gate = U32SubtractionGate:: { + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(inputs_x, inputs_y, borrows), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/src/iop/generator.rs b/src/iop/generator.rs index eb2c95f7..8c6cb294 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -1,9 +1,13 @@ use std::fmt::Debug; use std::marker::PhantomData; +use num::BigUint; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{Field, RichField}; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::biguint::BigUintTarget; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -150,6 +154,20 @@ impl GeneratedValues { self.target_values.push((target, value)) } + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)) + } + + pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + self.set_u32_target(target.get_limb(i), limbs[i]); + } + } + pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { ht.elements .iter() diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 858bacd9..0388a6cb 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -1,9 +1,12 @@ use std::collections::HashMap; use std::convert::TryInto; +use num::{BigUint, FromPrimitive, Zero}; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::Field; +use crate::gadgets::biguint::BigUintTarget; use crate::hash::hash_types::HashOutTarget; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; @@ -53,6 +56,19 @@ pub trait Witness { panic!("not a bool") } + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + let mut result = BigUint::zero(); + + let limb_base = BigUint::from_u64(1 << 32u64).unwrap(); + for i in (0..target.num_limbs()).rev() { + let limb = target.get_limb(i); + result *= &limb_base; + result += self.get_target(limb.0).to_biguint(); + } + + result + } + fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { HashOut { elements: self.get_targets(&ht.elements).try_into().unwrap(), diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 5dcde1e0..c6c49826 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -13,13 +13,16 @@ use crate::field::field_types::{Field, RichField}; use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::{FriConfig, FriParams}; use crate::gadgets::arithmetic_extension::ArithmeticOperation; +use crate::gadgets::arithmetic_u32::U32Target; use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; use crate::gates::random_access::RandomAccessGate; +use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; use crate::gates::switch::SwitchGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; use crate::hash::hashing::hash_n_to_hash; @@ -74,21 +77,7 @@ pub struct CircuitBuilder, const D: usize> { /// Memoized results of `arithmetic_extension` calls. pub(crate) arithmetic_results: HashMap, ExtensionTarget>, - /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using - /// these constants with gate index `g` and already using `i` arithmetic operations. - pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, - - /// A map `(c0, c1) -> (g, i)` from constants `vec_size` to an available arithmetic gate using - /// these constants with gate index `g` and already using `i` random accesses. - pub(crate) free_random_access: HashMap, - - // `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value - // chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies - // of switches - pub(crate) current_switch_gates: Vec, usize, usize)>>, - - /// An available `ConstantGate` instance, if any. - free_constant: Option<(usize, usize)>, + batched_gates: BatchedGates, } impl, const D: usize> CircuitBuilder { @@ -106,10 +95,7 @@ impl, const D: usize> CircuitBuilder { constants_to_targets: HashMap::new(), arithmetic_results: HashMap::new(), targets_to_constants: HashMap::new(), - free_arithmetic: HashMap::new(), - free_random_access: HashMap::new(), - current_switch_gates: Vec::new(), - free_constant: None, + batched_gates: BatchedGates::new(), }; builder.check_config(); builder @@ -216,6 +202,7 @@ impl, const D: usize> CircuitBuilder { gate_ref, constants, }); + index } @@ -260,6 +247,11 @@ impl, const D: usize> CircuitBuilder { self.connect(x, zero); } + pub fn assert_one(&mut self, x: Target) { + let one = self.one(); + self.connect(x, one); + } + pub fn add_generators(&mut self, generators: Vec>>) { self.generators.extend(generators); } @@ -313,26 +305,6 @@ impl, const D: usize> CircuitBuilder { target } - /// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a - /// new `ConstantGate` if needed. - fn constant_gate_instance(&mut self) -> (usize, usize) { - if self.free_constant.is_none() { - let num_consts = self.config.constant_gate_size; - // We will fill this `ConstantGate` with zero constants initially. - // These will be overwritten by `constant` as the gate instances are filled. - let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]); - self.free_constant = Some((gate, 0)); - } - - let (gate, instance) = self.free_constant.unwrap(); - if instance + 1 < self.config.constant_gate_size { - self.free_constant = Some((gate, instance + 1)); - } else { - self.free_constant = None; - } - (gate, instance) - } - pub fn constants(&mut self, constants: &[F]) -> Vec { constants.iter().map(|&c| self.constant(c)).collect() } @@ -345,6 +317,11 @@ impl, const D: usize> CircuitBuilder { } } + /// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits. + pub fn constant_u32(&mut self, c: u32) -> U32Target { + U32Target(self.constant(F::from_canonical_u32(c))) + } + /// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns /// its constant value. Otherwise, returns `None`. pub fn target_as_constant(&self, target: Target) -> Option { @@ -566,76 +543,6 @@ impl, const D: usize> CircuitBuilder { ) } - /// Fill the remaining unused arithmetic operations with zeros, so that all - /// `ArithmeticExtensionGenerator` are run. - fn fill_arithmetic_gates(&mut self) { - let zero = self.zero_extension(); - let remaining_arithmetic_gates = self.free_arithmetic.values().copied().collect::>(); - for (gate, i) in remaining_arithmetic_gates { - for j in i..ArithmeticExtensionGate::::num_ops(&self.config) { - let wires_multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_multiplicand_0(j), - ); - let wires_multiplicand_1 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_multiplicand_1(j), - ); - let wires_addend = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_ith_addend(j), - ); - - self.connect_extension(zero, wires_multiplicand_0); - self.connect_extension(zero, wires_multiplicand_1); - self.connect_extension(zero, wires_addend); - } - } - } - - /// Fill the remaining unused random access operations with zeros, so that all - /// `RandomAccessGenerator`s are run. - fn fill_random_access_gates(&mut self) { - let zero = self.zero(); - for (vec_size, (_, i)) in self.free_random_access.clone() { - let max_copies = RandomAccessGate::::max_num_copies( - self.config.num_routed_wires, - self.config.num_wires, - vec_size, - ); - for _ in i..max_copies { - self.random_access(zero, zero, vec![zero; vec_size]); - } - } - } - - /// Fill the remaining unused switch gates with dummy values, so that all - /// `SwitchGenerator` are run. - fn fill_switch_gates(&mut self) { - let zero = self.zero(); - - for chunk_size in 1..=self.current_switch_gates.len() { - if let Some((gate, gate_index, mut copy)) = - self.current_switch_gates[chunk_size - 1].clone() - { - while copy < gate.num_copies { - for element in 0..chunk_size { - let wire_first_input = - Target::wire(gate_index, gate.wire_first_input(copy, element)); - let wire_second_input = - Target::wire(gate_index, gate.wire_second_input(copy, element)); - let wire_switch_bool = - Target::wire(gate_index, gate.wire_switch_bool(copy)); - self.connect(zero, wire_first_input); - self.connect(zero, wire_second_input); - self.connect(zero, wire_switch_bool); - } - copy += 1; - } - } - } - } - pub fn print_gate_counts(&self, min_delta: usize) { // Print gate counts for each context. self.context_log @@ -659,9 +566,7 @@ impl, const D: usize> CircuitBuilder { let mut timing = TimingTree::new("preprocess", Level::Trace); let start = Instant::now(); - self.fill_arithmetic_gates(); - self.fill_random_access_gates(); - self.fill_switch_gates(); + self.fill_batched_gates(); // Hash the public inputs, and route them to a `PublicInputGate` which will enforce that // those hash wires match the claimed public inputs. @@ -836,3 +741,332 @@ impl, const D: usize> CircuitBuilder { } } } + +/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a CircuitBuilder track such gates that are currently being "filled up." +pub struct BatchedGates, const D: usize> { + /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using + /// these constants with gate index `g` and already using `i` arithmetic operations. + pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, + + /// A map `(c0, c1) -> (g, i)` from constants `vec_size` to an available arithmetic gate using + /// these constants with gate index `g` and already using `i` random accesses. + pub(crate) free_random_access: HashMap, + + /// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value + /// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies + /// of switches + pub(crate) current_switch_gates: Vec, usize, usize)>>, + + /// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one) + pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>, + + /// The `U32SubtractionGate` currently being filled (so new u32 subtraction operations will be added to this gate before creating a new one) + pub(crate) current_u32_subtraction_gate: Option<(usize, usize)>, + + /// An available `ConstantGate` instance, if any. + pub(crate) free_constant: Option<(usize, usize)>, +} + +impl, const D: usize> BatchedGates { + pub fn new() -> Self { + Self { + free_arithmetic: HashMap::new(), + free_random_access: HashMap::new(), + current_switch_gates: Vec::new(), + current_u32_arithmetic_gate: None, + current_u32_subtraction_gate: None, + free_constant: None, + } + } +} + +impl, const D: usize> CircuitBuilder { + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_arithmetic + .get(&(const_0, const_1)) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + ArithmeticExtensionGate::new_from_config(&self.config), + vec![const_0, const_1], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < ArithmeticExtensionGate::::num_ops(&self.config) - 1 { + self.batched_gates + .free_arithmetic + .insert((const_0, const_1), (gate, i + 1)); + } else { + self.batched_gates + .free_arithmetic + .remove(&(const_0, const_1)); + } + + (gate, i) + } + + /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. + /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index + /// `g` and the gate's `i`-th random access is available. + pub(crate) fn find_random_access_gate(&mut self, vec_size: usize) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_random_access + .get(&vec_size) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + RandomAccessGate::new_from_config(&self.config, vec_size), + vec![], + ); + (gate, 0) + }); + + // Update `free_random_access` with new values. + if i < RandomAccessGate::::max_num_copies( + self.config.num_routed_wires, + self.config.num_wires, + vec_size, + ) - 1 + { + self.batched_gates + .free_random_access + .insert(vec_size, (gate, i + 1)); + } else { + self.batched_gates.free_random_access.remove(&vec_size); + } + + (gate, i) + } + + pub(crate) fn find_switch_gate( + &mut self, + chunk_size: usize, + ) -> (SwitchGate, usize, usize) { + if self.batched_gates.current_switch_gates.len() < chunk_size { + self.batched_gates.current_switch_gates.extend(vec![ + None; + chunk_size + - self + .batched_gates + .current_switch_gates + .len() + ]); + } + + let (gate, gate_index, next_copy) = + match self.batched_gates.current_switch_gates[chunk_size - 1].clone() { + None => { + let gate = SwitchGate::::new_from_config(&self.config, chunk_size); + let gate_index = self.add_gate(gate.clone(), vec![]); + (gate, gate_index, 0) + } + Some((gate, idx, next_copy)) => (gate, idx, next_copy), + }; + + let num_copies = gate.num_copies; + + if next_copy == num_copies - 1 { + self.batched_gates.current_switch_gates[chunk_size - 1] = None; + } else { + self.batched_gates.current_switch_gates[chunk_size - 1] = + Some((gate.clone(), gate_index, next_copy + 1)); + } + + (gate, gate_index, next_copy) + } + + pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) { + let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate { + None => { + let gate = U32ArithmeticGate::new(); + let gate_index = self.add_gate(gate, vec![]); + (gate_index, 0) + } + Some((gate_index, copy)) => (gate_index, copy), + }; + + if copy == NUM_U32_ARITHMETIC_OPS - 1 { + self.batched_gates.current_u32_arithmetic_gate = None; + } else { + self.batched_gates.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); + } + + (gate_index, copy) + } + + pub(crate) fn find_u32_subtraction_gate(&mut self) -> (usize, usize) { + let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate { + None => { + let gate = U32SubtractionGate::new(); + let gate_index = self.add_gate(gate, vec![]); + (gate_index, 0) + } + Some((gate_index, copy)) => (gate_index, copy), + }; + + if copy == NUM_U32_SUBTRACTION_OPS - 1 { + self.batched_gates.current_u32_subtraction_gate = None; + } else { + self.batched_gates.current_u32_subtraction_gate = Some((gate_index, copy + 1)); + } + + (gate_index, copy) + } + + /// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a + /// new `ConstantGate` if needed. + fn constant_gate_instance(&mut self) -> (usize, usize) { + if self.batched_gates.free_constant.is_none() { + let num_consts = self.config.constant_gate_size; + // We will fill this `ConstantGate` with zero constants initially. + // These will be overwritten by `constant` as the gate instances are filled. + let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]); + self.batched_gates.free_constant = Some((gate, 0)); + } + + let (gate, instance) = self.batched_gates.free_constant.unwrap(); + if instance + 1 < self.config.constant_gate_size { + self.batched_gates.free_constant = Some((gate, instance + 1)); + } else { + self.batched_gates.free_constant = None; + } + (gate, instance) + } + + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticExtensionGenerator`s are run. + fn fill_arithmetic_gates(&mut self) { + let zero = self.zero_extension(); + let remaining_arithmetic_gates = self + .batched_gates + .free_arithmetic + .values() + .copied() + .collect::>(); + for (gate, i) in remaining_arithmetic_gates { + for j in i..ArithmeticExtensionGate::::num_ops(&self.config) { + let wires_multiplicand_0 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_ith_multiplicand_0(j), + ); + let wires_multiplicand_1 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_ith_multiplicand_1(j), + ); + let wires_addend = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_ith_addend(j), + ); + + self.connect_extension(zero, wires_multiplicand_0); + self.connect_extension(zero, wires_multiplicand_1); + self.connect_extension(zero, wires_addend); + } + } + } + + /// Fill the remaining unused random access operations with zeros, so that all + /// `RandomAccessGenerator`s are run. + fn fill_random_access_gates(&mut self) { + let zero = self.zero(); + for (vec_size, (_, i)) in self.batched_gates.free_random_access.clone() { + let max_copies = RandomAccessGate::::max_num_copies( + self.config.num_routed_wires, + self.config.num_wires, + vec_size, + ); + for _ in i..max_copies { + self.random_access(zero, zero, vec![zero; vec_size]); + } + } + } + + /// Fill the remaining unused switch gates with dummy values, so that all + /// `SwitchGenerator`s are run. + fn fill_switch_gates(&mut self) { + let zero = self.zero(); + + for chunk_size in 1..=self.batched_gates.current_switch_gates.len() { + if let Some((gate, gate_index, mut copy)) = + self.batched_gates.current_switch_gates[chunk_size - 1].clone() + { + while copy < gate.num_copies { + for element in 0..chunk_size { + let wire_first_input = + Target::wire(gate_index, gate.wire_first_input(copy, element)); + let wire_second_input = + Target::wire(gate_index, gate.wire_second_input(copy, element)); + let wire_switch_bool = + Target::wire(gate_index, gate.wire_switch_bool(copy)); + self.connect(zero, wire_first_input); + self.connect(zero, wire_second_input); + self.connect(zero, wire_switch_bool); + } + copy += 1; + } + } + } + } + + /// Fill the remaining unused U32 arithmetic operations with zeros, so that all + /// `U32ArithmeticGenerator`s are run. + fn fill_u32_arithmetic_gates(&mut self) { + let zero = self.zero(); + if let Some((gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate { + for i in copy..NUM_U32_ARITHMETIC_OPS { + let wire_multiplicand_0 = Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_multiplicand_0(i), + ); + let wire_multiplicand_1 = Target::wire( + gate_index, + U32ArithmeticGate::::wire_ith_multiplicand_1(i), + ); + let wire_addend = + Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(i)); + + self.connect(zero, wire_multiplicand_0); + self.connect(zero, wire_multiplicand_1); + self.connect(zero, wire_addend); + } + } + } + + /// Fill the remaining unused U32 subtraction operations with zeros, so that all + /// `U32SubtractionGenerator`s are run. + fn fill_u32_subtraction_gates(&mut self) { + let zero = self.zero(); + if let Some((gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { + for i in copy..NUM_U32_ARITHMETIC_OPS { + let wire_input_x = + Target::wire(gate_index, U32SubtractionGate::::wire_ith_input_x(i)); + let wire_input_y = + Target::wire(gate_index, U32SubtractionGate::::wire_ith_input_y(i)); + let wire_input_borrow = Target::wire( + gate_index, + U32SubtractionGate::::wire_ith_input_borrow(i), + ); + + self.connect(zero, wire_input_x); + self.connect(zero, wire_input_y); + self.connect(zero, wire_input_borrow); + } + } + } + + fn fill_batched_gates(&mut self) { + self.fill_arithmetic_gates(); + self.fill_random_access_gates(); + self.fill_switch_gates(); + self.fill_u32_arithmetic_gates(); + self.fill_u32_subtraction_gates(); + } +}