preliminary C implementation of NTT

This commit is contained in:
Balazs Komuves 2025-11-04 10:58:02 +01:00
parent f7955ac21b
commit a9cb0a96a6
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
16 changed files with 862 additions and 226 deletions

View File

@ -87,6 +87,9 @@ newtype ReductionStrategy = MkRedStrategy
}
deriving (Eq,Show)
reductionStrategyLength :: ReductionStrategy -> Int
reductionStrategyLength = length . fromReductionStrategy
instance Binary ReductionStrategy where
put = putSmallList . fromReductionStrategy
get = MkRedStrategy <$> getSmallList
@ -123,6 +126,18 @@ data FriConfig = MkFriConfig
}
deriving (Eq,Show)
friCommitPhaseVectorSizes :: FriConfig -> [Log2]
friCommitPhaseVectorSizes (MkFriConfig{..}) = result where
result = go (rsEncodedSize friRSConfig) (fromReductionStrategy friReductionStrategy)
go n [] = []
go n (a:as) = n : go (n-a) as
friCommitPhaseTreeSizes :: FriConfig -> [Log2]
friCommitPhaseTreeSizes (MkFriConfig{..}) = result where
result = go (rsEncodedSize friRSConfig) (fromReductionStrategy friReductionStrategy)
go n [] = []
go n (a:as) = (n-a) : go (n-a) as
instance Binary FriConfig where
put (MkFriConfig{..}) = do
put friRSConfig
@ -252,3 +267,44 @@ instance Binary FriProof where
<*> get
--------------------------------------------------------------------------------
estimateFriProofSize :: FriConfig -> Int
estimateFriProofSize friConfig@(MkFriConfig{..}) = total where
total = friCfgSize + commitPhaseCaps + finalPolySize + (friNQueryRounds * queryRoundSize) + 8
arities = fromReductionStrategy friReductionStrategy
nsteps = length arities
rsCfgSize = 1 + 1 + 8
redStratSize = 1 + nsteps
friCfgSize = rsCfgSize + 8 + 1 + redStratSize + 8 + 1
(phases2,_) = friCommitPhaseSizesLog2 friConfig
merkleCapSize = 32 * exp2_ friMerkleCapSize
commitPhaseCaps = merkleCapSize * nsteps
finalPolySize = 8 * exp2_ (rsDataSize - sum arities)
queryStepSize arity = 16 * exp2_ arity +
data FriQueryStep = MkFriQueryStep
{ queryEvals :: [FExt]
, queryMerklePath :: RawMerklePath
}
deriving (Eq,Show)
data FriQueryRound = MkFriQueryRound
{ queryRow :: FRow
, queryInitialTreeProof :: RawMerklePath
, querySteps :: [FriQueryStep]
}
deriving (Eq,Show)
data FriProof = MkFriProof
{ proofFriConfig :: FriConfig -- ^ the FRI configuration
, proofCommitPhaseCaps :: [MerkleCap] -- ^ commit phase Merkle caps
, proofFinalPoly :: Poly FExt -- ^ the final polynomial in coefficient form
, proofQueryRounds :: [FriQueryRound] -- ^ query rounds
, proofPowWitness :: F -- ^ witness showing that the prover did PoW
}

View File

@ -1,8 +1,10 @@
{-# LANGUAGE TypeFamilies #-}
module Field.Class where
--------------------------------------------------------------------------------
import Data.Kind
import Data.Proxy
import System.Random
@ -26,6 +28,15 @@ class (Show a, Eq a, Num a, Fractional a) => Field a where
inverse :: Field a => a -> a
inverse = recip
-- | Quadratic extensions
class (Field (BaseField ext), Field ext) => QuadraticExt ext where
type BaseField ext :: Type
inject :: BaseField ext -> ext
project :: ext -> Maybe (BaseField ext)
scale :: BaseField ext -> ext -> ext
quadraticPack :: (BaseField ext, BaseField ext) -> ext
quadraticUnpack :: ext -> (BaseField ext, BaseField ext)
--------------------------------------------------------------------------------
instance Field Goldi.F where
@ -53,3 +64,14 @@ instance Field GoldiExt.FExt where
rndIO = randomIO
--------------------------------------------------------------------------------
instance QuadraticExt GoldiExt.FExt where
type BaseField GoldiExt.FExt = Goldi.F
inject = GoldiExt.inj
project = GoldiExt.proj
scale = GoldiExt.scl
quadraticPack = GoldiExt.pack
quadraticUnpack = GoldiExt.unpack
--------------------------------------------------------------------------------

View File

@ -139,6 +139,17 @@ binaryOpIO c_action x y = allocaBytesAligned 48 8 $ \ptr1 -> do
inj :: F -> F2
inj r = F2 r 0
proj :: F2 -> Maybe F
proj (F2 r i) = if Goldi.isZero i then Just r else Nothing
pack :: (F,F) -> F2
pack (r,i) = F2 r i
unpack :: F2 -> (F,F)
unpack (F2 r i) = (r,i)
--------------------------------------------------------------------------------
neg, sqr, inv :: F2 -> F2
neg x = unsafePerformIO (unaryOpIO c_goldilocks_ext_neg x)
sqr x = unsafePerformIO (unaryOpIO c_goldilocks_ext_sqr x)

View File

@ -108,6 +108,17 @@ isOne (F2 r i) = Goldi.isOne r && Goldi.isZero i
inj :: F -> F2
inj r = F2 r 0
proj :: F2 -> Maybe F
proj (F2 r i) = if Goldi.isZero i then Just r else Nothing
pack :: (F,F) -> F2
pack (r,i) = F2 r i
unpack :: F2 -> (F,F)
unpack (F2 r i) = (r,i)
--------------------------------------------------------------------------------
neg :: F2 -> F2
neg (F2 r i) = F2 (negate r) (negate i)

View File

@ -5,11 +5,11 @@ module NTT
( module Field.Goldilocks
, module NTT.Subgroup
, module NTT.Poly
, module NTT.Slow
, module NTT.FFT
) where
import Field.Goldilocks
import NTT.Subgroup
import NTT.Poly
import NTT.Slow
import NTT.FFT

View File

@ -0,0 +1,22 @@
module NTT.Class where
--------------------------------------------------------------------------------
import Data.Kind
import NTT.Subgroup
--------------------------------------------------------------------------------
{-
class NTT ntt where
type Field ntt :: Type
type Poly ntt :: Type
forwardNTT :: Subgroup F -> Poly F -> FlatArray F
inverseNTT :: Subgroup F -> FlatArray F -> Poly F
shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F
shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F
asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F
-}

14
reference/src/NTT/FFT.hs Normal file
View File

@ -0,0 +1,14 @@
{-# LANGUAGE CPP #-}
#ifdef USE_NAIVE_HASKELL
module NTT.FFT ( module NTT.FFT.Slow ) where
import NTT.FFT.Slow
#else
module NTT.FFT ( module NTT.FFT.Fast ) where
import NTT.FFT.Fast
#endif

View File

@ -0,0 +1,132 @@
{-# LANGUAGE StrictData, ForeignFunctionInterface #-}
module NTT.FFT.Fast where
--------------------------------------------------------------------------------
import Data.Word
import Foreign.C
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.Marshal
import System.IO.Unsafe
import Data.Flat
import NTT.Poly
import NTT.Subgroup
import Field.Goldilocks
import Misc
--------------------------------------------------------------------------------
{-
void goldilocks_ntt_forward ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_forward_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, uint64_t gen_src, uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_inverse ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_inverse_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
-}
foreign import ccall unsafe "goldilocks_ntt_forward" c_ntt_forward :: CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "goldilocks_ntt_forward_shifted" c_ntt_forward_shifted :: Word64 -> CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "goldilocks_ntt_forward_asymmetric" c_ntt_forward_asymmetric :: CInt -> CInt -> Word64 -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ();
foreign import ccall unsafe "goldilocks_ntt_inverse" c_ntt_inverse :: CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
foreign import ccall unsafe "goldilocks_ntt_inverse_shifted" c_ntt_inverse_shifted :: Word64 -> CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
--------------------------------------------------------------------------------
{-# NOINLINE forwardNTT #-}
forwardNTT :: Subgroup F -> Poly F -> FlatArray F
forwardNTT sg (MkPoly (MkFlatArray n fptr2))
| subgroupSize sg /= n = error "forwardNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n -- (n*4)
let MkGoldilocks gen = subgroupGen sg
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
c_ntt_forward (subgroupCLogSize sg) gen ptr2 ptr3
return (MkFlatArray n fptr3)
{-# NOINLINE inverseNTT #-}
inverseNTT :: Subgroup F -> FlatArray F -> Poly F
inverseNTT sg (MkFlatArray n fptr2)
| subgroupSize sg /= n = error "inverseNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n --(n*4)
let MkGoldilocks gen = subgroupGen sg
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
c_ntt_inverse (subgroupCLogSize sg) gen ptr2 ptr3
return (MkPoly (MkFlatArray n fptr3))
-- | Pre-multiplies the coefficients by powers of eta, effectively evaluating @f(eta*x)@ on the subgroup
{-# NOINLINE shiftedForwardNTT #-}
shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F
shiftedForwardNTT sg (MkGoldilocks eta) (MkPoly (MkFlatArray n fptr2))
| subgroupSize sg /= n = error "shiftedForwardNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n -- (n*4)
let MkGoldilocks gen = subgroupGen sg
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
c_ntt_forward_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3
return (MkFlatArray n fptr3)
-- | Post-multiplies the coefficients by powers of eta, effectively interpolating @f@ such that @f(eta^-1 * omega^k) = y_k@
{-# NOINLINE shiftedInverseNTT #-}
shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F
shiftedInverseNTT sg (MkGoldilocks eta) (MkFlatArray n fptr2)
| subgroupSize sg /= n = error "shiftedInverseNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n -- (n*4)
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
c_ntt_inverse_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3
return (MkPoly (MkFlatArray n fptr3))
-- | Evaluates a polynomial f on a larger subgroup than it's defined on
{-# NOINLINE asymmForwardNTT #-}
asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F
asymmForwardNTT sg_src (MkPoly (MkFlatArray n fptr2)) sg_tgt
| subgroupSize sg_src /= n = error "asymmForwardNTT: subgroup size differs from the array size"
| m < n = error "asymmForwardNTT: target subgroup size should be at least the source subgroup src"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray m -- (m*4)
let MkGoldilocks sgen1 = subgroupGen sg_src
let MkGoldilocks sgen2 = subgroupGen sg_tgt
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
c_ntt_forward_asymmetric
(subgroupCLogSize sg_src)
(subgroupCLogSize sg_tgt)
sgen1 sgen2 ptr2 ptr3
return (MkFlatArray m fptr3)
where
m = subgroupSize sg_tgt
{-
instance P.UnivariateFFT Poly where
ntt = forwardNTT
intt = inverseNTT
shiftedNTT = shiftedForwardNTT
shiftedINTT = shiftedInverseNTT
asymmNTT = asymmForwardNTT
-}
--------------------------------------------------------------------------------
subgroupCLogSize :: Subgroup a -> CInt
subgroupCLogSize = fromIntegral . fromLog2 . subgroupLogSize
subgroupGenAsWord64 :: Subgroup F -> Word64
subgroupGenAsWord64 sg = let MkGoldilocks x = subgroupGen sg in x
--------------------------------------------------------------------------------

View File

@ -1,13 +1,13 @@
{-# LANGUAGE ScopedTypeVariables #-}
module NTT.Slow where
module NTT.FFT.Slow where
--------------------------------------------------------------------------------
import Data.Array
import Data.Bits
import NTT.Poly
import NTT.Poly.Naive
import NTT.Subgroup
import Field.Goldilocks

View File

@ -1,227 +1,13 @@
{-# LANGUAGE CPP #-}
-- | Dense univariate polynomials
#ifdef USE_NAIVE_HASKELL
{-# LANGUAGE StrictData, BangPatterns, ScopedTypeVariables, DeriveFunctor #-}
module NTT.Poly where
module NTT.Poly ( module NTT.Poly.Naive ) where
import NTT.Poly.Naive
--------------------------------------------------------------------------------
#else
import Data.List
import Data.Array
import Data.Array.ST (STArray)
import Data.Array.MArray (newArray, readArray, writeArray, thaw, freeze)
module NTT.Poly ( module NTT.Poly.Flat ) where
import NTT.Poly.Flat
import Control.Monad
import Control.Monad.ST.Strict
import System.Random
import Data.Binary
import Field.Goldilocks
import Field.Goldilocks.Extension ( FExt )
import Field.Encode
import Misc
--------------------------------------------------------------------------------
-- * Univariate polynomials
-- | A dense univariate polynomial. The array index corresponds to the exponent.
newtype Poly a
= Poly (Array Int a)
deriving (Show,Functor)
instance Binary a => Binary (Poly a) where
put (Poly arr) = putSmallArray arr
get = Poly <$> getSmallArray
instance (Num a, Eq a) => Eq (Poly a) where
p == q = polyIsZero (polySub p q)
instance FieldEncode (Poly F) where
fieldEncode (Poly arr) = fieldEncode arr
instance FieldEncode (Poly FExt) where
fieldEncode (Poly arr) = fieldEncode arr
mkPoly :: [a] -> Poly a
mkPoly coeffs = Poly $ listArray (0,length coeffs-1) coeffs
-- | Degree of the polynomial
polyDegree :: (Eq a, Num a) => Poly a -> Int
polyDegree (Poly arr) = worker d0 where
(0,d0) = bounds arr
worker d
| d < 0 = -1
| arr!d /= 0 = d
| otherwise = worker (d-1)
-- | Size of the polynomial (can be larger than @degree + 1@, if the some top coefficients are zeros)
polySize :: (Eq a, Num a) => Poly a -> Int
polySize (Poly p) = arraySize p
polyIsZero :: (Eq a, Num a) => Poly a -> Bool
polyIsZero (Poly arr) = all (==0) (elems arr)
-- | Returns the coefficient of @x^k@
polyCoeff :: Num a => Poly a -> Int -> a
polyCoeff (Poly coeffs) k = safeIndex 0 coeffs k
-- | Note: this can include zero coeffs at higher than the actual degree!
polyCoeffArray :: Poly a -> Array Int a
polyCoeffArray (Poly coeffs) = coeffs
-- | Note: this cuts off the potential extra zeros at the end.
-- The order is little-endian (constant term first).
polyCoeffList :: (Eq a, Num a) => Poly a -> [a]
polyCoeffList poly@(Poly arr) = take (polyDegree poly + 1) (elems arr)
--------------------------------------------------------------------------------
-- * Elementary polynomials
-- | Constant polynomial
polyConst :: a -> Poly a
polyConst x = Poly $ listArray (0,0) [x]
-- | Zero polynomial
polyZero :: Num a => Poly a
polyZero = polyConst 0
-- | The polynomial @f(x) = x@
polyVarX :: Num a => Poly a
polyVarX = mkPoly [0,1]
-- | @polyLinear (A,B)@ means the linear polynomial @f(x) = A*x + B@
polyLinear :: (a,a) -> Poly a
polyLinear (a,b) = mkPoly [b,a]
-- | The monomial @x^n@
polyXpowN :: Num a => Int -> Poly a
polyXpowN n = Poly $ listArray (0,n) (replicate n 0 ++ [1])
-- | The binomial @(x^n - 1)@
polyXpowNminus1 :: Num a => Int -> Poly a
polyXpowNminus1 n = Poly $ listArray (0,n) (-1 : replicate (n-1) 0 ++ [1])
--------------------------------------------------------------------------------
-- * Evaluate polynomials
polyEvalAt :: forall f. Fractional f => Poly f -> f -> f
polyEvalAt (Poly arr) x = go 0 1 0 where
(0,d) = bounds arr
go :: f -> f -> Int -> f
go !acc !y !i = if i > d
then acc
else go (acc + (arr!i)*y) (y*x) (i+1)
polyEvalOnList :: forall f. Fractional f => Poly f -> [f] -> [f]
polyEvalOnList poly = map (polyEvalAt poly)
polyEvalOnArray :: forall f. Fractional f => Poly f -> Array Int f -> Array Int f
polyEvalOnArray poly = fmap (polyEvalAt poly)
--------------------------------------------------------------------------------
-- * Basic arithmetic operations on polynomials
polyNeg :: Num a => Poly a -> Poly a
polyNeg (Poly arr) = Poly $ fmap negate arr
polyAdd :: Num a => Poly a -> Poly a -> Poly a
polyAdd (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
(0,d1) = bounds arr1
(0,d2) = bounds arr2
d3 = max d1 d2
zs = zipWith (+) (elems arr1 ++ replicate (d3-d1) 0)
(elems arr2 ++ replicate (d3-d2) 0)
polySub :: Num a => Poly a -> Poly a -> Poly a
polySub (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
(0,d1) = bounds arr1
(0,d2) = bounds arr2
d3 = max d1 d2
zs = zipWith (-) (elems arr1 ++ replicate (d3-d1) 0)
(elems arr2 ++ replicate (d3-d2) 0)
polyMul :: Num a => Poly a -> Poly a -> Poly a
polyMul (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
(0,d1) = bounds arr1
(0,d2) = bounds arr2
d3 = d1 + d2
zs = [ f k | k<-[0..d3] ]
f !k = foldl' (+) 0 [ arr1!i * arr2!(k-i) | i<-[ max 0 (k-d2) .. min d1 k ] ]
instance Num a => Num (Poly a) where
fromInteger = polyConst . fromInteger
negate = polyNeg
(+) = polyAdd
(-) = polySub
(*) = polyMul
abs = id
signum = \_ -> polyConst 1
polySum :: Num a => [Poly a] -> Poly a
polySum = foldl' polyAdd 0
polyProd :: Num a => [Poly a] -> Poly a
polyProd = foldl' polyMul 1
--------------------------------------------------------------------------------
-- * Polynomial long division
-- | @polyDiv f h@ returns @(q,r)@ such that @f = q*h + r@ and @deg r < deg h@
polyDiv :: forall f. (Eq f, Fractional f) => Poly f -> Poly f -> (Poly f, Poly f)
polyDiv poly_f@(Poly arr_f) poly_h@(Poly arr_h)
| deg_q < 0 = (polyZero, poly_f)
| otherwise = runST action
where
deg_f = polyDegree poly_f
deg_h = polyDegree poly_h
deg_q = deg_f - deg_h
-- inverse of the top coefficient of divisor
b_inv = recip (arr_h ! deg_h)
action :: forall s. ST s (Poly f, Poly f)
action = do
p <- thaw arr_f :: ST s (STArray s Int f)
q <- newArray (0,deg_q) 0 :: ST s (STArray s Int f)
forM_ [deg_q,deg_q-1..0] $ \k -> do
top <- readArray p (deg_h + k)
let y = b_inv * top
writeArray q k y
forM_ [0..deg_h] $ \j -> do
a <- readArray p (j+k)
writeArray p (j+k) (a - y*(arr_h!j))
qarr <- freeze q
rs <- forM [0..deg_h-1] $ \i -> readArray p i
let rarr = listArray (0,deg_h-1) rs
return (Poly qarr, Poly rarr)
-- | Returns only the quotient
polyDivQuo :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
polyDivQuo f g = fst $ polyDiv f g
-- | Returns only the remainder
polyDivRem :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
polyDivRem f g = snd $ polyDiv f g
--------------------------------------------------------------------------------
-- * Sample random polynomials
randomPoly :: (RandomGen g, Random a) => Int -> g -> (Poly a, g)
randomPoly deg g0 =
let (coeffs,gfinal) = worker (deg+1) g0
poly = Poly (listArray (0,deg) coeffs)
in (poly, gfinal)
where
worker 0 g = ([] , g)
worker n g = let (x ,g1) = random g
(xs,g2) = worker (n-1) g1
in ((x:xs) , g)
randomPolyIO :: Random a => Int -> IO (Poly a)
randomPolyIO deg = getStdRandom (randomPoly deg)
--------------------------------------------------------------------------------
#endif

View File

@ -0,0 +1,55 @@
{-# LANGUAGE StrictData, ScopedTypeVariables, PatternSynonyms #-}
module NTT.Poly.Flat where
import Prelude hiding (div,quot,rem)
import GHC.Real hiding (div,quot,rem)
import Data.Bits
import Data.Word
import Data.List
import Data.Array
import Control.Monad
import Foreign.C
import Foreign.Ptr
import Foreign.Marshal
import Foreign.ForeignPtr
import System.Random
import System.IO.Unsafe
import NTT.Subgroup
import Field.Class
import Data.Flat as L
--------------------------------------------------------------------------------
newtype Poly a = MkPoly (L.FlatArray a)
pattern XPoly n arr = MkPoly (L.MkFlatArray n arr)
mkPoly :: Flat f => [f] -> Poly f
mkPoly = MkPoly . L.packFlatArrayFromList
mkPoly' :: Flat f => Int -> [f] -> Poly f
mkPoly' len xs = MkPoly $ L.packFlatArrayFromList' len xs
mkPolyArr :: Flat f => Array Int f -> Poly f
mkPolyArr = MkPoly . L.packFlatArray
mkPolyFlatArr :: L.FlatArray f -> Poly f
mkPolyFlatArr = MkPoly
coeffs :: Flat f => Poly f -> [f]
coeffs (MkPoly arr) = L.unpackFlatArrayToList arr
coeffsArr :: Flat f => Poly f -> Array Int f
coeffsArr (MkPoly arr) = L.unpackFlatArray arr
coeffsFlatArr :: Poly f -> L.FlatArray f
coeffsFlatArr (MkPoly flat) = flat
--------------------------------------------------------------------------------

View File

@ -0,0 +1,227 @@
-- | Dense univariate polynomials
{-# LANGUAGE StrictData, BangPatterns, ScopedTypeVariables, DeriveFunctor #-}
module NTT.Poly.Naive where
--------------------------------------------------------------------------------
import Data.List
import Data.Array
import Data.Array.ST (STArray)
import Data.Array.MArray (newArray, readArray, writeArray, thaw, freeze)
import Control.Monad
import Control.Monad.ST.Strict
import System.Random
import Data.Binary
import Field.Goldilocks
import Field.Goldilocks.Extension ( FExt )
import Field.Encode
import Misc
--------------------------------------------------------------------------------
-- * Univariate polynomials
-- | A dense univariate polynomial. The array index corresponds to the exponent.
newtype Poly a
= Poly (Array Int a)
deriving (Show,Functor)
instance Binary a => Binary (Poly a) where
put (Poly arr) = putSmallArray arr
get = Poly <$> getSmallArray
instance (Num a, Eq a) => Eq (Poly a) where
p == q = polyIsZero (polySub p q)
instance FieldEncode (Poly F) where
fieldEncode (Poly arr) = fieldEncode arr
instance FieldEncode (Poly FExt) where
fieldEncode (Poly arr) = fieldEncode arr
mkPoly :: [a] -> Poly a
mkPoly coeffs = Poly $ listArray (0,length coeffs-1) coeffs
-- | Degree of the polynomial
polyDegree :: (Eq a, Num a) => Poly a -> Int
polyDegree (Poly arr) = worker d0 where
(0,d0) = bounds arr
worker d
| d < 0 = -1
| arr!d /= 0 = d
| otherwise = worker (d-1)
-- | Size of the polynomial (can be larger than @degree + 1@, if the some top coefficients are zeros)
polySize :: (Eq a, Num a) => Poly a -> Int
polySize (Poly p) = arraySize p
polyIsZero :: (Eq a, Num a) => Poly a -> Bool
polyIsZero (Poly arr) = all (==0) (elems arr)
-- | Returns the coefficient of @x^k@
polyCoeff :: Num a => Poly a -> Int -> a
polyCoeff (Poly coeffs) k = safeIndex 0 coeffs k
-- | Note: this can include zero coeffs at higher than the actual degree!
polyCoeffArray :: Poly a -> Array Int a
polyCoeffArray (Poly coeffs) = coeffs
-- | Note: this cuts off the potential extra zeros at the end.
-- The order is little-endian (constant term first).
polyCoeffList :: (Eq a, Num a) => Poly a -> [a]
polyCoeffList poly@(Poly arr) = take (polyDegree poly + 1) (elems arr)
--------------------------------------------------------------------------------
-- * Elementary polynomials
-- | Constant polynomial
polyConst :: a -> Poly a
polyConst x = Poly $ listArray (0,0) [x]
-- | Zero polynomial
polyZero :: Num a => Poly a
polyZero = polyConst 0
-- | The polynomial @f(x) = x@
polyVarX :: Num a => Poly a
polyVarX = mkPoly [0,1]
-- | @polyLinear (A,B)@ means the linear polynomial @f(x) = A*x + B@
polyLinear :: (a,a) -> Poly a
polyLinear (a,b) = mkPoly [b,a]
-- | The monomial @x^n@
polyXpowN :: Num a => Int -> Poly a
polyXpowN n = Poly $ listArray (0,n) (replicate n 0 ++ [1])
-- | The binomial @(x^n - 1)@
polyXpowNminus1 :: Num a => Int -> Poly a
polyXpowNminus1 n = Poly $ listArray (0,n) (-1 : replicate (n-1) 0 ++ [1])
--------------------------------------------------------------------------------
-- * Evaluate polynomials
polyEvalAt :: forall f. Fractional f => Poly f -> f -> f
polyEvalAt (Poly arr) x = go 0 1 0 where
(0,d) = bounds arr
go :: f -> f -> Int -> f
go !acc !y !i = if i > d
then acc
else go (acc + (arr!i)*y) (y*x) (i+1)
polyEvalOnList :: forall f. Fractional f => Poly f -> [f] -> [f]
polyEvalOnList poly = map (polyEvalAt poly)
polyEvalOnArray :: forall f. Fractional f => Poly f -> Array Int f -> Array Int f
polyEvalOnArray poly = fmap (polyEvalAt poly)
--------------------------------------------------------------------------------
-- * Basic arithmetic operations on polynomials
polyNeg :: Num a => Poly a -> Poly a
polyNeg (Poly arr) = Poly $ fmap negate arr
polyAdd :: Num a => Poly a -> Poly a -> Poly a
polyAdd (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
(0,d1) = bounds arr1
(0,d2) = bounds arr2
d3 = max d1 d2
zs = zipWith (+) (elems arr1 ++ replicate (d3-d1) 0)
(elems arr2 ++ replicate (d3-d2) 0)
polySub :: Num a => Poly a -> Poly a -> Poly a
polySub (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
(0,d1) = bounds arr1
(0,d2) = bounds arr2
d3 = max d1 d2
zs = zipWith (-) (elems arr1 ++ replicate (d3-d1) 0)
(elems arr2 ++ replicate (d3-d2) 0)
polyMul :: Num a => Poly a -> Poly a -> Poly a
polyMul (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
(0,d1) = bounds arr1
(0,d2) = bounds arr2
d3 = d1 + d2
zs = [ f k | k<-[0..d3] ]
f !k = foldl' (+) 0 [ arr1!i * arr2!(k-i) | i<-[ max 0 (k-d2) .. min d1 k ] ]
instance Num a => Num (Poly a) where
fromInteger = polyConst . fromInteger
negate = polyNeg
(+) = polyAdd
(-) = polySub
(*) = polyMul
abs = id
signum = \_ -> polyConst 1
polySum :: Num a => [Poly a] -> Poly a
polySum = foldl' polyAdd 0
polyProd :: Num a => [Poly a] -> Poly a
polyProd = foldl' polyMul 1
--------------------------------------------------------------------------------
-- * Polynomial long division
-- | @polyDiv f h@ returns @(q,r)@ such that @f = q*h + r@ and @deg r < deg h@
polyDiv :: forall f. (Eq f, Fractional f) => Poly f -> Poly f -> (Poly f, Poly f)
polyDiv poly_f@(Poly arr_f) poly_h@(Poly arr_h)
| deg_q < 0 = (polyZero, poly_f)
| otherwise = runST action
where
deg_f = polyDegree poly_f
deg_h = polyDegree poly_h
deg_q = deg_f - deg_h
-- inverse of the top coefficient of divisor
b_inv = recip (arr_h ! deg_h)
action :: forall s. ST s (Poly f, Poly f)
action = do
p <- thaw arr_f :: ST s (STArray s Int f)
q <- newArray (0,deg_q) 0 :: ST s (STArray s Int f)
forM_ [deg_q,deg_q-1..0] $ \k -> do
top <- readArray p (deg_h + k)
let y = b_inv * top
writeArray q k y
forM_ [0..deg_h] $ \j -> do
a <- readArray p (j+k)
writeArray p (j+k) (a - y*(arr_h!j))
qarr <- freeze q
rs <- forM [0..deg_h-1] $ \i -> readArray p i
let rarr = listArray (0,deg_h-1) rs
return (Poly qarr, Poly rarr)
-- | Returns only the quotient
polyDivQuo :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
polyDivQuo f g = fst $ polyDiv f g
-- | Returns only the remainder
polyDivRem :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
polyDivRem f g = snd $ polyDiv f g
--------------------------------------------------------------------------------
-- * Sample random polynomials
randomPoly :: (RandomGen g, Random a) => Int -> g -> (Poly a, g)
randomPoly deg g0 =
let (coeffs,gfinal) = worker (deg+1) g0
poly = Poly (listArray (0,deg) coeffs)
in (poly, gfinal)
where
worker 0 g = ([] , g)
worker n g = let (x ,g1) = random g
(xs,g2) = worker (n-1) g1
in ((x:xs) , g)
randomPolyIO :: Random a => Int -> IO (Poly a)
randomPolyIO deg = getStdRandom (randomPoly deg)
--------------------------------------------------------------------------------

View File

@ -20,6 +20,9 @@ data Subgroup g = MkSubgroup
subgroupSize :: Subgroup g -> Int
subgroupSize = subgroupOrder
subgroupLogSize :: Subgroup g -> Log2
subgroupLogSize = exactLog2__ . subgroupSize
getSubgroup :: Log2 -> Subgroup F
getSubgroup log2@(Log2 n)
| n<0 = error "getSubgroup: negative logarithm"

View File

@ -0,0 +1,55 @@
module NTT.Tests where
--------------------------------------------------------------------------------
import Data.Array
import Control.Monad
import Field.Class
import Field.Goldilocks ( F )
import Misc
import Data.Flat as L
import NTT.Subgroup
import qualified NTT.Poly.Flat as F
import qualified NTT.FFT.Fast as F
import qualified NTT.Poly.Naive as S
import qualified NTT.FFT.Slow as S
--------------------------------------------------------------------------------
-- | Compare slow and fast FFT implementations to each other
compareFFTs :: Log2 -> IO Bool
compareFFTs logSize = do
let size = exp2_ logSize
let sg = getSubgroup logSize
values1_s <- listToArray <$> replicateM size rndIO :: IO (Array Int F)
let values1_f = L.packFlatArray values1_s :: L.FlatArray F
let poly1_s = S.Poly values1_s :: S.Poly F
let poly1_f = F.mkPolyArr values1_s :: F.Poly F
let values2_s = S.subgroupNTT sg poly1_s :: Array Int F
let values2_f = F.forwardNTT sg poly1_f :: L.FlatArray F
let poly2_s = S.subgroupINTT sg values1_s :: S.Poly F
let poly2_f = F.inverseNTT sg values1_f :: F.Poly F
let ok_ntt = values2_s == L.unpackFlatArray values2_f
let ok_intt = S.polyCoeffArray poly2_s == F.coeffsArr poly2_f
return (ok_ntt && ok_intt)
--------------------------------------------------------------------------------
runTests :: Int -> Log2 -> IO Bool
runTests n size = do
oks <- replicateM n (compareFFTs size)
return (and oks)
--------------------------------------------------------------------------------

228
reference/src/cbits/ntt.c Normal file
View File

@ -0,0 +1,228 @@
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include "goldilocks.h"
#include "ntt.h"
// -----------------------------------------------------------------------------
void goldilocks_ntt_forward_noalloc(int m, int src_stride, const uint64_t *gpows, const uint64_t *src, uint64_t *buf, uint64_t *tgt) {
if (m==0) {
tgt[0] = src[0];
return;
}
if (m==1) {
// N = 2
tgt[0] = goldilocks_add( src[0] , src[src_stride] ); // x + y
tgt[1] = goldilocks_sub( src[0] , src[src_stride] ); // x - y
return;
}
else {
int N = (1<< m );
int halfN = (1<<(m-1));
goldilocks_ntt_forward_noalloc( m-1 , src_stride<<1 , gpows , src , buf + N , buf );
goldilocks_ntt_forward_noalloc( m-1 , src_stride<<1 , gpows , src + src_stride , buf + N , buf + halfN );
for(int j=0; j<halfN; j++) {
const uint64_t gpow = gpows[j*src_stride];
tgt[j ] = goldilocks_mul( buf[j+halfN] , gpow ); // g*v[k]
tgt[j+halfN] = goldilocks_neg( tgt[j ] ); // - g*v[k]
tgt[j ] = goldilocks_add( tgt[j ] , buf[j] ); // u[k] + g*v[k]
tgt[j+halfN] = goldilocks_add( tgt[j+halfN] , buf[j] ); // u[k] - g*v[k]
}
}
}
// forward number-theoretical transform (evaluation of a polynomial)
// `src` and `tgt` should be `N = 2^m` sized arrays of field elements
// `gen` should be the generator of the multiplicative subgroup sized `N`
void goldilocks_ntt_forward(int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
int N = (1<<m);
int halfN = (N>>1);
// precalculate [1,g,g^2,g^3...]
uint64_t *gpows = (uint64_t*) malloc( 8 * halfN );
assert( gpows != 0 );
uint64_t x = gen;
gpows[0] = 1;
gpows[1] = gen;
for(int i=2; i<halfN; i++) {
x = goldilocks_mul( x , gen );
gpows[i] = x;
}
uint64_t *buf = (uint64_t*) malloc( 8 * (2*N) );
assert( buf != 0 );
goldilocks_ntt_forward_noalloc( m, 1, gpows, src, buf, tgt);
free(buf);
free(gpows);
}
// it's like `ntt_forward` but we pre-multiply the coefficients with `eta^k`
// resulting in evaluating f(eta*x) instead of f(x)
void goldilocks_ntt_forward_shifted(const uint64_t eta, int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
int N = (1<<m);
uint64_t *shifted = malloc( 8 * N );
assert( shifted != 0 );
uint64_t x = 1;
for(int i=0; i<N; i++) {
shifted[i] = goldilocks_mul( src[i] , x );
x = goldilocks_mul( x , eta );
}
goldilocks_ntt_forward( m, gen, shifted, tgt );
free(shifted);
}
// it's like `ntt_forward` but asymmetric, evaluating on a larger target subgroup
void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, const uint64_t gen_src, const uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt) {
assert( m_tgt >= m_src );
int N_src = (1 << m_src);
int N_tgt = (1 << m_tgt);
int halfN_src = (N_src >> 1);
int K = (1 << (m_tgt - m_src));
// precalculate [1,g,g^2,g^3...]
uint64_t *gpows = malloc( 8 * halfN_src );
assert( gpows != 0 );
uint64_t x = gen_src;
gpows[0] = 1;
gpows[1] = gen_src;
for(int i=2; i<halfN_src; i++) {
x = goldilocks_mul(x, gen_src);
gpows[i] = x;
}
uint64_t *shifted = malloc( 8 * N_src );
assert( shifted != 0 );
uint64_t *buf = malloc( 8 * (2*N_src) );
assert( buf != 0 );
// temporary target buffer (we could replace this by adding `tgt_stride`)
uint64_t *tgt_small = malloc( 8 * N_src );
assert( tgt_small != 0 );
// eta will be the shift
uint64_t eta = 1;
for(int k=0; k<K; k++) {
if (k==0) {
memcpy( shifted, src, N_src*8 );
}
else {
eta = goldilocks_mul( eta , gen_tgt );
uint64_t x = 1;
for(int i=0; i<N_src; i++) {
shifted[i] = goldilocks_mul( src[i] , x );
x = goldilocks_mul(x, eta);
}
}
goldilocks_ntt_forward_noalloc( m_src, 1, gpows, shifted, buf, tgt_small );
uint64_t *p = tgt_small;
uint64_t *q = tgt + k;
int tgt_stride = K;
for(int i=0; i<N_src; i++) {
q[i] = p[i];
p += 1;
q += tgt_stride;
}
}
free(tgt_small);
free(buf);
free(gpows);
free(shifted);
}
// -----------------------------------------------------------------------------
// inverse of 2 (which is is the same as `(p+1)/2`)
const uint64_t goldilocks_oneHalf = 0x7fffffff80000001ull;
void goldilocks_ntt_inverse_noalloc(int m, int tgt_stride, const uint64_t *gpows, const uint64_t *src, uint64_t *buf, uint64_t *tgt) {
if (m==0) {
tgt[0] = src[0];
return;
}
if (m==1) {
// N = 2
tgt[0 ] = goldilocks_add( src[0] , src[1] ); // x + y
tgt[tgt_stride] = goldilocks_sub( src[0] , src[1] ); // x - y
tgt[0 ] = goldilocks_div_by_2( tgt[0 ] ); // (x + y)/2
tgt[tgt_stride] = goldilocks_div_by_2( tgt[tgt_stride] ); // (x - y)/2
return;
}
else {
int N = (1<< m );
int halfN = (1<<(m-1));
for(int j=0; j<halfN; j++) {
uint64_t gpow = gpows[j*tgt_stride];
buf[j ] = goldilocks_add( src[j] , src[j+halfN] ); // x + y
buf[j+halfN] = goldilocks_sub( src[j] , src[j+halfN] ); // x - y
buf[j ] = goldilocks_div_by_2( buf[j ] ); // (x + y) / 2
buf[j+halfN] = goldilocks_mul ( buf[j+halfN] , gpow ); // (x - y) / (2*g^k)
}
goldilocks_ntt_inverse_noalloc( m-1 , tgt_stride<<1 , gpows , buf , buf + N , tgt );
goldilocks_ntt_inverse_noalloc( m-1 , tgt_stride<<1 , gpows , buf + halfN , buf + N , tgt + tgt_stride );
}
}
// inverse number-theoretical transform (interpolation of a polynomial)
// `src` and `tgt` should be `N = 2^m` sized arrays of field elements
// `gen` should be the generator of the multiplicative subgroup sized `N`
void goldilocks_ntt_inverse(int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
int N = (1<<m);
int halfN = (N>>1);
// precalculate [1/2,g^{-1}/2,g^{-2}/2,g^{-3}/2...]
uint64_t *gpows = malloc( 8 * halfN );
assert( gpows != 0 );
uint64_t x = goldilocks_oneHalf; // 1/2
uint64_t ginv = goldilocks_inv(gen); // gen^-1
for(int i=0; i<halfN; i++) {
gpows[i] = x;
x = goldilocks_mul(x, ginv);
}
uint64_t *buf = malloc( 8 * (2*N) );
assert( buf !=0 );
goldilocks_ntt_inverse_noalloc( m, 1, gpows, src, buf, tgt );
free(buf);
free(gpows);
}
// it's like `ntt_inverse` but we post-multiply the resulting coefficients with `eta^k`
// resulting in interpolating an f such that f(eta^-1 * omega^k) = y_k
void goldilocks_ntt_inverse_shifted(const uint64_t eta, int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
int N = (1<<m);
uint64_t *unshifted = malloc( 8*N );
assert( unshifted != 0 );
goldilocks_ntt_inverse( m, gen, src, unshifted );
uint64_t x = 1;
for(int i=0; i<N; i++) {
tgt[i] = goldilocks_mul( unshifted[i] , x );
x = goldilocks_mul( x , eta );
}
free(unshifted);
}
// -----------------------------------------------------------------------------

14
reference/src/cbits/ntt.h Normal file
View File

@ -0,0 +1,14 @@
#include <stdint.h>
//------------------------------------------------------------------------------
void goldilocks_ntt_forward ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_forward_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, uint64_t gen_src, uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_inverse ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
void goldilocks_ntt_inverse_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
//------------------------------------------------------------------------------