cleaner witness extension

This commit is contained in:
Sladuca 2022-08-26 16:10:34 -04:00
parent 70971aee2d
commit 8aa3ed0997
6 changed files with 93 additions and 89 deletions

View File

@ -7,10 +7,10 @@ use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartitionWitness, Witness};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2_field::extension::Extendable;
use plonky2_field::types::PrimeField;
use plonky2_field::types::{PrimeField, PrimeField64};
use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target};
use plonky2_u32::gadgets::multiple_comparison::list_le_u32_circuit;
use plonky2_u32::witness::{generated_values_set_u32_target, witness_set_u32_target};
use plonky2_u32::witness::{GeneratedValuesU32, WitnessU32};
#[derive(Clone, Debug)]
pub struct BigUintTarget {
@ -270,41 +270,43 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderBiguint<F, D>
}
}
pub fn witness_get_biguint_target<W: Witness<F>, F: PrimeField>(
witness: &W,
bt: BigUintTarget,
) -> BigUint {
bt.limbs
pub trait WitnessBigUint<F: PrimeField64>: Witness<F> {
fn get_biguint_target(&self, target: BigUintTarget) -> BigUint;
fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint);
}
impl<T: Witness<F>, F: PrimeField64> WitnessBigUint<F> for T {
fn get_biguint_target(&self, target: BigUintTarget) -> BigUint {
target.limbs
.into_iter()
.rev()
.fold(BigUint::zero(), |acc, limb| {
(acc << 32) + witness.get_target(limb.0).to_canonical_biguint()
(acc << 32) + self.get_target(limb.0).to_canonical_biguint()
})
}
}
pub fn witness_set_biguint_target<W: Witness<F>, F: PrimeField>(
witness: &mut W,
target: &BigUintTarget,
value: &BigUint,
) {
let mut limbs = value.to_u32_digits();
assert!(target.num_limbs() >= limbs.len());
limbs.resize(target.num_limbs(), 0);
for i in 0..target.num_limbs() {
witness_set_u32_target(witness, target.limbs[i], limbs[i]);
fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) {
let mut limbs = value.to_u32_digits();
assert!(target.num_limbs() >= limbs.len());
limbs.resize(target.num_limbs(), 0);
for i in 0..target.num_limbs() {
self.set_u32_target(target.limbs[i], limbs[i]);
}
}
}
pub fn buffer_set_biguint_target<F: PrimeField>(
buffer: &mut GeneratedValues<F>,
target: &BigUintTarget,
value: &BigUint,
) {
let mut limbs = value.to_u32_digits();
assert!(target.num_limbs() >= limbs.len());
limbs.resize(target.num_limbs(), 0);
for i in 0..target.num_limbs() {
generated_values_set_u32_target(buffer, target.get_limb(i), limbs[i]);
pub trait GeneratedValuesBigUint<F: PrimeField> {
fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint);
}
impl<F: PrimeField> GeneratedValuesBigUint<F> for GeneratedValues<F> {
fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) {
let mut limbs = value.to_u32_digits();
assert!(target.num_limbs() >= limbs.len());
limbs.resize(target.num_limbs(), 0);
for i in 0..target.num_limbs() {
self.set_u32_target(target.get_limb(i), limbs[i]);
}
}
}
@ -330,12 +332,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = witness_get_biguint_target(witness, self.a.clone());
let b = witness_get_biguint_target(witness, self.b.clone());
let a = witness.get_biguint_target(self.a.clone());
let b = witness.get_biguint_target(self.b.clone());
let (div, rem) = a.div_rem(&b);
buffer_set_biguint_target(out_buffer, &self.div, &div);
buffer_set_biguint_target(out_buffer, &self.rem, &rem);
out_buffer.set_biguint_target(&self.div, &div);
out_buffer.set_biguint_target(&self.rem, &rem);
}
}
@ -350,7 +352,7 @@ mod tests {
};
use rand::Rng;
use crate::gadgets::biguint::{witness_set_biguint_target, CircuitBuilderBiguint};
use crate::gadgets::biguint::{WitnessBigUint, CircuitBuilderBiguint};
#[test]
fn test_biguint_add() -> Result<()> {
@ -373,9 +375,9 @@ mod tests {
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
builder.connect_biguint(&z, &expected_z);
witness_set_biguint_target(&mut pw, &x, &x_value);
witness_set_biguint_target(&mut pw, &y, &y_value);
witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value);
pw.set_biguint_target(&x, &x_value);
pw.set_biguint_target(&y, &y_value);
pw.set_biguint_target(&expected_z, &expected_z_value);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
@ -433,9 +435,9 @@ mod tests {
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
builder.connect_biguint(&z, &expected_z);
witness_set_biguint_target(&mut pw, &x, &x_value);
witness_set_biguint_target(&mut pw, &y, &y_value);
witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value);
pw.set_biguint_target(&x, &x_value);
pw.set_biguint_target(&y, &y_value);
pw.set_biguint_target(&expected_z, &expected_z_value);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();

View File

@ -76,7 +76,7 @@ mod tests {
use crate::curve::curve_types::{Curve, CurveScalar};
use crate::curve::secp256k1::Secp256K1;
use crate::gadgets::biguint::witness_set_biguint_target;
use crate::gadgets::biguint::WitnessBigUint;
use crate::gadgets::curve::CircuitBuilderCurve;
use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit;
use crate::gadgets::nonnative::CircuitBuilderNonNative;
@ -101,7 +101,7 @@ mod tests {
builder.curve_assert_valid(&res_expected);
let n_target = builder.add_virtual_nonnative_target::<Secp256K1Scalar>();
witness_set_biguint_target(&mut pw, &n_target.value, &n.to_canonical_biguint());
pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint());
let res_target = fixed_base_curve_mul_circuit(&mut builder, g, &n_target);
builder.curve_assert_valid(&res_target);

View File

@ -12,7 +12,7 @@ use plonky2_field::types::{Field, PrimeField};
use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S};
use crate::curve::secp256k1::Secp256K1;
use crate::gadgets::biguint::{buffer_set_biguint_target, witness_get_biguint_target};
use crate::gadgets::biguint::{GeneratedValuesBigUint, WitnessBigUint};
use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve};
use crate::gadgets::curve_msm::curve_msm_circuit;
use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget};
@ -116,15 +116,14 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let k = Secp256K1Scalar::from_noncanonical_biguint(witness_get_biguint_target(
witness,
let k = Secp256K1Scalar::from_noncanonical_biguint(witness.get_biguint_target(
self.k.value.clone(),
));
let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k);
buffer_set_biguint_target(out_buffer, &self.k1.value, &k1.to_canonical_biguint());
buffer_set_biguint_target(out_buffer, &self.k2.value, &k2.to_canonical_biguint());
out_buffer.set_biguint_target(&self.k1.value, &k1.to_canonical_biguint());
out_buffer.set_biguint_target(&self.k2.value, &k2.to_canonical_biguint());
out_buffer.set_bool_target(self.k1_neg, k1_neg);
out_buffer.set_bool_target(self.k2_neg, k2_neg);
}

View File

@ -10,11 +10,11 @@ use plonky2_field::types::PrimeField;
use plonky2_field::{extension::Extendable, types::Field};
use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target};
use plonky2_u32::gadgets::range_check::range_check_u32_circuit;
use plonky2_u32::witness::generated_values_set_u32_target;
use plonky2_u32::witness::GeneratedValuesU32;
use plonky2_util::ceil_div_usize;
use crate::gadgets::biguint::{
buffer_set_biguint_target, witness_get_biguint_target, BigUintTarget, CircuitBuilderBiguint,
GeneratedValuesBigUint, WitnessBigUint, BigUintTarget, CircuitBuilderBiguint,
};
#[derive(Clone, Debug)]
@ -467,12 +467,10 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = FF::from_noncanonical_biguint(witness_get_biguint_target(
witness,
let a = FF::from_noncanonical_biguint(witness.get_biguint_target(
self.a.value.clone(),
));
let b = FF::from_noncanonical_biguint(witness_get_biguint_target(
witness,
let b = FF::from_noncanonical_biguint(witness.get_biguint_target(
self.b.value.clone(),
));
let a_biguint = a.to_canonical_biguint();
@ -485,7 +483,7 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
(false, sum_biguint)
};
buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced);
out_buffer.set_biguint_target(&self.sum.value, &sum_reduced);
out_buffer.set_bool_target(self.overflow, overflow);
}
}
@ -514,8 +512,7 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
.summands
.iter()
.map(|summand| {
FF::from_noncanonical_biguint(witness_get_biguint_target(
witness,
FF::from_noncanonical_biguint(witness.get_biguint_target(
summand.value.clone(),
))
})
@ -533,8 +530,8 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus);
let overflow = overflow_biguint.to_u64_digits()[0] as u32;
buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced);
generated_values_set_u32_target(out_buffer, self.overflow, overflow);
out_buffer.set_biguint_target(&self.sum.value, &sum_reduced);
out_buffer.set_u32_target(self.overflow, overflow);
}
}
@ -562,12 +559,10 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = FF::from_noncanonical_biguint(witness_get_biguint_target(
witness,
let a = FF::from_noncanonical_biguint(witness.get_biguint_target(
self.a.value.clone(),
));
let b = FF::from_noncanonical_biguint(witness_get_biguint_target(
witness,
let b = FF::from_noncanonical_biguint(witness.get_biguint_target(
self.b.value.clone(),
));
let a_biguint = a.to_canonical_biguint();
@ -580,7 +575,7 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
(modulus + a_biguint - b_biguint, true)
};
buffer_set_biguint_target(out_buffer, &self.diff.value, &diff_biguint);
out_buffer.set_biguint_target(&self.diff.value, &diff_biguint);
out_buffer.set_bool_target(self.overflow, overflow);
}
}
@ -609,12 +604,10 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = FF::from_noncanonical_biguint(witness_get_biguint_target(
witness,
let a = FF::from_noncanonical_biguint(witness.get_biguint_target(
self.a.value.clone(),
));
let b = FF::from_noncanonical_biguint(witness_get_biguint_target(
witness,
let b = FF::from_noncanonical_biguint(witness.get_biguint_target(
self.b.value.clone(),
));
let a_biguint = a.to_canonical_biguint();
@ -625,8 +618,8 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
let modulus = FF::order();
let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus);
buffer_set_biguint_target(out_buffer, &self.prod.value, &prod_reduced);
buffer_set_biguint_target(out_buffer, &self.overflow, &overflow_biguint);
out_buffer.set_biguint_target(&self.prod.value, &prod_reduced);
out_buffer.set_biguint_target(&self.overflow, &overflow_biguint);
}
}
@ -646,8 +639,7 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let x = FF::from_noncanonical_biguint(witness_get_biguint_target(
witness,
let x = FF::from_noncanonical_biguint(witness.get_biguint_target(
self.x.value.clone(),
));
let inv = x.inverse();
@ -658,8 +650,8 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
let modulus = FF::order();
let (div, _rem) = prod.div_rem(&modulus);
buffer_set_biguint_target(out_buffer, &self.div, &div);
buffer_set_biguint_target(out_buffer, &self.inv, &inv_biguint);
out_buffer.set_biguint_target(&self.div, &div);
out_buffer.set_biguint_target(&self.inv, &inv_biguint);
}
}

View File

@ -10,7 +10,7 @@ use plonky2_field::extension::Extendable;
use crate::gates::add_many_u32::U32AddManyGate;
use crate::gates::arithmetic_u32::U32ArithmeticGate;
use crate::gates::subtraction_u32::U32SubtractionGate;
use crate::witness::generated_values_set_u32_target;
use crate::witness::GeneratedValuesU32;
#[derive(Clone, Copy, Debug)]
pub struct U32Target(pub Target);
@ -249,8 +249,8 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
let low = x_u64 as u32;
let high = (x_u64 >> 32) as u32;
generated_values_set_u32_target(out_buffer, self.low, low);
generated_values_set_u32_target(out_buffer, self.high, high);
out_buffer.set_u32_target(self.low, low);
out_buffer.set_u32_target(self.high, high);
}
}

View File

@ -1,21 +1,32 @@
use plonky2::iop::generator::GeneratedValues;
use plonky2::iop::witness::Witness;
use plonky2_field::types::Field;
use plonky2_field::types::{Field, PrimeField64};
use crate::gadgets::arithmetic_u32::U32Target;
pub fn generated_values_set_u32_target<F: Field>(
buffer: &mut GeneratedValues<F>,
target: U32Target,
value: u32,
) {
buffer.set_target(target.0, F::from_canonical_u32(value))
pub trait WitnessU32<F: PrimeField64>: Witness<F> {
fn set_u32_target(&mut self, target: U32Target, value: u32);
fn get_u32_target(&self, target: U32Target) -> (u32, u32);
}
pub fn witness_set_u32_target<W: Witness<F>, F: Field>(
witness: &mut W,
target: U32Target,
value: u32,
) {
witness.set_target(target.0, F::from_canonical_u32(value))
impl<T: Witness<F>, F: PrimeField64> WitnessU32<F> for T {
fn set_u32_target(&mut self, target: U32Target, value: u32) {
self.set_target(target.0, F::from_canonical_u32(value));
}
fn get_u32_target(&self, target: U32Target) -> (u32, u32) {
let x_u64 = self.get_target(target.0).to_canonical_u64();
let low = x_u64 as u32;
let high = (x_u64 >> 32) as u32;
(low, high)
}
}
pub trait GeneratedValuesU32<F: Field> {
fn set_u32_target(&mut self, target: U32Target, value: u32);
}
impl<F: Field> GeneratedValuesU32<F> for GeneratedValues<F> {
fn set_u32_target(&mut self, target: U32Target, value: u32) {
self.set_target(target.0, F::from_canonical_u32(value))
}
}