diff --git a/reference/src/cbits/ntt.c b/reference/src/cbits/ntt.c index 1b94b1e..929b937 100644 --- a/reference/src/cbits/ntt.c +++ b/reference/src/cbits/ntt.c @@ -6,41 +6,57 @@ #include #include "goldilocks.h" +#include "short_dft.h" #include "ntt.h" // ----------------------------------------------------------------------------- void goldilocks_ntt_forward_noalloc(int m, int src_stride, const uint64_t *gpows, const uint64_t *src, uint64_t *buf, uint64_t *tgt) { - if (m==0) { - tgt[0] = src[0]; - return; - } + switch(m) { - if (m==1) { - // N = 2 - tgt[0] = goldilocks_add( src[0] , src[src_stride] ); // x + y - tgt[1] = goldilocks_sub( src[0] , src[src_stride] ); // x - y - return; - } + case 4: + short_fwd_DFT_size_16( src_stride, 1, src, tgt ); + break; - else { - int N = (1<< m ); - int halfN = (1<<(m-1)); + case 3: + short_fwd_DFT_size_8( src_stride, 1, src, tgt ); + break; - goldilocks_ntt_forward_noalloc( m-1 , src_stride<<1 , gpows , src , buf + N , buf ); - goldilocks_ntt_forward_noalloc( m-1 , src_stride<<1 , gpows , src + src_stride , buf + N , buf + halfN ); + case 2: + short_fwd_DFT_size_4( src_stride, 1, src, tgt ); + break; - for(int j=0; j