diff --git a/reference/src/FRI/Types.hs b/reference/src/FRI/Types.hs index b9076e5..eaed620 100644 --- a/reference/src/FRI/Types.hs +++ b/reference/src/FRI/Types.hs @@ -87,6 +87,9 @@ newtype ReductionStrategy = MkRedStrategy } deriving (Eq,Show) +reductionStrategyLength :: ReductionStrategy -> Int +reductionStrategyLength = length . fromReductionStrategy + instance Binary ReductionStrategy where put = putSmallList . fromReductionStrategy get = MkRedStrategy <$> getSmallList @@ -123,6 +126,18 @@ data FriConfig = MkFriConfig } deriving (Eq,Show) +friCommitPhaseVectorSizes :: FriConfig -> [Log2] +friCommitPhaseVectorSizes (MkFriConfig{..}) = result where + result = go (rsEncodedSize friRSConfig) (fromReductionStrategy friReductionStrategy) + go n [] = [] + go n (a:as) = n : go (n-a) as + +friCommitPhaseTreeSizes :: FriConfig -> [Log2] +friCommitPhaseTreeSizes (MkFriConfig{..}) = result where + result = go (rsEncodedSize friRSConfig) (fromReductionStrategy friReductionStrategy) + go n [] = [] + go n (a:as) = (n-a) : go (n-a) as + instance Binary FriConfig where put (MkFriConfig{..}) = do put friRSConfig @@ -252,3 +267,44 @@ instance Binary FriProof where <*> get -------------------------------------------------------------------------------- + +estimateFriProofSize :: FriConfig -> Int +estimateFriProofSize friConfig@(MkFriConfig{..}) = total where + + total = friCfgSize + commitPhaseCaps + finalPolySize + (friNQueryRounds * queryRoundSize) + 8 + + arities = fromReductionStrategy friReductionStrategy + nsteps = length arities + + rsCfgSize = 1 + 1 + 8 + redStratSize = 1 + nsteps + friCfgSize = rsCfgSize + 8 + 1 + redStratSize + 8 + 1 + + (phases2,_) = friCommitPhaseSizesLog2 friConfig + + merkleCapSize = 32 * exp2_ friMerkleCapSize + commitPhaseCaps = merkleCapSize * nsteps + finalPolySize = 8 * exp2_ (rsDataSize - sum arities) + + queryStepSize arity = 16 * exp2_ arity + + +data FriQueryStep = MkFriQueryStep + { queryEvals :: [FExt] + , queryMerklePath :: RawMerklePath + } + deriving (Eq,Show) + +data FriQueryRound = MkFriQueryRound + { queryRow :: FRow + , queryInitialTreeProof :: RawMerklePath + , querySteps :: [FriQueryStep] + } + deriving (Eq,Show) + +data FriProof = MkFriProof + { proofFriConfig :: FriConfig -- ^ the FRI configuration + , proofCommitPhaseCaps :: [MerkleCap] -- ^ commit phase Merkle caps + , proofFinalPoly :: Poly FExt -- ^ the final polynomial in coefficient form + , proofQueryRounds :: [FriQueryRound] -- ^ query rounds + , proofPowWitness :: F -- ^ witness showing that the prover did PoW + } diff --git a/reference/src/Field/Class.hs b/reference/src/Field/Class.hs index 01bce31..bade3a7 100644 --- a/reference/src/Field/Class.hs +++ b/reference/src/Field/Class.hs @@ -1,8 +1,10 @@ +{-# LANGUAGE TypeFamilies #-} module Field.Class where -------------------------------------------------------------------------------- +import Data.Kind import Data.Proxy import System.Random @@ -26,6 +28,15 @@ class (Show a, Eq a, Num a, Fractional a) => Field a where inverse :: Field a => a -> a inverse = recip +-- | Quadratic extensions +class (Field (BaseField ext), Field ext) => QuadraticExt ext where + type BaseField ext :: Type + inject :: BaseField ext -> ext + project :: ext -> Maybe (BaseField ext) + scale :: BaseField ext -> ext -> ext + quadraticPack :: (BaseField ext, BaseField ext) -> ext + quadraticUnpack :: ext -> (BaseField ext, BaseField ext) + -------------------------------------------------------------------------------- instance Field Goldi.F where @@ -53,3 +64,14 @@ instance Field GoldiExt.FExt where rndIO = randomIO -------------------------------------------------------------------------------- + +instance QuadraticExt GoldiExt.FExt where + type BaseField GoldiExt.FExt = Goldi.F + + inject = GoldiExt.inj + project = GoldiExt.proj + scale = GoldiExt.scl + quadraticPack = GoldiExt.pack + quadraticUnpack = GoldiExt.unpack + +-------------------------------------------------------------------------------- diff --git a/reference/src/Field/Goldilocks/Extension/BindC.hs b/reference/src/Field/Goldilocks/Extension/BindC.hs index f85ec96..88f7ab8 100644 --- a/reference/src/Field/Goldilocks/Extension/BindC.hs +++ b/reference/src/Field/Goldilocks/Extension/BindC.hs @@ -139,6 +139,17 @@ binaryOpIO c_action x y = allocaBytesAligned 48 8 $ \ptr1 -> do inj :: F -> F2 inj r = F2 r 0 +proj :: F2 -> Maybe F +proj (F2 r i) = if Goldi.isZero i then Just r else Nothing + +pack :: (F,F) -> F2 +pack (r,i) = F2 r i + +unpack :: F2 -> (F,F) +unpack (F2 r i) = (r,i) + +-------------------------------------------------------------------------------- + neg, sqr, inv :: F2 -> F2 neg x = unsafePerformIO (unaryOpIO c_goldilocks_ext_neg x) sqr x = unsafePerformIO (unaryOpIO c_goldilocks_ext_sqr x) diff --git a/reference/src/Field/Goldilocks/Extension/Haskell.hs b/reference/src/Field/Goldilocks/Extension/Haskell.hs index fa7c586..b182b18 100644 --- a/reference/src/Field/Goldilocks/Extension/Haskell.hs +++ b/reference/src/Field/Goldilocks/Extension/Haskell.hs @@ -108,6 +108,17 @@ isOne (F2 r i) = Goldi.isOne r && Goldi.isZero i inj :: F -> F2 inj r = F2 r 0 +proj :: F2 -> Maybe F +proj (F2 r i) = if Goldi.isZero i then Just r else Nothing + +pack :: (F,F) -> F2 +pack (r,i) = F2 r i + +unpack :: F2 -> (F,F) +unpack (F2 r i) = (r,i) + +-------------------------------------------------------------------------------- + neg :: F2 -> F2 neg (F2 r i) = F2 (negate r) (negate i) diff --git a/reference/src/NTT.hs b/reference/src/NTT.hs index 8fe02a0..5c19c78 100644 --- a/reference/src/NTT.hs +++ b/reference/src/NTT.hs @@ -5,11 +5,11 @@ module NTT ( module Field.Goldilocks , module NTT.Subgroup , module NTT.Poly - , module NTT.Slow + , module NTT.FFT ) where import Field.Goldilocks import NTT.Subgroup import NTT.Poly -import NTT.Slow +import NTT.FFT diff --git a/reference/src/NTT/Class.hs b/reference/src/NTT/Class.hs new file mode 100644 index 0000000..6541863 --- /dev/null +++ b/reference/src/NTT/Class.hs @@ -0,0 +1,22 @@ + +module NTT.Class where + +-------------------------------------------------------------------------------- + +import Data.Kind + +import NTT.Subgroup + +-------------------------------------------------------------------------------- + +{- +class NTT ntt where + type Field ntt :: Type + type Poly ntt :: Type + + forwardNTT :: Subgroup F -> Poly F -> FlatArray F + inverseNTT :: Subgroup F -> FlatArray F -> Poly F + shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F + shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F + asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F +-} diff --git a/reference/src/NTT/FFT.hs b/reference/src/NTT/FFT.hs new file mode 100644 index 0000000..fb25fd0 --- /dev/null +++ b/reference/src/NTT/FFT.hs @@ -0,0 +1,14 @@ + +{-# LANGUAGE CPP #-} + +#ifdef USE_NAIVE_HASKELL + +module NTT.FFT ( module NTT.FFT.Slow ) where +import NTT.FFT.Slow + +#else + +module NTT.FFT ( module NTT.FFT.Fast ) where +import NTT.FFT.Fast + +#endif diff --git a/reference/src/NTT/FFT/Fast.hs b/reference/src/NTT/FFT/Fast.hs new file mode 100644 index 0000000..19a13a4 --- /dev/null +++ b/reference/src/NTT/FFT/Fast.hs @@ -0,0 +1,132 @@ + +{-# LANGUAGE StrictData, ForeignFunctionInterface #-} + +module NTT.FFT.Fast where + +-------------------------------------------------------------------------------- + +import Data.Word + +import Foreign.C +import Foreign.Ptr +import Foreign.ForeignPtr +import Foreign.Marshal + +import System.IO.Unsafe + +import Data.Flat + +import NTT.Poly +import NTT.Subgroup +import Field.Goldilocks +import Misc + +-------------------------------------------------------------------------------- + +{- +void goldilocks_ntt_forward ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt); +void goldilocks_ntt_forward_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt); + +void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, uint64_t gen_src, uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt); + +void goldilocks_ntt_inverse ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt); +void goldilocks_ntt_inverse_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt); +-} + +foreign import ccall unsafe "goldilocks_ntt_forward" c_ntt_forward :: CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO () +foreign import ccall unsafe "goldilocks_ntt_forward_shifted" c_ntt_forward_shifted :: Word64 -> CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO () + +foreign import ccall unsafe "goldilocks_ntt_forward_asymmetric" c_ntt_forward_asymmetric :: CInt -> CInt -> Word64 -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO (); + +foreign import ccall unsafe "goldilocks_ntt_inverse" c_ntt_inverse :: CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO () +foreign import ccall unsafe "goldilocks_ntt_inverse_shifted" c_ntt_inverse_shifted :: Word64 -> CInt -> Word64 -> Ptr Word64 -> Ptr Word64 -> IO () + +-------------------------------------------------------------------------------- + +{-# NOINLINE forwardNTT #-} +forwardNTT :: Subgroup F -> Poly F -> FlatArray F +forwardNTT sg (MkPoly (MkFlatArray n fptr2)) + | subgroupSize sg /= n = error "forwardNTT: subgroup size differs from the array size" + | otherwise = unsafePerformIO $ do + fptr3 <- mallocForeignPtrArray n -- (n*4) + let MkGoldilocks gen = subgroupGen sg + withForeignPtr fptr2 $ \ptr2 -> do + withForeignPtr fptr3 $ \ptr3 -> do + c_ntt_forward (subgroupCLogSize sg) gen ptr2 ptr3 + return (MkFlatArray n fptr3) + +{-# NOINLINE inverseNTT #-} +inverseNTT :: Subgroup F -> FlatArray F -> Poly F +inverseNTT sg (MkFlatArray n fptr2) + | subgroupSize sg /= n = error "inverseNTT: subgroup size differs from the array size" + | otherwise = unsafePerformIO $ do + fptr3 <- mallocForeignPtrArray n --(n*4) + let MkGoldilocks gen = subgroupGen sg + withForeignPtr fptr2 $ \ptr2 -> do + withForeignPtr fptr3 $ \ptr3 -> do + c_ntt_inverse (subgroupCLogSize sg) gen ptr2 ptr3 + return (MkPoly (MkFlatArray n fptr3)) + +-- | Pre-multiplies the coefficients by powers of eta, effectively evaluating @f(eta*x)@ on the subgroup +{-# NOINLINE shiftedForwardNTT #-} +shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F +shiftedForwardNTT sg (MkGoldilocks eta) (MkPoly (MkFlatArray n fptr2)) + | subgroupSize sg /= n = error "shiftedForwardNTT: subgroup size differs from the array size" + | otherwise = unsafePerformIO $ do + fptr3 <- mallocForeignPtrArray n -- (n*4) + let MkGoldilocks gen = subgroupGen sg + withForeignPtr fptr2 $ \ptr2 -> do + withForeignPtr fptr3 $ \ptr3 -> do + c_ntt_forward_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3 + return (MkFlatArray n fptr3) + +-- | Post-multiplies the coefficients by powers of eta, effectively interpolating @f@ such that @f(eta^-1 * omega^k) = y_k@ +{-# NOINLINE shiftedInverseNTT #-} +shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F +shiftedInverseNTT sg (MkGoldilocks eta) (MkFlatArray n fptr2) + | subgroupSize sg /= n = error "shiftedInverseNTT: subgroup size differs from the array size" + | otherwise = unsafePerformIO $ do + fptr3 <- mallocForeignPtrArray n -- (n*4) + withForeignPtr fptr2 $ \ptr2 -> do + withForeignPtr fptr3 $ \ptr3 -> do + c_ntt_inverse_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3 + return (MkPoly (MkFlatArray n fptr3)) + +-- | Evaluates a polynomial f on a larger subgroup than it's defined on +{-# NOINLINE asymmForwardNTT #-} +asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F +asymmForwardNTT sg_src (MkPoly (MkFlatArray n fptr2)) sg_tgt + | subgroupSize sg_src /= n = error "asymmForwardNTT: subgroup size differs from the array size" + | m < n = error "asymmForwardNTT: target subgroup size should be at least the source subgroup src" + | otherwise = unsafePerformIO $ do + fptr3 <- mallocForeignPtrArray m -- (m*4) + let MkGoldilocks sgen1 = subgroupGen sg_src + let MkGoldilocks sgen2 = subgroupGen sg_tgt + withForeignPtr fptr2 $ \ptr2 -> do + withForeignPtr fptr3 $ \ptr3 -> do + c_ntt_forward_asymmetric + (subgroupCLogSize sg_src) + (subgroupCLogSize sg_tgt) + sgen1 sgen2 ptr2 ptr3 + return (MkFlatArray m fptr3) + where + m = subgroupSize sg_tgt + +{- +instance P.UnivariateFFT Poly where + ntt = forwardNTT + intt = inverseNTT + shiftedNTT = shiftedForwardNTT + shiftedINTT = shiftedInverseNTT + asymmNTT = asymmForwardNTT +-} + +-------------------------------------------------------------------------------- + +subgroupCLogSize :: Subgroup a -> CInt +subgroupCLogSize = fromIntegral . fromLog2 . subgroupLogSize + +subgroupGenAsWord64 :: Subgroup F -> Word64 +subgroupGenAsWord64 sg = let MkGoldilocks x = subgroupGen sg in x + +-------------------------------------------------------------------------------- diff --git a/reference/src/NTT/Slow.hs b/reference/src/NTT/FFT/Slow.hs similarity index 99% rename from reference/src/NTT/Slow.hs rename to reference/src/NTT/FFT/Slow.hs index 77110ce..0c41c76 100644 --- a/reference/src/NTT/Slow.hs +++ b/reference/src/NTT/FFT/Slow.hs @@ -1,13 +1,13 @@ {-# LANGUAGE ScopedTypeVariables #-} -module NTT.Slow where +module NTT.FFT.Slow where -------------------------------------------------------------------------------- import Data.Array import Data.Bits -import NTT.Poly +import NTT.Poly.Naive import NTT.Subgroup import Field.Goldilocks diff --git a/reference/src/NTT/Poly.hs b/reference/src/NTT/Poly.hs index 3676a7c..e2f1e51 100644 --- a/reference/src/NTT/Poly.hs +++ b/reference/src/NTT/Poly.hs @@ -1,227 +1,13 @@ +{-# LANGUAGE CPP #-} --- | Dense univariate polynomials +#ifdef USE_NAIVE_HASKELL -{-# LANGUAGE StrictData, BangPatterns, ScopedTypeVariables, DeriveFunctor #-} -module NTT.Poly where +module NTT.Poly ( module NTT.Poly.Naive ) where +import NTT.Poly.Naive --------------------------------------------------------------------------------- +#else -import Data.List -import Data.Array -import Data.Array.ST (STArray) -import Data.Array.MArray (newArray, readArray, writeArray, thaw, freeze) +module NTT.Poly ( module NTT.Poly.Flat ) where +import NTT.Poly.Flat -import Control.Monad -import Control.Monad.ST.Strict - -import System.Random - -import Data.Binary - -import Field.Goldilocks -import Field.Goldilocks.Extension ( FExt ) -import Field.Encode - -import Misc - --------------------------------------------------------------------------------- --- * Univariate polynomials - --- | A dense univariate polynomial. The array index corresponds to the exponent. -newtype Poly a - = Poly (Array Int a) - deriving (Show,Functor) - -instance Binary a => Binary (Poly a) where - put (Poly arr) = putSmallArray arr - get = Poly <$> getSmallArray - -instance (Num a, Eq a) => Eq (Poly a) where - p == q = polyIsZero (polySub p q) - -instance FieldEncode (Poly F) where - fieldEncode (Poly arr) = fieldEncode arr - -instance FieldEncode (Poly FExt) where - fieldEncode (Poly arr) = fieldEncode arr - -mkPoly :: [a] -> Poly a -mkPoly coeffs = Poly $ listArray (0,length coeffs-1) coeffs - --- | Degree of the polynomial -polyDegree :: (Eq a, Num a) => Poly a -> Int -polyDegree (Poly arr) = worker d0 where - (0,d0) = bounds arr - worker d - | d < 0 = -1 - | arr!d /= 0 = d - | otherwise = worker (d-1) - --- | Size of the polynomial (can be larger than @degree + 1@, if the some top coefficients are zeros) -polySize :: (Eq a, Num a) => Poly a -> Int -polySize (Poly p) = arraySize p - -polyIsZero :: (Eq a, Num a) => Poly a -> Bool -polyIsZero (Poly arr) = all (==0) (elems arr) - --- | Returns the coefficient of @x^k@ -polyCoeff :: Num a => Poly a -> Int -> a -polyCoeff (Poly coeffs) k = safeIndex 0 coeffs k - --- | Note: this can include zero coeffs at higher than the actual degree! -polyCoeffArray :: Poly a -> Array Int a -polyCoeffArray (Poly coeffs) = coeffs - --- | Note: this cuts off the potential extra zeros at the end. --- The order is little-endian (constant term first). -polyCoeffList :: (Eq a, Num a) => Poly a -> [a] -polyCoeffList poly@(Poly arr) = take (polyDegree poly + 1) (elems arr) - --------------------------------------------------------------------------------- --- * Elementary polynomials - --- | Constant polynomial -polyConst :: a -> Poly a -polyConst x = Poly $ listArray (0,0) [x] - --- | Zero polynomial -polyZero :: Num a => Poly a -polyZero = polyConst 0 - --- | The polynomial @f(x) = x@ -polyVarX :: Num a => Poly a -polyVarX = mkPoly [0,1] - --- | @polyLinear (A,B)@ means the linear polynomial @f(x) = A*x + B@ -polyLinear :: (a,a) -> Poly a -polyLinear (a,b) = mkPoly [b,a] - --- | The monomial @x^n@ -polyXpowN :: Num a => Int -> Poly a -polyXpowN n = Poly $ listArray (0,n) (replicate n 0 ++ [1]) - --- | The binomial @(x^n - 1)@ -polyXpowNminus1 :: Num a => Int -> Poly a -polyXpowNminus1 n = Poly $ listArray (0,n) (-1 : replicate (n-1) 0 ++ [1]) - --------------------------------------------------------------------------------- --- * Evaluate polynomials - -polyEvalAt :: forall f. Fractional f => Poly f -> f -> f -polyEvalAt (Poly arr) x = go 0 1 0 where - (0,d) = bounds arr - go :: f -> f -> Int -> f - go !acc !y !i = if i > d - then acc - else go (acc + (arr!i)*y) (y*x) (i+1) - -polyEvalOnList :: forall f. Fractional f => Poly f -> [f] -> [f] -polyEvalOnList poly = map (polyEvalAt poly) - -polyEvalOnArray :: forall f. Fractional f => Poly f -> Array Int f -> Array Int f -polyEvalOnArray poly = fmap (polyEvalAt poly) - --------------------------------------------------------------------------------- --- * Basic arithmetic operations on polynomials - -polyNeg :: Num a => Poly a -> Poly a -polyNeg (Poly arr) = Poly $ fmap negate arr - -polyAdd :: Num a => Poly a -> Poly a -> Poly a -polyAdd (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where - (0,d1) = bounds arr1 - (0,d2) = bounds arr2 - d3 = max d1 d2 - zs = zipWith (+) (elems arr1 ++ replicate (d3-d1) 0) - (elems arr2 ++ replicate (d3-d2) 0) - -polySub :: Num a => Poly a -> Poly a -> Poly a -polySub (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where - (0,d1) = bounds arr1 - (0,d2) = bounds arr2 - d3 = max d1 d2 - zs = zipWith (-) (elems arr1 ++ replicate (d3-d1) 0) - (elems arr2 ++ replicate (d3-d2) 0) - -polyMul :: Num a => Poly a -> Poly a -> Poly a -polyMul (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where - (0,d1) = bounds arr1 - (0,d2) = bounds arr2 - d3 = d1 + d2 - zs = [ f k | k<-[0..d3] ] - f !k = foldl' (+) 0 [ arr1!i * arr2!(k-i) | i<-[ max 0 (k-d2) .. min d1 k ] ] - -instance Num a => Num (Poly a) where - fromInteger = polyConst . fromInteger - negate = polyNeg - (+) = polyAdd - (-) = polySub - (*) = polyMul - abs = id - signum = \_ -> polyConst 1 - -polySum :: Num a => [Poly a] -> Poly a -polySum = foldl' polyAdd 0 - -polyProd :: Num a => [Poly a] -> Poly a -polyProd = foldl' polyMul 1 - --------------------------------------------------------------------------------- --- * Polynomial long division - --- | @polyDiv f h@ returns @(q,r)@ such that @f = q*h + r@ and @deg r < deg h@ -polyDiv :: forall f. (Eq f, Fractional f) => Poly f -> Poly f -> (Poly f, Poly f) -polyDiv poly_f@(Poly arr_f) poly_h@(Poly arr_h) - | deg_q < 0 = (polyZero, poly_f) - | otherwise = runST action - where - deg_f = polyDegree poly_f - deg_h = polyDegree poly_h - deg_q = deg_f - deg_h - - -- inverse of the top coefficient of divisor - b_inv = recip (arr_h ! deg_h) - - action :: forall s. ST s (Poly f, Poly f) - action = do - p <- thaw arr_f :: ST s (STArray s Int f) - q <- newArray (0,deg_q) 0 :: ST s (STArray s Int f) - forM_ [deg_q,deg_q-1..0] $ \k -> do - top <- readArray p (deg_h + k) - let y = b_inv * top - writeArray q k y - forM_ [0..deg_h] $ \j -> do - a <- readArray p (j+k) - writeArray p (j+k) (a - y*(arr_h!j)) - qarr <- freeze q - rs <- forM [0..deg_h-1] $ \i -> readArray p i - let rarr = listArray (0,deg_h-1) rs - return (Poly qarr, Poly rarr) - --- | Returns only the quotient -polyDivQuo :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f -polyDivQuo f g = fst $ polyDiv f g - --- | Returns only the remainder -polyDivRem :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f -polyDivRem f g = snd $ polyDiv f g - --------------------------------------------------------------------------------- --- * Sample random polynomials - -randomPoly :: (RandomGen g, Random a) => Int -> g -> (Poly a, g) -randomPoly deg g0 = - let (coeffs,gfinal) = worker (deg+1) g0 - poly = Poly (listArray (0,deg) coeffs) - in (poly, gfinal) - - where - worker 0 g = ([] , g) - worker n g = let (x ,g1) = random g - (xs,g2) = worker (n-1) g1 - in ((x:xs) , g) - -randomPolyIO :: Random a => Int -> IO (Poly a) -randomPolyIO deg = getStdRandom (randomPoly deg) - --------------------------------------------------------------------------------- \ No newline at end of file +#endif diff --git a/reference/src/NTT/Poly/Flat.hs b/reference/src/NTT/Poly/Flat.hs new file mode 100644 index 0000000..b1179ee --- /dev/null +++ b/reference/src/NTT/Poly/Flat.hs @@ -0,0 +1,55 @@ + +{-# LANGUAGE StrictData, ScopedTypeVariables, PatternSynonyms #-} +module NTT.Poly.Flat where + +import Prelude hiding (div,quot,rem) +import GHC.Real hiding (div,quot,rem) + +import Data.Bits +import Data.Word +import Data.List +import Data.Array + +import Control.Monad + +import Foreign.C +import Foreign.Ptr +import Foreign.Marshal +import Foreign.ForeignPtr + +import System.Random +import System.IO.Unsafe + +import NTT.Subgroup +import Field.Class +import Data.Flat as L + +-------------------------------------------------------------------------------- + +newtype Poly a = MkPoly (L.FlatArray a) + +pattern XPoly n arr = MkPoly (L.MkFlatArray n arr) + +mkPoly :: Flat f => [f] -> Poly f +mkPoly = MkPoly . L.packFlatArrayFromList + +mkPoly' :: Flat f => Int -> [f] -> Poly f +mkPoly' len xs = MkPoly $ L.packFlatArrayFromList' len xs + +mkPolyArr :: Flat f => Array Int f -> Poly f +mkPolyArr = MkPoly . L.packFlatArray + +mkPolyFlatArr :: L.FlatArray f -> Poly f +mkPolyFlatArr = MkPoly + +coeffs :: Flat f => Poly f -> [f] +coeffs (MkPoly arr) = L.unpackFlatArrayToList arr + +coeffsArr :: Flat f => Poly f -> Array Int f +coeffsArr (MkPoly arr) = L.unpackFlatArray arr + +coeffsFlatArr :: Poly f -> L.FlatArray f +coeffsFlatArr (MkPoly flat) = flat + +-------------------------------------------------------------------------------- + diff --git a/reference/src/NTT/Poly/Naive.hs b/reference/src/NTT/Poly/Naive.hs new file mode 100644 index 0000000..665d8c6 --- /dev/null +++ b/reference/src/NTT/Poly/Naive.hs @@ -0,0 +1,227 @@ + +-- | Dense univariate polynomials + +{-# LANGUAGE StrictData, BangPatterns, ScopedTypeVariables, DeriveFunctor #-} +module NTT.Poly.Naive where + +-------------------------------------------------------------------------------- + +import Data.List +import Data.Array +import Data.Array.ST (STArray) +import Data.Array.MArray (newArray, readArray, writeArray, thaw, freeze) + +import Control.Monad +import Control.Monad.ST.Strict + +import System.Random + +import Data.Binary + +import Field.Goldilocks +import Field.Goldilocks.Extension ( FExt ) +import Field.Encode + +import Misc + +-------------------------------------------------------------------------------- +-- * Univariate polynomials + +-- | A dense univariate polynomial. The array index corresponds to the exponent. +newtype Poly a + = Poly (Array Int a) + deriving (Show,Functor) + +instance Binary a => Binary (Poly a) where + put (Poly arr) = putSmallArray arr + get = Poly <$> getSmallArray + +instance (Num a, Eq a) => Eq (Poly a) where + p == q = polyIsZero (polySub p q) + +instance FieldEncode (Poly F) where + fieldEncode (Poly arr) = fieldEncode arr + +instance FieldEncode (Poly FExt) where + fieldEncode (Poly arr) = fieldEncode arr + +mkPoly :: [a] -> Poly a +mkPoly coeffs = Poly $ listArray (0,length coeffs-1) coeffs + +-- | Degree of the polynomial +polyDegree :: (Eq a, Num a) => Poly a -> Int +polyDegree (Poly arr) = worker d0 where + (0,d0) = bounds arr + worker d + | d < 0 = -1 + | arr!d /= 0 = d + | otherwise = worker (d-1) + +-- | Size of the polynomial (can be larger than @degree + 1@, if the some top coefficients are zeros) +polySize :: (Eq a, Num a) => Poly a -> Int +polySize (Poly p) = arraySize p + +polyIsZero :: (Eq a, Num a) => Poly a -> Bool +polyIsZero (Poly arr) = all (==0) (elems arr) + +-- | Returns the coefficient of @x^k@ +polyCoeff :: Num a => Poly a -> Int -> a +polyCoeff (Poly coeffs) k = safeIndex 0 coeffs k + +-- | Note: this can include zero coeffs at higher than the actual degree! +polyCoeffArray :: Poly a -> Array Int a +polyCoeffArray (Poly coeffs) = coeffs + +-- | Note: this cuts off the potential extra zeros at the end. +-- The order is little-endian (constant term first). +polyCoeffList :: (Eq a, Num a) => Poly a -> [a] +polyCoeffList poly@(Poly arr) = take (polyDegree poly + 1) (elems arr) + +-------------------------------------------------------------------------------- +-- * Elementary polynomials + +-- | Constant polynomial +polyConst :: a -> Poly a +polyConst x = Poly $ listArray (0,0) [x] + +-- | Zero polynomial +polyZero :: Num a => Poly a +polyZero = polyConst 0 + +-- | The polynomial @f(x) = x@ +polyVarX :: Num a => Poly a +polyVarX = mkPoly [0,1] + +-- | @polyLinear (A,B)@ means the linear polynomial @f(x) = A*x + B@ +polyLinear :: (a,a) -> Poly a +polyLinear (a,b) = mkPoly [b,a] + +-- | The monomial @x^n@ +polyXpowN :: Num a => Int -> Poly a +polyXpowN n = Poly $ listArray (0,n) (replicate n 0 ++ [1]) + +-- | The binomial @(x^n - 1)@ +polyXpowNminus1 :: Num a => Int -> Poly a +polyXpowNminus1 n = Poly $ listArray (0,n) (-1 : replicate (n-1) 0 ++ [1]) + +-------------------------------------------------------------------------------- +-- * Evaluate polynomials + +polyEvalAt :: forall f. Fractional f => Poly f -> f -> f +polyEvalAt (Poly arr) x = go 0 1 0 where + (0,d) = bounds arr + go :: f -> f -> Int -> f + go !acc !y !i = if i > d + then acc + else go (acc + (arr!i)*y) (y*x) (i+1) + +polyEvalOnList :: forall f. Fractional f => Poly f -> [f] -> [f] +polyEvalOnList poly = map (polyEvalAt poly) + +polyEvalOnArray :: forall f. Fractional f => Poly f -> Array Int f -> Array Int f +polyEvalOnArray poly = fmap (polyEvalAt poly) + +-------------------------------------------------------------------------------- +-- * Basic arithmetic operations on polynomials + +polyNeg :: Num a => Poly a -> Poly a +polyNeg (Poly arr) = Poly $ fmap negate arr + +polyAdd :: Num a => Poly a -> Poly a -> Poly a +polyAdd (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where + (0,d1) = bounds arr1 + (0,d2) = bounds arr2 + d3 = max d1 d2 + zs = zipWith (+) (elems arr1 ++ replicate (d3-d1) 0) + (elems arr2 ++ replicate (d3-d2) 0) + +polySub :: Num a => Poly a -> Poly a -> Poly a +polySub (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where + (0,d1) = bounds arr1 + (0,d2) = bounds arr2 + d3 = max d1 d2 + zs = zipWith (-) (elems arr1 ++ replicate (d3-d1) 0) + (elems arr2 ++ replicate (d3-d2) 0) + +polyMul :: Num a => Poly a -> Poly a -> Poly a +polyMul (Poly arr1) (Poly arr2) = Poly $ listArray (0,d3) zs where + (0,d1) = bounds arr1 + (0,d2) = bounds arr2 + d3 = d1 + d2 + zs = [ f k | k<-[0..d3] ] + f !k = foldl' (+) 0 [ arr1!i * arr2!(k-i) | i<-[ max 0 (k-d2) .. min d1 k ] ] + +instance Num a => Num (Poly a) where + fromInteger = polyConst . fromInteger + negate = polyNeg + (+) = polyAdd + (-) = polySub + (*) = polyMul + abs = id + signum = \_ -> polyConst 1 + +polySum :: Num a => [Poly a] -> Poly a +polySum = foldl' polyAdd 0 + +polyProd :: Num a => [Poly a] -> Poly a +polyProd = foldl' polyMul 1 + +-------------------------------------------------------------------------------- +-- * Polynomial long division + +-- | @polyDiv f h@ returns @(q,r)@ such that @f = q*h + r@ and @deg r < deg h@ +polyDiv :: forall f. (Eq f, Fractional f) => Poly f -> Poly f -> (Poly f, Poly f) +polyDiv poly_f@(Poly arr_f) poly_h@(Poly arr_h) + | deg_q < 0 = (polyZero, poly_f) + | otherwise = runST action + where + deg_f = polyDegree poly_f + deg_h = polyDegree poly_h + deg_q = deg_f - deg_h + + -- inverse of the top coefficient of divisor + b_inv = recip (arr_h ! deg_h) + + action :: forall s. ST s (Poly f, Poly f) + action = do + p <- thaw arr_f :: ST s (STArray s Int f) + q <- newArray (0,deg_q) 0 :: ST s (STArray s Int f) + forM_ [deg_q,deg_q-1..0] $ \k -> do + top <- readArray p (deg_h + k) + let y = b_inv * top + writeArray q k y + forM_ [0..deg_h] $ \j -> do + a <- readArray p (j+k) + writeArray p (j+k) (a - y*(arr_h!j)) + qarr <- freeze q + rs <- forM [0..deg_h-1] $ \i -> readArray p i + let rarr = listArray (0,deg_h-1) rs + return (Poly qarr, Poly rarr) + +-- | Returns only the quotient +polyDivQuo :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f +polyDivQuo f g = fst $ polyDiv f g + +-- | Returns only the remainder +polyDivRem :: (Eq f, Fractional f) => Poly f -> Poly f -> Poly f +polyDivRem f g = snd $ polyDiv f g + +-------------------------------------------------------------------------------- +-- * Sample random polynomials + +randomPoly :: (RandomGen g, Random a) => Int -> g -> (Poly a, g) +randomPoly deg g0 = + let (coeffs,gfinal) = worker (deg+1) g0 + poly = Poly (listArray (0,deg) coeffs) + in (poly, gfinal) + + where + worker 0 g = ([] , g) + worker n g = let (x ,g1) = random g + (xs,g2) = worker (n-1) g1 + in ((x:xs) , g) + +randomPolyIO :: Random a => Int -> IO (Poly a) +randomPolyIO deg = getStdRandom (randomPoly deg) + +-------------------------------------------------------------------------------- \ No newline at end of file diff --git a/reference/src/NTT/Subgroup.hs b/reference/src/NTT/Subgroup.hs index 6f86864..56cf1c6 100644 --- a/reference/src/NTT/Subgroup.hs +++ b/reference/src/NTT/Subgroup.hs @@ -20,6 +20,9 @@ data Subgroup g = MkSubgroup subgroupSize :: Subgroup g -> Int subgroupSize = subgroupOrder +subgroupLogSize :: Subgroup g -> Log2 +subgroupLogSize = exactLog2__ . subgroupSize + getSubgroup :: Log2 -> Subgroup F getSubgroup log2@(Log2 n) | n<0 = error "getSubgroup: negative logarithm" diff --git a/reference/src/NTT/Tests.hs b/reference/src/NTT/Tests.hs new file mode 100644 index 0000000..de80263 --- /dev/null +++ b/reference/src/NTT/Tests.hs @@ -0,0 +1,55 @@ + +module NTT.Tests where + +-------------------------------------------------------------------------------- + +import Data.Array + +import Control.Monad + +import Field.Class +import Field.Goldilocks ( F ) +import Misc + +import Data.Flat as L +import NTT.Subgroup + +import qualified NTT.Poly.Flat as F +import qualified NTT.FFT.Fast as F + +import qualified NTT.Poly.Naive as S +import qualified NTT.FFT.Slow as S + + +-------------------------------------------------------------------------------- + +-- | Compare slow and fast FFT implementations to each other +compareFFTs :: Log2 -> IO Bool +compareFFTs logSize = do + let size = exp2_ logSize + let sg = getSubgroup logSize + + values1_s <- listToArray <$> replicateM size rndIO :: IO (Array Int F) + let values1_f = L.packFlatArray values1_s :: L.FlatArray F + let poly1_s = S.Poly values1_s :: S.Poly F + let poly1_f = F.mkPolyArr values1_s :: F.Poly F + + let values2_s = S.subgroupNTT sg poly1_s :: Array Int F + let values2_f = F.forwardNTT sg poly1_f :: L.FlatArray F + + let poly2_s = S.subgroupINTT sg values1_s :: S.Poly F + let poly2_f = F.inverseNTT sg values1_f :: F.Poly F + + let ok_ntt = values2_s == L.unpackFlatArray values2_f + let ok_intt = S.polyCoeffArray poly2_s == F.coeffsArr poly2_f + + return (ok_ntt && ok_intt) + +-------------------------------------------------------------------------------- + +runTests :: Int -> Log2 -> IO Bool +runTests n size = do + oks <- replicateM n (compareFFTs size) + return (and oks) + +-------------------------------------------------------------------------------- diff --git a/reference/src/cbits/ntt.c b/reference/src/cbits/ntt.c new file mode 100644 index 0000000..6ef15f7 --- /dev/null +++ b/reference/src/cbits/ntt.c @@ -0,0 +1,228 @@ + +#include +#include +#include +#include + +#include "goldilocks.h" +#include "ntt.h" + +// ----------------------------------------------------------------------------- + +void goldilocks_ntt_forward_noalloc(int m, int src_stride, const uint64_t *gpows, const uint64_t *src, uint64_t *buf, uint64_t *tgt) { + + if (m==0) { + tgt[0] = src[0]; + return; + } + + if (m==1) { + // N = 2 + tgt[0] = goldilocks_add( src[0] , src[src_stride] ); // x + y + tgt[1] = goldilocks_sub( src[0] , src[src_stride] ); // x - y + return; + } + + else { + int N = (1<< m ); + int halfN = (1<<(m-1)); + + goldilocks_ntt_forward_noalloc( m-1 , src_stride<<1 , gpows , src , buf + N , buf ); + goldilocks_ntt_forward_noalloc( m-1 , src_stride<<1 , gpows , src + src_stride , buf + N , buf + halfN ); + + for(int j=0; j>1); + + // precalculate [1,g,g^2,g^3...] + uint64_t *gpows = (uint64_t*) malloc( 8 * halfN ); + assert( gpows != 0 ); + uint64_t x = gen; + gpows[0] = 1; + gpows[1] = gen; + for(int i=2; i= m_src ); + int N_src = (1 << m_src); + int N_tgt = (1 << m_tgt); + int halfN_src = (N_src >> 1); + int K = (1 << (m_tgt - m_src)); + + // precalculate [1,g,g^2,g^3...] + uint64_t *gpows = malloc( 8 * halfN_src ); + assert( gpows != 0 ); + uint64_t x = gen_src; + gpows[0] = 1; + gpows[1] = gen_src; + for(int i=2; i>1); + + // precalculate [1/2,g^{-1}/2,g^{-2}/2,g^{-3}/2...] + uint64_t *gpows = malloc( 8 * halfN ); + assert( gpows != 0 ); + uint64_t x = goldilocks_oneHalf; // 1/2 + uint64_t ginv = goldilocks_inv(gen); // gen^-1 + for(int i=0; i + +//------------------------------------------------------------------------------ + +void goldilocks_ntt_forward ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt); +void goldilocks_ntt_forward_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt); + +void goldilocks_ntt_forward_asymmetric(int m_src, int m_tgt, uint64_t gen_src, uint64_t gen_tgt, const uint64_t *src, uint64_t *tgt); + +void goldilocks_ntt_inverse ( int m, uint64_t gen, const uint64_t *src, uint64_t *tgt); +void goldilocks_ntt_inverse_shifted (uint64_t eta, int m, uint64_t gen, const uint64_t *src, uint64_t *tgt); + +//------------------------------------------------------------------------------