Merge branch 'main' into filtered_ctl

This commit is contained in:
wborgeaud 2022-06-07 22:59:43 +02:00
commit 2ff738633b
9 changed files with 42 additions and 27 deletions

View File

@ -421,8 +421,8 @@ pub(crate) fn eval_permutation_checks_circuit<F, S, const D: usize>(
) )
}) })
.unzip(); .unzip();
let reduced_lhs_product = builder.mul_many_extension(&reduced_lhs); let reduced_lhs_product = builder.mul_many_extension(reduced_lhs);
let reduced_rhs_product = builder.mul_many_extension(&reduced_rhs); let reduced_rhs_product = builder.mul_many_extension(reduced_rhs);
// constraint = next_zs[i] * reduced_rhs_product - local_zs[i] * reduced_lhs_product // constraint = next_zs[i] * reduced_rhs_product - local_zs[i] * reduced_lhs_product
let constraint = { let constraint = {
let tmp = builder.mul_extension(local_zs[i], reduced_lhs_product); let tmp = builder.mul_extension(local_zs[i], reduced_lhs_product);

View File

@ -188,8 +188,13 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
/// Add `n` `Target`s. /// Add `n` `Target`s.
pub fn add_many(&mut self, terms: &[Target]) -> Target { pub fn add_many<T>(&mut self, terms: impl IntoIterator<Item = T>) -> Target
terms.iter().fold(self.zero(), |acc, &t| self.add(acc, t)) where
T: Borrow<Target>,
{
terms
.into_iter()
.fold(self.zero(), |acc, t| self.add(acc, *t.borrow()))
} }
/// Computes `x - y`. /// Computes `x - y`.

View File

@ -1,3 +1,5 @@
use std::borrow::Borrow;
use plonky2_field::extension_field::FieldExtension; use plonky2_field::extension_field::FieldExtension;
use plonky2_field::extension_field::{Extendable, OEF}; use plonky2_field::extension_field::{Extendable, OEF};
use plonky2_field::field_types::{Field, Field64}; use plonky2_field::field_types::{Field, Field64};
@ -204,12 +206,16 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
/// Add `n` `ExtensionTarget`s. /// Add `n` `ExtensionTarget`s.
pub fn add_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> { pub fn add_many_extension<T>(
let mut sum = self.zero_extension(); &mut self,
for &term in terms { terms: impl IntoIterator<Item = T>,
sum = self.add_extension(sum, term); ) -> ExtensionTarget<D>
} where
sum T: Borrow<ExtensionTarget<D>>,
{
terms.into_iter().fold(self.zero_extension(), |acc, t| {
self.add_extension(acc, *t.borrow())
})
} }
pub fn sub_extension( pub fn sub_extension(
@ -257,7 +263,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Computes `x^3`. /// Computes `x^3`.
pub fn cube_extension(&mut self, x: ExtensionTarget<D>) -> ExtensionTarget<D> { pub fn cube_extension(&mut self, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
self.mul_many_extension(&[x, x, x]) self.mul_many_extension([x, x, x])
} }
/// Returns `a * b + c`. /// Returns `a * b + c`.
@ -301,12 +307,16 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
/// Multiply `n` `ExtensionTarget`s. /// Multiply `n` `ExtensionTarget`s.
pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> { pub fn mul_many_extension<T>(
terms &mut self,
.iter() terms: impl IntoIterator<Item = T>,
.copied() ) -> ExtensionTarget<D>
.reduce(|acc, t| self.mul_extension(acc, t)) where
.unwrap_or_else(|| self.one_extension()) T: Borrow<ExtensionTarget<D>>,
{
terms.into_iter().fold(self.one_extension(), |acc, t| {
self.mul_extension(acc, *t.borrow())
})
} }
/// Like `mul_add`, but for `ExtensionTarget`s. /// Like `mul_add`, but for `ExtensionTarget`s.

View File

@ -102,7 +102,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate
let output = vars.local_wires[Self::wire_ith_output(i)]; let output = vars.local_wires[Self::wire_ith_output(i)];
let computed_output = { let computed_output = {
let scaled_mul = let scaled_mul =
builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); builder.mul_many_extension([const_0, multiplicand_0, multiplicand_1]);
builder.mul_add_extension(const_1, addend, scaled_mul) builder.mul_add_extension(const_1, addend, scaled_mul)
}; };

View File

@ -265,5 +265,5 @@ fn compute_filter_circuit<F: RichField + Extendable<D>, const D: usize>(
builder.sub_extension(c, s) builder.sub_extension(c, s)
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
builder.mul_many_extension(&v) builder.mul_many_extension(v)
} }

View File

@ -385,8 +385,8 @@ pub(crate) fn eval_permutation_checks_circuit<F, S, const D: usize>(
) )
}) })
.unzip(); .unzip();
let reduced_lhs_product = builder.mul_many_extension(&reduced_lhs); let reduced_lhs_product = builder.mul_many_extension(reduced_lhs);
let reduced_rhs_product = builder.mul_many_extension(&reduced_rhs); let reduced_rhs_product = builder.mul_many_extension(reduced_rhs);
// constraint = next_zs[i] * reduced_rhs_product - local_zs[i] * reduced_lhs_product // constraint = next_zs[i] * reduced_rhs_product - local_zs[i] * reduced_lhs_product
let constraint = { let constraint = {
let tmp = builder.mul_extension(local_zs[i], reduced_lhs_product); let tmp = builder.mul_extension(local_zs[i], reduced_lhs_product);

View File

@ -62,7 +62,7 @@ pub(crate) fn eval_addition_circuit<F: RichField + Extendable<D>, const D: usize
// this sum can be around 48 bits at most. // this sum can be around 48 bits at most.
let out = reduce_with_powers_ext_circuit(builder, &[out_1, out_2, out_3], limb_base); let out = reduce_with_powers_ext_circuit(builder, &[out_1, out_2, out_3], limb_base);
let computed_out = builder.add_many_extension(&[in_1, in_2, in_3]); let computed_out = builder.add_many_extension([in_1, in_2, in_3]);
let diff = builder.sub_extension(out, computed_out); let diff = builder.sub_extension(out, computed_out);
let filtered_diff = builder.mul_extension(is_add, diff); let filtered_diff = builder.mul_extension(is_add, diff);

View File

@ -163,7 +163,7 @@ fn eval_bitop_32_circuit<F: RichField + Extendable<D>, const D: usize>(
let b_bits = input_b_regs.map(|r| lv[r]); let b_bits = input_b_regs.map(|r| lv[r]);
// Ensure that the inputs are bits // Ensure that the inputs are bits
let inst_constr = builder.add_many_extension(&[is_and, is_ior, is_xor, is_andnot]); let inst_constr = builder.add_many_extension([is_and, is_ior, is_xor, is_andnot]);
constrain_all_to_bits_circuit(builder, yield_constr, inst_constr, &a_bits); constrain_all_to_bits_circuit(builder, yield_constr, inst_constr, &a_bits);
constrain_all_to_bits_circuit(builder, yield_constr, inst_constr, &b_bits); constrain_all_to_bits_circuit(builder, yield_constr, inst_constr, &b_bits);
@ -204,7 +204,7 @@ fn eval_bitop_32_circuit<F: RichField + Extendable<D>, const D: usize>(
builder.mul_extension(t1, is_andnot) builder.mul_extension(t1, is_andnot)
}; };
let constr = builder.add_many_extension(&[and_constr, ior_constr, xor_constr, andnot_constr]); let constr = builder.add_many_extension([and_constr, ior_constr, xor_constr, andnot_constr]);
yield_constr.constraint(builder, constr); yield_constr.constraint(builder, constr);
} }

View File

@ -195,7 +195,7 @@ pub(crate) fn eval_permutation_unit_circuit<F: RichField + Extendable<D>, const
builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]); builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]);
yield_constr.constraint(builder, diff); yield_constr.constraint(builder, diff);
let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); state[i] = builder.mul_many_extension([state[i], state_cubed, state_cubed]);
// Form state ** 7. // Form state ** 7.
} }
@ -216,7 +216,7 @@ pub(crate) fn eval_permutation_unit_circuit<F: RichField + Extendable<D>, const
let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]); let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]);
yield_constr.constraint(builder, diff); yield_constr.constraint(builder, diff);
let state0_cubed = local_values[col_partial_mid_sbox(r)]; let state0_cubed = local_values[col_partial_mid_sbox(r)];
state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7. state[0] = builder.mul_many_extension([state[0], state0_cubed, state0_cubed]); // Form state ** 7.
let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]); let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]);
yield_constr.constraint(builder, diff); yield_constr.constraint(builder, diff);
state[0] = local_values[col_partial_after_sbox(r)]; state[0] = local_values[col_partial_after_sbox(r)];
@ -237,7 +237,7 @@ pub(crate) fn eval_permutation_unit_circuit<F: RichField + Extendable<D>, const
builder.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]); builder.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]);
yield_constr.constraint(builder, diff); yield_constr.constraint(builder, diff);
let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); state[i] = builder.mul_many_extension([state[i], state_cubed, state_cubed]);
// Form state ** 7. // Form state ** 7.
} }