Simplify reduction loop
This commit is contained in:
parent
f25ada5ea9
commit
0924f243b4
|
@ -20,6 +20,13 @@
|
||||||
#include "c_kzg.h"
|
#include "c_kzg.h"
|
||||||
#include "poly.h"
|
#include "poly.h"
|
||||||
|
|
||||||
|
#define min_u64(a, b) \
|
||||||
|
({ \
|
||||||
|
uint64_t _a = (a); \
|
||||||
|
uint64_t _b = (b); \
|
||||||
|
_a < _b ? _a : _b; \
|
||||||
|
})
|
||||||
|
|
||||||
C_KZG_RET c_kzg_malloc(void **p, size_t n);
|
C_KZG_RET c_kzg_malloc(void **p, size_t n);
|
||||||
C_KZG_RET new_uint64_array(uint64_t **x, size_t n);
|
C_KZG_RET new_uint64_array(uint64_t **x, size_t n);
|
||||||
C_KZG_RET new_fr_array(fr_t **x, size_t n);
|
C_KZG_RET new_fr_array(fr_t **x, size_t n);
|
||||||
|
|
|
@ -189,12 +189,13 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u
|
||||||
uint64_t missing_per_partial = degree_of_partial - 1;
|
uint64_t missing_per_partial = degree_of_partial - 1;
|
||||||
uint64_t domain_stride = fs->max_width / length;
|
uint64_t domain_stride = fs->max_width / length;
|
||||||
uint64_t partial_count = (len_missing + missing_per_partial - 1) / missing_per_partial;
|
uint64_t partial_count = (len_missing + missing_per_partial - 1) / missing_per_partial;
|
||||||
uint64_t n = next_power_of_two(partial_count * degree_of_partial);
|
uint64_t n = min_u64(next_power_of_two(partial_count * degree_of_partial), length);
|
||||||
if (n > length) n = length;
|
|
||||||
|
|
||||||
if (len_missing <= missing_per_partial) {
|
if (len_missing <= missing_per_partial) {
|
||||||
|
|
||||||
TRY(do_zero_poly_mul_partial(zero_poly, missing_indices, len_missing, domain_stride, fs));
|
TRY(do_zero_poly_mul_partial(zero_poly, missing_indices, len_missing, domain_stride, fs));
|
||||||
TRY(fft_fr(zero_eval, zero_poly->coeffs, false, length, fs));
|
TRY(fft_fr(zero_eval, zero_poly->coeffs, false, length, fs));
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
// Work space for building and reducing the partials
|
// Work space for building and reducing the partials
|
||||||
|
@ -209,8 +210,7 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u
|
||||||
TRY(new_poly_array(&partials, partial_count));
|
TRY(new_poly_array(&partials, partial_count));
|
||||||
uint64_t offset = 0, out_offset = 0, max = len_missing;
|
uint64_t offset = 0, out_offset = 0, max = len_missing;
|
||||||
for (int i = 0; i < partial_count; i++) {
|
for (int i = 0; i < partial_count; i++) {
|
||||||
uint64_t end = offset + missing_per_partial;
|
uint64_t end = min_u64(offset + missing_per_partial, max);
|
||||||
if (end > max) end = max;
|
|
||||||
partials[i].coeffs = &work[out_offset];
|
partials[i].coeffs = &work[out_offset];
|
||||||
partials[i].length = degree_of_partial;
|
partials[i].length = degree_of_partial;
|
||||||
TRY(do_zero_poly_mul_partial(&partials[i], &missing_indices[offset], end - offset, domain_stride, fs));
|
TRY(do_zero_poly_mul_partial(&partials[i], &missing_indices[offset], end - offset, domain_stride, fs));
|
||||||
|
@ -221,8 +221,7 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u
|
||||||
partials[partial_count - 1].length = 1 + len_missing - (partial_count - 1) * missing_per_partial;
|
partials[partial_count - 1].length = 1 + len_missing - (partial_count - 1) * missing_per_partial;
|
||||||
|
|
||||||
// Reduce all the partials to a single polynomial
|
// Reduce all the partials to a single polynomial
|
||||||
|
int reduction_factor = 4; // must be a power of 2 (for sake of the FFTs in reduce_partials)
|
||||||
int reduction_factor = 4; // must be a power of 2 (for sake of the FFTs in reduce partials)
|
|
||||||
fr_t *scratch;
|
fr_t *scratch;
|
||||||
TRY(new_fr_array(&scratch, n * 3));
|
TRY(new_fr_array(&scratch, n * 3));
|
||||||
while (partial_count > 1) {
|
while (partial_count > 1) {
|
||||||
|
@ -230,18 +229,12 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u
|
||||||
uint64_t partial_size = next_power_of_two(partials[0].length);
|
uint64_t partial_size = next_power_of_two(partials[0].length);
|
||||||
for (uint64_t i = 0; i < reduced_count; i++) {
|
for (uint64_t i = 0; i < reduced_count; i++) {
|
||||||
uint64_t start = i * reduction_factor;
|
uint64_t start = i * reduction_factor;
|
||||||
uint64_t end = start + reduction_factor;
|
uint64_t out_end = min_u64((start + reduction_factor) * partial_size, n);
|
||||||
uint64_t out_end = end * partial_size;
|
uint64_t reduced_len = min_u64(out_end - start * partial_size, length);
|
||||||
if (out_end > n) out_end = n;
|
uint64_t partials_num = min_u64(reduction_factor, partial_count - start);
|
||||||
fr_t *reduced = work + start * partial_size;
|
partials[i].coeffs = work + start * partial_size;
|
||||||
uint64_t reduced_len = out_end - start * partial_size;
|
if (partials_num > 1) {
|
||||||
if (reduced_len > length) reduced_len = length;
|
TRY(reduce_partials(&partials[i], reduced_len, scratch, n * 3, &partials[start], partials_num, fs));
|
||||||
if (end > partial_count) end = partial_count;
|
|
||||||
uint64_t partials_slice = end - start;
|
|
||||||
partials[i].coeffs = reduced;
|
|
||||||
if (partials_slice > 1) {
|
|
||||||
TRY(reduce_partials(&partials[i], reduced_len, scratch, n * 3, &partials[start], partials_slice,
|
|
||||||
fs));
|
|
||||||
} else {
|
} else {
|
||||||
partials[i].length = partials[start].length;
|
partials[i].length = partials[start].length;
|
||||||
}
|
}
|
||||||
|
@ -250,7 +243,6 @@ C_KZG_RET zero_polynomial_via_multiplication(fr_t *zero_eval, poly *zero_poly, u
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process final output
|
// Process final output
|
||||||
|
|
||||||
TRY(pad_p(zero_poly->coeffs, length, &partials[0]));
|
TRY(pad_p(zero_poly->coeffs, length, &partials[0]));
|
||||||
TRY(fft_fr(zero_eval, zero_poly->coeffs, false, length, fs));
|
TRY(fft_fr(zero_eval, zero_poly->coeffs, false, length, fs));
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue