Add add_const, mul_const, mul_const_add methods (#312)

* Add mul_const, mul_const_add methods

To replace some arithmetic calls; I think it's easier to read.

* One more

* Couple more

* tweak

* tweak
This commit is contained in:
Daniel Lubarov 2021-10-20 23:43:52 -07:00 committed by GitHub
parent 0b75b24c09
commit 22ce2da9e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 56 additions and 46 deletions

View File

@ -51,6 +51,24 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.arithmetic(F::ONE, x, y, F::ONE, z)
}
/// Computes `x + C`.
pub fn add_const(&mut self, x: Target, c: F) -> Target {
let one = self.one();
self.arithmetic(F::ONE, one, x, c, one)
}
/// Computes `C * x`.
pub fn mul_const(&mut self, c: F, x: Target) -> Target {
let zero = self.zero();
self.mul_const_add(c, x, zero)
}
/// Computes `C * x + y`.
pub fn mul_const_add(&mut self, c: F, x: Target, y: Target) -> Target {
let one = self.one();
self.arithmetic(c, x, one, F::ONE, y)
}
/// Computes `x * y - z`.
pub fn mul_sub(&mut self, x: Target, y: Target, z: Target) -> Target {
self.arithmetic(F::ONE, x, y, F::NEG_ONE, z)

View File

@ -297,6 +297,29 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.arithmetic_extension(F::ONE, F::ONE, a, b, c)
}
/// Like `add_const`, but for `ExtensionTarget`s.
pub fn add_const_extension(&mut self, x: ExtensionTarget<D>, c: F) -> ExtensionTarget<D> {
let one = self.one_extension();
self.arithmetic_extension(F::ONE, c, one, x, one)
}
/// Like `mul_const`, but for `ExtensionTarget`s.
pub fn mul_const_extension(&mut self, c: F, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
let zero = self.zero_extension();
self.mul_const_add_extension(c, x, zero)
}
/// Like `mul_const_add`, but for `ExtensionTarget`s.
pub fn mul_const_add_extension(
&mut self,
c: F,
x: ExtensionTarget<D>,
y: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let one = self.one_extension();
self.arithmetic_extension(c, F::ONE, x, one, y)
}
/// Like `mul_add`, but for `ExtensionTarget`s.
pub fn scalar_mul_add_extension(
&mut self,

View File

@ -34,17 +34,10 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
bits.drain(num_bits..);
let zero = self.zero();
let one = self.one();
let mut acc = zero;
for &gate in gates.iter().rev() {
let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM);
acc = self.arithmetic(
F::from_canonical_usize(1 << bits_per_gate),
acc,
one,
F::ONE,
sum,
);
acc = self.mul_const_add(F::from_canonical_usize(1 << bits_per_gate), acc, sum);
}
self.connect(acc, integer);

View File

@ -258,7 +258,6 @@ where
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let one = builder.one_extension();
let mut constraints = Vec::with_capacity(self.num_constraints());
// Assert that `swap` is binary.
@ -305,12 +304,9 @@ where
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(builder.sub_extension(state[0], sbox_in));
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
state[0] = builder.arithmetic_extension(
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]),
F::ONE,
one,
one,
state[0] = builder.add_const_extension(
state[0],
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]),
);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(builder, &state, r);
}

View File

@ -206,15 +206,12 @@ where
r: usize,
v: &[ExtensionTarget<D>; WIDTH],
) -> ExtensionTarget<D> {
let one = builder.one_extension();
debug_assert!(r < WIDTH);
let mut res = builder.zero_extension();
for i in 0..WIDTH {
res = builder.arithmetic_extension(
res = builder.mul_const_add_extension(
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]),
F::ONE,
one,
v[(i + r) % WIDTH],
res,
);
@ -292,14 +289,10 @@ where
builder: &mut CircuitBuilder<F, D>,
state: &mut [ExtensionTarget<D>; WIDTH],
) {
let one = builder.one_extension();
for i in 0..WIDTH {
state[i] = builder.arithmetic_extension(
F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]),
F::ONE,
one,
one,
state[i] = builder.add_const_extension(
state[i],
F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]),
);
}
}
@ -341,7 +334,6 @@ where
builder: &mut CircuitBuilder<F, D>,
state: &[ExtensionTarget<D>; WIDTH],
) -> [ExtensionTarget<D>; WIDTH] {
let one = builder.one_extension();
let mut result = [builder.zero_extension(); WIDTH];
result[0] = state[0];
@ -350,7 +342,7 @@ where
for c in 1..WIDTH {
let t =
F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1]);
result[c] = builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]);
result[c] = builder.mul_const_add_extension(t, state[r], result[c]);
}
}
result
@ -423,27 +415,19 @@ where
state: &[ExtensionTarget<D>; WIDTH],
r: usize,
) -> [ExtensionTarget<D>; WIDTH] {
let zero = builder.zero_extension();
let one = builder.one_extension();
let s0 = state[0];
let mut d = builder.arithmetic_extension(
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]),
F::ONE,
one,
s0,
zero,
);
let mut d =
builder.mul_const_extension(F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]), s0);
for i in 1..WIDTH {
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
d = builder.arithmetic_extension(t, F::ONE, one, state[i], d);
d = builder.mul_const_add_extension(t, state[i], d);
}
let mut result = [zero; WIDTH];
let mut result = [builder.zero_extension(); WIDTH];
result[0] = d;
for i in 1..WIDTH {
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]);
result[i] = builder.arithmetic_extension(t, F::ONE, one, state[0], state[i]);
result[i] = builder.mul_const_add_extension(t, state[0], state[i]);
}
result
}
@ -478,14 +462,10 @@ where
state: &mut [ExtensionTarget<D>; WIDTH],
round_ctr: usize,
) {
let one = builder.one_extension();
for i in 0..WIDTH {
state[i] = builder.arithmetic_extension(
F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]),
F::ONE,
one,
one,
state[i] = builder.add_const_extension(
state[i],
F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]),
);
}
}

View File

@ -319,7 +319,7 @@ pub(crate) fn eval_vanishing_poly_recursively<F: RichField + Extendable<D>, cons
for i in 0..common_data.config.num_challenges {
let z_x = local_zs[i];
let z_gz = next_zs[i];
vanishing_z_1_terms.push(builder.arithmetic_extension(F::ONE, F::NEG_ONE, l1_x, z_x, l1_x));
vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x));
let numerator_values = (0..common_data.config.num_routed_wires)
.map(|j| {