specialized algorithms for short NTT/INTT (size 4, 8, 16)

This commit is contained in:
Balazs Komuves 2025-11-04 22:40:33 +01:00
parent bd888d5b57
commit 0eb39eb5c9
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
5 changed files with 1182 additions and 1 deletions

View File

@ -0,0 +1,590 @@
-- Short DFT algorithms (sizes 2, 4, 8 and 16)
--
-- See :
--
-- * Nussbaumer: "Fast Fourier Transform and Convolution Algorithms", Chapter 5.5
--
{-# LANGUAGE ForeignFunctionInterface #-}
module NTT.FFT.Short where
--------------------------------------------------------------------------------
import Data.Array
import Data.Word
import Data.List ( sort )
import Control.Monad
import Foreign.C
import Foreign.Ptr
import Foreign.Marshal
import System.IO.Unsafe
import Class.Field
import Field.Goldilocks ( F , fromF )
import NTT.FFT.Slow
import NTT.Poly.Naive
import NTT.Subgroup
import Misc
--------------------------------------------------------------------------------
-- DFT mathematical definition - O(n^2)
defForwardDFT :: Subgroup F -> Array Int F -> Array Int F
defForwardDFT sg input
| subgroupOrder sg /= n = error "defForwardDFT: input size does not match the subgroup"
| otherwise = listToArray [ f k | k <-[0..n-1] ]
where
n = arrayLength input
omega = subgroupGen sg
f k = sum [ input!i * power_ omega (k*i) | i<-[0..n-1] ]
defInverseDFT :: Subgroup F -> Array Int F -> Array Int F
defInverseDFT sg input
| subgroupOrder sg /= n = error "defInverseDFT: input size does not match the subgroup"
| otherwise = listToArray [ f k / (fromIntegral n) | k <-[0..n-1] ]
where
n = arrayLength input
omega = subgroupGen sg
invomega = inverse omega
f k = sum [ input!i * power_ invomega (k*i) | i<-[0..n-1] ]
--------------------------------------------------------------------------------
foreign import ccall unsafe "short_fwd_DFT_size_4" c_dft4_fwd :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "short_inv_DFT_size_4_unscaled" c_dft4_inv_unscaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "short_inv_DFT_size_4_rescaled" c_dft4_inv_rescaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "short_fwd_DFT_size_8" c_dft8_fwd :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "short_inv_DFT_size_8_unscaled" c_dft8_inv_unscaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "short_inv_DFT_size_8_rescaled" c_dft8_inv_rescaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "short_fwd_DFT_size_16" c_dft16_fwd :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "short_inv_DFT_size_16_unscaled" c_dft16_inv_unscaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "short_inv_DFT_size_16_rescaled" c_dft16_inv_rescaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO ()
----------------------------------------
{-# NOINLINE shortForwardDFT4 #-}
shortForwardDFT4 :: Array Int F -> Array Int F
shortForwardDFT4 input
| arrayLength input /= 4 = error "shortForwardDFT4: expecting an input array of size 4"
| otherwise = unsafePerformIO $ do
allocaArray 4 $ \(ptr1 :: Ptr F) -> do
allocaArray 4 $ \(ptr2 :: Ptr F) -> do
pokeArray ptr1 (elems input)
c_dft4_fwd 1 1 (castPtr ptr1) (castPtr ptr2)
ys <- peekArray 4 ptr2
return $ listToArray ys
{-# NOINLINE shortInverseDFT4 #-}
shortInverseDFT4 :: Array Int F -> Array Int F
shortInverseDFT4 input
| arrayLength input /= 4 = error "shortInverseDFT4: expecting an input array of size 4"
| otherwise = unsafePerformIO $ do
allocaArray 4 $ \(ptr1 :: Ptr F) -> do
allocaArray 4 $ \(ptr2 :: Ptr F) -> do
pokeArray ptr1 (elems input)
c_dft4_inv_rescaled 1 1 (castPtr ptr1) (castPtr ptr2)
ys <- peekArray 4 ptr2
return $ listToArray ys
----------------------------------------
{-# NOINLINE shortForwardDFT8 #-}
shortForwardDFT8 :: Array Int F -> Array Int F
shortForwardDFT8 input
| arrayLength input /= 8 = error "shortForwardDFT8: expecting an input array of size 8"
| otherwise = unsafePerformIO $ do
allocaArray 8 $ \(ptr1 :: Ptr F) -> do
allocaArray 8 $ \(ptr2 :: Ptr F) -> do
pokeArray ptr1 (elems input)
c_dft8_fwd 1 1 (castPtr ptr1) (castPtr ptr2)
ys <- peekArray 8 ptr2
return $ listToArray ys
{-# NOINLINE shortInverseDFT8 #-}
shortInverseDFT8 :: Array Int F -> Array Int F
shortInverseDFT8 input
| arrayLength input /= 8 = error "shortInverseDFT8: expecting an input array of size 8"
| otherwise = unsafePerformIO $ do
allocaArray 8 $ \(ptr1 :: Ptr F) -> do
allocaArray 8 $ \(ptr2 :: Ptr F) -> do
pokeArray ptr1 (elems input)
c_dft8_inv_rescaled 1 1 (castPtr ptr1) (castPtr ptr2)
ys <- peekArray 8 ptr2
return $ listToArray ys
----------------------------------------
{-# NOINLINE shortForwardDFT16 #-}
shortForwardDFT16 :: Array Int F -> Array Int F
shortForwardDFT16 input
| arrayLength input /= 16 = error "shortForwardDFT16: expecting an input array of size 16"
| otherwise = unsafePerformIO $ do
allocaArray 16 $ \(ptr1 :: Ptr F) -> do
allocaArray 16 $ \(ptr2 :: Ptr F) -> do
pokeArray ptr1 (elems input)
c_dft16_fwd 1 1 (castPtr ptr1) (castPtr ptr2)
ys <- peekArray 16 ptr2
return $ listToArray ys
{-# NOINLINE shortInverseDFT16 #-}
shortInverseDFT16 :: Array Int F -> Array Int F
shortInverseDFT16 input
| arrayLength input /= 16 = error "shortInverseDFT16: expecting an input array of size 16"
| otherwise = unsafePerformIO $ do
allocaArray 16 $ \(ptr1 :: Ptr F) -> do
allocaArray 16 $ \(ptr2 :: Ptr F) -> do
pokeArray ptr1 (elems input)
c_dft16_inv_rescaled 1 1 (castPtr ptr1) (castPtr ptr2)
ys <- peekArray 16 ptr2
return $ listToArray ys
--------------------------------------------------------------------------------
testDFT4 :: IO ()
testDFT4 = do
let sg = getSubgroup (Log2 2)
xs <- listToArray <$> (replicateM 4 rndIO :: IO [F])
let ys0 = defForwardDFT sg xs
let ys1 = forwardNTT sg (Poly xs)
let ys2 = shortForwardDFT4 xs
print $ "xs = " ++ show (elems xs )
print $ "ys0 (defin) = " ++ show (elems ys0)
print $ "ys1 (ref) = " ++ show (elems ys1)
print $ "ys2 (short) = " ++ show (elems ys2)
print $ "ok: " ++ show (ys1 == ys2)
testIDFT4 :: IO ()
testIDFT4 = do
let sg = getSubgroup (Log2 2)
xs <- listToArray <$> (replicateM 4 rndIO :: IO [F])
let ys0 = defInverseDFT sg xs
let ys1 = fromPoly (inverseNTT sg xs)
let ys2 = shortInverseDFT4 xs
print $ "xs = " ++ show (elems xs )
print $ "ys0 (defin) = " ++ show (elems ys0)
print $ "ys1 (ref) = " ++ show (elems ys1)
print $ "ys2 (short) = " ++ show (elems ys2)
print $ "ok: " ++ show (ys1 == ys2)
----------------------------------------
testDFT8 :: IO ()
testDFT8 = do
let sg = getSubgroup (Log2 3)
xs <- listToArray <$> (replicateM 8 rndIO :: IO [F])
let ys0 = defForwardDFT sg xs
let ys1 = forwardNTT sg (Poly xs)
let ys2 = shortForwardDFT8 xs
print $ "xs = " ++ show (elems xs )
print $ "ys0 (defin) = " ++ show (elems ys0)
print $ "ys1 (ref) = " ++ show (elems ys1)
print $ "ys2 (short) = " ++ show (elems ys2)
print $ "ok: " ++ show (ys1 == ys2)
testIDFT8 :: IO ()
testIDFT8 = do
let sg = getSubgroup (Log2 3)
xs <- listToArray <$> (replicateM 8 rndIO :: IO [F])
let ys0 = defInverseDFT sg xs
let ys1 = fromPoly (inverseNTT sg xs)
let ys2 = shortInverseDFT8 xs
print $ "xs = " ++ show (elems xs )
print $ "ys0 (defin) = " ++ show (elems ys0)
print $ "ys1 (ref) = " ++ show (elems ys1)
print $ "ys2 (short) = " ++ show (elems ys2)
print $ "ok: " ++ show (ys1 == ys2)
--------------------------------------------------------------------------------
-- tmp for debugging
instance Ord F where
compare a b = compare (fromF a) (fromF b)
testDFT16 :: IO ()
testDFT16 = do
let sg = getSubgroup (Log2 4)
xs <- listToArray <$> (replicateM 16 rndIO :: IO [F])
let ys0 = defForwardDFT sg xs
let ys1 = forwardNTT sg (Poly xs)
let ys2 = experimental_DFT16 xs
let ys3 = shortForwardDFT16 xs
print $ "xs = " ++ show (elems xs )
print $ "ys0 (defin) = " ++ show (elems ys0)
print $ "ys1 (ref) = " ++ show (elems ys1)
print $ "ys2 (hsexp) = " ++ show (elems ys2)
print $ "ys3 (short) = " ++ show (elems ys3)
print $ "ok-def: " ++ show (ys0 == ys1)
print $ "ok-hs: " ++ show (ys1 == ys2)
print $ "ok-hs (modulo order): " ++ show (sort (elems ys1) == sort (elems ys2))
print $ "ok-C: " ++ show (ys1 == ys3)
testIDFT16 :: IO ()
testIDFT16 = do
let sg = getSubgroup (Log2 4)
xs <- listToArray <$> (replicateM 16 rndIO :: IO [F])
let ys0 = defInverseDFT sg xs
let ys1 = fromPoly (inverseNTT sg xs)
let ys2 = experimental_IDFT16 xs
let ys3 = shortInverseDFT16 xs
print $ "xs = " ++ show (elems xs )
print $ "ys0 (defin) = " ++ show (elems ys0)
print $ "ys1 (ref) = " ++ show (elems ys1)
print $ "ys2 (hsexp) = " ++ show (elems ys2)
print $ "ys3 (short) = " ++ show (elems ys3)
print $ "ok-def: " ++ show (ys0 == ys1)
print $ "ok-hs: " ++ show (ys1 == ys2)
print $ "ok-C: " ++ show (ys1 == ys3)
--------------------------------------------------------------------------------
dft2_omega, dft4_omega, dft8_omega, dft16_omega :: F
dft2_omega = subgroupGen $ getSubgroup (Log2 1)
dft4_omega = subgroupGen $ getSubgroup (Log2 2)
dft8_omega = subgroupGen $ getSubgroup (Log2 3)
dft16_omega = subgroupGen $ getSubgroup (Log2 4)
dft2_inv_omega = inverse dft2_omega
dft4_inv_omega = inverse dft4_omega
dft8_inv_omega = inverse dft8_omega
dft16_inv_omega = inverse dft16_omega
--------------------------------------------------------------------------------
-- * Size = 4
dft4_j = dft4_omega
idft4_inv_4 = 1 / (4 :: F)
printIDFT4 = do
putStrLn $ "const uint64_t IDFT4_OMEGA = " ++ show dft4_omega ++ " ;"
putStrLn $ "const uint64_t IDFT4_INV_OMEGA = " ++ show dft4_inv_omega ++ " ;"
putStrLn $ "const uint64_t IDFT4_J = " ++ show dft4_j ++ " ;"
putStrLn $ "const uint64_t IDFT4_INV_4 = " ++ show idft4_inv_4 ++ " ;"
--------------------------------------------------------------------------------
-- * Size = 8
dft8_j = square dft8_omega
dft8_cos_u = - (dft8_omega + dft8_omega^7) / 2
dft8_j_sin_u = - (dft8_omega - dft8_omega^7) / 2
dft8_minus_j_sin_u = - dft8_j_sin_u
printDFT8 = do
putStrLn $ "const uint64_t DFT8_OMEGA = " ++ show dft8_omega ++ " ;"
putStrLn $ "const uint64_t DFT8_INV_OMEGA = " ++ show dft8_inv_omega ++ " ;"
putStrLn $ "const uint64_t DFT8_J = " ++ show dft8_j ++ " ;"
putStrLn $ "const uint64_t DFT8_COS_U = " ++ show dft8_cos_u ++ " ;"
putStrLn $ "const uint64_t DFT8_MINUS_J_SIN_U = " ++ show dft8_minus_j_sin_u ++ " ;"
idft8_j = square dft8_omega
idft8_cos_u = (dft8_omega + dft8_omega^7) / 2
idft8_j_sin_u = (dft8_omega - dft8_omega^7) / 2
idft8_minus_j_sin_u = - idft8_j_sin_u
idft8_inv_8 = 1 / (8 :: F)
printIDFT8 = do
putStrLn $ "const uint64_t IDFT8_OMEGA = " ++ show dft8_omega ++ " ;"
putStrLn $ "const uint64_t IDFT8_INV_OMEGA = " ++ show dft8_inv_omega ++ " ;"
putStrLn $ "const uint64_t IDFT8_J = " ++ show idft8_j ++ " ;"
putStrLn $ "const uint64_t IDFT8_COS_U = " ++ show idft8_cos_u ++ " ;"
putStrLn $ "const uint64_t IDFT8_MINUS_J_SIN_U = " ++ show idft8_minus_j_sin_u ++ " ;"
putStrLn $ "const uint64_t IDFT8_INV_8 = " ++ show idft8_inv_8 ++ " ;"
--------------------------------------------------------------------------------
-- * Size = 16
dft16_j = square (square dft16_omega)
dft16_cos_u = (dft16_omega + dft16_omega^15) / 2
dft16_cos_2u = (dft16_omega^2 + dft16_omega^14) / 2
dft16_cos_3u = (dft16_omega^3 + dft16_omega^13) / 2
dft16_j_sin_u = (dft16_omega - dft16_omega^15) / 2 -- ????
dft16_j_sin_2u = (dft16_omega^2 - dft16_omega^14) / 2 -- but it seems to work...
dft16_j_sin_3u = (dft16_omega^3 - dft16_omega^13) / 2
printDFT16 :: IO ()
printDFT16 = do
putStrLn $ "const uint64_t DFT16_OMEGA = " ++ show dft16_omega ++ " ;"
putStrLn $ "const uint64_t DFT16_INV_OMEGA = " ++ show dft16_inv_omega ++ " ;"
putStrLn $ "const uint64_t DFT16_J = " ++ show dft16_j ++ " ;"
putStrLn $ "const uint64_t DFT16_COS_U = " ++ show dft16_cos_u ++ " ;"
putStrLn $ "const uint64_t DFT16_COS_2U = " ++ show dft16_cos_2u ++ " ;"
putStrLn $ "const uint64_t DFT16_COS_3U = " ++ show dft16_cos_3u ++ " ;"
putStrLn $ "const uint64_t DFT16_MINUS_J_SIN_U = " ++ show (- dft16_j_sin_u ) ++ " ;"
putStrLn $ "const uint64_t DFT16_MINUS_J_SIN_2U = " ++ show (- dft16_j_sin_2u) ++ " ;"
putStrLn $ "const uint64_t DFT16_MINUS_J_SIN_3U = " ++ show (- dft16_j_sin_3u) ++ " ;"
putStrLn $ "const uint64_t DFT16_COS_3U_PLUS_U = " ++ show ( dft16_cos_3u + dft16_cos_u ) ++ " ;"
putStrLn $ "const uint64_t DFT16_COS_3U_MINUS_U = " ++ show ( dft16_cos_3u - dft16_cos_u ) ++ " ;"
putStrLn $ "const uint64_t DFT16_J_SIN_3U_MINUS_U = " ++ show ( dft16_j_sin_3u - dft16_j_sin_u) ++ " ;"
putStrLn $ "const uint64_t DFT16_J_SIN_MINUS_3U_MINUS_U = " ++ show (- dft16_j_sin_3u - dft16_j_sin_u) ++ " ;"
--------------------
idft16_inv_16 = 1 / (16 :: F)
idft16_j = square (square dft16_omega)
idft16_cos_u = (dft16_omega + dft16_omega^15) / 2
idft16_cos_2u = (dft16_omega^2 + dft16_omega^14) / 2
idft16_cos_3u = (dft16_omega^3 + dft16_omega^13) / 2
idft16_j_sin_u = (dft16_omega - dft16_omega^15) / 2
idft16_j_sin_2u = (dft16_omega^2 - dft16_omega^14) / 2
idft16_j_sin_3u = (dft16_omega^3 - dft16_omega^13) / 2
printIDFT16 :: IO ()
printIDFT16 = do
putStrLn $ "const uint64_t IDFT16_OMEGA = " ++ show dft16_omega ++ " ;"
putStrLn $ "const uint64_t IDFT16_INV_OMEGA = " ++ show dft16_inv_omega ++ " ;"
putStrLn $ "const uint64_t IDFT16_INV_16 = " ++ show idft16_inv_16 ++ " ;"
putStrLn $ "const uint64_t IDFT16_J = " ++ show idft16_j ++ " ;"
putStrLn $ "const uint64_t IDFT16_COS_U = " ++ show idft16_cos_u ++ " ;"
putStrLn $ "const uint64_t IDFT16_COS_2U = " ++ show idft16_cos_2u ++ " ;"
putStrLn $ "const uint64_t IDFT16_COS_3U = " ++ show idft16_cos_3u ++ " ;"
putStrLn $ "const uint64_t IDFT16_MINUS_J_SIN_U = " ++ show (- idft16_j_sin_u ) ++ " ;"
putStrLn $ "const uint64_t IDFT16_MINUS_J_SIN_2U = " ++ show (- idft16_j_sin_2u) ++ " ;"
putStrLn $ "const uint64_t IDFT16_MINUS_J_SIN_3U = " ++ show (- idft16_j_sin_3u) ++ " ;"
putStrLn $ "const uint64_t IDFT16_COS_3U_PLUS_U = " ++ show ( idft16_cos_3u + idft16_cos_u ) ++ " ;"
putStrLn $ "const uint64_t IDFT16_COS_3U_MINUS_U = " ++ show ( idft16_cos_3u - idft16_cos_u ) ++ " ;"
putStrLn $ "const uint64_t IDFT16_J_SIN_3U_MINUS_U = " ++ show ( idft16_j_sin_3u - idft16_j_sin_u) ++ " ;"
putStrLn $ "const uint64_t IDFT16_J_SIN_MINUS_3U_MINUS_U = " ++ show (- idft16_j_sin_3u - idft16_j_sin_u) ++ " ;"
--------------------------------------------------------------------------------
experimental_IDFT16 :: Array Int F -> Array Int F
experimental_IDFT16 input = output where
rescale :: F -> F
rescale x = x / 16
[x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15] = elems input
output = listArray (0,15) $ map rescale [y0,y1,y2,y3,y4,y5,y6,y7,y8,y9,y10,y11,y12,y13,y14,y15]
t1 = x0 + x8
t2 = x4 + x12
t3 = x2 + x10
t4 = x2 - x10
t5 = x6 + x14
t6 = x6 - x14
t7 = x1 + x9
t8 = x1 - x9
t9 = x3 + x11
t10 = x3 - x11
t11 = x5 + x13
t12 = x5 - x13
t13 = x7 + x15
t14 = x7 - x15
t15 = t1 + t2
t16 = t3 + t5
t17 = t15 + t16
t18 = t7 + t11
t19 = t7 - t11
t20 = t9 + t13
t21 = t9 - t13
t22 = t18 + t20
t23 = t8 + t14
t24 = t8 - t14
t25 = t10 + t12
t26 = t12 - t10
m0 = t17 + t22
m1 = t17 - t22
m2 = t15 - t16
m3 = t1 - t2
m4 = x0 - x8
m5 = idft16_cos_2u * (t19 - t21)
m6 = idft16_cos_2u * (t4 - t6 )
m7 = idft16_cos_3u * (t24 + t26)
m8 = (idft16_cos_3u + idft16_cos_u) * t24
m9 = (idft16_cos_3u - idft16_cos_u) * t26
m10 = idft16_j * (t20 - t18)
m11 = idft16_j * (t5 - t3 )
m12 = idft16_j * (x12 - x4 )
m13 = - idft16_j_sin_2u * ( t19 + t21)
m14 = - idft16_j_sin_2u * ( t4 + t6 )
m15 = - idft16_j_sin_3u * ( t23 + t25)
m16 = (idft16_j_sin_3u - idft16_j_sin_u) * t23
m17 = - (idft16_j_sin_3u + idft16_j_sin_u) * t25
s7 = m8 - m7
s8 = m9 - m7
s15 = m15 + m16
s16 = m15 - m17
s1 = m3 + m5
s2 = m3 - m5
s3 = m11 + m13
s4 = m13 - m11
s5 = m4 + m6
s6 = m4 - m6
s9 = s5 + s7
s10 = s5 - s7
s11 = s6 + s8
s12 = s6 - s8
s13 = m12 + m14
s14 = m12 - m14
s17 = s13 + s15
s18 = s13 - s15
s19 = s14 + s16
s20 = s14 - s16
y0 = m0
y1 = s9 + s17
y2 = s1 + s3
y3 = s12 - s20
y4 = m2 + m10
y5 = s11 + s19
y6 = s2 + s4
y7 = s10 - s18
y8 = m1
y9 = s10 + s18
y10 = s2 - s4
y11 = s11 - s19
y12 = m2 - m10
y13 = s12 + s20
y14 = s1 - s3
y15 = s9 - s17
----------------------------------------
-- is it always true, that in the NTT setting, the DFT and the unscaled
-- inverse DFT are the same up to permutation??
--
experimental_DFT16 :: Array Int F -> Array Int F
experimental_DFT16 input = output where
[x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15] = elems input
output = listArray (0,15) [y0,y1,y2,y3,y4,y5,y6,y7,y8,y9,y10,y11,y12,y13,y14,y15]
t1 = x0 + x8
t2 = x4 + x12
t3 = x2 + x10
t4 = x2 - x10
t5 = x6 + x14
t6 = x6 - x14
t7 = x1 + x9
t8 = x1 - x9
t9 = x3 + x11
t10 = x3 - x11
t11 = x5 + x13
t12 = x5 - x13
t13 = x7 + x15
t14 = x7 - x15
t15 = t1 + t2
t16 = t3 + t5
t17 = t15 + t16
t18 = t7 + t11
t19 = t7 - t11
t20 = t9 + t13
t21 = t9 - t13
t22 = t18 + t20
t23 = t8 + t14
t24 = t8 - t14
t25 = t10 + t12
t26 = t12 - t10
m0 = t17 + t22
m1 = t17 - t22
m2 = t15 - t16
m3 = t1 - t2
m4 = x0 - x8
m5 = dft16_cos_2u * (t19 - t21)
m6 = dft16_cos_2u * (t4 - t6 )
m7 = dft16_cos_3u * (t24 + t26)
m8 = (dft16_cos_3u + dft16_cos_u) * t24
m9 = (dft16_cos_3u - dft16_cos_u) * t26
m10 = dft16_j * (t20 - t18)
m11 = dft16_j * (t5 - t3 )
m12 = dft16_j * (x12 - x4 )
m13 = - dft16_j_sin_2u * ( t19 + t21)
m14 = - dft16_j_sin_2u * ( t4 + t6 )
m15 = - dft16_j_sin_3u * ( t23 + t25)
m16 = (dft16_j_sin_3u - dft16_j_sin_u) * t23
m17 = - (dft16_j_sin_3u + dft16_j_sin_u) * t25
s7 = m8 - m7
s8 = m9 - m7
s15 = m15 + m16
s16 = m15 - m17
s1 = m3 + m5
s2 = m3 - m5
s3 = m11 + m13
s4 = m13 - m11
s5 = m4 + m6
s6 = m4 - m6
s9 = s5 + s7
s10 = s5 - s7
s11 = s6 + s8
s12 = s6 - s8
s13 = m12 + m14
s14 = m12 - m14
s17 = s13 + s15
s18 = s13 - s15
s19 = s14 + s16
s20 = s14 - s16
y0 = m0
y1 = s9 - s17
y2 = s1 - s3
y3 = s12 + s20
y4 = m2 - m10
y5 = s11 - s19
y6 = s2 - s4
y7 = s10 + s18
y8 = m1
y9 = s10 - s18
y10 = s2 + s4
y11 = s11 + s19
y12 = m2 + m10
y13 = s12 - s20
y14 = s1 + s3
y15 = s9 + s17
--------------------------------------------------------------------------------
data Cost = MkCost
{ _nAdds :: Int
, _nMuls :: Int
}
deriving (Eq,Show)
addCost :: Cost -> Cost -> Cost
addCost (MkCost a1 m1) (MkCost a2 m2) = MkCost (a1+a2) (m1+m2)
scaleCost :: Int -> Cost -> Cost
scaleCost s (MkCost a m) = MkCost (s*a) (s*m)
doubleCost :: Cost -> Cost
doubleCost = scaleCost 2
estimateNTTCost :: Log2 -> Cost
estimateNTTCost = go where
go :: Log2 -> Cost
go 0 = MkCost 0 0
go 1 = MkCost 2 0
go m = recursive `addCost` post where
recursive = doubleCost (go (m-1))
post = scaleCost halfN (MkCost 3 1)
halfN = exp2_ (m-1)
--------------------------------------------------------------------------------

View File

@ -4,4 +4,5 @@ gcc -c -O2 goldilocks.c
gcc -c -O2 goldilocks_ext.c
gcc -c -O2 monolith.c
gcc -c -O2 ntt.c
gcc -c -O2 short_dft.c

View File

@ -0,0 +1,566 @@
//
// Short DFT algoritmus (size 8 and 16)
//
// See:
//
// - Nussbaumer: "Fast Fourier Transform and Convolution Algorithms", Chapter 5.5
//
// Note: As they describe complex DFT and we need NTT, the conventions differ a bit.
//
// Hence some formulas in the comments look false, that's because we follow the
// formulas from the book but need to change some coefficients and/or ordering...
//
// Note #2: If the multiplicative generator is changed, the constants need
// to be regenerated. See the module "NTT.FTT.Short"
//
#include <stdint.h>
#include "goldilocks.h"
#include "short_dft.h"
//------------------------------------------------------------------------------
// SIZE = 4
const uint64_t DFT4_J = 0x0001000000000000 ;
const uint64_t IDFT4_OMEGA = 0x0001000000000000 ;
const uint64_t IDFT4_INV_OMEGA = 0xfffeffff00000001 ;
const uint64_t IDFT4_J = 0x0001000000000000 ;
const uint64_t IDFT4_INV_4 = 0xbfffffff40000001 ;
void short_inv_DFT_size_4_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
int src_stride2 = src_stride + src_stride ;
int src_stride3 = src_stride2 + src_stride ;
uint64_t x0 = src[ 0];
uint64_t x1 = src[src_stride ];
uint64_t x2 = src[src_stride2];
uint64_t x3 = src[src_stride3];
uint64_t t1 = goldilocks_add( x0 , x2 ); // x0 + x2
uint64_t t2 = goldilocks_add( x1 , x3 ); // x1 + x3
uint64_t m0 = goldilocks_add( t1 , t2 ); // t1 + t2
uint64_t m1 = goldilocks_sub( t1 , t2 ); // t1 - t2
uint64_t m2 = goldilocks_sub( x0 , x2 ); // x0 - x2
uint64_t m3 = goldilocks_mul( DFT4_J , goldilocks_sub( x3 , x1 ) ); // j * (x3 - x1)
int tgt_stride2 = tgt_stride + tgt_stride ;
int tgt_stride3 = tgt_stride2 + tgt_stride ;
tgt[ 0] = m0;
tgt[tgt_stride ] = goldilocks_add( m2 , m3 );
tgt[tgt_stride2] = m1;
tgt[tgt_stride3] = goldilocks_sub( m2 , m3 );
}
void short_fwd_DFT_size_4( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
short_inv_DFT_size_4_unscaled( src_stride, tgt_stride, src, tgt );
int tgt_stride3 = 3*tgt_stride;
uint64_t tmp = tgt[tgt_stride ];
tgt[tgt_stride ] = tgt[tgt_stride3];
tgt[tgt_stride3] = tmp;
}
void short_inv_DFT_size_4_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
short_inv_DFT_size_4_unscaled( src_stride, tgt_stride, src, tgt );
int tgt_stride2 = tgt_stride + tgt_stride ;
int tgt_stride3 = tgt_stride2 + tgt_stride ;
tgt[ 0] = goldilocks_mul( IDFT4_INV_4 , tgt[ 0] );
tgt[tgt_stride ] = goldilocks_mul( IDFT4_INV_4 , tgt[tgt_stride ] );
tgt[tgt_stride2] = goldilocks_mul( IDFT4_INV_4 , tgt[tgt_stride2] );
tgt[tgt_stride3] = goldilocks_mul( IDFT4_INV_4 , tgt[tgt_stride3] );
}
//------------------------------------------------------------------------------
// SIZE = 8
const uint64_t DFT8_OMEGA = 0xfffffffeff000001 ;
const uint64_t DFT8_INV_OMEGA = 0x000000ffffffff00 ;
const uint64_t DFT8_J = 0x0001000000000000 ;
const uint64_t DFT8_COS_U = 0xffffff7f00800081 ;
const uint64_t DFT8_MINUS_J_SIN_U = 0xffffff7eff800081 ;
void short_fwd_DFT_size_8( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
// u = 2pi/8
// omega = cos(u) + i*sin(u)
//
// cos_u ~> - (omega+omega^7) / 2
// j_sin_u ~> - (omega-omega^7) / 2
// -j_sin_u ~> + (omega-omega^7) / 2
// j ~> omega^2
int src_stride2 = src_stride + src_stride ;
int src_stride3 = src_stride2 + src_stride ;
int src_stride4 = src_stride2 + src_stride2;
int src_stride5 = src_stride4 + src_stride ;
int src_stride6 = src_stride4 + src_stride2;
int src_stride7 = src_stride4 + src_stride3;
uint64_t x0 = src[0 ];
uint64_t x1 = src[src_stride ];
uint64_t x2 = src[src_stride2];
uint64_t x3 = src[src_stride3];
uint64_t x4 = src[src_stride4];
uint64_t x5 = src[src_stride5];
uint64_t x6 = src[src_stride6];
uint64_t x7 = src[src_stride7];
uint64_t t1 = goldilocks_add( x0 , x4 ); // x0 + x4
uint64_t t2 = goldilocks_add( x2 , x6 ); // x2 + x6
uint64_t t3 = goldilocks_add( x1 , x5 ); // x1 + x5
uint64_t t4 = goldilocks_sub( x1 , x5 ); // x1 - x5
uint64_t t5 = goldilocks_add( x3 , x7 ); // x3 + x7
uint64_t t6 = goldilocks_sub( x3 , x7 ); // x3 - x7
uint64_t t7 = goldilocks_add( t1 , t2 ); // t1 + t2
uint64_t t8 = goldilocks_add( t3 , t5 ); // t3 + t5
uint64_t m0 = goldilocks_add( t7 , t8 ); // (t7 + t8)
uint64_t m1 = goldilocks_sub( t7 , t8 ); // (t7 - t8)
uint64_t m2 = goldilocks_sub( t1 , t2 ); // (t1 - t2)
uint64_t m3 = goldilocks_sub( x0 , x4 ); // (x0 - x4)
uint64_t m4 = goldilocks_mul( DFT8_COS_U , goldilocks_sub( t4 , t6 ) ); // cos_u * (t4 - t6)
uint64_t m5 = goldilocks_mul( DFT8_J , goldilocks_sub( t5 , t3 ) ); // j*(t5 - t3)
uint64_t m6 = goldilocks_mul( DFT8_J , goldilocks_sub( x6 , x2 ) ); // j*(x6 - x2)
uint64_t m7 = goldilocks_mul( DFT8_MINUS_J_SIN_U , goldilocks_add( t4 , t6 ) ); // - j_sin_u * (t4 + t6)
uint64_t s1 = goldilocks_add( m3 , m4 ); // m3 + m4
uint64_t s2 = goldilocks_sub( m3 , m4 ); // m3 - m4
uint64_t s3 = goldilocks_add( m6 , m7 ); // m6 + m7
uint64_t s4 = goldilocks_sub( m6 , m7 ); // m6 - m7
int tgt_stride2 = tgt_stride + tgt_stride ;
int tgt_stride3 = tgt_stride2 + tgt_stride ;
int tgt_stride4 = tgt_stride2 + tgt_stride2;
int tgt_stride5 = tgt_stride4 + tgt_stride ;
int tgt_stride6 = tgt_stride4 + tgt_stride2;
int tgt_stride7 = tgt_stride4 + tgt_stride3;
tgt[ 0] = m0; // m0
tgt[tgt_stride ] = goldilocks_sub( s2 , s4 ); // s2 - s4
tgt[tgt_stride2] = goldilocks_sub( m2 , m5 ); // m2 - m5
tgt[tgt_stride3] = goldilocks_add( s1 , s3 ); // s1 + s3
tgt[tgt_stride4] = m1; // m1
tgt[tgt_stride5] = goldilocks_sub( s1 , s3 ); // s1 - s3
tgt[tgt_stride6] = goldilocks_add( m2 , m5 ); // m2 + m5
tgt[tgt_stride7] = goldilocks_add( s2 , s4 ); // s2 + s4
}
const uint64_t IDFT8_OMEGA = 0xfffffffeff000001 ;
const uint64_t IDFT8_INV_OMEGA = 0x000000ffffffff00 ;
const uint64_t IDFT8_J = 0x0001000000000000 ;
const uint64_t IDFT8_COS_U = 0x0000007fff7fff80 ;
const uint64_t IDFT8_MINUS_J_SIN_U = 0x00000080007fff80 ;
const uint64_t IDFT8_INV_8 = 0xdfffffff20000001 ;
//--------------------------------------
void short_inv_DFT_size_8_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
// u = 2pi/8
// omega = cos(u) + i*sin(u)
//
// cos_u ~> (omega+omega^7) / 2
// -j_sin_u ~> - (omega-omega^7) / 2
// j ~> omega^2
int src_stride2 = src_stride + src_stride ;
int src_stride3 = src_stride2 + src_stride ;
int src_stride4 = src_stride2 + src_stride2;
int src_stride5 = src_stride4 + src_stride ;
int src_stride6 = src_stride4 + src_stride2;
int src_stride7 = src_stride4 + src_stride3;
uint64_t x0 = src[0 ];
uint64_t x1 = src[src_stride ];
uint64_t x2 = src[src_stride2];
uint64_t x3 = src[src_stride3];
uint64_t x4 = src[src_stride4];
uint64_t x5 = src[src_stride5];
uint64_t x6 = src[src_stride6];
uint64_t x7 = src[src_stride7];
uint64_t t1 = goldilocks_add( x0 , x4 ); // x0 + x4
uint64_t t2 = goldilocks_add( x2 , x6 ); // x2 + x6
uint64_t t3 = goldilocks_add( x1 , x5 ); // x1 + x5
uint64_t t4 = goldilocks_sub( x1 , x5 ); // x1 - x5
uint64_t t5 = goldilocks_add( x3 , x7 ); // x3 + x7
uint64_t t6 = goldilocks_sub( x3 , x7 ); // x3 - x7
uint64_t t7 = goldilocks_add( t1 , t2 ); // t1 + t2
uint64_t t8 = goldilocks_add( t3 , t5 ); // t3 + t5
uint64_t m0 = goldilocks_add( t7 , t8 ); // (t7 + t8)
uint64_t m1 = goldilocks_sub( t7 , t8 ); // (t7 - t8)
uint64_t m2 = goldilocks_sub( t1 , t2 ); // (t1 - t2)
uint64_t m3 = goldilocks_sub( x0 , x4 ); // (x0 - x4)
uint64_t m4 = goldilocks_mul( IDFT8_COS_U , goldilocks_sub( t4 , t6 ) ); // cos_u * (t4 - t6)
uint64_t m5 = goldilocks_mul( IDFT8_J , goldilocks_sub( t5 , t3 ) ); // j*(t5 - t3)
uint64_t m6 = goldilocks_mul( IDFT8_J , goldilocks_sub( x6 , x2 ) ); // j*(x6 - x2)
uint64_t m7 = goldilocks_mul( IDFT8_MINUS_J_SIN_U , goldilocks_add( t4 , t6 ) ); // - j_sin_u * (t4 + t6)
uint64_t s1 = goldilocks_add( m3 , m4 ); // m3 + m4
uint64_t s2 = goldilocks_sub( m3 , m4 ); // m3 - m4
uint64_t s3 = goldilocks_add( m6 , m7 ); // m6 + m7
uint64_t s4 = goldilocks_sub( m6 , m7 ); // m6 - m7
int tgt_stride2 = tgt_stride + tgt_stride ;
int tgt_stride3 = tgt_stride2 + tgt_stride ;
int tgt_stride4 = tgt_stride2 + tgt_stride2;
int tgt_stride5 = tgt_stride4 + tgt_stride ;
int tgt_stride6 = tgt_stride4 + tgt_stride2;
int tgt_stride7 = tgt_stride4 + tgt_stride3;
tgt[ 0] = m0 ; // m0
tgt[tgt_stride ] = goldilocks_add( s1 , s3 ); // s1 + s3
tgt[tgt_stride2] = goldilocks_add( m2 , m5 ); // m2 + m5
tgt[tgt_stride3] = goldilocks_sub( s2 , s4 ); // s2 - s4
tgt[tgt_stride4] = m1 ; // m1
tgt[tgt_stride5] = goldilocks_add( s2 , s4 ); // s2 + s4
tgt[tgt_stride6] = goldilocks_sub( m2 , m5 ); // m2 - m5
tgt[tgt_stride7] = goldilocks_sub( s1 , s3 ); // s1 - s3
}
//------------------
void short_inv_DFT_size_8_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
short_inv_DFT_size_8_unscaled( src_stride, tgt_stride, src, tgt );
int tgt_stride2 = tgt_stride + tgt_stride ;
int tgt_stride3 = tgt_stride2 + tgt_stride ;
int tgt_stride4 = tgt_stride2 + tgt_stride2;
int tgt_stride5 = tgt_stride4 + tgt_stride ;
int tgt_stride6 = tgt_stride4 + tgt_stride2;
int tgt_stride7 = tgt_stride4 + tgt_stride3;
tgt[ 0] = goldilocks_mul( IDFT8_INV_8 , tgt[ 0] );
tgt[tgt_stride ] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride ] );
tgt[tgt_stride2] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride2] );
tgt[tgt_stride3] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride3] );
tgt[tgt_stride4] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride4] );
tgt[tgt_stride5] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride5] );
tgt[tgt_stride6] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride6] );
tgt[tgt_stride7] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride7] );
}
//------------------------------------------------------------------------------
// SIZE = 16
const uint64_t DFT16_OMEGA = 0xefffffff00000001 ;
const uint64_t DFT16_INV_OMEGA = 0x0000001000000000 ;
const uint64_t DFT16_J = 0x0001000000000000 ;
const uint64_t DFT16_COS_U = 0xf800000700000001 ;
const uint64_t DFT16_COS_2U = 0x0000007fff7fff80 ;
const uint64_t DFT16_COS_3U = 0x0007fffffff7f800 ;
const uint64_t DFT16_MINUS_J_SIN_U = 0x0800000800000000 ;
const uint64_t DFT16_MINUS_J_SIN_2U = 0x00000080007fff80 ;
const uint64_t DFT16_MINUS_J_SIN_3U = 0xfff7ffff0007f801 ;
const uint64_t DFT16_COS_3U_PLUS_U = 0xf8080006fff7f801 ;
const uint64_t DFT16_COS_3U_MINUS_U = 0x0807fff7fff7f800 ;
const uint64_t DFT16_J_SIN_3U_MINUS_U = 0x08080007fff80800 ;
const uint64_t DFT16_J_SIN_MINUS_3U_MINUS_U = 0x07f800080007f800 ;
void short_fwd_DFT_size_16( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
int src_stride2 = src_stride + src_stride ;
int src_stride3 = src_stride2 + src_stride ;
int src_stride4 = src_stride2 + src_stride2;
int src_stride5 = src_stride4 + src_stride ;
int src_stride6 = src_stride4 + src_stride2;
int src_stride7 = src_stride4 + src_stride3;
int src_stride8 = src_stride4 + src_stride4;
uint64_t x0 = src[ 0 ];
uint64_t x1 = src[ src_stride ];
uint64_t x2 = src[ src_stride2 ];
uint64_t x3 = src[ src_stride3 ];
uint64_t x4 = src[ src_stride4 ];
uint64_t x5 = src[ src_stride5 ];
uint64_t x6 = src[ src_stride6 ];
uint64_t x7 = src[ src_stride7 ];
uint64_t x8 = src[ src_stride8 ];
uint64_t x9 = src[ src_stride + src_stride8 ];
uint64_t x10 = src[ src_stride2 + src_stride8 ];
uint64_t x11 = src[ src_stride3 + src_stride8 ];
uint64_t x12 = src[ src_stride4 + src_stride8 ];
uint64_t x13 = src[ src_stride5 + src_stride8 ];
uint64_t x14 = src[ src_stride6 + src_stride8 ];
uint64_t x15 = src[ src_stride7 + src_stride8 ];
uint64_t t1 = goldilocks_add( x0 , x8 ); // x0 + x8
uint64_t t2 = goldilocks_add( x4 , x12 ); // x4 + x12
uint64_t t3 = goldilocks_add( x2 , x10 ); // x2 + x10
uint64_t t4 = goldilocks_sub( x2 , x10 ); // x2 - x10
uint64_t t5 = goldilocks_add( x6 , x14 ); // x6 + x14
uint64_t t6 = goldilocks_sub( x6 , x14 ); // x6 - x14
uint64_t t7 = goldilocks_add( x1 , x9 ); // x1 + x9
uint64_t t8 = goldilocks_sub( x1 , x9 ); // x1 - x9
uint64_t t9 = goldilocks_add( x3 , x11 ); // x3 + x11
uint64_t t10 = goldilocks_sub( x3 , x11 ); // x3 - x11
uint64_t t11 = goldilocks_add( x5 , x13 ); // x5 + x13
uint64_t t12 = goldilocks_sub( x5 , x13 ); // x5 - x13
uint64_t t13 = goldilocks_add( x7 , x15 ); // x7 + x15
uint64_t t14 = goldilocks_sub( x7 , x15 ); // x7 - x15
uint64_t t15 = goldilocks_add( t1 , t2 ); // t1 + t2
uint64_t t16 = goldilocks_add( t3 , t5 ); // t3 + t5
uint64_t t17 = goldilocks_add( t15 , t16 ); // t15 + t16
uint64_t t18 = goldilocks_add( t7 , t11 ); // t7 + t11
uint64_t t19 = goldilocks_sub( t7 , t11 ); // t7 - t11
uint64_t t20 = goldilocks_add( t9 , t13 ); // t9 + t13
uint64_t t21 = goldilocks_sub( t9 , t13 ); // t9 - t13
uint64_t t22 = goldilocks_add( t18 , t20 ); // t18 + t20
uint64_t t23 = goldilocks_add( t8 , t14 ); // t8 + t14
uint64_t t24 = goldilocks_sub( t8 , t14 ); // t8 - t14
uint64_t t25 = goldilocks_add( t10 , t12 ); // t10 + t12
uint64_t t26 = goldilocks_sub( t12 , t10 ); // t12 - t10
uint64_t m0 = goldilocks_add( t17 , t22 ); // t17 + t22
uint64_t m1 = goldilocks_sub( t17 , t22 ); // t17 - t22
uint64_t m2 = goldilocks_sub( t15 , t16 ); // t15 - t16
uint64_t m3 = goldilocks_sub( t1 , t2 ); // t1 - t2
uint64_t m4 = goldilocks_sub( x0 , x8 ); // x0 - x8
uint64_t m5 = goldilocks_mul( DFT16_COS_2U , goldilocks_sub(t19 , t21) );
uint64_t m6 = goldilocks_mul( DFT16_COS_2U , goldilocks_sub(t4 , t6 ) );
uint64_t m7 = goldilocks_mul( DFT16_COS_3U , goldilocks_add(t24 , t26) );
uint64_t m8 = goldilocks_mul( DFT16_COS_3U_PLUS_U , t24 );
uint64_t m9 = goldilocks_mul( DFT16_COS_3U_MINUS_U , t26 );
uint64_t m10 = goldilocks_mul( DFT16_J , goldilocks_sub(t20 , t18) );
uint64_t m11 = goldilocks_mul( DFT16_J , goldilocks_sub(t5 , t3 ) );
uint64_t m12 = goldilocks_mul( DFT16_J , goldilocks_sub(x12 , x4 ) );
uint64_t m13 = goldilocks_mul( DFT16_MINUS_J_SIN_2U , goldilocks_add( t19 , t21) );
uint64_t m14 = goldilocks_mul( DFT16_MINUS_J_SIN_2U , goldilocks_add( t4 , t6 ) );
uint64_t m15 = goldilocks_mul( DFT16_MINUS_J_SIN_3U , goldilocks_add( t23 , t25) );
uint64_t m16 = goldilocks_mul( DFT16_J_SIN_3U_MINUS_U , t23 );
uint64_t m17 = goldilocks_mul( DFT16_J_SIN_MINUS_3U_MINUS_U , t25 );
uint64_t s1 = goldilocks_add( m3 , m5 ); // m3 + m5
uint64_t s2 = goldilocks_sub( m3 , m5 ); // m3 - m5
uint64_t s3 = goldilocks_add( m11 , m13 ); // m11 + m13
uint64_t s4 = goldilocks_sub( m13 , m11 ); // m13 - m11
uint64_t s5 = goldilocks_add( m4 , m6 ); // m4 + m6
uint64_t s6 = goldilocks_sub( m4 , m6 ); // m4 - m6
uint64_t s7 = goldilocks_sub( m8 , m7 ); // m8 - m7
uint64_t s8 = goldilocks_sub( m9 , m7 ); // m9 - m7
uint64_t s9 = goldilocks_add( s5 , s7 ); // s5 + s7
uint64_t s10 = goldilocks_sub( s5 , s7 ); // s5 - s7
uint64_t s11 = goldilocks_add( s6 , s8 ); // s6 + s8
uint64_t s12 = goldilocks_sub( s6 , s8 ); // s6 - s8
uint64_t s13 = goldilocks_add( m12 , m14 ); // m12 + m14
uint64_t s14 = goldilocks_sub( m12 , m14 ); // m12 - m14
uint64_t s15 = goldilocks_add( m15 , m16 ); // m15 + m16
uint64_t s16 = goldilocks_sub( m15 , m17 ); // m15 - m17
uint64_t s17 = goldilocks_add( s13 , s15 ); // s13 + s15
uint64_t s18 = goldilocks_sub( s13 , s15 ); // s13 - s15
uint64_t s19 = goldilocks_add( s14 , s16 ); // s14 + s16
uint64_t s20 = goldilocks_sub( s14 , s16 ); // s14 - s16
int tgt_stride2 = tgt_stride + tgt_stride ;
int tgt_stride3 = tgt_stride2 + tgt_stride ;
int tgt_stride4 = tgt_stride2 + tgt_stride2;
int tgt_stride5 = tgt_stride4 + tgt_stride ;
int tgt_stride6 = tgt_stride4 + tgt_stride2;
int tgt_stride7 = tgt_stride4 + tgt_stride3;
int tgt_stride8 = tgt_stride4 + tgt_stride4;
tgt[ 0 ] = m0; // m0
tgt[ tgt_stride ] = goldilocks_sub( s9 , s17 ); // s9 - s17
tgt[ tgt_stride2 ] = goldilocks_sub( s1 , s3 ); // s1 - s3
tgt[ tgt_stride3 ] = goldilocks_add( s12 , s20 ); // s12 + s20
tgt[ tgt_stride4 ] = goldilocks_sub( m2 , m10 ); // m2 - m10
tgt[ tgt_stride5 ] = goldilocks_sub( s11 , s19 ); // s11 - s19
tgt[ tgt_stride6 ] = goldilocks_sub( s2 , s4 ); // s2 - s4
tgt[ tgt_stride7 ] = goldilocks_add( s10 , s18 ); // s10 + s18
tgt[ tgt_stride8 ] = m1; // m1
tgt[ tgt_stride + tgt_stride8 ] = goldilocks_sub( s10 , s18 ); // s10 - s18
tgt[ tgt_stride2 + tgt_stride8 ] = goldilocks_add( s2 , s4 ); // s2 + s4
tgt[ tgt_stride3 + tgt_stride8 ] = goldilocks_add( s11 , s19 ); // s11 + s19
tgt[ tgt_stride4 + tgt_stride8 ] = goldilocks_add( m2 , m10 ); // m2 + m10
tgt[ tgt_stride5 + tgt_stride8 ] = goldilocks_sub( s12 , s20 ); // s12 - s20
tgt[ tgt_stride6 + tgt_stride8 ] = goldilocks_add( s1 , s3 ); // s1 + s3
tgt[ tgt_stride7 + tgt_stride8 ] = goldilocks_add( s9 , s17 ); // s9 + s17
}
//--------------------------------------
const uint64_t IDFT16_OMEGA = 0xefffffff00000001 ;
const uint64_t IDFT16_INV_OMEGA = 0x0000001000000000 ;
const uint64_t IDFT16_INV_16 = 0xefffffff10000001 ;
const uint64_t IDFT16_J = 0x0001000000000000 ;
const uint64_t IDFT16_COS_U = 0xf800000700000001 ;
const uint64_t IDFT16_COS_2U = 0x0000007fff7fff80 ;
const uint64_t IDFT16_COS_3U = 0x0007fffffff7f800 ;
const uint64_t IDFT16_MINUS_J_SIN_U = 0x0800000800000000 ;
const uint64_t IDFT16_MINUS_J_SIN_2U = 0x00000080007fff80 ;
const uint64_t IDFT16_MINUS_J_SIN_3U = 0xfff7ffff0007f801 ;
const uint64_t IDFT16_COS_3U_PLUS_U = 0xf8080006fff7f801 ;
const uint64_t IDFT16_COS_3U_MINUS_U = 0x0807fff7fff7f800 ;
const uint64_t IDFT16_J_SIN_3U_MINUS_U = 0x08080007fff80800 ;
const uint64_t IDFT16_J_SIN_MINUS_3U_MINUS_U = 0x07f800080007f800 ;
void short_inv_DFT_size_16_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
int src_stride2 = src_stride + src_stride ;
int src_stride3 = src_stride2 + src_stride ;
int src_stride4 = src_stride2 + src_stride2;
int src_stride5 = src_stride4 + src_stride ;
int src_stride6 = src_stride4 + src_stride2;
int src_stride7 = src_stride4 + src_stride3;
int src_stride8 = src_stride4 + src_stride4;
uint64_t x0 = src[ 0 ];
uint64_t x1 = src[ src_stride ];
uint64_t x2 = src[ src_stride2 ];
uint64_t x3 = src[ src_stride3 ];
uint64_t x4 = src[ src_stride4 ];
uint64_t x5 = src[ src_stride5 ];
uint64_t x6 = src[ src_stride6 ];
uint64_t x7 = src[ src_stride7 ];
uint64_t x8 = src[ src_stride8 ];
uint64_t x9 = src[ src_stride + src_stride8 ];
uint64_t x10 = src[ src_stride2 + src_stride8 ];
uint64_t x11 = src[ src_stride3 + src_stride8 ];
uint64_t x12 = src[ src_stride4 + src_stride8 ];
uint64_t x13 = src[ src_stride5 + src_stride8 ];
uint64_t x14 = src[ src_stride6 + src_stride8 ];
uint64_t x15 = src[ src_stride7 + src_stride8 ];
uint64_t t1 = goldilocks_add( x0 , x8 ); // x0 + x8
uint64_t t2 = goldilocks_add( x4 , x12 ); // x4 + x12
uint64_t t3 = goldilocks_add( x2 , x10 ); // x2 + x10
uint64_t t4 = goldilocks_sub( x2 , x10 ); // x2 - x10
uint64_t t5 = goldilocks_add( x6 , x14 ); // x6 + x14
uint64_t t6 = goldilocks_sub( x6 , x14 ); // x6 - x14
uint64_t t7 = goldilocks_add( x1 , x9 ); // x1 + x9
uint64_t t8 = goldilocks_sub( x1 , x9 ); // x1 - x9
uint64_t t9 = goldilocks_add( x3 , x11 ); // x3 + x11
uint64_t t10 = goldilocks_sub( x3 , x11 ); // x3 - x11
uint64_t t11 = goldilocks_add( x5 , x13 ); // x5 + x13
uint64_t t12 = goldilocks_sub( x5 , x13 ); // x5 - x13
uint64_t t13 = goldilocks_add( x7 , x15 ); // x7 + x15
uint64_t t14 = goldilocks_sub( x7 , x15 ); // x7 - x15
uint64_t t15 = goldilocks_add( t1 , t2 ); // t1 + t2
uint64_t t16 = goldilocks_add( t3 , t5 ); // t3 + t5
uint64_t t17 = goldilocks_add( t15 , t16 ); // t15 + t16
uint64_t t18 = goldilocks_add( t7 , t11 ); // t7 + t11
uint64_t t19 = goldilocks_sub( t7 , t11 ); // t7 - t11
uint64_t t20 = goldilocks_add( t9 , t13 ); // t9 + t13
uint64_t t21 = goldilocks_sub( t9 , t13 ); // t9 - t13
uint64_t t22 = goldilocks_add( t18 , t20 ); // t18 + t20
uint64_t t23 = goldilocks_add( t8 , t14 ); // t8 + t14
uint64_t t24 = goldilocks_sub( t8 , t14 ); // t8 - t14
uint64_t t25 = goldilocks_add( t10 , t12 ); // t10 + t12
uint64_t t26 = goldilocks_sub( t12 , t10 ); // t12 - t10
uint64_t m0 = goldilocks_add( t17 , t22 ); // t17 + t22
uint64_t m1 = goldilocks_sub( t17 , t22 ); // t17 - t22
uint64_t m2 = goldilocks_sub( t15 , t16 ); // t15 - t16
uint64_t m3 = goldilocks_sub( t1 , t2 ); // t1 - t2
uint64_t m4 = goldilocks_sub( x0 , x8 ); // x0 - x8
uint64_t m5 = goldilocks_mul( IDFT16_COS_2U , goldilocks_sub(t19 , t21) );
uint64_t m6 = goldilocks_mul( IDFT16_COS_2U , goldilocks_sub(t4 , t6 ) );
uint64_t m7 = goldilocks_mul( IDFT16_COS_3U , goldilocks_add(t24 , t26) );
uint64_t m8 = goldilocks_mul( IDFT16_COS_3U_PLUS_U , t24 );
uint64_t m9 = goldilocks_mul( IDFT16_COS_3U_MINUS_U , t26 );
uint64_t m10 = goldilocks_mul( IDFT16_J , goldilocks_sub(t20 , t18) );
uint64_t m11 = goldilocks_mul( IDFT16_J , goldilocks_sub(t5 , t3 ) );
uint64_t m12 = goldilocks_mul( IDFT16_J , goldilocks_sub(x12 , x4 ) );
uint64_t m13 = goldilocks_mul( IDFT16_MINUS_J_SIN_2U , goldilocks_add( t19 , t21) );
uint64_t m14 = goldilocks_mul( IDFT16_MINUS_J_SIN_2U , goldilocks_add( t4 , t6 ) );
uint64_t m15 = goldilocks_mul( IDFT16_MINUS_J_SIN_3U , goldilocks_add( t23 , t25) );
uint64_t m16 = goldilocks_mul( IDFT16_J_SIN_3U_MINUS_U , t23 );
uint64_t m17 = goldilocks_mul( IDFT16_J_SIN_MINUS_3U_MINUS_U , t25 );
uint64_t s1 = goldilocks_add( m3 , m5 ); // m3 + m5
uint64_t s2 = goldilocks_sub( m3 , m5 ); // m3 - m5
uint64_t s3 = goldilocks_add( m11 , m13 ); // m11 + m13
uint64_t s4 = goldilocks_sub( m13 , m11 ); // m13 - m11
uint64_t s5 = goldilocks_add( m4 , m6 ); // m4 + m6
uint64_t s6 = goldilocks_sub( m4 , m6 ); // m4 - m6
uint64_t s7 = goldilocks_sub( m8 , m7 ); // m8 - m7
uint64_t s8 = goldilocks_sub( m9 , m7 ); // m9 - m7
uint64_t s9 = goldilocks_add( s5 , s7 ); // s5 + s7
uint64_t s10 = goldilocks_sub( s5 , s7 ); // s5 - s7
uint64_t s11 = goldilocks_add( s6 , s8 ); // s6 + s8
uint64_t s12 = goldilocks_sub( s6 , s8 ); // s6 - s8
uint64_t s13 = goldilocks_add( m12 , m14 ); // m12 + m14
uint64_t s14 = goldilocks_sub( m12 , m14 ); // m12 - m14
uint64_t s15 = goldilocks_add( m15 , m16 ); // m15 + m16
uint64_t s16 = goldilocks_sub( m15 , m17 ); // m15 - m17
uint64_t s17 = goldilocks_add( s13 , s15 ); // s13 + s15
uint64_t s18 = goldilocks_sub( s13 , s15 ); // s13 - s15
uint64_t s19 = goldilocks_add( s14 , s16 ); // s14 + s16
uint64_t s20 = goldilocks_sub( s14 , s16 ); // s14 - s16
int tgt_stride2 = tgt_stride + tgt_stride ;
int tgt_stride3 = tgt_stride2 + tgt_stride ;
int tgt_stride4 = tgt_stride2 + tgt_stride2;
int tgt_stride5 = tgt_stride4 + tgt_stride ;
int tgt_stride6 = tgt_stride4 + tgt_stride2;
int tgt_stride7 = tgt_stride4 + tgt_stride3;
int tgt_stride8 = tgt_stride4 + tgt_stride4;
tgt[ 0 ] = m0; // m0
tgt[ tgt_stride ] = goldilocks_add( s9 , s17 ); // s9 + s17
tgt[ tgt_stride2 ] = goldilocks_add( s1 , s3 ); // s1 + s3
tgt[ tgt_stride3 ] = goldilocks_sub( s12 , s20 ); // s12 - s20
tgt[ tgt_stride4 ] = goldilocks_add( m2 , m10 ); // m2 + m10
tgt[ tgt_stride5 ] = goldilocks_add( s11 , s19 ); // s11 + s19
tgt[ tgt_stride6 ] = goldilocks_add( s2 , s4 ); // s2 + s4
tgt[ tgt_stride7 ] = goldilocks_sub( s10 , s18 ); // s10 - s18
tgt[ tgt_stride8 ] = m1; // m1
tgt[ tgt_stride + tgt_stride8 ] = goldilocks_add( s10 , s18 ); // s10 + s18
tgt[ tgt_stride2 + tgt_stride8 ] = goldilocks_sub( s2 , s4 ); // s2 - s4
tgt[ tgt_stride3 + tgt_stride8 ] = goldilocks_sub( s11 , s19 ); // s11 - s19
tgt[ tgt_stride4 + tgt_stride8 ] = goldilocks_sub( m2 , m10 ); // m2 - m10
tgt[ tgt_stride5 + tgt_stride8 ] = goldilocks_add( s12 , s20 ); // s12 + s20
tgt[ tgt_stride6 + tgt_stride8 ] = goldilocks_sub( s1 , s3 ); // s1 - s3
tgt[ tgt_stride7 + tgt_stride8 ] = goldilocks_sub( s9 , s17 ); // s9 - s17
}
//------------------
void short_inv_DFT_size_16_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) {
short_inv_DFT_size_16_unscaled( src_stride, tgt_stride, src, tgt );
int tgt_stride2 = tgt_stride + tgt_stride ;
int tgt_stride3 = tgt_stride2 + tgt_stride ;
int tgt_stride4 = tgt_stride2 + tgt_stride2;
int tgt_stride5 = tgt_stride4 + tgt_stride ;
int tgt_stride6 = tgt_stride4 + tgt_stride2;
int tgt_stride7 = tgt_stride4 + tgt_stride3;
int tgt_stride8 = tgt_stride4 + tgt_stride4;
tgt[ 0 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ 0 ] );
tgt[ tgt_stride ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride ] );
tgt[ tgt_stride2 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride2 ] );
tgt[ tgt_stride3 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride3 ] );
tgt[ tgt_stride4 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride4 ] );
tgt[ tgt_stride5 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride5 ] );
tgt[ tgt_stride6 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride6 ] );
tgt[ tgt_stride7 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride7 ] );
tgt[ tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride8 ] );
tgt[ tgt_stride + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride + tgt_stride8 ] );
tgt[ tgt_stride2 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride2 + tgt_stride8 ] );
tgt[ tgt_stride3 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride3 + tgt_stride8 ] );
tgt[ tgt_stride4 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride4 + tgt_stride8 ] );
tgt[ tgt_stride5 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride5 + tgt_stride8 ] );
tgt[ tgt_stride6 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride6 + tgt_stride8 ] );
tgt[ tgt_stride7 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride7 + tgt_stride8 ] );
}
//------------------------------------------------------------------------------

View File

@ -0,0 +1,24 @@
#include <stdint.h>
//------------------------------------------------------------------------------
void short_fwd_DFT_size_4 ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
void short_inv_DFT_size_4_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
void short_inv_DFT_size_4_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
void short_fwd_DFT_size_8 ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
void short_inv_DFT_size_8_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
void short_inv_DFT_size_8_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
void short_fwd_DFT_size_16 ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
void short_inv_DFT_size_16_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
void short_inv_DFT_size_16_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
//------------------------------------------------------------------------------
// void short_fwd_DFT_size_4_ext ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
// void short_inv_DFT_size_4_ext_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
// void short_inv_DFT_size_4_ext_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt );
//------------------------------------------------------------------------------

View File

@ -1,3 +1,3 @@
#!/bin/bash
ghci testMain.hs cbits/goldilocks.o cbits/goldilocks_ext.o cbits/monolith.o cbits/ntt.o
ghci testMain.hs cbits/goldilocks.o cbits/goldilocks_ext.o cbits/monolith.o cbits/ntt.o cbits/short_dft.o