133 lines
5.7 KiB
Haskell

{-# LANGUAGE StrictData, ForeignFunctionInterface #-}
module NTT.FFT.Fast where
--------------------------------------------------------------------------------
import Data.Word
import Foreign.C
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.Marshal
import System.IO.Unsafe
import Data.Flat
import NTT.Poly
import NTT.Subgroup
import Field.Goldilocks
import Misc
--------------------------------------------------------------------------------
{-
void goldilocks_ntt_forward ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_forward_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, uint64_t gen_src, uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_inverse ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_inverse_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
-}
foreign import ccall unsafe "goldilocks_ntt_forward" c_ntt_forward :: CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "goldilocks_ntt_forward_shifted" c_ntt_forward_shifted :: Word64 -> CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "goldilocks_ntt_forward_asymmetric" c_ntt_forward_asymmetric :: CInt -> CInt -> Word64 -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ();
foreign import ccall unsafe "goldilocks_ntt_inverse" c_ntt_inverse :: CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "goldilocks_ntt_inverse_shifted" c_ntt_inverse_shifted :: Word64 -> CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
--------------------------------------------------------------------------------
{-# NOINLINE forwardNTT #-}
forwardNTT :: Subgroup F -> Poly F -> FlatArray F
forwardNTT sg (MkPoly (MkFlatArray n fptr2))
| subgroupSize sg /= n = error "forwardNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n
let MkGoldilocks gen = subgroupGen sg
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
c_ntt_forward (subgroupCLogSize sg) gen ptr2 ptr3
return (MkFlatArray n fptr3)
{-# NOINLINE inverseNTT #-}
inverseNTT :: Subgroup F -> FlatArray F -> Poly F
inverseNTT sg (MkFlatArray n fptr2)
| subgroupSize sg /= n = error "inverseNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n
let MkGoldilocks gen = subgroupGen sg
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
c_ntt_inverse (subgroupCLogSize sg) gen ptr2 ptr3
return (MkPoly (MkFlatArray n fptr3))
-- | Pre-multiplies the coefficients by powers of eta, effectively evaluating @f(eta*x)@ on the subgroup
{-# NOINLINE shiftedForwardNTT #-}
shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F
shiftedForwardNTT sg (MkGoldilocks eta) (MkPoly (MkFlatArray n fptr2))
| subgroupSize sg /= n = error "shiftedForwardNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n
let MkGoldilocks gen = subgroupGen sg
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
c_ntt_forward_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3
return (MkFlatArray n fptr3)
-- | Post-multiplies the coefficients by powers of eta, effectively interpolating @f@ such that @f(eta^-1 * omega^k) = y_k@
{-# NOINLINE shiftedInverseNTT #-}
shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F
shiftedInverseNTT sg (MkGoldilocks eta) (MkFlatArray n fptr2)
| subgroupSize sg /= n = error "shiftedInverseNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
c_ntt_inverse_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3
return (MkPoly (MkFlatArray n fptr3))
-- | Evaluates a polynomial f on a larger subgroup than it's defined on
{-# NOINLINE asymmForwardNTT #-}
asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F
asymmForwardNTT sg_src (MkPoly (MkFlatArray n fptr2)) sg_tgt
| subgroupSize sg_src /= n = error "asymmForwardNTT: subgroup size differs from the array size"
| m < n = error "asymmForwardNTT: target subgroup size should be at least the source subgroup src"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray m
let MkGoldilocks sgen1 = subgroupGen sg_src
let MkGoldilocks sgen2 = subgroupGen sg_tgt
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
c_ntt_forward_asymmetric
(subgroupCLogSize sg_src)
(subgroupCLogSize sg_tgt)
sgen1 sgen2 ptr2 ptr3
return (MkFlatArray m fptr3)
where
m = subgroupSize sg_tgt
{-
instance P.UnivariateFFT Poly where
ntt = forwardNTT
intt = inverseNTT
shiftedNTT = shiftedForwardNTT
shiftedINTT = shiftedInverseNTT
asymmNTT = asymmForwardNTT
-}
--------------------------------------------------------------------------------
subgroupCLogSize :: Subgroup a -> CInt
subgroupCLogSize = fromIntegral . fromLog2 . subgroupLogSize
subgroupGenAsWord64 :: Subgroup F -> Word64
subgroupGenAsWord64 sg = let MkGoldilocks x = subgroupGen sg in x
--------------------------------------------------------------------------------