From ccc9c024a21a7cbc1f010c067be6aacb6fee28a1 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 3 Jun 2022 18:06:14 +0200 Subject: [PATCH] Change some fn to take iterators instead of slices --- evm/src/permutation.rs | 4 +-- plonky2/src/gadgets/arithmetic.rs | 9 ++++-- plonky2/src/gadgets/arithmetic_extension.rs | 36 +++++++++++++-------- plonky2/src/gates/arithmetic_base.rs | 2 +- plonky2/src/gates/gate.rs | 2 +- starky/src/permutation.rs | 4 +-- system_zero/src/alu/addition.rs | 2 +- system_zero/src/alu/bitops.rs | 4 +-- system_zero/src/permutation_unit.rs | 6 ++-- 9 files changed, 42 insertions(+), 27 deletions(-) diff --git a/evm/src/permutation.rs b/evm/src/permutation.rs index 35bd92e5..cb690064 100644 --- a/evm/src/permutation.rs +++ b/evm/src/permutation.rs @@ -421,8 +421,8 @@ pub(crate) fn eval_permutation_checks_circuit( ) }) .unzip(); - let reduced_lhs_product = builder.mul_many_extension(&reduced_lhs); - let reduced_rhs_product = builder.mul_many_extension(&reduced_rhs); + let reduced_lhs_product = builder.mul_many_extension(reduced_lhs); + let reduced_rhs_product = builder.mul_many_extension(reduced_rhs); // constraint = next_zs[i] * reduced_rhs_product - local_zs[i] * reduced_lhs_product let constraint = { let tmp = builder.mul_extension(local_zs[i], reduced_lhs_product); diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index 4306a3a9..fdfce6aa 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -188,8 +188,13 @@ impl, const D: usize> CircuitBuilder { } /// Add `n` `Target`s. - pub fn add_many(&mut self, terms: &[Target]) -> Target { - terms.iter().fold(self.zero(), |acc, &t| self.add(acc, t)) + pub fn add_many(&mut self, terms: impl IntoIterator) -> Target + where + T: Borrow, + { + terms + .into_iter() + .fold(self.zero(), |acc, t| self.add(acc, *t.borrow())) } /// Computes `x - y`. diff --git a/plonky2/src/gadgets/arithmetic_extension.rs b/plonky2/src/gadgets/arithmetic_extension.rs index ea3e8b13..83054a79 100644 --- a/plonky2/src/gadgets/arithmetic_extension.rs +++ b/plonky2/src/gadgets/arithmetic_extension.rs @@ -1,3 +1,5 @@ +use std::borrow::Borrow; + use plonky2_field::extension_field::FieldExtension; use plonky2_field::extension_field::{Extendable, OEF}; use plonky2_field::field_types::{Field, Field64}; @@ -204,12 +206,16 @@ impl, const D: usize> CircuitBuilder { } /// Add `n` `ExtensionTarget`s. - pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - let mut sum = self.zero_extension(); - for &term in terms { - sum = self.add_extension(sum, term); - } - sum + pub fn add_many_extension( + &mut self, + terms: impl IntoIterator, + ) -> ExtensionTarget + where + T: Borrow>, + { + terms.into_iter().fold(self.zero_extension(), |acc, t| { + self.add_extension(acc, *t.borrow()) + }) } pub fn sub_extension( @@ -257,7 +263,7 @@ impl, const D: usize> CircuitBuilder { /// Computes `x^3`. pub fn cube_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { - self.mul_many_extension(&[x, x, x]) + self.mul_many_extension([x, x, x]) } /// Returns `a * b + c`. @@ -301,12 +307,16 @@ impl, const D: usize> CircuitBuilder { } /// Multiply `n` `ExtensionTarget`s. - pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { - terms - .iter() - .copied() - .reduce(|acc, t| self.mul_extension(acc, t)) - .unwrap_or_else(|| self.one_extension()) + pub fn mul_many_extension( + &mut self, + terms: impl IntoIterator, + ) -> ExtensionTarget + where + T: Borrow>, + { + terms.into_iter().fold(self.one_extension(), |acc, t| { + self.mul_extension(acc, *t.borrow()) + }) } /// Like `mul_add`, but for `ExtensionTarget`s. diff --git a/plonky2/src/gates/arithmetic_base.rs b/plonky2/src/gates/arithmetic_base.rs index 9aa3c51a..86338ecf 100644 --- a/plonky2/src/gates/arithmetic_base.rs +++ b/plonky2/src/gates/arithmetic_base.rs @@ -102,7 +102,7 @@ impl, const D: usize> Gate for ArithmeticGate let output = vars.local_wires[Self::wire_ith_output(i)]; let computed_output = { let scaled_mul = - builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); + builder.mul_many_extension([const_0, multiplicand_0, multiplicand_1]); builder.mul_add_extension(const_1, addend, scaled_mul) }; diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index 02811bf5..56026103 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -265,5 +265,5 @@ fn compute_filter_circuit, const D: usize>( builder.sub_extension(c, s) }) .collect::>(); - builder.mul_many_extension(&v) + builder.mul_many_extension(v) } diff --git a/starky/src/permutation.rs b/starky/src/permutation.rs index c96e8ae7..6ee9ccb4 100644 --- a/starky/src/permutation.rs +++ b/starky/src/permutation.rs @@ -385,8 +385,8 @@ pub(crate) fn eval_permutation_checks_circuit( ) }) .unzip(); - let reduced_lhs_product = builder.mul_many_extension(&reduced_lhs); - let reduced_rhs_product = builder.mul_many_extension(&reduced_rhs); + let reduced_lhs_product = builder.mul_many_extension(reduced_lhs); + let reduced_rhs_product = builder.mul_many_extension(reduced_rhs); // constraint = next_zs[i] * reduced_rhs_product - local_zs[i] * reduced_lhs_product let constraint = { let tmp = builder.mul_extension(local_zs[i], reduced_lhs_product); diff --git a/system_zero/src/alu/addition.rs b/system_zero/src/alu/addition.rs index a2ddeea4..5ce6bfb3 100644 --- a/system_zero/src/alu/addition.rs +++ b/system_zero/src/alu/addition.rs @@ -62,7 +62,7 @@ pub(crate) fn eval_addition_circuit, const D: usize // this sum can be around 48 bits at most. let out = reduce_with_powers_ext_circuit(builder, &[out_1, out_2, out_3], limb_base); - let computed_out = builder.add_many_extension(&[in_1, in_2, in_3]); + let computed_out = builder.add_many_extension([in_1, in_2, in_3]); let diff = builder.sub_extension(out, computed_out); let filtered_diff = builder.mul_extension(is_add, diff); diff --git a/system_zero/src/alu/bitops.rs b/system_zero/src/alu/bitops.rs index 14501303..fb5effcd 100644 --- a/system_zero/src/alu/bitops.rs +++ b/system_zero/src/alu/bitops.rs @@ -163,7 +163,7 @@ fn eval_bitop_32_circuit, const D: usize>( let b_bits = input_b_regs.map(|r| lv[r]); // Ensure that the inputs are bits - let inst_constr = builder.add_many_extension(&[is_and, is_ior, is_xor, is_andnot]); + let inst_constr = builder.add_many_extension([is_and, is_ior, is_xor, is_andnot]); constrain_all_to_bits_circuit(builder, yield_constr, inst_constr, &a_bits); constrain_all_to_bits_circuit(builder, yield_constr, inst_constr, &b_bits); @@ -204,7 +204,7 @@ fn eval_bitop_32_circuit, const D: usize>( builder.mul_extension(t1, is_andnot) }; - let constr = builder.add_many_extension(&[and_constr, ior_constr, xor_constr, andnot_constr]); + let constr = builder.add_many_extension([and_constr, ior_constr, xor_constr, andnot_constr]); yield_constr.constraint(builder, constr); } diff --git a/system_zero/src/permutation_unit.rs b/system_zero/src/permutation_unit.rs index e2de7d0b..430481da 100644 --- a/system_zero/src/permutation_unit.rs +++ b/system_zero/src/permutation_unit.rs @@ -195,7 +195,7 @@ pub(crate) fn eval_permutation_unit_circuit, const builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]); yield_constr.constraint(builder, diff); let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; - state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); + state[i] = builder.mul_many_extension([state[i], state_cubed, state_cubed]); // Form state ** 7. } @@ -216,7 +216,7 @@ pub(crate) fn eval_permutation_unit_circuit, const let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]); yield_constr.constraint(builder, diff); let state0_cubed = local_values[col_partial_mid_sbox(r)]; - state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7. + state[0] = builder.mul_many_extension([state[0], state0_cubed, state0_cubed]); // Form state ** 7. let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]); yield_constr.constraint(builder, diff); state[0] = local_values[col_partial_after_sbox(r)]; @@ -237,7 +237,7 @@ pub(crate) fn eval_permutation_unit_circuit, const builder.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]); yield_constr.constraint(builder, diff); let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; - state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); + state[i] = builder.mul_many_extension([state[i], state_cubed, state_cubed]); // Form state ** 7. }