diff --git a/u32/src/gates/arithmetic_u32.rs b/u32/src/gates/arithmetic_u32.rs index c46c9e47..c05ed86c 100644 --- a/u32/src/gates/arithmetic_u32.rs +++ b/u32/src/gates/arithmetic_u32.rs @@ -36,31 +36,36 @@ impl, const D: usize> U32ArithmeticGate { } pub(crate) fn num_ops(config: &CircuitConfig) -> usize { - let wires_per_op = 5 + Self::num_limbs(); - let routed_wires_per_op = 5; - (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) + let wires_per_op = Self::routed_wires_per_op() + Self::num_limbs(); + (config.num_wires / wires_per_op).min(config.num_routed_wires / Self::routed_wires_per_op()) } pub fn wire_ith_multiplicand_0(&self, i: usize) -> usize { debug_assert!(i < self.num_ops); - 5 * i + Self::routed_wires_per_op() * i } pub fn wire_ith_multiplicand_1(&self, i: usize) -> usize { debug_assert!(i < self.num_ops); - 5 * i + 1 + Self::routed_wires_per_op() * i + 1 } pub fn wire_ith_addend(&self, i: usize) -> usize { debug_assert!(i < self.num_ops); - 5 * i + 2 + Self::routed_wires_per_op() * i + 2 } pub fn wire_ith_output_low_half(&self, i: usize) -> usize { debug_assert!(i < self.num_ops); - 5 * i + 3 + Self::routed_wires_per_op() * i + 3 } + pub fn wire_ith_output_high_half(&self, i: usize) -> usize { debug_assert!(i < self.num_ops); - 5 * i + 4 + Self::routed_wires_per_op() * i + 4 + } + + pub fn wire_ith_inverse(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + Self::routed_wires_per_op() * i + 5 } pub fn limb_bits() -> usize { @@ -69,11 +74,13 @@ impl, const D: usize> U32ArithmeticGate { pub fn num_limbs() -> usize { 64 / Self::limb_bits() } - + pub fn routed_wires_per_op() -> usize { + 6 + } pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { debug_assert!(i < self.num_ops); debug_assert!(j < Self::num_limbs()); - 5 * self.num_ops + Self::num_limbs() * i + j + Self::routed_wires_per_op() * self.num_ops + Self::num_limbs() * i + j } } @@ -93,9 +100,28 @@ impl, const D: usize> Gate for U32ArithmeticG let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; + let inverse = vars.local_wires[self.wire_ith_inverse(i)]; - let base = F::Extension::from_canonical_u64(1 << 32u64); - let combined_output = output_high * base + output_low; + // Check canonicity of combined_output = output_high * 2^32 + output_low + let combined_output = { + let base = F::Extension::from_canonical_u64(1 << 32u64); + let one = F::Extension::ONE; + let u32_max = F::Extension::from_canonical_u32(u32::MAX); + + // This is zero if and only if the high limb is `u32::MAX`. + // u32::MAX - output_high + let diff = u32_max - output_high; + // If this is zero, the diff is invertible, so the high limb is not `u32::MAX`. + // inverse * diff - 1 + let hi_not_max = inverse * diff - one; + // If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero. + // hi_not_max * limb_0_u32 + let hi_not_max_or_lo_zero = hi_not_max * output_low; + + constraints.push(hi_not_max_or_lo_zero); + + output_high * base + output_low + }; constraints.push(combined_output - computed_output); @@ -152,10 +178,27 @@ impl, const D: usize> Gate for U32ArithmeticG let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; + let inverse = vars.local_wires[self.wire_ith_inverse(i)]; - let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); - let base_target = builder.constant_extension(base); - let combined_output = builder.mul_add_extension(output_high, base_target, output_low); + // Check canonicity of combined_output = output_high * 2^32 + output_low + let combined_output = { + let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); + let base_target = builder.constant_extension(base); + let one = builder.one_extension(); + let u32_max = + builder.constant_extension(F::Extension::from_canonical_u32(u32::MAX)); + + // This is zero if and only if the high limb is `u32::MAX`. + let diff = builder.sub_extension(u32_max, output_high); + // If this is zero, the diff is invertible, so the high limb is not `u32::MAX`. + let hi_not_max = builder.mul_sub_extension(inverse, diff, one); + // If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero. + let hi_not_max_or_lo_zero = builder.mul_extension(hi_not_max, output_low); + + constraints.push(hi_not_max_or_lo_zero); + + builder.mul_add_extension(output_high, base_target, output_low) + }; constraints.push(builder.sub_extension(combined_output, computed_output)); @@ -211,7 +254,7 @@ impl, const D: usize> Gate for U32ArithmeticG } fn num_wires(&self) -> usize { - self.num_ops * (5 + Self::num_limbs()) + self.num_ops * (Self::routed_wires_per_op() + Self::num_limbs()) } fn num_constants(&self) -> usize { @@ -223,7 +266,7 @@ impl, const D: usize> Gate for U32ArithmeticG } fn num_constraints(&self) -> usize { - self.num_ops * (3 + Self::num_limbs()) + self.num_ops * (4 + Self::num_limbs()) } } @@ -244,9 +287,27 @@ impl, const D: usize> PackedEvaluableBase let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; + let inverse = vars.local_wires[self.wire_ith_inverse(i)]; - let base = F::from_canonical_u64(1 << 32u64); - let combined_output = output_high * base + output_low; + let combined_output = { + let base = P::from(F::from_canonical_u64(1 << 32u64)); + let one = P::ONES; + let u32_max = P::from(F::from_canonical_u32(u32::MAX)); + + // This is zero if and only if the high limb is `u32::MAX`. + // u32::MAX - output_high + let diff = u32_max - output_high; + // If this is zero, the diff is invertible, so the high limb is not `u32::MAX`. + // inverse * diff - 1 + let hi_not_max = inverse * diff - one; + // If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero. + // hi_not_max * limb_0_u32 + let hi_not_max_or_lo_zero = hi_not_max * output_low; + + yield_constr.one(hi_not_max_or_lo_zero); + + output_high * base + output_low + }; yield_constr.one(combined_output - computed_output); @@ -322,6 +383,15 @@ impl, const D: usize> SimpleGenerator out_buffer.set_wire(output_high_wire, output_high); out_buffer.set_wire(output_low_wire, output_low); + let diff = u32::MAX as u64 - output_high_u64; + let inverse = if diff == 0 { + F::ZERO + } else { + F::from_canonical_u64(diff).inverse() + }; + let inverse_wire = local_wire(self.gate.wire_ith_inverse(self.i)); + out_buffer.set_wire(inverse_wire, inverse); + let num_limbs = U32ArithmeticGate::::num_limbs(); let limb_base = 1 << U32ArithmeticGate::::limb_bits(); let output_limbs_u64 = unfold((), move |_| { @@ -347,8 +417,10 @@ mod tests { use plonky2::gates::gate::Gate; use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; use plonky2::hash::hash_types::HashOut; + use plonky2::hash::hash_types::RichField; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::plonk::vars::EvaluationVars; + use plonky2_field::extension::Extendable; use plonky2_field::goldilocks_field::GoldilocksField; use plonky2_field::types::Field; use rand::Rng; @@ -374,6 +446,59 @@ mod tests { }) } + fn get_wires< + F: RichField + Extendable, + FF: From, + const D: usize, + const NUM_U32_ARITHMETIC_OPS: usize, + >( + multiplicands_0: Vec, + multiplicands_1: Vec, + addends: Vec, + ) -> Vec { + let mut v0 = Vec::new(); + let mut v1 = Vec::new(); + + let limb_bits = U32ArithmeticGate::::limb_bits(); + let num_limbs = U32ArithmeticGate::::num_limbs(); + let limb_base = 1 << limb_bits; + for c in 0..NUM_U32_ARITHMETIC_OPS { + let m0 = multiplicands_0[c]; + let m1 = multiplicands_1[c]; + let a = addends[c]; + + let mut output = m0 * m1 + a; + let output_low = output & ((1 << 32) - 1); + let output_high = output >> 32; + let diff = u32::MAX as u64 - output_high; + let inverse = if diff == 0 { + F::ZERO + } else { + F::from_canonical_u64(diff).inverse() + }; + + let mut output_limbs = Vec::with_capacity(num_limbs); + for _i in 0..num_limbs { + output_limbs.push(output % limb_base); + output /= limb_base; + } + let mut output_limbs_f: Vec<_> = output_limbs + .into_iter() + .map(F::from_canonical_u64) + .collect(); + + v0.push(F::from_canonical_u64(m0)); + v0.push(F::from_canonical_u64(m1)); + v0.push(F::from_noncanonical_u64(a)); + v0.push(F::from_canonical_u64(output_low)); + v0.push(F::from_canonical_u64(output_high)); + v0.push(inverse); + v1.append(&mut output_limbs_f); + } + + v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() + } + #[test] fn test_gate_constraint() { const D: usize = 2; @@ -382,47 +507,6 @@ mod tests { type FF = >::FE; const NUM_U32_ARITHMETIC_OPS: usize = 3; - fn get_wires( - multiplicands_0: Vec, - multiplicands_1: Vec, - addends: Vec, - ) -> Vec { - let mut v0 = Vec::new(); - let mut v1 = Vec::new(); - - let limb_bits = U32ArithmeticGate::::limb_bits(); - let num_limbs = U32ArithmeticGate::::num_limbs(); - let limb_base = 1 << limb_bits; - for c in 0..NUM_U32_ARITHMETIC_OPS { - let m0 = multiplicands_0[c]; - let m1 = multiplicands_1[c]; - let a = addends[c]; - - let mut output = m0 * m1 + a; - let output_low = output & ((1 << 32) - 1); - let output_high = output >> 32; - - let mut output_limbs = Vec::with_capacity(num_limbs); - for _i in 0..num_limbs { - output_limbs.push(output % limb_base); - output /= limb_base; - } - let mut output_limbs_f: Vec<_> = output_limbs - .into_iter() - .map(F::from_canonical_u64) - .collect(); - - v0.push(F::from_canonical_u64(m0)); - v0.push(F::from_canonical_u64(m1)); - v0.push(F::from_canonical_u64(a)); - v0.push(F::from_canonical_u64(output_low)); - v0.push(F::from_canonical_u64(output_high)); - v1.append(&mut output_limbs_f); - } - - v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() - } - let mut rng = rand::thread_rng(); let multiplicands_0: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS) .map(|_| rng.gen::() as u64) @@ -441,7 +525,11 @@ mod tests { let vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(multiplicands_0, multiplicands_1, addends), + local_wires: &get_wires::( + multiplicands_0, + multiplicands_1, + addends, + ), public_inputs_hash: &HashOut::rand(), }; @@ -450,4 +538,39 @@ mod tests { "Gate constraints are not satisfied." ); } + + #[test] + fn test_canonicity() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; + const NUM_U32_ARITHMETIC_OPS: usize = 3; + + let multiplicands_0 = vec![0; NUM_U32_ARITHMETIC_OPS]; + let multiplicands_1 = vec![0; NUM_U32_ARITHMETIC_OPS]; + // A non-canonical addend will produce a non-canonical output using + // get_wires. + let addends = vec![0xFFFFFFFF00000001; NUM_U32_ARITHMETIC_OPS]; + + let gate = U32ArithmeticGate:: { + num_ops: NUM_U32_ARITHMETIC_OPS, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires::( + multiplicands_0, + multiplicands_1, + addends, + ), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + !gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Non-canonical output should not pass constraints." + ); + } }