Make polynomial division allocate space for the result

This commit is contained in:
Ben Edgington 2021-02-07 08:03:10 +00:00
parent 951ce118cd
commit ff014c293f
4 changed files with 21 additions and 24 deletions

View File

@ -18,9 +18,9 @@
#define C_KZG_H #define C_KZG_H
typedef enum { typedef enum {
C_KZG_OK = 0, C_KZG_OK = 0, // Success!
C_KZG_BADARGS, C_KZG_BADARGS, // The supplied data is invalid in some way
C_KZG_ERROR, C_KZG_ERROR, // Internal error - should never occur
} C_KZG_RET; } C_KZG_RET;
#include <stdbool.h> #include <stdbool.h>

View File

@ -52,7 +52,6 @@ bool check_proof_single(const KZGSettings *ks, const blst_p1 *commitment, const
C_KZG_RET compute_proof_multi(blst_p1 *out, const KZGSettings *ks, poly *p, const blst_fr *x0, uint64_t n) { C_KZG_RET compute_proof_multi(blst_p1 *out, const KZGSettings *ks, poly *p, const blst_fr *x0, uint64_t n) {
poly divisor, q; poly divisor, q;
blst_fr x_pow_n; blst_fr x_pow_n;
C_KZG_RET ret;
// Construct x^n - x0^n = (x - w^0)(x - w^1)...(x - w^(n-1)) // Construct x^n - x0^n = (x - w^0)(x - w^1)...(x - w^(n-1))
init_poly(&divisor, n + 1); init_poly(&divisor, n + 1);
@ -70,10 +69,7 @@ C_KZG_RET compute_proof_multi(blst_p1 *out, const KZGSettings *ks, poly *p, cons
divisor.coeffs[n] = fr_one; divisor.coeffs[n] = fr_one;
// Calculate q = p / (x^n - x0^n) // Calculate q = p / (x^n - x0^n)
init_poly(&q, poly_quotient_length(p, &divisor)); ASSERT(poly_long_div(&q, p, &divisor) == C_KZG_OK, C_KZG_ERROR);
if ((ret = poly_long_div(&q, p, &divisor) != C_KZG_OK)) {
return C_KZG_ERROR;
}
commit_to_poly(out, ks, &q); commit_to_poly(out, ks, &q);

View File

@ -24,6 +24,7 @@ static void poly_factor_div(blst_fr *out, const blst_fr *a, const blst_fr *b) {
void init_poly(poly *out, const uint64_t length) { void init_poly(poly *out, const uint64_t length) {
out->length = length; out->length = length;
out->coeffs = length > 0 ? malloc(length * sizeof(blst_fr)): NULL; out->coeffs = length > 0 ? malloc(length * sizeof(blst_fr)): NULL;
// TODO: check malloc return and handle accordingly
} }
void free_poly(poly *p) { void free_poly(poly *p) {
@ -62,18 +63,22 @@ uint64_t poly_quotient_length(const poly *dividend, const poly *divisor) {
return dividend->length >= divisor->length ? dividend->length - divisor->length + 1 : 0; return dividend->length >= divisor->length ? dividend->length - divisor->length + 1 : 0;
} }
// `out` must have been pre-allocated to the correct size, see `poly_quotient_length()` // `out` must be an uninitialised poly and has space allocated for it here, which
// must be freed by calling `free_poly()` later.
C_KZG_RET poly_long_div(poly *out, const poly *dividend, const poly *divisor) { C_KZG_RET poly_long_div(poly *out, const poly *dividend, const poly *divisor) {
uint64_t a_pos = dividend->length - 1; uint64_t a_pos = dividend->length - 1;
uint64_t b_pos = divisor->length - 1; uint64_t b_pos = divisor->length - 1;
uint64_t diff = a_pos - b_pos; uint64_t diff = a_pos - b_pos;
blst_fr a[dividend->length]; blst_fr a[dividend->length];
// Dividing by zero is undefined
ASSERT(divisor->length > 0, C_KZG_BADARGS); ASSERT(divisor->length > 0, C_KZG_BADARGS);
ASSERT(out->length == poly_quotient_length(dividend, divisor), C_KZG_BADARGS);
// Initialise the output polynomial
init_poly(out, poly_quotient_length(dividend, divisor));
// If the divisor is larger than the dividend, the result is zero-length // If the divisor is larger than the dividend, the result is zero-length
if (divisor->length > dividend->length) return C_KZG_OK; if (out->length == 0) return C_KZG_OK;
for (uint64_t i = 0; i < dividend->length; i++) { for (uint64_t i = 0; i < dividend->length; i++) {
a[i] = dividend->coeffs[i]; a[i] = dividend->coeffs[i];

View File

@ -33,7 +33,7 @@ void poly_div_length(void) {
} }
void poly_div_0(void) { void poly_div_0(void) {
blst_fr a[3], b[2], c[2], expected[2]; blst_fr a[3], b[2], expected[2];
poly dividend, divisor, actual; poly dividend, divisor, actual;
// Calculate (x^2 - 1) / (x + 1) = x - 1 // Calculate (x^2 - 1) / (x + 1) = x - 1
@ -57,16 +57,15 @@ void poly_div_0(void) {
fr_negate(&expected[0], &expected[0]); fr_negate(&expected[0], &expected[0]);
fr_from_uint64(&expected[1], 1); fr_from_uint64(&expected[1], 1);
actual.length = 2;
actual.coeffs = c;
TEST_CHECK(C_KZG_OK == poly_long_div(&actual, &dividend, &divisor)); TEST_CHECK(C_KZG_OK == poly_long_div(&actual, &dividend, &divisor));
TEST_CHECK(fr_equal(&expected[0], &actual.coeffs[0])); TEST_CHECK(fr_equal(&expected[0], &actual.coeffs[0]));
TEST_CHECK(fr_equal(&expected[1], &actual.coeffs[1])); TEST_CHECK(fr_equal(&expected[1], &actual.coeffs[1]));
free_poly(&actual);
} }
void poly_div_1(void) { void poly_div_1(void) {
blst_fr a[4], b[2], c[3], expected[3]; blst_fr a[4], b[2], expected[3];
poly dividend, divisor, actual; poly dividend, divisor, actual;
// Calculate (12x^3 - 11x^2 + 9x + 18) / (4x + 3) = 3x^2 - 5x + 6 // Calculate (12x^3 - 11x^2 + 9x + 18) / (4x + 3) = 3x^2 - 5x + 6
@ -92,13 +91,12 @@ void poly_div_1(void) {
fr_negate(&expected[1], &expected[1]); fr_negate(&expected[1], &expected[1]);
fr_from_uint64(&expected[2], 3); fr_from_uint64(&expected[2], 3);
actual.length = 3;
actual.coeffs = c;
TEST_CHECK(C_KZG_OK == poly_long_div(&actual, &dividend, &divisor)); TEST_CHECK(C_KZG_OK == poly_long_div(&actual, &dividend, &divisor));
TEST_CHECK(fr_equal(&expected[0], &actual.coeffs[0])); TEST_CHECK(fr_equal(&expected[0], &actual.coeffs[0]));
TEST_CHECK(fr_equal(&expected[1], &actual.coeffs[1])); TEST_CHECK(fr_equal(&expected[1], &actual.coeffs[1]));
TEST_CHECK(fr_equal(&expected[2], &actual.coeffs[2])); TEST_CHECK(fr_equal(&expected[2], &actual.coeffs[2]));
free_poly(&actual);
} }
void poly_div_2(void) { void poly_div_2(void) {
@ -121,8 +119,6 @@ void poly_div_2(void) {
divisor.length = 3; divisor.length = 3;
divisor.coeffs = b; divisor.coeffs = b;
init_poly(&actual, poly_quotient_length(&dividend, &divisor));
TEST_CHECK(C_KZG_OK == poly_long_div(&actual, &dividend, &divisor)); TEST_CHECK(C_KZG_OK == poly_long_div(&actual, &dividend, &divisor));
TEST_CHECK(NULL == actual.coeffs); TEST_CHECK(NULL == actual.coeffs);
@ -131,7 +127,7 @@ void poly_div_2(void) {
void poly_div_by_zero(void) { void poly_div_by_zero(void) {
blst_fr a[2]; blst_fr a[2];
poly dividend, divisor; poly dividend, divisor, dummy;
// Calculate (x + 1) / 0 = FAIL // Calculate (x + 1) / 0 = FAIL
@ -144,7 +140,7 @@ void poly_div_by_zero(void) {
// Divisor // Divisor
init_poly(&divisor, 0); init_poly(&divisor, 0);
TEST_CHECK(C_KZG_BADARGS == poly_long_div(NULL, &dividend, &divisor)); TEST_CHECK(C_KZG_BADARGS == poly_long_div(&dummy, &dividend, &divisor));
} }
void poly_eval_check(void) { void poly_eval_check(void) {