diff --git a/insertion/src/insertion_gate.rs b/insertion/src/insertion_gate.rs index 8ee60483..442416d3 100644 --- a/insertion/src/insertion_gate.rs +++ b/insertion/src/insertion_gate.rs @@ -404,7 +404,7 @@ mod tests { v.extend(equality_dummy_vals); v.extend(insert_here_vals); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } let orig_vec = vec![FF::rand(); 3]; diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 8726fde7..31b3cb1b 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -171,9 +171,25 @@ impl, const D: usize> CircuitBuilder { a: &NonNativeTarget, b: &NonNativeTarget, ) -> NonNativeTarget { - let result = self.mul_biguint(&a.value, &b.value); + let prod = self.add_virtual_nonnative_target::(); + let modulus = self.constant_biguint(&FF::order()); + let overflow = self.add_virtual_biguint_target(a.value.num_limbs() + b.value.num_limbs() - modulus.num_limbs()); - self.reduce(&result) + self.add_simple_generator(NonNativeMultiplicationGenerator:: { + a: a.clone(), + b: b.clone(), + prod: prod.clone(), + overflow: overflow.clone(), + _phantom: PhantomData, + }); + + let prod_expected = self.mul_biguint(&a.value, &b.value); + + let mod_times_overflow = self.mul_biguint(&modulus, &overflow); + let prod_actual = self.add_biguint(&prod.value, &mod_times_overflow); + self.connect_biguint(&prod_expected, &prod_actual); + + prod } pub fn mul_many_nonnative( @@ -226,20 +242,6 @@ impl, const D: usize> CircuitBuilder { inv } - pub fn div_rem_nonnative( - &mut self, - x: &NonNativeTarget, - y: &NonNativeTarget, - ) -> (NonNativeTarget, NonNativeTarget) { - let x_biguint = self.nonnative_to_biguint(x); - let y_biguint = self.nonnative_to_biguint(y); - - let (div_biguint, rem_biguint) = self.div_rem_biguint(&x_biguint, &y_biguint); - let div = self.biguint_to_nonnative(&div_biguint); - let rem = self.biguint_to_nonnative(&rem_biguint); - (div, rem) - } - /// Returns `x % |FF|` as a `NonNativeTarget`. fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { let modulus = FF::order(); @@ -252,8 +254,7 @@ impl, const D: usize> CircuitBuilder { } } - #[allow(dead_code)] - fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + pub fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { let x_biguint = self.nonnative_to_biguint(x); self.reduce(&x_biguint) } @@ -416,6 +417,45 @@ impl, const D: usize, FF: Field> SimpleGenerator } } +#[derive(Debug)] +struct NonNativeMultiplicationGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + prod: NonNativeTarget, + overflow: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: Field> SimpleGenerator + for NonNativeMultiplicationGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_nonnative_target(self.a.clone()); + let b = witness.get_nonnative_target(self.b.clone()); + let a_biguint = a.to_biguint(); + let b_biguint = b.to_biguint(); + + let prod_biguint = a_biguint * b_biguint; + + let modulus = FF::order(); + let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); + + out_buffer.set_biguint_target(self.prod.value.clone(), prod_reduced); + out_buffer.set_biguint_target(self.overflow.clone(), overflow_biguint); + } +} + #[derive(Debug)] struct NonNativeInverseGenerator, const D: usize, FF: Field> { x: NonNativeTarget, @@ -566,7 +606,6 @@ mod tests { let x = builder.constant_nonnative(x_ff); let y = builder.constant_nonnative(y_ff); - println!("LIMBS LIMBS LIMBS {}", y.value.limbs.len()); let product = builder.mul_nonnative(&x, &y); let product_expected = builder.constant_nonnative(product_ff); diff --git a/plonky2/src/gates/add_many_u32.rs b/plonky2/src/gates/add_many_u32.rs index c8b5f8af..01c7ed30 100644 --- a/plonky2/src/gates/add_many_u32.rs +++ b/plonky2/src/gates/add_many_u32.rs @@ -248,7 +248,7 @@ impl, const D: usize> Gate for U32AddManyGate ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { @@ -426,7 +426,7 @@ mod tests { v0.iter() .chain(v1.iter()) .map(|&x| x.into()) - .collect::>() + .collect() } let mut rng = rand::thread_rng(); diff --git a/plonky2/src/gates/arithmetic_base.rs b/plonky2/src/gates/arithmetic_base.rs index 8f67dab2..738b8ad4 100644 --- a/plonky2/src/gates/arithmetic_base.rs +++ b/plonky2/src/gates/arithmetic_base.rs @@ -131,7 +131,7 @@ impl, const D: usize> Gate for ArithmeticGate ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/arithmetic_u32.rs b/plonky2/src/gates/arithmetic_u32.rs index 1d4a834c..bef21a97 100644 --- a/plonky2/src/gates/arithmetic_u32.rs +++ b/plonky2/src/gates/arithmetic_u32.rs @@ -212,7 +212,7 @@ impl, const D: usize> Gate for U32ArithmeticG ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/assert_le.rs b/plonky2/src/gates/assert_le.rs index 6a99acd9..c385bb31 100644 --- a/plonky2/src/gates/assert_le.rs +++ b/plonky2/src/gates/assert_le.rs @@ -578,7 +578,7 @@ mod tests { v.append(&mut chunks_equal); v.append(&mut intermediate_values); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() }; let mut rng = rand::thread_rng(); diff --git a/plonky2/src/gates/low_degree_interpolation.rs b/plonky2/src/gates/low_degree_interpolation.rs index b7307470..845da5ab 100644 --- a/plonky2/src/gates/low_degree_interpolation.rs +++ b/plonky2/src/gates/low_degree_interpolation.rs @@ -443,7 +443,7 @@ mod tests { .take(gate.num_points() - 2) .flat_map(|ff| ff.0), ); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } // Get a working row for LowDegreeInterpolationGate. diff --git a/plonky2/src/gates/mod.rs b/plonky2/src/gates/mod.rs index 177db7cf..163a7dac 100644 --- a/plonky2/src/gates/mod.rs +++ b/plonky2/src/gates/mod.rs @@ -11,7 +11,6 @@ pub mod binary_arithmetic; pub mod binary_subtraction; pub mod comparison; pub mod constant; -// pub mod curve_double; pub mod exponentiation; pub mod gate; pub mod gate_tree; @@ -24,6 +23,7 @@ pub mod poseidon; pub(crate) mod poseidon_mds; pub(crate) mod public_input; pub mod random_access; +pub mod range_check_u32; pub mod reducing; pub mod reducing_extension; pub mod subtraction_u32; diff --git a/plonky2/src/gates/multiplication_extension.rs b/plonky2/src/gates/multiplication_extension.rs index 9ccfe637..54629a47 100644 --- a/plonky2/src/gates/multiplication_extension.rs +++ b/plonky2/src/gates/multiplication_extension.rs @@ -125,7 +125,7 @@ impl, const D: usize> Gate for MulExtensionGa ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/random_access.rs b/plonky2/src/gates/random_access.rs index 77359a19..6379f99f 100644 --- a/plonky2/src/gates/random_access.rs +++ b/plonky2/src/gates/random_access.rs @@ -209,7 +209,7 @@ impl, const D: usize> Gate for RandomAccessGa ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/range_check_u32.rs b/plonky2/src/gates/range_check_u32.rs new file mode 100644 index 00000000..2533b51f --- /dev/null +++ b/plonky2/src/gates/range_check_u32.rs @@ -0,0 +1,305 @@ +use std::marker::PhantomData; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; +use crate::gates::gate::Gate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::util::ceil_div_usize; + +/// A gate which can decompose a number into base B little-endian limbs. +#[derive(Copy, Clone, Debug)] +pub struct U32RangeCheckGate, const D: usize> { + pub num_input_limbs: usize, + _phantom: PhantomData, +} + +impl, const D: usize> U32RangeCheckGate { + pub fn new(num_input_limbs: usize) -> Self { + Self { + num_input_limbs, + _phantom: PhantomData, + } + } + + pub const AUX_LIMB_BITS: usize = 3; + pub const BASE: usize = 1 << Self::AUX_LIMB_BITS; + + fn aux_limbs_per_input_limb(&self) -> usize { + ceil_div_usize(32, Self::AUX_LIMB_BITS) + } + pub fn wire_ith_input_limb(&self, i: usize) -> usize{ + debug_assert!(i < self.num_input_limbs); + i + } + pub fn wire_ith_input_limb_jth_aux_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_input_limbs); + debug_assert!(j < self.aux_limbs_per_input_limb()); + self.num_input_limbs + self.aux_limbs_per_input_limb() * i + j + } +} + +impl, const D: usize> Gate for U32RangeCheckGate{ + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let base = F::Extension::from_canonical_usize(Self::BASE); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()).map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]).collect(); + let computed_sum = reduce_with_powers(&aux_limbs, base); + + constraints.push(computed_sum - input_limb); + for aux_limb in aux_limbs { + constraints.push( + (0..Self::BASE) + .map(|i| aux_limb - F::Extension::from_canonical_usize(i)) + .product(), + ); + } + + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let base = F::from_canonical_usize(Self::BASE); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()).map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]).collect(); + let computed_sum = reduce_with_powers(&aux_limbs, base); + + constraints.push(computed_sum - input_limb); + for aux_limb in aux_limbs { + constraints.push( + (0..Self::BASE) + .map(|i| aux_limb - F::from_canonical_usize(i)) + .product(), + ); + } + + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let base = builder.constant(F::from_canonical_usize(Self::BASE)); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()).map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]).collect(); + let computed_sum = reduce_with_powers_ext_recursive(builder, &aux_limbs, base); + + constraints.push(builder.sub_extension(computed_sum, input_limb)); + for aux_limb in aux_limbs { + constraints.push({ + let mut acc = builder.one_extension(); + (0..Self::BASE).for_each(|i| { + // We update our accumulator as: + // acc' = acc (x - i) + // = acc x + (-i) acc + // Since -i is constant, we can do this in one arithmetic_extension call. + let neg_i = -F::from_canonical_usize(i); + acc = builder.arithmetic_extension(F::ONE, neg_i, acc, aux_limb, acc) + }); + acc + }); + } + + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = U32RangeCheckGenerator { + gate: self.clone(), + gate_index, + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.num_input_limbs * (1 + self.aux_limbs_per_input_limb()) + } + + fn num_constants(&self) -> usize { + 0 + } + + // Bounded by the range-check (x-0)*(x-1)*...*(x-BASE+1). + fn degree(&self) -> usize { + Self::BASE + } + + // 1 for checking the each sum of aux limbs, plus a range check for each aux limb. + fn num_constraints(&self) -> usize { + self.num_input_limbs * (1 + self.aux_limbs_per_input_limb()) + } +} + +#[derive(Debug)] +pub struct U32RangeCheckGenerator, const D: usize> { + gate: U32RangeCheckGate, + gate_index: usize, +} + +impl, const D: usize> SimpleGenerator for U32RangeCheckGenerator { + fn dependencies(&self) -> Vec { + let num_input_limbs = self.gate.num_input_limbs; + (0..num_input_limbs).map(|i| Target::wire(self.gate_index, self.gate.wire_ith_input_limb(i))).collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let num_input_limbs = self.gate.num_input_limbs; + for i in 0..num_input_limbs { + let sum_value = witness + .get_target(Target::wire(self.gate_index, self.gate.wire_ith_input_limb(i))) + .to_canonical_u64() as u32; + + let base = U32RangeCheckGate::::BASE as u32; + let limbs = (0..self.gate.aux_limbs_per_input_limb()) + .map(|j| Target::wire(self.gate_index, self.gate.wire_ith_input_limb_jth_aux_limb(i, j))); + let limbs_value = (0..self.gate.aux_limbs_per_input_limb()) + .scan(sum_value, |acc, _| { + let tmp = *acc % base; + *acc /= base; + Some(F::from_canonical_u32(tmp)) + }) + .collect::>(); + + for (b, b_value) in limbs.zip(limbs_value) { + out_buffer.set_target(b, b_value); + } + } + + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use itertools::unfold; + use rand::Rng; + + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::range_check_u32::U32RangeCheckGate; + use crate::hash::hash_types::HashOut; + use crate::plonk::vars::EvaluationVars; + use crate::util::ceil_div_usize; + + #[test] + fn low_degree() { + test_low_degree::(U32RangeCheckGate::new(8)) + } + + #[test] + fn eval_fns() -> Result<()> { + test_eval_fns::(U32RangeCheckGate::new(8)) + } + + fn test_gate_constraint(input_limbs: Vec) { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + const AUX_LIMB_BITS: usize = 3; + const BASE: usize = 1 << AUX_LIMB_BITS; + const AUX_LIMBS_PER_INPUT_LIMB: usize = ceil_div_usize(32, AUX_LIMB_BITS); + + fn get_wires(input_limbs: Vec) -> Vec { + let num_input_limbs = input_limbs.len(); + let mut v = Vec::new(); + + for i in 0..num_input_limbs { + let input_limb = input_limbs[i]; + + let split_to_limbs = |mut val, num| { + unfold((), move |_| { + let ret = val % (BASE as u64); + val /= BASE as u64; + Some(ret) + }) + .take(num) + .map(F::from_canonical_u64) + }; + + let mut aux_limbs: Vec<_> = + split_to_limbs(input_limb, AUX_LIMBS_PER_INPUT_LIMB).collect(); + + v.append(&mut aux_limbs); + } + + input_limbs.iter() + .cloned() + .map(F::from_canonical_u64) + .chain(v.iter().cloned()) + .map(|x| x.into()) + .collect() + } + + let gate = U32RangeCheckGate:: { + num_input_limbs: 8, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(input_limbs), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } + + #[test] + fn test_gate_constraint_good() { + let mut rng = rand::thread_rng(); + let input_limbs: Vec<_> = (0..8) + .map(|_| rng.gen::() as u64) + .collect(); + + test_gate_constraint(input_limbs); + } + + #[test] + #[should_panic] + fn test_gate_constraint_bad() { + let mut rng = rand::thread_rng(); + let input_limbs: Vec<_> = (0..8) + .map(|_| rng.gen()) + .collect(); + + test_gate_constraint(input_limbs); + } +} diff --git a/plonky2/src/gates/subtraction_u32.rs b/plonky2/src/gates/subtraction_u32.rs index fa817ce4..ffb2e2cb 100644 --- a/plonky2/src/gates/subtraction_u32.rs +++ b/plonky2/src/gates/subtraction_u32.rs @@ -419,7 +419,7 @@ mod tests { v0.iter() .chain(v1.iter()) .map(|&x| x.into()) - .collect::>() + .collect() } let mut rng = rand::thread_rng();