mirror of
https://github.com/logos-storage/outsourcing-Reed-Solomon.git
synced 2026-01-02 13:43:07 +00:00
preliminary C implementation of NTT
This commit is contained in:
parent
f7955ac21b
commit
a9cb0a96a6
@ -87,6 +87,9 @@ newtype ReductionStrategy = MkRedStrategy
|
||||
}
|
||||
deriving (Eq,Show)
|
||||
|
||||
reductionStrategyLength :: ReductionStrategy -> Int
|
||||
reductionStrategyLength = length . fromReductionStrategy
|
||||
|
||||
instance Binary ReductionStrategy where
|
||||
put = putSmallList . fromReductionStrategy
|
||||
get = MkRedStrategy <$> getSmallList
|
||||
@ -123,6 +126,18 @@ data FriConfig = MkFriConfig
|
||||
}
|
||||
deriving (Eq,Show)
|
||||
|
||||
friCommitPhaseVectorSizes :: FriConfig -> [Log2]
|
||||
friCommitPhaseVectorSizes (MkFriConfig{..}) = result where
|
||||
result = go (rsEncodedSize friRSConfig) (fromReductionStrategy friReductionStrategy)
|
||||
go n [] = []
|
||||
go n (a:as) = n : go (n-a) as
|
||||
|
||||
friCommitPhaseTreeSizes :: FriConfig -> [Log2]
|
||||
friCommitPhaseTreeSizes (MkFriConfig{..}) = result where
|
||||
result = go (rsEncodedSize friRSConfig) (fromReductionStrategy friReductionStrategy)
|
||||
go n [] = []
|
||||
go n (a:as) = (n-a) : go (n-a) as
|
||||
|
||||
instance Binary FriConfig where
|
||||
put (MkFriConfig{..}) = do
|
||||
put friRSConfig
|
||||
@ -252,3 +267,44 @@ instance Binary FriProof where
|
||||
<*> get
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
estimateFriProofSize :: FriConfig -> Int
|
||||
estimateFriProofSize friConfig@(MkFriConfig{..}) = total where
|
||||
|
||||
total = friCfgSize + commitPhaseCaps + finalPolySize + (friNQueryRounds * queryRoundSize) + 8
|
||||
|
||||
arities = fromReductionStrategy friReductionStrategy
|
||||
nsteps = length arities
|
||||
|
||||
rsCfgSize = 1 + 1 + 8
|
||||
redStratSize = 1 + nsteps
|
||||
friCfgSize = rsCfgSize + 8 + 1 + redStratSize + 8 + 1
|
||||
|
||||
(phases2,_) = friCommitPhaseSizesLog2 friConfig
|
||||
|
||||
merkleCapSize = 32 * exp2_ friMerkleCapSize
|
||||
commitPhaseCaps = merkleCapSize * nsteps
|
||||
finalPolySize = 8 * exp2_ (rsDataSize - sum arities)
|
||||
|
||||
queryStepSize arity = 16 * exp2_ arity +
|
||||
|
||||
data FriQueryStep = MkFriQueryStep
|
||||
{ queryEvals :: [FExt]
|
||||
, queryMerklePath :: RawMerklePath
|
||||
}
|
||||
deriving (Eq,Show)
|
||||
|
||||
data FriQueryRound = MkFriQueryRound
|
||||
{ queryRow :: FRow
|
||||
, queryInitialTreeProof :: RawMerklePath
|
||||
, querySteps :: [FriQueryStep]
|
||||
}
|
||||
deriving (Eq,Show)
|
||||
|
||||
data FriProof = MkFriProof
|
||||
{ proofFriConfig :: FriConfig -- ^ the FRI configuration
|
||||
, proofCommitPhaseCaps :: [MerkleCap] -- ^ commit phase Merkle caps
|
||||
, proofFinalPoly :: Poly FExt -- ^ the final polynomial in coefficient form
|
||||
, proofQueryRounds :: [FriQueryRound] -- ^ query rounds
|
||||
, proofPowWitness :: F -- ^ witness showing that the prover did PoW
|
||||
}
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
module Field.Class where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Kind
|
||||
import Data.Proxy
|
||||
|
||||
import System.Random
|
||||
@ -26,6 +28,15 @@ class (Show a, Eq a, Num a, Fractional a) => Field a where
|
||||
inverse :: Field a => a -> a
|
||||
inverse = recip
|
||||
|
||||
-- | Quadratic extensions
|
||||
class (Field (BaseField ext), Field ext) => QuadraticExt ext where
|
||||
type BaseField ext :: Type
|
||||
inject :: BaseField ext -> ext
|
||||
project :: ext -> Maybe (BaseField ext)
|
||||
scale :: BaseField ext -> ext -> ext
|
||||
quadraticPack :: (BaseField ext, BaseField ext) -> ext
|
||||
quadraticUnpack :: ext -> (BaseField ext, BaseField ext)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
instance Field Goldi.F where
|
||||
@ -53,3 +64,14 @@ instance Field GoldiExt.FExt where
|
||||
rndIO = randomIO
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
instance QuadraticExt GoldiExt.FExt where
|
||||
type BaseField GoldiExt.FExt = Goldi.F
|
||||
|
||||
inject = GoldiExt.inj
|
||||
project = GoldiExt.proj
|
||||
scale = GoldiExt.scl
|
||||
quadraticPack = GoldiExt.pack
|
||||
quadraticUnpack = GoldiExt.unpack
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
@ -139,6 +139,17 @@ binaryOpIO c_action x y = allocaBytesAligned 48 8 $ \ptr1 -> do
|
||||
inj :: F -> F2
|
||||
inj r = F2 r 0
|
||||
|
||||
proj :: F2 -> Maybe F
|
||||
proj (F2 r i) = if Goldi.isZero i then Just r else Nothing
|
||||
|
||||
pack :: (F,F) -> F2
|
||||
pack (r,i) = F2 r i
|
||||
|
||||
unpack :: F2 -> (F,F)
|
||||
unpack (F2 r i) = (r,i)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
neg, sqr, inv :: F2 -> F2
|
||||
neg x = unsafePerformIO (unaryOpIO c_goldilocks_ext_neg x)
|
||||
sqr x = unsafePerformIO (unaryOpIO c_goldilocks_ext_sqr x)
|
||||
|
||||
@ -108,6 +108,17 @@ isOne (F2 r i) = Goldi.isOne r && Goldi.isZero i
|
||||
inj :: F -> F2
|
||||
inj r = F2 r 0
|
||||
|
||||
proj :: F2 -> Maybe F
|
||||
proj (F2 r i) = if Goldi.isZero i then Just r else Nothing
|
||||
|
||||
pack :: (F,F) -> F2
|
||||
pack (r,i) = F2 r i
|
||||
|
||||
unpack :: F2 -> (F,F)
|
||||
unpack (F2 r i) = (r,i)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
neg :: F2 -> F2
|
||||
neg (F2 r i) = F2 (negate r) (negate i)
|
||||
|
||||
|
||||
@ -5,11 +5,11 @@ module NTT
|
||||
( module Field.Goldilocks
|
||||
, module NTT.Subgroup
|
||||
, module NTT.Poly
|
||||
, module NTT.Slow
|
||||
, module NTT.FFT
|
||||
) where
|
||||
|
||||
import Field.Goldilocks
|
||||
import NTT.Subgroup
|
||||
import NTT.Poly
|
||||
import NTT.Slow
|
||||
import NTT.FFT
|
||||
|
||||
|
||||
22
reference/src/NTT/Class.hs
Normal file
22
reference/src/NTT/Class.hs
Normal file
@ -0,0 +1,22 @@
|
||||
|
||||
module NTT.Class where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Kind
|
||||
|
||||
import NTT.Subgroup
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
{-
|
||||
class NTT ntt where
|
||||
type Field ntt :: Type
|
||||
type Poly ntt :: Type
|
||||
|
||||
forwardNTT :: Subgroup F -> Poly F -> FlatArray F
|
||||
inverseNTT :: Subgroup F -> FlatArray F -> Poly F
|
||||
shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F
|
||||
shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F
|
||||
asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F
|
||||
-}
|
||||
14
reference/src/NTT/FFT.hs
Normal file
14
reference/src/NTT/FFT.hs
Normal file
@ -0,0 +1,14 @@
|
||||
|
||||
{-# LANGUAGE CPP #-}
|
||||
|
||||
#ifdef USE_NAIVE_HASKELL
|
||||
|
||||
module NTT.FFT ( module NTT.FFT.Slow ) where
|
||||
import NTT.FFT.Slow
|
||||
|
||||
#else
|
||||
|
||||
module NTT.FFT ( module NTT.FFT.Fast ) where
|
||||
import NTT.FFT.Fast
|
||||
|
||||
#endif
|
||||
132
reference/src/NTT/FFT/Fast.hs
Normal file
132
reference/src/NTT/FFT/Fast.hs
Normal file
@ -0,0 +1,132 @@
|
||||
|
||||
{-# LANGUAGE StrictData, ForeignFunctionInterface #-}
|
||||
|
||||
module NTT.FFT.Fast where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Word
|
||||
|
||||
import Foreign.C
|
||||
import Foreign.Ptr
|
||||
import Foreign.ForeignPtr
|
||||
import Foreign.Marshal
|
||||
|
||||
import System.IO.Unsafe
|
||||
|
||||
import Data.Flat
|
||||
|
||||
import NTT.Poly
|
||||
import NTT.Subgroup
|
||||
import Field.Goldilocks
|
||||
import Misc
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
{-
|
||||
void goldilocks_ntt_forward ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||
void goldilocks_ntt_forward_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||
|
||||
void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, uint64_t gen_src, uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt);
|
||||
|
||||
void goldilocks_ntt_inverse ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||
void goldilocks_ntt_inverse_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||
-}
|
||||
|
||||
foreign import ccall unsafe "goldilocks_ntt_forward" c_ntt_forward :: CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
foreign import ccall unsafe "goldilocks_ntt_forward_shifted" c_ntt_forward_shifted :: Word64 -> CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
|
||||
foreign import ccall unsafe "goldilocks_ntt_forward_asymmetric" c_ntt_forward_asymmetric :: CInt -> CInt -> Word64 -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ();
|
||||
|
||||
foreign import ccall unsafe "goldilocks_ntt_inverse" c_ntt_inverse :: CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
foreign import ccall unsafe "goldilocks_ntt_inverse_shifted" c_ntt_inverse_shifted :: Word64 -> CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
{-# NOINLINE forwardNTT #-}
|
||||
forwardNTT :: Subgroup F -> Poly F -> FlatArray F
|
||||
forwardNTT sg (MkPoly (MkFlatArray n fptr2))
|
||||
| subgroupSize sg /= n = error "forwardNTT: subgroup size differs from the array size"
|
||||
| otherwise = unsafePerformIO $ do
|
||||
fptr3 <- mallocForeignPtrArray n -- (n*4)
|
||||
let MkGoldilocks gen = subgroupGen sg
|
||||
withForeignPtr fptr2 $ \ptr2 -> do
|
||||
withForeignPtr fptr3 $ \ptr3 -> do
|
||||
c_ntt_forward (subgroupCLogSize sg) gen ptr2 ptr3
|
||||
return (MkFlatArray n fptr3)
|
||||
|
||||
{-# NOINLINE inverseNTT #-}
|
||||
inverseNTT :: Subgroup F -> FlatArray F -> Poly F
|
||||
inverseNTT sg (MkFlatArray n fptr2)
|
||||
| subgroupSize sg /= n = error "inverseNTT: subgroup size differs from the array size"
|
||||
| otherwise = unsafePerformIO $ do
|
||||
fptr3 <- mallocForeignPtrArray n --(n*4)
|
||||
let MkGoldilocks gen = subgroupGen sg
|
||||
withForeignPtr fptr2 $ \ptr2 -> do
|
||||
withForeignPtr fptr3 $ \ptr3 -> do
|
||||
c_ntt_inverse (subgroupCLogSize sg) gen ptr2 ptr3
|
||||
return (MkPoly (MkFlatArray n fptr3))
|
||||
|
||||
-- | Pre-multiplies the coefficients by powers of eta, effectively evaluating @f(eta*x)@ on the subgroup
|
||||
{-# NOINLINE shiftedForwardNTT #-}
|
||||
shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F
|
||||
shiftedForwardNTT sg (MkGoldilocks eta) (MkPoly (MkFlatArray n fptr2))
|
||||
| subgroupSize sg /= n = error "shiftedForwardNTT: subgroup size differs from the array size"
|
||||
| otherwise = unsafePerformIO $ do
|
||||
fptr3 <- mallocForeignPtrArray n -- (n*4)
|
||||
let MkGoldilocks gen = subgroupGen sg
|
||||
withForeignPtr fptr2 $ \ptr2 -> do
|
||||
withForeignPtr fptr3 $ \ptr3 -> do
|
||||
c_ntt_forward_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3
|
||||
return (MkFlatArray n fptr3)
|
||||
|
||||
-- | Post-multiplies the coefficients by powers of eta, effectively interpolating @f@ such that @f(eta^-1 * omega^k) = y_k@
|
||||
{-# NOINLINE shiftedInverseNTT #-}
|
||||
shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F
|
||||
shiftedInverseNTT sg (MkGoldilocks eta) (MkFlatArray n fptr2)
|
||||
| subgroupSize sg /= n = error "shiftedInverseNTT: subgroup size differs from the array size"
|
||||
| otherwise = unsafePerformIO $ do
|
||||
fptr3 <- mallocForeignPtrArray n -- (n*4)
|
||||
withForeignPtr fptr2 $ \ptr2 -> do
|
||||
withForeignPtr fptr3 $ \ptr3 -> do
|
||||
c_ntt_inverse_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3
|
||||
return (MkPoly (MkFlatArray n fptr3))
|
||||
|
||||
-- | Evaluates a polynomial f on a larger subgroup than it's defined on
|
||||
{-# NOINLINE asymmForwardNTT #-}
|
||||
asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F
|
||||
asymmForwardNTT sg_src (MkPoly (MkFlatArray n fptr2)) sg_tgt
|
||||
| subgroupSize sg_src /= n = error "asymmForwardNTT: subgroup size differs from the array size"
|
||||
| m < n = error "asymmForwardNTT: target subgroup size should be at least the source subgroup src"
|
||||
| otherwise = unsafePerformIO $ do
|
||||
fptr3 <- mallocForeignPtrArray m -- (m*4)
|
||||
let MkGoldilocks sgen1 = subgroupGen sg_src
|
||||
let MkGoldilocks sgen2 = subgroupGen sg_tgt
|
||||
withForeignPtr fptr2 $ \ptr2 -> do
|
||||
withForeignPtr fptr3 $ \ptr3 -> do
|
||||
c_ntt_forward_asymmetric
|
||||
(subgroupCLogSize sg_src)
|
||||
(subgroupCLogSize sg_tgt)
|
||||
sgen1 sgen2 ptr2 ptr3
|
||||
return (MkFlatArray m fptr3)
|
||||
where
|
||||
m = subgroupSize sg_tgt
|
||||
|
||||
{-
|
||||
instance P.UnivariateFFT Poly where
|
||||
ntt = forwardNTT
|
||||
intt = inverseNTT
|
||||
shiftedNTT = shiftedForwardNTT
|
||||
shiftedINTT = shiftedInverseNTT
|
||||
asymmNTT = asymmForwardNTT
|
||||
-}
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
subgroupCLogSize :: Subgroup a -> CInt
|
||||
subgroupCLogSize = fromIntegral . fromLog2 . subgroupLogSize
|
||||
|
||||
subgroupGenAsWord64 :: Subgroup F -> Word64
|
||||
subgroupGenAsWord64 sg = let MkGoldilocks x = subgroupGen sg in x
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
@ -1,13 +1,13 @@
|
||||
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
module NTT.Slow where
|
||||
module NTT.FFT.Slow where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Array
|
||||
import Data.Bits
|
||||
|
||||
import NTT.Poly
|
||||
import NTT.Poly.Naive
|
||||
import NTT.Subgroup
|
||||
|
||||
import Field.Goldilocks
|
||||
@ -1,227 +1,13 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
|
||||
-- | Dense univariate polynomials
|
||||
#ifdef USE_NAIVE_HASKELL
|
||||
|
||||
{-# LANGUAGE StrictData, BangPatterns, ScopedTypeVariables, DeriveFunctor #-}
|
||||
module NTT.Poly where
|
||||
module NTT.Poly ( module NTT.Poly.Naive ) where
|
||||
import NTT.Poly.Naive
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
#else
|
||||
|
||||
import Data.List
|
||||
import Data.Array
|
||||
import Data.Array.ST (STArray)
|
||||
import Data.Array.MArray (newArray, readArray, writeArray, thaw, freeze)
|
||||
module NTT.Poly ( module NTT.Poly.Flat ) where
|
||||
import NTT.Poly.Flat
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.ST.Strict
|
||||
|
||||
import System.Random
|
||||
|
||||
import Data.Binary
|
||||
|
||||
import Field.Goldilocks
|
||||
import Field.Goldilocks.Extension ( FExt )
|
||||
import Field.Encode
|
||||
|
||||
import Misc
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Univariate polynomials
|
||||
|
||||
-- | A dense univariate polynomial. The array index corresponds to the exponent.
|
||||
newtype Poly a
|
||||
= Poly (Array Int a)
|
||||
deriving (Show,Functor)
|
||||
|
||||
instance Binary a => Binary (Poly a) where
|
||||
put (Poly arr) = putSmallArray arr
|
||||
get = Poly <$> getSmallArray
|
||||
|
||||
instance (Num a, Eq a) => Eq (Poly a) where
|
||||
p == q = polyIsZero (polySub p q)
|
||||
|
||||
instance FieldEncode (Poly F) where
|
||||
fieldEncode (Poly arr) = fieldEncode arr
|
||||
|
||||
instance FieldEncode (Poly FExt) where
|
||||
fieldEncode (Poly arr) = fieldEncode arr
|
||||
|
||||
mkPoly :: [a] -> Poly a
|
||||
mkPoly coeffs = Poly $ listArray (0,length coeffs-1) coeffs
|
||||
|
||||
-- | Degree of the polynomial
|
||||
polyDegree :: (Eq a, Num a) => Poly a -> Int
|
||||
polyDegree (Poly arr) = worker d0 where
|
||||
(0,d0) = bounds arr
|
||||
worker d
|
||||
| d < 0 = -1
|
||||
| arr!d /= 0 = d
|
||||
| otherwise = worker (d-1)
|
||||
|
||||
-- | Size of the polynomial (can be larger than @degree + 1@, if the some top coefficients are zeros)
|
||||
polySize :: (Eq a, Num a) => Poly a -> Int
|
||||
polySize (Poly p) = arraySize p
|
||||
|
||||
polyIsZero :: (Eq a, Num a) => Poly a -> Bool
|
||||
polyIsZero (Poly arr) = all (==0) (elems arr)
|
||||
|
||||
-- | Returns the coefficient of @x^k@
|
||||
polyCoeff :: Num a => Poly a -> Int -> a
|
||||
polyCoeff (Poly coeffs) k = safeIndex 0 coeffs k
|
||||
|
||||
-- | Note: this can include zero coeffs at higher than the actual degree!
|
||||
polyCoeffArray :: Poly a -> Array Int a
|
||||
polyCoeffArray (Poly coeffs) = coeffs
|
||||
|
||||
-- | Note: this cuts off the potential extra zeros at the end.
|
||||
-- The order is little-endian (constant term first).
|
||||
polyCoeffList :: (Eq a, Num a) => Poly a -> [a]
|
||||
polyCoeffList poly@(Poly arr) = take (polyDegree poly + 1) (elems arr)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Elementary polynomials
|
||||
|
||||
-- | Constant polynomial
|
||||
polyConst :: a -> Poly a
|
||||
polyConst x = Poly $ listArray (0,0) [x]
|
||||
|
||||
-- | Zero polynomial
|
||||
polyZero :: Num a => Poly a
|
||||
polyZero = polyConst 0
|
||||
|
||||
-- | The polynomial @f(x) = x@
|
||||
polyVarX :: Num a => Poly a
|
||||
polyVarX = mkPoly [0,1]
|
||||
|
||||
-- | @polyLinear (A,B)@ means the linear polynomial @f(x) = A*x + B@
|
||||
polyLinear :: (a,a) -> Poly a
|
||||
polyLinear (a,b) = mkPoly [b,a]
|
||||
|
||||
-- | The monomial @x^n@
|
||||
polyXpowN :: Num a => Int -> Poly a
|
||||
polyXpowN n = Poly $ listArray (0,n) (replicate n 0 ++ [1])
|
||||
|
||||
-- | The binomial @(x^n - 1)@
|
||||
polyXpowNminus1 :: Num a => Int -> Poly a
|
||||
polyXpowNminus1 n = Poly $ listArray (0,n) (-1 : replicate (n-1) 0 ++ [1])
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Evaluate polynomials
|
||||
|
||||
polyEvalAt :: forall f. Fractional f => Poly f -> f -> f
|
||||
polyEvalAt (Poly arr) x = go 0 1 0 where
|
||||
(0,d) = bounds arr
|
||||
go :: f -> f -> Int -> f
|
||||
go !acc !y !i = if i > d
|
||||
then acc
|
||||
else go (acc + (arr!i)*y) (y*x) (i+1)
|
||||
|
||||
polyEvalOnList :: forall f. Fractional f => Poly f -> [f] -> [f]
|
||||
polyEvalOnList poly = map (polyEvalAt poly)
|
||||
|
||||
polyEvalOnArray :: forall f. Fractional f => Poly f -> Array Int f -> Array Int f
|
||||
polyEvalOnArray poly = fmap (polyEvalAt poly)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Basic arithmetic operations on polynomials
|
||||
|
||||
polyNeg :: Num a => Poly a -> Poly a
|
||||
polyNeg (Poly arr) = Poly $ fmap negate arr
|
||||
|
||||
polyAdd :: Num a => Poly a -> Poly a -> Poly a
|
||||
polyAdd (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
||||
(0,d1) = bounds arr1
|
||||
(0,d2) = bounds arr2
|
||||
d3 = max d1 d2
|
||||
zs = zipWith (+) (elems arr1 ++ replicate (d3-d1) 0)
|
||||
(elems arr2 ++ replicate (d3-d2) 0)
|
||||
|
||||
polySub :: Num a => Poly a -> Poly a -> Poly a
|
||||
polySub (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
||||
(0,d1) = bounds arr1
|
||||
(0,d2) = bounds arr2
|
||||
d3 = max d1 d2
|
||||
zs = zipWith (-) (elems arr1 ++ replicate (d3-d1) 0)
|
||||
(elems arr2 ++ replicate (d3-d2) 0)
|
||||
|
||||
polyMul :: Num a => Poly a -> Poly a -> Poly a
|
||||
polyMul (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
||||
(0,d1) = bounds arr1
|
||||
(0,d2) = bounds arr2
|
||||
d3 = d1 + d2
|
||||
zs = [ f k | k<-[0..d3] ]
|
||||
f !k = foldl' (+) 0 [ arr1!i * arr2!(k-i) | i<-[ max 0 (k-d2) .. min d1 k ] ]
|
||||
|
||||
instance Num a => Num (Poly a) where
|
||||
fromInteger = polyConst . fromInteger
|
||||
negate = polyNeg
|
||||
(+) = polyAdd
|
||||
(-) = polySub
|
||||
(*) = polyMul
|
||||
abs = id
|
||||
signum = \_ -> polyConst 1
|
||||
|
||||
polySum :: Num a => [Poly a] -> Poly a
|
||||
polySum = foldl' polyAdd 0
|
||||
|
||||
polyProd :: Num a => [Poly a] -> Poly a
|
||||
polyProd = foldl' polyMul 1
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Polynomial long division
|
||||
|
||||
-- | @polyDiv f h@ returns @(q,r)@ such that @f = q*h + r@ and @deg r < deg h@
|
||||
polyDiv :: forall f. (Eq f, Fractional f) => Poly f -> Poly f -> (Poly f, Poly f)
|
||||
polyDiv poly_f@(Poly arr_f) poly_h@(Poly arr_h)
|
||||
| deg_q < 0 = (polyZero, poly_f)
|
||||
| otherwise = runST action
|
||||
where
|
||||
deg_f = polyDegree poly_f
|
||||
deg_h = polyDegree poly_h
|
||||
deg_q = deg_f - deg_h
|
||||
|
||||
-- inverse of the top coefficient of divisor
|
||||
b_inv = recip (arr_h ! deg_h)
|
||||
|
||||
action :: forall s. ST s (Poly f, Poly f)
|
||||
action = do
|
||||
p <- thaw arr_f :: ST s (STArray s Int f)
|
||||
q <- newArray (0,deg_q) 0 :: ST s (STArray s Int f)
|
||||
forM_ [deg_q,deg_q-1..0] $ \k -> do
|
||||
top <- readArray p (deg_h + k)
|
||||
let y = b_inv * top
|
||||
writeArray q k y
|
||||
forM_ [0..deg_h] $ \j -> do
|
||||
a <- readArray p (j+k)
|
||||
writeArray p (j+k) (a - y*(arr_h!j))
|
||||
qarr <- freeze q
|
||||
rs <- forM [0..deg_h-1] $ \i -> readArray p i
|
||||
let rarr = listArray (0,deg_h-1) rs
|
||||
return (Poly qarr, Poly rarr)
|
||||
|
||||
-- | Returns only the quotient
|
||||
polyDivQuo :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
|
||||
polyDivQuo f g = fst $ polyDiv f g
|
||||
|
||||
-- | Returns only the remainder
|
||||
polyDivRem :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
|
||||
polyDivRem f g = snd $ polyDiv f g
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Sample random polynomials
|
||||
|
||||
randomPoly :: (RandomGen g, Random a) => Int -> g -> (Poly a, g)
|
||||
randomPoly deg g0 =
|
||||
let (coeffs,gfinal) = worker (deg+1) g0
|
||||
poly = Poly (listArray (0,deg) coeffs)
|
||||
in (poly, gfinal)
|
||||
|
||||
where
|
||||
worker 0 g = ([] , g)
|
||||
worker n g = let (x ,g1) = random g
|
||||
(xs,g2) = worker (n-1) g1
|
||||
in ((x:xs) , g)
|
||||
|
||||
randomPolyIO :: Random a => Int -> IO (Poly a)
|
||||
randomPolyIO deg = getStdRandom (randomPoly deg)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
#endif
|
||||
|
||||
55
reference/src/NTT/Poly/Flat.hs
Normal file
55
reference/src/NTT/Poly/Flat.hs
Normal file
@ -0,0 +1,55 @@
|
||||
|
||||
{-# LANGUAGE StrictData, ScopedTypeVariables, PatternSynonyms #-}
|
||||
module NTT.Poly.Flat where
|
||||
|
||||
import Prelude hiding (div,quot,rem)
|
||||
import GHC.Real hiding (div,quot,rem)
|
||||
|
||||
import Data.Bits
|
||||
import Data.Word
|
||||
import Data.List
|
||||
import Data.Array
|
||||
|
||||
import Control.Monad
|
||||
|
||||
import Foreign.C
|
||||
import Foreign.Ptr
|
||||
import Foreign.Marshal
|
||||
import Foreign.ForeignPtr
|
||||
|
||||
import System.Random
|
||||
import System.IO.Unsafe
|
||||
|
||||
import NTT.Subgroup
|
||||
import Field.Class
|
||||
import Data.Flat as L
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
newtype Poly a = MkPoly (L.FlatArray a)
|
||||
|
||||
pattern XPoly n arr = MkPoly (L.MkFlatArray n arr)
|
||||
|
||||
mkPoly :: Flat f => [f] -> Poly f
|
||||
mkPoly = MkPoly . L.packFlatArrayFromList
|
||||
|
||||
mkPoly' :: Flat f => Int -> [f] -> Poly f
|
||||
mkPoly' len xs = MkPoly $ L.packFlatArrayFromList' len xs
|
||||
|
||||
mkPolyArr :: Flat f => Array Int f -> Poly f
|
||||
mkPolyArr = MkPoly . L.packFlatArray
|
||||
|
||||
mkPolyFlatArr :: L.FlatArray f -> Poly f
|
||||
mkPolyFlatArr = MkPoly
|
||||
|
||||
coeffs :: Flat f => Poly f -> [f]
|
||||
coeffs (MkPoly arr) = L.unpackFlatArrayToList arr
|
||||
|
||||
coeffsArr :: Flat f => Poly f -> Array Int f
|
||||
coeffsArr (MkPoly arr) = L.unpackFlatArray arr
|
||||
|
||||
coeffsFlatArr :: Poly f -> L.FlatArray f
|
||||
coeffsFlatArr (MkPoly flat) = flat
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
227
reference/src/NTT/Poly/Naive.hs
Normal file
227
reference/src/NTT/Poly/Naive.hs
Normal file
@ -0,0 +1,227 @@
|
||||
|
||||
-- | Dense univariate polynomials
|
||||
|
||||
{-# LANGUAGE StrictData, BangPatterns, ScopedTypeVariables, DeriveFunctor #-}
|
||||
module NTT.Poly.Naive where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.List
|
||||
import Data.Array
|
||||
import Data.Array.ST (STArray)
|
||||
import Data.Array.MArray (newArray, readArray, writeArray, thaw, freeze)
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.ST.Strict
|
||||
|
||||
import System.Random
|
||||
|
||||
import Data.Binary
|
||||
|
||||
import Field.Goldilocks
|
||||
import Field.Goldilocks.Extension ( FExt )
|
||||
import Field.Encode
|
||||
|
||||
import Misc
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Univariate polynomials
|
||||
|
||||
-- | A dense univariate polynomial. The array index corresponds to the exponent.
|
||||
newtype Poly a
|
||||
= Poly (Array Int a)
|
||||
deriving (Show,Functor)
|
||||
|
||||
instance Binary a => Binary (Poly a) where
|
||||
put (Poly arr) = putSmallArray arr
|
||||
get = Poly <$> getSmallArray
|
||||
|
||||
instance (Num a, Eq a) => Eq (Poly a) where
|
||||
p == q = polyIsZero (polySub p q)
|
||||
|
||||
instance FieldEncode (Poly F) where
|
||||
fieldEncode (Poly arr) = fieldEncode arr
|
||||
|
||||
instance FieldEncode (Poly FExt) where
|
||||
fieldEncode (Poly arr) = fieldEncode arr
|
||||
|
||||
mkPoly :: [a] -> Poly a
|
||||
mkPoly coeffs = Poly $ listArray (0,length coeffs-1) coeffs
|
||||
|
||||
-- | Degree of the polynomial
|
||||
polyDegree :: (Eq a, Num a) => Poly a -> Int
|
||||
polyDegree (Poly arr) = worker d0 where
|
||||
(0,d0) = bounds arr
|
||||
worker d
|
||||
| d < 0 = -1
|
||||
| arr!d /= 0 = d
|
||||
| otherwise = worker (d-1)
|
||||
|
||||
-- | Size of the polynomial (can be larger than @degree + 1@, if the some top coefficients are zeros)
|
||||
polySize :: (Eq a, Num a) => Poly a -> Int
|
||||
polySize (Poly p) = arraySize p
|
||||
|
||||
polyIsZero :: (Eq a, Num a) => Poly a -> Bool
|
||||
polyIsZero (Poly arr) = all (==0) (elems arr)
|
||||
|
||||
-- | Returns the coefficient of @x^k@
|
||||
polyCoeff :: Num a => Poly a -> Int -> a
|
||||
polyCoeff (Poly coeffs) k = safeIndex 0 coeffs k
|
||||
|
||||
-- | Note: this can include zero coeffs at higher than the actual degree!
|
||||
polyCoeffArray :: Poly a -> Array Int a
|
||||
polyCoeffArray (Poly coeffs) = coeffs
|
||||
|
||||
-- | Note: this cuts off the potential extra zeros at the end.
|
||||
-- The order is little-endian (constant term first).
|
||||
polyCoeffList :: (Eq a, Num a) => Poly a -> [a]
|
||||
polyCoeffList poly@(Poly arr) = take (polyDegree poly + 1) (elems arr)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Elementary polynomials
|
||||
|
||||
-- | Constant polynomial
|
||||
polyConst :: a -> Poly a
|
||||
polyConst x = Poly $ listArray (0,0) [x]
|
||||
|
||||
-- | Zero polynomial
|
||||
polyZero :: Num a => Poly a
|
||||
polyZero = polyConst 0
|
||||
|
||||
-- | The polynomial @f(x) = x@
|
||||
polyVarX :: Num a => Poly a
|
||||
polyVarX = mkPoly [0,1]
|
||||
|
||||
-- | @polyLinear (A,B)@ means the linear polynomial @f(x) = A*x + B@
|
||||
polyLinear :: (a,a) -> Poly a
|
||||
polyLinear (a,b) = mkPoly [b,a]
|
||||
|
||||
-- | The monomial @x^n@
|
||||
polyXpowN :: Num a => Int -> Poly a
|
||||
polyXpowN n = Poly $ listArray (0,n) (replicate n 0 ++ [1])
|
||||
|
||||
-- | The binomial @(x^n - 1)@
|
||||
polyXpowNminus1 :: Num a => Int -> Poly a
|
||||
polyXpowNminus1 n = Poly $ listArray (0,n) (-1 : replicate (n-1) 0 ++ [1])
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Evaluate polynomials
|
||||
|
||||
polyEvalAt :: forall f. Fractional f => Poly f -> f -> f
|
||||
polyEvalAt (Poly arr) x = go 0 1 0 where
|
||||
(0,d) = bounds arr
|
||||
go :: f -> f -> Int -> f
|
||||
go !acc !y !i = if i > d
|
||||
then acc
|
||||
else go (acc + (arr!i)*y) (y*x) (i+1)
|
||||
|
||||
polyEvalOnList :: forall f. Fractional f => Poly f -> [f] -> [f]
|
||||
polyEvalOnList poly = map (polyEvalAt poly)
|
||||
|
||||
polyEvalOnArray :: forall f. Fractional f => Poly f -> Array Int f -> Array Int f
|
||||
polyEvalOnArray poly = fmap (polyEvalAt poly)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Basic arithmetic operations on polynomials
|
||||
|
||||
polyNeg :: Num a => Poly a -> Poly a
|
||||
polyNeg (Poly arr) = Poly $ fmap negate arr
|
||||
|
||||
polyAdd :: Num a => Poly a -> Poly a -> Poly a
|
||||
polyAdd (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
||||
(0,d1) = bounds arr1
|
||||
(0,d2) = bounds arr2
|
||||
d3 = max d1 d2
|
||||
zs = zipWith (+) (elems arr1 ++ replicate (d3-d1) 0)
|
||||
(elems arr2 ++ replicate (d3-d2) 0)
|
||||
|
||||
polySub :: Num a => Poly a -> Poly a -> Poly a
|
||||
polySub (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
||||
(0,d1) = bounds arr1
|
||||
(0,d2) = bounds arr2
|
||||
d3 = max d1 d2
|
||||
zs = zipWith (-) (elems arr1 ++ replicate (d3-d1) 0)
|
||||
(elems arr2 ++ replicate (d3-d2) 0)
|
||||
|
||||
polyMul :: Num a => Poly a -> Poly a -> Poly a
|
||||
polyMul (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
||||
(0,d1) = bounds arr1
|
||||
(0,d2) = bounds arr2
|
||||
d3 = d1 + d2
|
||||
zs = [ f k | k<-[0..d3] ]
|
||||
f !k = foldl' (+) 0 [ arr1!i * arr2!(k-i) | i<-[ max 0 (k-d2) .. min d1 k ] ]
|
||||
|
||||
instance Num a => Num (Poly a) where
|
||||
fromInteger = polyConst . fromInteger
|
||||
negate = polyNeg
|
||||
(+) = polyAdd
|
||||
(-) = polySub
|
||||
(*) = polyMul
|
||||
abs = id
|
||||
signum = \_ -> polyConst 1
|
||||
|
||||
polySum :: Num a => [Poly a] -> Poly a
|
||||
polySum = foldl' polyAdd 0
|
||||
|
||||
polyProd :: Num a => [Poly a] -> Poly a
|
||||
polyProd = foldl' polyMul 1
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Polynomial long division
|
||||
|
||||
-- | @polyDiv f h@ returns @(q,r)@ such that @f = q*h + r@ and @deg r < deg h@
|
||||
polyDiv :: forall f. (Eq f, Fractional f) => Poly f -> Poly f -> (Poly f, Poly f)
|
||||
polyDiv poly_f@(Poly arr_f) poly_h@(Poly arr_h)
|
||||
| deg_q < 0 = (polyZero, poly_f)
|
||||
| otherwise = runST action
|
||||
where
|
||||
deg_f = polyDegree poly_f
|
||||
deg_h = polyDegree poly_h
|
||||
deg_q = deg_f - deg_h
|
||||
|
||||
-- inverse of the top coefficient of divisor
|
||||
b_inv = recip (arr_h ! deg_h)
|
||||
|
||||
action :: forall s. ST s (Poly f, Poly f)
|
||||
action = do
|
||||
p <- thaw arr_f :: ST s (STArray s Int f)
|
||||
q <- newArray (0,deg_q) 0 :: ST s (STArray s Int f)
|
||||
forM_ [deg_q,deg_q-1..0] $ \k -> do
|
||||
top <- readArray p (deg_h + k)
|
||||
let y = b_inv * top
|
||||
writeArray q k y
|
||||
forM_ [0..deg_h] $ \j -> do
|
||||
a <- readArray p (j+k)
|
||||
writeArray p (j+k) (a - y*(arr_h!j))
|
||||
qarr <- freeze q
|
||||
rs <- forM [0..deg_h-1] $ \i -> readArray p i
|
||||
let rarr = listArray (0,deg_h-1) rs
|
||||
return (Poly qarr, Poly rarr)
|
||||
|
||||
-- | Returns only the quotient
|
||||
polyDivQuo :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
|
||||
polyDivQuo f g = fst $ polyDiv f g
|
||||
|
||||
-- | Returns only the remainder
|
||||
polyDivRem :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
|
||||
polyDivRem f g = snd $ polyDiv f g
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Sample random polynomials
|
||||
|
||||
randomPoly :: (RandomGen g, Random a) => Int -> g -> (Poly a, g)
|
||||
randomPoly deg g0 =
|
||||
let (coeffs,gfinal) = worker (deg+1) g0
|
||||
poly = Poly (listArray (0,deg) coeffs)
|
||||
in (poly, gfinal)
|
||||
|
||||
where
|
||||
worker 0 g = ([] , g)
|
||||
worker n g = let (x ,g1) = random g
|
||||
(xs,g2) = worker (n-1) g1
|
||||
in ((x:xs) , g)
|
||||
|
||||
randomPolyIO :: Random a => Int -> IO (Poly a)
|
||||
randomPolyIO deg = getStdRandom (randomPoly deg)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
@ -20,6 +20,9 @@ data Subgroup g = MkSubgroup
|
||||
subgroupSize :: Subgroup g -> Int
|
||||
subgroupSize = subgroupOrder
|
||||
|
||||
subgroupLogSize :: Subgroup g -> Log2
|
||||
subgroupLogSize = exactLog2__ . subgroupSize
|
||||
|
||||
getSubgroup :: Log2 -> Subgroup F
|
||||
getSubgroup log2@(Log2 n)
|
||||
| n<0 = error "getSubgroup: negative logarithm"
|
||||
|
||||
55
reference/src/NTT/Tests.hs
Normal file
55
reference/src/NTT/Tests.hs
Normal file
@ -0,0 +1,55 @@
|
||||
|
||||
module NTT.Tests where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Array
|
||||
|
||||
import Control.Monad
|
||||
|
||||
import Field.Class
|
||||
import Field.Goldilocks ( F )
|
||||
import Misc
|
||||
|
||||
import Data.Flat as L
|
||||
import NTT.Subgroup
|
||||
|
||||
import qualified NTT.Poly.Flat as F
|
||||
import qualified NTT.FFT.Fast as F
|
||||
|
||||
import qualified NTT.Poly.Naive as S
|
||||
import qualified NTT.FFT.Slow as S
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
-- | Compare slow and fast FFT implementations to each other
|
||||
compareFFTs :: Log2 -> IO Bool
|
||||
compareFFTs logSize = do
|
||||
let size = exp2_ logSize
|
||||
let sg = getSubgroup logSize
|
||||
|
||||
values1_s <- listToArray <$> replicateM size rndIO :: IO (Array Int F)
|
||||
let values1_f = L.packFlatArray values1_s :: L.FlatArray F
|
||||
let poly1_s = S.Poly values1_s :: S.Poly F
|
||||
let poly1_f = F.mkPolyArr values1_s :: F.Poly F
|
||||
|
||||
let values2_s = S.subgroupNTT sg poly1_s :: Array Int F
|
||||
let values2_f = F.forwardNTT sg poly1_f :: L.FlatArray F
|
||||
|
||||
let poly2_s = S.subgroupINTT sg values1_s :: S.Poly F
|
||||
let poly2_f = F.inverseNTT sg values1_f :: F.Poly F
|
||||
|
||||
let ok_ntt = values2_s == L.unpackFlatArray values2_f
|
||||
let ok_intt = S.polyCoeffArray poly2_s == F.coeffsArr poly2_f
|
||||
|
||||
return (ok_ntt && ok_intt)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
runTests :: Int -> Log2 -> IO Bool
|
||||
runTests n size = do
|
||||
oks <- replicateM n (compareFFTs size)
|
||||
return (and oks)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
228
reference/src/cbits/ntt.c
Normal file
228
reference/src/cbits/ntt.c
Normal file
@ -0,0 +1,228 @@
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
|
||||
#include "goldilocks.h"
|
||||
#include "ntt.h"
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
void goldilocks_ntt_forward_noalloc(int m, int src_stride, const uint64_t *gpows, const uint64_t *src, uint64_t *buf, uint64_t *tgt) {
|
||||
|
||||
if (m==0) {
|
||||
tgt[0] = src[0];
|
||||
return;
|
||||
}
|
||||
|
||||
if (m==1) {
|
||||
// N = 2
|
||||
tgt[0] = goldilocks_add( src[0] , src[src_stride] ); // x + y
|
||||
tgt[1] = goldilocks_sub( src[0] , src[src_stride] ); // x - y
|
||||
return;
|
||||
}
|
||||
|
||||
else {
|
||||
int N = (1<< m );
|
||||
int halfN = (1<<(m-1));
|
||||
|
||||
goldilocks_ntt_forward_noalloc( m-1 , src_stride<<1 , gpows , src , buf + N , buf );
|
||||
goldilocks_ntt_forward_noalloc( m-1 , src_stride<<1 , gpows , src + src_stride , buf + N , buf + halfN );
|
||||
|
||||
for(int j=0; j<halfN; j++) {
|
||||
const uint64_t gpow = gpows[j*src_stride];
|
||||
tgt[j ] = goldilocks_mul( buf[j+halfN] , gpow ); // g*v[k]
|
||||
tgt[j+halfN] = goldilocks_neg( tgt[j ] ); // - g*v[k]
|
||||
tgt[j ] = goldilocks_add( tgt[j ] , buf[j] ); // u[k] + g*v[k]
|
||||
tgt[j+halfN] = goldilocks_add( tgt[j+halfN] , buf[j] ); // u[k] - g*v[k]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forward number-theoretical transform (evaluation of a polynomial)
|
||||
// `src` and `tgt` should be `N = 2^m` sized arrays of field elements
|
||||
// `gen` should be the generator of the multiplicative subgroup sized `N`
|
||||
void goldilocks_ntt_forward(int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
|
||||
int N = (1<<m);
|
||||
int halfN = (N>>1);
|
||||
|
||||
// precalculate [1,g,g^2,g^3...]
|
||||
uint64_t *gpows = (uint64_t*) malloc( 8 * halfN );
|
||||
assert( gpows != 0 );
|
||||
uint64_t x = gen;
|
||||
gpows[0] = 1;
|
||||
gpows[1] = gen;
|
||||
for(int i=2; i<halfN; i++) {
|
||||
x = goldilocks_mul( x , gen );
|
||||
gpows[i] = x;
|
||||
}
|
||||
|
||||
uint64_t *buf = (uint64_t*) malloc( 8 * (2*N) );
|
||||
assert( buf != 0 );
|
||||
goldilocks_ntt_forward_noalloc( m, 1, gpows, src, buf, tgt);
|
||||
free(buf);
|
||||
free(gpows);
|
||||
}
|
||||
|
||||
// it's like `ntt_forward` but we pre-multiply the coefficients with `eta^k`
|
||||
// resulting in evaluating f(eta*x) instead of f(x)
|
||||
void goldilocks_ntt_forward_shifted(const uint64_t eta, int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
|
||||
int N = (1<<m);
|
||||
uint64_t *shifted = malloc( 8 * N );
|
||||
assert( shifted != 0 );
|
||||
uint64_t x = 1;
|
||||
for(int i=0; i<N; i++) {
|
||||
shifted[i] = goldilocks_mul( src[i] , x );
|
||||
x = goldilocks_mul( x , eta );
|
||||
}
|
||||
goldilocks_ntt_forward( m, gen, shifted, tgt );
|
||||
free(shifted);
|
||||
}
|
||||
|
||||
// it's like `ntt_forward` but asymmetric, evaluating on a larger target subgroup
|
||||
void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, const uint64_t gen_src, const uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt) {
|
||||
assert( m_tgt >= m_src );
|
||||
int N_src = (1 << m_src);
|
||||
int N_tgt = (1 << m_tgt);
|
||||
int halfN_src = (N_src >> 1);
|
||||
int K = (1 << (m_tgt - m_src));
|
||||
|
||||
// precalculate [1,g,g^2,g^3...]
|
||||
uint64_t *gpows = malloc( 8 * halfN_src );
|
||||
assert( gpows != 0 );
|
||||
uint64_t x = gen_src;
|
||||
gpows[0] = 1;
|
||||
gpows[1] = gen_src;
|
||||
for(int i=2; i<halfN_src; i++) {
|
||||
x = goldilocks_mul(x, gen_src);
|
||||
gpows[i] = x;
|
||||
}
|
||||
|
||||
uint64_t *shifted = malloc( 8 * N_src );
|
||||
assert( shifted != 0 );
|
||||
|
||||
uint64_t *buf = malloc( 8 * (2*N_src) );
|
||||
assert( buf != 0 );
|
||||
|
||||
// temporary target buffer (we could replace this by adding `tgt_stride`)
|
||||
uint64_t *tgt_small = malloc( 8 * N_src );
|
||||
assert( tgt_small != 0 );
|
||||
|
||||
// eta will be the shift
|
||||
uint64_t eta = 1;
|
||||
|
||||
for(int k=0; k<K; k++) {
|
||||
if (k==0) {
|
||||
memcpy( shifted, src, N_src*8 );
|
||||
}
|
||||
else {
|
||||
eta = goldilocks_mul( eta , gen_tgt );
|
||||
|
||||
uint64_t x = 1;
|
||||
for(int i=0; i<N_src; i++) {
|
||||
shifted[i] = goldilocks_mul( src[i] , x );
|
||||
x = goldilocks_mul(x, eta);
|
||||
}
|
||||
}
|
||||
|
||||
goldilocks_ntt_forward_noalloc( m_src, 1, gpows, shifted, buf, tgt_small );
|
||||
|
||||
uint64_t *p = tgt_small;
|
||||
uint64_t *q = tgt + k;
|
||||
int tgt_stride = K;
|
||||
for(int i=0; i<N_src; i++) {
|
||||
q[i] = p[i];
|
||||
p += 1;
|
||||
q += tgt_stride;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
free(tgt_small);
|
||||
free(buf);
|
||||
free(gpows);
|
||||
free(shifted);
|
||||
}
|
||||
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// inverse of 2 (which is is the same as `(p+1)/2`)
|
||||
const uint64_t goldilocks_oneHalf = 0x7fffffff80000001ull;
|
||||
|
||||
void goldilocks_ntt_inverse_noalloc(int m, int tgt_stride, const uint64_t *gpows, const uint64_t *src, uint64_t *buf, uint64_t *tgt) {
|
||||
|
||||
if (m==0) {
|
||||
tgt[0] = src[0];
|
||||
return;
|
||||
}
|
||||
|
||||
if (m==1) {
|
||||
// N = 2
|
||||
|
||||
tgt[0 ] = goldilocks_add( src[0] , src[1] ); // x + y
|
||||
tgt[tgt_stride] = goldilocks_sub( src[0] , src[1] ); // x - y
|
||||
tgt[0 ] = goldilocks_div_by_2( tgt[0 ] ); // (x + y)/2
|
||||
tgt[tgt_stride] = goldilocks_div_by_2( tgt[tgt_stride] ); // (x - y)/2
|
||||
return;
|
||||
}
|
||||
|
||||
else {
|
||||
int N = (1<< m );
|
||||
int halfN = (1<<(m-1));
|
||||
|
||||
for(int j=0; j<halfN; j++) {
|
||||
uint64_t gpow = gpows[j*tgt_stride];
|
||||
buf[j ] = goldilocks_add( src[j] , src[j+halfN] ); // x + y
|
||||
buf[j+halfN] = goldilocks_sub( src[j] , src[j+halfN] ); // x - y
|
||||
buf[j ] = goldilocks_div_by_2( buf[j ] ); // (x + y) / 2
|
||||
buf[j+halfN] = goldilocks_mul ( buf[j+halfN] , gpow ); // (x - y) / (2*g^k)
|
||||
}
|
||||
|
||||
goldilocks_ntt_inverse_noalloc( m-1 , tgt_stride<<1 , gpows , buf , buf + N , tgt );
|
||||
goldilocks_ntt_inverse_noalloc( m-1 , tgt_stride<<1 , gpows , buf + halfN , buf + N , tgt + tgt_stride );
|
||||
}
|
||||
}
|
||||
|
||||
// inverse number-theoretical transform (interpolation of a polynomial)
|
||||
// `src` and `tgt` should be `N = 2^m` sized arrays of field elements
|
||||
// `gen` should be the generator of the multiplicative subgroup sized `N`
|
||||
void goldilocks_ntt_inverse(int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
|
||||
int N = (1<<m);
|
||||
int halfN = (N>>1);
|
||||
|
||||
// precalculate [1/2,g^{-1}/2,g^{-2}/2,g^{-3}/2...]
|
||||
uint64_t *gpows = malloc( 8 * halfN );
|
||||
assert( gpows != 0 );
|
||||
uint64_t x = goldilocks_oneHalf; // 1/2
|
||||
uint64_t ginv = goldilocks_inv(gen); // gen^-1
|
||||
for(int i=0; i<halfN; i++) {
|
||||
gpows[i] = x;
|
||||
x = goldilocks_mul(x, ginv);
|
||||
}
|
||||
|
||||
uint64_t *buf = malloc( 8 * (2*N) );
|
||||
assert( buf !=0 );
|
||||
goldilocks_ntt_inverse_noalloc( m, 1, gpows, src, buf, tgt );
|
||||
free(buf);
|
||||
free(gpows);
|
||||
}
|
||||
|
||||
// it's like `ntt_inverse` but we post-multiply the resulting coefficients with `eta^k`
|
||||
// resulting in interpolating an f such that f(eta^-1 * omega^k) = y_k
|
||||
void goldilocks_ntt_inverse_shifted(const uint64_t eta, int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
|
||||
int N = (1<<m);
|
||||
uint64_t *unshifted = malloc( 8*N );
|
||||
assert( unshifted != 0 );
|
||||
goldilocks_ntt_inverse( m, gen, src, unshifted );
|
||||
uint64_t x = 1;
|
||||
for(int i=0; i<N; i++) {
|
||||
tgt[i] = goldilocks_mul( unshifted[i] , x );
|
||||
x = goldilocks_mul( x , eta );
|
||||
}
|
||||
free(unshifted);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
14
reference/src/cbits/ntt.h
Normal file
14
reference/src/cbits/ntt.h
Normal file
@ -0,0 +1,14 @@
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
void goldilocks_ntt_forward ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||
void goldilocks_ntt_forward_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||
|
||||
void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, uint64_t gen_src, uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt);
|
||||
|
||||
void goldilocks_ntt_inverse ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||
void goldilocks_ntt_inverse_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
Loading…
x
Reference in New Issue
Block a user