From eb5a60bef110d1a0b2fd731ea9a4dac6e780ac63 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 16 Nov 2021 09:29:14 -0800 Subject: [PATCH] Allow one BaseSumGate to handle 64 bits (#365) --- src/gadgets/split_join.rs | 23 +++++++++++++++-------- src/gates/base_sum.rs | 3 +-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 39527c6a..72786bd8 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -24,8 +24,7 @@ impl, const D: usize> CircuitBuilder { let mut bits = Vec::with_capacity(num_bits); for &gate in &gates { - let start_limbs = BaseSumGate::<2>::START_LIMBS; - for limb_input in start_limbs..start_limbs + gate_type.num_limbs { + for limb_input in gate_type.limbs() { // `new_unsafe` is safe here because BaseSumGate::<2> forces it to be in `{0, 1}`. bits.push(BoolTarget::new_unsafe(Target::wire(gate, limb_input))); } @@ -35,10 +34,11 @@ impl, const D: usize> CircuitBuilder { } let zero = self.zero(); + let base = F::TWO.exp_u64(gate_type.num_limbs as u64); let mut acc = zero; for &gate in gates.iter().rev() { let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); - acc = self.mul_const_add(F::from_canonical_usize(1 << gate_type.num_limbs), acc, sum); + acc = self.mul_const_add(base, acc, sum); } self.connect(acc, integer); @@ -96,11 +96,18 @@ impl SimpleGenerator for WireSplitGenerator { for &gate in &self.gates { let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); - out_buffer.set_target( - sum, - F::from_canonical_u64(integer_value & ((1 << self.num_limbs) - 1)), - ); - integer_value >>= self.num_limbs; + + // If num_limbs >= 64, we don't need to truncate since `integer_value` is already + // limited to 64 bits, and trying to do so would cause overflow. Hence the conditional. + let mut truncated_value = integer_value; + if self.num_limbs < 64 { + truncated_value = integer_value & ((1 << self.num_limbs) - 1); + integer_value >>= self.num_limbs; + } else { + integer_value = 0; + }; + + out_buffer.set_target(sum, F::from_canonical_u64(truncated_value)); } debug_assert_eq!( diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 99ee05eb..2ab5345b 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -24,8 +24,7 @@ impl BaseSumGate { } pub fn new_from_config(config: &CircuitConfig) -> Self { - let num_limbs = ((F::ORDER as f64).log(B as f64).floor() as usize) - .min(config.num_routed_wires - Self::START_LIMBS); + let num_limbs = F::bits().min(config.num_routed_wires - Self::START_LIMBS); Self::new(num_limbs) }