From 0eb39eb5c9f8d1a6121dd3d891783cb84e44d278 Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Tue, 4 Nov 2025 22:40:33 +0100 Subject: [PATCH] specialized algorithms for short NTT/INTT (size 4, 8, 16) --- reference/src/NTT/FFT/Short.hs | 590 ++++++++++++++++++++++++++++++++ reference/src/cbits/compile.sh | 1 + reference/src/cbits/short_dft.c | 566 ++++++++++++++++++++++++++++++ reference/src/cbits/short_dft.h | 24 ++ reference/src/runi.sh | 2 +- 5 files changed, 1182 insertions(+), 1 deletion(-) create mode 100644 reference/src/NTT/FFT/Short.hs create mode 100644 reference/src/cbits/short_dft.c create mode 100644 reference/src/cbits/short_dft.h diff --git a/reference/src/NTT/FFT/Short.hs b/reference/src/NTT/FFT/Short.hs new file mode 100644 index 0000000..1b5cc47 --- /dev/null +++ b/reference/src/NTT/FFT/Short.hs @@ -0,0 +1,590 @@ + +-- Short DFT algorithms (sizes 2, 4, 8 and 16) +-- +-- See : +-- +-- * Nussbaumer: "Fast Fourier Transform and Convolution Algorithms", Chapter 5.5 +-- + +{-# LANGUAGE ForeignFunctionInterface #-} +module NTT.FFT.Short where + +-------------------------------------------------------------------------------- + +import Data.Array +import Data.Word +import Data.List ( sort ) + +import Control.Monad + +import Foreign.C +import Foreign.Ptr +import Foreign.Marshal +import System.IO.Unsafe + +import Class.Field +import Field.Goldilocks ( F , fromF ) + +import NTT.FFT.Slow +import NTT.Poly.Naive + +import NTT.Subgroup +import Misc + +-------------------------------------------------------------------------------- +-- DFT mathematical definition - O(n^2) + +defForwardDFT :: Subgroup F -> Array Int F -> Array Int F +defForwardDFT sg input + | subgroupOrder sg /= n = error "defForwardDFT: input size does not match the subgroup" + | otherwise = listToArray [ f k | k <-[0..n-1] ] + where + n = arrayLength input + omega = subgroupGen sg + f k = sum [ input!i * power_ omega (k*i) | i<-[0..n-1] ] + +defInverseDFT :: Subgroup F -> Array Int F -> Array Int F +defInverseDFT sg input + | subgroupOrder sg /= n = error "defInverseDFT: input size does not match the subgroup" + | otherwise = listToArray [ f k / (fromIntegral n) | k <-[0..n-1] ] + where + n = arrayLength input + omega = subgroupGen sg + invomega = inverse omega + f k = sum [ input!i * power_ invomega (k*i) | i<-[0..n-1] ] + +-------------------------------------------------------------------------------- + +foreign import ccall unsafe "short_fwd_DFT_size_4" c_dft4_fwd :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO () +foreign import ccall unsafe "short_inv_DFT_size_4_unscaled" c_dft4_inv_unscaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO () +foreign import ccall unsafe "short_inv_DFT_size_4_rescaled" c_dft4_inv_rescaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO () + +foreign import ccall unsafe "short_fwd_DFT_size_8" c_dft8_fwd :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO () +foreign import ccall unsafe "short_inv_DFT_size_8_unscaled" c_dft8_inv_unscaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO () +foreign import ccall unsafe "short_inv_DFT_size_8_rescaled" c_dft8_inv_rescaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO () + +foreign import ccall unsafe "short_fwd_DFT_size_16" c_dft16_fwd :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO () +foreign import ccall unsafe "short_inv_DFT_size_16_unscaled" c_dft16_inv_unscaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO () +foreign import ccall unsafe "short_inv_DFT_size_16_rescaled" c_dft16_inv_rescaled :: CInt -> CInt -> Ptr Word64 -> Ptr Word64 -> IO () + +---------------------------------------- + +{-# NOINLINE shortForwardDFT4 #-} +shortForwardDFT4 :: Array Int F -> Array Int F +shortForwardDFT4 input + | arrayLength input /= 4 = error "shortForwardDFT4: expecting an input array of size 4" + | otherwise = unsafePerformIO $ do + allocaArray 4 $ \(ptr1 :: Ptr F) -> do + allocaArray 4 $ \(ptr2 :: Ptr F) -> do + pokeArray ptr1 (elems input) + c_dft4_fwd 1 1 (castPtr ptr1) (castPtr ptr2) + ys <- peekArray 4 ptr2 + return $ listToArray ys + +{-# NOINLINE shortInverseDFT4 #-} +shortInverseDFT4 :: Array Int F -> Array Int F +shortInverseDFT4 input + | arrayLength input /= 4 = error "shortInverseDFT4: expecting an input array of size 4" + | otherwise = unsafePerformIO $ do + allocaArray 4 $ \(ptr1 :: Ptr F) -> do + allocaArray 4 $ \(ptr2 :: Ptr F) -> do + pokeArray ptr1 (elems input) + c_dft4_inv_rescaled 1 1 (castPtr ptr1) (castPtr ptr2) + ys <- peekArray 4 ptr2 + return $ listToArray ys + +---------------------------------------- + +{-# NOINLINE shortForwardDFT8 #-} +shortForwardDFT8 :: Array Int F -> Array Int F +shortForwardDFT8 input + | arrayLength input /= 8 = error "shortForwardDFT8: expecting an input array of size 8" + | otherwise = unsafePerformIO $ do + allocaArray 8 $ \(ptr1 :: Ptr F) -> do + allocaArray 8 $ \(ptr2 :: Ptr F) -> do + pokeArray ptr1 (elems input) + c_dft8_fwd 1 1 (castPtr ptr1) (castPtr ptr2) + ys <- peekArray 8 ptr2 + return $ listToArray ys + +{-# NOINLINE shortInverseDFT8 #-} +shortInverseDFT8 :: Array Int F -> Array Int F +shortInverseDFT8 input + | arrayLength input /= 8 = error "shortInverseDFT8: expecting an input array of size 8" + | otherwise = unsafePerformIO $ do + allocaArray 8 $ \(ptr1 :: Ptr F) -> do + allocaArray 8 $ \(ptr2 :: Ptr F) -> do + pokeArray ptr1 (elems input) + c_dft8_inv_rescaled 1 1 (castPtr ptr1) (castPtr ptr2) + ys <- peekArray 8 ptr2 + return $ listToArray ys + +---------------------------------------- + +{-# NOINLINE shortForwardDFT16 #-} +shortForwardDFT16 :: Array Int F -> Array Int F +shortForwardDFT16 input + | arrayLength input /= 16 = error "shortForwardDFT16: expecting an input array of size 16" + | otherwise = unsafePerformIO $ do + allocaArray 16 $ \(ptr1 :: Ptr F) -> do + allocaArray 16 $ \(ptr2 :: Ptr F) -> do + pokeArray ptr1 (elems input) + c_dft16_fwd 1 1 (castPtr ptr1) (castPtr ptr2) + ys <- peekArray 16 ptr2 + return $ listToArray ys + +{-# NOINLINE shortInverseDFT16 #-} +shortInverseDFT16 :: Array Int F -> Array Int F +shortInverseDFT16 input + | arrayLength input /= 16 = error "shortInverseDFT16: expecting an input array of size 16" + | otherwise = unsafePerformIO $ do + allocaArray 16 $ \(ptr1 :: Ptr F) -> do + allocaArray 16 $ \(ptr2 :: Ptr F) -> do + pokeArray ptr1 (elems input) + c_dft16_inv_rescaled 1 1 (castPtr ptr1) (castPtr ptr2) + ys <- peekArray 16 ptr2 + return $ listToArray ys + +-------------------------------------------------------------------------------- + +testDFT4 :: IO () +testDFT4 = do + let sg = getSubgroup (Log2 2) + xs <- listToArray <$> (replicateM 4 rndIO :: IO [F]) + let ys0 = defForwardDFT sg xs + let ys1 = forwardNTT sg (Poly xs) + let ys2 = shortForwardDFT4 xs + print $ "xs = " ++ show (elems xs ) + print $ "ys0 (defin) = " ++ show (elems ys0) + print $ "ys1 (ref) = " ++ show (elems ys1) + print $ "ys2 (short) = " ++ show (elems ys2) + print $ "ok: " ++ show (ys1 == ys2) + +testIDFT4 :: IO () +testIDFT4 = do + let sg = getSubgroup (Log2 2) + xs <- listToArray <$> (replicateM 4 rndIO :: IO [F]) + let ys0 = defInverseDFT sg xs + let ys1 = fromPoly (inverseNTT sg xs) + let ys2 = shortInverseDFT4 xs + print $ "xs = " ++ show (elems xs ) + print $ "ys0 (defin) = " ++ show (elems ys0) + print $ "ys1 (ref) = " ++ show (elems ys1) + print $ "ys2 (short) = " ++ show (elems ys2) + print $ "ok: " ++ show (ys1 == ys2) + +---------------------------------------- + +testDFT8 :: IO () +testDFT8 = do + let sg = getSubgroup (Log2 3) + xs <- listToArray <$> (replicateM 8 rndIO :: IO [F]) + let ys0 = defForwardDFT sg xs + let ys1 = forwardNTT sg (Poly xs) + let ys2 = shortForwardDFT8 xs + print $ "xs = " ++ show (elems xs ) + print $ "ys0 (defin) = " ++ show (elems ys0) + print $ "ys1 (ref) = " ++ show (elems ys1) + print $ "ys2 (short) = " ++ show (elems ys2) + print $ "ok: " ++ show (ys1 == ys2) + +testIDFT8 :: IO () +testIDFT8 = do + let sg = getSubgroup (Log2 3) + xs <- listToArray <$> (replicateM 8 rndIO :: IO [F]) + let ys0 = defInverseDFT sg xs + let ys1 = fromPoly (inverseNTT sg xs) + let ys2 = shortInverseDFT8 xs + print $ "xs = " ++ show (elems xs ) + print $ "ys0 (defin) = " ++ show (elems ys0) + print $ "ys1 (ref) = " ++ show (elems ys1) + print $ "ys2 (short) = " ++ show (elems ys2) + print $ "ok: " ++ show (ys1 == ys2) + +-------------------------------------------------------------------------------- + +-- tmp for debugging +instance Ord F where + compare a b = compare (fromF a) (fromF b) + +testDFT16 :: IO () +testDFT16 = do + let sg = getSubgroup (Log2 4) + xs <- listToArray <$> (replicateM 16 rndIO :: IO [F]) + let ys0 = defForwardDFT sg xs + let ys1 = forwardNTT sg (Poly xs) + let ys2 = experimental_DFT16 xs + let ys3 = shortForwardDFT16 xs + print $ "xs = " ++ show (elems xs ) + print $ "ys0 (defin) = " ++ show (elems ys0) + print $ "ys1 (ref) = " ++ show (elems ys1) + print $ "ys2 (hsexp) = " ++ show (elems ys2) + print $ "ys3 (short) = " ++ show (elems ys3) + print $ "ok-def: " ++ show (ys0 == ys1) + print $ "ok-hs: " ++ show (ys1 == ys2) + print $ "ok-hs (modulo order): " ++ show (sort (elems ys1) == sort (elems ys2)) + print $ "ok-C: " ++ show (ys1 == ys3) + +testIDFT16 :: IO () +testIDFT16 = do + let sg = getSubgroup (Log2 4) + xs <- listToArray <$> (replicateM 16 rndIO :: IO [F]) + let ys0 = defInverseDFT sg xs + let ys1 = fromPoly (inverseNTT sg xs) + let ys2 = experimental_IDFT16 xs + let ys3 = shortInverseDFT16 xs + print $ "xs = " ++ show (elems xs ) + print $ "ys0 (defin) = " ++ show (elems ys0) + print $ "ys1 (ref) = " ++ show (elems ys1) + print $ "ys2 (hsexp) = " ++ show (elems ys2) + print $ "ys3 (short) = " ++ show (elems ys3) + print $ "ok-def: " ++ show (ys0 == ys1) + print $ "ok-hs: " ++ show (ys1 == ys2) + print $ "ok-C: " ++ show (ys1 == ys3) + +-------------------------------------------------------------------------------- + +dft2_omega, dft4_omega, dft8_omega, dft16_omega :: F +dft2_omega = subgroupGen $ getSubgroup (Log2 1) +dft4_omega = subgroupGen $ getSubgroup (Log2 2) +dft8_omega = subgroupGen $ getSubgroup (Log2 3) +dft16_omega = subgroupGen $ getSubgroup (Log2 4) + +dft2_inv_omega = inverse dft2_omega +dft4_inv_omega = inverse dft4_omega +dft8_inv_omega = inverse dft8_omega +dft16_inv_omega = inverse dft16_omega + +-------------------------------------------------------------------------------- +-- * Size = 4 + +dft4_j = dft4_omega +idft4_inv_4 = 1 / (4 :: F) + +printIDFT4 = do + putStrLn $ "const uint64_t IDFT4_OMEGA = " ++ show dft4_omega ++ " ;" + putStrLn $ "const uint64_t IDFT4_INV_OMEGA = " ++ show dft4_inv_omega ++ " ;" + putStrLn $ "const uint64_t IDFT4_J = " ++ show dft4_j ++ " ;" + putStrLn $ "const uint64_t IDFT4_INV_4 = " ++ show idft4_inv_4 ++ " ;" + +-------------------------------------------------------------------------------- +-- * Size = 8 + +dft8_j = square dft8_omega +dft8_cos_u = - (dft8_omega + dft8_omega^7) / 2 +dft8_j_sin_u = - (dft8_omega - dft8_omega^7) / 2 +dft8_minus_j_sin_u = - dft8_j_sin_u + +printDFT8 = do + putStrLn $ "const uint64_t DFT8_OMEGA = " ++ show dft8_omega ++ " ;" + putStrLn $ "const uint64_t DFT8_INV_OMEGA = " ++ show dft8_inv_omega ++ " ;" + putStrLn $ "const uint64_t DFT8_J = " ++ show dft8_j ++ " ;" + putStrLn $ "const uint64_t DFT8_COS_U = " ++ show dft8_cos_u ++ " ;" + putStrLn $ "const uint64_t DFT8_MINUS_J_SIN_U = " ++ show dft8_minus_j_sin_u ++ " ;" + +idft8_j = square dft8_omega +idft8_cos_u = (dft8_omega + dft8_omega^7) / 2 +idft8_j_sin_u = (dft8_omega - dft8_omega^7) / 2 +idft8_minus_j_sin_u = - idft8_j_sin_u +idft8_inv_8 = 1 / (8 :: F) + +printIDFT8 = do + putStrLn $ "const uint64_t IDFT8_OMEGA = " ++ show dft8_omega ++ " ;" + putStrLn $ "const uint64_t IDFT8_INV_OMEGA = " ++ show dft8_inv_omega ++ " ;" + putStrLn $ "const uint64_t IDFT8_J = " ++ show idft8_j ++ " ;" + putStrLn $ "const uint64_t IDFT8_COS_U = " ++ show idft8_cos_u ++ " ;" + putStrLn $ "const uint64_t IDFT8_MINUS_J_SIN_U = " ++ show idft8_minus_j_sin_u ++ " ;" + putStrLn $ "const uint64_t IDFT8_INV_8 = " ++ show idft8_inv_8 ++ " ;" + +-------------------------------------------------------------------------------- +-- * Size = 16 + +dft16_j = square (square dft16_omega) + +dft16_cos_u = (dft16_omega + dft16_omega^15) / 2 +dft16_cos_2u = (dft16_omega^2 + dft16_omega^14) / 2 +dft16_cos_3u = (dft16_omega^3 + dft16_omega^13) / 2 + +dft16_j_sin_u = (dft16_omega - dft16_omega^15) / 2 -- ???? +dft16_j_sin_2u = (dft16_omega^2 - dft16_omega^14) / 2 -- but it seems to work... +dft16_j_sin_3u = (dft16_omega^3 - dft16_omega^13) / 2 + +printDFT16 :: IO () +printDFT16 = do + putStrLn $ "const uint64_t DFT16_OMEGA = " ++ show dft16_omega ++ " ;" + putStrLn $ "const uint64_t DFT16_INV_OMEGA = " ++ show dft16_inv_omega ++ " ;" + putStrLn $ "const uint64_t DFT16_J = " ++ show dft16_j ++ " ;" + putStrLn $ "const uint64_t DFT16_COS_U = " ++ show dft16_cos_u ++ " ;" + putStrLn $ "const uint64_t DFT16_COS_2U = " ++ show dft16_cos_2u ++ " ;" + putStrLn $ "const uint64_t DFT16_COS_3U = " ++ show dft16_cos_3u ++ " ;" + putStrLn $ "const uint64_t DFT16_MINUS_J_SIN_U = " ++ show (- dft16_j_sin_u ) ++ " ;" + putStrLn $ "const uint64_t DFT16_MINUS_J_SIN_2U = " ++ show (- dft16_j_sin_2u) ++ " ;" + putStrLn $ "const uint64_t DFT16_MINUS_J_SIN_3U = " ++ show (- dft16_j_sin_3u) ++ " ;" + + putStrLn $ "const uint64_t DFT16_COS_3U_PLUS_U = " ++ show ( dft16_cos_3u + dft16_cos_u ) ++ " ;" + putStrLn $ "const uint64_t DFT16_COS_3U_MINUS_U = " ++ show ( dft16_cos_3u - dft16_cos_u ) ++ " ;" + putStrLn $ "const uint64_t DFT16_J_SIN_3U_MINUS_U = " ++ show ( dft16_j_sin_3u - dft16_j_sin_u) ++ " ;" + putStrLn $ "const uint64_t DFT16_J_SIN_MINUS_3U_MINUS_U = " ++ show (- dft16_j_sin_3u - dft16_j_sin_u) ++ " ;" + +-------------------- + +idft16_inv_16 = 1 / (16 :: F) +idft16_j = square (square dft16_omega) + +idft16_cos_u = (dft16_omega + dft16_omega^15) / 2 +idft16_cos_2u = (dft16_omega^2 + dft16_omega^14) / 2 +idft16_cos_3u = (dft16_omega^3 + dft16_omega^13) / 2 + +idft16_j_sin_u = (dft16_omega - dft16_omega^15) / 2 +idft16_j_sin_2u = (dft16_omega^2 - dft16_omega^14) / 2 +idft16_j_sin_3u = (dft16_omega^3 - dft16_omega^13) / 2 + +printIDFT16 :: IO () +printIDFT16 = do + putStrLn $ "const uint64_t IDFT16_OMEGA = " ++ show dft16_omega ++ " ;" + putStrLn $ "const uint64_t IDFT16_INV_OMEGA = " ++ show dft16_inv_omega ++ " ;" + putStrLn $ "const uint64_t IDFT16_INV_16 = " ++ show idft16_inv_16 ++ " ;" + putStrLn $ "const uint64_t IDFT16_J = " ++ show idft16_j ++ " ;" + putStrLn $ "const uint64_t IDFT16_COS_U = " ++ show idft16_cos_u ++ " ;" + putStrLn $ "const uint64_t IDFT16_COS_2U = " ++ show idft16_cos_2u ++ " ;" + putStrLn $ "const uint64_t IDFT16_COS_3U = " ++ show idft16_cos_3u ++ " ;" + putStrLn $ "const uint64_t IDFT16_MINUS_J_SIN_U = " ++ show (- idft16_j_sin_u ) ++ " ;" + putStrLn $ "const uint64_t IDFT16_MINUS_J_SIN_2U = " ++ show (- idft16_j_sin_2u) ++ " ;" + putStrLn $ "const uint64_t IDFT16_MINUS_J_SIN_3U = " ++ show (- idft16_j_sin_3u) ++ " ;" + + putStrLn $ "const uint64_t IDFT16_COS_3U_PLUS_U = " ++ show ( idft16_cos_3u + idft16_cos_u ) ++ " ;" + putStrLn $ "const uint64_t IDFT16_COS_3U_MINUS_U = " ++ show ( idft16_cos_3u - idft16_cos_u ) ++ " ;" + putStrLn $ "const uint64_t IDFT16_J_SIN_3U_MINUS_U = " ++ show ( idft16_j_sin_3u - idft16_j_sin_u) ++ " ;" + putStrLn $ "const uint64_t IDFT16_J_SIN_MINUS_3U_MINUS_U = " ++ show (- idft16_j_sin_3u - idft16_j_sin_u) ++ " ;" + +-------------------------------------------------------------------------------- + +experimental_IDFT16 :: Array Int F -> Array Int F +experimental_IDFT16 input = output where + + rescale :: F -> F + rescale x = x / 16 + + [x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15] = elems input + output = listArray (0,15) $ map rescale [y0,y1,y2,y3,y4,y5,y6,y7,y8,y9,y10,y11,y12,y13,y14,y15] + + t1 = x0 + x8 + t2 = x4 + x12 + t3 = x2 + x10 + t4 = x2 - x10 + t5 = x6 + x14 + t6 = x6 - x14 + t7 = x1 + x9 + t8 = x1 - x9 + t9 = x3 + x11 + t10 = x3 - x11 + t11 = x5 + x13 + t12 = x5 - x13 + t13 = x7 + x15 + t14 = x7 - x15 + t15 = t1 + t2 + t16 = t3 + t5 + t17 = t15 + t16 + t18 = t7 + t11 + t19 = t7 - t11 + t20 = t9 + t13 + t21 = t9 - t13 + t22 = t18 + t20 + t23 = t8 + t14 + t24 = t8 - t14 + t25 = t10 + t12 + t26 = t12 - t10 + + m0 = t17 + t22 + m1 = t17 - t22 + m2 = t15 - t16 + m3 = t1 - t2 + m4 = x0 - x8 + + m5 = idft16_cos_2u * (t19 - t21) + m6 = idft16_cos_2u * (t4 - t6 ) + m7 = idft16_cos_3u * (t24 + t26) + m8 = (idft16_cos_3u + idft16_cos_u) * t24 + m9 = (idft16_cos_3u - idft16_cos_u) * t26 + m10 = idft16_j * (t20 - t18) + m11 = idft16_j * (t5 - t3 ) + m12 = idft16_j * (x12 - x4 ) + m13 = - idft16_j_sin_2u * ( t19 + t21) + m14 = - idft16_j_sin_2u * ( t4 + t6 ) + m15 = - idft16_j_sin_3u * ( t23 + t25) + m16 = (idft16_j_sin_3u - idft16_j_sin_u) * t23 + m17 = - (idft16_j_sin_3u + idft16_j_sin_u) * t25 + + s7 = m8 - m7 + s8 = m9 - m7 + + s15 = m15 + m16 + s16 = m15 - m17 + + s1 = m3 + m5 + s2 = m3 - m5 + s3 = m11 + m13 + s4 = m13 - m11 + s5 = m4 + m6 + s6 = m4 - m6 + + s9 = s5 + s7 + s10 = s5 - s7 + s11 = s6 + s8 + s12 = s6 - s8 + + s13 = m12 + m14 + s14 = m12 - m14 + + s17 = s13 + s15 + s18 = s13 - s15 + s19 = s14 + s16 + s20 = s14 - s16 + + y0 = m0 + y1 = s9 + s17 + y2 = s1 + s3 + y3 = s12 - s20 + y4 = m2 + m10 + y5 = s11 + s19 + y6 = s2 + s4 + y7 = s10 - s18 + y8 = m1 + y9 = s10 + s18 + y10 = s2 - s4 + y11 = s11 - s19 + y12 = m2 - m10 + y13 = s12 + s20 + y14 = s1 - s3 + y15 = s9 - s17 + +---------------------------------------- + +-- is it always true, that in the NTT setting, the DFT and the unscaled +-- inverse DFT are the same up to permutation?? +-- +experimental_DFT16 :: Array Int F -> Array Int F +experimental_DFT16 input = output where + + [x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15] = elems input + output = listArray (0,15) [y0,y1,y2,y3,y4,y5,y6,y7,y8,y9,y10,y11,y12,y13,y14,y15] + + t1 = x0 + x8 + t2 = x4 + x12 + t3 = x2 + x10 + t4 = x2 - x10 + t5 = x6 + x14 + t6 = x6 - x14 + t7 = x1 + x9 + t8 = x1 - x9 + t9 = x3 + x11 + t10 = x3 - x11 + t11 = x5 + x13 + t12 = x5 - x13 + t13 = x7 + x15 + t14 = x7 - x15 + t15 = t1 + t2 + t16 = t3 + t5 + t17 = t15 + t16 + t18 = t7 + t11 + t19 = t7 - t11 + t20 = t9 + t13 + t21 = t9 - t13 + t22 = t18 + t20 + t23 = t8 + t14 + t24 = t8 - t14 + t25 = t10 + t12 + t26 = t12 - t10 + + m0 = t17 + t22 + m1 = t17 - t22 + m2 = t15 - t16 + m3 = t1 - t2 + m4 = x0 - x8 + + m5 = dft16_cos_2u * (t19 - t21) + m6 = dft16_cos_2u * (t4 - t6 ) + m7 = dft16_cos_3u * (t24 + t26) + m8 = (dft16_cos_3u + dft16_cos_u) * t24 + m9 = (dft16_cos_3u - dft16_cos_u) * t26 + m10 = dft16_j * (t20 - t18) + m11 = dft16_j * (t5 - t3 ) + m12 = dft16_j * (x12 - x4 ) + m13 = - dft16_j_sin_2u * ( t19 + t21) + m14 = - dft16_j_sin_2u * ( t4 + t6 ) + m15 = - dft16_j_sin_3u * ( t23 + t25) + m16 = (dft16_j_sin_3u - dft16_j_sin_u) * t23 + m17 = - (dft16_j_sin_3u + dft16_j_sin_u) * t25 + + s7 = m8 - m7 + s8 = m9 - m7 + + s15 = m15 + m16 + s16 = m15 - m17 + + s1 = m3 + m5 + s2 = m3 - m5 + s3 = m11 + m13 + s4 = m13 - m11 + s5 = m4 + m6 + s6 = m4 - m6 + + s9 = s5 + s7 + s10 = s5 - s7 + s11 = s6 + s8 + s12 = s6 - s8 + + s13 = m12 + m14 + s14 = m12 - m14 + + s17 = s13 + s15 + s18 = s13 - s15 + s19 = s14 + s16 + s20 = s14 - s16 + + y0 = m0 + y1 = s9 - s17 + y2 = s1 - s3 + y3 = s12 + s20 + y4 = m2 - m10 + y5 = s11 - s19 + y6 = s2 - s4 + y7 = s10 + s18 + y8 = m1 + y9 = s10 - s18 + y10 = s2 + s4 + y11 = s11 + s19 + y12 = m2 + m10 + y13 = s12 - s20 + y14 = s1 + s3 + y15 = s9 + s17 + +-------------------------------------------------------------------------------- + +data Cost = MkCost + { _nAdds :: Int + , _nMuls :: Int + } + deriving (Eq,Show) + +addCost :: Cost -> Cost -> Cost +addCost (MkCost a1 m1) (MkCost a2 m2) = MkCost (a1+a2) (m1+m2) + +scaleCost :: Int -> Cost -> Cost +scaleCost s (MkCost a m) = MkCost (s*a) (s*m) + +doubleCost :: Cost -> Cost +doubleCost = scaleCost 2 + +estimateNTTCost :: Log2 -> Cost +estimateNTTCost = go where + go :: Log2 -> Cost + go 0 = MkCost 0 0 + go 1 = MkCost 2 0 + go m = recursive `addCost` post where + recursive = doubleCost (go (m-1)) + post = scaleCost halfN (MkCost 3 1) + halfN = exp2_ (m-1) + +-------------------------------------------------------------------------------- + diff --git a/reference/src/cbits/compile.sh b/reference/src/cbits/compile.sh index c52bddd..afa43ab 100755 --- a/reference/src/cbits/compile.sh +++ b/reference/src/cbits/compile.sh @@ -4,4 +4,5 @@ gcc -c -O2 goldilocks.c gcc -c -O2 goldilocks_ext.c gcc -c -O2 monolith.c gcc -c -O2 ntt.c +gcc -c -O2 short_dft.c diff --git a/reference/src/cbits/short_dft.c b/reference/src/cbits/short_dft.c new file mode 100644 index 0000000..d5444c9 --- /dev/null +++ b/reference/src/cbits/short_dft.c @@ -0,0 +1,566 @@ + +// +// Short DFT algoritmus (size 8 and 16) +// +// See: +// +// - Nussbaumer: "Fast Fourier Transform and Convolution Algorithms", Chapter 5.5 +// +// Note: As they describe complex DFT and we need NTT, the conventions differ a bit. +// +// Hence some formulas in the comments look false, that's because we follow the +// formulas from the book but need to change some coefficients and/or ordering... +// +// Note #2: If the multiplicative generator is changed, the constants need +// to be regenerated. See the module "NTT.FTT.Short" +// + +#include + +#include "goldilocks.h" +#include "short_dft.h" + +//------------------------------------------------------------------------------ +// SIZE = 4 + +const uint64_t DFT4_J = 0x0001000000000000 ; + +const uint64_t IDFT4_OMEGA = 0x0001000000000000 ; +const uint64_t IDFT4_INV_OMEGA = 0xfffeffff00000001 ; +const uint64_t IDFT4_J = 0x0001000000000000 ; +const uint64_t IDFT4_INV_4 = 0xbfffffff40000001 ; + +void short_inv_DFT_size_4_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) { + + int src_stride2 = src_stride + src_stride ; + int src_stride3 = src_stride2 + src_stride ; + + uint64_t x0 = src[ 0]; + uint64_t x1 = src[src_stride ]; + uint64_t x2 = src[src_stride2]; + uint64_t x3 = src[src_stride3]; + + uint64_t t1 = goldilocks_add( x0 , x2 ); // x0 + x2 + uint64_t t2 = goldilocks_add( x1 , x3 ); // x1 + x3 + uint64_t m0 = goldilocks_add( t1 , t2 ); // t1 + t2 + uint64_t m1 = goldilocks_sub( t1 , t2 ); // t1 - t2 + uint64_t m2 = goldilocks_sub( x0 , x2 ); // x0 - x2 + uint64_t m3 = goldilocks_mul( DFT4_J , goldilocks_sub( x3 , x1 ) ); // j * (x3 - x1) + + int tgt_stride2 = tgt_stride + tgt_stride ; + int tgt_stride3 = tgt_stride2 + tgt_stride ; + + tgt[ 0] = m0; + tgt[tgt_stride ] = goldilocks_add( m2 , m3 ); + tgt[tgt_stride2] = m1; + tgt[tgt_stride3] = goldilocks_sub( m2 , m3 ); + +} + +void short_fwd_DFT_size_4( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) { + short_inv_DFT_size_4_unscaled( src_stride, tgt_stride, src, tgt ); + + int tgt_stride3 = 3*tgt_stride; + + uint64_t tmp = tgt[tgt_stride ]; + tgt[tgt_stride ] = tgt[tgt_stride3]; + tgt[tgt_stride3] = tmp; +} + +void short_inv_DFT_size_4_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) { + + short_inv_DFT_size_4_unscaled( src_stride, tgt_stride, src, tgt ); + + int tgt_stride2 = tgt_stride + tgt_stride ; + int tgt_stride3 = tgt_stride2 + tgt_stride ; + + tgt[ 0] = goldilocks_mul( IDFT4_INV_4 , tgt[ 0] ); + tgt[tgt_stride ] = goldilocks_mul( IDFT4_INV_4 , tgt[tgt_stride ] ); + tgt[tgt_stride2] = goldilocks_mul( IDFT4_INV_4 , tgt[tgt_stride2] ); + tgt[tgt_stride3] = goldilocks_mul( IDFT4_INV_4 , tgt[tgt_stride3] ); +} + +//------------------------------------------------------------------------------ +// SIZE = 8 + +const uint64_t DFT8_OMEGA = 0xfffffffeff000001 ; +const uint64_t DFT8_INV_OMEGA = 0x000000ffffffff00 ; +const uint64_t DFT8_J = 0x0001000000000000 ; +const uint64_t DFT8_COS_U = 0xffffff7f00800081 ; +const uint64_t DFT8_MINUS_J_SIN_U = 0xffffff7eff800081 ; + +void short_fwd_DFT_size_8( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) { + // u = 2pi/8 + // omega = cos(u) + i*sin(u) + // + // cos_u ~> - (omega+omega^7) / 2 + // j_sin_u ~> - (omega-omega^7) / 2 + // -j_sin_u ~> + (omega-omega^7) / 2 + // j ~> omega^2 + + int src_stride2 = src_stride + src_stride ; + int src_stride3 = src_stride2 + src_stride ; + int src_stride4 = src_stride2 + src_stride2; + int src_stride5 = src_stride4 + src_stride ; + int src_stride6 = src_stride4 + src_stride2; + int src_stride7 = src_stride4 + src_stride3; + + uint64_t x0 = src[0 ]; + uint64_t x1 = src[src_stride ]; + uint64_t x2 = src[src_stride2]; + uint64_t x3 = src[src_stride3]; + uint64_t x4 = src[src_stride4]; + uint64_t x5 = src[src_stride5]; + uint64_t x6 = src[src_stride6]; + uint64_t x7 = src[src_stride7]; + + uint64_t t1 = goldilocks_add( x0 , x4 ); // x0 + x4 + uint64_t t2 = goldilocks_add( x2 , x6 ); // x2 + x6 + uint64_t t3 = goldilocks_add( x1 , x5 ); // x1 + x5 + uint64_t t4 = goldilocks_sub( x1 , x5 ); // x1 - x5 + uint64_t t5 = goldilocks_add( x3 , x7 ); // x3 + x7 + uint64_t t6 = goldilocks_sub( x3 , x7 ); // x3 - x7 + uint64_t t7 = goldilocks_add( t1 , t2 ); // t1 + t2 + uint64_t t8 = goldilocks_add( t3 , t5 ); // t3 + t5 + + uint64_t m0 = goldilocks_add( t7 , t8 ); // (t7 + t8) + uint64_t m1 = goldilocks_sub( t7 , t8 ); // (t7 - t8) + uint64_t m2 = goldilocks_sub( t1 , t2 ); // (t1 - t2) + uint64_t m3 = goldilocks_sub( x0 , x4 ); // (x0 - x4) + + uint64_t m4 = goldilocks_mul( DFT8_COS_U , goldilocks_sub( t4 , t6 ) ); // cos_u * (t4 - t6) + uint64_t m5 = goldilocks_mul( DFT8_J , goldilocks_sub( t5 , t3 ) ); // j*(t5 - t3) + uint64_t m6 = goldilocks_mul( DFT8_J , goldilocks_sub( x6 , x2 ) ); // j*(x6 - x2) + uint64_t m7 = goldilocks_mul( DFT8_MINUS_J_SIN_U , goldilocks_add( t4 , t6 ) ); // - j_sin_u * (t4 + t6) + + uint64_t s1 = goldilocks_add( m3 , m4 ); // m3 + m4 + uint64_t s2 = goldilocks_sub( m3 , m4 ); // m3 - m4 + uint64_t s3 = goldilocks_add( m6 , m7 ); // m6 + m7 + uint64_t s4 = goldilocks_sub( m6 , m7 ); // m6 - m7 + + int tgt_stride2 = tgt_stride + tgt_stride ; + int tgt_stride3 = tgt_stride2 + tgt_stride ; + int tgt_stride4 = tgt_stride2 + tgt_stride2; + int tgt_stride5 = tgt_stride4 + tgt_stride ; + int tgt_stride6 = tgt_stride4 + tgt_stride2; + int tgt_stride7 = tgt_stride4 + tgt_stride3; + + tgt[ 0] = m0; // m0 + tgt[tgt_stride ] = goldilocks_sub( s2 , s4 ); // s2 - s4 + tgt[tgt_stride2] = goldilocks_sub( m2 , m5 ); // m2 - m5 + tgt[tgt_stride3] = goldilocks_add( s1 , s3 ); // s1 + s3 + tgt[tgt_stride4] = m1; // m1 + tgt[tgt_stride5] = goldilocks_sub( s1 , s3 ); // s1 - s3 + tgt[tgt_stride6] = goldilocks_add( m2 , m5 ); // m2 + m5 + tgt[tgt_stride7] = goldilocks_add( s2 , s4 ); // s2 + s4 +} + +const uint64_t IDFT8_OMEGA = 0xfffffffeff000001 ; +const uint64_t IDFT8_INV_OMEGA = 0x000000ffffffff00 ; +const uint64_t IDFT8_J = 0x0001000000000000 ; +const uint64_t IDFT8_COS_U = 0x0000007fff7fff80 ; +const uint64_t IDFT8_MINUS_J_SIN_U = 0x00000080007fff80 ; +const uint64_t IDFT8_INV_8 = 0xdfffffff20000001 ; + +//-------------------------------------- + +void short_inv_DFT_size_8_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) { + // u = 2pi/8 + // omega = cos(u) + i*sin(u) + // + // cos_u ~> (omega+omega^7) / 2 + // -j_sin_u ~> - (omega-omega^7) / 2 + // j ~> omega^2 + + int src_stride2 = src_stride + src_stride ; + int src_stride3 = src_stride2 + src_stride ; + int src_stride4 = src_stride2 + src_stride2; + int src_stride5 = src_stride4 + src_stride ; + int src_stride6 = src_stride4 + src_stride2; + int src_stride7 = src_stride4 + src_stride3; + + uint64_t x0 = src[0 ]; + uint64_t x1 = src[src_stride ]; + uint64_t x2 = src[src_stride2]; + uint64_t x3 = src[src_stride3]; + uint64_t x4 = src[src_stride4]; + uint64_t x5 = src[src_stride5]; + uint64_t x6 = src[src_stride6]; + uint64_t x7 = src[src_stride7]; + + uint64_t t1 = goldilocks_add( x0 , x4 ); // x0 + x4 + uint64_t t2 = goldilocks_add( x2 , x6 ); // x2 + x6 + uint64_t t3 = goldilocks_add( x1 , x5 ); // x1 + x5 + uint64_t t4 = goldilocks_sub( x1 , x5 ); // x1 - x5 + uint64_t t5 = goldilocks_add( x3 , x7 ); // x3 + x7 + uint64_t t6 = goldilocks_sub( x3 , x7 ); // x3 - x7 + uint64_t t7 = goldilocks_add( t1 , t2 ); // t1 + t2 + uint64_t t8 = goldilocks_add( t3 , t5 ); // t3 + t5 + + uint64_t m0 = goldilocks_add( t7 , t8 ); // (t7 + t8) + uint64_t m1 = goldilocks_sub( t7 , t8 ); // (t7 - t8) + uint64_t m2 = goldilocks_sub( t1 , t2 ); // (t1 - t2) + uint64_t m3 = goldilocks_sub( x0 , x4 ); // (x0 - x4) + + uint64_t m4 = goldilocks_mul( IDFT8_COS_U , goldilocks_sub( t4 , t6 ) ); // cos_u * (t4 - t6) + uint64_t m5 = goldilocks_mul( IDFT8_J , goldilocks_sub( t5 , t3 ) ); // j*(t5 - t3) + uint64_t m6 = goldilocks_mul( IDFT8_J , goldilocks_sub( x6 , x2 ) ); // j*(x6 - x2) + uint64_t m7 = goldilocks_mul( IDFT8_MINUS_J_SIN_U , goldilocks_add( t4 , t6 ) ); // - j_sin_u * (t4 + t6) + + uint64_t s1 = goldilocks_add( m3 , m4 ); // m3 + m4 + uint64_t s2 = goldilocks_sub( m3 , m4 ); // m3 - m4 + uint64_t s3 = goldilocks_add( m6 , m7 ); // m6 + m7 + uint64_t s4 = goldilocks_sub( m6 , m7 ); // m6 - m7 + + int tgt_stride2 = tgt_stride + tgt_stride ; + int tgt_stride3 = tgt_stride2 + tgt_stride ; + int tgt_stride4 = tgt_stride2 + tgt_stride2; + int tgt_stride5 = tgt_stride4 + tgt_stride ; + int tgt_stride6 = tgt_stride4 + tgt_stride2; + int tgt_stride7 = tgt_stride4 + tgt_stride3; + + tgt[ 0] = m0 ; // m0 + tgt[tgt_stride ] = goldilocks_add( s1 , s3 ); // s1 + s3 + tgt[tgt_stride2] = goldilocks_add( m2 , m5 ); // m2 + m5 + tgt[tgt_stride3] = goldilocks_sub( s2 , s4 ); // s2 - s4 + tgt[tgt_stride4] = m1 ; // m1 + tgt[tgt_stride5] = goldilocks_add( s2 , s4 ); // s2 + s4 + tgt[tgt_stride6] = goldilocks_sub( m2 , m5 ); // m2 - m5 + tgt[tgt_stride7] = goldilocks_sub( s1 , s3 ); // s1 - s3 +} + +//------------------ + +void short_inv_DFT_size_8_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) { + + short_inv_DFT_size_8_unscaled( src_stride, tgt_stride, src, tgt ); + + int tgt_stride2 = tgt_stride + tgt_stride ; + int tgt_stride3 = tgt_stride2 + tgt_stride ; + int tgt_stride4 = tgt_stride2 + tgt_stride2; + int tgt_stride5 = tgt_stride4 + tgt_stride ; + int tgt_stride6 = tgt_stride4 + tgt_stride2; + int tgt_stride7 = tgt_stride4 + tgt_stride3; + + tgt[ 0] = goldilocks_mul( IDFT8_INV_8 , tgt[ 0] ); + tgt[tgt_stride ] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride ] ); + tgt[tgt_stride2] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride2] ); + tgt[tgt_stride3] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride3] ); + tgt[tgt_stride4] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride4] ); + tgt[tgt_stride5] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride5] ); + tgt[tgt_stride6] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride6] ); + tgt[tgt_stride7] = goldilocks_mul( IDFT8_INV_8 , tgt[tgt_stride7] ); +} + +//------------------------------------------------------------------------------ +// SIZE = 16 + +const uint64_t DFT16_OMEGA = 0xefffffff00000001 ; +const uint64_t DFT16_INV_OMEGA = 0x0000001000000000 ; +const uint64_t DFT16_J = 0x0001000000000000 ; +const uint64_t DFT16_COS_U = 0xf800000700000001 ; +const uint64_t DFT16_COS_2U = 0x0000007fff7fff80 ; +const uint64_t DFT16_COS_3U = 0x0007fffffff7f800 ; +const uint64_t DFT16_MINUS_J_SIN_U = 0x0800000800000000 ; +const uint64_t DFT16_MINUS_J_SIN_2U = 0x00000080007fff80 ; +const uint64_t DFT16_MINUS_J_SIN_3U = 0xfff7ffff0007f801 ; +const uint64_t DFT16_COS_3U_PLUS_U = 0xf8080006fff7f801 ; +const uint64_t DFT16_COS_3U_MINUS_U = 0x0807fff7fff7f800 ; +const uint64_t DFT16_J_SIN_3U_MINUS_U = 0x08080007fff80800 ; +const uint64_t DFT16_J_SIN_MINUS_3U_MINUS_U = 0x07f800080007f800 ; + +void short_fwd_DFT_size_16( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) { + + int src_stride2 = src_stride + src_stride ; + int src_stride3 = src_stride2 + src_stride ; + int src_stride4 = src_stride2 + src_stride2; + int src_stride5 = src_stride4 + src_stride ; + int src_stride6 = src_stride4 + src_stride2; + int src_stride7 = src_stride4 + src_stride3; + int src_stride8 = src_stride4 + src_stride4; + + uint64_t x0 = src[ 0 ]; + uint64_t x1 = src[ src_stride ]; + uint64_t x2 = src[ src_stride2 ]; + uint64_t x3 = src[ src_stride3 ]; + uint64_t x4 = src[ src_stride4 ]; + uint64_t x5 = src[ src_stride5 ]; + uint64_t x6 = src[ src_stride6 ]; + uint64_t x7 = src[ src_stride7 ]; + uint64_t x8 = src[ src_stride8 ]; + uint64_t x9 = src[ src_stride + src_stride8 ]; + uint64_t x10 = src[ src_stride2 + src_stride8 ]; + uint64_t x11 = src[ src_stride3 + src_stride8 ]; + uint64_t x12 = src[ src_stride4 + src_stride8 ]; + uint64_t x13 = src[ src_stride5 + src_stride8 ]; + uint64_t x14 = src[ src_stride6 + src_stride8 ]; + uint64_t x15 = src[ src_stride7 + src_stride8 ]; + + uint64_t t1 = goldilocks_add( x0 , x8 ); // x0 + x8 + uint64_t t2 = goldilocks_add( x4 , x12 ); // x4 + x12 + uint64_t t3 = goldilocks_add( x2 , x10 ); // x2 + x10 + uint64_t t4 = goldilocks_sub( x2 , x10 ); // x2 - x10 + uint64_t t5 = goldilocks_add( x6 , x14 ); // x6 + x14 + uint64_t t6 = goldilocks_sub( x6 , x14 ); // x6 - x14 + uint64_t t7 = goldilocks_add( x1 , x9 ); // x1 + x9 + uint64_t t8 = goldilocks_sub( x1 , x9 ); // x1 - x9 + uint64_t t9 = goldilocks_add( x3 , x11 ); // x3 + x11 + uint64_t t10 = goldilocks_sub( x3 , x11 ); // x3 - x11 + uint64_t t11 = goldilocks_add( x5 , x13 ); // x5 + x13 + uint64_t t12 = goldilocks_sub( x5 , x13 ); // x5 - x13 + uint64_t t13 = goldilocks_add( x7 , x15 ); // x7 + x15 + uint64_t t14 = goldilocks_sub( x7 , x15 ); // x7 - x15 + uint64_t t15 = goldilocks_add( t1 , t2 ); // t1 + t2 + uint64_t t16 = goldilocks_add( t3 , t5 ); // t3 + t5 + uint64_t t17 = goldilocks_add( t15 , t16 ); // t15 + t16 + uint64_t t18 = goldilocks_add( t7 , t11 ); // t7 + t11 + uint64_t t19 = goldilocks_sub( t7 , t11 ); // t7 - t11 + uint64_t t20 = goldilocks_add( t9 , t13 ); // t9 + t13 + uint64_t t21 = goldilocks_sub( t9 , t13 ); // t9 - t13 + uint64_t t22 = goldilocks_add( t18 , t20 ); // t18 + t20 + uint64_t t23 = goldilocks_add( t8 , t14 ); // t8 + t14 + uint64_t t24 = goldilocks_sub( t8 , t14 ); // t8 - t14 + uint64_t t25 = goldilocks_add( t10 , t12 ); // t10 + t12 + uint64_t t26 = goldilocks_sub( t12 , t10 ); // t12 - t10 + + uint64_t m0 = goldilocks_add( t17 , t22 ); // t17 + t22 + uint64_t m1 = goldilocks_sub( t17 , t22 ); // t17 - t22 + uint64_t m2 = goldilocks_sub( t15 , t16 ); // t15 - t16 + uint64_t m3 = goldilocks_sub( t1 , t2 ); // t1 - t2 + uint64_t m4 = goldilocks_sub( x0 , x8 ); // x0 - x8 + + uint64_t m5 = goldilocks_mul( DFT16_COS_2U , goldilocks_sub(t19 , t21) ); + uint64_t m6 = goldilocks_mul( DFT16_COS_2U , goldilocks_sub(t4 , t6 ) ); + uint64_t m7 = goldilocks_mul( DFT16_COS_3U , goldilocks_add(t24 , t26) ); + uint64_t m8 = goldilocks_mul( DFT16_COS_3U_PLUS_U , t24 ); + uint64_t m9 = goldilocks_mul( DFT16_COS_3U_MINUS_U , t26 ); + uint64_t m10 = goldilocks_mul( DFT16_J , goldilocks_sub(t20 , t18) ); + uint64_t m11 = goldilocks_mul( DFT16_J , goldilocks_sub(t5 , t3 ) ); + uint64_t m12 = goldilocks_mul( DFT16_J , goldilocks_sub(x12 , x4 ) ); + uint64_t m13 = goldilocks_mul( DFT16_MINUS_J_SIN_2U , goldilocks_add( t19 , t21) ); + uint64_t m14 = goldilocks_mul( DFT16_MINUS_J_SIN_2U , goldilocks_add( t4 , t6 ) ); + uint64_t m15 = goldilocks_mul( DFT16_MINUS_J_SIN_3U , goldilocks_add( t23 , t25) ); + uint64_t m16 = goldilocks_mul( DFT16_J_SIN_3U_MINUS_U , t23 ); + uint64_t m17 = goldilocks_mul( DFT16_J_SIN_MINUS_3U_MINUS_U , t25 ); + + uint64_t s1 = goldilocks_add( m3 , m5 ); // m3 + m5 + uint64_t s2 = goldilocks_sub( m3 , m5 ); // m3 - m5 + uint64_t s3 = goldilocks_add( m11 , m13 ); // m11 + m13 + uint64_t s4 = goldilocks_sub( m13 , m11 ); // m13 - m11 + uint64_t s5 = goldilocks_add( m4 , m6 ); // m4 + m6 + uint64_t s6 = goldilocks_sub( m4 , m6 ); // m4 - m6 + uint64_t s7 = goldilocks_sub( m8 , m7 ); // m8 - m7 + uint64_t s8 = goldilocks_sub( m9 , m7 ); // m9 - m7 + uint64_t s9 = goldilocks_add( s5 , s7 ); // s5 + s7 + uint64_t s10 = goldilocks_sub( s5 , s7 ); // s5 - s7 + uint64_t s11 = goldilocks_add( s6 , s8 ); // s6 + s8 + uint64_t s12 = goldilocks_sub( s6 , s8 ); // s6 - s8 + uint64_t s13 = goldilocks_add( m12 , m14 ); // m12 + m14 + uint64_t s14 = goldilocks_sub( m12 , m14 ); // m12 - m14 + uint64_t s15 = goldilocks_add( m15 , m16 ); // m15 + m16 + uint64_t s16 = goldilocks_sub( m15 , m17 ); // m15 - m17 + uint64_t s17 = goldilocks_add( s13 , s15 ); // s13 + s15 + uint64_t s18 = goldilocks_sub( s13 , s15 ); // s13 - s15 + uint64_t s19 = goldilocks_add( s14 , s16 ); // s14 + s16 + uint64_t s20 = goldilocks_sub( s14 , s16 ); // s14 - s16 + + int tgt_stride2 = tgt_stride + tgt_stride ; + int tgt_stride3 = tgt_stride2 + tgt_stride ; + int tgt_stride4 = tgt_stride2 + tgt_stride2; + int tgt_stride5 = tgt_stride4 + tgt_stride ; + int tgt_stride6 = tgt_stride4 + tgt_stride2; + int tgt_stride7 = tgt_stride4 + tgt_stride3; + int tgt_stride8 = tgt_stride4 + tgt_stride4; + + tgt[ 0 ] = m0; // m0 + tgt[ tgt_stride ] = goldilocks_sub( s9 , s17 ); // s9 - s17 + tgt[ tgt_stride2 ] = goldilocks_sub( s1 , s3 ); // s1 - s3 + tgt[ tgt_stride3 ] = goldilocks_add( s12 , s20 ); // s12 + s20 + tgt[ tgt_stride4 ] = goldilocks_sub( m2 , m10 ); // m2 - m10 + tgt[ tgt_stride5 ] = goldilocks_sub( s11 , s19 ); // s11 - s19 + tgt[ tgt_stride6 ] = goldilocks_sub( s2 , s4 ); // s2 - s4 + tgt[ tgt_stride7 ] = goldilocks_add( s10 , s18 ); // s10 + s18 + tgt[ tgt_stride8 ] = m1; // m1 + tgt[ tgt_stride + tgt_stride8 ] = goldilocks_sub( s10 , s18 ); // s10 - s18 + tgt[ tgt_stride2 + tgt_stride8 ] = goldilocks_add( s2 , s4 ); // s2 + s4 + tgt[ tgt_stride3 + tgt_stride8 ] = goldilocks_add( s11 , s19 ); // s11 + s19 + tgt[ tgt_stride4 + tgt_stride8 ] = goldilocks_add( m2 , m10 ); // m2 + m10 + tgt[ tgt_stride5 + tgt_stride8 ] = goldilocks_sub( s12 , s20 ); // s12 - s20 + tgt[ tgt_stride6 + tgt_stride8 ] = goldilocks_add( s1 , s3 ); // s1 + s3 + tgt[ tgt_stride7 + tgt_stride8 ] = goldilocks_add( s9 , s17 ); // s9 + s17 +} + +//-------------------------------------- + +const uint64_t IDFT16_OMEGA = 0xefffffff00000001 ; +const uint64_t IDFT16_INV_OMEGA = 0x0000001000000000 ; +const uint64_t IDFT16_INV_16 = 0xefffffff10000001 ; +const uint64_t IDFT16_J = 0x0001000000000000 ; +const uint64_t IDFT16_COS_U = 0xf800000700000001 ; +const uint64_t IDFT16_COS_2U = 0x0000007fff7fff80 ; +const uint64_t IDFT16_COS_3U = 0x0007fffffff7f800 ; +const uint64_t IDFT16_MINUS_J_SIN_U = 0x0800000800000000 ; +const uint64_t IDFT16_MINUS_J_SIN_2U = 0x00000080007fff80 ; +const uint64_t IDFT16_MINUS_J_SIN_3U = 0xfff7ffff0007f801 ; + +const uint64_t IDFT16_COS_3U_PLUS_U = 0xf8080006fff7f801 ; +const uint64_t IDFT16_COS_3U_MINUS_U = 0x0807fff7fff7f800 ; +const uint64_t IDFT16_J_SIN_3U_MINUS_U = 0x08080007fff80800 ; +const uint64_t IDFT16_J_SIN_MINUS_3U_MINUS_U = 0x07f800080007f800 ; + +void short_inv_DFT_size_16_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) { + + int src_stride2 = src_stride + src_stride ; + int src_stride3 = src_stride2 + src_stride ; + int src_stride4 = src_stride2 + src_stride2; + int src_stride5 = src_stride4 + src_stride ; + int src_stride6 = src_stride4 + src_stride2; + int src_stride7 = src_stride4 + src_stride3; + int src_stride8 = src_stride4 + src_stride4; + + uint64_t x0 = src[ 0 ]; + uint64_t x1 = src[ src_stride ]; + uint64_t x2 = src[ src_stride2 ]; + uint64_t x3 = src[ src_stride3 ]; + uint64_t x4 = src[ src_stride4 ]; + uint64_t x5 = src[ src_stride5 ]; + uint64_t x6 = src[ src_stride6 ]; + uint64_t x7 = src[ src_stride7 ]; + uint64_t x8 = src[ src_stride8 ]; + uint64_t x9 = src[ src_stride + src_stride8 ]; + uint64_t x10 = src[ src_stride2 + src_stride8 ]; + uint64_t x11 = src[ src_stride3 + src_stride8 ]; + uint64_t x12 = src[ src_stride4 + src_stride8 ]; + uint64_t x13 = src[ src_stride5 + src_stride8 ]; + uint64_t x14 = src[ src_stride6 + src_stride8 ]; + uint64_t x15 = src[ src_stride7 + src_stride8 ]; + + uint64_t t1 = goldilocks_add( x0 , x8 ); // x0 + x8 + uint64_t t2 = goldilocks_add( x4 , x12 ); // x4 + x12 + uint64_t t3 = goldilocks_add( x2 , x10 ); // x2 + x10 + uint64_t t4 = goldilocks_sub( x2 , x10 ); // x2 - x10 + uint64_t t5 = goldilocks_add( x6 , x14 ); // x6 + x14 + uint64_t t6 = goldilocks_sub( x6 , x14 ); // x6 - x14 + uint64_t t7 = goldilocks_add( x1 , x9 ); // x1 + x9 + uint64_t t8 = goldilocks_sub( x1 , x9 ); // x1 - x9 + uint64_t t9 = goldilocks_add( x3 , x11 ); // x3 + x11 + uint64_t t10 = goldilocks_sub( x3 , x11 ); // x3 - x11 + uint64_t t11 = goldilocks_add( x5 , x13 ); // x5 + x13 + uint64_t t12 = goldilocks_sub( x5 , x13 ); // x5 - x13 + uint64_t t13 = goldilocks_add( x7 , x15 ); // x7 + x15 + uint64_t t14 = goldilocks_sub( x7 , x15 ); // x7 - x15 + uint64_t t15 = goldilocks_add( t1 , t2 ); // t1 + t2 + uint64_t t16 = goldilocks_add( t3 , t5 ); // t3 + t5 + uint64_t t17 = goldilocks_add( t15 , t16 ); // t15 + t16 + uint64_t t18 = goldilocks_add( t7 , t11 ); // t7 + t11 + uint64_t t19 = goldilocks_sub( t7 , t11 ); // t7 - t11 + uint64_t t20 = goldilocks_add( t9 , t13 ); // t9 + t13 + uint64_t t21 = goldilocks_sub( t9 , t13 ); // t9 - t13 + uint64_t t22 = goldilocks_add( t18 , t20 ); // t18 + t20 + uint64_t t23 = goldilocks_add( t8 , t14 ); // t8 + t14 + uint64_t t24 = goldilocks_sub( t8 , t14 ); // t8 - t14 + uint64_t t25 = goldilocks_add( t10 , t12 ); // t10 + t12 + uint64_t t26 = goldilocks_sub( t12 , t10 ); // t12 - t10 + + uint64_t m0 = goldilocks_add( t17 , t22 ); // t17 + t22 + uint64_t m1 = goldilocks_sub( t17 , t22 ); // t17 - t22 + uint64_t m2 = goldilocks_sub( t15 , t16 ); // t15 - t16 + uint64_t m3 = goldilocks_sub( t1 , t2 ); // t1 - t2 + uint64_t m4 = goldilocks_sub( x0 , x8 ); // x0 - x8 + + uint64_t m5 = goldilocks_mul( IDFT16_COS_2U , goldilocks_sub(t19 , t21) ); + uint64_t m6 = goldilocks_mul( IDFT16_COS_2U , goldilocks_sub(t4 , t6 ) ); + uint64_t m7 = goldilocks_mul( IDFT16_COS_3U , goldilocks_add(t24 , t26) ); + uint64_t m8 = goldilocks_mul( IDFT16_COS_3U_PLUS_U , t24 ); + uint64_t m9 = goldilocks_mul( IDFT16_COS_3U_MINUS_U , t26 ); + uint64_t m10 = goldilocks_mul( IDFT16_J , goldilocks_sub(t20 , t18) ); + uint64_t m11 = goldilocks_mul( IDFT16_J , goldilocks_sub(t5 , t3 ) ); + uint64_t m12 = goldilocks_mul( IDFT16_J , goldilocks_sub(x12 , x4 ) ); + uint64_t m13 = goldilocks_mul( IDFT16_MINUS_J_SIN_2U , goldilocks_add( t19 , t21) ); + uint64_t m14 = goldilocks_mul( IDFT16_MINUS_J_SIN_2U , goldilocks_add( t4 , t6 ) ); + uint64_t m15 = goldilocks_mul( IDFT16_MINUS_J_SIN_3U , goldilocks_add( t23 , t25) ); + uint64_t m16 = goldilocks_mul( IDFT16_J_SIN_3U_MINUS_U , t23 ); + uint64_t m17 = goldilocks_mul( IDFT16_J_SIN_MINUS_3U_MINUS_U , t25 ); + + uint64_t s1 = goldilocks_add( m3 , m5 ); // m3 + m5 + uint64_t s2 = goldilocks_sub( m3 , m5 ); // m3 - m5 + uint64_t s3 = goldilocks_add( m11 , m13 ); // m11 + m13 + uint64_t s4 = goldilocks_sub( m13 , m11 ); // m13 - m11 + uint64_t s5 = goldilocks_add( m4 , m6 ); // m4 + m6 + uint64_t s6 = goldilocks_sub( m4 , m6 ); // m4 - m6 + uint64_t s7 = goldilocks_sub( m8 , m7 ); // m8 - m7 + uint64_t s8 = goldilocks_sub( m9 , m7 ); // m9 - m7 + uint64_t s9 = goldilocks_add( s5 , s7 ); // s5 + s7 + uint64_t s10 = goldilocks_sub( s5 , s7 ); // s5 - s7 + uint64_t s11 = goldilocks_add( s6 , s8 ); // s6 + s8 + uint64_t s12 = goldilocks_sub( s6 , s8 ); // s6 - s8 + uint64_t s13 = goldilocks_add( m12 , m14 ); // m12 + m14 + uint64_t s14 = goldilocks_sub( m12 , m14 ); // m12 - m14 + uint64_t s15 = goldilocks_add( m15 , m16 ); // m15 + m16 + uint64_t s16 = goldilocks_sub( m15 , m17 ); // m15 - m17 + uint64_t s17 = goldilocks_add( s13 , s15 ); // s13 + s15 + uint64_t s18 = goldilocks_sub( s13 , s15 ); // s13 - s15 + uint64_t s19 = goldilocks_add( s14 , s16 ); // s14 + s16 + uint64_t s20 = goldilocks_sub( s14 , s16 ); // s14 - s16 + + int tgt_stride2 = tgt_stride + tgt_stride ; + int tgt_stride3 = tgt_stride2 + tgt_stride ; + int tgt_stride4 = tgt_stride2 + tgt_stride2; + int tgt_stride5 = tgt_stride4 + tgt_stride ; + int tgt_stride6 = tgt_stride4 + tgt_stride2; + int tgt_stride7 = tgt_stride4 + tgt_stride3; + int tgt_stride8 = tgt_stride4 + tgt_stride4; + + tgt[ 0 ] = m0; // m0 + tgt[ tgt_stride ] = goldilocks_add( s9 , s17 ); // s9 + s17 + tgt[ tgt_stride2 ] = goldilocks_add( s1 , s3 ); // s1 + s3 + tgt[ tgt_stride3 ] = goldilocks_sub( s12 , s20 ); // s12 - s20 + tgt[ tgt_stride4 ] = goldilocks_add( m2 , m10 ); // m2 + m10 + tgt[ tgt_stride5 ] = goldilocks_add( s11 , s19 ); // s11 + s19 + tgt[ tgt_stride6 ] = goldilocks_add( s2 , s4 ); // s2 + s4 + tgt[ tgt_stride7 ] = goldilocks_sub( s10 , s18 ); // s10 - s18 + tgt[ tgt_stride8 ] = m1; // m1 + tgt[ tgt_stride + tgt_stride8 ] = goldilocks_add( s10 , s18 ); // s10 + s18 + tgt[ tgt_stride2 + tgt_stride8 ] = goldilocks_sub( s2 , s4 ); // s2 - s4 + tgt[ tgt_stride3 + tgt_stride8 ] = goldilocks_sub( s11 , s19 ); // s11 - s19 + tgt[ tgt_stride4 + tgt_stride8 ] = goldilocks_sub( m2 , m10 ); // m2 - m10 + tgt[ tgt_stride5 + tgt_stride8 ] = goldilocks_add( s12 , s20 ); // s12 + s20 + tgt[ tgt_stride6 + tgt_stride8 ] = goldilocks_sub( s1 , s3 ); // s1 - s3 + tgt[ tgt_stride7 + tgt_stride8 ] = goldilocks_sub( s9 , s17 ); // s9 - s17 + +} + +//------------------ + +void short_inv_DFT_size_16_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ) { + + short_inv_DFT_size_16_unscaled( src_stride, tgt_stride, src, tgt ); + + int tgt_stride2 = tgt_stride + tgt_stride ; + int tgt_stride3 = tgt_stride2 + tgt_stride ; + int tgt_stride4 = tgt_stride2 + tgt_stride2; + int tgt_stride5 = tgt_stride4 + tgt_stride ; + int tgt_stride6 = tgt_stride4 + tgt_stride2; + int tgt_stride7 = tgt_stride4 + tgt_stride3; + int tgt_stride8 = tgt_stride4 + tgt_stride4; + + tgt[ 0 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ 0 ] ); + tgt[ tgt_stride ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride ] ); + tgt[ tgt_stride2 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride2 ] ); + tgt[ tgt_stride3 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride3 ] ); + tgt[ tgt_stride4 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride4 ] ); + tgt[ tgt_stride5 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride5 ] ); + tgt[ tgt_stride6 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride6 ] ); + tgt[ tgt_stride7 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride7 ] ); + tgt[ tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride8 ] ); + tgt[ tgt_stride + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride + tgt_stride8 ] ); + tgt[ tgt_stride2 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride2 + tgt_stride8 ] ); + tgt[ tgt_stride3 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride3 + tgt_stride8 ] ); + tgt[ tgt_stride4 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride4 + tgt_stride8 ] ); + tgt[ tgt_stride5 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride5 + tgt_stride8 ] ); + tgt[ tgt_stride6 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride6 + tgt_stride8 ] ); + tgt[ tgt_stride7 + tgt_stride8 ] = goldilocks_mul( IDFT16_INV_16 , tgt[ tgt_stride7 + tgt_stride8 ] ); + +} + +//------------------------------------------------------------------------------ diff --git a/reference/src/cbits/short_dft.h b/reference/src/cbits/short_dft.h new file mode 100644 index 0000000..fbc7022 --- /dev/null +++ b/reference/src/cbits/short_dft.h @@ -0,0 +1,24 @@ + +#include + +//------------------------------------------------------------------------------ + +void short_fwd_DFT_size_4 ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); +void short_inv_DFT_size_4_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); +void short_inv_DFT_size_4_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); + +void short_fwd_DFT_size_8 ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); +void short_inv_DFT_size_8_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); +void short_inv_DFT_size_8_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); + +void short_fwd_DFT_size_16 ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); +void short_inv_DFT_size_16_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); +void short_inv_DFT_size_16_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); + +//------------------------------------------------------------------------------ + +// void short_fwd_DFT_size_4_ext ( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); +// void short_inv_DFT_size_4_ext_unscaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); +// void short_inv_DFT_size_4_ext_rescaled( int src_stride, int tgt_stride, uint64_t *src, uint64_t *tgt ); + +//------------------------------------------------------------------------------ diff --git a/reference/src/runi.sh b/reference/src/runi.sh index 51aabbb..1321dfb 100755 --- a/reference/src/runi.sh +++ b/reference/src/runi.sh @@ -1,3 +1,3 @@ #!/bin/bash -ghci testMain.hs cbits/goldilocks.o cbits/goldilocks_ext.o cbits/monolith.o cbits/ntt.o +ghci testMain.hs cbits/goldilocks.o cbits/goldilocks_ext.o cbits/monolith.o cbits/ntt.o cbits/short_dft.o