From a38a5e227d7ccb98f46497a766edd9052eb7de56 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 27 Jul 2021 22:51:40 -0700 Subject: [PATCH] select_ext takes bit as extension; used in recursive eval --- src/gadgets/select.rs | 16 ++++++++-------- src/gates/exponentiation.rs | 6 ++---- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/gadgets/select.rs b/src/gadgets/select.rs index bbd36d76..9df395d8 100644 --- a/src/gadgets/select.rs +++ b/src/gadgets/select.rs @@ -10,24 +10,24 @@ impl, const D: usize> CircuitBuilder { /// Note: This does not range-check `b`. pub fn select_ext( &mut self, - b: Target, + b: ExtensionTarget, x: ExtensionTarget, y: ExtensionTarget, ) -> ExtensionTarget { - let b_ext = self.convert_to_ext(b); let gate = self.num_gates(); // Holds `by - y`. let first_out = ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); - self.double_arithmetic_extension(F::ONE, F::NEG_ONE, b_ext, y, y, b_ext, x, first_out) + self.double_arithmetic_extension(F::ONE, F::NEG_ONE, b, y, y, b, x, first_out) .1 } /// See `select_ext`. pub fn select(&mut self, b: Target, x: Target, y: Target) -> Target { + let b_ext = self.convert_to_ext(b); let x_ext = self.convert_to_ext(x); let y_ext = self.convert_to_ext(y); - self.select_ext(b, x_ext, y_ext).to_target_array()[0] + self.select_ext(b_ext, x_ext, y_ext).to_target_array()[0] } } @@ -54,13 +54,13 @@ mod tests { let (x, y) = (FF::rand(), FF::rand()); let xt = builder.add_virtual_extension_target(); let yt = builder.add_virtual_extension_target(); - let truet = builder.add_virtual_target(); - let falset = builder.add_virtual_target(); + let truet = builder.add_virtual_extension_target(); + let falset = builder.add_virtual_extension_target(); pw.set_extension_target(xt, x); pw.set_extension_target(yt, y); - pw.set_target(truet, F::ONE); - pw.set_target(falset, F::ZERO); + pw.set_extension_target(truet, FF::ONE); + pw.set_extension_target(falset, FF::ZERO); let should_be_x = builder.select_ext(truet, xt, yt); let should_be_y = builder.select_ext(falset, xt, yt); diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index d7b99166..db051bcb 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -171,9 +171,7 @@ impl, const D: usize> Gate for ExponentiationGate { // power_bits is in LE order, but we accumulate in BE order. let cur_bit = power_bits[self.num_power_bits - i - 1]; - - let not_cur_bit = builder.sub_extension(one, cur_bit); - let mul_by = builder.mul_add_extension(cur_bit, base, not_cur_bit); + let mul_by = builder.select_ext(cur_bit, base, one); let computed_intermediate_value = builder.mul_extension(prev_intermediate_value, mul_by); let intermediate_value_diff = @@ -261,7 +259,7 @@ impl, const D: usize> SimpleGenerator for ExponentiationGene current_intermediate_value *= current_intermediate_value; } - let mut result = GeneratedValues::::with_capacity(num_power_bits); + let mut result = GeneratedValues::::with_capacity(num_power_bits + 1); for i in 0..num_power_bits { let intermediate_value_wire = local_wire(self.gate.wires_intermediate_value(i)); result.set_wire(intermediate_value_wire, intermediate_values[i]);