From 6ab01e51f3da79d8bcc524905674123d5f5c91a5 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 1 Nov 2021 16:02:46 -0700 Subject: [PATCH] u32 arithmetic check for special cases --- src/gadgets/arithmetic_u32.rs | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index ba076a8f..22957075 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -36,6 +36,37 @@ impl, const D: usize> CircuitBuilder { self.assert_zero(x.0) } + /// Checks for special cases where the value of + /// `x * y + z` + /// can be determined without adding a `U32ArithmeticGate`. + pub fn arithmetic_u32_special_cases( + &mut self, + x: U32Target, + y: U32Target, + z: U32Target, + ) -> Option<(U32Target, U32Target)> { + let x_const = self.target_as_constant(x.0); + let y_const = self.target_as_constant(y.0); + let z_const = self.target_as_constant(z.0); + + // If both terms are constant, return their (constant) sum. + let first_term_const = if let (Some(xx), Some(yy)) = (x_const, y_const) { + Some(xx * yy) + } else { + None + }; + + if let (Some(a), Some(b)) = (first_term_const, z_const) { + let sum_u64 = (a + b).to_canonical_u64(); + let (low_u64, high_u64) = (sum_u64 % (1u64 << 32), sum_u64 >> 32); + let low = F::from_canonical_u64(low_u64); + let high = F::from_canonical_u64(high_u64); + return Some((self.constant_u32(low), self.constant_u32(high))); + } + + None + } + // Returns x * y + z. pub fn mul_add_u32( &mut self, @@ -43,6 +74,10 @@ impl, const D: usize> CircuitBuilder { y: U32Target, z: U32Target, ) -> (U32Target, U32Target) { + if let Some(result) = self.arithmetic_u32_special_cases(x, y, z) { + return result; + } + let (gate_index, copy) = self.find_u32_arithmetic_gate(); self.connect(