Simplify reduction loop

This commit is contained in:
Ben Edgington 2021-04-28 19:56:22 +01:00
parent f25ada5ea9
commit 0924f243b4
2 changed files with 18 additions and 19 deletions

View File

@ -20,6 +20,13 @@
#include "c_kzg.h"
#include "poly.h"
#define min_u64(a, b) \
({ \
uint64_t _a = (a); \
uint64_t _b = (b); \
_a < _b ? _a : _b; \
})
C_KZG_RET c_kzg_malloc(void **p, size_t n);
C_KZG_RET new_uint64_array(uint64_t **x, size_t n);
C_KZG_RET new_fr_array(fr_t **x, size_t n);

View File

@ -189,12 +189,13 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u
uint64_t missing_per_partial = degree_of_partial - 1;
uint64_t domain_stride = fs->max_width / length;
uint64_t partial_count = (len_missing + missing_per_partial - 1) / missing_per_partial;
uint64_t n = next_power_of_two(partial_count * degree_of_partial);
if (n > length) n = length;
uint64_t n = min_u64(next_power_of_two(partial_count * degree_of_partial), length);
if (len_missing <= missing_per_partial) {
TRY(do_zero_poly_mul_partial(zero_poly, missing_indices, len_missing, domain_stride, fs));
TRY(fft_fr(zero_eval, zero_poly->coeffs, false, length, fs));
} else {
// Work space for building and reducing the partials
@ -209,8 +210,7 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u
TRY(new_poly_array(&partials, partial_count));
uint64_t offset = 0, out_offset = 0, max = len_missing;
for (int i = 0; i < partial_count; i++) {
uint64_t end = offset + missing_per_partial;
if (end > max) end = max;
uint64_t end = min_u64(offset + missing_per_partial, max);
partials[i].coeffs = &work[out_offset];
partials[i].length = degree_of_partial;
TRY(do_zero_poly_mul_partial(&partials[i], &missing_indices[offset], end - offset, domain_stride, fs));
@ -221,8 +221,7 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u
partials[partial_count - 1].length = 1 + len_missing - (partial_count - 1) * missing_per_partial;
// Reduce all the partials to a single polynomial
int reduction_factor = 4; // must be a power of 2 (for sake of the FFTs in reduce partials)
int reduction_factor = 4; // must be a power of 2 (for sake of the FFTs in reduce_partials)
fr_t *scratch;
TRY(new_fr_array(&scratch, n * 3));
while (partial_count > 1) {
@ -230,18 +229,12 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u
uint64_t partial_size = next_power_of_two(partials[0].length);
for (uint64_t i = 0; i < reduced_count; i++) {
uint64_t start = i * reduction_factor;
uint64_t end = start + reduction_factor;
uint64_t out_end = end * partial_size;
if (out_end > n) out_end = n;
fr_t *reduced = work + start * partial_size;
uint64_t reduced_len = out_end - start * partial_size;
if (reduced_len > length) reduced_len = length;
if (end > partial_count) end = partial_count;
uint64_t partials_slice = end - start;
partials[i].coeffs = reduced;
if (partials_slice > 1) {
TRY(reduce_partials(&partials[i], reduced_len, scratch, n * 3, &partials[start], partials_slice,
fs));
uint64_t out_end = min_u64((start + reduction_factor) * partial_size, n);
uint64_t reduced_len = min_u64(out_end - start * partial_size, length);
uint64_t partials_num = min_u64(reduction_factor, partial_count - start);
partials[i].coeffs = work + start * partial_size;
if (partials_num > 1) {
TRY(reduce_partials(&partials[i], reduced_len, scratch, n * 3, &partials[start], partials_num, fs));
} else {
partials[i].length = partials[start].length;
}
@ -250,7 +243,6 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u
}
// Process final output
TRY(pad_p(zero_poly->coeffs, length, &partials[0]));
TRY(fft_fr(zero_eval, zero_poly->coeffs, false, length, fs));