From 3d207464f5a721c6dab5628e464d98b04d5b1546 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 29 Sep 2021 20:04:42 -0700 Subject: [PATCH] combine halves separately --- src/gates/arithmetic_u32.rs | 68 ++++++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index baa52497..557a8500 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -81,7 +81,9 @@ impl, const D: usize> Gate for U32ArithmeticG constraints.push(combined_output - computed_output); - let mut combined_limbs = F::Extension::ZERO; + let mut combined_low_limbs = F::Extension::ZERO; + let mut combined_high_limbs = F::Extension::ZERO; + let midpoint = Self::num_limbs() / 2; for j in 0..Self::num_limbs() { let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); @@ -90,13 +92,18 @@ impl, const D: usize> Gate for U32ArithmeticG .product(); constraints.push(product); - let base = F::Extension::from_canonical_u64(1u64 << (j * Self::limb_bits())); - combined_limbs += base * this_limb; + if j < midpoint { + let base = F::Extension::from_canonical_u64(1u64 << (j * Self::limb_bits())); + combined_low_limbs += base * this_limb; + } else { + let base = F::Extension::from_canonical_u64( + 1u64 << ((j - midpoint) * Self::limb_bits()), + ); + combined_high_limbs += base * this_limb; + } } - - let combined_halves = - output_low + F::Extension::from_canonical_u64(1 << 32u64) * output_high; - constraints.push(combined_limbs - combined_halves); + constraints.push(combined_low_limbs - output_low); + constraints.push(combined_high_limbs - output_high); } constraints @@ -119,7 +126,9 @@ impl, const D: usize> Gate for U32ArithmeticG constraints.push(combined_output - computed_output); - let mut combined_limbs = F::ZERO; + let mut combined_low_limbs = F::ZERO; + let mut combined_high_limbs = F::ZERO; + let midpoint = Self::num_limbs() / 2; for j in 0..Self::num_limbs() { let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); @@ -128,12 +137,16 @@ impl, const D: usize> Gate for U32ArithmeticG .product(); constraints.push(product); - let base = F::from_canonical_u64(1u64 << (j * Self::limb_bits())); - combined_limbs += base * this_limb; + if j < midpoint { + let base = F::from_canonical_u64(1u64 << (j * Self::limb_bits())); + combined_low_limbs += base * this_limb; + } else { + let base = F::from_canonical_u64(1u64 << ((j - midpoint) * Self::limb_bits())); + combined_high_limbs += base * this_limb; + } } - - let combined_halves = output_low + F::from_canonical_u64(1 << 32u64) * output_high; - constraints.push(combined_limbs - combined_halves); + constraints.push(combined_low_limbs - output_low); + constraints.push(combined_high_limbs - output_high); } constraints @@ -162,7 +175,9 @@ impl, const D: usize> Gate for U32ArithmeticG constraints.push(builder.sub_extension(combined_output, computed_output)); - let mut combined_limbs = builder.zero_extension(); + let mut combined_low_limbs = builder.zero_extension(); + let mut combined_high_limbs = builder.zero_extension(); + let midpoint = Self::num_limbs() / 2; for j in 0..Self::num_limbs() { let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); @@ -176,16 +191,23 @@ impl, const D: usize> Gate for U32ArithmeticG } constraints.push(product); - let base = builder.constant_extension(F::Extension::from_canonical_u64( - 1u64 << (j * Self::limb_bits()), - )); - combined_limbs = builder.mul_add_extension(base, this_limb, combined_limbs); + if j < midpoint { + let base = builder.constant_extension(F::Extension::from_canonical_u64( + 1u64 << (j * Self::limb_bits()), + )); + combined_low_limbs = + builder.mul_add_extension(base, this_limb, combined_low_limbs); + } else { + let base = builder.constant_extension(F::Extension::from_canonical_u64( + 1u64 << ((j - midpoint) * Self::limb_bits()), + )); + combined_high_limbs = + builder.mul_add_extension(base, this_limb, combined_high_limbs); + } } - let high_base = - builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64)); - let combined_halves = builder.mul_add_extension(output_high, high_base, output_low); - constraints.push(builder.sub_extension(combined_limbs, combined_halves)); + constraints.push(builder.sub_extension(combined_low_limbs, output_low)); + constraints.push(builder.sub_extension(combined_high_limbs, output_high)); } constraints @@ -224,7 +246,7 @@ impl, const D: usize> Gate for U32ArithmeticG } fn num_constraints(&self) -> usize { - NUM_U32_ARITHMETIC_OPS * (2 + Self::num_limbs()) + NUM_U32_ARITHMETIC_OPS * (3 + Self::num_limbs()) } }