mirror of
https://github.com/logos-storage/outsourcing-Reed-Solomon.git
synced 2026-01-03 22:23:07 +00:00
preliminary C implementation of NTT
This commit is contained in:
parent
f7955ac21b
commit
a9cb0a96a6
@ -87,6 +87,9 @@ newtype ReductionStrategy = MkRedStrategy
|
|||||||
}
|
}
|
||||||
deriving (Eq,Show)
|
deriving (Eq,Show)
|
||||||
|
|
||||||
|
reductionStrategyLength :: ReductionStrategy -> Int
|
||||||
|
reductionStrategyLength = length . fromReductionStrategy
|
||||||
|
|
||||||
instance Binary ReductionStrategy where
|
instance Binary ReductionStrategy where
|
||||||
put = putSmallList . fromReductionStrategy
|
put = putSmallList . fromReductionStrategy
|
||||||
get = MkRedStrategy <$> getSmallList
|
get = MkRedStrategy <$> getSmallList
|
||||||
@ -123,6 +126,18 @@ data FriConfig = MkFriConfig
|
|||||||
}
|
}
|
||||||
deriving (Eq,Show)
|
deriving (Eq,Show)
|
||||||
|
|
||||||
|
friCommitPhaseVectorSizes :: FriConfig -> [Log2]
|
||||||
|
friCommitPhaseVectorSizes (MkFriConfig{..}) = result where
|
||||||
|
result = go (rsEncodedSize friRSConfig) (fromReductionStrategy friReductionStrategy)
|
||||||
|
go n [] = []
|
||||||
|
go n (a:as) = n : go (n-a) as
|
||||||
|
|
||||||
|
friCommitPhaseTreeSizes :: FriConfig -> [Log2]
|
||||||
|
friCommitPhaseTreeSizes (MkFriConfig{..}) = result where
|
||||||
|
result = go (rsEncodedSize friRSConfig) (fromReductionStrategy friReductionStrategy)
|
||||||
|
go n [] = []
|
||||||
|
go n (a:as) = (n-a) : go (n-a) as
|
||||||
|
|
||||||
instance Binary FriConfig where
|
instance Binary FriConfig where
|
||||||
put (MkFriConfig{..}) = do
|
put (MkFriConfig{..}) = do
|
||||||
put friRSConfig
|
put friRSConfig
|
||||||
@ -252,3 +267,44 @@ instance Binary FriProof where
|
|||||||
<*> get
|
<*> get
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
estimateFriProofSize :: FriConfig -> Int
|
||||||
|
estimateFriProofSize friConfig@(MkFriConfig{..}) = total where
|
||||||
|
|
||||||
|
total = friCfgSize + commitPhaseCaps + finalPolySize + (friNQueryRounds * queryRoundSize) + 8
|
||||||
|
|
||||||
|
arities = fromReductionStrategy friReductionStrategy
|
||||||
|
nsteps = length arities
|
||||||
|
|
||||||
|
rsCfgSize = 1 + 1 + 8
|
||||||
|
redStratSize = 1 + nsteps
|
||||||
|
friCfgSize = rsCfgSize + 8 + 1 + redStratSize + 8 + 1
|
||||||
|
|
||||||
|
(phases2,_) = friCommitPhaseSizesLog2 friConfig
|
||||||
|
|
||||||
|
merkleCapSize = 32 * exp2_ friMerkleCapSize
|
||||||
|
commitPhaseCaps = merkleCapSize * nsteps
|
||||||
|
finalPolySize = 8 * exp2_ (rsDataSize - sum arities)
|
||||||
|
|
||||||
|
queryStepSize arity = 16 * exp2_ arity +
|
||||||
|
|
||||||
|
data FriQueryStep = MkFriQueryStep
|
||||||
|
{ queryEvals :: [FExt]
|
||||||
|
, queryMerklePath :: RawMerklePath
|
||||||
|
}
|
||||||
|
deriving (Eq,Show)
|
||||||
|
|
||||||
|
data FriQueryRound = MkFriQueryRound
|
||||||
|
{ queryRow :: FRow
|
||||||
|
, queryInitialTreeProof :: RawMerklePath
|
||||||
|
, querySteps :: [FriQueryStep]
|
||||||
|
}
|
||||||
|
deriving (Eq,Show)
|
||||||
|
|
||||||
|
data FriProof = MkFriProof
|
||||||
|
{ proofFriConfig :: FriConfig -- ^ the FRI configuration
|
||||||
|
, proofCommitPhaseCaps :: [MerkleCap] -- ^ commit phase Merkle caps
|
||||||
|
, proofFinalPoly :: Poly FExt -- ^ the final polynomial in coefficient form
|
||||||
|
, proofQueryRounds :: [FriQueryRound] -- ^ query rounds
|
||||||
|
, proofPowWitness :: F -- ^ witness showing that the prover did PoW
|
||||||
|
}
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
|
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
module Field.Class where
|
module Field.Class where
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Data.Kind
|
||||||
import Data.Proxy
|
import Data.Proxy
|
||||||
|
|
||||||
import System.Random
|
import System.Random
|
||||||
@ -26,6 +28,15 @@ class (Show a, Eq a, Num a, Fractional a) => Field a where
|
|||||||
inverse :: Field a => a -> a
|
inverse :: Field a => a -> a
|
||||||
inverse = recip
|
inverse = recip
|
||||||
|
|
||||||
|
-- | Quadratic extensions
|
||||||
|
class (Field (BaseField ext), Field ext) => QuadraticExt ext where
|
||||||
|
type BaseField ext :: Type
|
||||||
|
inject :: BaseField ext -> ext
|
||||||
|
project :: ext -> Maybe (BaseField ext)
|
||||||
|
scale :: BaseField ext -> ext -> ext
|
||||||
|
quadraticPack :: (BaseField ext, BaseField ext) -> ext
|
||||||
|
quadraticUnpack :: ext -> (BaseField ext, BaseField ext)
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
instance Field Goldi.F where
|
instance Field Goldi.F where
|
||||||
@ -53,3 +64,14 @@ instance Field GoldiExt.FExt where
|
|||||||
rndIO = randomIO
|
rndIO = randomIO
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
instance QuadraticExt GoldiExt.FExt where
|
||||||
|
type BaseField GoldiExt.FExt = Goldi.F
|
||||||
|
|
||||||
|
inject = GoldiExt.inj
|
||||||
|
project = GoldiExt.proj
|
||||||
|
scale = GoldiExt.scl
|
||||||
|
quadraticPack = GoldiExt.pack
|
||||||
|
quadraticUnpack = GoldiExt.unpack
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|||||||
@ -139,6 +139,17 @@ binaryOpIO c_action x y = allocaBytesAligned 48 8 $ \ptr1 -> do
|
|||||||
inj :: F -> F2
|
inj :: F -> F2
|
||||||
inj r = F2 r 0
|
inj r = F2 r 0
|
||||||
|
|
||||||
|
proj :: F2 -> Maybe F
|
||||||
|
proj (F2 r i) = if Goldi.isZero i then Just r else Nothing
|
||||||
|
|
||||||
|
pack :: (F,F) -> F2
|
||||||
|
pack (r,i) = F2 r i
|
||||||
|
|
||||||
|
unpack :: F2 -> (F,F)
|
||||||
|
unpack (F2 r i) = (r,i)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
neg, sqr, inv :: F2 -> F2
|
neg, sqr, inv :: F2 -> F2
|
||||||
neg x = unsafePerformIO (unaryOpIO c_goldilocks_ext_neg x)
|
neg x = unsafePerformIO (unaryOpIO c_goldilocks_ext_neg x)
|
||||||
sqr x = unsafePerformIO (unaryOpIO c_goldilocks_ext_sqr x)
|
sqr x = unsafePerformIO (unaryOpIO c_goldilocks_ext_sqr x)
|
||||||
|
|||||||
@ -108,6 +108,17 @@ isOne (F2 r i) = Goldi.isOne r && Goldi.isZero i
|
|||||||
inj :: F -> F2
|
inj :: F -> F2
|
||||||
inj r = F2 r 0
|
inj r = F2 r 0
|
||||||
|
|
||||||
|
proj :: F2 -> Maybe F
|
||||||
|
proj (F2 r i) = if Goldi.isZero i then Just r else Nothing
|
||||||
|
|
||||||
|
pack :: (F,F) -> F2
|
||||||
|
pack (r,i) = F2 r i
|
||||||
|
|
||||||
|
unpack :: F2 -> (F,F)
|
||||||
|
unpack (F2 r i) = (r,i)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
neg :: F2 -> F2
|
neg :: F2 -> F2
|
||||||
neg (F2 r i) = F2 (negate r) (negate i)
|
neg (F2 r i) = F2 (negate r) (negate i)
|
||||||
|
|
||||||
|
|||||||
@ -5,11 +5,11 @@ module NTT
|
|||||||
( module Field.Goldilocks
|
( module Field.Goldilocks
|
||||||
, module NTT.Subgroup
|
, module NTT.Subgroup
|
||||||
, module NTT.Poly
|
, module NTT.Poly
|
||||||
, module NTT.Slow
|
, module NTT.FFT
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Field.Goldilocks
|
import Field.Goldilocks
|
||||||
import NTT.Subgroup
|
import NTT.Subgroup
|
||||||
import NTT.Poly
|
import NTT.Poly
|
||||||
import NTT.Slow
|
import NTT.FFT
|
||||||
|
|
||||||
|
|||||||
22
reference/src/NTT/Class.hs
Normal file
22
reference/src/NTT/Class.hs
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
|
||||||
|
module NTT.Class where
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Data.Kind
|
||||||
|
|
||||||
|
import NTT.Subgroup
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
{-
|
||||||
|
class NTT ntt where
|
||||||
|
type Field ntt :: Type
|
||||||
|
type Poly ntt :: Type
|
||||||
|
|
||||||
|
forwardNTT :: Subgroup F -> Poly F -> FlatArray F
|
||||||
|
inverseNTT :: Subgroup F -> FlatArray F -> Poly F
|
||||||
|
shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F
|
||||||
|
shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F
|
||||||
|
asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F
|
||||||
|
-}
|
||||||
14
reference/src/NTT/FFT.hs
Normal file
14
reference/src/NTT/FFT.hs
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
|
||||||
|
{-# LANGUAGE CPP #-}
|
||||||
|
|
||||||
|
#ifdef USE_NAIVE_HASKELL
|
||||||
|
|
||||||
|
module NTT.FFT ( module NTT.FFT.Slow ) where
|
||||||
|
import NTT.FFT.Slow
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
module NTT.FFT ( module NTT.FFT.Fast ) where
|
||||||
|
import NTT.FFT.Fast
|
||||||
|
|
||||||
|
#endif
|
||||||
132
reference/src/NTT/FFT/Fast.hs
Normal file
132
reference/src/NTT/FFT/Fast.hs
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
|
||||||
|
{-# LANGUAGE StrictData, ForeignFunctionInterface #-}
|
||||||
|
|
||||||
|
module NTT.FFT.Fast where
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Data.Word
|
||||||
|
|
||||||
|
import Foreign.C
|
||||||
|
import Foreign.Ptr
|
||||||
|
import Foreign.ForeignPtr
|
||||||
|
import Foreign.Marshal
|
||||||
|
|
||||||
|
import System.IO.Unsafe
|
||||||
|
|
||||||
|
import Data.Flat
|
||||||
|
|
||||||
|
import NTT.Poly
|
||||||
|
import NTT.Subgroup
|
||||||
|
import Field.Goldilocks
|
||||||
|
import Misc
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
{-
|
||||||
|
void goldilocks_ntt_forward ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||||
|
void goldilocks_ntt_forward_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||||
|
|
||||||
|
void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, uint64_t gen_src, uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt);
|
||||||
|
|
||||||
|
void goldilocks_ntt_inverse ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||||
|
void goldilocks_ntt_inverse_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||||
|
-}
|
||||||
|
|
||||||
|
foreign import ccall unsafe "goldilocks_ntt_forward" c_ntt_forward :: CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||||
|
foreign import ccall unsafe "goldilocks_ntt_forward_shifted" c_ntt_forward_shifted :: Word64 -> CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||||
|
|
||||||
|
foreign import ccall unsafe "goldilocks_ntt_forward_asymmetric" c_ntt_forward_asymmetric :: CInt -> CInt -> Word64 -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ();
|
||||||
|
|
||||||
|
foreign import ccall unsafe "goldilocks_ntt_inverse" c_ntt_inverse :: CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||||
|
foreign import ccall unsafe "goldilocks_ntt_inverse_shifted" c_ntt_inverse_shifted :: Word64 -> CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
{-# NOINLINE forwardNTT #-}
|
||||||
|
forwardNTT :: Subgroup F -> Poly F -> FlatArray F
|
||||||
|
forwardNTT sg (MkPoly (MkFlatArray n fptr2))
|
||||||
|
| subgroupSize sg /= n = error "forwardNTT: subgroup size differs from the array size"
|
||||||
|
| otherwise = unsafePerformIO $ do
|
||||||
|
fptr3 <- mallocForeignPtrArray n -- (n*4)
|
||||||
|
let MkGoldilocks gen = subgroupGen sg
|
||||||
|
withForeignPtr fptr2 $ \ptr2 -> do
|
||||||
|
withForeignPtr fptr3 $ \ptr3 -> do
|
||||||
|
c_ntt_forward (subgroupCLogSize sg) gen ptr2 ptr3
|
||||||
|
return (MkFlatArray n fptr3)
|
||||||
|
|
||||||
|
{-# NOINLINE inverseNTT #-}
|
||||||
|
inverseNTT :: Subgroup F -> FlatArray F -> Poly F
|
||||||
|
inverseNTT sg (MkFlatArray n fptr2)
|
||||||
|
| subgroupSize sg /= n = error "inverseNTT: subgroup size differs from the array size"
|
||||||
|
| otherwise = unsafePerformIO $ do
|
||||||
|
fptr3 <- mallocForeignPtrArray n --(n*4)
|
||||||
|
let MkGoldilocks gen = subgroupGen sg
|
||||||
|
withForeignPtr fptr2 $ \ptr2 -> do
|
||||||
|
withForeignPtr fptr3 $ \ptr3 -> do
|
||||||
|
c_ntt_inverse (subgroupCLogSize sg) gen ptr2 ptr3
|
||||||
|
return (MkPoly (MkFlatArray n fptr3))
|
||||||
|
|
||||||
|
-- | Pre-multiplies the coefficients by powers of eta, effectively evaluating @f(eta*x)@ on the subgroup
|
||||||
|
{-# NOINLINE shiftedForwardNTT #-}
|
||||||
|
shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F
|
||||||
|
shiftedForwardNTT sg (MkGoldilocks eta) (MkPoly (MkFlatArray n fptr2))
|
||||||
|
| subgroupSize sg /= n = error "shiftedForwardNTT: subgroup size differs from the array size"
|
||||||
|
| otherwise = unsafePerformIO $ do
|
||||||
|
fptr3 <- mallocForeignPtrArray n -- (n*4)
|
||||||
|
let MkGoldilocks gen = subgroupGen sg
|
||||||
|
withForeignPtr fptr2 $ \ptr2 -> do
|
||||||
|
withForeignPtr fptr3 $ \ptr3 -> do
|
||||||
|
c_ntt_forward_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3
|
||||||
|
return (MkFlatArray n fptr3)
|
||||||
|
|
||||||
|
-- | Post-multiplies the coefficients by powers of eta, effectively interpolating @f@ such that @f(eta^-1 * omega^k) = y_k@
|
||||||
|
{-# NOINLINE shiftedInverseNTT #-}
|
||||||
|
shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F
|
||||||
|
shiftedInverseNTT sg (MkGoldilocks eta) (MkFlatArray n fptr2)
|
||||||
|
| subgroupSize sg /= n = error "shiftedInverseNTT: subgroup size differs from the array size"
|
||||||
|
| otherwise = unsafePerformIO $ do
|
||||||
|
fptr3 <- mallocForeignPtrArray n -- (n*4)
|
||||||
|
withForeignPtr fptr2 $ \ptr2 -> do
|
||||||
|
withForeignPtr fptr3 $ \ptr3 -> do
|
||||||
|
c_ntt_inverse_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3
|
||||||
|
return (MkPoly (MkFlatArray n fptr3))
|
||||||
|
|
||||||
|
-- | Evaluates a polynomial f on a larger subgroup than it's defined on
|
||||||
|
{-# NOINLINE asymmForwardNTT #-}
|
||||||
|
asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F
|
||||||
|
asymmForwardNTT sg_src (MkPoly (MkFlatArray n fptr2)) sg_tgt
|
||||||
|
| subgroupSize sg_src /= n = error "asymmForwardNTT: subgroup size differs from the array size"
|
||||||
|
| m < n = error "asymmForwardNTT: target subgroup size should be at least the source subgroup src"
|
||||||
|
| otherwise = unsafePerformIO $ do
|
||||||
|
fptr3 <- mallocForeignPtrArray m -- (m*4)
|
||||||
|
let MkGoldilocks sgen1 = subgroupGen sg_src
|
||||||
|
let MkGoldilocks sgen2 = subgroupGen sg_tgt
|
||||||
|
withForeignPtr fptr2 $ \ptr2 -> do
|
||||||
|
withForeignPtr fptr3 $ \ptr3 -> do
|
||||||
|
c_ntt_forward_asymmetric
|
||||||
|
(subgroupCLogSize sg_src)
|
||||||
|
(subgroupCLogSize sg_tgt)
|
||||||
|
sgen1 sgen2 ptr2 ptr3
|
||||||
|
return (MkFlatArray m fptr3)
|
||||||
|
where
|
||||||
|
m = subgroupSize sg_tgt
|
||||||
|
|
||||||
|
{-
|
||||||
|
instance P.UnivariateFFT Poly where
|
||||||
|
ntt = forwardNTT
|
||||||
|
intt = inverseNTT
|
||||||
|
shiftedNTT = shiftedForwardNTT
|
||||||
|
shiftedINTT = shiftedInverseNTT
|
||||||
|
asymmNTT = asymmForwardNTT
|
||||||
|
-}
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
subgroupCLogSize :: Subgroup a -> CInt
|
||||||
|
subgroupCLogSize = fromIntegral . fromLog2 . subgroupLogSize
|
||||||
|
|
||||||
|
subgroupGenAsWord64 :: Subgroup F -> Word64
|
||||||
|
subgroupGenAsWord64 sg = let MkGoldilocks x = subgroupGen sg in x
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
@ -1,13 +1,13 @@
|
|||||||
|
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
module NTT.Slow where
|
module NTT.FFT.Slow where
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
import Data.Array
|
import Data.Array
|
||||||
import Data.Bits
|
import Data.Bits
|
||||||
|
|
||||||
import NTT.Poly
|
import NTT.Poly.Naive
|
||||||
import NTT.Subgroup
|
import NTT.Subgroup
|
||||||
|
|
||||||
import Field.Goldilocks
|
import Field.Goldilocks
|
||||||
@ -1,227 +1,13 @@
|
|||||||
|
{-# LANGUAGE CPP #-}
|
||||||
|
|
||||||
-- | Dense univariate polynomials
|
#ifdef USE_NAIVE_HASKELL
|
||||||
|
|
||||||
{-# LANGUAGE StrictData, BangPatterns, ScopedTypeVariables, DeriveFunctor #-}
|
module NTT.Poly ( module NTT.Poly.Naive ) where
|
||||||
module NTT.Poly where
|
import NTT.Poly.Naive
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
#else
|
||||||
|
|
||||||
import Data.List
|
module NTT.Poly ( module NTT.Poly.Flat ) where
|
||||||
import Data.Array
|
import NTT.Poly.Flat
|
||||||
import Data.Array.ST (STArray)
|
|
||||||
import Data.Array.MArray (newArray, readArray, writeArray, thaw, freeze)
|
|
||||||
|
|
||||||
import Control.Monad
|
#endif
|
||||||
import Control.Monad.ST.Strict
|
|
||||||
|
|
||||||
import System.Random
|
|
||||||
|
|
||||||
import Data.Binary
|
|
||||||
|
|
||||||
import Field.Goldilocks
|
|
||||||
import Field.Goldilocks.Extension ( FExt )
|
|
||||||
import Field.Encode
|
|
||||||
|
|
||||||
import Misc
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
|
||||||
-- * Univariate polynomials
|
|
||||||
|
|
||||||
-- | A dense univariate polynomial. The array index corresponds to the exponent.
|
|
||||||
newtype Poly a
|
|
||||||
= Poly (Array Int a)
|
|
||||||
deriving (Show,Functor)
|
|
||||||
|
|
||||||
instance Binary a => Binary (Poly a) where
|
|
||||||
put (Poly arr) = putSmallArray arr
|
|
||||||
get = Poly <$> getSmallArray
|
|
||||||
|
|
||||||
instance (Num a, Eq a) => Eq (Poly a) where
|
|
||||||
p == q = polyIsZero (polySub p q)
|
|
||||||
|
|
||||||
instance FieldEncode (Poly F) where
|
|
||||||
fieldEncode (Poly arr) = fieldEncode arr
|
|
||||||
|
|
||||||
instance FieldEncode (Poly FExt) where
|
|
||||||
fieldEncode (Poly arr) = fieldEncode arr
|
|
||||||
|
|
||||||
mkPoly :: [a] -> Poly a
|
|
||||||
mkPoly coeffs = Poly $ listArray (0,length coeffs-1) coeffs
|
|
||||||
|
|
||||||
-- | Degree of the polynomial
|
|
||||||
polyDegree :: (Eq a, Num a) => Poly a -> Int
|
|
||||||
polyDegree (Poly arr) = worker d0 where
|
|
||||||
(0,d0) = bounds arr
|
|
||||||
worker d
|
|
||||||
| d < 0 = -1
|
|
||||||
| arr!d /= 0 = d
|
|
||||||
| otherwise = worker (d-1)
|
|
||||||
|
|
||||||
-- | Size of the polynomial (can be larger than @degree + 1@, if the some top coefficients are zeros)
|
|
||||||
polySize :: (Eq a, Num a) => Poly a -> Int
|
|
||||||
polySize (Poly p) = arraySize p
|
|
||||||
|
|
||||||
polyIsZero :: (Eq a, Num a) => Poly a -> Bool
|
|
||||||
polyIsZero (Poly arr) = all (==0) (elems arr)
|
|
||||||
|
|
||||||
-- | Returns the coefficient of @x^k@
|
|
||||||
polyCoeff :: Num a => Poly a -> Int -> a
|
|
||||||
polyCoeff (Poly coeffs) k = safeIndex 0 coeffs k
|
|
||||||
|
|
||||||
-- | Note: this can include zero coeffs at higher than the actual degree!
|
|
||||||
polyCoeffArray :: Poly a -> Array Int a
|
|
||||||
polyCoeffArray (Poly coeffs) = coeffs
|
|
||||||
|
|
||||||
-- | Note: this cuts off the potential extra zeros at the end.
|
|
||||||
-- The order is little-endian (constant term first).
|
|
||||||
polyCoeffList :: (Eq a, Num a) => Poly a -> [a]
|
|
||||||
polyCoeffList poly@(Poly arr) = take (polyDegree poly + 1) (elems arr)
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
|
||||||
-- * Elementary polynomials
|
|
||||||
|
|
||||||
-- | Constant polynomial
|
|
||||||
polyConst :: a -> Poly a
|
|
||||||
polyConst x = Poly $ listArray (0,0) [x]
|
|
||||||
|
|
||||||
-- | Zero polynomial
|
|
||||||
polyZero :: Num a => Poly a
|
|
||||||
polyZero = polyConst 0
|
|
||||||
|
|
||||||
-- | The polynomial @f(x) = x@
|
|
||||||
polyVarX :: Num a => Poly a
|
|
||||||
polyVarX = mkPoly [0,1]
|
|
||||||
|
|
||||||
-- | @polyLinear (A,B)@ means the linear polynomial @f(x) = A*x + B@
|
|
||||||
polyLinear :: (a,a) -> Poly a
|
|
||||||
polyLinear (a,b) = mkPoly [b,a]
|
|
||||||
|
|
||||||
-- | The monomial @x^n@
|
|
||||||
polyXpowN :: Num a => Int -> Poly a
|
|
||||||
polyXpowN n = Poly $ listArray (0,n) (replicate n 0 ++ [1])
|
|
||||||
|
|
||||||
-- | The binomial @(x^n - 1)@
|
|
||||||
polyXpowNminus1 :: Num a => Int -> Poly a
|
|
||||||
polyXpowNminus1 n = Poly $ listArray (0,n) (-1 : replicate (n-1) 0 ++ [1])
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
|
||||||
-- * Evaluate polynomials
|
|
||||||
|
|
||||||
polyEvalAt :: forall f. Fractional f => Poly f -> f -> f
|
|
||||||
polyEvalAt (Poly arr) x = go 0 1 0 where
|
|
||||||
(0,d) = bounds arr
|
|
||||||
go :: f -> f -> Int -> f
|
|
||||||
go !acc !y !i = if i > d
|
|
||||||
then acc
|
|
||||||
else go (acc + (arr!i)*y) (y*x) (i+1)
|
|
||||||
|
|
||||||
polyEvalOnList :: forall f. Fractional f => Poly f -> [f] -> [f]
|
|
||||||
polyEvalOnList poly = map (polyEvalAt poly)
|
|
||||||
|
|
||||||
polyEvalOnArray :: forall f. Fractional f => Poly f -> Array Int f -> Array Int f
|
|
||||||
polyEvalOnArray poly = fmap (polyEvalAt poly)
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
|
||||||
-- * Basic arithmetic operations on polynomials
|
|
||||||
|
|
||||||
polyNeg :: Num a => Poly a -> Poly a
|
|
||||||
polyNeg (Poly arr) = Poly $ fmap negate arr
|
|
||||||
|
|
||||||
polyAdd :: Num a => Poly a -> Poly a -> Poly a
|
|
||||||
polyAdd (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
|
||||||
(0,d1) = bounds arr1
|
|
||||||
(0,d2) = bounds arr2
|
|
||||||
d3 = max d1 d2
|
|
||||||
zs = zipWith (+) (elems arr1 ++ replicate (d3-d1) 0)
|
|
||||||
(elems arr2 ++ replicate (d3-d2) 0)
|
|
||||||
|
|
||||||
polySub :: Num a => Poly a -> Poly a -> Poly a
|
|
||||||
polySub (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
|
||||||
(0,d1) = bounds arr1
|
|
||||||
(0,d2) = bounds arr2
|
|
||||||
d3 = max d1 d2
|
|
||||||
zs = zipWith (-) (elems arr1 ++ replicate (d3-d1) 0)
|
|
||||||
(elems arr2 ++ replicate (d3-d2) 0)
|
|
||||||
|
|
||||||
polyMul :: Num a => Poly a -> Poly a -> Poly a
|
|
||||||
polyMul (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
|
||||||
(0,d1) = bounds arr1
|
|
||||||
(0,d2) = bounds arr2
|
|
||||||
d3 = d1 + d2
|
|
||||||
zs = [ f k | k<-[0..d3] ]
|
|
||||||
f !k = foldl' (+) 0 [ arr1!i * arr2!(k-i) | i<-[ max 0 (k-d2) .. min d1 k ] ]
|
|
||||||
|
|
||||||
instance Num a => Num (Poly a) where
|
|
||||||
fromInteger = polyConst . fromInteger
|
|
||||||
negate = polyNeg
|
|
||||||
(+) = polyAdd
|
|
||||||
(-) = polySub
|
|
||||||
(*) = polyMul
|
|
||||||
abs = id
|
|
||||||
signum = \_ -> polyConst 1
|
|
||||||
|
|
||||||
polySum :: Num a => [Poly a] -> Poly a
|
|
||||||
polySum = foldl' polyAdd 0
|
|
||||||
|
|
||||||
polyProd :: Num a => [Poly a] -> Poly a
|
|
||||||
polyProd = foldl' polyMul 1
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
|
||||||
-- * Polynomial long division
|
|
||||||
|
|
||||||
-- | @polyDiv f h@ returns @(q,r)@ such that @f = q*h + r@ and @deg r < deg h@
|
|
||||||
polyDiv :: forall f. (Eq f, Fractional f) => Poly f -> Poly f -> (Poly f, Poly f)
|
|
||||||
polyDiv poly_f@(Poly arr_f) poly_h@(Poly arr_h)
|
|
||||||
| deg_q < 0 = (polyZero, poly_f)
|
|
||||||
| otherwise = runST action
|
|
||||||
where
|
|
||||||
deg_f = polyDegree poly_f
|
|
||||||
deg_h = polyDegree poly_h
|
|
||||||
deg_q = deg_f - deg_h
|
|
||||||
|
|
||||||
-- inverse of the top coefficient of divisor
|
|
||||||
b_inv = recip (arr_h ! deg_h)
|
|
||||||
|
|
||||||
action :: forall s. ST s (Poly f, Poly f)
|
|
||||||
action = do
|
|
||||||
p <- thaw arr_f :: ST s (STArray s Int f)
|
|
||||||
q <- newArray (0,deg_q) 0 :: ST s (STArray s Int f)
|
|
||||||
forM_ [deg_q,deg_q-1..0] $ \k -> do
|
|
||||||
top <- readArray p (deg_h + k)
|
|
||||||
let y = b_inv * top
|
|
||||||
writeArray q k y
|
|
||||||
forM_ [0..deg_h] $ \j -> do
|
|
||||||
a <- readArray p (j+k)
|
|
||||||
writeArray p (j+k) (a - y*(arr_h!j))
|
|
||||||
qarr <- freeze q
|
|
||||||
rs <- forM [0..deg_h-1] $ \i -> readArray p i
|
|
||||||
let rarr = listArray (0,deg_h-1) rs
|
|
||||||
return (Poly qarr, Poly rarr)
|
|
||||||
|
|
||||||
-- | Returns only the quotient
|
|
||||||
polyDivQuo :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
|
|
||||||
polyDivQuo f g = fst $ polyDiv f g
|
|
||||||
|
|
||||||
-- | Returns only the remainder
|
|
||||||
polyDivRem :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
|
|
||||||
polyDivRem f g = snd $ polyDiv f g
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
|
||||||
-- * Sample random polynomials
|
|
||||||
|
|
||||||
randomPoly :: (RandomGen g, Random a) => Int -> g -> (Poly a, g)
|
|
||||||
randomPoly deg g0 =
|
|
||||||
let (coeffs,gfinal) = worker (deg+1) g0
|
|
||||||
poly = Poly (listArray (0,deg) coeffs)
|
|
||||||
in (poly, gfinal)
|
|
||||||
|
|
||||||
where
|
|
||||||
worker 0 g = ([] , g)
|
|
||||||
worker n g = let (x ,g1) = random g
|
|
||||||
(xs,g2) = worker (n-1) g1
|
|
||||||
in ((x:xs) , g)
|
|
||||||
|
|
||||||
randomPolyIO :: Random a => Int -> IO (Poly a)
|
|
||||||
randomPolyIO deg = getStdRandom (randomPoly deg)
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
|
||||||
|
|||||||
55
reference/src/NTT/Poly/Flat.hs
Normal file
55
reference/src/NTT/Poly/Flat.hs
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
|
||||||
|
{-# LANGUAGE StrictData, ScopedTypeVariables, PatternSynonyms #-}
|
||||||
|
module NTT.Poly.Flat where
|
||||||
|
|
||||||
|
import Prelude hiding (div,quot,rem)
|
||||||
|
import GHC.Real hiding (div,quot,rem)
|
||||||
|
|
||||||
|
import Data.Bits
|
||||||
|
import Data.Word
|
||||||
|
import Data.List
|
||||||
|
import Data.Array
|
||||||
|
|
||||||
|
import Control.Monad
|
||||||
|
|
||||||
|
import Foreign.C
|
||||||
|
import Foreign.Ptr
|
||||||
|
import Foreign.Marshal
|
||||||
|
import Foreign.ForeignPtr
|
||||||
|
|
||||||
|
import System.Random
|
||||||
|
import System.IO.Unsafe
|
||||||
|
|
||||||
|
import NTT.Subgroup
|
||||||
|
import Field.Class
|
||||||
|
import Data.Flat as L
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
newtype Poly a = MkPoly (L.FlatArray a)
|
||||||
|
|
||||||
|
pattern XPoly n arr = MkPoly (L.MkFlatArray n arr)
|
||||||
|
|
||||||
|
mkPoly :: Flat f => [f] -> Poly f
|
||||||
|
mkPoly = MkPoly . L.packFlatArrayFromList
|
||||||
|
|
||||||
|
mkPoly' :: Flat f => Int -> [f] -> Poly f
|
||||||
|
mkPoly' len xs = MkPoly $ L.packFlatArrayFromList' len xs
|
||||||
|
|
||||||
|
mkPolyArr :: Flat f => Array Int f -> Poly f
|
||||||
|
mkPolyArr = MkPoly . L.packFlatArray
|
||||||
|
|
||||||
|
mkPolyFlatArr :: L.FlatArray f -> Poly f
|
||||||
|
mkPolyFlatArr = MkPoly
|
||||||
|
|
||||||
|
coeffs :: Flat f => Poly f -> [f]
|
||||||
|
coeffs (MkPoly arr) = L.unpackFlatArrayToList arr
|
||||||
|
|
||||||
|
coeffsArr :: Flat f => Poly f -> Array Int f
|
||||||
|
coeffsArr (MkPoly arr) = L.unpackFlatArray arr
|
||||||
|
|
||||||
|
coeffsFlatArr :: Poly f -> L.FlatArray f
|
||||||
|
coeffsFlatArr (MkPoly flat) = flat
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
227
reference/src/NTT/Poly/Naive.hs
Normal file
227
reference/src/NTT/Poly/Naive.hs
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
|
||||||
|
-- | Dense univariate polynomials
|
||||||
|
|
||||||
|
{-# LANGUAGE StrictData, BangPatterns, ScopedTypeVariables, DeriveFunctor #-}
|
||||||
|
module NTT.Poly.Naive where
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Data.List
|
||||||
|
import Data.Array
|
||||||
|
import Data.Array.ST (STArray)
|
||||||
|
import Data.Array.MArray (newArray, readArray, writeArray, thaw, freeze)
|
||||||
|
|
||||||
|
import Control.Monad
|
||||||
|
import Control.Monad.ST.Strict
|
||||||
|
|
||||||
|
import System.Random
|
||||||
|
|
||||||
|
import Data.Binary
|
||||||
|
|
||||||
|
import Field.Goldilocks
|
||||||
|
import Field.Goldilocks.Extension ( FExt )
|
||||||
|
import Field.Encode
|
||||||
|
|
||||||
|
import Misc
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- * Univariate polynomials
|
||||||
|
|
||||||
|
-- | A dense univariate polynomial. The array index corresponds to the exponent.
|
||||||
|
newtype Poly a
|
||||||
|
= Poly (Array Int a)
|
||||||
|
deriving (Show,Functor)
|
||||||
|
|
||||||
|
instance Binary a => Binary (Poly a) where
|
||||||
|
put (Poly arr) = putSmallArray arr
|
||||||
|
get = Poly <$> getSmallArray
|
||||||
|
|
||||||
|
instance (Num a, Eq a) => Eq (Poly a) where
|
||||||
|
p == q = polyIsZero (polySub p q)
|
||||||
|
|
||||||
|
instance FieldEncode (Poly F) where
|
||||||
|
fieldEncode (Poly arr) = fieldEncode arr
|
||||||
|
|
||||||
|
instance FieldEncode (Poly FExt) where
|
||||||
|
fieldEncode (Poly arr) = fieldEncode arr
|
||||||
|
|
||||||
|
mkPoly :: [a] -> Poly a
|
||||||
|
mkPoly coeffs = Poly $ listArray (0,length coeffs-1) coeffs
|
||||||
|
|
||||||
|
-- | Degree of the polynomial
|
||||||
|
polyDegree :: (Eq a, Num a) => Poly a -> Int
|
||||||
|
polyDegree (Poly arr) = worker d0 where
|
||||||
|
(0,d0) = bounds arr
|
||||||
|
worker d
|
||||||
|
| d < 0 = -1
|
||||||
|
| arr!d /= 0 = d
|
||||||
|
| otherwise = worker (d-1)
|
||||||
|
|
||||||
|
-- | Size of the polynomial (can be larger than @degree + 1@, if the some top coefficients are zeros)
|
||||||
|
polySize :: (Eq a, Num a) => Poly a -> Int
|
||||||
|
polySize (Poly p) = arraySize p
|
||||||
|
|
||||||
|
polyIsZero :: (Eq a, Num a) => Poly a -> Bool
|
||||||
|
polyIsZero (Poly arr) = all (==0) (elems arr)
|
||||||
|
|
||||||
|
-- | Returns the coefficient of @x^k@
|
||||||
|
polyCoeff :: Num a => Poly a -> Int -> a
|
||||||
|
polyCoeff (Poly coeffs) k = safeIndex 0 coeffs k
|
||||||
|
|
||||||
|
-- | Note: this can include zero coeffs at higher than the actual degree!
|
||||||
|
polyCoeffArray :: Poly a -> Array Int a
|
||||||
|
polyCoeffArray (Poly coeffs) = coeffs
|
||||||
|
|
||||||
|
-- | Note: this cuts off the potential extra zeros at the end.
|
||||||
|
-- The order is little-endian (constant term first).
|
||||||
|
polyCoeffList :: (Eq a, Num a) => Poly a -> [a]
|
||||||
|
polyCoeffList poly@(Poly arr) = take (polyDegree poly + 1) (elems arr)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- * Elementary polynomials
|
||||||
|
|
||||||
|
-- | Constant polynomial
|
||||||
|
polyConst :: a -> Poly a
|
||||||
|
polyConst x = Poly $ listArray (0,0) [x]
|
||||||
|
|
||||||
|
-- | Zero polynomial
|
||||||
|
polyZero :: Num a => Poly a
|
||||||
|
polyZero = polyConst 0
|
||||||
|
|
||||||
|
-- | The polynomial @f(x) = x@
|
||||||
|
polyVarX :: Num a => Poly a
|
||||||
|
polyVarX = mkPoly [0,1]
|
||||||
|
|
||||||
|
-- | @polyLinear (A,B)@ means the linear polynomial @f(x) = A*x + B@
|
||||||
|
polyLinear :: (a,a) -> Poly a
|
||||||
|
polyLinear (a,b) = mkPoly [b,a]
|
||||||
|
|
||||||
|
-- | The monomial @x^n@
|
||||||
|
polyXpowN :: Num a => Int -> Poly a
|
||||||
|
polyXpowN n = Poly $ listArray (0,n) (replicate n 0 ++ [1])
|
||||||
|
|
||||||
|
-- | The binomial @(x^n - 1)@
|
||||||
|
polyXpowNminus1 :: Num a => Int -> Poly a
|
||||||
|
polyXpowNminus1 n = Poly $ listArray (0,n) (-1 : replicate (n-1) 0 ++ [1])
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- * Evaluate polynomials
|
||||||
|
|
||||||
|
polyEvalAt :: forall f. Fractional f => Poly f -> f -> f
|
||||||
|
polyEvalAt (Poly arr) x = go 0 1 0 where
|
||||||
|
(0,d) = bounds arr
|
||||||
|
go :: f -> f -> Int -> f
|
||||||
|
go !acc !y !i = if i > d
|
||||||
|
then acc
|
||||||
|
else go (acc + (arr!i)*y) (y*x) (i+1)
|
||||||
|
|
||||||
|
polyEvalOnList :: forall f. Fractional f => Poly f -> [f] -> [f]
|
||||||
|
polyEvalOnList poly = map (polyEvalAt poly)
|
||||||
|
|
||||||
|
polyEvalOnArray :: forall f. Fractional f => Poly f -> Array Int f -> Array Int f
|
||||||
|
polyEvalOnArray poly = fmap (polyEvalAt poly)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- * Basic arithmetic operations on polynomials
|
||||||
|
|
||||||
|
polyNeg :: Num a => Poly a -> Poly a
|
||||||
|
polyNeg (Poly arr) = Poly $ fmap negate arr
|
||||||
|
|
||||||
|
polyAdd :: Num a => Poly a -> Poly a -> Poly a
|
||||||
|
polyAdd (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
||||||
|
(0,d1) = bounds arr1
|
||||||
|
(0,d2) = bounds arr2
|
||||||
|
d3 = max d1 d2
|
||||||
|
zs = zipWith (+) (elems arr1 ++ replicate (d3-d1) 0)
|
||||||
|
(elems arr2 ++ replicate (d3-d2) 0)
|
||||||
|
|
||||||
|
polySub :: Num a => Poly a -> Poly a -> Poly a
|
||||||
|
polySub (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
||||||
|
(0,d1) = bounds arr1
|
||||||
|
(0,d2) = bounds arr2
|
||||||
|
d3 = max d1 d2
|
||||||
|
zs = zipWith (-) (elems arr1 ++ replicate (d3-d1) 0)
|
||||||
|
(elems arr2 ++ replicate (d3-d2) 0)
|
||||||
|
|
||||||
|
polyMul :: Num a => Poly a -> Poly a -> Poly a
|
||||||
|
polyMul (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where
|
||||||
|
(0,d1) = bounds arr1
|
||||||
|
(0,d2) = bounds arr2
|
||||||
|
d3 = d1 + d2
|
||||||
|
zs = [ f k | k<-[0..d3] ]
|
||||||
|
f !k = foldl' (+) 0 [ arr1!i * arr2!(k-i) | i<-[ max 0 (k-d2) .. min d1 k ] ]
|
||||||
|
|
||||||
|
instance Num a => Num (Poly a) where
|
||||||
|
fromInteger = polyConst . fromInteger
|
||||||
|
negate = polyNeg
|
||||||
|
(+) = polyAdd
|
||||||
|
(-) = polySub
|
||||||
|
(*) = polyMul
|
||||||
|
abs = id
|
||||||
|
signum = \_ -> polyConst 1
|
||||||
|
|
||||||
|
polySum :: Num a => [Poly a] -> Poly a
|
||||||
|
polySum = foldl' polyAdd 0
|
||||||
|
|
||||||
|
polyProd :: Num a => [Poly a] -> Poly a
|
||||||
|
polyProd = foldl' polyMul 1
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- * Polynomial long division
|
||||||
|
|
||||||
|
-- | @polyDiv f h@ returns @(q,r)@ such that @f = q*h + r@ and @deg r < deg h@
|
||||||
|
polyDiv :: forall f. (Eq f, Fractional f) => Poly f -> Poly f -> (Poly f, Poly f)
|
||||||
|
polyDiv poly_f@(Poly arr_f) poly_h@(Poly arr_h)
|
||||||
|
| deg_q < 0 = (polyZero, poly_f)
|
||||||
|
| otherwise = runST action
|
||||||
|
where
|
||||||
|
deg_f = polyDegree poly_f
|
||||||
|
deg_h = polyDegree poly_h
|
||||||
|
deg_q = deg_f - deg_h
|
||||||
|
|
||||||
|
-- inverse of the top coefficient of divisor
|
||||||
|
b_inv = recip (arr_h ! deg_h)
|
||||||
|
|
||||||
|
action :: forall s. ST s (Poly f, Poly f)
|
||||||
|
action = do
|
||||||
|
p <- thaw arr_f :: ST s (STArray s Int f)
|
||||||
|
q <- newArray (0,deg_q) 0 :: ST s (STArray s Int f)
|
||||||
|
forM_ [deg_q,deg_q-1..0] $ \k -> do
|
||||||
|
top <- readArray p (deg_h + k)
|
||||||
|
let y = b_inv * top
|
||||||
|
writeArray q k y
|
||||||
|
forM_ [0..deg_h] $ \j -> do
|
||||||
|
a <- readArray p (j+k)
|
||||||
|
writeArray p (j+k) (a - y*(arr_h!j))
|
||||||
|
qarr <- freeze q
|
||||||
|
rs <- forM [0..deg_h-1] $ \i -> readArray p i
|
||||||
|
let rarr = listArray (0,deg_h-1) rs
|
||||||
|
return (Poly qarr, Poly rarr)
|
||||||
|
|
||||||
|
-- | Returns only the quotient
|
||||||
|
polyDivQuo :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
|
||||||
|
polyDivQuo f g = fst $ polyDiv f g
|
||||||
|
|
||||||
|
-- | Returns only the remainder
|
||||||
|
polyDivRem :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f
|
||||||
|
polyDivRem f g = snd $ polyDiv f g
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- * Sample random polynomials
|
||||||
|
|
||||||
|
randomPoly :: (RandomGen g, Random a) => Int -> g -> (Poly a, g)
|
||||||
|
randomPoly deg g0 =
|
||||||
|
let (coeffs,gfinal) = worker (deg+1) g0
|
||||||
|
poly = Poly (listArray (0,deg) coeffs)
|
||||||
|
in (poly, gfinal)
|
||||||
|
|
||||||
|
where
|
||||||
|
worker 0 g = ([] , g)
|
||||||
|
worker n g = let (x ,g1) = random g
|
||||||
|
(xs,g2) = worker (n-1) g1
|
||||||
|
in ((x:xs) , g)
|
||||||
|
|
||||||
|
randomPolyIO :: Random a => Int -> IO (Poly a)
|
||||||
|
randomPolyIO deg = getStdRandom (randomPoly deg)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
@ -20,6 +20,9 @@ data Subgroup g = MkSubgroup
|
|||||||
subgroupSize :: Subgroup g -> Int
|
subgroupSize :: Subgroup g -> Int
|
||||||
subgroupSize = subgroupOrder
|
subgroupSize = subgroupOrder
|
||||||
|
|
||||||
|
subgroupLogSize :: Subgroup g -> Log2
|
||||||
|
subgroupLogSize = exactLog2__ . subgroupSize
|
||||||
|
|
||||||
getSubgroup :: Log2 -> Subgroup F
|
getSubgroup :: Log2 -> Subgroup F
|
||||||
getSubgroup log2@(Log2 n)
|
getSubgroup log2@(Log2 n)
|
||||||
| n<0 = error "getSubgroup: negative logarithm"
|
| n<0 = error "getSubgroup: negative logarithm"
|
||||||
|
|||||||
55
reference/src/NTT/Tests.hs
Normal file
55
reference/src/NTT/Tests.hs
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
|
||||||
|
module NTT.Tests where
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Data.Array
|
||||||
|
|
||||||
|
import Control.Monad
|
||||||
|
|
||||||
|
import Field.Class
|
||||||
|
import Field.Goldilocks ( F )
|
||||||
|
import Misc
|
||||||
|
|
||||||
|
import Data.Flat as L
|
||||||
|
import NTT.Subgroup
|
||||||
|
|
||||||
|
import qualified NTT.Poly.Flat as F
|
||||||
|
import qualified NTT.FFT.Fast as F
|
||||||
|
|
||||||
|
import qualified NTT.Poly.Naive as S
|
||||||
|
import qualified NTT.FFT.Slow as S
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
-- | Compare slow and fast FFT implementations to each other
|
||||||
|
compareFFTs :: Log2 -> IO Bool
|
||||||
|
compareFFTs logSize = do
|
||||||
|
let size = exp2_ logSize
|
||||||
|
let sg = getSubgroup logSize
|
||||||
|
|
||||||
|
values1_s <- listToArray <$> replicateM size rndIO :: IO (Array Int F)
|
||||||
|
let values1_f = L.packFlatArray values1_s :: L.FlatArray F
|
||||||
|
let poly1_s = S.Poly values1_s :: S.Poly F
|
||||||
|
let poly1_f = F.mkPolyArr values1_s :: F.Poly F
|
||||||
|
|
||||||
|
let values2_s = S.subgroupNTT sg poly1_s :: Array Int F
|
||||||
|
let values2_f = F.forwardNTT sg poly1_f :: L.FlatArray F
|
||||||
|
|
||||||
|
let poly2_s = S.subgroupINTT sg values1_s :: S.Poly F
|
||||||
|
let poly2_f = F.inverseNTT sg values1_f :: F.Poly F
|
||||||
|
|
||||||
|
let ok_ntt = values2_s == L.unpackFlatArray values2_f
|
||||||
|
let ok_intt = S.polyCoeffArray poly2_s == F.coeffsArr poly2_f
|
||||||
|
|
||||||
|
return (ok_ntt && ok_intt)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
runTests :: Int -> Log2 -> IO Bool
|
||||||
|
runTests n size = do
|
||||||
|
oks <- replicateM n (compareFFTs size)
|
||||||
|
return (and oks)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
228
reference/src/cbits/ntt.c
Normal file
228
reference/src/cbits/ntt.c
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include "goldilocks.h"
|
||||||
|
#include "ntt.h"
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
void goldilocks_ntt_forward_noalloc(int m, int src_stride, const uint64_t *gpows, const uint64_t *src, uint64_t *buf, uint64_t *tgt) {
|
||||||
|
|
||||||
|
if (m==0) {
|
||||||
|
tgt[0] = src[0];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (m==1) {
|
||||||
|
// N = 2
|
||||||
|
tgt[0] = goldilocks_add( src[0] , src[src_stride] ); // x + y
|
||||||
|
tgt[1] = goldilocks_sub( src[0] , src[src_stride] ); // x - y
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
else {
|
||||||
|
int N = (1<< m );
|
||||||
|
int halfN = (1<<(m-1));
|
||||||
|
|
||||||
|
goldilocks_ntt_forward_noalloc( m-1 , src_stride<<1 , gpows , src , buf + N , buf );
|
||||||
|
goldilocks_ntt_forward_noalloc( m-1 , src_stride<<1 , gpows , src + src_stride , buf + N , buf + halfN );
|
||||||
|
|
||||||
|
for(int j=0; j<halfN; j++) {
|
||||||
|
const uint64_t gpow = gpows[j*src_stride];
|
||||||
|
tgt[j ] = goldilocks_mul( buf[j+halfN] , gpow ); // g*v[k]
|
||||||
|
tgt[j+halfN] = goldilocks_neg( tgt[j ] ); // - g*v[k]
|
||||||
|
tgt[j ] = goldilocks_add( tgt[j ] , buf[j] ); // u[k] + g*v[k]
|
||||||
|
tgt[j+halfN] = goldilocks_add( tgt[j+halfN] , buf[j] ); // u[k] - g*v[k]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// forward number-theoretical transform (evaluation of a polynomial)
|
||||||
|
// `src` and `tgt` should be `N = 2^m` sized arrays of field elements
|
||||||
|
// `gen` should be the generator of the multiplicative subgroup sized `N`
|
||||||
|
void goldilocks_ntt_forward(int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
|
||||||
|
int N = (1<<m);
|
||||||
|
int halfN = (N>>1);
|
||||||
|
|
||||||
|
// precalculate [1,g,g^2,g^3...]
|
||||||
|
uint64_t *gpows = (uint64_t*) malloc( 8 * halfN );
|
||||||
|
assert( gpows != 0 );
|
||||||
|
uint64_t x = gen;
|
||||||
|
gpows[0] = 1;
|
||||||
|
gpows[1] = gen;
|
||||||
|
for(int i=2; i<halfN; i++) {
|
||||||
|
x = goldilocks_mul( x , gen );
|
||||||
|
gpows[i] = x;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t *buf = (uint64_t*) malloc( 8 * (2*N) );
|
||||||
|
assert( buf != 0 );
|
||||||
|
goldilocks_ntt_forward_noalloc( m, 1, gpows, src, buf, tgt);
|
||||||
|
free(buf);
|
||||||
|
free(gpows);
|
||||||
|
}
|
||||||
|
|
||||||
|
// it's like `ntt_forward` but we pre-multiply the coefficients with `eta^k`
|
||||||
|
// resulting in evaluating f(eta*x) instead of f(x)
|
||||||
|
void goldilocks_ntt_forward_shifted(const uint64_t eta, int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
|
||||||
|
int N = (1<<m);
|
||||||
|
uint64_t *shifted = malloc( 8 * N );
|
||||||
|
assert( shifted != 0 );
|
||||||
|
uint64_t x = 1;
|
||||||
|
for(int i=0; i<N; i++) {
|
||||||
|
shifted[i] = goldilocks_mul( src[i] , x );
|
||||||
|
x = goldilocks_mul( x , eta );
|
||||||
|
}
|
||||||
|
goldilocks_ntt_forward( m, gen, shifted, tgt );
|
||||||
|
free(shifted);
|
||||||
|
}
|
||||||
|
|
||||||
|
// it's like `ntt_forward` but asymmetric, evaluating on a larger target subgroup
|
||||||
|
void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, const uint64_t gen_src, const uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt) {
|
||||||
|
assert( m_tgt >= m_src );
|
||||||
|
int N_src = (1 << m_src);
|
||||||
|
int N_tgt = (1 << m_tgt);
|
||||||
|
int halfN_src = (N_src >> 1);
|
||||||
|
int K = (1 << (m_tgt - m_src));
|
||||||
|
|
||||||
|
// precalculate [1,g,g^2,g^3...]
|
||||||
|
uint64_t *gpows = malloc( 8 * halfN_src );
|
||||||
|
assert( gpows != 0 );
|
||||||
|
uint64_t x = gen_src;
|
||||||
|
gpows[0] = 1;
|
||||||
|
gpows[1] = gen_src;
|
||||||
|
for(int i=2; i<halfN_src; i++) {
|
||||||
|
x = goldilocks_mul(x, gen_src);
|
||||||
|
gpows[i] = x;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t *shifted = malloc( 8 * N_src );
|
||||||
|
assert( shifted != 0 );
|
||||||
|
|
||||||
|
uint64_t *buf = malloc( 8 * (2*N_src) );
|
||||||
|
assert( buf != 0 );
|
||||||
|
|
||||||
|
// temporary target buffer (we could replace this by adding `tgt_stride`)
|
||||||
|
uint64_t *tgt_small = malloc( 8 * N_src );
|
||||||
|
assert( tgt_small != 0 );
|
||||||
|
|
||||||
|
// eta will be the shift
|
||||||
|
uint64_t eta = 1;
|
||||||
|
|
||||||
|
for(int k=0; k<K; k++) {
|
||||||
|
if (k==0) {
|
||||||
|
memcpy( shifted, src, N_src*8 );
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
eta = goldilocks_mul( eta , gen_tgt );
|
||||||
|
|
||||||
|
uint64_t x = 1;
|
||||||
|
for(int i=0; i<N_src; i++) {
|
||||||
|
shifted[i] = goldilocks_mul( src[i] , x );
|
||||||
|
x = goldilocks_mul(x, eta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
goldilocks_ntt_forward_noalloc( m_src, 1, gpows, shifted, buf, tgt_small );
|
||||||
|
|
||||||
|
uint64_t *p = tgt_small;
|
||||||
|
uint64_t *q = tgt + k;
|
||||||
|
int tgt_stride = K;
|
||||||
|
for(int i=0; i<N_src; i++) {
|
||||||
|
q[i] = p[i];
|
||||||
|
p += 1;
|
||||||
|
q += tgt_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
free(tgt_small);
|
||||||
|
free(buf);
|
||||||
|
free(gpows);
|
||||||
|
free(shifted);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// inverse of 2 (which is is the same as `(p+1)/2`)
|
||||||
|
const uint64_t goldilocks_oneHalf = 0x7fffffff80000001ull;
|
||||||
|
|
||||||
|
void goldilocks_ntt_inverse_noalloc(int m, int tgt_stride, const uint64_t *gpows, const uint64_t *src, uint64_t *buf, uint64_t *tgt) {
|
||||||
|
|
||||||
|
if (m==0) {
|
||||||
|
tgt[0] = src[0];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (m==1) {
|
||||||
|
// N = 2
|
||||||
|
|
||||||
|
tgt[0 ] = goldilocks_add( src[0] , src[1] ); // x + y
|
||||||
|
tgt[tgt_stride] = goldilocks_sub( src[0] , src[1] ); // x - y
|
||||||
|
tgt[0 ] = goldilocks_div_by_2( tgt[0 ] ); // (x + y)/2
|
||||||
|
tgt[tgt_stride] = goldilocks_div_by_2( tgt[tgt_stride] ); // (x - y)/2
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
else {
|
||||||
|
int N = (1<< m );
|
||||||
|
int halfN = (1<<(m-1));
|
||||||
|
|
||||||
|
for(int j=0; j<halfN; j++) {
|
||||||
|
uint64_t gpow = gpows[j*tgt_stride];
|
||||||
|
buf[j ] = goldilocks_add( src[j] , src[j+halfN] ); // x + y
|
||||||
|
buf[j+halfN] = goldilocks_sub( src[j] , src[j+halfN] ); // x - y
|
||||||
|
buf[j ] = goldilocks_div_by_2( buf[j ] ); // (x + y) / 2
|
||||||
|
buf[j+halfN] = goldilocks_mul ( buf[j+halfN] , gpow ); // (x - y) / (2*g^k)
|
||||||
|
}
|
||||||
|
|
||||||
|
goldilocks_ntt_inverse_noalloc( m-1 , tgt_stride<<1 , gpows , buf , buf + N , tgt );
|
||||||
|
goldilocks_ntt_inverse_noalloc( m-1 , tgt_stride<<1 , gpows , buf + halfN , buf + N , tgt + tgt_stride );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inverse number-theoretical transform (interpolation of a polynomial)
|
||||||
|
// `src` and `tgt` should be `N = 2^m` sized arrays of field elements
|
||||||
|
// `gen` should be the generator of the multiplicative subgroup sized `N`
|
||||||
|
void goldilocks_ntt_inverse(int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
|
||||||
|
int N = (1<<m);
|
||||||
|
int halfN = (N>>1);
|
||||||
|
|
||||||
|
// precalculate [1/2,g^{-1}/2,g^{-2}/2,g^{-3}/2...]
|
||||||
|
uint64_t *gpows = malloc( 8 * halfN );
|
||||||
|
assert( gpows != 0 );
|
||||||
|
uint64_t x = goldilocks_oneHalf; // 1/2
|
||||||
|
uint64_t ginv = goldilocks_inv(gen); // gen^-1
|
||||||
|
for(int i=0; i<halfN; i++) {
|
||||||
|
gpows[i] = x;
|
||||||
|
x = goldilocks_mul(x, ginv);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t *buf = malloc( 8 * (2*N) );
|
||||||
|
assert( buf !=0 );
|
||||||
|
goldilocks_ntt_inverse_noalloc( m, 1, gpows, src, buf, tgt );
|
||||||
|
free(buf);
|
||||||
|
free(gpows);
|
||||||
|
}
|
||||||
|
|
||||||
|
// it's like `ntt_inverse` but we post-multiply the resulting coefficients with `eta^k`
|
||||||
|
// resulting in interpolating an f such that f(eta^-1 * omega^k) = y_k
|
||||||
|
void goldilocks_ntt_inverse_shifted(const uint64_t eta, int m, const uint64_t gen, const uint64_t *src, uint64_t *tgt) {
|
||||||
|
int N = (1<<m);
|
||||||
|
uint64_t *unshifted = malloc( 8*N );
|
||||||
|
assert( unshifted != 0 );
|
||||||
|
goldilocks_ntt_inverse( m, gen, src, unshifted );
|
||||||
|
uint64_t x = 1;
|
||||||
|
for(int i=0; i<N; i++) {
|
||||||
|
tgt[i] = goldilocks_mul( unshifted[i] , x );
|
||||||
|
x = goldilocks_mul( x , eta );
|
||||||
|
}
|
||||||
|
free(unshifted);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
14
reference/src/cbits/ntt.h
Normal file
14
reference/src/cbits/ntt.h
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
void goldilocks_ntt_forward ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||||
|
void goldilocks_ntt_forward_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||||
|
|
||||||
|
void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, uint64_t gen_src, uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt);
|
||||||
|
|
||||||
|
void goldilocks_ntt_inverse ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||||
|
void goldilocks_ntt_inverse_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt);
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
Loading…
x
Reference in New Issue
Block a user