From 22ce2da9e182acd8f8b566e1697cee813f69e45c Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 20 Oct 2021 23:43:52 -0700 Subject: [PATCH] Add add_const, mul_const, mul_const_add methods (#312) * Add mul_const, mul_const_add methods To replace some arithmetic calls; I think it's easier to read. * One more * Couple more * tweak * tweak --- src/gadgets/arithmetic.rs | 18 +++++++++++++ src/gadgets/arithmetic_extension.rs | 23 ++++++++++++++++ src/gadgets/split_join.rs | 9 +------ src/gates/poseidon.rs | 8 ++---- src/hash/poseidon.rs | 42 ++++++++--------------------- src/plonk/vanishing_poly.rs | 2 +- 6 files changed, 56 insertions(+), 46 deletions(-) diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index be5af1ce..95e7453f 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -51,6 +51,24 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, y, F::ONE, z) } + /// Computes `x + C`. + pub fn add_const(&mut self, x: Target, c: F) -> Target { + let one = self.one(); + self.arithmetic(F::ONE, one, x, c, one) + } + + /// Computes `C * x`. + pub fn mul_const(&mut self, c: F, x: Target) -> Target { + let zero = self.zero(); + self.mul_const_add(c, x, zero) + } + + /// Computes `C * x + y`. + pub fn mul_const_add(&mut self, c: F, x: Target, y: Target) -> Target { + let one = self.one(); + self.arithmetic(c, x, one, F::ONE, y) + } + /// Computes `x * y - z`. pub fn mul_sub(&mut self, x: Target, y: Target, z: Target) -> Target { self.arithmetic(F::ONE, x, y, F::NEG_ONE, z) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index fd5aae7f..2a602f04 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -297,6 +297,29 @@ impl, const D: usize> CircuitBuilder { self.arithmetic_extension(F::ONE, F::ONE, a, b, c) } + /// Like `add_const`, but for `ExtensionTarget`s. + pub fn add_const_extension(&mut self, x: ExtensionTarget, c: F) -> ExtensionTarget { + let one = self.one_extension(); + self.arithmetic_extension(F::ONE, c, one, x, one) + } + + /// Like `mul_const`, but for `ExtensionTarget`s. + pub fn mul_const_extension(&mut self, c: F, x: ExtensionTarget) -> ExtensionTarget { + let zero = self.zero_extension(); + self.mul_const_add_extension(c, x, zero) + } + + /// Like `mul_const_add`, but for `ExtensionTarget`s. + pub fn mul_const_add_extension( + &mut self, + c: F, + x: ExtensionTarget, + y: ExtensionTarget, + ) -> ExtensionTarget { + let one = self.one_extension(); + self.arithmetic_extension(c, F::ONE, x, one, y) + } + /// Like `mul_add`, but for `ExtensionTarget`s. pub fn scalar_mul_add_extension( &mut self, diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 9b84ed76..8875d041 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -34,17 +34,10 @@ impl, const D: usize> CircuitBuilder { bits.drain(num_bits..); let zero = self.zero(); - let one = self.one(); let mut acc = zero; for &gate in gates.iter().rev() { let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); - acc = self.arithmetic( - F::from_canonical_usize(1 << bits_per_gate), - acc, - one, - F::ONE, - sum, - ); + acc = self.mul_const_add(F::from_canonical_usize(1 << bits_per_gate), acc, sum); } self.connect(acc, integer); diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index b6902ada..b7e3ff7b 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -258,7 +258,6 @@ where builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { - let one = builder.one_extension(); let mut constraints = Vec::with_capacity(self.num_constraints()); // Assert that `swap` is binary. @@ -305,12 +304,9 @@ where let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; constraints.push(builder.sub_extension(state[0], sbox_in)); state[0] = >::sbox_monomial_recursive(builder, sbox_in); - state[0] = builder.arithmetic_extension( - F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]), - F::ONE, - one, - one, + state[0] = builder.add_const_extension( state[0], + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]), ); state = >::mds_partial_layer_fast_recursive(builder, &state, r); } diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 8da878aa..ae0b132c 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -206,15 +206,12 @@ where r: usize, v: &[ExtensionTarget; WIDTH], ) -> ExtensionTarget { - let one = builder.one_extension(); debug_assert!(r < WIDTH); let mut res = builder.zero_extension(); for i in 0..WIDTH { - res = builder.arithmetic_extension( + res = builder.mul_const_add_extension( F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]), - F::ONE, - one, v[(i + r) % WIDTH], res, ); @@ -292,14 +289,10 @@ where builder: &mut CircuitBuilder, state: &mut [ExtensionTarget; WIDTH], ) { - let one = builder.one_extension(); for i in 0..WIDTH { - state[i] = builder.arithmetic_extension( - F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]), - F::ONE, - one, - one, + state[i] = builder.add_const_extension( state[i], + F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]), ); } } @@ -341,7 +334,6 @@ where builder: &mut CircuitBuilder, state: &[ExtensionTarget; WIDTH], ) -> [ExtensionTarget; WIDTH] { - let one = builder.one_extension(); let mut result = [builder.zero_extension(); WIDTH]; result[0] = state[0]; @@ -350,7 +342,7 @@ where for c in 1..WIDTH { let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1]); - result[c] = builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]); + result[c] = builder.mul_const_add_extension(t, state[r], result[c]); } } result @@ -423,27 +415,19 @@ where state: &[ExtensionTarget; WIDTH], r: usize, ) -> [ExtensionTarget; WIDTH] { - let zero = builder.zero_extension(); - let one = builder.one_extension(); - let s0 = state[0]; - let mut d = builder.arithmetic_extension( - F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]), - F::ONE, - one, - s0, - zero, - ); + let mut d = + builder.mul_const_extension(F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]), s0); for i in 1..WIDTH { let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); - d = builder.arithmetic_extension(t, F::ONE, one, state[i], d); + d = builder.mul_const_add_extension(t, state[i], d); } - let mut result = [zero; WIDTH]; + let mut result = [builder.zero_extension(); WIDTH]; result[0] = d; for i in 1..WIDTH { let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); - result[i] = builder.arithmetic_extension(t, F::ONE, one, state[0], state[i]); + result[i] = builder.mul_const_add_extension(t, state[0], state[i]); } result } @@ -478,14 +462,10 @@ where state: &mut [ExtensionTarget; WIDTH], round_ctr: usize, ) { - let one = builder.one_extension(); for i in 0..WIDTH { - state[i] = builder.arithmetic_extension( - F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), - F::ONE, - one, - one, + state[i] = builder.add_const_extension( state[i], + F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), ); } } diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index b22ebb33..0bb03b50 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -319,7 +319,7 @@ pub(crate) fn eval_vanishing_poly_recursively, cons for i in 0..common_data.config.num_challenges { let z_x = local_zs[i]; let z_gz = next_zs[i]; - vanishing_z_1_terms.push(builder.arithmetic_extension(F::ONE, F::NEG_ONE, l1_x, z_x, l1_x)); + vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x)); let numerator_values = (0..common_data.config.num_routed_wires) .map(|j| {