From 8aa3ed0997b36df4f1f80cd6c348ba98d99da153 Mon Sep 17 00:00:00 2001 From: Sladuca Date: Fri, 26 Aug 2022 16:10:34 -0400 Subject: [PATCH 1/2] cleaner witness extension --- ecdsa/src/gadgets/biguint.rs | 82 ++++++++++++++------------- ecdsa/src/gadgets/curve_fixed_base.rs | 4 +- ecdsa/src/gadgets/glv.rs | 9 ++- ecdsa/src/gadgets/nonnative.rs | 44 ++++++-------- u32/src/gadgets/arithmetic_u32.rs | 6 +- u32/src/witness.rs | 37 +++++++----- 6 files changed, 93 insertions(+), 89 deletions(-) diff --git a/ecdsa/src/gadgets/biguint.rs b/ecdsa/src/gadgets/biguint.rs index 1dbe4657..188b04ba 100644 --- a/ecdsa/src/gadgets/biguint.rs +++ b/ecdsa/src/gadgets/biguint.rs @@ -7,10 +7,10 @@ use plonky2::iop::target::{BoolTarget, Target}; use plonky2::iop::witness::{PartitionWitness, Witness}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2_field::extension::Extendable; -use plonky2_field::types::PrimeField; +use plonky2_field::types::{PrimeField, PrimeField64}; use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; use plonky2_u32::gadgets::multiple_comparison::list_le_u32_circuit; -use plonky2_u32::witness::{generated_values_set_u32_target, witness_set_u32_target}; +use plonky2_u32::witness::{GeneratedValuesU32, WitnessU32}; #[derive(Clone, Debug)] pub struct BigUintTarget { @@ -270,41 +270,43 @@ impl, const D: usize> CircuitBuilderBiguint } } -pub fn witness_get_biguint_target, F: PrimeField>( - witness: &W, - bt: BigUintTarget, -) -> BigUint { - bt.limbs +pub trait WitnessBigUint: Witness { + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint; + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); +} + +impl, F: PrimeField64> WitnessBigUint for T { + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + target.limbs .into_iter() .rev() .fold(BigUint::zero(), |acc, limb| { - (acc << 32) + witness.get_target(limb.0).to_canonical_biguint() + (acc << 32) + self.get_target(limb.0).to_canonical_biguint() }) -} + } -pub fn witness_set_biguint_target, F: PrimeField>( - witness: &mut W, - target: &BigUintTarget, - value: &BigUint, -) { - let mut limbs = value.to_u32_digits(); - assert!(target.num_limbs() >= limbs.len()); - limbs.resize(target.num_limbs(), 0); - for i in 0..target.num_limbs() { - witness_set_u32_target(witness, target.limbs[i], limbs[i]); + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + self.set_u32_target(target.limbs[i], limbs[i]); + } } } -pub fn buffer_set_biguint_target( - buffer: &mut GeneratedValues, - target: &BigUintTarget, - value: &BigUint, -) { - let mut limbs = value.to_u32_digits(); - assert!(target.num_limbs() >= limbs.len()); - limbs.resize(target.num_limbs(), 0); - for i in 0..target.num_limbs() { - generated_values_set_u32_target(buffer, target.get_limb(i), limbs[i]); +pub trait GeneratedValuesBigUint { + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); +} + +impl GeneratedValuesBigUint for GeneratedValues { + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + self.set_u32_target(target.get_limb(i), limbs[i]); + } } } @@ -330,12 +332,12 @@ impl, const D: usize> SimpleGenerator } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness_get_biguint_target(witness, self.a.clone()); - let b = witness_get_biguint_target(witness, self.b.clone()); + let a = witness.get_biguint_target(self.a.clone()); + let b = witness.get_biguint_target(self.b.clone()); let (div, rem) = a.div_rem(&b); - buffer_set_biguint_target(out_buffer, &self.div, &div); - buffer_set_biguint_target(out_buffer, &self.rem, &rem); + out_buffer.set_biguint_target(&self.div, &div); + out_buffer.set_biguint_target(&self.rem, &rem); } } @@ -350,7 +352,7 @@ mod tests { }; use rand::Rng; - use crate::gadgets::biguint::{witness_set_biguint_target, CircuitBuilderBiguint}; + use crate::gadgets::biguint::{WitnessBigUint, CircuitBuilderBiguint}; #[test] fn test_biguint_add() -> Result<()> { @@ -373,9 +375,9 @@ mod tests { let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); - witness_set_biguint_target(&mut pw, &x, &x_value); - witness_set_biguint_target(&mut pw, &y, &y_value); - witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); let data = builder.build::(); let proof = data.prove(pw).unwrap(); @@ -433,9 +435,9 @@ mod tests { let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); - witness_set_biguint_target(&mut pw, &x, &x_value); - witness_set_biguint_target(&mut pw, &y, &y_value); - witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); let data = builder.build::(); let proof = data.prove(pw).unwrap(); diff --git a/ecdsa/src/gadgets/curve_fixed_base.rs b/ecdsa/src/gadgets/curve_fixed_base.rs index 44dc9488..0fd8e841 100644 --- a/ecdsa/src/gadgets/curve_fixed_base.rs +++ b/ecdsa/src/gadgets/curve_fixed_base.rs @@ -76,7 +76,7 @@ mod tests { use crate::curve::curve_types::{Curve, CurveScalar}; use crate::curve::secp256k1::Secp256K1; - use crate::gadgets::biguint::witness_set_biguint_target; + use crate::gadgets::biguint::WitnessBigUint; use crate::gadgets::curve::CircuitBuilderCurve; use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; use crate::gadgets::nonnative::CircuitBuilderNonNative; @@ -101,7 +101,7 @@ mod tests { builder.curve_assert_valid(&res_expected); let n_target = builder.add_virtual_nonnative_target::(); - witness_set_biguint_target(&mut pw, &n_target.value, &n.to_canonical_biguint()); + pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint()); let res_target = fixed_base_curve_mul_circuit(&mut builder, g, &n_target); builder.curve_assert_valid(&res_target); diff --git a/ecdsa/src/gadgets/glv.rs b/ecdsa/src/gadgets/glv.rs index 8e62e906..2d86652c 100644 --- a/ecdsa/src/gadgets/glv.rs +++ b/ecdsa/src/gadgets/glv.rs @@ -12,7 +12,7 @@ use plonky2_field::types::{Field, PrimeField}; use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S}; use crate::curve::secp256k1::Secp256K1; -use crate::gadgets::biguint::{buffer_set_biguint_target, witness_get_biguint_target}; +use crate::gadgets::biguint::{GeneratedValuesBigUint, WitnessBigUint}; use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; use crate::gadgets::curve_msm::curve_msm_circuit; use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; @@ -116,15 +116,14 @@ impl, const D: usize> SimpleGenerator } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let k = Secp256K1Scalar::from_noncanonical_biguint(witness_get_biguint_target( - witness, + let k = Secp256K1Scalar::from_noncanonical_biguint(witness.get_biguint_target( self.k.value.clone(), )); let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); - buffer_set_biguint_target(out_buffer, &self.k1.value, &k1.to_canonical_biguint()); - buffer_set_biguint_target(out_buffer, &self.k2.value, &k2.to_canonical_biguint()); + out_buffer.set_biguint_target(&self.k1.value, &k1.to_canonical_biguint()); + out_buffer.set_biguint_target(&self.k2.value, &k2.to_canonical_biguint()); out_buffer.set_bool_target(self.k1_neg, k1_neg); out_buffer.set_bool_target(self.k2_neg, k2_neg); } diff --git a/ecdsa/src/gadgets/nonnative.rs b/ecdsa/src/gadgets/nonnative.rs index 393aac75..db1231bc 100644 --- a/ecdsa/src/gadgets/nonnative.rs +++ b/ecdsa/src/gadgets/nonnative.rs @@ -10,11 +10,11 @@ use plonky2_field::types::PrimeField; use plonky2_field::{extension::Extendable, types::Field}; use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; use plonky2_u32::gadgets::range_check::range_check_u32_circuit; -use plonky2_u32::witness::generated_values_set_u32_target; +use plonky2_u32::witness::GeneratedValuesU32; use plonky2_util::ceil_div_usize; use crate::gadgets::biguint::{ - buffer_set_biguint_target, witness_get_biguint_target, BigUintTarget, CircuitBuilderBiguint, + GeneratedValuesBigUint, WitnessBigUint, BigUintTarget, CircuitBuilderBiguint, }; #[derive(Clone, Debug)] @@ -467,12 +467,10 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, + let a = FF::from_noncanonical_biguint(witness.get_biguint_target( self.a.value.clone(), )); - let b = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, + let b = FF::from_noncanonical_biguint(witness.get_biguint_target( self.b.value.clone(), )); let a_biguint = a.to_canonical_biguint(); @@ -485,7 +483,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat (false, sum_biguint) }; - buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); + out_buffer.set_biguint_target(&self.sum.value, &sum_reduced); out_buffer.set_bool_target(self.overflow, overflow); } } @@ -514,8 +512,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat .summands .iter() .map(|summand| { - FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, + FF::from_noncanonical_biguint(witness.get_biguint_target( summand.value.clone(), )) }) @@ -533,8 +530,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); let overflow = overflow_biguint.to_u64_digits()[0] as u32; - buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); - generated_values_set_u32_target(out_buffer, self.overflow, overflow); + out_buffer.set_biguint_target(&self.sum.value, &sum_reduced); + out_buffer.set_u32_target(self.overflow, overflow); } } @@ -562,12 +559,10 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, + let a = FF::from_noncanonical_biguint(witness.get_biguint_target( self.a.value.clone(), )); - let b = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, + let b = FF::from_noncanonical_biguint(witness.get_biguint_target( self.b.value.clone(), )); let a_biguint = a.to_canonical_biguint(); @@ -580,7 +575,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat (modulus + a_biguint - b_biguint, true) }; - buffer_set_biguint_target(out_buffer, &self.diff.value, &diff_biguint); + out_buffer.set_biguint_target(&self.diff.value, &diff_biguint); out_buffer.set_bool_target(self.overflow, overflow); } } @@ -609,12 +604,10 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, + let a = FF::from_noncanonical_biguint(witness.get_biguint_target( self.a.value.clone(), )); - let b = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, + let b = FF::from_noncanonical_biguint(witness.get_biguint_target( self.b.value.clone(), )); let a_biguint = a.to_canonical_biguint(); @@ -625,8 +618,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let modulus = FF::order(); let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); - buffer_set_biguint_target(out_buffer, &self.prod.value, &prod_reduced); - buffer_set_biguint_target(out_buffer, &self.overflow, &overflow_biguint); + out_buffer.set_biguint_target(&self.prod.value, &prod_reduced); + out_buffer.set_biguint_target(&self.overflow, &overflow_biguint); } } @@ -646,8 +639,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let x = FF::from_noncanonical_biguint(witness_get_biguint_target( - witness, + let x = FF::from_noncanonical_biguint(witness.get_biguint_target( self.x.value.clone(), )); let inv = x.inverse(); @@ -658,8 +650,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let modulus = FF::order(); let (div, _rem) = prod.div_rem(&modulus); - buffer_set_biguint_target(out_buffer, &self.div, &div); - buffer_set_biguint_target(out_buffer, &self.inv, &inv_biguint); + out_buffer.set_biguint_target(&self.div, &div); + out_buffer.set_biguint_target(&self.inv, &inv_biguint); } } diff --git a/u32/src/gadgets/arithmetic_u32.rs b/u32/src/gadgets/arithmetic_u32.rs index 7a7731b1..7475681c 100644 --- a/u32/src/gadgets/arithmetic_u32.rs +++ b/u32/src/gadgets/arithmetic_u32.rs @@ -10,7 +10,7 @@ use plonky2_field::extension::Extendable; use crate::gates::add_many_u32::U32AddManyGate; use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::subtraction_u32::U32SubtractionGate; -use crate::witness::generated_values_set_u32_target; +use crate::witness::GeneratedValuesU32; #[derive(Clone, Copy, Debug)] pub struct U32Target(pub Target); @@ -249,8 +249,8 @@ impl, const D: usize> SimpleGenerator let low = x_u64 as u32; let high = (x_u64 >> 32) as u32; - generated_values_set_u32_target(out_buffer, self.low, low); - generated_values_set_u32_target(out_buffer, self.high, high); + out_buffer.set_u32_target(self.low, low); + out_buffer.set_u32_target(self.high, high); } } diff --git a/u32/src/witness.rs b/u32/src/witness.rs index 1b88d60d..38aa2238 100644 --- a/u32/src/witness.rs +++ b/u32/src/witness.rs @@ -1,21 +1,32 @@ use plonky2::iop::generator::GeneratedValues; use plonky2::iop::witness::Witness; -use plonky2_field::types::Field; +use plonky2_field::types::{Field, PrimeField64}; use crate::gadgets::arithmetic_u32::U32Target; -pub fn generated_values_set_u32_target( - buffer: &mut GeneratedValues, - target: U32Target, - value: u32, -) { - buffer.set_target(target.0, F::from_canonical_u32(value)) +pub trait WitnessU32: Witness { + fn set_u32_target(&mut self, target: U32Target, value: u32); + fn get_u32_target(&self, target: U32Target) -> (u32, u32); } -pub fn witness_set_u32_target, F: Field>( - witness: &mut W, - target: U32Target, - value: u32, -) { - witness.set_target(target.0, F::from_canonical_u32(value)) +impl, F: PrimeField64> WitnessU32 for T { + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)); + } + fn get_u32_target(&self, target: U32Target) -> (u32, u32) { + let x_u64 = self.get_target(target.0).to_canonical_u64(); + let low = x_u64 as u32; + let high = (x_u64 >> 32) as u32; + (low, high) + } +} + +pub trait GeneratedValuesU32 { + fn set_u32_target(&mut self, target: U32Target, value: u32); +} + +impl GeneratedValuesU32 for GeneratedValues { + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)) + } } From 356c7cd9255bdc6ab32e3c8dd4f1d7a35d45aeb2 Mon Sep 17 00:00:00 2001 From: Sladuca Date: Fri, 26 Aug 2022 16:10:44 -0400 Subject: [PATCH 2/2] fmt --- ecdsa/src/gadgets/biguint.rs | 15 ++++++++------- ecdsa/src/gadgets/glv.rs | 6 +++--- ecdsa/src/gadgets/nonnative.rs | 34 +++++++++------------------------- 3 files changed, 20 insertions(+), 35 deletions(-) diff --git a/ecdsa/src/gadgets/biguint.rs b/ecdsa/src/gadgets/biguint.rs index 188b04ba..faae365c 100644 --- a/ecdsa/src/gadgets/biguint.rs +++ b/ecdsa/src/gadgets/biguint.rs @@ -277,12 +277,13 @@ pub trait WitnessBigUint: Witness { impl, F: PrimeField64> WitnessBigUint for T { fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { - target.limbs - .into_iter() - .rev() - .fold(BigUint::zero(), |acc, limb| { - (acc << 32) + self.get_target(limb.0).to_canonical_biguint() - }) + target + .limbs + .into_iter() + .rev() + .fold(BigUint::zero(), |acc, limb| { + (acc << 32) + self.get_target(limb.0).to_canonical_biguint() + }) } fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { @@ -352,7 +353,7 @@ mod tests { }; use rand::Rng; - use crate::gadgets::biguint::{WitnessBigUint, CircuitBuilderBiguint}; + use crate::gadgets::biguint::{CircuitBuilderBiguint, WitnessBigUint}; #[test] fn test_biguint_add() -> Result<()> { diff --git a/ecdsa/src/gadgets/glv.rs b/ecdsa/src/gadgets/glv.rs index 2d86652c..4302023e 100644 --- a/ecdsa/src/gadgets/glv.rs +++ b/ecdsa/src/gadgets/glv.rs @@ -116,9 +116,9 @@ impl, const D: usize> SimpleGenerator } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let k = Secp256K1Scalar::from_noncanonical_biguint(witness.get_biguint_target( - self.k.value.clone(), - )); + let k = Secp256K1Scalar::from_noncanonical_biguint( + witness.get_biguint_target(self.k.value.clone()), + ); let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); diff --git a/ecdsa/src/gadgets/nonnative.rs b/ecdsa/src/gadgets/nonnative.rs index db1231bc..c6ff4753 100644 --- a/ecdsa/src/gadgets/nonnative.rs +++ b/ecdsa/src/gadgets/nonnative.rs @@ -14,7 +14,7 @@ use plonky2_u32::witness::GeneratedValuesU32; use plonky2_util::ceil_div_usize; use crate::gadgets::biguint::{ - GeneratedValuesBigUint, WitnessBigUint, BigUintTarget, CircuitBuilderBiguint, + BigUintTarget, CircuitBuilderBiguint, GeneratedValuesBigUint, WitnessBigUint, }; #[derive(Clone, Debug)] @@ -467,12 +467,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness.get_biguint_target( - self.a.value.clone(), - )); - let b = FF::from_noncanonical_biguint(witness.get_biguint_target( - self.b.value.clone(), - )); + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); let sum_biguint = a_biguint + b_biguint; @@ -512,9 +508,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat .summands .iter() .map(|summand| { - FF::from_noncanonical_biguint(witness.get_biguint_target( - summand.value.clone(), - )) + FF::from_noncanonical_biguint(witness.get_biguint_target(summand.value.clone())) }) .collect(); let summand_biguints: Vec<_> = summands @@ -559,12 +553,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness.get_biguint_target( - self.a.value.clone(), - )); - let b = FF::from_noncanonical_biguint(witness.get_biguint_target( - self.b.value.clone(), - )); + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); @@ -604,12 +594,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness.get_biguint_target( - self.a.value.clone(), - )); - let b = FF::from_noncanonical_biguint(witness.get_biguint_target( - self.b.value.clone(), - )); + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); @@ -639,9 +625,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let x = FF::from_noncanonical_biguint(witness.get_biguint_target( - self.x.value.clone(), - )); + let x = FF::from_noncanonical_biguint(witness.get_biguint_target(self.x.value.clone())); let inv = x.inverse(); let x_biguint = x.to_canonical_biguint();