cleaner witness extension

This commit is contained in:
Sladuca 2022-08-26 16:10:34 -04:00
parent 70971aee2d
commit 8aa3ed0997
6 changed files with 93 additions and 89 deletions

View File

@ -7,10 +7,10 @@ use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartitionWitness, Witness}; use plonky2::iop::witness::{PartitionWitness, Witness};
use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2_field::extension::Extendable; use plonky2_field::extension::Extendable;
use plonky2_field::types::PrimeField; use plonky2_field::types::{PrimeField, PrimeField64};
use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target};
use plonky2_u32::gadgets::multiple_comparison::list_le_u32_circuit; use plonky2_u32::gadgets::multiple_comparison::list_le_u32_circuit;
use plonky2_u32::witness::{generated_values_set_u32_target, witness_set_u32_target}; use plonky2_u32::witness::{GeneratedValuesU32, WitnessU32};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct BigUintTarget { pub struct BigUintTarget {
@ -270,41 +270,43 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderBiguint<F, D>
} }
} }
pub fn witness_get_biguint_target<W: Witness<F>, F: PrimeField>( pub trait WitnessBigUint<F: PrimeField64>: Witness<F> {
witness: &W, fn get_biguint_target(&self, target: BigUintTarget) -> BigUint;
bt: BigUintTarget, fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint);
) -> BigUint { }
bt.limbs
impl<T: Witness<F>, F: PrimeField64> WitnessBigUint<F> for T {
fn get_biguint_target(&self, target: BigUintTarget) -> BigUint {
target.limbs
.into_iter() .into_iter()
.rev() .rev()
.fold(BigUint::zero(), |acc, limb| { .fold(BigUint::zero(), |acc, limb| {
(acc << 32) + witness.get_target(limb.0).to_canonical_biguint() (acc << 32) + self.get_target(limb.0).to_canonical_biguint()
}) })
} }
pub fn witness_set_biguint_target<W: Witness<F>, F: PrimeField>( fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) {
witness: &mut W, let mut limbs = value.to_u32_digits();
target: &BigUintTarget, assert!(target.num_limbs() >= limbs.len());
value: &BigUint, limbs.resize(target.num_limbs(), 0);
) { for i in 0..target.num_limbs() {
let mut limbs = value.to_u32_digits(); self.set_u32_target(target.limbs[i], limbs[i]);
assert!(target.num_limbs() >= limbs.len()); }
limbs.resize(target.num_limbs(), 0);
for i in 0..target.num_limbs() {
witness_set_u32_target(witness, target.limbs[i], limbs[i]);
} }
} }
pub fn buffer_set_biguint_target<F: PrimeField>( pub trait GeneratedValuesBigUint<F: PrimeField> {
buffer: &mut GeneratedValues<F>, fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint);
target: &BigUintTarget, }
value: &BigUint,
) { impl<F: PrimeField> GeneratedValuesBigUint<F> for GeneratedValues<F> {
let mut limbs = value.to_u32_digits(); fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) {
assert!(target.num_limbs() >= limbs.len()); let mut limbs = value.to_u32_digits();
limbs.resize(target.num_limbs(), 0); assert!(target.num_limbs() >= limbs.len());
for i in 0..target.num_limbs() { limbs.resize(target.num_limbs(), 0);
generated_values_set_u32_target(buffer, target.get_limb(i), limbs[i]); for i in 0..target.num_limbs() {
self.set_u32_target(target.get_limb(i), limbs[i]);
}
} }
} }
@ -330,12 +332,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
} }
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) { fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = witness_get_biguint_target(witness, self.a.clone()); let a = witness.get_biguint_target(self.a.clone());
let b = witness_get_biguint_target(witness, self.b.clone()); let b = witness.get_biguint_target(self.b.clone());
let (div, rem) = a.div_rem(&b); let (div, rem) = a.div_rem(&b);
buffer_set_biguint_target(out_buffer, &self.div, &div); out_buffer.set_biguint_target(&self.div, &div);
buffer_set_biguint_target(out_buffer, &self.rem, &rem); out_buffer.set_biguint_target(&self.rem, &rem);
} }
} }
@ -350,7 +352,7 @@ mod tests {
}; };
use rand::Rng; use rand::Rng;
use crate::gadgets::biguint::{witness_set_biguint_target, CircuitBuilderBiguint}; use crate::gadgets::biguint::{WitnessBigUint, CircuitBuilderBiguint};
#[test] #[test]
fn test_biguint_add() -> Result<()> { fn test_biguint_add() -> Result<()> {
@ -373,9 +375,9 @@ mod tests {
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
builder.connect_biguint(&z, &expected_z); builder.connect_biguint(&z, &expected_z);
witness_set_biguint_target(&mut pw, &x, &x_value); pw.set_biguint_target(&x, &x_value);
witness_set_biguint_target(&mut pw, &y, &y_value); pw.set_biguint_target(&y, &y_value);
witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); pw.set_biguint_target(&expected_z, &expected_z_value);
let data = builder.build::<C>(); let data = builder.build::<C>();
let proof = data.prove(pw).unwrap(); let proof = data.prove(pw).unwrap();
@ -433,9 +435,9 @@ mod tests {
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
builder.connect_biguint(&z, &expected_z); builder.connect_biguint(&z, &expected_z);
witness_set_biguint_target(&mut pw, &x, &x_value); pw.set_biguint_target(&x, &x_value);
witness_set_biguint_target(&mut pw, &y, &y_value); pw.set_biguint_target(&y, &y_value);
witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); pw.set_biguint_target(&expected_z, &expected_z_value);
let data = builder.build::<C>(); let data = builder.build::<C>();
let proof = data.prove(pw).unwrap(); let proof = data.prove(pw).unwrap();

View File

@ -76,7 +76,7 @@ mod tests {
use crate::curve::curve_types::{Curve, CurveScalar}; use crate::curve::curve_types::{Curve, CurveScalar};
use crate::curve::secp256k1::Secp256K1; use crate::curve::secp256k1::Secp256K1;
use crate::gadgets::biguint::witness_set_biguint_target; use crate::gadgets::biguint::WitnessBigUint;
use crate::gadgets::curve::CircuitBuilderCurve; use crate::gadgets::curve::CircuitBuilderCurve;
use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit;
use crate::gadgets::nonnative::CircuitBuilderNonNative; use crate::gadgets::nonnative::CircuitBuilderNonNative;
@ -101,7 +101,7 @@ mod tests {
builder.curve_assert_valid(&res_expected); builder.curve_assert_valid(&res_expected);
let n_target = builder.add_virtual_nonnative_target::<Secp256K1Scalar>(); let n_target = builder.add_virtual_nonnative_target::<Secp256K1Scalar>();
witness_set_biguint_target(&mut pw, &n_target.value, &n.to_canonical_biguint()); pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint());
let res_target = fixed_base_curve_mul_circuit(&mut builder, g, &n_target); let res_target = fixed_base_curve_mul_circuit(&mut builder, g, &n_target);
builder.curve_assert_valid(&res_target); builder.curve_assert_valid(&res_target);

View File

@ -12,7 +12,7 @@ use plonky2_field::types::{Field, PrimeField};
use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S}; use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S};
use crate::curve::secp256k1::Secp256K1; use crate::curve::secp256k1::Secp256K1;
use crate::gadgets::biguint::{buffer_set_biguint_target, witness_get_biguint_target}; use crate::gadgets::biguint::{GeneratedValuesBigUint, WitnessBigUint};
use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve};
use crate::gadgets::curve_msm::curve_msm_circuit; use crate::gadgets::curve_msm::curve_msm_circuit;
use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget};
@ -116,15 +116,14 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
} }
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) { fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let k = Secp256K1Scalar::from_noncanonical_biguint(witness_get_biguint_target( let k = Secp256K1Scalar::from_noncanonical_biguint(witness.get_biguint_target(
witness,
self.k.value.clone(), self.k.value.clone(),
)); ));
let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k);
buffer_set_biguint_target(out_buffer, &self.k1.value, &k1.to_canonical_biguint()); out_buffer.set_biguint_target(&self.k1.value, &k1.to_canonical_biguint());
buffer_set_biguint_target(out_buffer, &self.k2.value, &k2.to_canonical_biguint()); out_buffer.set_biguint_target(&self.k2.value, &k2.to_canonical_biguint());
out_buffer.set_bool_target(self.k1_neg, k1_neg); out_buffer.set_bool_target(self.k1_neg, k1_neg);
out_buffer.set_bool_target(self.k2_neg, k2_neg); out_buffer.set_bool_target(self.k2_neg, k2_neg);
} }

View File

@ -10,11 +10,11 @@ use plonky2_field::types::PrimeField;
use plonky2_field::{extension::Extendable, types::Field}; use plonky2_field::{extension::Extendable, types::Field};
use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target};
use plonky2_u32::gadgets::range_check::range_check_u32_circuit; use plonky2_u32::gadgets::range_check::range_check_u32_circuit;
use plonky2_u32::witness::generated_values_set_u32_target; use plonky2_u32::witness::GeneratedValuesU32;
use plonky2_util::ceil_div_usize; use plonky2_util::ceil_div_usize;
use crate::gadgets::biguint::{ use crate::gadgets::biguint::{
buffer_set_biguint_target, witness_get_biguint_target, BigUintTarget, CircuitBuilderBiguint, GeneratedValuesBigUint, WitnessBigUint, BigUintTarget, CircuitBuilderBiguint,
}; };
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -467,12 +467,10 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
} }
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) { fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = FF::from_noncanonical_biguint(witness_get_biguint_target( let a = FF::from_noncanonical_biguint(witness.get_biguint_target(
witness,
self.a.value.clone(), self.a.value.clone(),
)); ));
let b = FF::from_noncanonical_biguint(witness_get_biguint_target( let b = FF::from_noncanonical_biguint(witness.get_biguint_target(
witness,
self.b.value.clone(), self.b.value.clone(),
)); ));
let a_biguint = a.to_canonical_biguint(); let a_biguint = a.to_canonical_biguint();
@ -485,7 +483,7 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
(false, sum_biguint) (false, sum_biguint)
}; };
buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); out_buffer.set_biguint_target(&self.sum.value, &sum_reduced);
out_buffer.set_bool_target(self.overflow, overflow); out_buffer.set_bool_target(self.overflow, overflow);
} }
} }
@ -514,8 +512,7 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
.summands .summands
.iter() .iter()
.map(|summand| { .map(|summand| {
FF::from_noncanonical_biguint(witness_get_biguint_target( FF::from_noncanonical_biguint(witness.get_biguint_target(
witness,
summand.value.clone(), summand.value.clone(),
)) ))
}) })
@ -533,8 +530,8 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus);
let overflow = overflow_biguint.to_u64_digits()[0] as u32; let overflow = overflow_biguint.to_u64_digits()[0] as u32;
buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); out_buffer.set_biguint_target(&self.sum.value, &sum_reduced);
generated_values_set_u32_target(out_buffer, self.overflow, overflow); out_buffer.set_u32_target(self.overflow, overflow);
} }
} }
@ -562,12 +559,10 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
} }
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) { fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = FF::from_noncanonical_biguint(witness_get_biguint_target( let a = FF::from_noncanonical_biguint(witness.get_biguint_target(
witness,
self.a.value.clone(), self.a.value.clone(),
)); ));
let b = FF::from_noncanonical_biguint(witness_get_biguint_target( let b = FF::from_noncanonical_biguint(witness.get_biguint_target(
witness,
self.b.value.clone(), self.b.value.clone(),
)); ));
let a_biguint = a.to_canonical_biguint(); let a_biguint = a.to_canonical_biguint();
@ -580,7 +575,7 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
(modulus + a_biguint - b_biguint, true) (modulus + a_biguint - b_biguint, true)
}; };
buffer_set_biguint_target(out_buffer, &self.diff.value, &diff_biguint); out_buffer.set_biguint_target(&self.diff.value, &diff_biguint);
out_buffer.set_bool_target(self.overflow, overflow); out_buffer.set_bool_target(self.overflow, overflow);
} }
} }
@ -609,12 +604,10 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
} }
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) { fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = FF::from_noncanonical_biguint(witness_get_biguint_target( let a = FF::from_noncanonical_biguint(witness.get_biguint_target(
witness,
self.a.value.clone(), self.a.value.clone(),
)); ));
let b = FF::from_noncanonical_biguint(witness_get_biguint_target( let b = FF::from_noncanonical_biguint(witness.get_biguint_target(
witness,
self.b.value.clone(), self.b.value.clone(),
)); ));
let a_biguint = a.to_canonical_biguint(); let a_biguint = a.to_canonical_biguint();
@ -625,8 +618,8 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
let modulus = FF::order(); let modulus = FF::order();
let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus);
buffer_set_biguint_target(out_buffer, &self.prod.value, &prod_reduced); out_buffer.set_biguint_target(&self.prod.value, &prod_reduced);
buffer_set_biguint_target(out_buffer, &self.overflow, &overflow_biguint); out_buffer.set_biguint_target(&self.overflow, &overflow_biguint);
} }
} }
@ -646,8 +639,7 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
} }
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) { fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let x = FF::from_noncanonical_biguint(witness_get_biguint_target( let x = FF::from_noncanonical_biguint(witness.get_biguint_target(
witness,
self.x.value.clone(), self.x.value.clone(),
)); ));
let inv = x.inverse(); let inv = x.inverse();
@ -658,8 +650,8 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
let modulus = FF::order(); let modulus = FF::order();
let (div, _rem) = prod.div_rem(&modulus); let (div, _rem) = prod.div_rem(&modulus);
buffer_set_biguint_target(out_buffer, &self.div, &div); out_buffer.set_biguint_target(&self.div, &div);
buffer_set_biguint_target(out_buffer, &self.inv, &inv_biguint); out_buffer.set_biguint_target(&self.inv, &inv_biguint);
} }
} }

View File

@ -10,7 +10,7 @@ use plonky2_field::extension::Extendable;
use crate::gates::add_many_u32::U32AddManyGate; use crate::gates::add_many_u32::U32AddManyGate;
use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::arithmetic_u32::U32ArithmeticGate;
use crate::gates::subtraction_u32::U32SubtractionGate; use crate::gates::subtraction_u32::U32SubtractionGate;
use crate::witness::generated_values_set_u32_target; use crate::witness::GeneratedValuesU32;
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct U32Target(pub Target); pub struct U32Target(pub Target);
@ -249,8 +249,8 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
let low = x_u64 as u32; let low = x_u64 as u32;
let high = (x_u64 >> 32) as u32; let high = (x_u64 >> 32) as u32;
generated_values_set_u32_target(out_buffer, self.low, low); out_buffer.set_u32_target(self.low, low);
generated_values_set_u32_target(out_buffer, self.high, high); out_buffer.set_u32_target(self.high, high);
} }
} }

View File

@ -1,21 +1,32 @@
use plonky2::iop::generator::GeneratedValues; use plonky2::iop::generator::GeneratedValues;
use plonky2::iop::witness::Witness; use plonky2::iop::witness::Witness;
use plonky2_field::types::Field; use plonky2_field::types::{Field, PrimeField64};
use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::arithmetic_u32::U32Target;
pub fn generated_values_set_u32_target<F: Field>( pub trait WitnessU32<F: PrimeField64>: Witness<F> {
buffer: &mut GeneratedValues<F>, fn set_u32_target(&mut self, target: U32Target, value: u32);
target: U32Target, fn get_u32_target(&self, target: U32Target) -> (u32, u32);
value: u32,
) {
buffer.set_target(target.0, F::from_canonical_u32(value))
} }
pub fn witness_set_u32_target<W: Witness<F>, F: Field>( impl<T: Witness<F>, F: PrimeField64> WitnessU32<F> for T {
witness: &mut W, fn set_u32_target(&mut self, target: U32Target, value: u32) {
target: U32Target, self.set_target(target.0, F::from_canonical_u32(value));
value: u32, }
) { fn get_u32_target(&self, target: U32Target) -> (u32, u32) {
witness.set_target(target.0, F::from_canonical_u32(value)) let x_u64 = self.get_target(target.0).to_canonical_u64();
let low = x_u64 as u32;
let high = (x_u64 >> 32) as u32;
(low, high)
}
}
pub trait GeneratedValuesU32<F: Field> {
fn set_u32_target(&mut self, target: U32Target, value: u32);
}
impl<F: Field> GeneratedValuesU32<F> for GeneratedValues<F> {
fn set_u32_target(&mut self, target: U32Target, value: u32) {
self.set_target(target.0, F::from_canonical_u32(value))
}
} }