diff --git a/src/c_kzg_util.h b/src/c_kzg_util.h index c2487a1..b49b2a2 100644 --- a/src/c_kzg_util.h +++ b/src/c_kzg_util.h @@ -20,6 +20,13 @@ #include "c_kzg.h" #include "poly.h" +#define min_u64(a, b) \ + ({ \ + uint64_t _a = (a); \ + uint64_t _b = (b); \ + _a < _b ? _a : _b; \ + }) + C_KZG_RET c_kzg_malloc(void **p, size_t n); C_KZG_RET new_uint64_array(uint64_t **x, size_t n); C_KZG_RET new_fr_array(fr_t **x, size_t n); diff --git a/src/zero_poly.c b/src/zero_poly.c index b1ce64c..cdd7cd0 100644 --- a/src/zero_poly.c +++ b/src/zero_poly.c @@ -189,12 +189,13 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u uint64_t missing_per_partial = degree_of_partial - 1; uint64_t domain_stride = fs->max_width / length; uint64_t partial_count = (len_missing + missing_per_partial - 1) / missing_per_partial; - uint64_t n = next_power_of_two(partial_count * degree_of_partial); - if (n > length) n = length; + uint64_t n = min_u64(next_power_of_two(partial_count * degree_of_partial), length); if (len_missing <= missing_per_partial) { + TRY(do_zero_poly_mul_partial(zero_poly, missing_indices, len_missing, domain_stride, fs)); TRY(fft_fr(zero_eval, zero_poly->coeffs, false, length, fs)); + } else { // Work space for building and reducing the partials @@ -209,8 +210,7 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u TRY(new_poly_array(&partials, partial_count)); uint64_t offset = 0, out_offset = 0, max = len_missing; for (int i = 0; i < partial_count; i++) { - uint64_t end = offset + missing_per_partial; - if (end > max) end = max; + uint64_t end = min_u64(offset + missing_per_partial, max); partials[i].coeffs = &work[out_offset]; partials[i].length = degree_of_partial; TRY(do_zero_poly_mul_partial(&partials[i], &missing_indices[offset], end - offset, domain_stride, fs)); @@ -221,8 +221,7 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u partials[partial_count - 1].length = 1 + len_missing - (partial_count - 1) * missing_per_partial; // Reduce all the partials to a single polynomial - - int reduction_factor = 4; // must be a power of 2 (for sake of the FFTs in reduce partials) + int reduction_factor = 4; // must be a power of 2 (for sake of the FFTs in reduce_partials) fr_t *scratch; TRY(new_fr_array(&scratch, n * 3)); while (partial_count > 1) { @@ -230,18 +229,12 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u uint64_t partial_size = next_power_of_two(partials[0].length); for (uint64_t i = 0; i < reduced_count; i++) { uint64_t start = i * reduction_factor; - uint64_t end = start + reduction_factor; - uint64_t out_end = end * partial_size; - if (out_end > n) out_end = n; - fr_t *reduced = work + start * partial_size; - uint64_t reduced_len = out_end - start * partial_size; - if (reduced_len > length) reduced_len = length; - if (end > partial_count) end = partial_count; - uint64_t partials_slice = end - start; - partials[i].coeffs = reduced; - if (partials_slice > 1) { - TRY(reduce_partials(&partials[i], reduced_len, scratch, n * 3, &partials[start], partials_slice, - fs)); + uint64_t out_end = min_u64((start + reduction_factor) * partial_size, n); + uint64_t reduced_len = min_u64(out_end - start * partial_size, length); + uint64_t partials_num = min_u64(reduction_factor, partial_count - start); + partials[i].coeffs = work + start * partial_size; + if (partials_num > 1) { + TRY(reduce_partials(&partials[i], reduced_len, scratch, n * 3, &partials[start], partials_num, fs)); } else { partials[i].length = partials[start].length; } @@ -250,7 +243,6 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u } // Process final output - TRY(pad_p(zero_poly->coeffs, length, &partials[0])); TRY(fft_fr(zero_eval, zero_poly->coeffs, false, length, fs));