plonky2/ecdsa/src/gadgets/biguint.rs
2022-08-26 16:10:44 -04:00

507 lines
16 KiB
Rust

use std::marker::PhantomData;
use num::{BigUint, Integer, Zero};
use plonky2::hash::hash_types::RichField;
use plonky2::iop::generator::{GeneratedValues, SimpleGenerator};
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartitionWitness, Witness};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2_field::extension::Extendable;
use plonky2_field::types::{PrimeField, PrimeField64};
use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target};
use plonky2_u32::gadgets::multiple_comparison::list_le_u32_circuit;
use plonky2_u32::witness::{GeneratedValuesU32, WitnessU32};
#[derive(Clone, Debug)]
pub struct BigUintTarget {
pub limbs: Vec<U32Target>,
}
impl BigUintTarget {
pub fn num_limbs(&self) -> usize {
self.limbs.len()
}
pub fn get_limb(&self, i: usize) -> U32Target {
self.limbs[i]
}
}
pub trait CircuitBuilderBiguint<F: RichField + Extendable<D>, const D: usize> {
fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget;
fn zero_biguint(&mut self) -> BigUintTarget;
fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget);
fn pad_biguints(
&mut self,
a: &BigUintTarget,
b: &BigUintTarget,
) -> (BigUintTarget, BigUintTarget);
fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget;
fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget;
/// Add two `BigUintTarget`s.
fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget;
/// Subtract two `BigUintTarget`s. We assume that the first is larger than the second.
fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget;
fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget;
fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget;
/// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function).
fn mul_add_biguint(
&mut self,
x: &BigUintTarget,
y: &BigUintTarget,
z: &BigUintTarget,
) -> BigUintTarget;
fn div_rem_biguint(
&mut self,
a: &BigUintTarget,
b: &BigUintTarget,
) -> (BigUintTarget, BigUintTarget);
fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget;
fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget;
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderBiguint<F, D>
for CircuitBuilder<F, D>
{
fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget {
let limb_values = value.to_u32_digits();
let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect();
BigUintTarget { limbs }
}
fn zero_biguint(&mut self) -> BigUintTarget {
self.constant_biguint(&BigUint::zero())
}
fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) {
let min_limbs = lhs.num_limbs().min(rhs.num_limbs());
for i in 0..min_limbs {
self.connect_u32(lhs.get_limb(i), rhs.get_limb(i));
}
for i in min_limbs..lhs.num_limbs() {
self.assert_zero_u32(lhs.get_limb(i));
}
for i in min_limbs..rhs.num_limbs() {
self.assert_zero_u32(rhs.get_limb(i));
}
}
fn pad_biguints(
&mut self,
a: &BigUintTarget,
b: &BigUintTarget,
) -> (BigUintTarget, BigUintTarget) {
if a.num_limbs() > b.num_limbs() {
let mut padded_b = b.clone();
for _ in b.num_limbs()..a.num_limbs() {
padded_b.limbs.push(self.zero_u32());
}
(a.clone(), padded_b)
} else {
let mut padded_a = a.clone();
for _ in a.num_limbs()..b.num_limbs() {
padded_a.limbs.push(self.zero_u32());
}
(padded_a, b.clone())
}
}
fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget {
let (a, b) = self.pad_biguints(a, b);
list_le_u32_circuit(self, a.limbs, b.limbs)
}
fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget {
let limbs = self.add_virtual_u32_targets(num_limbs);
BigUintTarget { limbs }
}
fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let num_limbs = a.num_limbs().max(b.num_limbs());
let mut combined_limbs = vec![];
let mut carry = self.zero_u32();
for i in 0..num_limbs {
let a_limb = (i < a.num_limbs())
.then(|| a.limbs[i])
.unwrap_or_else(|| self.zero_u32());
let b_limb = (i < b.num_limbs())
.then(|| b.limbs[i])
.unwrap_or_else(|| self.zero_u32());
let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]);
carry = new_carry;
combined_limbs.push(new_limb);
}
combined_limbs.push(carry);
BigUintTarget {
limbs: combined_limbs,
}
}
fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let (a, b) = self.pad_biguints(a, b);
let num_limbs = a.limbs.len();
let mut result_limbs = vec![];
let mut borrow = self.zero_u32();
for i in 0..num_limbs {
let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow);
result_limbs.push(result);
borrow = new_borrow;
}
// Borrow should be zero here.
BigUintTarget {
limbs: result_limbs,
}
}
fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let total_limbs = a.limbs.len() + b.limbs.len();
let mut to_add = vec![vec![]; total_limbs];
for i in 0..a.limbs.len() {
for j in 0..b.limbs.len() {
let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]);
to_add[i + j].push(product);
to_add[i + j + 1].push(carry);
}
}
let mut combined_limbs = vec![];
let mut carry = self.zero_u32();
for summands in &mut to_add {
let (new_result, new_carry) = self.add_u32s_with_carry(summands, carry);
combined_limbs.push(new_result);
carry = new_carry;
}
combined_limbs.push(carry);
BigUintTarget {
limbs: combined_limbs,
}
}
fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget {
let t = b.target;
BigUintTarget {
limbs: a
.limbs
.iter()
.map(|&l| U32Target(self.mul(l.0, t)))
.collect(),
}
}
fn mul_add_biguint(
&mut self,
x: &BigUintTarget,
y: &BigUintTarget,
z: &BigUintTarget,
) -> BigUintTarget {
let prod = self.mul_biguint(x, y);
self.add_biguint(&prod, z)
}
fn div_rem_biguint(
&mut self,
a: &BigUintTarget,
b: &BigUintTarget,
) -> (BigUintTarget, BigUintTarget) {
let a_len = a.limbs.len();
let b_len = b.limbs.len();
let div_num_limbs = if b_len > a_len + 1 {
0
} else {
a_len - b_len + 1
};
let div = self.add_virtual_biguint_target(div_num_limbs);
let rem = self.add_virtual_biguint_target(b_len);
self.add_simple_generator(BigUintDivRemGenerator::<F, D> {
a: a.clone(),
b: b.clone(),
div: div.clone(),
rem: rem.clone(),
_phantom: PhantomData,
});
let div_b = self.mul_biguint(&div, b);
let div_b_plus_rem = self.add_biguint(&div_b, &rem);
self.connect_biguint(a, &div_b_plus_rem);
let cmp_rem_b = self.cmp_biguint(&rem, b);
self.assert_one(cmp_rem_b.target);
(div, rem)
}
fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let (div, _rem) = self.div_rem_biguint(a, b);
div
}
fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let (_div, rem) = self.div_rem_biguint(a, b);
rem
}
}
pub trait WitnessBigUint<F: PrimeField64>: Witness<F> {
fn get_biguint_target(&self, target: BigUintTarget) -> BigUint;
fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint);
}
impl<T: Witness<F>, F: PrimeField64> WitnessBigUint<F> for T {
fn get_biguint_target(&self, target: BigUintTarget) -> BigUint {
target
.limbs
.into_iter()
.rev()
.fold(BigUint::zero(), |acc, limb| {
(acc << 32) + self.get_target(limb.0).to_canonical_biguint()
})
}
fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) {
let mut limbs = value.to_u32_digits();
assert!(target.num_limbs() >= limbs.len());
limbs.resize(target.num_limbs(), 0);
for i in 0..target.num_limbs() {
self.set_u32_target(target.limbs[i], limbs[i]);
}
}
}
pub trait GeneratedValuesBigUint<F: PrimeField> {
fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint);
}
impl<F: PrimeField> GeneratedValuesBigUint<F> for GeneratedValues<F> {
fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) {
let mut limbs = value.to_u32_digits();
assert!(target.num_limbs() >= limbs.len());
limbs.resize(target.num_limbs(), 0);
for i in 0..target.num_limbs() {
self.set_u32_target(target.get_limb(i), limbs[i]);
}
}
}
#[derive(Debug)]
struct BigUintDivRemGenerator<F: RichField + Extendable<D>, const D: usize> {
a: BigUintTarget,
b: BigUintTarget,
div: BigUintTarget,
rem: BigUintTarget,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for BigUintDivRemGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
self.a
.limbs
.iter()
.chain(&self.b.limbs)
.map(|&l| l.0)
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = witness.get_biguint_target(self.a.clone());
let b = witness.get_biguint_target(self.b.clone());
let (div, rem) = a.div_rem(&b);
out_buffer.set_biguint_target(&self.div, &div);
out_buffer.set_biguint_target(&self.rem, &rem);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use num::{BigUint, FromPrimitive, Integer};
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use plonky2::{
iop::witness::PartialWitness,
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig},
};
use rand::Rng;
use crate::gadgets::biguint::{CircuitBuilderBiguint, WitnessBigUint};
#[test]
fn test_biguint_add() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let mut rng = rand::thread_rng();
let x_value = BigUint::from_u128(rng.gen()).unwrap();
let y_value = BigUint::from_u128(rng.gen()).unwrap();
let expected_z_value = &x_value + &y_value;
let config = CircuitConfig::standard_recursion_config();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len());
let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len());
let z = builder.add_biguint(&x, &y);
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
builder.connect_biguint(&z, &expected_z);
pw.set_biguint_target(&x, &x_value);
pw.set_biguint_target(&y, &y_value);
pw.set_biguint_target(&expected_z, &expected_z_value);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof)
}
#[test]
fn test_biguint_sub() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let mut rng = rand::thread_rng();
let mut x_value = BigUint::from_u128(rng.gen()).unwrap();
let mut y_value = BigUint::from_u128(rng.gen()).unwrap();
if y_value > x_value {
(x_value, y_value) = (y_value, x_value);
}
let expected_z_value = &x_value - &y_value;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let z = builder.sub_biguint(&x, &y);
let expected_z = builder.constant_biguint(&expected_z_value);
builder.connect_biguint(&z, &expected_z);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof)
}
#[test]
fn test_biguint_mul() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let mut rng = rand::thread_rng();
let x_value = BigUint::from_u128(rng.gen()).unwrap();
let y_value = BigUint::from_u128(rng.gen()).unwrap();
let expected_z_value = &x_value * &y_value;
let config = CircuitConfig::standard_recursion_config();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len());
let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len());
let z = builder.mul_biguint(&x, &y);
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
builder.connect_biguint(&z, &expected_z);
pw.set_biguint_target(&x, &x_value);
pw.set_biguint_target(&y, &y_value);
pw.set_biguint_target(&expected_z, &expected_z_value);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof)
}
#[test]
fn test_biguint_cmp() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let mut rng = rand::thread_rng();
let x_value = BigUint::from_u128(rng.gen()).unwrap();
let y_value = BigUint::from_u128(rng.gen()).unwrap();
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let cmp = builder.cmp_biguint(&x, &y);
let expected_cmp = builder.constant_bool(x_value <= y_value);
builder.connect(cmp.target, expected_cmp.target);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof)
}
#[test]
fn test_biguint_div_rem() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let mut rng = rand::thread_rng();
let mut x_value = BigUint::from_u128(rng.gen()).unwrap();
let mut y_value = BigUint::from_u128(rng.gen()).unwrap();
if y_value > x_value {
(x_value, y_value) = (y_value, x_value);
}
let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value);
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let (div, rem) = builder.div_rem_biguint(&x, &y);
let expected_div = builder.constant_biguint(&expected_div_value);
let expected_rem = builder.constant_biguint(&expected_rem_value);
builder.connect_biguint(&div, &expected_div);
builder.connect_biguint(&rem, &expected_rem);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof)
}
}