Merge pull request #377 from mir-protocol/ecdsa

ECDSA
This commit is contained in:
Nicholas Ward 2022-01-31 12:23:18 -08:00 committed by GitHub
commit bcbc987a8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1645 additions and 141 deletions

View File

@ -404,7 +404,7 @@ mod tests {
v.extend(equality_dummy_vals);
v.extend(insert_here_vals);
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
v.iter().map(|&x| x.into()).collect()
}
let orig_vec = vec![FF::rand(); 3];

View File

@ -259,3 +259,11 @@ impl<C: Curve> Neg for ProjectivePoint<C> {
ProjectivePoint { x, y: -y, z }
}
}
pub fn base_to_scalar<C: Curve>(x: C::BaseField) -> C::ScalarField {
C::ScalarField::from_biguint(x.to_biguint())
}
pub fn scalar_to_base<C: Curve>(x: C::ScalarField) -> C::BaseField {
C::BaseField::from_biguint(x.to_biguint())
}

View File

@ -0,0 +1,72 @@
use crate::curve::curve_msm::msm_parallel;
use crate::curve::curve_types::{base_to_scalar, AffinePoint, Curve, CurveScalar};
use crate::field::field_types::Field;
pub struct ECDSASignature<C: Curve> {
pub r: C::ScalarField,
pub s: C::ScalarField,
}
pub struct ECDSASecretKey<C: Curve>(pub C::ScalarField);
pub struct ECDSAPublicKey<C: Curve>(pub AffinePoint<C>);
pub fn sign_message<C: Curve>(msg: C::ScalarField, sk: ECDSASecretKey<C>) -> ECDSASignature<C> {
let (k, rr) = {
let mut k = C::ScalarField::rand();
let mut rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine();
while rr.x == C::BaseField::ZERO {
k = C::ScalarField::rand();
rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine();
}
(k, rr)
};
let r = base_to_scalar::<C>(rr.x);
let s = k.inverse() * (msg + r * sk.0);
ECDSASignature { r, s }
}
pub fn verify_message<C: Curve>(
msg: C::ScalarField,
sig: ECDSASignature<C>,
pk: ECDSAPublicKey<C>,
) -> bool {
let ECDSASignature { r, s } = sig;
assert!(pk.0.is_valid());
let c = s.inverse();
let u1 = msg * c;
let u2 = r * c;
let g = C::GENERATOR_PROJECTIVE;
let w = 5; // Experimentally fastest
let point_proj = msm_parallel(&[u1, u2], &[g, pk.0.to_projective()], w);
let point = point_proj.to_affine();
let x = base_to_scalar::<C>(point.x);
r == x
}
#[cfg(test)]
mod tests {
use crate::curve::curve_types::{Curve, CurveScalar};
use crate::curve::ecdsa::{sign_message, verify_message, ECDSAPublicKey, ECDSASecretKey};
use crate::curve::secp256k1::Secp256K1;
use crate::field::field_types::Field;
use crate::field::secp256k1_scalar::Secp256K1Scalar;
#[test]
fn test_ecdsa_native() {
type C = Secp256K1;
let msg = Secp256K1Scalar::rand();
let sk = ECDSASecretKey(Secp256K1Scalar::rand());
let pk = ECDSAPublicKey((CurveScalar(sk.0) * C::GENERATOR_PROJECTIVE).to_affine());
let sig = sign_message(msg, sk);
let result = verify_message(msg, sig, pk);
assert!(result);
}
}

View File

@ -3,4 +3,5 @@ pub mod curve_msm;
pub mod curve_multiplication;
pub mod curve_summation;
pub mod curve_types;
pub mod ecdsa;
pub mod secp256k1;

View File

@ -315,6 +315,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let x_ext = self.convert_to_ext(x);
self.inverse_extension(x_ext).0[0]
}
pub fn not(&mut self, b: BoolTarget) -> BoolTarget {
let one = self.one();
let res = self.sub(one, b.target);
BoolTarget::new_unsafe(res)
}
}
/// Represents a base arithmetic operation in the circuit. Used to memoize results.

View File

@ -1,9 +1,14 @@
use std::marker::PhantomData;
use plonky2_field::extension_field::Extendable;
use crate::gates::add_many_u32::U32AddManyGate;
use crate::gates::arithmetic_u32::U32ArithmeticGate;
use crate::gates::subtraction_u32::U32SubtractionGate;
use crate::hash::hash_types::RichField;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
#[derive(Clone, Copy, Debug)]
@ -113,18 +118,57 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
1 => (to_add[0], self.zero_u32()),
2 => self.add_u32(to_add[0], to_add[1]),
_ => {
let (mut low, mut carry) = self.add_u32(to_add[0], to_add[1]);
for i in 2..to_add.len() {
let (new_low, new_carry) = self.add_u32(to_add[i], low);
let (combined_carry, _zero) = self.add_u32(carry, new_carry);
low = new_low;
carry = combined_carry;
let num_addends = to_add.len();
let gate = U32AddManyGate::<F, D>::new_from_config(&self.config, num_addends);
let (gate_index, copy) = self.find_u32_add_many_gate(num_addends);
for j in 0..num_addends {
self.connect(
Target::wire(gate_index, gate.wire_ith_op_jth_addend(copy, j)),
to_add[j].0,
);
}
(low, carry)
let zero = self.zero();
self.connect(Target::wire(gate_index, gate.wire_ith_carry(copy)), zero);
let output_low =
U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy)));
let output_high =
U32Target(Target::wire(gate_index, gate.wire_ith_output_carry(copy)));
(output_low, output_high)
}
}
}
pub fn add_u32s_with_carry(
&mut self,
to_add: &[U32Target],
carry: U32Target,
) -> (U32Target, U32Target) {
if to_add.len() == 1 {
return self.add_u32(to_add[0], carry);
}
let num_addends = to_add.len();
let gate = U32AddManyGate::<F, D>::new_from_config(&self.config, num_addends);
let (gate_index, copy) = self.find_u32_add_many_gate(num_addends);
for j in 0..num_addends {
self.connect(
Target::wire(gate_index, gate.wire_ith_op_jth_addend(copy, j)),
to_add[j].0,
);
}
self.connect(Target::wire(gate_index, gate.wire_ith_carry(copy)), carry.0);
let output = U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy)));
let output_carry = U32Target(Target::wire(gate_index, gate.wire_ith_output_carry(copy)));
(output, output_carry)
}
pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) {
let zero = self.zero_u32();
self.mul_add_u32(a, b, zero)
@ -153,3 +197,75 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
(output_result, output_borrow)
}
}
#[derive(Debug)]
struct SplitToU32Generator<F: RichField + Extendable<D>, const D: usize> {
x: Target,
low: U32Target,
high: U32Target,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for SplitToU32Generator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
vec![self.x]
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let x = witness.get_target(self.x);
let x_u64 = x.to_canonical_u64();
let low = x_u64 as u32;
let high = (x_u64 >> 32) as u32;
out_buffer.set_u32_target(self.low, low);
out_buffer.set_u32_target(self.high, high);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use rand::{thread_rng, Rng};
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use crate::plonk::verifier::verify;
#[test]
pub fn test_add_many_u32s() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
const NUM_ADDENDS: usize = 15;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut rng = thread_rng();
let mut to_add = Vec::new();
let mut sum = 0u64;
for _ in 0..NUM_ADDENDS {
let x: u32 = rng.gen();
sum += x as u64;
to_add.push(builder.constant_u32(x));
}
let carry = builder.zero_u32();
let (result_low, result_high) = builder.add_u32s_with_carry(&to_add, carry);
let expected_low = builder.constant_u32((sum % (1 << 32)) as u32);
let expected_high = builder.constant_u32((sum >> 32) as u32);
builder.connect_u32(result_low, expected_low);
builder.connect_u32(result_high, expected_high);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -1,6 +1,6 @@
use std::marker::PhantomData;
use num::{BigUint, Integer};
use num::{BigUint, Integer, Zero};
use plonky2_field::extension_field::Extendable;
use crate::gadgets::arithmetic_u32::U32Target;
@ -33,6 +33,10 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
BigUintTarget { limbs }
}
pub fn zero_biguint(&mut self) -> BigUintTarget {
self.constant_biguint(&BigUint::zero())
}
pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) {
let min_limbs = lhs.num_limbs().min(rhs.num_limbs());
for i in 0..min_limbs {
@ -76,9 +80,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget {
let limbs = (0..num_limbs)
.map(|_| self.add_virtual_u32_target())
.collect();
let limbs = self.add_virtual_u32_targets(num_limbs);
BigUintTarget { limbs }
}
@ -143,8 +145,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut combined_limbs = vec![];
let mut carry = self.zero_u32();
for summands in &mut to_add {
summands.push(carry);
let (new_result, new_carry) = self.add_many_u32(summands);
let (new_result, new_carry) = self.add_u32s_with_carry(summands, carry);
combined_limbs.push(new_result);
carry = new_carry;
}
@ -155,6 +156,18 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
pub fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget {
let t = b.target;
BigUintTarget {
limbs: a
.limbs
.iter()
.map(|&l| U32Target(self.mul(l.0, t)))
.collect(),
}
}
// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function).
pub fn mul_add_biguint(
&mut self,

View File

@ -104,29 +104,17 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let AffinePointTarget { x: x2, y: y2 } = p2;
let u = self.sub_nonnative(y2, y1);
let uu = self.mul_nonnative(&u, &u);
let v = self.sub_nonnative(x2, x1);
let vv = self.mul_nonnative(&v, &v);
let vvv = self.mul_nonnative(&v, &vv);
let r = self.mul_nonnative(&vv, x1);
let diff = self.sub_nonnative(&uu, &vvv);
let r2 = self.add_nonnative(&r, &r);
let a = self.sub_nonnative(&diff, &r2);
let x3 = self.mul_nonnative(&v, &a);
let v_inv = self.inv_nonnative(&v);
let s = self.mul_nonnative(&u, &v_inv);
let s_squared = self.mul_nonnative(&s, &s);
let x_sum = self.add_nonnative(x2, x1);
let x3 = self.sub_nonnative(&s_squared, &x_sum);
let x_diff = self.sub_nonnative(x1, &x3);
let prod = self.mul_nonnative(&s, &x_diff);
let y3 = self.sub_nonnative(&prod, y1);
let r_a = self.sub_nonnative(&r, &a);
let y3_first = self.mul_nonnative(&u, &r_a);
let y3_second = self.mul_nonnative(&vvv, y1);
let y3 = self.sub_nonnative(&y3_first, &y3_second);
let z3_inv = self.inv_nonnative(&vvv);
let x3_norm = self.mul_nonnative(&x3, &z3_inv);
let y3_norm = self.mul_nonnative(&y3, &z3_inv);
AffinePointTarget {
x: x3_norm,
y: y3_norm,
}
AffinePointTarget { x: x3, y: y3 }
}
pub fn curve_scalar_mul<C: Curve>(
@ -134,11 +122,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
p: &AffinePointTarget<C>,
n: &NonNativeTarget<C::ScalarField>,
) -> AffinePointTarget<C> {
let one = self.constant_nonnative(C::BaseField::ONE);
let bits = self.split_nonnative_to_bits(n);
let bits_as_base: Vec<NonNativeTarget<C::BaseField>> =
bits.iter().map(|b| self.bool_to_nonnative(b)).collect();
let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine();
let randot = self.constant_affine_point(rando);
@ -149,15 +133,15 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut two_i_times_p = self.add_virtual_affine_point_target();
self.connect_affine_point(p, &two_i_times_p);
for bit in bits_as_base.iter() {
let not_bit = self.sub_nonnative(&one, bit);
for &bit in bits.iter() {
let not_bit = self.not(bit);
let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p);
let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.x);
let new_x_if_not_bit = self.mul_nonnative(&not_bit, &result.x);
let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.y);
let new_y_if_not_bit = self.mul_nonnative(&not_bit, &result.y);
let new_x_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.x, bit);
let new_x_if_not_bit = self.mul_nonnative_by_bool(&result.x, not_bit);
let new_y_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.y, bit);
let new_y_if_not_bit = self.mul_nonnative_by_bool(&result.y, not_bit);
let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit);
let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit);
@ -177,6 +161,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
#[cfg(test)]
mod tests {
use std::ops::Neg;
use anyhow::Result;
use plonky2_field::field_types::Field;
use plonky2_field::secp256k1_base::Secp256K1Base;
@ -196,7 +182,7 @@ mod tests {
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -221,7 +207,7 @@ mod tests {
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -248,7 +234,7 @@ mod tests {
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -285,7 +271,7 @@ mod tests {
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -310,33 +296,30 @@ mod tests {
}
#[test]
#[ignore]
fn test_curve_mul() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig {
num_routed_wires: 33,
..CircuitConfig::standard_recursion_config()
};
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let five = Secp256K1Scalar::from_canonical_usize(5);
let five_scalar = CurveScalar::<Secp256K1>(five);
let five_g = (five_scalar * g.to_projective()).to_affine();
let five_g_expected = builder.constant_affine_point(five_g);
builder.curve_assert_valid(&five_g_expected);
let neg_five = five.neg();
let neg_five_scalar = CurveScalar::<Secp256K1>(neg_five);
let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine();
let neg_five_g_expected = builder.constant_affine_point(neg_five_g);
builder.curve_assert_valid(&neg_five_g_expected);
let g_target = builder.constant_affine_point(g);
let five_target = builder.constant_nonnative(five);
let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target);
builder.curve_assert_valid(&five_g_actual);
let neg_five_target = builder.constant_nonnative(neg_five);
let neg_five_g_actual = builder.curve_scalar_mul(&g_target, &neg_five_target);
builder.curve_assert_valid(&neg_five_g_actual);
builder.connect_affine_point(&five_g_expected, &five_g_actual);
builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
@ -345,16 +328,12 @@ mod tests {
}
#[test]
#[ignore]
fn test_curve_random() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig {
num_routed_wires: 33,
..CircuitConfig::standard_recursion_config()
};
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);

View File

@ -0,0 +1,100 @@
use std::marker::PhantomData;
use crate::curve::curve_types::Curve;
use crate::field::extension_field::Extendable;
use crate::gadgets::curve::AffinePointTarget;
use crate::gadgets::nonnative::NonNativeTarget;
use crate::hash::hash_types::RichField;
use crate::plonk::circuit_builder::CircuitBuilder;
pub struct ECDSASecretKeyTarget<C: Curve>(NonNativeTarget<C::ScalarField>);
pub struct ECDSAPublicKeyTarget<C: Curve>(AffinePointTarget<C>);
pub struct ECDSASignatureTarget<C: Curve> {
pub r: NonNativeTarget<C::ScalarField>,
pub s: NonNativeTarget<C::ScalarField>,
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn verify_message<C: Curve>(
&mut self,
msg: NonNativeTarget<C::ScalarField>,
sig: ECDSASignatureTarget<C>,
pk: ECDSAPublicKeyTarget<C>,
) {
let ECDSASignatureTarget { r, s } = sig;
self.curve_assert_valid(&pk.0);
let c = self.inv_nonnative(&s);
let u1 = self.mul_nonnative(&msg, &c);
let u2 = self.mul_nonnative(&r, &c);
let g = self.constant_affine_point(C::GENERATOR_AFFINE);
let point1 = self.curve_scalar_mul(&g, &u1);
let point2 = self.curve_scalar_mul(&pk.0, &u2);
let point = self.curve_add(&point1, &point2);
let x = NonNativeTarget::<C::ScalarField> {
value: point.x.value,
_phantom: PhantomData,
};
self.connect_nonnative(&r, &x);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::curve::curve_types::{Curve, CurveScalar};
use crate::curve::ecdsa::{sign_message, ECDSAPublicKey, ECDSASecretKey, ECDSASignature};
use crate::curve::secp256k1::Secp256K1;
use crate::field::field_types::Field;
use crate::field::secp256k1_scalar::Secp256K1Scalar;
use crate::gadgets::ecdsa::{ECDSAPublicKeyTarget, ECDSASignatureTarget};
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use crate::plonk::verifier::verify;
#[test]
#[ignore]
fn test_ecdsa_circuit() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type Curve = Secp256K1;
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let msg = Secp256K1Scalar::rand();
let msg_target = builder.constant_nonnative(msg);
let sk = ECDSASecretKey::<Curve>(Secp256K1Scalar::rand());
let pk = ECDSAPublicKey((CurveScalar(sk.0) * Curve::GENERATOR_PROJECTIVE).to_affine());
let pk_target = ECDSAPublicKeyTarget(builder.constant_affine_point(pk.0));
let sig = sign_message(msg, sk);
let ECDSASignature { r, s } = sig;
let r_target = builder.constant_nonnative(r);
let s_target = builder.constant_nonnative(s);
let sig_target = ECDSASignatureTarget {
r: r_target,
s: s_target,
};
builder.verify_message(msg_target, sig_target, pk_target);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -3,6 +3,7 @@ pub mod arithmetic_extension;
pub mod arithmetic_u32;
pub mod biguint;
pub mod curve;
pub mod ecdsa;
pub mod hash;
pub mod interpolation;
pub mod multiple_comparison;

View File

@ -60,8 +60,9 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Helper function for comparing, specifically, lists of `U32Target`s.
pub fn list_le_u32(&mut self, a: Vec<U32Target>, b: Vec<U32Target>) -> BoolTarget {
let a_targets = a.iter().map(|&t| t.0).collect();
let b_targets = b.iter().map(|&t| t.0).collect();
let a_targets: Vec<Target> = a.iter().map(|&t| t.0).collect();
let b_targets: Vec<Target> = b.iter().map(|&t| t.0).collect();
self.list_le(a_targets, b_targets, 32)
}
}

View File

@ -1,6 +1,6 @@
use std::marker::PhantomData;
use num::{BigUint, Zero};
use num::{BigUint, Integer, One, Zero};
use plonky2_field::{extension_field::Extendable, field_types::Field};
use plonky2_util::ceil_div_usize;
@ -15,7 +15,7 @@ use crate::plonk::circuit_builder::CircuitBuilder;
#[derive(Clone, Debug)]
pub struct NonNativeTarget<FF: Field> {
pub(crate) value: BigUintTarget,
_phantom: PhantomData<FF>,
pub(crate) _phantom: PhantomData<FF>,
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
@ -39,6 +39,10 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.biguint_to_nonnative(&x_biguint)
}
pub fn zero_nonnative<FF: Field>(&mut self) -> NonNativeTarget<FF> {
self.constant_nonnative(FF::ZERO)
}
// Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal.
pub fn connect_nonnative<FF: Field>(
&mut self,
@ -58,16 +62,90 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
// Add two `NonNativeTarget`s.
pub fn add_nonnative<FF: Field>(
&mut self,
a: &NonNativeTarget<FF>,
b: &NonNativeTarget<FF>,
) -> NonNativeTarget<FF> {
let result = self.add_biguint(&a.value, &b.value);
let sum = self.add_virtual_nonnative_target::<FF>();
let overflow = self.add_virtual_bool_target();
// TODO: reduce add result with only one conditional subtraction
self.reduce(&result)
self.add_simple_generator(NonNativeAdditionGenerator::<F, D, FF> {
a: a.clone(),
b: b.clone(),
sum: sum.clone(),
overflow,
_phantom: PhantomData,
});
let sum_expected = self.add_biguint(&a.value, &b.value);
let modulus = self.constant_biguint(&FF::order());
let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow);
let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow);
self.connect_biguint(&sum_expected, &sum_actual);
// Range-check result.
// TODO: can potentially leave unreduced until necessary (e.g. when connecting values).
let cmp = self.cmp_biguint(&sum.value, &modulus);
let one = self.one();
self.connect(cmp.target, one);
sum
}
pub fn mul_nonnative_by_bool<FF: Field>(
&mut self,
a: &NonNativeTarget<FF>,
b: BoolTarget,
) -> NonNativeTarget<FF> {
NonNativeTarget {
value: self.mul_biguint_by_bool(&a.value, b),
_phantom: PhantomData,
}
}
pub fn add_many_nonnative<FF: Field>(
&mut self,
to_add: &[NonNativeTarget<FF>],
) -> NonNativeTarget<FF> {
if to_add.len() == 1 {
return to_add[0].clone();
}
let sum = self.add_virtual_nonnative_target::<FF>();
let overflow = self.add_virtual_u32_target();
let summands = to_add.to_vec();
self.add_simple_generator(NonNativeMultipleAddsGenerator::<F, D, FF> {
summands: summands.clone(),
sum: sum.clone(),
overflow,
_phantom: PhantomData,
});
self.range_check_u32(sum.value.limbs.clone());
self.range_check_u32(vec![overflow]);
let sum_expected = summands
.iter()
.fold(self.zero_biguint(), |a, b| self.add_biguint(&a, &b.value));
let modulus = self.constant_biguint(&FF::order());
let overflow_biguint = BigUintTarget {
limbs: vec![overflow],
};
let mod_times_overflow = self.mul_biguint(&modulus, &overflow_biguint);
let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow);
self.connect_biguint(&sum_expected, &sum_actual);
// Range-check result.
// TODO: can potentially leave unreduced until necessary (e.g. when connecting values).
let cmp = self.cmp_biguint(&sum.value, &modulus);
let one = self.one();
self.connect(cmp.target, one);
sum
}
// Subtract two `NonNativeTarget`s.
@ -76,12 +154,27 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
a: &NonNativeTarget<FF>,
b: &NonNativeTarget<FF>,
) -> NonNativeTarget<FF> {
let order = self.constant_biguint(&FF::order());
let a_plus_order = self.add_biguint(&order, &a.value);
let result = self.sub_biguint(&a_plus_order, &b.value);
let diff = self.add_virtual_nonnative_target::<FF>();
let overflow = self.add_virtual_bool_target();
// TODO: reduce sub result with only one conditional addition?
self.reduce(&result)
self.add_simple_generator(NonNativeSubtractionGenerator::<F, D, FF> {
a: a.clone(),
b: b.clone(),
diff: diff.clone(),
overflow,
_phantom: PhantomData,
});
self.range_check_u32(diff.value.limbs.clone());
self.assert_bool(overflow);
let diff_plus_b = self.add_biguint(&diff.value, &b.value);
let modulus = self.constant_biguint(&FF::order());
let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow);
let diff_plus_b_reduced = self.sub_biguint(&diff_plus_b, &mod_times_overflow);
self.connect_biguint(&a.value, &diff_plus_b_reduced);
diff
}
pub fn mul_nonnative<FF: Field>(
@ -89,9 +182,45 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
a: &NonNativeTarget<FF>,
b: &NonNativeTarget<FF>,
) -> NonNativeTarget<FF> {
let result = self.mul_biguint(&a.value, &b.value);
let prod = self.add_virtual_nonnative_target::<FF>();
let modulus = self.constant_biguint(&FF::order());
let overflow = self.add_virtual_biguint_target(
a.value.num_limbs() + b.value.num_limbs() - modulus.num_limbs(),
);
self.reduce(&result)
self.add_simple_generator(NonNativeMultiplicationGenerator::<F, D, FF> {
a: a.clone(),
b: b.clone(),
prod: prod.clone(),
overflow: overflow.clone(),
_phantom: PhantomData,
});
self.range_check_u32(prod.value.limbs.clone());
self.range_check_u32(overflow.limbs.clone());
let prod_expected = self.mul_biguint(&a.value, &b.value);
let mod_times_overflow = self.mul_biguint(&modulus, &overflow);
let prod_actual = self.add_biguint(&prod.value, &mod_times_overflow);
self.connect_biguint(&prod_expected, &prod_actual);
prod
}
pub fn mul_many_nonnative<FF: Field>(
&mut self,
to_mul: &[NonNativeTarget<FF>],
) -> NonNativeTarget<FF> {
if to_mul.len() == 1 {
return to_mul[0].clone();
}
let mut accumulator = self.mul_nonnative(&to_mul[0], &to_mul[1]);
for i in 2..to_mul.len() {
accumulator = self.mul_nonnative(&accumulator, &to_mul[i]);
}
accumulator
}
pub fn neg_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
@ -104,36 +233,27 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn inv_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
let num_limbs = x.value.num_limbs();
let inv_biguint = self.add_virtual_biguint_target(num_limbs);
let inv = NonNativeTarget::<FF> {
value: inv_biguint,
_phantom: PhantomData,
};
let div = self.add_virtual_biguint_target(num_limbs);
self.add_simple_generator(NonNativeInverseGenerator::<F, D, FF> {
x: x.clone(),
inv: inv.clone(),
inv: inv_biguint.clone(),
div: div.clone(),
_phantom: PhantomData,
});
let product = self.mul_nonnative(x, &inv);
let one = self.constant_nonnative(FF::ONE);
self.connect_nonnative(&product, &one);
let product = self.mul_biguint(&x.value, &inv_biguint);
inv
}
let modulus = self.constant_biguint(&FF::order());
let mod_times_div = self.mul_biguint(&modulus, &div);
let one = self.constant_biguint(&BigUint::one());
let expected_product = self.add_biguint(&mod_times_div, &one);
self.connect_biguint(&product, &expected_product);
pub fn div_rem_nonnative<FF: Field>(
&mut self,
x: &NonNativeTarget<FF>,
y: &NonNativeTarget<FF>,
) -> (NonNativeTarget<FF>, NonNativeTarget<FF>) {
let x_biguint = self.nonnative_to_biguint(x);
let y_biguint = self.nonnative_to_biguint(y);
let (div_biguint, rem_biguint) = self.div_rem_biguint(&x_biguint, &y_biguint);
let div = self.biguint_to_nonnative(&div_biguint);
let rem = self.biguint_to_nonnative(&rem_biguint);
(div, rem)
NonNativeTarget::<FF> {
value: inv_biguint,
_phantom: PhantomData,
}
}
/// Returns `x % |FF|` as a `NonNativeTarget`.
@ -148,8 +268,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
#[allow(dead_code)]
fn reduce_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
pub fn reduce_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
let x_biguint = self.nonnative_to_biguint(x);
self.reduce(&x_biguint)
}
@ -187,10 +306,174 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
#[derive(Debug)]
struct NonNativeAdditionGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
a: NonNativeTarget<FF>,
b: NonNativeTarget<FF>,
sum: NonNativeTarget<FF>,
overflow: BoolTarget,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
for NonNativeAdditionGenerator<F, D, FF>
{
fn dependencies(&self) -> Vec<Target> {
self.a
.value
.limbs
.iter()
.cloned()
.chain(self.b.value.limbs.clone())
.map(|l| l.0)
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = witness.get_nonnative_target(self.a.clone());
let b = witness.get_nonnative_target(self.b.clone());
let a_biguint = a.to_biguint();
let b_biguint = b.to_biguint();
let sum_biguint = a_biguint + b_biguint;
let modulus = FF::order();
let (overflow, sum_reduced) = if sum_biguint > modulus {
(true, sum_biguint - modulus)
} else {
(false, sum_biguint)
};
out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced);
out_buffer.set_bool_target(self.overflow, overflow);
}
}
#[derive(Debug)]
struct NonNativeMultipleAddsGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
summands: Vec<NonNativeTarget<FF>>,
sum: NonNativeTarget<FF>,
overflow: U32Target,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
for NonNativeMultipleAddsGenerator<F, D, FF>
{
fn dependencies(&self) -> Vec<Target> {
self.summands
.iter()
.flat_map(|summand| summand.value.limbs.iter().map(|limb| limb.0))
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let summands: Vec<_> = self
.summands
.iter()
.map(|summand| witness.get_nonnative_target(summand.clone()))
.collect();
let summand_biguints: Vec<_> = summands
.iter()
.map(|summand| summand.to_biguint())
.collect();
let sum_biguint = summand_biguints
.iter()
.fold(BigUint::zero(), |a, b| a + b.clone());
let modulus = FF::order();
let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus);
let overflow = overflow_biguint.to_u64_digits()[0] as u32;
out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced);
out_buffer.set_u32_target(self.overflow, overflow);
}
}
#[derive(Debug)]
struct NonNativeSubtractionGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
a: NonNativeTarget<FF>,
b: NonNativeTarget<FF>,
diff: NonNativeTarget<FF>,
overflow: BoolTarget,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
for NonNativeSubtractionGenerator<F, D, FF>
{
fn dependencies(&self) -> Vec<Target> {
self.a
.value
.limbs
.iter()
.cloned()
.chain(self.b.value.limbs.clone())
.map(|l| l.0)
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = witness.get_nonnative_target(self.a.clone());
let b = witness.get_nonnative_target(self.b.clone());
let a_biguint = a.to_biguint();
let b_biguint = b.to_biguint();
let modulus = FF::order();
let (diff_biguint, overflow) = if a_biguint > b_biguint {
(a_biguint - b_biguint, false)
} else {
(modulus + a_biguint - b_biguint, true)
};
out_buffer.set_biguint_target(self.diff.value.clone(), diff_biguint);
out_buffer.set_bool_target(self.overflow, overflow);
}
}
#[derive(Debug)]
struct NonNativeMultiplicationGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
a: NonNativeTarget<FF>,
b: NonNativeTarget<FF>,
prod: NonNativeTarget<FF>,
overflow: BigUintTarget,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
for NonNativeMultiplicationGenerator<F, D, FF>
{
fn dependencies(&self) -> Vec<Target> {
self.a
.value
.limbs
.iter()
.cloned()
.chain(self.b.value.limbs.clone())
.map(|l| l.0)
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = witness.get_nonnative_target(self.a.clone());
let b = witness.get_nonnative_target(self.b.clone());
let a_biguint = a.to_biguint();
let b_biguint = b.to_biguint();
let prod_biguint = a_biguint * b_biguint;
let modulus = FF::order();
let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus);
out_buffer.set_biguint_target(self.prod.value.clone(), prod_reduced);
out_buffer.set_biguint_target(self.overflow.clone(), overflow_biguint);
}
}
#[derive(Debug)]
struct NonNativeInverseGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
x: NonNativeTarget<FF>,
inv: NonNativeTarget<FF>,
inv: BigUintTarget,
div: BigUintTarget,
_phantom: PhantomData<F>,
}
@ -205,7 +488,14 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
let x = witness.get_nonnative_target(self.x.clone());
let inv = x.inverse();
out_buffer.set_nonnative_target(self.inv.clone(), inv);
let x_biguint = x.to_biguint();
let inv_biguint = inv.to_biguint();
let prod = x_biguint * &inv_biguint;
let modulus = FF::order();
let (div, _rem) = prod.div_rem(&modulus);
out_buffer.set_biguint_target(self.div.clone(), div);
out_buffer.set_biguint_target(self.inv.clone(), inv_biguint);
}
}
@ -227,11 +517,12 @@ mod tests {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let x_ff = FF::rand();
let y_ff = FF::rand();
let sum_ff = x_ff + y_ff;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -247,12 +538,53 @@ mod tests {
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_nonnative_many_adds() -> Result<()> {
type FF = Secp256K1Base;
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let a_ff = FF::rand();
let b_ff = FF::rand();
let c_ff = FF::rand();
let d_ff = FF::rand();
let e_ff = FF::rand();
let f_ff = FF::rand();
let g_ff = FF::rand();
let h_ff = FF::rand();
let sum_ff = a_ff + b_ff + c_ff + d_ff + e_ff + f_ff + g_ff + h_ff;
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let a = builder.constant_nonnative(a_ff);
let b = builder.constant_nonnative(b_ff);
let c = builder.constant_nonnative(c_ff);
let d = builder.constant_nonnative(d_ff);
let e = builder.constant_nonnative(e_ff);
let f = builder.constant_nonnative(f_ff);
let g = builder.constant_nonnative(g_ff);
let h = builder.constant_nonnative(h_ff);
let all = [a, b, c, d, e, f, g, h];
let sum = builder.add_many_nonnative(&all);
let sum_expected = builder.constant_nonnative(sum_ff);
builder.connect_nonnative(&sum, &sum_expected);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_nonnative_sub() -> Result<()> {
type FF = Secp256K1Base;
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let x_ff = FF::rand();
let mut y_ff = FF::rand();
while y_ff.to_biguint() > x_ff.to_biguint() {
@ -260,7 +592,7 @@ mod tests {
}
let diff_ff = x_ff - y_ff;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -286,7 +618,7 @@ mod tests {
let y_ff = FF::rand();
let product_ff = x_ff * y_ff;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -311,7 +643,7 @@ mod tests {
let x_ff = FF::rand();
let neg_x_ff = -x_ff;
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -335,7 +667,7 @@ mod tests {
let x_ff = FF::rand();
let inv_x_ff = x_ff.inverse();
let config = CircuitConfig::standard_recursion_config();
let config = CircuitConfig::standard_ecc_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);

View File

@ -1,5 +1,7 @@
use plonky2_field::extension_field::Extendable;
use crate::gadgets::arithmetic_u32::U32Target;
use crate::gates::range_check_u32::U32RangeCheckGate;
use crate::hash::hash_types::RichField;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::{BoolTarget, Target};
@ -41,6 +43,25 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
(low, high)
}
pub fn range_check_u32(&mut self, vals: Vec<U32Target>) {
let num_input_limbs = vals.len();
let gate = U32RangeCheckGate::<F, D>::new(num_input_limbs);
let gate_index = self.add_gate(gate, vec![]);
for i in 0..num_input_limbs {
self.connect(
Target::wire(gate_index, gate.wire_ith_input_limb(i)),
vals[i].0,
);
}
}
pub fn assert_bool(&mut self, b: BoolTarget) {
let z = self.mul_sub(b.target, b.target, b.target);
let zero = self.zero();
self.connect(z, zero);
}
}
#[derive(Debug)]

View File

@ -0,0 +1,461 @@
use std::marker::PhantomData;
use itertools::unfold;
use plonky2_util::ceil_div_usize;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::hash::hash_types::RichField;
use crate::iop::ext_target::ExtensionTarget;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
const LOG2_MAX_NUM_ADDENDS: usize = 4;
const MAX_NUM_ADDENDS: usize = 16;
/// A gate to perform addition on `num_addends` different 32-bit values, plus a small carry
#[derive(Copy, Clone, Debug)]
pub struct U32AddManyGate<F: RichField + Extendable<D>, const D: usize> {
pub num_addends: usize,
pub num_ops: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> U32AddManyGate<F, D> {
pub fn new_from_config(config: &CircuitConfig, num_addends: usize) -> Self {
Self {
num_addends,
num_ops: Self::num_ops(num_addends, config),
_phantom: PhantomData,
}
}
pub(crate) fn num_ops(num_addends: usize, config: &CircuitConfig) -> usize {
debug_assert!(num_addends <= MAX_NUM_ADDENDS);
let wires_per_op = (num_addends + 3) + Self::num_limbs();
let routed_wires_per_op = num_addends + 3;
(config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op)
}
pub fn wire_ith_op_jth_addend(&self, i: usize, j: usize) -> usize {
debug_assert!(i < self.num_ops);
debug_assert!(j < self.num_addends);
(self.num_addends + 3) * i + j
}
pub fn wire_ith_carry(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
(self.num_addends + 3) * i + self.num_addends
}
pub fn wire_ith_output_result(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
(self.num_addends + 3) * i + self.num_addends + 1
}
pub fn wire_ith_output_carry(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
(self.num_addends + 3) * i + self.num_addends + 2
}
pub fn limb_bits() -> usize {
2
}
pub fn num_result_limbs() -> usize {
ceil_div_usize(32, Self::limb_bits())
}
pub fn num_carry_limbs() -> usize {
ceil_div_usize(LOG2_MAX_NUM_ADDENDS, Self::limb_bits())
}
pub fn num_limbs() -> usize {
Self::num_result_limbs() + Self::num_carry_limbs()
}
pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize {
debug_assert!(i < self.num_ops);
debug_assert!(j < Self::num_limbs());
(self.num_addends + 3) * self.num_ops + Self::num_limbs() * i + j
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32AddManyGate<F, D> {
fn id(&self) -> String {
format!("{:?}", self)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..self.num_ops {
let addends: Vec<F::Extension> = (0..self.num_addends)
.map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)])
.collect();
let carry = vars.local_wires[self.wire_ith_carry(i)];
let computed_output = addends.iter().fold(F::Extension::ZERO, |x, &y| x + y) + carry;
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_carry = vars.local_wires[self.wire_ith_output_carry(i)];
let base = F::Extension::from_canonical_u64(1 << 32u64);
let combined_output = output_carry * base + output_result;
constraints.push(combined_output - computed_output);
let mut combined_result_limbs = F::Extension::ZERO;
let mut combined_carry_limbs = F::Extension::ZERO;
let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::Extension::from_canonical_usize(x))
.product();
constraints.push(product);
if j < Self::num_result_limbs() {
combined_result_limbs = base * combined_result_limbs + this_limb;
} else {
combined_carry_limbs = base * combined_carry_limbs + this_limb;
}
}
constraints.push(combined_result_limbs - output_result);
constraints.push(combined_carry_limbs - output_carry);
}
constraints
}
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
for i in 0..self.num_ops {
let addends: Vec<F> = (0..self.num_addends)
.map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)])
.collect();
let carry = vars.local_wires[self.wire_ith_carry(i)];
let computed_output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry;
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_carry = vars.local_wires[self.wire_ith_output_carry(i)];
let base = F::from_canonical_u64(1 << 32u64);
let combined_output = output_carry * base + output_result;
yield_constr.one(combined_output - computed_output);
let mut combined_result_limbs = F::ZERO;
let mut combined_carry_limbs = F::ZERO;
let base = F::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::from_canonical_usize(x))
.product();
yield_constr.one(product);
if j < Self::num_result_limbs() {
combined_result_limbs = base * combined_result_limbs + this_limb;
} else {
combined_carry_limbs = base * combined_carry_limbs + this_limb;
}
}
yield_constr.one(combined_result_limbs - output_result);
yield_constr.one(combined_carry_limbs - output_carry);
}
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..self.num_ops {
let addends: Vec<ExtensionTarget<D>> = (0..self.num_addends)
.map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)])
.collect();
let carry = vars.local_wires[self.wire_ith_carry(i)];
let mut computed_output = carry;
for addend in addends {
computed_output = builder.add_extension(computed_output, addend);
}
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_carry = vars.local_wires[self.wire_ith_output_carry(i)];
let base: F::Extension = F::from_canonical_u64(1 << 32u64).into();
let base_target = builder.constant_extension(base);
let combined_output =
builder.mul_add_extension(output_carry, base_target, output_result);
constraints.push(builder.sub_extension(combined_output, computed_output));
let mut combined_result_limbs = builder.zero_extension();
let mut combined_carry_limbs = builder.zero_extension();
let base = builder
.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let mut product = builder.one_extension();
for x in 0..max_limb {
let x_target =
builder.constant_extension(F::Extension::from_canonical_usize(x));
let diff = builder.sub_extension(this_limb, x_target);
product = builder.mul_extension(product, diff);
}
constraints.push(product);
if j < Self::num_result_limbs() {
combined_result_limbs =
builder.mul_add_extension(base, combined_result_limbs, this_limb);
} else {
combined_carry_limbs =
builder.mul_add_extension(base, combined_carry_limbs, this_limb);
}
}
constraints.push(builder.sub_extension(combined_result_limbs, output_result));
constraints.push(builder.sub_extension(combined_carry_limbs, output_carry));
}
constraints
}
fn generators(
&self,
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
(0..self.num_ops)
.map(|i| {
let g: Box<dyn WitnessGenerator<F>> = Box::new(
U32AddManyGenerator {
gate: *self,
gate_index,
i,
_phantom: PhantomData,
}
.adapter(),
);
g
})
.collect()
}
fn num_wires(&self) -> usize {
(self.num_addends + 3) * self.num_ops + Self::num_limbs() * self.num_ops
}
fn num_constants(&self) -> usize {
0
}
fn degree(&self) -> usize {
1 << Self::limb_bits()
}
fn num_constraints(&self) -> usize {
self.num_ops * (3 + Self::num_limbs())
}
}
#[derive(Clone, Debug)]
struct U32AddManyGenerator<F: RichField + Extendable<D>, const D: usize> {
gate: U32AddManyGate<F, D>,
gate_index: usize,
i: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for U32AddManyGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
(0..self.gate.num_addends)
.map(|j| local_target(self.gate.wire_ith_op_jth_addend(self.i, j)))
.chain([local_target(self.gate.wire_ith_carry(self.i))])
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
};
let get_local_wire = |input| witness.get_wire(local_wire(input));
let addends: Vec<_> = (0..self.gate.num_addends)
.map(|j| get_local_wire(self.gate.wire_ith_op_jth_addend(self.i, j)))
.collect();
let carry = get_local_wire(self.gate.wire_ith_carry(self.i));
let output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry;
let output_u64 = output.to_canonical_u64();
let output_carry_u64 = output_u64 >> 32;
let output_result_u64 = output_u64 & ((1 << 32) - 1);
let output_carry = F::from_canonical_u64(output_carry_u64);
let output_result = F::from_canonical_u64(output_result_u64);
let output_carry_wire = local_wire(self.gate.wire_ith_output_carry(self.i));
let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i));
out_buffer.set_wire(output_carry_wire, output_carry);
out_buffer.set_wire(output_result_wire, output_result);
let num_result_limbs = U32AddManyGate::<F, D>::num_result_limbs();
let num_carry_limbs = U32AddManyGate::<F, D>::num_carry_limbs();
let limb_base = 1 << U32AddManyGate::<F, D>::limb_bits();
let split_to_limbs = |mut val, num| {
unfold((), move |_| {
let ret = val % limb_base;
val /= limb_base;
Some(ret)
})
.take(num)
.map(F::from_canonical_u64)
};
let result_limbs = split_to_limbs(output_result_u64, num_result_limbs);
let carry_limbs = split_to_limbs(output_carry_u64, num_carry_limbs);
for (j, limb) in result_limbs.chain(carry_limbs).enumerate() {
let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j));
out_buffer.set_wire(wire, limb);
}
}
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use anyhow::Result;
use itertools::unfold;
use rand::Rng;
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::add_many_u32::U32AddManyGate;
use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::hash::hash_types::HashOut;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use crate::plonk::vars::EvaluationVars;
#[test]
fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(U32AddManyGate::<GoldilocksField, 4> {
num_addends: 4,
num_ops: 3,
_phantom: PhantomData,
})
}
#[test]
fn eval_fns() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
test_eval_fns::<F, C, _, D>(U32AddManyGate::<GoldilocksField, D> {
num_addends: 4,
num_ops: 3,
_phantom: PhantomData,
})
}
#[test]
fn test_gate_constraint() {
type F = GoldilocksField;
type FF = QuarticExtension<GoldilocksField>;
const D: usize = 4;
const NUM_ADDENDS: usize = 10;
const NUM_U32_ADD_MANY_OPS: usize = 3;
fn get_wires(addends: Vec<Vec<u64>>, carries: Vec<u64>) -> Vec<FF> {
let mut v0 = Vec::new();
let mut v1 = Vec::new();
let num_result_limbs = U32AddManyGate::<F, D>::num_result_limbs();
let num_carry_limbs = U32AddManyGate::<F, D>::num_carry_limbs();
let limb_base = 1 << U32AddManyGate::<F, D>::limb_bits();
for op in 0..NUM_U32_ADD_MANY_OPS {
let adds = &addends[op];
let ca = carries[op];
let output = adds.iter().sum::<u64>() + ca;
let output_result = output & ((1 << 32) - 1);
let output_carry = output >> 32;
let split_to_limbs = |mut val, num| {
unfold((), move |_| {
let ret = val % limb_base;
val /= limb_base;
Some(ret)
})
.take(num)
.map(F::from_canonical_u64)
};
let mut result_limbs: Vec<_> =
split_to_limbs(output_result, num_result_limbs).collect();
let mut carry_limbs: Vec<_> =
split_to_limbs(output_carry, num_carry_limbs).collect();
for a in adds {
v0.push(F::from_canonical_u64(*a));
}
v0.push(F::from_canonical_u64(ca));
v0.push(F::from_canonical_u64(output_result));
v0.push(F::from_canonical_u64(output_carry));
v1.append(&mut result_limbs);
v1.append(&mut carry_limbs);
}
v0.iter().chain(v1.iter()).map(|&x| x.into()).collect()
}
let mut rng = rand::thread_rng();
let addends: Vec<Vec<_>> = (0..NUM_U32_ADD_MANY_OPS)
.map(|_| (0..NUM_ADDENDS).map(|_| rng.gen::<u32>() as u64).collect())
.collect();
let carries: Vec<_> = (0..NUM_U32_ADD_MANY_OPS)
.map(|_| rng.gen::<u32>() as u64)
.collect();
let gate = U32AddManyGate::<F, D> {
num_addends: NUM_ADDENDS,
num_ops: NUM_U32_ADD_MANY_OPS,
_phantom: PhantomData,
};
let vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(addends, carries),
public_inputs_hash: &HashOut::rand(),
};
assert!(
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
"Gate constraints are not satisfied."
);
}
}

View File

@ -131,7 +131,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate
);
g
})
.collect::<Vec<_>>()
.collect()
}
fn num_wires(&self) -> usize {

View File

@ -138,7 +138,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExte
);
g
})
.collect::<Vec<_>>()
.collect()
}
fn num_wires(&self) -> usize {

View File

@ -212,7 +212,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticG
);
g
})
.collect::<Vec<_>>()
.collect()
}
fn num_wires(&self) -> usize {
@ -425,10 +425,7 @@ mod tests {
v1.append(&mut output_limbs_f);
}
v0.iter()
.chain(v1.iter())
.map(|&x| x.into())
.collect::<Vec<_>>()
v0.iter().chain(v1.iter()).map(|&x| x.into()).collect()
}
let mut rng = rand::thread_rng();

View File

@ -578,7 +578,7 @@ mod tests {
v.append(&mut chunks_equal);
v.append(&mut intermediate_values);
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
v.iter().map(|&x| x.into()).collect()
};
let mut rng = rand::thread_rng();

View File

@ -658,7 +658,7 @@ mod tests {
v.append(&mut intermediate_values);
v.append(&mut msd_bits);
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
v.iter().map(|&x| x.into()).collect()
};
let mut rng = rand::thread_rng();

View File

@ -113,7 +113,7 @@ pub trait Gate<F: RichField + Extendable<D>, const D: usize>: 'static + Send + S
builder: &mut CircuitBuilder<F, D>,
mut vars: EvaluationTargets<D>,
prefix: &[bool],
combined_gate_constraints: &mut Vec<ExtensionTarget<D>>,
combined_gate_constraints: &mut [ExtensionTarget<D>],
) {
let filter = compute_filter_recursively(builder, prefix, vars.local_constants);
vars.remove_prefix(prefix);

View File

@ -343,7 +343,7 @@ mod tests {
for i in 0..coeffs.len() {
v.extend(coeffs.coeffs[i].0);
}
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
v.iter().map(|&x| x.into()).collect()
}
// Get a working row for InterpolationGate.

View File

@ -443,7 +443,7 @@ mod tests {
.take(gate.num_points() - 2)
.flat_map(|ff| ff.0),
);
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
v.iter().map(|&x| x.into()).collect()
}
// Get a working row for LowDegreeInterpolationGate.

View File

@ -1,6 +1,7 @@
// Gates have `new` methods that return `GateRef`s.
#![allow(clippy::new_ret_no_self)]
pub mod add_many_u32;
pub mod arithmetic_base;
pub mod arithmetic_extension;
pub mod arithmetic_u32;
@ -20,6 +21,7 @@ pub mod poseidon;
pub(crate) mod poseidon_mds;
pub(crate) mod public_input;
pub mod random_access;
pub mod range_check_u32;
pub mod reducing;
pub mod reducing_extension;
pub mod subtraction_u32;

View File

@ -125,7 +125,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGa
);
g
})
.collect::<Vec<_>>()
.collect()
}
fn num_wires(&self) -> usize {

View File

@ -209,7 +209,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGa
);
g
})
.collect::<Vec<_>>()
.collect()
}
fn num_wires(&self) -> usize {

View File

@ -0,0 +1,322 @@
use std::marker::PhantomData;
use plonky2_util::ceil_div_usize;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::hash::hash_types::RichField;
use crate::iop::ext_target::ExtensionTarget;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive};
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate which can decompose a number into base B little-endian limbs.
#[derive(Copy, Clone, Debug)]
pub struct U32RangeCheckGate<F: RichField + Extendable<D>, const D: usize> {
pub num_input_limbs: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> U32RangeCheckGate<F, D> {
pub fn new(num_input_limbs: usize) -> Self {
Self {
num_input_limbs,
_phantom: PhantomData,
}
}
pub const AUX_LIMB_BITS: usize = 2;
pub const BASE: usize = 1 << Self::AUX_LIMB_BITS;
fn aux_limbs_per_input_limb(&self) -> usize {
ceil_div_usize(32, Self::AUX_LIMB_BITS)
}
pub fn wire_ith_input_limb(&self, i: usize) -> usize {
debug_assert!(i < self.num_input_limbs);
i
}
pub fn wire_ith_input_limb_jth_aux_limb(&self, i: usize, j: usize) -> usize {
debug_assert!(i < self.num_input_limbs);
debug_assert!(j < self.aux_limbs_per_input_limb());
self.num_input_limbs + self.aux_limbs_per_input_limb() * i + j
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32RangeCheckGate<F, D> {
fn id(&self) -> String {
format!("{:?}", self)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let base = F::Extension::from_canonical_usize(Self::BASE);
for i in 0..self.num_input_limbs {
let input_limb = vars.local_wires[self.wire_ith_input_limb(i)];
let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb())
.map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)])
.collect();
let computed_sum = reduce_with_powers(&aux_limbs, base);
constraints.push(computed_sum - input_limb);
for aux_limb in aux_limbs {
constraints.push(
(0..Self::BASE)
.map(|i| aux_limb - F::Extension::from_canonical_usize(i))
.product(),
);
}
}
constraints
}
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let base = F::from_canonical_usize(Self::BASE);
for i in 0..self.num_input_limbs {
let input_limb = vars.local_wires[self.wire_ith_input_limb(i)];
let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb())
.map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)])
.collect();
let computed_sum = reduce_with_powers(&aux_limbs, base);
yield_constr.one(computed_sum - input_limb);
for aux_limb in aux_limbs {
yield_constr.one(
(0..Self::BASE)
.map(|i| aux_limb - F::from_canonical_usize(i))
.product(),
);
}
}
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let base = builder.constant(F::from_canonical_usize(Self::BASE));
for i in 0..self.num_input_limbs {
let input_limb = vars.local_wires[self.wire_ith_input_limb(i)];
let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb())
.map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)])
.collect();
let computed_sum = reduce_with_powers_ext_recursive(builder, &aux_limbs, base);
constraints.push(builder.sub_extension(computed_sum, input_limb));
for aux_limb in aux_limbs {
constraints.push({
let mut acc = builder.one_extension();
(0..Self::BASE).for_each(|i| {
// We update our accumulator as:
// acc' = acc (x - i)
// = acc x + (-i) acc
// Since -i is constant, we can do this in one arithmetic_extension call.
let neg_i = -F::from_canonical_usize(i);
acc = builder.arithmetic_extension(F::ONE, neg_i, acc, aux_limb, acc)
});
acc
});
}
}
constraints
}
fn generators(
&self,
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = U32RangeCheckGenerator {
gate: *self,
gate_index,
};
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {
self.num_input_limbs * (1 + self.aux_limbs_per_input_limb())
}
fn num_constants(&self) -> usize {
0
}
// Bounded by the range-check (x-0)*(x-1)*...*(x-BASE+1).
fn degree(&self) -> usize {
Self::BASE
}
// 1 for checking the each sum of aux limbs, plus a range check for each aux limb.
fn num_constraints(&self) -> usize {
self.num_input_limbs * (1 + self.aux_limbs_per_input_limb())
}
}
#[derive(Debug)]
pub struct U32RangeCheckGenerator<F: RichField + Extendable<D>, const D: usize> {
gate: U32RangeCheckGate<F, D>,
gate_index: usize,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for U32RangeCheckGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
let num_input_limbs = self.gate.num_input_limbs;
(0..num_input_limbs)
.map(|i| Target::wire(self.gate_index, self.gate.wire_ith_input_limb(i)))
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let num_input_limbs = self.gate.num_input_limbs;
for i in 0..num_input_limbs {
let sum_value = witness
.get_target(Target::wire(
self.gate_index,
self.gate.wire_ith_input_limb(i),
))
.to_canonical_u64() as u32;
let base = U32RangeCheckGate::<F, D>::BASE as u32;
let limbs = (0..self.gate.aux_limbs_per_input_limb()).map(|j| {
Target::wire(
self.gate_index,
self.gate.wire_ith_input_limb_jth_aux_limb(i, j),
)
});
let limbs_value = (0..self.gate.aux_limbs_per_input_limb())
.scan(sum_value, |acc, _| {
let tmp = *acc % base;
*acc /= base;
Some(F::from_canonical_u32(tmp))
})
.collect::<Vec<_>>();
for (b, b_value) in limbs.zip(limbs_value) {
out_buffer.set_target(b, b_value);
}
}
}
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use anyhow::Result;
use itertools::unfold;
use plonky2_util::ceil_div_usize;
use rand::Rng;
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::range_check_u32::U32RangeCheckGate;
use crate::hash::hash_types::HashOut;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use crate::plonk::vars::EvaluationVars;
#[test]
fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(U32RangeCheckGate::new(8))
}
#[test]
fn eval_fns() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
test_eval_fns::<F, C, _, D>(U32RangeCheckGate::new(8))
}
fn test_gate_constraint(input_limbs: Vec<u64>) {
type F = GoldilocksField;
type FF = QuarticExtension<GoldilocksField>;
const D: usize = 4;
const AUX_LIMB_BITS: usize = 2;
const BASE: usize = 1 << AUX_LIMB_BITS;
const AUX_LIMBS_PER_INPUT_LIMB: usize = ceil_div_usize(32, AUX_LIMB_BITS);
fn get_wires(input_limbs: Vec<u64>) -> Vec<FF> {
let num_input_limbs = input_limbs.len();
let mut v = Vec::new();
for i in 0..num_input_limbs {
let input_limb = input_limbs[i];
let split_to_limbs = |mut val, num| {
unfold((), move |_| {
let ret = val % (BASE as u64);
val /= BASE as u64;
Some(ret)
})
.take(num)
.map(F::from_canonical_u64)
};
let mut aux_limbs: Vec<_> =
split_to_limbs(input_limb, AUX_LIMBS_PER_INPUT_LIMB).collect();
v.append(&mut aux_limbs);
}
input_limbs
.iter()
.cloned()
.map(F::from_canonical_u64)
.chain(v.iter().cloned())
.map(|x| x.into())
.collect()
}
let gate = U32RangeCheckGate::<F, D> {
num_input_limbs: 8,
_phantom: PhantomData,
};
let vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(input_limbs),
public_inputs_hash: &HashOut::rand(),
};
assert!(
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
"Gate constraints are not satisfied."
);
}
#[test]
fn test_gate_constraint_good() {
let mut rng = rand::thread_rng();
let input_limbs: Vec<_> = (0..8).map(|_| rng.gen::<u32>() as u64).collect();
test_gate_constraint(input_limbs);
}
#[test]
#[should_panic]
fn test_gate_constraint_bad() {
let mut rng = rand::thread_rng();
let input_limbs: Vec<_> = (0..8).map(|_| rng.gen()).collect();
test_gate_constraint(input_limbs);
}
}

View File

@ -416,10 +416,7 @@ mod tests {
v1.append(&mut output_limbs);
}
v0.iter()
.chain(v1.iter())
.map(|&x| x.into())
.collect::<Vec<_>>()
v0.iter().chain(v1.iter()).map(|&x| x.into()).collect()
}
let mut rng = rand::thread_rng();

View File

@ -432,7 +432,7 @@ mod tests {
v.push(F::from_bool(switch));
}
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
v.iter().map(|&x| x.into()).collect()
}
let first_inputs: Vec<Vec<F>> = (0..num_copies).map(|_| F::rand_vec(CHUNK_SIZE)).collect();

View File

@ -10,7 +10,7 @@ use crate::gadgets::biguint::BigUintTarget;
use crate::gadgets::nonnative::NonNativeTarget;
use crate::hash::hash_types::{HashOut, HashOutTarget, RichField};
use crate::iop::ext_target::ExtensionTarget;
use crate::iop::target::Target;
use crate::iop::target::{BoolTarget, Target};
use crate::iop::wire::Wire;
use crate::iop::witness::{PartialWitness, PartitionWitness, Witness};
use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData};
@ -161,12 +161,17 @@ impl<F: Field> GeneratedValues<F> {
self.target_values.push((target, value))
}
fn set_u32_target(&mut self, target: U32Target, value: u32) {
pub fn set_bool_target(&mut self, target: BoolTarget, value: bool) {
self.set_target(target.target, F::from_bool(value))
}
pub fn set_u32_target(&mut self, target: U32Target, value: u32) {
self.set_target(target.0, F::from_canonical_u32(value))
}
pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) {
let mut limbs = value.to_u32_digits();
assert!(target.num_limbs() >= limbs.len());
limbs.resize(target.num_limbs(), 0);

View File

@ -16,6 +16,7 @@ use crate::gadgets::arithmetic::BaseArithmeticOperation;
use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation;
use crate::gadgets::arithmetic_u32::U32Target;
use crate::gadgets::polynomial::PolynomialCoeffsExtTarget;
use crate::gates::add_many_u32::U32AddManyGate;
use crate::gates::arithmetic_base::ArithmeticGate;
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::arithmetic_u32::U32ArithmeticGate;
@ -203,6 +204,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
BoolTarget::new_unsafe(self.add_virtual_target())
}
pub fn add_virtual_bool_target_safe(&mut self) -> BoolTarget {
let b = BoolTarget::new_unsafe(self.add_virtual_target());
self.assert_bool(b);
b
}
/// Adds a gate to the circuit, and returns its index.
pub fn add_gate<G: Gate<F, D>>(&mut self, gate_type: G, constants: Vec<F>) -> usize {
self.check_gate_compatibility(&gate_type);
@ -233,7 +240,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fn check_gate_compatibility<G: Gate<F, D>>(&self, gate: &G) {
assert!(
gate.num_wires() <= self.config.num_wires,
"{:?} requires {} wires, but our GateConfig has only {}",
"{:?} requires {} wires, but our CircuitConfig has only {}",
gate.id(),
gate.num_wires(),
self.config.num_wires
@ -654,14 +661,14 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let subgroup = F::two_adic_subgroup(degree_bits);
let constant_vecs = timed!(
&mut timing,
timing,
"generate constant polynomials",
self.constant_polys(&prefixed_gates, num_constants)
);
let k_is = get_unique_coset_shifts(degree, self.config.num_routed_wires);
let (sigma_vecs, forest) = timed!(
&mut timing,
timing,
"generate sigma polynomials",
self.sigma_vecs(&k_is, &subgroup)
);
@ -816,9 +823,12 @@ pub struct BatchedGates<F: RichField + Extendable<D>, const D: usize> {
/// of switches
pub(crate) current_switch_gates: Vec<Option<(SwitchGate<F, D>, usize, usize)>>,
/// A map `n -> (g, i)` from `n` number of addends to an available `U32AddManyGate` of that size with gate
/// index `g` and already using `i` random accesses.
pub(crate) free_u32_add_many: HashMap<usize, (usize, usize)>,
/// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one)
pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>,
/// The `U32SubtractionGate` currently being filled (so new u32 subtraction operations will be added to this gate before creating a new one)
pub(crate) current_u32_subtraction_gate: Option<(usize, usize)>,
@ -834,6 +844,7 @@ impl<F: RichField + Extendable<D>, const D: usize> BatchedGates<F, D> {
free_mul: HashMap::new(),
free_random_access: HashMap::new(),
current_switch_gates: Vec::new(),
free_u32_add_many: HashMap::new(),
current_u32_arithmetic_gate: None,
current_u32_subtraction_gate: None,
free_constant: None,
@ -931,8 +942,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
(gate, i)
}
/// Finds the last available random access gate with the given `vec_size` or add one if there aren't any.
/// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index
/// Finds the last available random access gate with the given `bits` or adds one if there aren't any.
/// Returns `(g,i)` such that there is a random access gate for the given `bits` at index
/// `g` and the gate's `i`-th random access is available.
pub(crate) fn find_random_access_gate(&mut self, bits: usize) -> (usize, usize) {
let (gate, i) = self
@ -994,6 +1005,35 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
(gate, gate_index, next_copy)
}
/// Finds the last available U32 add-many gate with the given `num_addends` or adds one if there aren't any.
/// Returns `(g,i)` such that there is a `U32AddManyGate` for the given `num_addends` at index
/// `g` and the gate's `i`-th copy is available.
pub(crate) fn find_u32_add_many_gate(&mut self, num_addends: usize) -> (usize, usize) {
let (gate, i) = self
.batched_gates
.free_u32_add_many
.get(&num_addends)
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
U32AddManyGate::new_from_config(&self.config, num_addends),
vec![],
);
(gate, 0)
});
// Update `free_u32_add_many` with new values.
if i + 1 < U32AddManyGate::<F, D>::new_from_config(&self.config, num_addends).num_ops {
self.batched_gates
.free_u32_add_many
.insert(num_addends, (gate, i + 1));
} else {
self.batched_gates.free_u32_add_many.remove(&num_addends);
}
(gate, i)
}
pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) {
let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate {
None => {
@ -1140,6 +1180,28 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
/// Fill the remaining unused u32 add-many operations with zeros, so that all
/// `U32AddManyGenerator`s are run.
fn fill_u32_add_many_gates(&mut self) {
let zero = self.zero_u32();
for (num_addends, (_, i)) in self.batched_gates.free_u32_add_many.clone() {
let max_copies =
U32AddManyGate::<F, D>::new_from_config(&self.config, num_addends).num_ops;
for _ in i..max_copies {
let gate = U32AddManyGate::<F, D>::new_from_config(&self.config, num_addends);
let (gate_index, copy) = self.find_u32_add_many_gate(num_addends);
for j in 0..num_addends {
self.connect(
Target::wire(gate_index, gate.wire_ith_op_jth_addend(copy, j)),
zero.0,
);
}
self.connect(Target::wire(gate_index, gate.wire_ith_carry(copy)), zero.0);
}
}
}
/// Fill the remaining unused U32 arithmetic operations with zeros, so that all
/// `U32ArithmeticGenerator`s are run.
fn fill_u32_arithmetic_gates(&mut self) {
@ -1172,6 +1234,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.fill_mul_gates();
self.fill_random_access_gates();
self.fill_switch_gates();
self.fill_u32_add_many_gates();
self.fill_u32_arithmetic_gates();
self.fill_u32_subtraction_gates();
}

View File

@ -49,7 +49,7 @@ pub struct CircuitConfig {
impl Default for CircuitConfig {
fn default() -> Self {
CircuitConfig::standard_recursion_config()
Self::standard_recursion_config()
}
}
@ -79,6 +79,13 @@ impl CircuitConfig {
}
}
pub fn standard_ecc_config() -> Self {
Self {
num_wires: 136,
..Self::standard_recursion_config()
}
}
pub fn standard_recursion_zk_config() -> Self {
CircuitConfig {
zero_knowledge: true,