optimizations and cleanup

This commit is contained in:
Nicholas Ward 2022-01-20 15:44:03 -08:00
parent 2ddfb03aea
commit c392606a9a
8 changed files with 87 additions and 63 deletions

View File

@ -315,6 +315,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let x_ext = self.convert_to_ext(x);
self.inverse_extension(x_ext).0[0]
}
pub fn not(&mut self, b: BoolTarget) -> BoolTarget {
let one = self.one();
let res = self.sub(one, b.target);
BoolTarget::new_unsafe(res)
}
}
/// Represents a base arithmetic operation in the circuit. Used to memoize results.

View File

@ -100,35 +100,21 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
p1: &AffinePointTarget<C>,
p2: &AffinePointTarget<C>,
) -> AffinePointTarget<C> {
let before = self.num_gates();
let AffinePointTarget { x: x1, y: y1 } = p1;
let AffinePointTarget { x: x2, y: y2 } = p2;
let u = self.sub_nonnative(y2, y1);
let uu = self.mul_nonnative(&u, &u);
let v = self.sub_nonnative(x2, x1);
let vv = self.mul_nonnative(&v, &v);
let vvv = self.mul_nonnative(&v, &vv);
let r = self.mul_nonnative(&vv, x1);
let diff = self.sub_nonnative(&uu, &vvv);
let r2 = self.add_nonnative(&r, &r);
let a = self.sub_nonnative(&diff, &r2);
let x3 = self.mul_nonnative(&v, &a);
let v_inv = self.inv_nonnative(&v);
let s = self.mul_nonnative(&u, &v_inv);
let s_squared = self.mul_nonnative(&s, &s);
let x_sum = self.add_nonnative(x2, x1);
let x3 = self.sub_nonnative(&s_squared, &x_sum);
let x_diff = self.sub_nonnative(&x1, &x3);
let prod = self.mul_nonnative(&s, &x_diff);
let y3 = self.sub_nonnative(&prod, &y1);
let r_a = self.sub_nonnative(&r, &a);
let y3_first = self.mul_nonnative(&u, &r_a);
let y3_second = self.mul_nonnative(&vvv, y1);
let y3 = self.sub_nonnative(&y3_first, &y3_second);
let z3_inv = self.inv_nonnative(&vvv);
let x3_norm = self.mul_nonnative(&x3, &z3_inv);
let y3_norm = self.mul_nonnative(&y3, &z3_inv);
println!("NUM GATES: {}", self.num_gates() - before);
AffinePointTarget {
x: x3_norm,
y: y3_norm,
}
AffinePointTarget { x: x3, y: y3 }
}
pub fn curve_scalar_mul<C: Curve>(
@ -136,11 +122,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
p: &AffinePointTarget<C>,
n: &NonNativeTarget<C::ScalarField>,
) -> AffinePointTarget<C> {
let one = self.constant_nonnative(C::BaseField::ONE);
let bits = self.split_nonnative_to_bits(n);
let bits_as_base: Vec<NonNativeTarget<C::BaseField>> =
bits.iter().map(|b| self.bool_to_nonnative(b)).collect();
let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine();
let randot = self.constant_affine_point(rando);
@ -151,15 +133,15 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut two_i_times_p = self.add_virtual_affine_point_target();
self.connect_affine_point(p, &two_i_times_p);
for bit in bits_as_base.iter() {
let not_bit = self.sub_nonnative(&one, bit);
for &bit in bits.iter() {
let not_bit = self.not(bit);
let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p);
let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.x);
let new_x_if_not_bit = self.mul_nonnative(&not_bit, &result.x);
let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.y);
let new_y_if_not_bit = self.mul_nonnative(&not_bit, &result.y);
let new_x_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.x, bit);
let new_x_if_not_bit = self.mul_nonnative_by_bool(&result.x, not_bit);
let new_y_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.y, bit);
let new_y_if_not_bit = self.mul_nonnative_by_bool(&result.y, not_bit);
let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit);
let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit);
@ -179,6 +161,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
#[cfg(test)]
mod tests {
use std::ops::Neg;
use anyhow::Result;
use plonky2_field::field_types::Field;
use plonky2_field::secp256k1_base::Secp256K1Base;
@ -198,7 +182,7 @@ mod tests {
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -223,7 +207,7 @@ mod tests {
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -250,7 +234,7 @@ mod tests {
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -287,7 +271,7 @@ mod tests {
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -317,27 +301,25 @@ mod tests {
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig {
num_routed_wires: 33,
..CircuitConfig::standard_recursion_config()
};
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let five = Secp256K1Scalar::from_canonical_usize(5);
let five_scalar = CurveScalar::<Secp256K1>(five);
let five_g = (five_scalar * g.to_projective()).to_affine();
let five_g_expected = builder.constant_affine_point(five_g);
builder.curve_assert_valid(&five_g_expected);
let neg_five = five.neg();
let neg_five_scalar = CurveScalar::<Secp256K1>(neg_five);
let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine();
let neg_five_g_expected = builder.constant_affine_point(neg_five_g);
builder.curve_assert_valid(&neg_five_g_expected);
let g_target = builder.constant_affine_point(g);
let five_target = builder.constant_nonnative(five);
let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target);
builder.curve_assert_valid(&five_g_actual);
let neg_five_target = builder.constant_nonnative(neg_five);
let neg_five_g_actual = builder.curve_scalar_mul(&g_target, &neg_five_target);
builder.curve_assert_valid(&neg_five_g_actual);
builder.connect_affine_point(&five_g_expected, &five_g_actual);
builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
@ -351,10 +333,7 @@ mod tests {
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig {
num_routed_wires: 33,
..CircuitConfig::standard_recursion_config()
};
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);

View File

@ -96,7 +96,7 @@ mod tests {
const D: usize = 4;
type C = Secp256K1;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -123,7 +123,6 @@ mod tests {
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -125,6 +125,9 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
_phantom: PhantomData,
});
self.range_check_u32(sum.value.limbs.clone());
self.range_check_u32(vec![overflow]);
let sum_expected = summands
.iter()
.fold(self.zero_biguint(), |a, b| self.add_biguint(&a, &b.value));
@ -157,6 +160,9 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
_phantom: PhantomData,
});
self.range_check_u32(diff.value.limbs.clone());
self.assert_bool(overflow);
let diff_plus_b = self.add_biguint(&diff.value, &b.value);
let modulus = self.constant_biguint(&FF::order());
let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow);
@ -185,6 +191,9 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
_phantom: PhantomData,
});
self.range_check_u32(prod.value.limbs.clone());
self.range_check_u32(overflow.limbs.clone());
let prod_expected = self.mul_biguint(&a.value, &b.value);
let mod_times_overflow = self.mul_biguint(&modulus, &overflow);
@ -202,12 +211,11 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
return to_mul[0].clone();
}
let mut result = self.mul_biguint(&to_mul[0].value, &to_mul[1].value);
let mut accumulator = self.mul_nonnative(&to_mul[0], &to_mul[1]);
for i in 2..to_mul.len() {
result = self.mul_biguint(&result, &to_mul[i].value);
accumulator = self.mul_nonnative(&accumulator, &to_mul[i]);
}
self.reduce(&result)
accumulator
}
pub fn neg_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {

View File

@ -41,6 +41,25 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
(low, high)
}
pub fn range_check_u32(&mut self, vals: Vec<U32Target>) {
let num_input_limbs = vals.len();
let gate = U32RangeCheckGate::<F, D>::new(num_input_limbs);
let gate_index = self.add_gate(gate, vec![]);
for i in 0..num_input_limbs {
self.connect(
Target::wire(gate_index, gate.wire_ith_input_limb(i)),
vals[i].0,
);
}
}
pub fn assert_bool(&mut self, b: BoolTarget) {
let z = self.mul_sub(b.target, b.target, b.target);
let zero = self.zero();
self.connect(z, zero);
}
}
#[derive(Debug)]

View File

@ -27,7 +27,7 @@ impl<F: RichField + Extendable<D>, const D: usize> U32RangeCheckGate<F, D> {
}
}
pub const AUX_LIMB_BITS: usize = 3;
pub const AUX_LIMB_BITS: usize = 2;
pub const BASE: usize = 1 << Self::AUX_LIMB_BITS;
fn aux_limbs_per_input_limb(&self) -> usize {
@ -243,7 +243,7 @@ mod tests {
type F = GoldilocksField;
type FF = QuarticExtension<GoldilocksField>;
const D: usize = 4;
const AUX_LIMB_BITS: usize = 3;
const AUX_LIMB_BITS: usize = 2;
const BASE: usize = 1 << AUX_LIMB_BITS;
const AUX_LIMBS_PER_INPUT_LIMB: usize = ceil_div_usize(32, AUX_LIMB_BITS);

View File

@ -205,6 +205,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
BoolTarget::new_unsafe(self.add_virtual_target())
}
pub fn add_virtual_bool_target_safe(&mut self) -> BoolTarget {
let b = BoolTarget::new_unsafe(self.add_virtual_target());
self.assert_bool(b);
b
}
/// Adds a gate to the circuit, and returns its index.
pub fn add_gate<G: Gate<F, D>>(&mut self, gate_type: G, constants: Vec<F>) -> usize {
self.check_gate_compatibility(&gate_type);
@ -235,7 +241,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fn check_gate_compatibility<G: Gate<F, D>>(&self, gate: &G) {
assert!(
gate.num_wires() <= self.config.num_wires,
"{:?} requires {} wires, but our GateConfig has only {}",
"{:?} requires {} wires, but our CircuitConfig has only {}",
gate.id(),
gate.num_wires(),
self.config.num_wires

View File

@ -49,7 +49,7 @@ pub struct CircuitConfig {
impl Default for CircuitConfig {
fn default() -> Self {
CircuitConfig::standard_recursion_config()
Self::standard_recursion_config()
}
}
@ -79,6 +79,13 @@ impl CircuitConfig {
}
}
pub fn standard_ecc_config() -> Self {
Self {
num_wires: 136,
..Self::standard_recursion_config()
}
}
pub fn standard_recursion_zk_config() -> Self {
CircuitConfig {
zero_knowledge: true,