use std::marker::PhantomData; use num::{BigUint, Integer, Zero}; use plonky2::hash::hash_types::RichField; use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; use plonky2::iop::target::{BoolTarget, Target}; use plonky2::iop::witness::{PartitionWitness, Witness}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2_field::extension::Extendable; use plonky2_field::types::{PrimeField, PrimeField64}; use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; use plonky2_u32::gadgets::multiple_comparison::list_le_u32_circuit; use plonky2_u32::witness::{GeneratedValuesU32, WitnessU32}; #[derive(Clone, Debug)] pub struct BigUintTarget { pub limbs: Vec, } impl BigUintTarget { pub fn num_limbs(&self) -> usize { self.limbs.len() } pub fn get_limb(&self, i: usize) -> U32Target { self.limbs[i] } } pub trait CircuitBuilderBiguint, const D: usize> { fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget; fn zero_biguint(&mut self) -> BigUintTarget; fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget); fn pad_biguints( &mut self, a: &BigUintTarget, b: &BigUintTarget, ) -> (BigUintTarget, BigUintTarget); fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget; fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget; /// Add two `BigUintTarget`s. fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; /// Subtract two `BigUintTarget`s. We assume that the first is larger than the second. fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget; /// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). fn mul_add_biguint( &mut self, x: &BigUintTarget, y: &BigUintTarget, z: &BigUintTarget, ) -> BigUintTarget; fn div_rem_biguint( &mut self, a: &BigUintTarget, b: &BigUintTarget, ) -> (BigUintTarget, BigUintTarget); fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; } impl, const D: usize> CircuitBuilderBiguint for CircuitBuilder { fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { let limb_values = value.to_u32_digits(); let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect(); BigUintTarget { limbs } } fn zero_biguint(&mut self) -> BigUintTarget { self.constant_biguint(&BigUint::zero()) } fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); for i in 0..min_limbs { self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); } for i in min_limbs..lhs.num_limbs() { self.assert_zero_u32(lhs.get_limb(i)); } for i in min_limbs..rhs.num_limbs() { self.assert_zero_u32(rhs.get_limb(i)); } } fn pad_biguints( &mut self, a: &BigUintTarget, b: &BigUintTarget, ) -> (BigUintTarget, BigUintTarget) { if a.num_limbs() > b.num_limbs() { let mut padded_b = b.clone(); for _ in b.num_limbs()..a.num_limbs() { padded_b.limbs.push(self.zero_u32()); } (a.clone(), padded_b) } else { let mut padded_a = a.clone(); for _ in a.num_limbs()..b.num_limbs() { padded_a.limbs.push(self.zero_u32()); } (padded_a, b.clone()) } } fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { let (a, b) = self.pad_biguints(a, b); list_le_u32_circuit(self, a.limbs, b.limbs) } fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { let limbs = self.add_virtual_u32_targets(num_limbs); BigUintTarget { limbs } } fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let num_limbs = a.num_limbs().max(b.num_limbs()); let mut combined_limbs = vec![]; let mut carry = self.zero_u32(); for i in 0..num_limbs { let a_limb = (i < a.num_limbs()) .then(|| a.limbs[i]) .unwrap_or_else(|| self.zero_u32()); let b_limb = (i < b.num_limbs()) .then(|| b.limbs[i]) .unwrap_or_else(|| self.zero_u32()); let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]); carry = new_carry; combined_limbs.push(new_limb); } combined_limbs.push(carry); BigUintTarget { limbs: combined_limbs, } } fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let (a, b) = self.pad_biguints(a, b); let num_limbs = a.limbs.len(); let mut result_limbs = vec![]; let mut borrow = self.zero_u32(); for i in 0..num_limbs { let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); result_limbs.push(result); borrow = new_borrow; } // Borrow should be zero here. BigUintTarget { limbs: result_limbs, } } fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let total_limbs = a.limbs.len() + b.limbs.len(); let mut to_add = vec![vec![]; total_limbs]; for i in 0..a.limbs.len() { for j in 0..b.limbs.len() { let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); to_add[i + j].push(product); to_add[i + j + 1].push(carry); } } let mut combined_limbs = vec![]; let mut carry = self.zero_u32(); for summands in &mut to_add { let (new_result, new_carry) = self.add_u32s_with_carry(summands, carry); combined_limbs.push(new_result); carry = new_carry; } combined_limbs.push(carry); BigUintTarget { limbs: combined_limbs, } } fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget { let t = b.target; BigUintTarget { limbs: a .limbs .iter() .map(|&l| U32Target(self.mul(l.0, t))) .collect(), } } fn mul_add_biguint( &mut self, x: &BigUintTarget, y: &BigUintTarget, z: &BigUintTarget, ) -> BigUintTarget { let prod = self.mul_biguint(x, y); self.add_biguint(&prod, z) } fn div_rem_biguint( &mut self, a: &BigUintTarget, b: &BigUintTarget, ) -> (BigUintTarget, BigUintTarget) { let a_len = a.limbs.len(); let b_len = b.limbs.len(); let div_num_limbs = if b_len > a_len + 1 { 0 } else { a_len - b_len + 1 }; let div = self.add_virtual_biguint_target(div_num_limbs); let rem = self.add_virtual_biguint_target(b_len); self.add_simple_generator(BigUintDivRemGenerator:: { a: a.clone(), b: b.clone(), div: div.clone(), rem: rem.clone(), _phantom: PhantomData, }); let div_b = self.mul_biguint(&div, b); let div_b_plus_rem = self.add_biguint(&div_b, &rem); self.connect_biguint(a, &div_b_plus_rem); let cmp_rem_b = self.cmp_biguint(&rem, b); self.assert_one(cmp_rem_b.target); (div, rem) } fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let (div, _rem) = self.div_rem_biguint(a, b); div } fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let (_div, rem) = self.div_rem_biguint(a, b); rem } } pub trait WitnessBigUint: Witness { fn get_biguint_target(&self, target: BigUintTarget) -> BigUint; fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); } impl, F: PrimeField64> WitnessBigUint for T { fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { target .limbs .into_iter() .rev() .fold(BigUint::zero(), |acc, limb| { (acc << 32) + self.get_target(limb.0).to_canonical_biguint() }) } fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { let mut limbs = value.to_u32_digits(); assert!(target.num_limbs() >= limbs.len()); limbs.resize(target.num_limbs(), 0); for i in 0..target.num_limbs() { self.set_u32_target(target.limbs[i], limbs[i]); } } } pub trait GeneratedValuesBigUint { fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); } impl GeneratedValuesBigUint for GeneratedValues { fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { let mut limbs = value.to_u32_digits(); assert!(target.num_limbs() >= limbs.len()); limbs.resize(target.num_limbs(), 0); for i in 0..target.num_limbs() { self.set_u32_target(target.get_limb(i), limbs[i]); } } } #[derive(Debug)] struct BigUintDivRemGenerator, const D: usize> { a: BigUintTarget, b: BigUintTarget, div: BigUintTarget, rem: BigUintTarget, _phantom: PhantomData, } impl, const D: usize> SimpleGenerator for BigUintDivRemGenerator { fn dependencies(&self) -> Vec { self.a .limbs .iter() .chain(&self.b.limbs) .map(|&l| l.0) .collect() } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { let a = witness.get_biguint_target(self.a.clone()); let b = witness.get_biguint_target(self.b.clone()); let (div, rem) = a.div_rem(&b); out_buffer.set_biguint_target(&self.div, &div); out_buffer.set_biguint_target(&self.rem, &rem); } } #[cfg(test)] mod tests { use anyhow::Result; use num::{BigUint, FromPrimitive, Integer}; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::{ iop::witness::PartialWitness, plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, }; use rand::Rng; use crate::gadgets::biguint::{CircuitBuilderBiguint, WitnessBigUint}; #[test] fn test_biguint_add() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; let mut rng = rand::thread_rng(); let x_value = BigUint::from_u128(rng.gen()).unwrap(); let y_value = BigUint::from_u128(rng.gen()).unwrap(); let expected_z_value = &x_value + &y_value; let config = CircuitConfig::standard_recursion_config(); let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); let z = builder.add_biguint(&x, &y); let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); pw.set_biguint_target(&x, &x_value); pw.set_biguint_target(&y, &y_value); pw.set_biguint_target(&expected_z, &expected_z_value); let data = builder.build::(); let proof = data.prove(pw).unwrap(); data.verify(proof) } #[test] fn test_biguint_sub() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; let mut rng = rand::thread_rng(); let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); if y_value > x_value { (x_value, y_value) = (y_value, x_value); } let expected_z_value = &x_value - &y_value; let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let x = builder.constant_biguint(&x_value); let y = builder.constant_biguint(&y_value); let z = builder.sub_biguint(&x, &y); let expected_z = builder.constant_biguint(&expected_z_value); builder.connect_biguint(&z, &expected_z); let data = builder.build::(); let proof = data.prove(pw).unwrap(); data.verify(proof) } #[test] fn test_biguint_mul() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; let mut rng = rand::thread_rng(); let x_value = BigUint::from_u128(rng.gen()).unwrap(); let y_value = BigUint::from_u128(rng.gen()).unwrap(); let expected_z_value = &x_value * &y_value; let config = CircuitConfig::standard_recursion_config(); let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); let z = builder.mul_biguint(&x, &y); let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); pw.set_biguint_target(&x, &x_value); pw.set_biguint_target(&y, &y_value); pw.set_biguint_target(&expected_z, &expected_z_value); let data = builder.build::(); let proof = data.prove(pw).unwrap(); data.verify(proof) } #[test] fn test_biguint_cmp() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; let mut rng = rand::thread_rng(); let x_value = BigUint::from_u128(rng.gen()).unwrap(); let y_value = BigUint::from_u128(rng.gen()).unwrap(); let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let x = builder.constant_biguint(&x_value); let y = builder.constant_biguint(&y_value); let cmp = builder.cmp_biguint(&x, &y); let expected_cmp = builder.constant_bool(x_value <= y_value); builder.connect(cmp.target, expected_cmp.target); let data = builder.build::(); let proof = data.prove(pw).unwrap(); data.verify(proof) } #[test] fn test_biguint_div_rem() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; let mut rng = rand::thread_rng(); let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); if y_value > x_value { (x_value, y_value) = (y_value, x_value); } let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value); let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let x = builder.constant_biguint(&x_value); let y = builder.constant_biguint(&y_value); let (div, rem) = builder.div_rem_biguint(&x, &y); let expected_div = builder.constant_biguint(&expected_div_value); let expected_rem = builder.constant_biguint(&expected_rem_value); builder.connect_biguint(&div, &expected_div); builder.connect_biguint(&rem, &expected_rem); let data = builder.build::(); let proof = data.prove(pw).unwrap(); data.verify(proof) } }