Working GLV with MSM

This commit is contained in:
wborgeaud 2022-03-02 13:19:31 +01:00
parent 850df4dfb1
commit 7c70c46ca7
7 changed files with 152 additions and 73 deletions

View File

@ -78,15 +78,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
) -> AffinePointTarget<C> {
let not_b = self.not(b);
let neg = self.curve_neg(p);
let x_if_true = self.mul_nonnative_by_bool(&neg.x, b);
let y_if_true = self.mul_nonnative_by_bool(&neg.y, b);
let x_if_false = self.mul_nonnative_by_bool(&p.x, not_b);
let y_if_false = self.mul_nonnative_by_bool(&p.y, not_b);
let x = self.add_nonnative(&x_if_true, &x_if_false);
let y = self.add_nonnative(&y_if_true, &y_if_false);
AffinePointTarget { x, y }
AffinePointTarget { x: p.x.clone(), y }
}
pub fn curve_double<C: Curve>(&mut self, p: &AffinePointTarget<C>) -> AffinePointTarget<C> {

View File

@ -11,6 +11,8 @@ use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::config::{GenericHashOut, Hasher};
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Do windowed fixed-base scalar multiplication, using a 4-bit window.
// TODO: Benchmark other window sizes.
pub fn fixed_base_curve_mul<C: Curve>(
&mut self,
base: &AffinePoint<C>,

View File

@ -3,8 +3,8 @@ use plonky2_field::extension_field::Extendable;
use crate::curve::curve_types::{Curve, CurveScalar};
use crate::field::field_types::Field;
use crate::gadgets::biguint::BigUintTarget;
use crate::gadgets::curve::AffinePointTarget;
use crate::gadgets::nonnative::NonNativeTarget;
use crate::hash::hash_types::RichField;
use crate::hash::keccak::KeccakHash;
use crate::plonk::circuit_builder::CircuitBuilder;
@ -16,12 +16,13 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
&mut self,
p: &AffinePointTarget<C>,
q: &AffinePointTarget<C>,
n: &NonNativeTarget<C::ScalarField>,
m: &NonNativeTarget<C::ScalarField>,
n: &BigUintTarget,
m: &BigUintTarget,
) -> AffinePointTarget<C> {
let limbs_n = self.split_nonnative_to_2_bit_limbs(n);
let limbs_m = self.split_nonnative_to_2_bit_limbs(m);
let limbs_n = self.split_biguint_to_2_bit_limbs(n);
let limbs_m = self.split_biguint_to_2_bit_limbs(m);
assert_eq!(limbs_n.len(), limbs_m.len());
let num_limbs = limbs_n.len();
let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]);
let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le(
@ -63,8 +64,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let should_add = self.not(is_zero);
result = self.curve_conditional_add(&result, &r, should_add);
}
let starting_point_multiplied =
(0..C::ScalarField::BITS).fold(rando, |acc, _| acc.double());
let starting_point_multiplied = (0..2 * num_limbs).fold(rando, |acc, _| acc.double());
let to_add = self.constant_affine_point(-starting_point_multiplied);
result = self.curve_add(&result, &to_add);
@ -74,11 +74,14 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
#[cfg(test)]
mod tests {
use std::str::FromStr;
use anyhow::Result;
use num::BigUint;
use plonky2_field::secp256k1_base::Secp256K1Base;
use plonky2_field::secp256k1_scalar::Secp256K1Scalar;
use crate::curve::curve_types::{Curve, CurveScalar};
use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar};
use crate::curve::secp256k1::Secp256K1;
use crate::field::field_types::Field;
use crate::iop::witness::PartialWitness;
@ -115,7 +118,7 @@ mod tests {
let n_target = builder.constant_nonnative(n);
let m_target = builder.constant_nonnative(m);
let res_target = builder.curve_msm(&p_target, &q_target, &n_target, &m_target);
let res_target = builder.curve_msm(&p_target, &q_target, &n_target.value, &m_target.value);
builder.curve_assert_valid(&res_target);
builder.connect_affine_point(&res_target, &res_expected);
@ -168,4 +171,72 @@ mod tests {
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_curve_lul() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let p = AffinePoint::<Secp256K1> {
x: Secp256K1Base::from_biguint(
BigUint::from_str(
"95702873347299649035220040874584348285675823985309557645567012532974768144045",
)
.unwrap(),
),
y: Secp256K1Base::from_biguint(
BigUint::from_str(
"34849299245821426255020320369755722155634282348110887335812955146294938249053",
)
.unwrap(),
),
zero: false,
};
let q = AffinePoint::<Secp256K1> {
x: Secp256K1Base::from_biguint(
BigUint::from_str(
"66037057977021147605301350925941983227524093291368248236634649161657340356645",
)
.unwrap(),
),
y: Secp256K1Base::from_biguint(
BigUint::from_str(
"80942789991494769168550664638932185697635702317529676703644628861613896422610",
)
.unwrap(),
),
zero: false,
};
let n = BigUint::from_str("89874493710619023150462632713212469930").unwrap();
let m = BigUint::from_str("76073901947022186525975758425319149118").unwrap();
let res = (CurveScalar(Secp256K1Scalar::from_biguint(n.clone())) * p.to_projective()
+ CurveScalar(Secp256K1Scalar::from_biguint(m.clone())) * q.to_projective())
.to_affine();
let res_expected = builder.constant_affine_point(res);
builder.curve_assert_valid(&res_expected);
let p_target = builder.constant_affine_point(p);
let q_target = builder.constant_affine_point(q);
let n_target = builder.constant_biguint(&n);
let m_target = builder.constant_biguint(&m);
let res_target = builder.curve_msm(&p_target, &q_target, &n_target, &m_target);
builder.curve_assert_valid(&res_target);
builder.connect_affine_point(&res_target, &res_expected);
dbg!(builder.num_gates());
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -43,6 +43,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
_phantom: PhantomData,
});
// debug_assert!(k1_raw + S * k2_raw == k);
(k1, k2, k1_neg, k2_neg)
}
@ -60,24 +62,9 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
y: p.y.clone(),
};
// let part1 = self.curve_scalar_mul_windowed(p, &k1);
// let part1_neg = self.curve_conditional_neg(&part1, k1_neg);
// let part2 = self.curve_scalar_mul_windowed(&sp, &k2);
// let part2_neg = self.curve_conditional_neg(&part2, k2_neg);
//
// self.curve_add(&part1_neg, &part2_neg)
// dbg!(k1.value.limbs.len());
// dbg!(k2.value.limbs.len());
let p_neg = self.curve_conditional_neg(&p, k1_neg);
let sp_neg = self.curve_conditional_neg(&sp, k2_neg);
// let yo = self.curve_scalar_mul_windowed(&p_neg, &k1);
// let ya = self.curve_scalar_mul_windowed(&sp_neg, &k2);
// dbg!(&yo);
// dbg!(&ya);
// self.connect_affine_point(&part1_neg, &yo);
// self.connect_affine_point(&part2_neg, &ya);
self.curve_msm(&p_neg, &sp_neg, &k1, &k2)
// self.curve_add(&yo, &ya)
self.curve_msm(&p_neg, &sp_neg, &k1.value, &k2.value)
}
}
@ -118,7 +105,7 @@ mod tests {
use crate::curve::curve_types::{Curve, CurveScalar};
use crate::curve::glv::glv_mul;
use crate::curve::secp256k1::Secp256K1;
use crate::iop::witness::{PartialWitness, Witness};
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
@ -153,40 +140,4 @@ mod tests {
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_wtf() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_ecc_config();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let rando =
(CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine();
let randot = builder.constant_affine_point(rando);
let scalar = Secp256K1Scalar::rand();
let scalar_target = builder.constant_nonnative(scalar);
let tr = builder.add_virtual_bool_target();
pw.set_bool_target(tr, false);
let randotneg = builder.curve_conditional_neg(&randot, tr);
let y = builder.curve_scalar_mul_windowed(&randotneg, &scalar_target);
let yy = builder.curve_scalar_mul_windowed(&randot, &scalar_target);
let yy = builder.curve_conditional_neg(&yy, tr);
builder.connect_affine_point(&y, &yy);
dbg!(builder.num_gates());
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -454,7 +454,7 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: PrimeField> SimpleGenerat
let b_biguint = b.to_canonical_biguint();
let modulus = FF::order();
let (diff_biguint, overflow) = if a_biguint > b_biguint {
let (diff_biguint, overflow) = if a_biguint >= b_biguint {
(a_biguint - b_biguint, false)
} else {
(modulus + a_biguint - b_biguint, true)

View File

@ -35,12 +35,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.collect()
}
pub fn split_nonnative_to_2_bit_limbs<FF: Field>(
&mut self,
val: &NonNativeTarget<FF>,
) -> Vec<Target> {
val.value
.limbs
pub fn split_biguint_to_2_bit_limbs(&mut self, val: &BigUintTarget) -> Vec<Target> {
val.limbs
.iter()
.flat_map(|&l| self.split_le_base::<4>(l.0, 16))
.collect()

View File

@ -89,6 +89,68 @@ pub(crate) fn generate_partial_witness<
}
pending_generator_indices = next_pending_generator_indices;
// for t in [
// Target::VirtualTarget { index: 57934 },
// Target::VirtualTarget { index: 57935 },
// Target::VirtualTarget { index: 57936 },
// Target::VirtualTarget { index: 57937 },
// Target::VirtualTarget { index: 57938 },
// Target::VirtualTarget { index: 57939 },
// Target::VirtualTarget { index: 57940 },
// Target::VirtualTarget { index: 57941 },
// ] {
// if let Some(v) = witness.try_get_target(t) {
// println!("a {}", v);
// }
// }
// for t in [
// Target::VirtualTarget { index: 57952 },
// Target::VirtualTarget { index: 57953 },
// Target::VirtualTarget { index: 57954 },
// Target::VirtualTarget { index: 57955 },
// Target::VirtualTarget { index: 57956 },
// Target::VirtualTarget { index: 57957 },
// Target::VirtualTarget { index: 57958 },
// Target::VirtualTarget { index: 57959 },
// ] {
// if let Some(v) = witness.try_get_target(t) {
// println!("b {}", v);
// }
// }
//
// let t = Target::Wire(Wire {
// gate: 141_857,
// input: 8,
// });
// if let Some(v) = witness.try_get_target(t) {
// println!("prod_exp {}", v);
// }
// let t = Target::Wire(Wire {
// gate: 141_863,
// input: 22,
// });
// if let Some(v) = witness.try_get_target(t) {
// println!("prod act {}", v);
// }
// let t = Target::Wire(Wire { gate: 9, input: 3 });
// if let Some(v) = witness.try_get_target(t) {
// println!("modulus {}", v);
// }
// let t = Target::VirtualTarget { index: 57_976 };
// if let Some(v) = witness.try_get_target(t) {
// println!("overflow {}", v);
// }
// let t = Target::Wire(Wire {
// gate: 141_885,
// input: 8,
// });
// if let Some(v) = witness.try_get_target(t) {
// println!("mod time ov {}", v);
// }
// let t = Target::VirtualTarget { index: 57_968 };
// if let Some(v) = witness.try_get_target(t) {
// println!("prod {}", v);
// }
}
assert_eq!(