Change some fn to take iterators instead of slices

This commit is contained in:
wborgeaud 2022-06-03 18:06:14 +02:00
parent d6006f8fff
commit ccc9c024a2
9 changed files with 42 additions and 27 deletions

View File

@ -421,8 +421,8 @@ pub(crate) fn eval_permutation_checks_circuit<F, S, const D: usize>(
)
})
.unzip();
let reduced_lhs_product = builder.mul_many_extension(&reduced_lhs);
let reduced_rhs_product = builder.mul_many_extension(&reduced_rhs);
let reduced_lhs_product = builder.mul_many_extension(reduced_lhs);
let reduced_rhs_product = builder.mul_many_extension(reduced_rhs);
// constraint = next_zs[i] * reduced_rhs_product - local_zs[i] * reduced_lhs_product
let constraint = {
let tmp = builder.mul_extension(local_zs[i], reduced_lhs_product);

View File

@ -188,8 +188,13 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
/// Add `n` `Target`s.
pub fn add_many(&mut self, terms: &[Target]) -> Target {
terms.iter().fold(self.zero(), |acc, &t| self.add(acc, t))
pub fn add_many<T>(&mut self, terms: impl IntoIterator<Item = T>) -> Target
where
T: Borrow<Target>,
{
terms
.into_iter()
.fold(self.zero(), |acc, t| self.add(acc, *t.borrow()))
}
/// Computes `x - y`.

View File

@ -1,3 +1,5 @@
use std::borrow::Borrow;
use plonky2_field::extension_field::FieldExtension;
use plonky2_field::extension_field::{Extendable, OEF};
use plonky2_field::field_types::{Field, Field64};
@ -204,12 +206,16 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
/// Add `n` `ExtensionTarget`s.
pub fn add_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
let mut sum = self.zero_extension();
for &term in terms {
sum = self.add_extension(sum, term);
}
sum
pub fn add_many_extension<T>(
&mut self,
terms: impl IntoIterator<Item = T>,
) -> ExtensionTarget<D>
where
T: Borrow<ExtensionTarget<D>>,
{
terms.into_iter().fold(self.zero_extension(), |acc, t| {
self.add_extension(acc, *t.borrow())
})
}
pub fn sub_extension(
@ -257,7 +263,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Computes `x^3`.
pub fn cube_extension(&mut self, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
self.mul_many_extension(&[x, x, x])
self.mul_many_extension([x, x, x])
}
/// Returns `a * b + c`.
@ -301,12 +307,16 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
/// Multiply `n` `ExtensionTarget`s.
pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
terms
.iter()
.copied()
.reduce(|acc, t| self.mul_extension(acc, t))
.unwrap_or_else(|| self.one_extension())
pub fn mul_many_extension<T>(
&mut self,
terms: impl IntoIterator<Item = T>,
) -> ExtensionTarget<D>
where
T: Borrow<ExtensionTarget<D>>,
{
terms.into_iter().fold(self.one_extension(), |acc, t| {
self.mul_extension(acc, *t.borrow())
})
}
/// Like `mul_add`, but for `ExtensionTarget`s.

View File

@ -102,7 +102,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate
let output = vars.local_wires[Self::wire_ith_output(i)];
let computed_output = {
let scaled_mul =
builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]);
builder.mul_many_extension([const_0, multiplicand_0, multiplicand_1]);
builder.mul_add_extension(const_1, addend, scaled_mul)
};

View File

@ -265,5 +265,5 @@ fn compute_filter_circuit<F: RichField + Extendable<D>, const D: usize>(
builder.sub_extension(c, s)
})
.collect::<Vec<_>>();
builder.mul_many_extension(&v)
builder.mul_many_extension(v)
}

View File

@ -385,8 +385,8 @@ pub(crate) fn eval_permutation_checks_circuit<F, S, const D: usize>(
)
})
.unzip();
let reduced_lhs_product = builder.mul_many_extension(&reduced_lhs);
let reduced_rhs_product = builder.mul_many_extension(&reduced_rhs);
let reduced_lhs_product = builder.mul_many_extension(reduced_lhs);
let reduced_rhs_product = builder.mul_many_extension(reduced_rhs);
// constraint = next_zs[i] * reduced_rhs_product - local_zs[i] * reduced_lhs_product
let constraint = {
let tmp = builder.mul_extension(local_zs[i], reduced_lhs_product);

View File

@ -62,7 +62,7 @@ pub(crate) fn eval_addition_circuit<F: RichField + Extendable<D>, const D: usize
// this sum can be around 48 bits at most.
let out = reduce_with_powers_ext_circuit(builder, &[out_1, out_2, out_3], limb_base);
let computed_out = builder.add_many_extension(&[in_1, in_2, in_3]);
let computed_out = builder.add_many_extension([in_1, in_2, in_3]);
let diff = builder.sub_extension(out, computed_out);
let filtered_diff = builder.mul_extension(is_add, diff);

View File

@ -163,7 +163,7 @@ fn eval_bitop_32_circuit<F: RichField + Extendable<D>, const D: usize>(
let b_bits = input_b_regs.map(|r| lv[r]);
// Ensure that the inputs are bits
let inst_constr = builder.add_many_extension(&[is_and, is_ior, is_xor, is_andnot]);
let inst_constr = builder.add_many_extension([is_and, is_ior, is_xor, is_andnot]);
constrain_all_to_bits_circuit(builder, yield_constr, inst_constr, &a_bits);
constrain_all_to_bits_circuit(builder, yield_constr, inst_constr, &b_bits);
@ -204,7 +204,7 @@ fn eval_bitop_32_circuit<F: RichField + Extendable<D>, const D: usize>(
builder.mul_extension(t1, is_andnot)
};
let constr = builder.add_many_extension(&[and_constr, ior_constr, xor_constr, andnot_constr]);
let constr = builder.add_many_extension([and_constr, ior_constr, xor_constr, andnot_constr]);
yield_constr.constraint(builder, constr);
}

View File

@ -195,7 +195,7 @@ pub(crate) fn eval_permutation_unit_circuit<F: RichField + Extendable<D>, const
builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]);
yield_constr.constraint(builder, diff);
let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
state[i] = builder.mul_many_extension([state[i], state_cubed, state_cubed]);
// Form state ** 7.
}
@ -216,7 +216,7 @@ pub(crate) fn eval_permutation_unit_circuit<F: RichField + Extendable<D>, const
let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]);
yield_constr.constraint(builder, diff);
let state0_cubed = local_values[col_partial_mid_sbox(r)];
state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7.
state[0] = builder.mul_many_extension([state[0], state0_cubed, state0_cubed]); // Form state ** 7.
let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]);
yield_constr.constraint(builder, diff);
state[0] = local_values[col_partial_after_sbox(r)];
@ -237,7 +237,7 @@ pub(crate) fn eval_permutation_unit_circuit<F: RichField + Extendable<D>, const
builder.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]);
yield_constr.constraint(builder, diff);
let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
state[i] = builder.mul_many_extension([state[i], state_cubed, state_cubed]);
// Form state ** 7.
}