diff --git a/src/c_kzg.h b/src/c_kzg.h index 880313a..157113e 100644 --- a/src/c_kzg.h +++ b/src/c_kzg.h @@ -108,7 +108,7 @@ void free_poly_l(poly_l *p); // kzg_proofs.c // -void fr_vector_lincomb(fr_t out[], const fr_t *vectors, const fr_t *scalars, uint64_t n, uint64_t m); +void fr_vector_lincomb(fr_t out[], const fr_t *vectors[], const fr_t *scalars, uint64_t n, uint64_t m); /** * Stores the setup and parameters needed for computing KZG proofs. diff --git a/src/c_kzg_4844.c b/src/c_kzg_4844.c index e09658d..7d35ab8 100644 --- a/src/c_kzg_4844.c +++ b/src/c_kzg_4844.c @@ -69,7 +69,7 @@ void free_trusted_setup(KZGSettings *s) { void compute_powers(BLSFieldElement out[], const BLSFieldElement *x, uint64_t n) { fr_pow(out, x, n); } -void vector_lincomb(BLSFieldElement out[], const BLSFieldElement *vectors, const BLSFieldElement *scalars, uint64_t num_vectors, uint64_t vector_len) { +void vector_lincomb(BLSFieldElement out[], const BLSFieldElement[] *vectors, const BLSFieldElement *scalars, uint64_t num_vectors, uint64_t vector_len) { fr_vector_lincomb(out, vectors, scalars, num_vectors, vector_len); } diff --git a/src/c_kzg_4844.h b/src/c_kzg_4844.h index ecf4371..c6cc095 100644 --- a/src/c_kzg_4844.h +++ b/src/c_kzg_4844.h @@ -38,7 +38,7 @@ void free_trusted_setup(KZGSettings *s); void compute_powers(BLSFieldElement out[], const BLSFieldElement *x, uint64_t n); -void vector_lincomb(BLSFieldElement out[], const BLSFieldElement *vectors, const BLSFieldElement *scalars, uint64_t num_vectors, uint64_t vector_len); +void vector_lincomb(BLSFieldElement out[], const BLSFieldElement *vectors[], const BLSFieldElement *scalars, uint64_t num_vectors, uint64_t vector_len); void g1_lincomb(KZGCommitment *out, const KZGCommitment points[], const BLSFieldElement scalars[], uint64_t num_points); diff --git a/src/kzg_proofs.c b/src/kzg_proofs.c index a22775b..85de536 100644 --- a/src/kzg_proofs.c +++ b/src/kzg_proofs.c @@ -31,15 +31,14 @@ /** * Compute linear combinations of a sequence of vectors with some scalars */ -void fr_vector_lincomb(fr_t out[], const fr_t *vectors, const fr_t *scalars, uint64_t n, uint64_t m) { - fr_t (*vectors_ptr)[n][m] = (fr_t (*)[n][m]) vectors; +void fr_vector_lincomb(fr_t out[], const fr_t *vectors[], const fr_t *scalars, uint64_t n, uint64_t m) { fr_t tmp; uint64_t i, j; for (j = 0; j < m; j++) out[j] = fr_zero; for (i = 0; i < n; i++) { for (j = 0; j < m; j++) { - fr_mul(&tmp, &scalars[i], &((*vectors_ptr)[i][j])); + fr_mul(&tmp, &scalars[i], &vectors[i][j]); fr_add(&out[j], &out[j], &tmp); } } @@ -744,22 +743,24 @@ void fr_vector_lincomb_simple_test(void) { fr_add(&fr2, &fr_one, &fr_one); fr_add(&fr3, &fr2, &fr_one); fr_t out[m]; - const fr_t vectors[2][3] = { { fr_one, fr2, fr3 }, { fr3, fr2, fr_zero } }; + const fr_t v1[3] = { fr_one, fr2, fr3 }; + const fr_t v2[3] = { fr3, fr2, fr_zero }; + const fr_t* vectors[2] = { v1, v2 }; fr_t scalars[2] = { fr_zero, fr_one }; - fr_vector_lincomb(out, (fr_t*)vectors, (fr_t*)scalars, n, m); + fr_vector_lincomb(out, vectors, (fr_t*)scalars, n, m); for (i = 0; i < m; i++) { TEST_CHECK(fr_equal(&out[i], &vectors[1][i])); } scalars[0] = fr_one; scalars[1] = fr_zero; - fr_vector_lincomb(out, (fr_t*)vectors, (fr_t*)scalars, n, m); + fr_vector_lincomb(out, vectors, (fr_t*)scalars, n, m); for (i = 0; i < m; i++) { TEST_CHECK(fr_equal(&out[i], &vectors[0][i])); } scalars[1] = fr_one; - fr_vector_lincomb(out, (fr_t*)vectors, (fr_t*)scalars, n, m); + fr_vector_lincomb(out, vectors, (fr_t*)scalars, n, m); for (i = 0; i < m; i++) { fr_add(&tmp, &vectors[0][i], &vectors[1][i]); TEST_CHECK(fr_equal(&out[i], &tmp));