diff --git a/u32/src/gates/arithmetic_u32.rs b/u32/src/gates/arithmetic_u32.rs index c46c9e47..d2a3860e 100644 --- a/u32/src/gates/arithmetic_u32.rs +++ b/u32/src/gates/arithmetic_u32.rs @@ -347,8 +347,10 @@ mod tests { use plonky2::gates::gate::Gate; use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; use plonky2::hash::hash_types::HashOut; + use plonky2::hash::hash_types::RichField; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::plonk::vars::EvaluationVars; + use plonky2_field::extension::Extendable; use plonky2_field::goldilocks_field::GoldilocksField; use plonky2_field::types::Field; use rand::Rng; @@ -374,6 +376,52 @@ mod tests { }) } + fn get_wires< + F: RichField + Extendable, + FF: From, + const D: usize, + const NUM_U32_ARITHMETIC_OPS: usize, + >( + multiplicands_0: Vec, + multiplicands_1: Vec, + addends: Vec, + ) -> Vec { + let mut v0 = Vec::new(); + let mut v1 = Vec::new(); + + let limb_bits = U32ArithmeticGate::::limb_bits(); + let num_limbs = U32ArithmeticGate::::num_limbs(); + let limb_base = 1 << limb_bits; + for c in 0..NUM_U32_ARITHMETIC_OPS { + let m0 = multiplicands_0[c]; + let m1 = multiplicands_1[c]; + let a = addends[c]; + + let mut output = m0 * m1 + a; + let output_low = output & ((1 << 32) - 1); + let output_high = output >> 32; + + let mut output_limbs = Vec::with_capacity(num_limbs); + for _i in 0..num_limbs { + output_limbs.push(output % limb_base); + output /= limb_base; + } + let mut output_limbs_f: Vec<_> = output_limbs + .into_iter() + .map(F::from_canonical_u64) + .collect(); + + v0.push(F::from_canonical_u64(m0)); + v0.push(F::from_canonical_u64(m1)); + v0.push(F::from_noncanonical_u64(a)); + v0.push(F::from_canonical_u64(output_low)); + v0.push(F::from_canonical_u64(output_high)); + v1.append(&mut output_limbs_f); + } + + v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() + } + #[test] fn test_gate_constraint() { const D: usize = 2; @@ -382,47 +430,6 @@ mod tests { type FF = >::FE; const NUM_U32_ARITHMETIC_OPS: usize = 3; - fn get_wires( - multiplicands_0: Vec, - multiplicands_1: Vec, - addends: Vec, - ) -> Vec { - let mut v0 = Vec::new(); - let mut v1 = Vec::new(); - - let limb_bits = U32ArithmeticGate::::limb_bits(); - let num_limbs = U32ArithmeticGate::::num_limbs(); - let limb_base = 1 << limb_bits; - for c in 0..NUM_U32_ARITHMETIC_OPS { - let m0 = multiplicands_0[c]; - let m1 = multiplicands_1[c]; - let a = addends[c]; - - let mut output = m0 * m1 + a; - let output_low = output & ((1 << 32) - 1); - let output_high = output >> 32; - - let mut output_limbs = Vec::with_capacity(num_limbs); - for _i in 0..num_limbs { - output_limbs.push(output % limb_base); - output /= limb_base; - } - let mut output_limbs_f: Vec<_> = output_limbs - .into_iter() - .map(F::from_canonical_u64) - .collect(); - - v0.push(F::from_canonical_u64(m0)); - v0.push(F::from_canonical_u64(m1)); - v0.push(F::from_canonical_u64(a)); - v0.push(F::from_canonical_u64(output_low)); - v0.push(F::from_canonical_u64(output_high)); - v1.append(&mut output_limbs_f); - } - - v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() - } - let mut rng = rand::thread_rng(); let multiplicands_0: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS) .map(|_| rng.gen::() as u64) @@ -441,7 +448,11 @@ mod tests { let vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(multiplicands_0, multiplicands_1, addends), + local_wires: &get_wires::( + multiplicands_0, + multiplicands_1, + addends, + ), public_inputs_hash: &HashOut::rand(), };