mirror of
https://github.com/logos-storage/outsourcing-Reed-Solomon.git
synced 2026-01-02 13:43:07 +00:00
specialized algorithms for short NTT/INTT (size 4, 8, 16)
This commit is contained in:
parent
bd888d5b57
commit
0eb39eb5c9
590
reference/src/NTT/FFT/Short.hs
Normal file
590
reference/src/NTT/FFT/Short.hs
Normal file
@ -0,0 +1,590 @@
|
||||
|
||||
-- Short DFT algorithms (sizes 2, 4, 8 and 16)
|
||||
--
|
||||
-- See :
|
||||
--
|
||||
-- * Nussbaumer: "Fast Fourier Transform and Convolution Algorithms", Chapter 5.5
|
||||
--
|
||||
|
||||
{-# LANGUAGE ForeignFunctionInterface #-}
|
||||
module NTT.FFT.Short where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Array
|
||||
import Data.Word
|
||||
import Data.List ( sort )
|
||||
|
||||
import Control.Monad
|
||||
|
||||
import Foreign.C
|
||||
import Foreign.Ptr
|
||||
import Foreign.Marshal
|
||||
import System.IO.Unsafe
|
||||
|
||||
import Class.Field
|
||||
import Field.Goldilocks ( F , fromF )
|
||||
|
||||
import NTT.FFT.Slow
|
||||
import NTT.Poly.Naive
|
||||
|
||||
import NTT.Subgroup
|
||||
import Misc
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- DFT mathematical definition - O(n^2)
|
||||
|
||||
defForwardDFT :: Subgroup F -> Array Int F -> Array Int F
|
||||
defForwardDFT sg input
|
||||
| subgroupOrder sg /= n = error "defForwardDFT: input size does not match the subgroup"
|
||||
| otherwise = listToArray [ f k | k <-[0..n-1] ]
|
||||
where
|
||||
n = arrayLength input
|
||||
omega = subgroupGen sg
|
||||
f k = sum [ input!i * power_ omega (k*i) | i<-[0..n-1] ]
|
||||
|
||||
defInverseDFT :: Subgroup F -> Array Int F -> Array Int F
|
||||
defInverseDFT sg input
|
||||
| subgroupOrder sg /= n = error "defInverseDFT: input size does not match the subgroup"
|
||||
| otherwise = listToArray [ f k / (fromIntegral n) | k <-[0..n-1] ]
|
||||
where
|
||||
n = arrayLength input
|
||||
omega = subgroupGen sg
|
||||
invomega = inverse omega
|
||||
f k = sum [ input!i * power_ invomega (k*i) | i<-[0..n-1] ]
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
foreign import ccall unsafe "short_fwd_DFT_size_4" c_dft4_fwd :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
foreign import ccall unsafe "short_inv_DFT_size_4_unscaled" c_dft4_inv_unscaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
foreign import ccall unsafe "short_inv_DFT_size_4_rescaled" c_dft4_inv_rescaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
|
||||
foreign import ccall unsafe "short_fwd_DFT_size_8" c_dft8_fwd :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
foreign import ccall unsafe "short_inv_DFT_size_8_unscaled" c_dft8_inv_unscaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
foreign import ccall unsafe "short_inv_DFT_size_8_rescaled" c_dft8_inv_rescaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
|
||||
foreign import ccall unsafe "short_fwd_DFT_size_16" c_dft16_fwd :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
foreign import ccall unsafe "short_inv_DFT_size_16_unscaled" c_dft16_inv_unscaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
foreign import ccall unsafe "short_inv_DFT_size_16_rescaled" c_dft16_inv_rescaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
|
||||
----------------------------------------
|
||||
|
||||
{-# NOINLINE shortForwardDFT4 #-}
|
||||
shortForwardDFT4 :: Array Int F -> Array Int F
|
||||
shortForwardDFT4 input
|
||||
| arrayLength input /= 4 = error "shortForwardDFT4: expecting an input array of size 4"
|
||||
| otherwise = unsafePerformIO $ do
|
||||
allocaArray 4 $ \(ptr1 :: Ptr F) -> do
|
||||
allocaArray 4 $ \(ptr2 :: Ptr F) -> do
|
||||
pokeArray ptr1 (elems input)
|
||||
c_dft4_fwd 1 1 (castPtr ptr1) (castPtr ptr2)
|
||||
ys <- peekArray 4 ptr2
|
||||
return $ listToArray ys
|
||||
|
||||
{-# NOINLINE shortInverseDFT4 #-}
|
||||
shortInverseDFT4 :: Array Int F -> Array Int F
|
||||
shortInverseDFT4 input
|
||||
| arrayLength input /= 4 = error "shortInverseDFT4: expecting an input array of size 4"
|
||||
| otherwise = unsafePerformIO $ do
|
||||
allocaArray 4 $ \(ptr1 :: Ptr F) -> do
|
||||
allocaArray 4 $ \(ptr2 :: Ptr F) -> do
|
||||
pokeArray ptr1 (elems input)
|
||||
c_dft4_inv_rescaled 1 1 (castPtr ptr1) (castPtr ptr2)
|
||||
ys <- peekArray 4 ptr2
|
||||
return $ listToArray ys
|
||||
|
||||
----------------------------------------
|
||||
|
||||
{-# NOINLINE shortForwardDFT8 #-}
|
||||
shortForwardDFT8 :: Array Int F -> Array Int F
|
||||
shortForwardDFT8 input
|
||||
| arrayLength input /= 8 = error "shortForwardDFT8: expecting an input array of size 8"
|
||||
| otherwise = unsafePerformIO $ do
|
||||
allocaArray 8 $ \(ptr1 :: Ptr F) -> do
|
||||
allocaArray 8 $ \(ptr2 :: Ptr F) -> do
|
||||
pokeArray ptr1 (elems input)
|
||||
c_dft8_fwd 1 1 (castPtr ptr1) (castPtr ptr2)
|
||||
ys <- peekArray 8 ptr2
|
||||
return $ listToArray ys
|
||||
|
||||
{-# NOINLINE shortInverseDFT8 #-}
|
||||
shortInverseDFT8 :: Array Int F -> Array Int F
|
||||
shortInverseDFT8 input
|
||||
| arrayLength input /= 8 = error "shortInverseDFT8: expecting an input array of size 8"
|
||||
| otherwise = unsafePerformIO $ do
|
||||
allocaArray 8 $ \(ptr1 :: Ptr F) -> do
|
||||
allocaArray 8 $ \(ptr2 :: Ptr F) -> do
|
||||
pokeArray ptr1 (elems input)
|
||||
c_dft8_inv_rescaled 1 1 (castPtr ptr1) (castPtr ptr2)
|
||||
ys <- peekArray 8 ptr2
|
||||
return $ listToArray ys
|
||||
|
||||
----------------------------------------
|
||||
|
||||
{-# NOINLINE shortForwardDFT16 #-}
|
||||
shortForwardDFT16 :: Array Int F -> Array Int F
|
||||
shortForwardDFT16 input
|
||||
| arrayLength input /= 16 = error "shortForwardDFT16: expecting an input array of size 16"
|
||||
| otherwise = unsafePerformIO $ do
|
||||
allocaArray 16 $ \(ptr1 :: Ptr F) -> do
|
||||
allocaArray 16 $ \(ptr2 :: Ptr F) -> do
|
||||
pokeArray ptr1 (elems input)
|
||||
c_dft16_fwd 1 1 (castPtr ptr1) (castPtr ptr2)
|
||||
ys <- peekArray 16 ptr2
|
||||
return $ listToArray ys
|
||||
|
||||
{-# NOINLINE shortInverseDFT16 #-}
|
||||
shortInverseDFT16 :: Array Int F -> Array Int F
|
||||
shortInverseDFT16 input
|
||||
| arrayLength input /= 16 = error "shortInverseDFT16: expecting an input array of size 16"
|
||||
| otherwise = unsafePerformIO $ do
|
||||
allocaArray 16 $ \(ptr1 :: Ptr F) -> do
|
||||
allocaArray 16 $ \(ptr2 :: Ptr F) -> do
|
||||
pokeArray ptr1 (elems input)
|
||||
c_dft16_inv_rescaled 1 1 (castPtr ptr1) (castPtr ptr2)
|
||||
ys <- peekArray 16 ptr2
|
||||
return $ listToArray ys
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
testDFT4 :: IO ()
|
||||
testDFT4 = do
|
||||
let sg = getSubgroup (Log2 2)
|
||||
xs <- listToArray <$> (replicateM 4 rndIO :: IO [F])
|
||||
let ys0 = defForwardDFT sg xs
|
||||
let ys1 = forwardNTT sg (Poly xs)
|
||||
let ys2 = shortForwardDFT4 xs
|
||||
print $ "xs = " ++ show (elems xs )
|
||||
print $ "ys0 (defin) = " ++ show (elems ys0)
|
||||
print $ "ys1 (ref) = " ++ show (elems ys1)
|
||||
print $ "ys2 (short) = " ++ show (elems ys2)
|
||||
print $ "ok: " ++ show (ys1 == ys2)
|
||||
|
||||
testIDFT4 :: IO ()
|
||||
testIDFT4 = do
|
||||
let sg = getSubgroup (Log2 2)
|
||||
xs <- listToArray <$> (replicateM 4 rndIO :: IO [F])
|
||||
let ys0 = defInverseDFT sg xs
|
||||
let ys1 = fromPoly (inverseNTT sg xs)
|
||||
let ys2 = shortInverseDFT4 xs
|
||||
print $ "xs = " ++ show (elems xs )
|
||||
print $ "ys0 (defin) = " ++ show (elems ys0)
|
||||
print $ "ys1 (ref) = " ++ show (elems ys1)
|
||||
print $ "ys2 (short) = " ++ show (elems ys2)
|
||||
print $ "ok: " ++ show (ys1 == ys2)
|
||||
|
||||
----------------------------------------
|
||||
|
||||
testDFT8 :: IO ()
|
||||
testDFT8 = do
|
||||
let sg = getSubgroup (Log2 3)
|
||||
xs <- listToArray <$> (replicateM 8 rndIO :: IO [F])
|
||||
let ys0 = defForwardDFT sg xs
|
||||
let ys1 = forwardNTT sg (Poly xs)
|
||||
let ys2 = shortForwardDFT8 xs
|
||||
print $ "xs = " ++ show (elems xs )
|
||||
print $ "ys0 (defin) = " ++ show (elems ys0)
|
||||
print $ "ys1 (ref) = " ++ show (elems ys1)
|
||||
print $ "ys2 (short) = " ++ show (elems ys2)
|
||||
print $ "ok: " ++ show (ys1 == ys2)
|
||||
|
||||
testIDFT8 :: IO ()
|
||||
testIDFT8 = do
|
||||
let sg = getSubgroup (Log2 3)
|
||||
xs <- listToArray <$> (replicateM 8 rndIO :: IO [F])
|
||||
let ys0 = defInverseDFT sg xs
|
||||
let ys1 = fromPoly (inverseNTT sg xs)
|
||||
let ys2 = shortInverseDFT8 xs
|
||||
print $ "xs = " ++ show (elems xs )
|
||||
print $ "ys0 (defin) = " ++ show (elems ys0)
|
||||
print $ "ys1 (ref) = " ++ show (elems ys1)
|
||||
print $ "ys2 (short) = " ++ show (elems ys2)
|
||||
print $ "ok: " ++ show (ys1 == ys2)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
-- tmp for debugging
|
||||
instance Ord F where
|
||||
compare a b = compare (fromF a) (fromF b)
|
||||
|
||||
testDFT16 :: IO ()
|
||||
testDFT16 = do
|
||||
let sg = getSubgroup (Log2 4)
|
||||
xs <- listToArray <$> (replicateM 16 rndIO :: IO [F])
|
||||
let ys0 = defForwardDFT sg xs
|
||||
let ys1 = forwardNTT sg (Poly xs)
|
||||
let ys2 = experimental_DFT16 xs
|
||||
let ys3 = shortForwardDFT16 xs
|
||||
print $ "xs = " ++ show (elems xs )
|
||||
print $ "ys0 (defin) = " ++ show (elems ys0)
|
||||
print $ "ys1 (ref) = " ++ show (elems ys1)
|
||||
print $ "ys2 (hsexp) = " ++ show (elems ys2)
|
||||
print $ "ys3 (short) = " ++ show (elems ys3)
|
||||
print $ "ok-def: " ++ show (ys0 == ys1)
|
||||
print $ "ok-hs: " ++ show (ys1 == ys2)
|
||||
print $ "ok-hs (modulo order): " ++ show (sort (elems ys1) == sort (elems ys2))
|
||||
print $ "ok-C: " ++ show (ys1 == ys3)
|
||||
|
||||
testIDFT16 :: IO ()
|
||||
testIDFT16 = do
|
||||
let sg = getSubgroup (Log2 4)
|
||||
xs <- listToArray <$> (replicateM 16 rndIO :: IO [F])
|
||||
let ys0 = defInverseDFT sg xs
|
||||
let ys1 = fromPoly (inverseNTT sg xs)
|
||||
let ys2 = experimental_IDFT16 xs
|
||||
let ys3 = shortInverseDFT16 xs
|
||||
print $ "xs = " ++ show (elems xs )
|
||||
print $ "ys0 (defin) = " ++ show (elems ys0)
|
||||
print $ "ys1 (ref) = " ++ show (elems ys1)
|
||||
print $ "ys2 (hsexp) = " ++ show (elems ys2)
|
||||
print $ "ys3 (short) = " ++ show (elems ys3)
|
||||
print $ "ok-def: " ++ show (ys0 == ys1)
|
||||
print $ "ok-hs: " ++ show (ys1 == ys2)
|
||||
print $ "ok-C: " ++ show (ys1 == ys3)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
dft2_omega, dft4_omega, dft8_omega, dft16_omega :: F
|
||||
dft2_omega = subgroupGen $ getSubgroup (Log2 1)
|
||||
dft4_omega = subgroupGen $ getSubgroup (Log2 2)
|
||||
dft8_omega = subgroupGen $ getSubgroup (Log2 3)
|
||||
dft16_omega = subgroupGen $ getSubgroup (Log2 4)
|
||||
|
||||
dft2_inv_omega = inverse dft2_omega
|
||||
dft4_inv_omega = inverse dft4_omega
|
||||
dft8_inv_omega = inverse dft8_omega
|
||||
dft16_inv_omega = inverse dft16_omega
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Size = 4
|
||||
|
||||
dft4_j = dft4_omega
|
||||
idft4_inv_4 = 1 / (4 :: F)
|
||||
|
||||
printIDFT4 = do
|
||||
putStrLn $ "const uint64_t IDFT4_OMEGA = " ++ show dft4_omega ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT4_INV_OMEGA = " ++ show dft4_inv_omega ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT4_J = " ++ show dft4_j ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT4_INV_4 = " ++ show idft4_inv_4 ++ " ;"
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Size = 8
|
||||
|
||||
dft8_j = square dft8_omega
|
||||
dft8_cos_u = - (dft8_omega + dft8_omega^7) / 2
|
||||
dft8_j_sin_u = - (dft8_omega - dft8_omega^7) / 2
|
||||
dft8_minus_j_sin_u = - dft8_j_sin_u
|
||||
|
||||
printDFT8 = do
|
||||
putStrLn $ "const uint64_t DFT8_OMEGA = " ++ show dft8_omega ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT8_INV_OMEGA = " ++ show dft8_inv_omega ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT8_J = " ++ show dft8_j ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT8_COS_U = " ++ show dft8_cos_u ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT8_MINUS_J_SIN_U = " ++ show dft8_minus_j_sin_u ++ " ;"
|
||||
|
||||
idft8_j = square dft8_omega
|
||||
idft8_cos_u = (dft8_omega + dft8_omega^7) / 2
|
||||
idft8_j_sin_u = (dft8_omega - dft8_omega^7) / 2
|
||||
idft8_minus_j_sin_u = - idft8_j_sin_u
|
||||
idft8_inv_8 = 1 / (8 :: F)
|
||||
|
||||
printIDFT8 = do
|
||||
putStrLn $ "const uint64_t IDFT8_OMEGA = " ++ show dft8_omega ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT8_INV_OMEGA = " ++ show dft8_inv_omega ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT8_J = " ++ show idft8_j ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT8_COS_U = " ++ show idft8_cos_u ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT8_MINUS_J_SIN_U = " ++ show idft8_minus_j_sin_u ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT8_INV_8 = " ++ show idft8_inv_8 ++ " ;"
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Size = 16
|
||||
|
||||
dft16_j = square (square dft16_omega)
|
||||
|
||||
dft16_cos_u = (dft16_omega + dft16_omega^15) / 2
|
||||
dft16_cos_2u = (dft16_omega^2 + dft16_omega^14) / 2
|
||||
dft16_cos_3u = (dft16_omega^3 + dft16_omega^13) / 2
|
||||
|
||||
dft16_j_sin_u = (dft16_omega - dft16_omega^15) / 2 -- ????
|
||||
dft16_j_sin_2u = (dft16_omega^2 - dft16_omega^14) / 2 -- but it seems to work...
|
||||
dft16_j_sin_3u = (dft16_omega^3 - dft16_omega^13) / 2
|
||||
|
||||
printDFT16 :: IO ()
|
||||
printDFT16 = do
|
||||
putStrLn $ "const uint64_t DFT16_OMEGA = " ++ show dft16_omega ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT16_INV_OMEGA = " ++ show dft16_inv_omega ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT16_J = " ++ show dft16_j ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT16_COS_U = " ++ show dft16_cos_u ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT16_COS_2U = " ++ show dft16_cos_2u ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT16_COS_3U = " ++ show dft16_cos_3u ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT16_MINUS_J_SIN_U = " ++ show (- dft16_j_sin_u ) ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT16_MINUS_J_SIN_2U = " ++ show (- dft16_j_sin_2u) ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT16_MINUS_J_SIN_3U = " ++ show (- dft16_j_sin_3u) ++ " ;"
|
||||
|
||||
putStrLn $ "const uint64_t DFT16_COS_3U_PLUS_U = " ++ show ( dft16_cos_3u + dft16_cos_u ) ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT16_COS_3U_MINUS_U = " ++ show ( dft16_cos_3u - dft16_cos_u ) ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT16_J_SIN_3U_MINUS_U = " ++ show ( dft16_j_sin_3u - dft16_j_sin_u) ++ " ;"
|
||||
putStrLn $ "const uint64_t DFT16_J_SIN_MINUS_3U_MINUS_U = " ++ show (- dft16_j_sin_3u - dft16_j_sin_u) ++ " ;"
|
||||
|
||||
--------------------
|
||||
|
||||
idft16_inv_16 = 1 / (16 :: F)
|
||||
idft16_j = square (square dft16_omega)
|
||||
|
||||
idft16_cos_u = (dft16_omega + dft16_omega^15) / 2
|
||||
idft16_cos_2u = (dft16_omega^2 + dft16_omega^14) / 2
|
||||
idft16_cos_3u = (dft16_omega^3 + dft16_omega^13) / 2
|
||||
|
||||
idft16_j_sin_u = (dft16_omega - dft16_omega^15) / 2
|
||||
idft16_j_sin_2u = (dft16_omega^2 - dft16_omega^14) / 2
|
||||
idft16_j_sin_3u = (dft16_omega^3 - dft16_omega^13) / 2
|
||||
|
||||
printIDFT16 :: IO ()
|
||||
printIDFT16 = do
|
||||
putStrLn $ "const uint64_t IDFT16_OMEGA = " ++ show dft16_omega ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_INV_OMEGA = " ++ show dft16_inv_omega ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_INV_16 = " ++ show idft16_inv_16 ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_J = " ++ show idft16_j ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_COS_U = " ++ show idft16_cos_u ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_COS_2U = " ++ show idft16_cos_2u ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_COS_3U = " ++ show idft16_cos_3u ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_MINUS_J_SIN_U = " ++ show (- idft16_j_sin_u ) ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_MINUS_J_SIN_2U = " ++ show (- idft16_j_sin_2u) ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_MINUS_J_SIN_3U = " ++ show (- idft16_j_sin_3u) ++ " ;"
|
||||
|
||||
putStrLn $ "const uint64_t IDFT16_COS_3U_PLUS_U = " ++ show ( idft16_cos_3u + idft16_cos_u ) ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_COS_3U_MINUS_U = " ++ show ( idft16_cos_3u - idft16_cos_u ) ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_J_SIN_3U_MINUS_U = " ++ show ( idft16_j_sin_3u - idft16_j_sin_u) ++ " ;"
|
||||
putStrLn $ "const uint64_t IDFT16_J_SIN_MINUS_3U_MINUS_U = " ++ show (- idft16_j_sin_3u - idft16_j_sin_u) ++ " ;"
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
experimental_IDFT16 :: Array Int F -> Array Int F
|
||||
experimental_IDFT16 input = output where
|
||||
|
||||
rescale :: F -> F
|
||||
rescale x = x / 16
|
||||
|
||||
[x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15] = elems input
|
||||
output = listArray (0,15) $ map rescale [y0,y1,y2,y3,y4,y5,y6,y7,y8,y9,y10,y11,y12,y13,y14,y15]
|
||||
|
||||
t1 = x0 + x8
|
||||
t2 = x4 + x12
|
||||
t3 = x2 + x10
|
||||
t4 = x2 - x10
|
||||
t5 = x6 + x14
|
||||
t6 = x6 - x14
|
||||
t7 = x1 + x9
|
||||
t8 = x1 - x9
|
||||
t9 = x3 + x11
|
||||
t10 = x3 - x11
|
||||
t11 = x5 + x13
|
||||
t12 = x5 - x13
|
||||
t13 = x7 + x15
|
||||
t14 = x7 - x15
|
||||
t15 = t1 + t2
|
||||
t16 = t3 + t5
|
||||
t17 = t15 + t16
|
||||
t18 = t7 + t11
|
||||
t19 = t7 - t11
|
||||
t20 = t9 + t13
|
||||
t21 = t9 - t13
|
||||
t22 = t18 + t20
|
||||
t23 = t8 + t14
|
||||
t24 = t8 - t14
|
||||
t25 = t10 + t12
|
||||
t26 = t12 - t10
|
||||
|
||||
m0 = t17 + t22
|
||||
m1 = t17 - t22
|
||||
m2 = t15 - t16
|
||||
m3 = t1 - t2
|
||||
m4 = x0 - x8
|
||||
|
||||
m5 = idft16_cos_2u * (t19 - t21)
|
||||
m6 = idft16_cos_2u * (t4 - t6 )
|
||||
m7 = idft16_cos_3u * (t24 + t26)
|
||||
m8 = (idft16_cos_3u + idft16_cos_u) * t24
|
||||
m9 = (idft16_cos_3u - idft16_cos_u) * t26
|
||||
m10 = idft16_j * (t20 - t18)
|
||||
m11 = idft16_j * (t5 - t3 )
|
||||
m12 = idft16_j * (x12 - x4 )
|
||||
m13 = - idft16_j_sin_2u * ( t19 + t21)
|
||||
m14 = - idft16_j_sin_2u * ( t4 + t6 )
|
||||
m15 = - idft16_j_sin_3u * ( t23 + t25)
|
||||
m16 = (idft16_j_sin_3u - idft16_j_sin_u) * t23
|
||||
m17 = - (idft16_j_sin_3u + idft16_j_sin_u) * t25
|
||||
|
||||
s7 = m8 - m7
|
||||
s8 = m9 - m7
|
||||
|
||||
s15 = m15 + m16
|
||||
s16 = m15 - m17
|
||||
|
||||
s1 = m3 + m5
|
||||
s2 = m3 - m5
|
||||
s3 = m11 + m13
|
||||
s4 = m13 - m11
|
||||
s5 = m4 + m6
|
||||
s6 = m4 - m6
|
||||
|
||||
s9 = s5 + s7
|
||||
s10 = s5 - s7
|
||||
s11 = s6 + s8
|
||||
s12 = s6 - s8
|
||||
|
||||
s13 = m12 + m14
|
||||
s14 = m12 - m14
|
||||
|
||||
s17 = s13 + s15
|
||||
s18 = s13 - s15
|
||||
s19 = s14 + s16
|
||||
s20 = s14 - s16
|
||||
|
||||
y0 = m0
|
||||
y1 = s9 + s17
|
||||
y2 = s1 + s3
|
||||
y3 = s12 - s20
|
||||
y4 = m2 + m10
|
||||
y5 = s11 + s19
|
||||
y6 = s2 + s4
|
||||
y7 = s10 - s18
|
||||
y8 = m1
|
||||
y9 = s10 + s18
|
||||
y10 = s2 - s4
|
||||
y11 = s11 - s19
|
||||
y12 = m2 - m10
|
||||
y13 = s12 + s20
|
||||
y14 = s1 - s3
|
||||
y15 = s9 - s17
|
||||
|
||||
----------------------------------------
|
||||
|
||||
-- is it always true, that in the NTT setting, the DFT and the unscaled
|
||||
-- inverse DFT are the same up to permutation??
|
||||
--
|
||||
experimental_DFT16 :: Array Int F -> Array Int F
|
||||
experimental_DFT16 input = output where
|
||||
|
||||
[x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15] = elems input
|
||||
output = listArray (0,15) [y0,y1,y2,y3,y4,y5,y6,y7,y8,y9,y10,y11,y12,y13,y14,y15]
|
||||
|
||||
t1 = x0 + x8
|
||||
t2 = x4 + x12
|
||||
t3 = x2 + x10
|
||||
t4 = x2 - x10
|
||||
t5 = x6 + x14
|
||||
t6 = x6 - x14
|
||||
t7 = x1 + x9
|
||||
t8 = x1 - x9
|
||||
t9 = x3 + x11
|
||||
t10 = x3 - x11
|
||||
t11 = x5 + x13
|
||||
t12 = x5 - x13
|
||||
t13 = x7 + x15
|
||||
t14 = x7 - x15
|
||||
t15 = t1 + t2
|
||||
t16 = t3 + t5
|
||||
t17 = t15 + t16
|
||||
t18 = t7 + t11
|
||||
t19 = t7 - t11
|
||||
t20 = t9 + t13
|
||||
t21 = t9 - t13
|
||||
t22 = t18 + t20
|
||||
t23 = t8 + t14
|
||||
t24 = t8 - t14
|
||||
t25 = t10 + t12
|
||||
t26 = t12 - t10
|
||||
|
||||
m0 = t17 + t22
|
||||
m1 = t17 - t22
|
||||
m2 = t15 - t16
|
||||
m3 = t1 - t2
|
||||
m4 = x0 - x8
|
||||
|
||||
m5 = dft16_cos_2u * (t19 - t21)
|
||||
m6 = dft16_cos_2u * (t4 - t6 )
|
||||
m7 = dft16_cos_3u * (t24 + t26)
|
||||
m8 = (dft16_cos_3u + dft16_cos_u) * t24
|
||||
m9 = (dft16_cos_3u - dft16_cos_u) * t26
|
||||
m10 = dft16_j * (t20 - t18)
|
||||
m11 = dft16_j * (t5 - t3 )
|
||||
m12 = dft16_j * (x12 - x4 )
|
||||
m13 = - dft16_j_sin_2u * ( t19 + t21)
|
||||
m14 = - dft16_j_sin_2u * ( t4 + t6 )
|
||||
m15 = - dft16_j_sin_3u * ( t23 + t25)
|
||||
m16 = (dft16_j_sin_3u - dft16_j_sin_u) * t23
|
||||
m17 = - (dft16_j_sin_3u + dft16_j_sin_u) * t25
|
||||
|
||||
s7 = m8 - m7
|
||||
s8 = m9 - m7
|
||||
|
||||
s15 = m15 + m16
|
||||
s16 = m15 - m17
|
||||
|
||||
s1 = m3 + m5
|
||||
s2 = m3 - m5
|
||||
s3 = m11 + m13
|
||||
s4 = m13 - m11
|
||||
s5 = m4 + m6
|
||||
s6 = m4 - m6
|
||||
|
||||
s9 = s5 + s7
|
||||
s10 = s5 - s7
|
||||
s11 = s6 + s8
|
||||
s12 = s6 - s8
|
||||
|
||||
s13 = m12 + m14
|
||||
s14 = m12 - m14
|
||||
|
||||
s17 = s13 + s15
|
||||
s18 = s13 - s15
|
||||
s19 = s14 + s16
|
||||
s20 = s14 - s16
|
||||
|
||||
y0 = m0
|
||||
y1 = s9 - s17
|
||||
y2 = s1 - s3
|
||||
y3 = s12 + s20
|
||||
y4 = m2 - m10
|
||||
y5 = s11 - s19
|
||||
y6 = s2 - s4
|
||||
y7 = s10 + s18
|
||||
y8 = m1
|
||||
y9 = s10 - s18
|
||||
y10 = s2 + s4
|
||||
y11 = s11 + s19
|
||||
y12 = m2 + m10
|
||||
y13 = s12 - s20
|
||||
y14 = s1 + s3
|
||||
y15 = s9 + s17
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
data Cost = MkCost
|
||||
{ _nAdds :: Int
|
||||
, _nMuls :: Int
|
||||
}
|
||||
deriving (Eq,Show)
|
||||
|
||||
addCost :: Cost -> Cost -> Cost
|
||||
addCost (MkCost a1 m1) (MkCost a2 m2) = MkCost (a1+a2) (m1+m2)
|
||||
|
||||
scaleCost :: Int -> Cost -> Cost
|
||||
scaleCost s (MkCost a m) = MkCost (s*a) (s*m)
|
||||
|
||||
doubleCost :: Cost -> Cost
|
||||
doubleCost = scaleCost 2
|
||||
|
||||
estimateNTTCost :: Log2 -> Cost
|
||||
estimateNTTCost = go where
|
||||
go :: Log2 -> Cost
|
||||
go 0 = MkCost 0 0
|
||||
go 1 = MkCost 2 0
|
||||
go m = recursive `addCost` post where
|
||||
recursive = doubleCost (go (m-1))
|
||||
post = scaleCost halfN (MkCost 3 1)
|
||||
halfN = exp2_ (m-1)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
@ -4,4 +4,5 @@ gcc -c -O2 goldilocks.c
|
||||
gcc -c -O2 goldilocks_ext.c
|
||||
gcc -c -O2 monolith.c
|
||||
gcc -c -O2 ntt.c
|
||||
gcc -c -O2 short_dft.c
|
||||
|
||||
|
||||
566
reference/src/cbits/short_dft.c
Normal file
566
reference/src/cbits/short_dft.c
Normal file
@ -0,0 +1,566 @@
|
||||
|
||||
//
|
||||
// Short DFT algoritmus (size 8 and 16)
|
||||
//
|
||||
// See:
|
||||
//
|
||||
// - Nussbaumer: "Fast Fourier Transform and Convolution Algorithms", Chapter 5.5
|
||||
//
|
||||
// Note: As they describe complex DFT and we need NTT, the conventions differ a bit.
|
||||
//
|
||||
// Hence some formulas in the comments look false, that's because we follow the
|
||||
// formulas from the book but need to change some coefficients and/or ordering...
|
||||
//
|
||||
// Note #2: If the multiplicative generator is changed, the constants need
|
||||
// to be regenerated. See the module "NTT.FTT.Short"
|
||||
//
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "goldilocks.h"
|
||||
#include "short_dft.h"
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// SIZE = 4
|
||||
|
||||
const uint64_t DFT4_J = 0x0001000000000000 ;
|
||||
|
||||
const uint64_t IDFT4_OMEGA = 0x0001000000000000 ;
|
||||
const uint64_t IDFT4_INV_OMEGA = 0xfffeffff00000001 ;
|
||||
const uint64_t IDFT4_J = 0x0001000000000000 ;
|
||||
const uint64_t IDFT4_INV_4 = 0xbfffffff40000001 ;
|
||||
|
||||
void short_inv_DFT_size_4_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
|
||||
|
||||
int src_stride2 = src_stride + src_stride ;
|
||||
int src_stride3 = src_stride2 + src_stride ;
|
||||
|
||||
uint64_t x0 = src[ 0];
|
||||
uint64_t x1 = src[src_stride ];
|
||||
uint64_t x2 = src[src_stride2];
|
||||
uint64_t x3 = src[src_stride3];
|
||||
|
||||
uint64_t t1 = goldilocks_add( x0 , x2 ); // x0 + x2
|
||||
uint64_t t2 = goldilocks_add( x1 , x3 ); // x1 + x3
|
||||
uint64_t m0 = goldilocks_add( t1 , t2 ); // t1 + t2
|
||||
uint64_t m1 = goldilocks_sub( t1 , t2 ); // t1 - t2
|
||||
uint64_t m2 = goldilocks_sub( x0 , x2 ); // x0 - x2
|
||||
uint64_t m3 = goldilocks_mul( DFT4_J , goldilocks_sub( x3 , x1 ) ); // j * (x3 - x1)
|
||||
|
||||
int tgt_stride2 = tgt_stride + tgt_stride ;
|
||||
int tgt_stride3 = tgt_stride2 + tgt_stride ;
|
||||
|
||||
tgt[ 0] = m0;
|
||||
tgt[tgt_stride ] = goldilocks_add( m2 , m3 );
|
||||
tgt[tgt_stride2] = m1;
|
||||
tgt[tgt_stride3] = goldilocks_sub( m2 , m3 );
|
||||
|
||||
}
|
||||
|
||||
void short_fwd_DFT_size_4( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
|
||||
short_inv_DFT_size_4_unscaled( src_stride, tgt_stride, src, tgt );
|
||||
|
||||
int tgt_stride3 = 3*tgt_stride;
|
||||
|
||||
uint64_t tmp = tgt[tgt_stride ];
|
||||
tgt[tgt_stride ] = tgt[tgt_stride3];
|
||||
tgt[tgt_stride3] = tmp;
|
||||
}
|
||||
|
||||
void short_inv_DFT_size_4_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
|
||||
|
||||
short_inv_DFT_size_4_unscaled( src_stride, tgt_stride, src, tgt );
|
||||
|
||||
int tgt_stride2 = tgt_stride + tgt_stride ;
|
||||
int tgt_stride3 = tgt_stride2 + tgt_stride ;
|
||||
|
||||
tgt[ 0] = goldilocks_mul( IDFT4_INV_4 , tgt[ 0] );
|
||||
tgt[tgt_stride ] = goldilocks_mul( IDFT4_INV_4 , tgt[tgt_stride ] );
|
||||
tgt[tgt_stride2] = goldilocks_mul( IDFT4_INV_4 , tgt[tgt_stride2] );
|
||||
tgt[tgt_stride3] = goldilocks_mul( IDFT4_INV_4 , tgt[tgt_stride3] );
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// SIZE = 8
|
||||
|
||||
const uint64_t DFT8_OMEGA = 0xfffffffeff000001 ;
|
||||
const uint64_t DFT8_INV_OMEGA = 0x000000ffffffff00 ;
|
||||
const uint64_t DFT8_J = 0x0001000000000000 ;
|
||||
const uint64_t DFT8_COS_U = 0xffffff7f00800081 ;
|
||||
const uint64_t DFT8_MINUS_J_SIN_U = 0xffffff7eff800081 ;
|
||||
|
||||
void short_fwd_DFT_size_8( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
|
||||
// u = 2pi/8
|
||||
// omega = cos(u) + i*sin(u)
|
||||
//
|
||||
// cos_u ~> - (omega+omega^7) / 2
|
||||
// j_sin_u ~> - (omega-omega^7) / 2
|
||||
// -j_sin_u ~> + (omega-omega^7) / 2
|
||||
// j ~> omega^2
|
||||
|
||||
int src_stride2 = src_stride + src_stride ;
|
||||
int src_stride3 = src_stride2 + src_stride ;
|
||||
int src_stride4 = src_stride2 + src_stride2;
|
||||
int src_stride5 = src_stride4 + src_stride ;
|
||||
int src_stride6 = src_stride4 + src_stride2;
|
||||
int src_stride7 = src_stride4 + src_stride3;
|
||||
|
||||
uint64_t x0 = src[0 ];
|
||||
uint64_t x1 = src[src_stride ];
|
||||
uint64_t x2 = src[src_stride2];
|
||||
uint64_t x3 = src[src_stride3];
|
||||
uint64_t x4 = src[src_stride4];
|
||||
uint64_t x5 = src[src_stride5];
|
||||
uint64_t x6 = src[src_stride6];
|
||||
uint64_t x7 = src[src_stride7];
|
||||
|
||||
uint64_t t1 = goldilocks_add( x0 , x4 ); // x0 + x4
|
||||
uint64_t t2 = goldilocks_add( x2 , x6 ); // x2 + x6
|
||||
uint64_t t3 = goldilocks_add( x1 , x5 ); // x1 + x5
|
||||
uint64_t t4 = goldilocks_sub( x1 , x5 ); // x1 - x5
|
||||
uint64_t t5 = goldilocks_add( x3 , x7 ); // x3 + x7
|
||||
uint64_t t6 = goldilocks_sub( x3 , x7 ); // x3 - x7
|
||||
uint64_t t7 = goldilocks_add( t1 , t2 ); // t1 + t2
|
||||
uint64_t t8 = goldilocks_add( t3 , t5 ); // t3 + t5
|
||||
|
||||
uint64_t m0 = goldilocks_add( t7 , t8 ); // (t7 + t8)
|
||||
uint64_t m1 = goldilocks_sub( t7 , t8 ); // (t7 - t8)
|
||||
uint64_t m2 = goldilocks_sub( t1 , t2 ); // (t1 - t2)
|
||||
uint64_t m3 = goldilocks_sub( x0 , x4 ); // (x0 - x4)
|
||||
|
||||
uint64_t m4 = goldilocks_mul( DFT8_COS_U , goldilocks_sub( t4 , t6 ) ); // cos_u * (t4 - t6)
|
||||
uint64_t m5 = goldilocks_mul( DFT8_J , goldilocks_sub( t5 , t3 ) ); // j*(t5 - t3)
|
||||
uint64_t m6 = goldilocks_mul( DFT8_J , goldilocks_sub( x6 , x2 ) ); // j*(x6 - x2)
|
||||
uint64_t m7 = goldilocks_mul( DFT8_MINUS_J_SIN_U , goldilocks_add( t4 , t6 ) ); // - j_sin_u * (t4 + t6)
|
||||
|
||||
uint64_t s1 = goldilocks_add( m3 , m4 ); // m3 + m4
|
||||
uint64_t s2 = goldilocks_sub( m3 , m4 ); // m3 - m4
|
||||
uint64_t s3 = goldilocks_add( m6 , m7 ); // m6 + m7
|
||||
uint64_t s4 = goldilocks_sub( m6 , m7 ); // m6 - m7
|
||||
|
||||
int tgt_stride2 = tgt_stride + tgt_stride ;
|
||||
int tgt_stride3 = tgt_stride2 + tgt_stride ;
|
||||
int tgt_stride4 = tgt_stride2 + tgt_stride2;
|
||||
int tgt_stride5 = tgt_stride4 + tgt_stride ;
|
||||
int tgt_stride6 = tgt_stride4 + tgt_stride2;
|
||||
int tgt_stride7 = tgt_stride4 + tgt_stride3;
|
||||
|
||||
tgt[ 0] = m0; // m0
|
||||
tgt[tgt_stride ] = goldilocks_sub( s2 , s4 ); // s2 - s4
|
||||
tgt[tgt_stride2] = goldilocks_sub( m2 , m5 ); // m2 - m5
|
||||
tgt[tgt_stride3] = goldilocks_add( s1 , s3 ); // s1 + s3
|
||||
tgt[tgt_stride4] = m1; // m1
|
||||
tgt[tgt_stride5] = goldilocks_sub( s1 , s3 ); // s1 - s3
|
||||
tgt[tgt_stride6] = goldilocks_add( m2 , m5 ); // m2 + m5
|
||||
tgt[tgt_stride7] = goldilocks_add( s2 , s4 ); // s2 + s4
|
||||
}
|
||||
|
||||
const uint64_t IDFT8_OMEGA = 0xfffffffeff000001 ;
|
||||
const uint64_t IDFT8_INV_OMEGA = 0x000000ffffffff00 ;
|
||||
const uint64_t IDFT8_J = 0x0001000000000000 ;
|
||||
const uint64_t IDFT8_COS_U = 0x0000007fff7fff80 ;
|
||||
const uint64_t IDFT8_MINUS_J_SIN_U = 0x00000080007fff80 ;
|
||||
const uint64_t IDFT8_INV_8 = 0xdfffffff20000001 ;
|
||||
|
||||
//--------------------------------------
|
||||
|
||||
void short_inv_DFT_size_8_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
|
||||
// u = 2pi/8
|
||||
// omega = cos(u) + i*sin(u)
|
||||
//
|
||||
// cos_u ~> (omega+omega^7) / 2
|
||||
// -j_sin_u ~> - (omega-omega^7) / 2
|
||||
// j ~> omega^2
|
||||
|
||||
int src_stride2 = src_stride + src_stride ;
|
||||
int src_stride3 = src_stride2 + src_stride ;
|
||||
int src_stride4 = src_stride2 + src_stride2;
|
||||
int src_stride5 = src_stride4 + src_stride ;
|
||||
int src_stride6 = src_stride4 + src_stride2;
|
||||
int src_stride7 = src_stride4 + src_stride3;
|
||||
|
||||
uint64_t x0 = src[0 ];
|
||||
uint64_t x1 = src[src_stride ];
|
||||
uint64_t x2 = src[src_stride2];
|
||||
uint64_t x3 = src[src_stride3];
|
||||
uint64_t x4 = src[src_stride4];
|
||||
uint64_t x5 = src[src_stride5];
|
||||
uint64_t x6 = src[src_stride6];
|
||||
uint64_t x7 = src[src_stride7];
|
||||
|
||||
uint64_t t1 = goldilocks_add( x0 , x4 ); // x0 + x4
|
||||
uint64_t t2 = goldilocks_add( x2 , x6 ); // x2 + x6
|
||||
uint64_t t3 = goldilocks_add( x1 , x5 ); // x1 + x5
|
||||
uint64_t t4 = goldilocks_sub( x1 , x5 ); // x1 - x5
|
||||
uint64_t t5 = goldilocks_add( x3 , x7 ); // x3 + x7
|
||||
uint64_t t6 = goldilocks_sub( x3 , x7 ); // x3 - x7
|
||||
uint64_t t7 = goldilocks_add( t1 , t2 ); // t1 + t2
|
||||
uint64_t t8 = goldilocks_add( t3 , t5 ); // t3 + t5
|
||||
|
||||
uint64_t m0 = goldilocks_add( t7 , t8 ); // (t7 + t8)
|
||||
uint64_t m1 = goldilocks_sub( t7 , t8 ); // (t7 - t8)
|
||||
uint64_t m2 = goldilocks_sub( t1 , t2 ); // (t1 - t2)
|
||||
uint64_t m3 = goldilocks_sub( x0 , x4 ); // (x0 - x4)
|
||||
|
||||
uint64_t m4 = goldilocks_mul( IDFT8_COS_U , goldilocks_sub( t4 , t6 ) ); // cos_u * (t4 - t6)
|
||||
uint64_t m5 = goldilocks_mul( IDFT8_J , goldilocks_sub( t5 , t3 ) ); // j*(t5 - t3)
|
||||
uint64_t m6 = goldilocks_mul( IDFT8_J , goldilocks_sub( x6 , x2 ) ); // j*(x6 - x2)
|
||||
uint64_t m7 = goldilocks_mul( IDFT8_MINUS_J_SIN_U , goldilocks_add( t4 , t6 ) ); // - j_sin_u * (t4 + t6)
|
||||
|
||||
uint64_t s1 = goldilocks_add( m3 , m4 ); // m3 + m4
|
||||
uint64_t s2 = goldilocks_sub( m3 , m4 ); // m3 - m4
|
||||
uint64_t s3 = goldilocks_add( m6 , m7 ); // m6 + m7
|
||||
uint64_t s4 = goldilocks_sub( m6 , m7 ); // m6 - m7
|
||||
|
||||
int tgt_stride2 = tgt_stride + tgt_stride ;
|
||||
int tgt_stride3 = tgt_stride2 + tgt_stride ;
|
||||
int tgt_stride4 = tgt_stride2 + tgt_stride2;
|
||||
int tgt_stride5 = tgt_stride4 + tgt_stride ;
|
||||
int tgt_stride6 = tgt_stride4 + tgt_stride2;
|
||||
int tgt_stride7 = tgt_stride4 + tgt_stride3;
|
||||
|
||||
tgt[ 0] = m0 ; // m0
|
||||
tgt[tgt_stride ] = goldilocks_add( s1 , s3 ); // s1 + s3
|
||||
tgt[tgt_stride2] = goldilocks_add( m2 , m5 ); // m2 + m5
|
||||
tgt[tgt_stride3] = goldilocks_sub( s2 , s4 ); // s2 - s4
|
||||
tgt[tgt_stride4] = m1 ; // m1
|
||||
tgt[tgt_stride5] = goldilocks_add( s2 , s4 ); // s2 + s4
|
||||
tgt[tgt_stride6] = goldilocks_sub( m2 , m5 ); // m2 - m5
|
||||
tgt[tgt_stride7] = goldilocks_sub( s1 , s3 ); // s1 - s3
|
||||
}
|
||||
|
||||
//------------------
|
||||
|
||||
void short_inv_DFT_size_8_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
|
||||
|
||||
short_inv_DFT_size_8_unscaled( src_stride, tgt_stride, src, tgt );
|
||||
|
||||
int tgt_stride2 = tgt_stride + tgt_stride ;
|
||||
int tgt_stride3 = tgt_stride2 + tgt_stride ;
|
||||
int tgt_stride4 = tgt_stride2 + tgt_stride2;
|
||||
int tgt_stride5 = tgt_stride4 + tgt_stride ;
|
||||
int tgt_stride6 = tgt_stride4 + tgt_stride2;
|
||||
int tgt_stride7 = tgt_stride4 + tgt_stride3;
|
||||
|
||||
tgt[ 0] = goldilocks_mul( IDFT8_INV_8 , tgt[ 0] );
|
||||
tgt[tgt_stride ] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride ] );
|
||||
tgt[tgt_stride2] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride2] );
|
||||
tgt[tgt_stride3] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride3] );
|
||||
tgt[tgt_stride4] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride4] );
|
||||
tgt[tgt_stride5] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride5] );
|
||||
tgt[tgt_stride6] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride6] );
|
||||
tgt[tgt_stride7] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride7] );
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// SIZE = 16
|
||||
|
||||
const uint64_t DFT16_OMEGA = 0xefffffff00000001 ;
|
||||
const uint64_t DFT16_INV_OMEGA = 0x0000001000000000 ;
|
||||
const uint64_t DFT16_J = 0x0001000000000000 ;
|
||||
const uint64_t DFT16_COS_U = 0xf800000700000001 ;
|
||||
const uint64_t DFT16_COS_2U = 0x0000007fff7fff80 ;
|
||||
const uint64_t DFT16_COS_3U = 0x0007fffffff7f800 ;
|
||||
const uint64_t DFT16_MINUS_J_SIN_U = 0x0800000800000000 ;
|
||||
const uint64_t DFT16_MINUS_J_SIN_2U = 0x00000080007fff80 ;
|
||||
const uint64_t DFT16_MINUS_J_SIN_3U = 0xfff7ffff0007f801 ;
|
||||
const uint64_t DFT16_COS_3U_PLUS_U = 0xf8080006fff7f801 ;
|
||||
const uint64_t DFT16_COS_3U_MINUS_U = 0x0807fff7fff7f800 ;
|
||||
const uint64_t DFT16_J_SIN_3U_MINUS_U = 0x08080007fff80800 ;
|
||||
const uint64_t DFT16_J_SIN_MINUS_3U_MINUS_U = 0x07f800080007f800 ;
|
||||
|
||||
void short_fwd_DFT_size_16( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
|
||||
|
||||
int src_stride2 = src_stride + src_stride ;
|
||||
int src_stride3 = src_stride2 + src_stride ;
|
||||
int src_stride4 = src_stride2 + src_stride2;
|
||||
int src_stride5 = src_stride4 + src_stride ;
|
||||
int src_stride6 = src_stride4 + src_stride2;
|
||||
int src_stride7 = src_stride4 + src_stride3;
|
||||
int src_stride8 = src_stride4 + src_stride4;
|
||||
|
||||
uint64_t x0 = src[ 0 ];
|
||||
uint64_t x1 = src[ src_stride ];
|
||||
uint64_t x2 = src[ src_stride2 ];
|
||||
uint64_t x3 = src[ src_stride3 ];
|
||||
uint64_t x4 = src[ src_stride4 ];
|
||||
uint64_t x5 = src[ src_stride5 ];
|
||||
uint64_t x6 = src[ src_stride6 ];
|
||||
uint64_t x7 = src[ src_stride7 ];
|
||||
uint64_t x8 = src[ src_stride8 ];
|
||||
uint64_t x9 = src[ src_stride + src_stride8 ];
|
||||
uint64_t x10 = src[ src_stride2 + src_stride8 ];
|
||||
uint64_t x11 = src[ src_stride3 + src_stride8 ];
|
||||
uint64_t x12 = src[ src_stride4 + src_stride8 ];
|
||||
uint64_t x13 = src[ src_stride5 + src_stride8 ];
|
||||
uint64_t x14 = src[ src_stride6 + src_stride8 ];
|
||||
uint64_t x15 = src[ src_stride7 + src_stride8 ];
|
||||
|
||||
uint64_t t1 = goldilocks_add( x0 , x8 ); // x0 + x8
|
||||
uint64_t t2 = goldilocks_add( x4 , x12 ); // x4 + x12
|
||||
uint64_t t3 = goldilocks_add( x2 , x10 ); // x2 + x10
|
||||
uint64_t t4 = goldilocks_sub( x2 , x10 ); // x2 - x10
|
||||
uint64_t t5 = goldilocks_add( x6 , x14 ); // x6 + x14
|
||||
uint64_t t6 = goldilocks_sub( x6 , x14 ); // x6 - x14
|
||||
uint64_t t7 = goldilocks_add( x1 , x9 ); // x1 + x9
|
||||
uint64_t t8 = goldilocks_sub( x1 , x9 ); // x1 - x9
|
||||
uint64_t t9 = goldilocks_add( x3 , x11 ); // x3 + x11
|
||||
uint64_t t10 = goldilocks_sub( x3 , x11 ); // x3 - x11
|
||||
uint64_t t11 = goldilocks_add( x5 , x13 ); // x5 + x13
|
||||
uint64_t t12 = goldilocks_sub( x5 , x13 ); // x5 - x13
|
||||
uint64_t t13 = goldilocks_add( x7 , x15 ); // x7 + x15
|
||||
uint64_t t14 = goldilocks_sub( x7 , x15 ); // x7 - x15
|
||||
uint64_t t15 = goldilocks_add( t1 , t2 ); // t1 + t2
|
||||
uint64_t t16 = goldilocks_add( t3 , t5 ); // t3 + t5
|
||||
uint64_t t17 = goldilocks_add( t15 , t16 ); // t15 + t16
|
||||
uint64_t t18 = goldilocks_add( t7 , t11 ); // t7 + t11
|
||||
uint64_t t19 = goldilocks_sub( t7 , t11 ); // t7 - t11
|
||||
uint64_t t20 = goldilocks_add( t9 , t13 ); // t9 + t13
|
||||
uint64_t t21 = goldilocks_sub( t9 , t13 ); // t9 - t13
|
||||
uint64_t t22 = goldilocks_add( t18 , t20 ); // t18 + t20
|
||||
uint64_t t23 = goldilocks_add( t8 , t14 ); // t8 + t14
|
||||
uint64_t t24 = goldilocks_sub( t8 , t14 ); // t8 - t14
|
||||
uint64_t t25 = goldilocks_add( t10 , t12 ); // t10 + t12
|
||||
uint64_t t26 = goldilocks_sub( t12 , t10 ); // t12 - t10
|
||||
|
||||
uint64_t m0 = goldilocks_add( t17 , t22 ); // t17 + t22
|
||||
uint64_t m1 = goldilocks_sub( t17 , t22 ); // t17 - t22
|
||||
uint64_t m2 = goldilocks_sub( t15 , t16 ); // t15 - t16
|
||||
uint64_t m3 = goldilocks_sub( t1 , t2 ); // t1 - t2
|
||||
uint64_t m4 = goldilocks_sub( x0 , x8 ); // x0 - x8
|
||||
|
||||
uint64_t m5 = goldilocks_mul( DFT16_COS_2U , goldilocks_sub(t19 , t21) );
|
||||
uint64_t m6 = goldilocks_mul( DFT16_COS_2U , goldilocks_sub(t4 , t6 ) );
|
||||
uint64_t m7 = goldilocks_mul( DFT16_COS_3U , goldilocks_add(t24 , t26) );
|
||||
uint64_t m8 = goldilocks_mul( DFT16_COS_3U_PLUS_U , t24 );
|
||||
uint64_t m9 = goldilocks_mul( DFT16_COS_3U_MINUS_U , t26 );
|
||||
uint64_t m10 = goldilocks_mul( DFT16_J , goldilocks_sub(t20 , t18) );
|
||||
uint64_t m11 = goldilocks_mul( DFT16_J , goldilocks_sub(t5 , t3 ) );
|
||||
uint64_t m12 = goldilocks_mul( DFT16_J , goldilocks_sub(x12 , x4 ) );
|
||||
uint64_t m13 = goldilocks_mul( DFT16_MINUS_J_SIN_2U , goldilocks_add( t19 , t21) );
|
||||
uint64_t m14 = goldilocks_mul( DFT16_MINUS_J_SIN_2U , goldilocks_add( t4 , t6 ) );
|
||||
uint64_t m15 = goldilocks_mul( DFT16_MINUS_J_SIN_3U , goldilocks_add( t23 , t25) );
|
||||
uint64_t m16 = goldilocks_mul( DFT16_J_SIN_3U_MINUS_U , t23 );
|
||||
uint64_t m17 = goldilocks_mul( DFT16_J_SIN_MINUS_3U_MINUS_U , t25 );
|
||||
|
||||
uint64_t s1 = goldilocks_add( m3 , m5 ); // m3 + m5
|
||||
uint64_t s2 = goldilocks_sub( m3 , m5 ); // m3 - m5
|
||||
uint64_t s3 = goldilocks_add( m11 , m13 ); // m11 + m13
|
||||
uint64_t s4 = goldilocks_sub( m13 , m11 ); // m13 - m11
|
||||
uint64_t s5 = goldilocks_add( m4 , m6 ); // m4 + m6
|
||||
uint64_t s6 = goldilocks_sub( m4 , m6 ); // m4 - m6
|
||||
uint64_t s7 = goldilocks_sub( m8 , m7 ); // m8 - m7
|
||||
uint64_t s8 = goldilocks_sub( m9 , m7 ); // m9 - m7
|
||||
uint64_t s9 = goldilocks_add( s5 , s7 ); // s5 + s7
|
||||
uint64_t s10 = goldilocks_sub( s5 , s7 ); // s5 - s7
|
||||
uint64_t s11 = goldilocks_add( s6 , s8 ); // s6 + s8
|
||||
uint64_t s12 = goldilocks_sub( s6 , s8 ); // s6 - s8
|
||||
uint64_t s13 = goldilocks_add( m12 , m14 ); // m12 + m14
|
||||
uint64_t s14 = goldilocks_sub( m12 , m14 ); // m12 - m14
|
||||
uint64_t s15 = goldilocks_add( m15 , m16 ); // m15 + m16
|
||||
uint64_t s16 = goldilocks_sub( m15 , m17 ); // m15 - m17
|
||||
uint64_t s17 = goldilocks_add( s13 , s15 ); // s13 + s15
|
||||
uint64_t s18 = goldilocks_sub( s13 , s15 ); // s13 - s15
|
||||
uint64_t s19 = goldilocks_add( s14 , s16 ); // s14 + s16
|
||||
uint64_t s20 = goldilocks_sub( s14 , s16 ); // s14 - s16
|
||||
|
||||
int tgt_stride2 = tgt_stride + tgt_stride ;
|
||||
int tgt_stride3 = tgt_stride2 + tgt_stride ;
|
||||
int tgt_stride4 = tgt_stride2 + tgt_stride2;
|
||||
int tgt_stride5 = tgt_stride4 + tgt_stride ;
|
||||
int tgt_stride6 = tgt_stride4 + tgt_stride2;
|
||||
int tgt_stride7 = tgt_stride4 + tgt_stride3;
|
||||
int tgt_stride8 = tgt_stride4 + tgt_stride4;
|
||||
|
||||
tgt[ 0 ] = m0; // m0
|
||||
tgt[ tgt_stride ] = goldilocks_sub( s9 , s17 ); // s9 - s17
|
||||
tgt[ tgt_stride2 ] = goldilocks_sub( s1 , s3 ); // s1 - s3
|
||||
tgt[ tgt_stride3 ] = goldilocks_add( s12 , s20 ); // s12 + s20
|
||||
tgt[ tgt_stride4 ] = goldilocks_sub( m2 , m10 ); // m2 - m10
|
||||
tgt[ tgt_stride5 ] = goldilocks_sub( s11 , s19 ); // s11 - s19
|
||||
tgt[ tgt_stride6 ] = goldilocks_sub( s2 , s4 ); // s2 - s4
|
||||
tgt[ tgt_stride7 ] = goldilocks_add( s10 , s18 ); // s10 + s18
|
||||
tgt[ tgt_stride8 ] = m1; // m1
|
||||
tgt[ tgt_stride + tgt_stride8 ] = goldilocks_sub( s10 , s18 ); // s10 - s18
|
||||
tgt[ tgt_stride2 + tgt_stride8 ] = goldilocks_add( s2 , s4 ); // s2 + s4
|
||||
tgt[ tgt_stride3 + tgt_stride8 ] = goldilocks_add( s11 , s19 ); // s11 + s19
|
||||
tgt[ tgt_stride4 + tgt_stride8 ] = goldilocks_add( m2 , m10 ); // m2 + m10
|
||||
tgt[ tgt_stride5 + tgt_stride8 ] = goldilocks_sub( s12 , s20 ); // s12 - s20
|
||||
tgt[ tgt_stride6 + tgt_stride8 ] = goldilocks_add( s1 , s3 ); // s1 + s3
|
||||
tgt[ tgt_stride7 + tgt_stride8 ] = goldilocks_add( s9 , s17 ); // s9 + s17
|
||||
}
|
||||
|
||||
//--------------------------------------
|
||||
|
||||
const uint64_t IDFT16_OMEGA = 0xefffffff00000001 ;
|
||||
const uint64_t IDFT16_INV_OMEGA = 0x0000001000000000 ;
|
||||
const uint64_t IDFT16_INV_16 = 0xefffffff10000001 ;
|
||||
const uint64_t IDFT16_J = 0x0001000000000000 ;
|
||||
const uint64_t IDFT16_COS_U = 0xf800000700000001 ;
|
||||
const uint64_t IDFT16_COS_2U = 0x0000007fff7fff80 ;
|
||||
const uint64_t IDFT16_COS_3U = 0x0007fffffff7f800 ;
|
||||
const uint64_t IDFT16_MINUS_J_SIN_U = 0x0800000800000000 ;
|
||||
const uint64_t IDFT16_MINUS_J_SIN_2U = 0x00000080007fff80 ;
|
||||
const uint64_t IDFT16_MINUS_J_SIN_3U = 0xfff7ffff0007f801 ;
|
||||
|
||||
const uint64_t IDFT16_COS_3U_PLUS_U = 0xf8080006fff7f801 ;
|
||||
const uint64_t IDFT16_COS_3U_MINUS_U = 0x0807fff7fff7f800 ;
|
||||
const uint64_t IDFT16_J_SIN_3U_MINUS_U = 0x08080007fff80800 ;
|
||||
const uint64_t IDFT16_J_SIN_MINUS_3U_MINUS_U = 0x07f800080007f800 ;
|
||||
|
||||
void short_inv_DFT_size_16_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
|
||||
|
||||
int src_stride2 = src_stride + src_stride ;
|
||||
int src_stride3 = src_stride2 + src_stride ;
|
||||
int src_stride4 = src_stride2 + src_stride2;
|
||||
int src_stride5 = src_stride4 + src_stride ;
|
||||
int src_stride6 = src_stride4 + src_stride2;
|
||||
int src_stride7 = src_stride4 + src_stride3;
|
||||
int src_stride8 = src_stride4 + src_stride4;
|
||||
|
||||
uint64_t x0 = src[ 0 ];
|
||||
uint64_t x1 = src[ src_stride ];
|
||||
uint64_t x2 = src[ src_stride2 ];
|
||||
uint64_t x3 = src[ src_stride3 ];
|
||||
uint64_t x4 = src[ src_stride4 ];
|
||||
uint64_t x5 = src[ src_stride5 ];
|
||||
uint64_t x6 = src[ src_stride6 ];
|
||||
uint64_t x7 = src[ src_stride7 ];
|
||||
uint64_t x8 = src[ src_stride8 ];
|
||||
uint64_t x9 = src[ src_stride + src_stride8 ];
|
||||
uint64_t x10 = src[ src_stride2 + src_stride8 ];
|
||||
uint64_t x11 = src[ src_stride3 + src_stride8 ];
|
||||
uint64_t x12 = src[ src_stride4 + src_stride8 ];
|
||||
uint64_t x13 = src[ src_stride5 + src_stride8 ];
|
||||
uint64_t x14 = src[ src_stride6 + src_stride8 ];
|
||||
uint64_t x15 = src[ src_stride7 + src_stride8 ];
|
||||
|
||||
uint64_t t1 = goldilocks_add( x0 , x8 ); // x0 + x8
|
||||
uint64_t t2 = goldilocks_add( x4 , x12 ); // x4 + x12
|
||||
uint64_t t3 = goldilocks_add( x2 , x10 ); // x2 + x10
|
||||
uint64_t t4 = goldilocks_sub( x2 , x10 ); // x2 - x10
|
||||
uint64_t t5 = goldilocks_add( x6 , x14 ); // x6 + x14
|
||||
uint64_t t6 = goldilocks_sub( x6 , x14 ); // x6 - x14
|
||||
uint64_t t7 = goldilocks_add( x1 , x9 ); // x1 + x9
|
||||
uint64_t t8 = goldilocks_sub( x1 , x9 ); // x1 - x9
|
||||
uint64_t t9 = goldilocks_add( x3 , x11 ); // x3 + x11
|
||||
uint64_t t10 = goldilocks_sub( x3 , x11 ); // x3 - x11
|
||||
uint64_t t11 = goldilocks_add( x5 , x13 ); // x5 + x13
|
||||
uint64_t t12 = goldilocks_sub( x5 , x13 ); // x5 - x13
|
||||
uint64_t t13 = goldilocks_add( x7 , x15 ); // x7 + x15
|
||||
uint64_t t14 = goldilocks_sub( x7 , x15 ); // x7 - x15
|
||||
uint64_t t15 = goldilocks_add( t1 , t2 ); // t1 + t2
|
||||
uint64_t t16 = goldilocks_add( t3 , t5 ); // t3 + t5
|
||||
uint64_t t17 = goldilocks_add( t15 , t16 ); // t15 + t16
|
||||
uint64_t t18 = goldilocks_add( t7 , t11 ); // t7 + t11
|
||||
uint64_t t19 = goldilocks_sub( t7 , t11 ); // t7 - t11
|
||||
uint64_t t20 = goldilocks_add( t9 , t13 ); // t9 + t13
|
||||
uint64_t t21 = goldilocks_sub( t9 , t13 ); // t9 - t13
|
||||
uint64_t t22 = goldilocks_add( t18 , t20 ); // t18 + t20
|
||||
uint64_t t23 = goldilocks_add( t8 , t14 ); // t8 + t14
|
||||
uint64_t t24 = goldilocks_sub( t8 , t14 ); // t8 - t14
|
||||
uint64_t t25 = goldilocks_add( t10 , t12 ); // t10 + t12
|
||||
uint64_t t26 = goldilocks_sub( t12 , t10 ); // t12 - t10
|
||||
|
||||
uint64_t m0 = goldilocks_add( t17 , t22 ); // t17 + t22
|
||||
uint64_t m1 = goldilocks_sub( t17 , t22 ); // t17 - t22
|
||||
uint64_t m2 = goldilocks_sub( t15 , t16 ); // t15 - t16
|
||||
uint64_t m3 = goldilocks_sub( t1 , t2 ); // t1 - t2
|
||||
uint64_t m4 = goldilocks_sub( x0 , x8 ); // x0 - x8
|
||||
|
||||
uint64_t m5 = goldilocks_mul( IDFT16_COS_2U , goldilocks_sub(t19 , t21) );
|
||||
uint64_t m6 = goldilocks_mul( IDFT16_COS_2U , goldilocks_sub(t4 , t6 ) );
|
||||
uint64_t m7 = goldilocks_mul( IDFT16_COS_3U , goldilocks_add(t24 , t26) );
|
||||
uint64_t m8 = goldilocks_mul( IDFT16_COS_3U_PLUS_U , t24 );
|
||||
uint64_t m9 = goldilocks_mul( IDFT16_COS_3U_MINUS_U , t26 );
|
||||
uint64_t m10 = goldilocks_mul( IDFT16_J , goldilocks_sub(t20 , t18) );
|
||||
uint64_t m11 = goldilocks_mul( IDFT16_J , goldilocks_sub(t5 , t3 ) );
|
||||
uint64_t m12 = goldilocks_mul( IDFT16_J , goldilocks_sub(x12 , x4 ) );
|
||||
uint64_t m13 = goldilocks_mul( IDFT16_MINUS_J_SIN_2U , goldilocks_add( t19 , t21) );
|
||||
uint64_t m14 = goldilocks_mul( IDFT16_MINUS_J_SIN_2U , goldilocks_add( t4 , t6 ) );
|
||||
uint64_t m15 = goldilocks_mul( IDFT16_MINUS_J_SIN_3U , goldilocks_add( t23 , t25) );
|
||||
uint64_t m16 = goldilocks_mul( IDFT16_J_SIN_3U_MINUS_U , t23 );
|
||||
uint64_t m17 = goldilocks_mul( IDFT16_J_SIN_MINUS_3U_MINUS_U , t25 );
|
||||
|
||||
uint64_t s1 = goldilocks_add( m3 , m5 ); // m3 + m5
|
||||
uint64_t s2 = goldilocks_sub( m3 , m5 ); // m3 - m5
|
||||
uint64_t s3 = goldilocks_add( m11 , m13 ); // m11 + m13
|
||||
uint64_t s4 = goldilocks_sub( m13 , m11 ); // m13 - m11
|
||||
uint64_t s5 = goldilocks_add( m4 , m6 ); // m4 + m6
|
||||
uint64_t s6 = goldilocks_sub( m4 , m6 ); // m4 - m6
|
||||
uint64_t s7 = goldilocks_sub( m8 , m7 ); // m8 - m7
|
||||
uint64_t s8 = goldilocks_sub( m9 , m7 ); // m9 - m7
|
||||
uint64_t s9 = goldilocks_add( s5 , s7 ); // s5 + s7
|
||||
uint64_t s10 = goldilocks_sub( s5 , s7 ); // s5 - s7
|
||||
uint64_t s11 = goldilocks_add( s6 , s8 ); // s6 + s8
|
||||
uint64_t s12 = goldilocks_sub( s6 , s8 ); // s6 - s8
|
||||
uint64_t s13 = goldilocks_add( m12 , m14 ); // m12 + m14
|
||||
uint64_t s14 = goldilocks_sub( m12 , m14 ); // m12 - m14
|
||||
uint64_t s15 = goldilocks_add( m15 , m16 ); // m15 + m16
|
||||
uint64_t s16 = goldilocks_sub( m15 , m17 ); // m15 - m17
|
||||
uint64_t s17 = goldilocks_add( s13 , s15 ); // s13 + s15
|
||||
uint64_t s18 = goldilocks_sub( s13 , s15 ); // s13 - s15
|
||||
uint64_t s19 = goldilocks_add( s14 , s16 ); // s14 + s16
|
||||
uint64_t s20 = goldilocks_sub( s14 , s16 ); // s14 - s16
|
||||
|
||||
int tgt_stride2 = tgt_stride + tgt_stride ;
|
||||
int tgt_stride3 = tgt_stride2 + tgt_stride ;
|
||||
int tgt_stride4 = tgt_stride2 + tgt_stride2;
|
||||
int tgt_stride5 = tgt_stride4 + tgt_stride ;
|
||||
int tgt_stride6 = tgt_stride4 + tgt_stride2;
|
||||
int tgt_stride7 = tgt_stride4 + tgt_stride3;
|
||||
int tgt_stride8 = tgt_stride4 + tgt_stride4;
|
||||
|
||||
tgt[ 0 ] = m0; // m0
|
||||
tgt[ tgt_stride ] = goldilocks_add( s9 , s17 ); // s9 + s17
|
||||
tgt[ tgt_stride2 ] = goldilocks_add( s1 , s3 ); // s1 + s3
|
||||
tgt[ tgt_stride3 ] = goldilocks_sub( s12 , s20 ); // s12 - s20
|
||||
tgt[ tgt_stride4 ] = goldilocks_add( m2 , m10 ); // m2 + m10
|
||||
tgt[ tgt_stride5 ] = goldilocks_add( s11 , s19 ); // s11 + s19
|
||||
tgt[ tgt_stride6 ] = goldilocks_add( s2 , s4 ); // s2 + s4
|
||||
tgt[ tgt_stride7 ] = goldilocks_sub( s10 , s18 ); // s10 - s18
|
||||
tgt[ tgt_stride8 ] = m1; // m1
|
||||
tgt[ tgt_stride + tgt_stride8 ] = goldilocks_add( s10 , s18 ); // s10 + s18
|
||||
tgt[ tgt_stride2 + tgt_stride8 ] = goldilocks_sub( s2 , s4 ); // s2 - s4
|
||||
tgt[ tgt_stride3 + tgt_stride8 ] = goldilocks_sub( s11 , s19 ); // s11 - s19
|
||||
tgt[ tgt_stride4 + tgt_stride8 ] = goldilocks_sub( m2 , m10 ); // m2 - m10
|
||||
tgt[ tgt_stride5 + tgt_stride8 ] = goldilocks_add( s12 , s20 ); // s12 + s20
|
||||
tgt[ tgt_stride6 + tgt_stride8 ] = goldilocks_sub( s1 , s3 ); // s1 - s3
|
||||
tgt[ tgt_stride7 + tgt_stride8 ] = goldilocks_sub( s9 , s17 ); // s9 - s17
|
||||
|
||||
}
|
||||
|
||||
//------------------
|
||||
|
||||
void short_inv_DFT_size_16_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
|
||||
|
||||
short_inv_DFT_size_16_unscaled( src_stride, tgt_stride, src, tgt );
|
||||
|
||||
int tgt_stride2 = tgt_stride + tgt_stride ;
|
||||
int tgt_stride3 = tgt_stride2 + tgt_stride ;
|
||||
int tgt_stride4 = tgt_stride2 + tgt_stride2;
|
||||
int tgt_stride5 = tgt_stride4 + tgt_stride ;
|
||||
int tgt_stride6 = tgt_stride4 + tgt_stride2;
|
||||
int tgt_stride7 = tgt_stride4 + tgt_stride3;
|
||||
int tgt_stride8 = tgt_stride4 + tgt_stride4;
|
||||
|
||||
tgt[ 0 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ 0 ] );
|
||||
tgt[ tgt_stride ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride ] );
|
||||
tgt[ tgt_stride2 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride2 ] );
|
||||
tgt[ tgt_stride3 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride3 ] );
|
||||
tgt[ tgt_stride4 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride4 ] );
|
||||
tgt[ tgt_stride5 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride5 ] );
|
||||
tgt[ tgt_stride6 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride6 ] );
|
||||
tgt[ tgt_stride7 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride7 ] );
|
||||
tgt[ tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride8 ] );
|
||||
tgt[ tgt_stride + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride + tgt_stride8 ] );
|
||||
tgt[ tgt_stride2 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride2 + tgt_stride8 ] );
|
||||
tgt[ tgt_stride3 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride3 + tgt_stride8 ] );
|
||||
tgt[ tgt_stride4 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride4 + tgt_stride8 ] );
|
||||
tgt[ tgt_stride5 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride5 + tgt_stride8 ] );
|
||||
tgt[ tgt_stride6 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride6 + tgt_stride8 ] );
|
||||
tgt[ tgt_stride7 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride7 + tgt_stride8 ] );
|
||||
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
24
reference/src/cbits/short_dft.h
Normal file
24
reference/src/cbits/short_dft.h
Normal file
@ -0,0 +1,24 @@
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
void short_fwd_DFT_size_4 ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
void short_inv_DFT_size_4_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
void short_inv_DFT_size_4_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
|
||||
void short_fwd_DFT_size_8 ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
void short_inv_DFT_size_8_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
void short_inv_DFT_size_8_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
|
||||
void short_fwd_DFT_size_16 ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
void short_inv_DFT_size_16_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
void short_inv_DFT_size_16_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// void short_fwd_DFT_size_4_ext ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
// void short_inv_DFT_size_4_ext_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
// void short_inv_DFT_size_4_ext_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
@ -1,3 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
ghci testMain.hs cbits/goldilocks.o cbits/goldilocks_ext.o cbits/monolith.o cbits/ntt.o
|
||||
ghci testMain.hs cbits/goldilocks.o cbits/goldilocks_ext.o cbits/monolith.o cbits/ntt.o cbits/short_dft.o
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user