diff --git a/src/c_kzg.h b/src/c_kzg.h index 0f28592..51d4771 100644 --- a/src/c_kzg.h +++ b/src/c_kzg.h @@ -18,9 +18,9 @@ #define C_KZG_H typedef enum { - C_KZG_OK = 0, - C_KZG_BADARGS, - C_KZG_ERROR, + C_KZG_OK = 0, // Success! + C_KZG_BADARGS, // The supplied data is invalid in some way + C_KZG_ERROR, // Internal error - should never occur } C_KZG_RET; #include diff --git a/src/kzg_proofs.c b/src/kzg_proofs.c index 3f5ba88..6b4de1c 100644 --- a/src/kzg_proofs.c +++ b/src/kzg_proofs.c @@ -52,7 +52,6 @@ bool check_proof_single(const KZGSettings *ks, const blst_p1 *commitment, const C_KZG_RET compute_proof_multi(blst_p1 *out, const KZGSettings *ks, poly *p, const blst_fr *x0, uint64_t n) { poly divisor, q; blst_fr x_pow_n; - C_KZG_RET ret; // Construct x^n - x0^n = (x - w^0)(x - w^1)...(x - w^(n-1)) init_poly(&divisor, n + 1); @@ -70,10 +69,7 @@ C_KZG_RET compute_proof_multi(blst_p1 *out, const KZGSettings *ks, poly *p, cons divisor.coeffs[n] = fr_one; // Calculate q = p / (x^n - x0^n) - init_poly(&q, poly_quotient_length(p, &divisor)); - if ((ret = poly_long_div(&q, p, &divisor) != C_KZG_OK)) { - return C_KZG_ERROR; - } + ASSERT(poly_long_div(&q, p, &divisor) == C_KZG_OK, C_KZG_ERROR); commit_to_poly(out, ks, &q); @@ -121,4 +117,4 @@ bool check_proof_multi(const KZGSettings *ks, const blst_p1 *commitment, const b free_poly(&interp); return pairings_verify(&commit_minus_interp, blst_p2_generator(), proof, &xn_minus_yn); -} \ No newline at end of file +} diff --git a/src/poly.c b/src/poly.c index a4fb923..a82822d 100644 --- a/src/poly.c +++ b/src/poly.c @@ -24,6 +24,7 @@ static void poly_factor_div(blst_fr *out, const blst_fr *a, const blst_fr *b) { void init_poly(poly *out, const uint64_t length) { out->length = length; out->coeffs = length > 0 ? malloc(length * sizeof(blst_fr)): NULL; + // TODO: check malloc return and handle accordingly } void free_poly(poly *p) { @@ -62,18 +63,22 @@ uint64_t poly_quotient_length(const poly *dividend, const poly *divisor) { return dividend->length >= divisor->length ? dividend->length - divisor->length + 1 : 0; } -// `out` must have been pre-allocated to the correct size, see `poly_quotient_length()` +// `out` must be an uninitialised poly and has space allocated for it here, which +// must be freed by calling `free_poly()` later. C_KZG_RET poly_long_div(poly *out, const poly *dividend, const poly *divisor) { uint64_t a_pos = dividend->length - 1; uint64_t b_pos = divisor->length - 1; uint64_t diff = a_pos - b_pos; blst_fr a[dividend->length]; + // Dividing by zero is undefined ASSERT(divisor->length > 0, C_KZG_BADARGS); - ASSERT(out->length == poly_quotient_length(dividend, divisor), C_KZG_BADARGS); + + // Initialise the output polynomial + init_poly(out, poly_quotient_length(dividend, divisor)); // If the divisor is larger than the dividend, the result is zero-length - if (divisor->length > dividend->length) return C_KZG_OK; + if (out->length == 0) return C_KZG_OK; for (uint64_t i = 0; i < dividend->length; i++) { a[i] = dividend->coeffs[i]; diff --git a/src/poly_test.c b/src/poly_test.c index b5e9f0d..94bfb5c 100644 --- a/src/poly_test.c +++ b/src/poly_test.c @@ -33,7 +33,7 @@ void poly_div_length(void) { } void poly_div_0(void) { - blst_fr a[3], b[2], c[2], expected[2]; + blst_fr a[3], b[2], expected[2]; poly dividend, divisor, actual; // Calculate (x^2 - 1) / (x + 1) = x - 1 @@ -57,16 +57,15 @@ void poly_div_0(void) { fr_negate(&expected[0], &expected[0]); fr_from_uint64(&expected[1], 1); - actual.length = 2; - actual.coeffs = c; - TEST_CHECK(C_KZG_OK == poly_long_div(&actual, ÷nd, &divisor)); TEST_CHECK(fr_equal(&expected[0], &actual.coeffs[0])); TEST_CHECK(fr_equal(&expected[1], &actual.coeffs[1])); + + free_poly(&actual); } void poly_div_1(void) { - blst_fr a[4], b[2], c[3], expected[3]; + blst_fr a[4], b[2], expected[3]; poly dividend, divisor, actual; // Calculate (12x^3 - 11x^2 + 9x + 18) / (4x + 3) = 3x^2 - 5x + 6 @@ -92,13 +91,12 @@ void poly_div_1(void) { fr_negate(&expected[1], &expected[1]); fr_from_uint64(&expected[2], 3); - actual.length = 3; - actual.coeffs = c; - TEST_CHECK(C_KZG_OK == poly_long_div(&actual, ÷nd, &divisor)); TEST_CHECK(fr_equal(&expected[0], &actual.coeffs[0])); TEST_CHECK(fr_equal(&expected[1], &actual.coeffs[1])); TEST_CHECK(fr_equal(&expected[2], &actual.coeffs[2])); + + free_poly(&actual); } void poly_div_2(void) { @@ -121,8 +119,6 @@ void poly_div_2(void) { divisor.length = 3; divisor.coeffs = b; - init_poly(&actual, poly_quotient_length(÷nd, &divisor)); - TEST_CHECK(C_KZG_OK == poly_long_div(&actual, ÷nd, &divisor)); TEST_CHECK(NULL == actual.coeffs); @@ -131,7 +127,7 @@ void poly_div_2(void) { void poly_div_by_zero(void) { blst_fr a[2]; - poly dividend, divisor; + poly dividend, divisor, dummy; // Calculate (x + 1) / 0 = FAIL @@ -144,7 +140,7 @@ void poly_div_by_zero(void) { // Divisor init_poly(&divisor, 0); - TEST_CHECK(C_KZG_BADARGS == poly_long_div(NULL, ÷nd, &divisor)); + TEST_CHECK(C_KZG_BADARGS == poly_long_div(&dummy, ÷nd, &divisor)); } void poly_eval_check(void) {