Allow one BaseSumGate to handle 64 bits (#365)

This commit is contained in:
Daniel Lubarov 2021-11-16 09:29:14 -08:00 committed by GitHub
parent 1e66cb9aee
commit eb5a60bef1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 10 deletions

View File

@ -24,8 +24,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut bits = Vec::with_capacity(num_bits); let mut bits = Vec::with_capacity(num_bits);
for &gate in &gates { for &gate in &gates {
let start_limbs = BaseSumGate::<2>::START_LIMBS; for limb_input in gate_type.limbs() {
for limb_input in start_limbs..start_limbs + gate_type.num_limbs {
// `new_unsafe` is safe here because BaseSumGate::<2> forces it to be in `{0, 1}`. // `new_unsafe` is safe here because BaseSumGate::<2> forces it to be in `{0, 1}`.
bits.push(BoolTarget::new_unsafe(Target::wire(gate, limb_input))); bits.push(BoolTarget::new_unsafe(Target::wire(gate, limb_input)));
} }
@ -35,10 +34,11 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
let zero = self.zero(); let zero = self.zero();
let base = F::TWO.exp_u64(gate_type.num_limbs as u64);
let mut acc = zero; let mut acc = zero;
for &gate in gates.iter().rev() { for &gate in gates.iter().rev() {
let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM);
acc = self.mul_const_add(F::from_canonical_usize(1 << gate_type.num_limbs), acc, sum); acc = self.mul_const_add(base, acc, sum);
} }
self.connect(acc, integer); self.connect(acc, integer);
@ -96,11 +96,18 @@ impl<F: RichField> SimpleGenerator<F> for WireSplitGenerator {
for &gate in &self.gates { for &gate in &self.gates {
let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM);
out_buffer.set_target(
sum, // If num_limbs >= 64, we don't need to truncate since `integer_value` is already
F::from_canonical_u64(integer_value & ((1 << self.num_limbs) - 1)), // limited to 64 bits, and trying to do so would cause overflow. Hence the conditional.
); let mut truncated_value = integer_value;
integer_value >>= self.num_limbs; if self.num_limbs < 64 {
truncated_value = integer_value & ((1 << self.num_limbs) - 1);
integer_value >>= self.num_limbs;
} else {
integer_value = 0;
};
out_buffer.set_target(sum, F::from_canonical_u64(truncated_value));
} }
debug_assert_eq!( debug_assert_eq!(

View File

@ -24,8 +24,7 @@ impl<const B: usize> BaseSumGate<B> {
} }
pub fn new_from_config<F: PrimeField>(config: &CircuitConfig) -> Self { pub fn new_from_config<F: PrimeField>(config: &CircuitConfig) -> Self {
let num_limbs = ((F::ORDER as f64).log(B as f64).floor() as usize) let num_limbs = F::bits().min(config.num_routed_wires - Self::START_LIMBS);
.min(config.num_routed_wires - Self::START_LIMBS);
Self::new(num_limbs) Self::new(num_limbs)
} }