diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index d1e6cdf0..dd0a8e9c 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -44,9 +44,13 @@ impl, const D: usize> ExponentiationGate { 2 + i } + pub fn wires_output(&self) -> usize { + 2 + self.num_power_bits + } + pub fn wires_intermediate_value(&self, i: usize) -> usize { debug_assert!(i < self.num_power_bits); - 2 + self.num_power_bits + i + 3 + self.num_power_bits + i } } @@ -58,7 +62,6 @@ impl, const D: usize> Gate for ExponentiationGate { fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let base = vars.local_wires[self.wires_base()]; let power = vars.local_wires[self.wires_power()]; - let computed_output = base.exp(power.to_canonical_u64()); let power_bits: Vec<_> = (0..self.num_power_bits) .map(|i| vars.local_wires[self.wires_power_bit(i)]) @@ -67,6 +70,8 @@ impl, const D: usize> Gate for ExponentiationGate { .map(|i| vars.local_wires[self.wires_intermediate_value(i)]) .collect(); + let output = vars.local_wires[self.wires_output()]; + let mut constraints = Vec::new(); let computed_power = reduce_with_powers(&power_bits, F::Extension::TWO); @@ -88,6 +93,8 @@ impl, const D: usize> Gate for ExponentiationGate { constraints.push(computed_intermediate_value - intermediate_values[i]); } + constraints.push(output - intermediate_values[self.num_power_bits - 1]); + constraints } @@ -124,7 +131,7 @@ impl, const D: usize> Gate for ExponentiationGate { } fn num_constraints(&self) -> usize { - self.num_power_bits + 1 + self.num_power_bits + 2 } } @@ -157,11 +164,12 @@ impl, const D: usize> SimpleGenerator for ExponentiationGene let num_power_bits = self.gate.num_power_bits; let base = get_local_wire(self.gate.wires_base()); + let power_bits = (0..num_power_bits) .map(|i| get_local_wire(self.gate.wires_power_bit(i))) .collect::>(); - let mut intermediate_values = Vec::new(); + let mut current_intermediate_value = F::ONE; for i in 0..num_power_bits { if power_bits[i] == F::ONE { @@ -177,6 +185,9 @@ impl, const D: usize> SimpleGenerator for ExponentiationGene result.set_wire(intermediate_value_wire, intermediate_values[i]); } + let output_wire = local_wire(self.gate.wires_output()); + result.set_wire(output_wire, intermediate_values[num_power_bits - 1]); + result } } @@ -208,8 +219,9 @@ mod tests { assert_eq!(gate.wires_power(), 1); assert_eq!(gate.wires_power_bit(0), 2); assert_eq!(gate.wires_power_bit(4), 6); - assert_eq!(gate.wires_intermediate_value(0), 7); - assert_eq!(gate.wires_intermediate_value(4), 11); + assert_eq!(gate.wires_output(), 7); + assert_eq!(gate.wires_intermediate_value(0), 8); + assert_eq!(gate.wires_intermediate_value(4), 12); } #[test] @@ -255,6 +267,8 @@ mod tests { intermediate_values.push(current_intermediate_value); current_intermediate_value *= current_intermediate_value; } + let output_value = intermediate_values[num_power_bits - 1]; + v.push(output_value); v.extend(intermediate_values); v.iter().map(|&x| x.into()).collect::>()