diff --git a/src/c-kzg.h b/src/c-kzg.h index bbc4978..4e16f61 100644 --- a/src/c-kzg.h +++ b/src/c-kzg.h @@ -14,6 +14,30 @@ * limitations under the License. */ +#ifndef C_KZG_H +#define C_KZG_H + +typedef enum { + C_KZG_SUCCESS = 0, + C_KZG_BADARGS, + c_KZG_ERROR, +} C_KZG_RET; + #include -#include #include "../inc/blst.h" + +#define DEBUG + +#include +#ifdef DEBUG +#include +#define ASSERT(cond, ret) if (!(cond)) \ + { \ + printf("\n%s:%d: Failed ASSERT: %s\n", __FILE__, __LINE__, #cond); \ + abort(); \ + } +#else +#define ASSERT(cond, ret) if (!(cond)) return (ret) +#endif + +#endif diff --git a/src/fft_fr.c b/src/fft_fr.c index f9391f6..64ad8ea 100644 --- a/src/fft_fr.c +++ b/src/fft_fr.c @@ -50,10 +50,10 @@ void fft_fr_fast(blst_fr *out, blst_fr *in, uint64_t stride, blst_fr *roots, uin } // The main entry point for forward and reverse FFTs -void fft_fr (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n) { +C_KZG_RET fft_fr (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n) { uint64_t stride = fs->max_width / n; - assert(n <= fs->max_width); - assert(is_power_of_two(n)); + ASSERT(n <= fs->max_width, C_KZG_BADARGS); + ASSERT(is_power_of_two(n), C_KZG_BADARGS); if (inv) { blst_fr inv_len; fr_from_uint64(&inv_len, n); @@ -65,4 +65,5 @@ void fft_fr (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n) { } else { fft_fr_fast(out, in, 1, fs->expanded_roots_of_unity, stride, fs->max_width); } + return C_KZG_SUCCESS; } diff --git a/src/fft_fr.h b/src/fft_fr.h index 48f96d7..03cb147 100644 --- a/src/fft_fr.h +++ b/src/fft_fr.h @@ -19,4 +19,4 @@ void fft_fr_slow(blst_fr *out, blst_fr *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l); void fft_fr_fast(blst_fr *out, blst_fr *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l); -void fft_fr (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n); +C_KZG_RET fft_fr (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n); diff --git a/src/fft_fr_test.c b/src/fft_fr_test.c index de45f7d..99ec96a 100644 --- a/src/fft_fr_test.c +++ b/src/fft_fr_test.c @@ -69,8 +69,8 @@ void roundtrip_fft(void) { } // Forward and reverse FFT - fft_fr(coeffs, data, &fs, false, fs.max_width); - fft_fr(data, coeffs, &fs, true, fs.max_width); + TEST_CHECK(fft_fr(coeffs, data, &fs, false, fs.max_width) == C_KZG_SUCCESS); + TEST_CHECK(fft_fr(data, coeffs, &fs, true, fs.max_width) == C_KZG_SUCCESS); // Verify that the result is still ascending values of i for (int i = 0; i < fs.max_width; i++) { @@ -91,7 +91,7 @@ void inverse_fft(void) { } // Inverst FFT - fft_fr(out, data, &fs, true, fs.max_width); + TEST_CHECK(fft_fr(out, data, &fs, true, fs.max_width) == C_KZG_SUCCESS); // Verify against the known result, `inv_fft_expected` int n = sizeof(inv_fft_expected) / sizeof(inv_fft_expected[0]); diff --git a/src/fft_g1.c b/src/fft_g1.c index c1ed208..afed84f 100644 --- a/src/fft_g1.c +++ b/src/fft_g1.c @@ -63,10 +63,10 @@ void fft_g1_fast(blst_p1 *out, blst_p1 *in, uint64_t stride, blst_fr *roots, uin } // The main entry point for forward and reverse FFTs -void fft_g1 (blst_p1 *out, blst_p1 *in, FFTSettings *fs, bool inv, uint64_t n) { +C_KZG_RET fft_g1 (blst_p1 *out, blst_p1 *in, FFTSettings *fs, bool inv, uint64_t n) { uint64_t stride = fs->max_width / n; - assert(n <= fs->max_width); - assert(is_power_of_two(n)); + ASSERT(n <= fs->max_width, C_KZG_BADARGS); + ASSERT(is_power_of_two(n), C_KZG_BADARGS); if (inv) { blst_fr inv_len; fr_from_uint64(&inv_len, n); @@ -78,4 +78,5 @@ void fft_g1 (blst_p1 *out, blst_p1 *in, FFTSettings *fs, bool inv, uint64_t n) { } else { fft_g1_fast(out, in, 1, fs->expanded_roots_of_unity, stride, fs->max_width); } + return C_KZG_SUCCESS; } diff --git a/src/fft_g1.h b/src/fft_g1.h index b5ffbd6..dde2bda 100644 --- a/src/fft_g1.h +++ b/src/fft_g1.h @@ -21,4 +21,4 @@ void p1_mul(blst_p1 *out, const blst_p1 *a, const blst_fr *b); void p1_sub(blst_p1 *out, const blst_p1 *a, const blst_p1 *b); void fft_g1_slow(blst_p1 *out, blst_p1 *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l); void fft_g1_fast(blst_p1 *out, blst_p1 *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l); -void fft_g1 (blst_p1 *out, blst_p1 *in, FFTSettings *fs, bool inv, uint64_t n); +C_KZG_RET fft_g1 (blst_p1 *out, blst_p1 *in, FFTSettings *fs, bool inv, uint64_t n); diff --git a/src/fft_g1_test.c b/src/fft_g1_test.c index adc360e..35b74ad 100644 --- a/src/fft_g1_test.c +++ b/src/fft_g1_test.c @@ -84,8 +84,8 @@ void roundtrip_fft(void) { make_data(data, fs.max_width); // Forward and reverse FFT - fft_g1(coeffs, data, &fs, false, fs.max_width); - fft_g1(data, coeffs, &fs, true, fs.max_width); + TEST_CHECK(fft_g1(coeffs, data, &fs, false, fs.max_width) == C_KZG_SUCCESS); + TEST_CHECK(fft_g1(data, coeffs, &fs, true, fs.max_width) == C_KZG_SUCCESS); // Verify that the result is still ascending values of i for (int i = 0; i < fs.max_width; i++) { diff --git a/src/fft_util.c b/src/fft_util.c index b926f38..e273afb 100644 --- a/src/fft_util.c +++ b/src/fft_util.c @@ -39,6 +39,7 @@ blst_fr *expand_root_of_unity(blst_fr *root_of_unity, uint64_t width) { roots[1] = *root_of_unity; for (int i = 2; !is_one(&roots[i - 1]); i++) { + //ASSERT(i <= width, C_KZG_ERROR); assert(i <= width); blst_fr_mul(&roots[i], &roots[i - 1], root_of_unity); }