Implement FFT with G1 elements

This commit is contained in:
Ben Edgington 2021-02-02 23:06:42 +00:00
parent 4ee33f7aba
commit ae43c79f3f
10 changed files with 282 additions and 40 deletions

View File

@ -6,6 +6,7 @@ Initially, at least, this largely follows the [go-kzg](https://github.com/protol
Done so far:
- Rough and ready FFT and inverse FFT over the finite field.
- Ditto for FFTs over the G1 group
## Installation

View File

@ -1,4 +1,4 @@
tests = fft_fr_test fft_util_test
tests = fft_util_test fft_fr_test fft_g1_test
.PRECIOUS: %.o
@ -9,6 +9,10 @@ fft_fr_test: fft_fr.o fft_fr_test.c test_util.o fft_util.o
clang -Wall -o $@ $@.c test_util.o fft_fr.o fft_util.o -L../lib -lblst
./$@
fft_g1_test: fft_g1.o fft_g1_test.c test_util.o fft_util.o
clang -Wall -o $@ $@.c test_util.o fft_g1.o fft_util.o -L../lib -lblst
./$@
%_test: %.o %_test.c test_util.o
clang -Wall -o $@ $@.c test_util.o $*.o -L../lib -lblst
./$@

View File

@ -17,44 +17,43 @@
#include "fft_fr.h"
// Slow Fourier Transform (simple, good for small sizes)
void fft_fr_slow(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) {
blst_fr v, last, tmp;
void fft_fr_slow(blst_fr *out, blst_fr *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) {
blst_fr v, last, jv, r;
for (uint64_t i = 0; i < l; i++) {
blst_fr jv = in[offset];
blst_fr r = roots[0];
blst_fr_mul(&v, &jv, &r);
last = v;
blst_fr_mul(&last, &in[0], &roots[0]);
for (uint64_t j = 1; j < l; j++) {
jv = in[offset + j * stride];
jv = in[j * stride];
r = roots[((i * j) % l) * roots_stride];
blst_fr_mul(&v, &jv, &r);
tmp = last;
blst_fr_add(&last, &tmp, &v);
blst_fr_add(&last, &last, &v);
}
out[i] = last;
}
}
// Fast Fourier Transform
void fft_fr_fast(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) {
void fft_fr_fast(blst_fr *out, blst_fr *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) {
uint64_t half = l / 2;
fft_fr_helper(out, in, offset, stride * 2, roots, roots_stride * 2, l / 2);
fft_fr_helper(out + half, in, offset + stride, stride * 2, roots, roots_stride * 2, l / 2);
for (uint64_t i = 0; i < half; i++) {
blst_fr y_times_root;
blst_fr x = out[i];
blst_fr_mul(&y_times_root, &out[i + half], &roots[i * roots_stride]);
blst_fr_add(&out[i], &x, &y_times_root);
blst_fr_sub(&out[i + half], &x, &y_times_root);
if (half > 0) {
fft_fr_helper(out, in, stride * 2, roots, roots_stride * 2, half);
fft_fr_helper(out + half, in + stride, stride * 2, roots, roots_stride * 2, half);
for (uint64_t i = 0; i < half; i++) {
blst_fr y_times_root;
blst_fr x = out[i];
blst_fr_mul(&y_times_root, &out[i + half], &roots[i * roots_stride]);
blst_fr_add(&out[i], &x, &y_times_root);
blst_fr_sub(&out[i + half], &x, &y_times_root);
}
} else {
blst_fr_mul(out, in, roots);
}
}
void fft_fr_helper(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) {
// TODO: Tunable parameter
if (l <= 4) {
fft_fr_slow(out, in, offset, stride, roots, roots_stride, l);
void fft_fr_helper(blst_fr *out, blst_fr *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) {
if (l <= 4) { // TODO: Tunable parameter
fft_fr_slow(out, in, stride, roots, roots_stride, l);
} else {
fft_fr_fast(out, in, offset, stride, roots, roots_stride, l);
fft_fr_fast(out, in, stride, roots, roots_stride, l);
}
}
@ -67,11 +66,11 @@ void fft_fr (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n) {
blst_fr inv_len;
fr_from_uint64(&inv_len, n);
blst_fr_eucl_inverse(&inv_len, &inv_len);
fft_fr_helper(out, in, 0, 1, fs->reverse_roots_of_unity, stride, fs->max_width);
fft_fr_helper(out, in, 1, fs->reverse_roots_of_unity, stride, fs->max_width);
for (uint64_t i = 0; i < fs->max_width; i++) {
blst_fr_mul(&out[i], &out[i], &inv_len);
}
} else {
fft_fr_helper(out, in, 0, 1, fs->expanded_roots_of_unity, stride, fs->max_width);
fft_fr_helper(out, in, 1, fs->expanded_roots_of_unity, stride, fs->max_width);
}
}

View File

@ -17,7 +17,7 @@
#include "c-kzg.h"
#include "fft_util.h"
void fft_fr_slow(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l);
void fft_fr_fast(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l);
void fft_fr_helper(blst_fr *out, blst_fr *in, uint64_t offset, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l);
void fft_fr_slow(blst_fr *out, blst_fr *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l);
void fft_fr_fast(blst_fr *out, blst_fr *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l);
void fft_fr_helper(blst_fr *out, blst_fr *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l);
void fft_fr (blst_fr *out, blst_fr *in, FFTSettings *fs, bool inv, uint64_t n);

View File

@ -40,7 +40,7 @@ const uint64_t inv_fft_expected[][4] =
void compare_sft_fft(void) {
// Initialise: ascending values of i (could be anything), and arbitrary size
unsigned int size = 8;
unsigned int size = 12;
FFTSettings fs = new_fft_settings(size);
blst_fr data[fs.max_width], out0[fs.max_width], out1[fs.max_width];
for (int i = 0; i < fs.max_width; i++) {
@ -48,8 +48,8 @@ void compare_sft_fft(void) {
}
// Do both fast and slow transforms
fft_fr_slow(out0, data, 0, 1, fs.expanded_roots_of_unity, 1, fs.max_width);
fft_fr_fast(out1, data, 0, 1, fs.expanded_roots_of_unity, 1, fs.max_width);
fft_fr_slow(out0, data, 1, fs.expanded_roots_of_unity, 1, fs.max_width);
fft_fr_fast(out1, data, 1, fs.expanded_roots_of_unity, 1, fs.max_width);
// Verify the results are identical
for (int i = 0; i < fs.max_width; i++) {

91
src/fft_g1.c Normal file
View File

@ -0,0 +1,91 @@
/*
* Copyright 2021 Benjamin Edgington
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fft_g1.h"
#include "test_util.h"
void p1_mul(blst_p1 *out, const blst_p1 *a, const blst_fr *b) {
blst_scalar s;
blst_scalar_from_fr(&s, b);
//blst_p1_mult(out, a, s.b, 8 * sizeof(blst_scalar));
blst_p1_mult(out, a, s.b, 256);
}
void p1_sub(blst_p1 *out, const blst_p1 *a, const blst_p1 *b) {
blst_p1 bneg = *b;
blst_p1_cneg(&bneg, true);
blst_p1_add_or_double(out, a, &bneg);
}
// Slow Fourier Transform (simple, good for small sizes)
void fft_g1_slow(blst_p1 *out, blst_p1 *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) {
blst_p1 v, last, jv;
blst_fr r;
for (uint64_t i = 0; i < l; i++) {
p1_mul(&last, &in[0], &roots[0]);
for (uint64_t j = 1; j < l; j++) {
jv = in[j * stride];
r = roots[((i * j) % l) * roots_stride];
p1_mul(&v, &jv, &r);
blst_p1_add_or_double(&last, &last, &v);
}
out[i] = last;
}
}
// Fast Fourier Transform
void fft_g1_fast(blst_p1 *out, blst_p1 *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) {
uint64_t half = l / 2;
if (half > 0) {
fft_g1_helper(out, in, stride * 2, roots, roots_stride * 2, half);
fft_g1_helper(out + half, in + stride, stride * 2, roots, roots_stride * 2, half);
for (uint64_t i = 0; i < half; i++) {
blst_p1 y_times_root;
blst_p1 x = out[i];
p1_mul(&y_times_root, &out[i + half], &roots[i * roots_stride]);
blst_p1_add_or_double(&out[i], &x, &y_times_root);
p1_sub(&out[i + half], &x, &y_times_root);
}
} else {
p1_mul(out, in, roots);
}
}
void fft_g1_helper(blst_p1 *out, blst_p1 *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l) {
if (l <= 4) { // TODO: Tunable parameter
fft_g1_slow(out, in, stride, roots, roots_stride, l);
} else {
fft_g1_fast(out, in, stride, roots, roots_stride, l);
}
}
// The main entry point for forward and reverse FFTs
void fft_g1 (blst_p1 *out, blst_p1 *in, FFTSettings *fs, bool inv, uint64_t n) {
uint64_t stride = fs->max_width / n;
assert(n <= fs->max_width);
assert(is_power_of_two(n));
if (inv) {
blst_fr inv_len;
fr_from_uint64(&inv_len, n);
blst_fr_eucl_inverse(&inv_len, &inv_len);
fft_g1_helper(out, in, 1, fs->reverse_roots_of_unity, stride, fs->max_width);
for (uint64_t i = 0; i < fs->max_width; i++) {
p1_mul(&out[i], &out[i], &inv_len);
}
} else {
fft_g1_helper(out, in, 1, fs->expanded_roots_of_unity, stride, fs->max_width);
}
}

25
src/fft_g1.h Normal file
View File

@ -0,0 +1,25 @@
/*
* Copyright 2021 Benjamin Edgington
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c-kzg.h"
#include "fft_util.h"
void p1_mul(blst_p1 *out, const blst_p1 *a, const blst_fr *b);
void p1_sub(blst_p1 *out, const blst_p1 *a, const blst_p1 *b);
void fft_g1_slow(blst_p1 *out, blst_p1 *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l);
void fft_g1_fast(blst_p1 *out, blst_p1 *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l);
void fft_g1_helper(blst_p1 *out, blst_p1 *in, uint64_t stride, blst_fr *roots, uint64_t roots_stride, uint64_t l);
void fft_g1 (blst_p1 *out, blst_p1 *in, FFTSettings *fs, bool inv, uint64_t n);

108
src/fft_g1_test.c Normal file
View File

@ -0,0 +1,108 @@
/*
* Copyright 2021 Benjamin Edgington
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../inc/acutest.h"
#include "test_util.h"
#include "fft_g1.h"
// The G1 subgroup size minus 1
const uint64_t r_minus_1[] = {0xffffffff00000000L, 0x53bda402fffe5bfeL, 0x3339d80809a1d805L, 0x73eda753299d7d48L};
void make_data(blst_p1 *out, uint64_t n) {
// Multiples of g1_gen
assert(n > 0);
blst_p1_from_affine(out + 0, &BLS12_381_G1);
for (int i = 1; i < n; i++) {
blst_p1_add_affine(out + i, out + i - 1, &BLS12_381_G1);
}
}
void p1_mul_works(void) {
blst_fr rm1;
blst_p1 g1_gen, g1_gen_neg, res;
// Multiply the generator by the group order minus one
blst_p1_from_affine(&g1_gen, &BLS12_381_G1);
blst_fr_from_uint64(&rm1, r_minus_1);
p1_mul(&res, &g1_gen, &rm1);
// We should end up with negative the generator
blst_p1_from_affine(&g1_gen_neg, &BLS12_381_NEG_G1);
TEST_CHECK(blst_p1_is_equal(&res, &g1_gen_neg));
}
void p1_sub_works(void) {
blst_p1 g1_gen, g1_gen_neg;
blst_p1 tmp, res;
blst_p1_from_affine(&g1_gen, &BLS12_381_G1);
blst_p1_from_affine(&g1_gen_neg, &BLS12_381_NEG_G1);
// 2 * g1_gen = g1_gen - g1_gen_neg
blst_p1_double(&tmp, &g1_gen);
p1_sub(&res, &g1_gen, &g1_gen_neg);
TEST_CHECK(blst_p1_is_equal(&tmp, &res));
}
void compare_sft_fft(void) {
// Initialise: arbitrary size
unsigned int size = 6;
FFTSettings fs = new_fft_settings(size);
blst_p1 data[fs.max_width], slow[fs.max_width], fast[fs.max_width];
make_data(data, fs.max_width);
// Do both fast and slow transforms
fft_g1_slow(slow, data, 1, fs.expanded_roots_of_unity, 1, fs.max_width);
fft_g1_fast(fast, data, 1, fs.expanded_roots_of_unity, 1, fs.max_width);
// Verify the results are identical
for (int i = 0; i < fs.max_width; i++) {
TEST_CHECK(blst_p1_is_equal(slow + i, fast + i));
}
free_fft_settings(&fs);
}
void roundtrip_fft(void) {
// Initialise: arbitrary size
unsigned int size = 10;
FFTSettings fs = new_fft_settings(size);
blst_p1 expected[fs.max_width], data[fs.max_width], coeffs[fs.max_width];
make_data(expected, fs.max_width);
make_data(data, fs.max_width);
// Forward and reverse FFT
fft_g1(coeffs, data, &fs, false, fs.max_width);
fft_g1(data, coeffs, &fs, true, fs.max_width);
// Verify that the result is still ascending values of i
for (int i = 0; i < fs.max_width; i++) {
TEST_CHECK(blst_p1_is_equal(expected + i, data + i));
}
free_fft_settings(&fs);
}
TEST_LIST =
{
{"p1_mul_works", p1_mul_works},
{"p1_sub_works", p1_sub_works},
{"compare_sft_fft", compare_sft_fft},
{"roundtrip_fft", roundtrip_fft},
{ NULL, NULL } /* zero record marks the end of the list */
};

View File

@ -56,16 +56,28 @@ bool fr_equal(blst_fr *aa, blst_fr *bb) {
// G1 and G2 utilities
//
void print_p1_bytes(byte p1[96]) {
printf("[0x");
print_bytes_as_hex(p1, 0, 48);
printf(",0x");
print_bytes_as_hex(p1, 48, 48);
printf("]\n");
}
/* "Pretty" print a point in G1 */
void print_p1(const blst_p1 *p1) {
byte *p1_bytes = (byte *)malloc(96);
blst_p1_serialize(p1_bytes, p1);
print_p1_bytes(p1_bytes);
free(p1_bytes);
}
/* "Pretty" print an affine point in G1 */
void print_p1_affine(const blst_p1_affine *p1) {
byte *p1_hex = (byte *)malloc(96);
blst_p1_affine_serialize(p1_hex, p1);
printf("[0x");
print_bytes_as_hex(p1_hex, 0, 48);
printf(",0x");
print_bytes_as_hex(p1_hex, 48, 48);
printf("]\n");
free(p1_hex);
byte *p1_bytes = (byte *)malloc(96);
blst_p1_affine_serialize(p1_bytes, p1);
print_p1_bytes(p1_bytes);
free(p1_bytes);
}
/* "Pretty" print an affine point in G2 */

View File

@ -27,5 +27,7 @@ void print_fr(const blst_fr *a);
bool fr_equal(blst_fr *aa, blst_fr *bb);
// G1 and G2 utilities
void print_p1_bytes(byte p1[96]);
void print_p1(const blst_p1 *p1);
void print_p1_affine(const blst_p1_affine *p1);
void print_p2_affine(const blst_p2_affine *p2);