From 94198b5c18f8b82c2e56d79d7741ff583399f9e2 Mon Sep 17 00:00:00 2001 From: George Kadianakis Date: Fri, 24 Feb 2023 17:38:30 +0200 Subject: [PATCH] Refactor use of MSM around the base code (#159) * Separate naive MSM and fast MSM into separate functions * Use naive MSM in batch verify, and fast MSM when points are trusted --- src/c_kzg_4844.c | 45 ++++++++++++++++++++++++++++--------------- src/test_c_kzg_4844.c | 12 +++++------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/src/c_kzg_4844.c b/src/c_kzg_4844.c index 0aa2d0d..a4cf64b 100644 --- a/src/c_kzg_4844.c +++ b/src/c_kzg_4844.c @@ -752,6 +752,28 @@ static void compute_challenge( * Calculates `[coeffs_0]p_0 + [coeffs_1]p_1 + ... + [coeffs_n]p_n` * where `n` is `len - 1`. * + * This function computes the result naively without using Pippenger's + * algorithm. + */ +static void g1_lincomb_naive( + g1_t *out, const g1_t *p, const fr_t *coeffs, uint64_t len +) { + g1_t tmp; + *out = G1_IDENTITY; + for (uint64_t i = 0; i < len; i++) { + g1_mul(&tmp, &p[i], &coeffs[i]); + blst_p1_add_or_double(out, out, &tmp); + } +} + +/** + * Calculate a linear combination of G1 group elements. + * + * Calculates `[coeffs_0]p_0 + [coeffs_1]p_1 + ... + [coeffs_n]p_n` + * where `n` is `len - 1`. + * + * @remark This function MUST NOT be called with the point at infinity in `p`. + * @param[out] out The resulting sum-product * @param[in] p Array of G1 group elements, length @p len * @param[in] coeffs Array of field elements, length @p len @@ -768,7 +790,7 @@ static void compute_challenge( * * We do the second of these to save memory here. */ -static C_KZG_RET g1_lincomb( +static C_KZG_RET g1_lincomb_fast( g1_t *out, const g1_t *p, const fr_t *coeffs, uint64_t len ) { C_KZG_RET ret; @@ -778,13 +800,7 @@ static C_KZG_RET g1_lincomb( // Tunable parameter: must be at least 2 since Blst fails for 0 or 1 if (len < 8) { - // Direct approach - g1_t tmp; - *out = G1_IDENTITY; - for (uint64_t i = 0; i < len; i++) { - g1_mul(&tmp, &p[i], &coeffs[i]); - blst_p1_add_or_double(out, out, &tmp); - } + g1_lincomb_naive(out, p, coeffs, len); } else { // Blst's implementation of the Pippenger method size_t scratch_size = blst_p1s_mult_pippenger_scratch_sizeof(len); @@ -910,7 +926,7 @@ out: static C_KZG_RET poly_to_kzg_commitment( g1_t *out, const Polynomial *p, const KZGSettings *s ) { - return g1_lincomb( + return g1_lincomb_fast( out, s->g1_values, (const fr_t *)(&p->evals), FIELD_ELEMENTS_PER_BLOB ); } @@ -1139,7 +1155,7 @@ static C_KZG_RET compute_kzg_proof_impl( } g1_t out_g1; - ret = g1_lincomb( + ret = g1_lincomb_fast( &out_g1, s->g1_values, (const fr_t *)(&q.evals), FIELD_ELEMENTS_PER_BLOB ); if (ret != C_KZG_OK) goto out; @@ -1353,8 +1369,7 @@ static C_KZG_RET verify_kzg_proof_batch( if (ret != C_KZG_OK) goto out; /* Compute \sum r^i * Proof_i */ - ret = g1_lincomb(&proof_lincomb, proofs_g1, r_powers, n); - if (ret != C_KZG_OK) goto out; + g1_lincomb_naive(&proof_lincomb, proofs_g1, r_powers, n); for (size_t i = 0; i < n; i++) { g1_t ys_encrypted; @@ -1368,11 +1383,9 @@ static C_KZG_RET verify_kzg_proof_batch( } /* Get \sum r^i z_i Proof_i */ - ret = g1_lincomb(&proof_z_lincomb, proofs_g1, r_times_z, n); - if (ret != C_KZG_OK) goto out; + g1_lincomb_naive(&proof_z_lincomb, proofs_g1, r_times_z, n); /* Get \sum r^i (C_i - [y_i]) */ - ret = g1_lincomb(&C_minus_y_lincomb, C_minus_y, r_powers, n); - if (ret != C_KZG_OK) goto out; + g1_lincomb_naive(&C_minus_y_lincomb, C_minus_y, r_powers, n); /* Get C_minus_y_lincomb + proof_z_lincomb */ blst_p1_add_or_double(&rhs_g1, &C_minus_y_lincomb, &proof_z_lincomb); diff --git a/src/test_c_kzg_4844.c b/src/test_c_kzg_4844.c index 5b0e60c..29e681d 100644 --- a/src/test_c_kzg_4844.c +++ b/src/test_c_kzg_4844.c @@ -898,23 +898,21 @@ static void test_compute_powers__succeeds_expected_powers(void) { static void test_g1_lincomb__verify_consistent(void) { C_KZG_RET ret; - g1_t points[128], out, check, tmp; + g1_t points[128], out, check; fr_t scalars[128]; check = G1_IDENTITY; for (size_t i = 0; i < 128; i++) { get_rand_fr(&scalars[i]); get_rand_g1(&points[i]); - g1_mul(&tmp, &points[i], &scalars[i]); - blst_p1_add(&check, &check, &tmp); } - ret = g1_lincomb(&out, points, scalars, 128); + g1_lincomb_naive(&check, points, scalars, 128); + + ret = g1_lincomb_fast(&out, points, scalars, 128); ASSERT_EQUALS(ret, C_KZG_OK); - ASSERT( - "lincomb matches direct multiplication", blst_p1_is_equal(&out, &check) - ); + ASSERT("pippenger matches naive MSM", blst_p1_is_equal(&out, &check)); } ///////////////////////////////////////////////////////////////////////////////