diff --git a/reference/src/Class/Field.hs b/reference/src/Class/Field.hs new file mode 100644 index 0000000..27fbfe2 --- /dev/null +++ b/reference/src/Class/Field.hs @@ -0,0 +1,40 @@ + +{-# LANGUAGE TypeFamilies #-} +module Class.Field where + +-------------------------------------------------------------------------------- + +import Data.Kind +import Data.Proxy + +import System.Random + +-------------------------------------------------------------------------------- + +class (Show a, Eq a, Num a) => Ring a where + zero :: a + one :: a + isZero :: a -> Bool + isOne :: a -> Bool + square :: a -> a + power :: a -> Integer -> a + power_ :: a -> Int -> a + +class (Ring a, Fractional a) => Field a where + inverse :: a -> a + inverse = recip + +class Field a => FiniteField a where + rndIO :: IO a + fieldSize :: Proxy a -> Integer + +-- | Quadratic extensions +class (Field (BaseField ext), Field ext) => QuadraticExt ext where + type BaseField ext :: Type + inject :: BaseField ext -> ext + project :: ext -> Maybe (BaseField ext) + scale :: BaseField ext -> ext -> ext + quadraticPack :: (BaseField ext, BaseField ext) -> ext + quadraticUnpack :: ext -> (BaseField ext, BaseField ext) + +-------------------------------------------------------------------------------- diff --git a/reference/src/Class/Flat.hs b/reference/src/Class/Flat.hs new file mode 100644 index 0000000..29b640a --- /dev/null +++ b/reference/src/Class/Flat.hs @@ -0,0 +1,7 @@ + +module Class.Flat + ( module Data.Flat.Class + ) + where + +import Data.Flat.Class diff --git a/reference/src/Class/NTT.hs b/reference/src/Class/NTT.hs new file mode 100644 index 0000000..3def71d --- /dev/null +++ b/reference/src/Class/NTT.hs @@ -0,0 +1,40 @@ + +module NTT.Class where + +-------------------------------------------------------------------------------- + +import Data.Kind + +import Class.Field +import Class.Poly + +import NTT.Subgroup + +-------------------------------------------------------------------------------- + +{- +class NTT ntt where + type Field ntt :: Type + type Poly ntt :: Type + + forwardNTT :: Subgroup F -> Poly F -> FlatArray F + inverseNTT :: Subgroup F -> FlatArray F -> Poly F + shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F + shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F + asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F +-} + +{- +-- | Polynomials which over an FFT-friend field support NTT operations +class (Univariate p, FFTField (Coeff p)) => UnivariateFFT p where + -- | Number-theoretical transform (evaluate on a subgroup) + ntt :: FFTSubgroup (Coeff p) -> p -> FlatArray (Coeff p) + -- | Inverse number-theoretical transform (interpolate on a subgroup) + intt :: FFTSubgroup (Coeff p) -> FlatArray (Coeff p) -> p + -- | Shifts @f@ by @eta@, evaluating @f(eta*omega^k)@ + shiftedNTT :: FFTSubgroup (Coeff p) -> Coeff p -> p -> FlatArray (Coeff p) + -- | Shifts @f@ by @eta^-1@, interpolating @f@ so that @f(eta^-1 * omega^k) = y_k@ + shiftedINTT :: FFTSubgroup (Coeff p) -> Coeff p -> FlatArray (Coeff p) -> p + -- | Evaluate on a larger subgroup than the polynomial is defined on + asymmNTT :: FFTSubgroup (Coeff p) -> p -> FFTSubgroup (Coeff p) -> FlatArray (Coeff p) +-} diff --git a/reference/src/Class/Poly.hs b/reference/src/Class/Poly.hs new file mode 100644 index 0000000..bcacb2f --- /dev/null +++ b/reference/src/Class/Poly.hs @@ -0,0 +1,99 @@ + +-- | Univariate polynomials + +module Class.Poly where + +-------------------------------------------------------------------------------- + +import Data.Kind + +import Class.Field + +-------------------------------------------------------------------------------- +-- * Dense univariate polynomials over (finite) fields + +class (Ring p, Field (Coeff p)) => Univariate p where + -- | the type of coefficients + type Coeff p :: Type + -- | Degree + degree :: p -> Int + -- | Size (can be larger than degree+1 if the top coefficients are zeor) + polySize :: p -> Int + -- | The k-th coefficient + kthCoeff :: Int -> p -> Coeff p + -- | Evaluation + evalAt :: Coeff p -> p -> Coeff p + -- | Scaling + scale :: Coeff p -> p -> p + -- | Create a polynomial from coefficiens + mkPoly :: [Coeff p] -> p + -- | Coefficients of the polynomial as a list + coeffs :: p -> [Coeff p] + -- | Coefficients as an Array + coeffsArr :: p -> Array Int (Coeff p) +{- + -- | Polynomial long division + polyLongDiv :: p -> p -> (p,p) + -- | Polynomial quotient + polyQuot :: p -> p -> p + -- | Polynomial remainder + polyRem :: p -> p -> p + -- | Divide by the coset vanishing polynomial @(x^n - eta)@ + divByVanishing :: p -> (Int, Coeff p) -> (p,p) + -- | Quotient by the coset vanishing polynomial @(x^n - eta)@ + quotByVanishing :: p -> (Int, Coeff p) -> Maybe p +-} + +-------------------------------------------------------------------------------- +-- * Some generic functions + +-- | Checks whether the input is the constant one polynomial? +polyIsOne :: Univariate p => p -> Bool +polyIsOne p = case mbConst p of + Nothing -> False + Just x -> ZK.Algebra.Class.Field.isOne x + +-- | The constant term of a polynomial +constTermOf :: Univariate p => p -> Coeff p +constTermOf = kthCoeff 0 + +-- | Is this a constant polynomial? +-- +-- TODO: this is not efficient for high-degree polynomials, where we could exit early... +mbConst :: Univariate p => p -> Maybe (Coeff p) +mbConst p = if degree p <= 0 then Just (constTermOf p) else Nothing + +-- | Create a constant polynomial +constPoly :: Univariate p => Coeff p -> p +constPoly y = mkPoly [y] + +-- | The polynomial @p(x)=x@ +idPoly :: Univariate p => p +idPoly = mkPoly [0,1] + +-- | Create a linear polynomial. +-- +-- > linearPoly a b == a*x + b@ +-- +linearPoly :: Univariate p => Coeff p -> Coeff p -> p +linearPoly a b = mkPoly [b,a] + +showPoly :: Univariate p => p -> String +showPoly = showPoly' True + +showPoly' :: Univariate p => Bool -> p -> String +showPoly' newlines_flag poly = + case newlines_flag of + False -> intercalate " +" terms + True -> intercalate " +\n" terms + where + pairs = filter (\kx -> snd kx /= 0) + $ zip [0..] (coeffs poly) + terms = case pairs of + [] -> [" 0"] + _ -> map f pairs + f (0,x) = ' ' : show x + f (1,x) = ' ' : show x ++ " * x" + f (k,x) = ' ' : show x ++ " * x^" ++ show k + +-------------------------------------------------------------------------------- \ No newline at end of file diff --git a/reference/src/Class/Vector.hs b/reference/src/Class/Vector.hs new file mode 100644 index 0000000..5d28d59 --- /dev/null +++ b/reference/src/Class/Vector.hs @@ -0,0 +1,112 @@ + +{-# LANGUAGE TypeFamilies #-} +module Class.Vector where + +-------------------------------------------------------------------------------- + +import Data.Kind + +import Data.Array.IArray + +import Class.Field +import Misc + +-------------------------------------------------------------------------------- + +class Vector v where + + type VecElem v :: Type + + vecLength :: v -> Int + vecPeek :: v -> Int -> VecElem v + + vecToList :: v -> [VecElem v] + vecFromList :: [VecElem v] -> v + + vecToArray :: v -> Array Int (VecElem v) + vecFromArray :: Array Int (VecElem v) -> v + + vecTake :: Int -> v -> v + vecDrop :: Int -> v -> v + vecAppend :: v -> v -> v + vecConcat :: [v] -> v + vecMap :: (VecElem v -> VecElem v) -> v -> v + +vecSize :: Vector v => v -> Int +vecSize = vecLength + +-------------------------------------------------------------------------------- + +instance Vector [a] where + type VecElem [a] = a + vecLength = length + vecPeek = (!!) + vecToList = id + vecFromList = id + vecToArray = listToArray + vecFromArray = elems + vecTake = take + vecDrop = drop + vecAppend = (++) + vecConcat = concat + vecMap = map + +instance Vector (Array Int a) where + type VecElem (Array Int a) = a + vecLength = \arr -> let (0,n1) = bounds arr in n1+1 + vecPeek = (!) + vecToList = elems + vecFromList = listToArray + vecToArray = id + vecFromArray = id + vecTake = takeArray + vecDrop = dropArray + vecAppend = appendArrays + vecConcat = concatArrays + vecMap = amap + +-------------------------------------------------------------------------------- +-- * Pointwise operations + +class PointwiseGroup a where + -- | Pointwise negation + pwNeg :: a -> a + -- | Pointwise addition + pwAdd :: a -> a -> a + -- | Pointwise subtraction + pwSub :: a -> a -> a + +class PointwiseGroup a => PointwiseRing a where + -- | Pointwise squaring + pwSqr :: a -> a + -- | Pointwise multiplication + pwMul :: a -> a -> a + -- | Pointwise @a*b+c@ + pwMulAdd :: a -> a -> a -> a + -- | Pointwise @a*b-c@ + pwMulSub :: a -> a -> a -> a + +class PointwiseRing a => PointwiseField a where + -- | Pointwise inversion + pwInv :: a -> a + -- | Pointwise division + pwDiv :: a -> a -> a + +-------------------------------------------------------------------------------- + +-- | Finite dimensional vector spaces +class (Vector v, Field (VecElem v), PointwiseField v) => VectorSpace v where + -- | Scaling by an element + vecScale :: VecElem v -> v -> v + -- | Dot product + dotProd :: v -> v -> VecElem v + -- | The array @[ a*b^k | k<-[0..n-1] ]@ + powers :: VecElem v -> VecElem v -> Int -> v + -- | Pointwise multiplication by the array @[ a*b^k | k<-[0..n-1] ]@ + mulByPowers :: VecElem v -> VecElem v -> v -> v + -- | Linear combination @a*x + y@ + linComb1 :: (VecElem v, v) -> v -> v + -- | Linear combination @a*x + b*y@ + linComb2 :: (VecElem v, v) -> (VecElem v, v) -> v + +-------------------------------------------------------------------------------- diff --git a/reference/src/Data/Flat/Array.hs b/reference/src/Data/Flat/Array.hs index 45ff599..674aa3b 100644 --- a/reference/src/Data/Flat/Array.hs +++ b/reference/src/Data/Flat/Array.hs @@ -25,6 +25,7 @@ import System.IO import System.IO.Unsafe import Data.Flat.Class +import Class.Vector import Misc -------------------------------------------------------------------------------- @@ -33,7 +34,7 @@ import Misc -- foreign memory (not managed by the Haskell runtime). -- -- Note: the @Int@ means the number of objects in the array. -data FlatArray a +data FlatArray (a :: Type) = MkFlatArray !Int !(ForeignPtr Word64) deriving Show @@ -128,6 +129,14 @@ dropFlatArrayIO k (MkFlatArray n fptr1) = do when (m>0) $ copyBytes ptr2 src (8*sz*m) return (MkFlatArray m fptr2) +{-# NOINLINE mapFlatArrayIO #-} +mapFlatArrayIO :: (a -> a) -> FlatArray a -> IO (FlatArray a) +mapFlatArrayIO = error "mapFlatArrayIO: not yet implemented" + +{-# NOINLINE mapFlatArray #-} +mapFlatArray :: (a -> a) -> FlatArray a -> FlatArray a +mapFlatArray f arr = unsafePerformIO (mapFlatArrayIO f arr) + ---------------------------------------- -- | Read a flat array from a raw binary file. The size of the file determines the length of the array. @@ -214,3 +223,20 @@ unpackFlatArrayToListIO (MkFlatArray len fptr) = do makeFlat src -------------------------------------------------------------------------------- + +instance Flat a => Vector (FlatArray a) where + + type VecElem (FlatArray a) = a + + vecLength = flatArrayLength + vecPeek = peekFlatArray + + vecToList = unpackFlatArrayToList + vecFromList = packFlatArrayFromList + + vecToArray = unpackFlatArray + vecFromArray = packFlatArray + + vecAppend = error "Vector/vecAppend/FlatArray: not implemented" + vecConcat = error "Vector/vecAConcat/FlatArray: not implemented" + vecMap = mapFlatArray diff --git a/reference/src/Field/Class.hs b/reference/src/Field/Class.hs deleted file mode 100644 index bade3a7..0000000 --- a/reference/src/Field/Class.hs +++ /dev/null @@ -1,77 +0,0 @@ - -{-# LANGUAGE TypeFamilies #-} -module Field.Class where - --------------------------------------------------------------------------------- - -import Data.Kind -import Data.Proxy - -import System.Random - -import qualified Field.Goldilocks as Goldi -import qualified Field.Goldilocks.Extension as GoldiExt - --------------------------------------------------------------------------------- - -class (Show a, Eq a, Num a, Fractional a) => Field a where - fieldSize :: Proxy a -> Integer - zero :: a - one :: a - isZero :: a -> Bool - isOne :: a -> Bool - square :: a -> a - power :: a -> Integer -> a - power_ :: a -> Int -> a - rndIO :: IO a - -inverse :: Field a => a -> a -inverse = recip - --- | Quadratic extensions -class (Field (BaseField ext), Field ext) => QuadraticExt ext where - type BaseField ext :: Type - inject :: BaseField ext -> ext - project :: ext -> Maybe (BaseField ext) - scale :: BaseField ext -> ext -> ext - quadraticPack :: (BaseField ext, BaseField ext) -> ext - quadraticUnpack :: ext -> (BaseField ext, BaseField ext) - --------------------------------------------------------------------------------- - -instance Field Goldi.F where - fieldSize _ = Goldi.goldilocksPrime - zero = Goldi.zero - one = Goldi.one - isZero = Goldi.isZero - isOne = Goldi.isOne - square = Goldi.sqr - power = Goldi.pow - power_ = Goldi.pow_ - rndIO = randomIO - --------------------------------------------------------------------------------- - -instance Field GoldiExt.FExt where - fieldSize _ = (Goldi.goldilocksPrime ^ 2) - zero = GoldiExt.zero - one = GoldiExt.one - isZero = GoldiExt.isZero - isOne = GoldiExt.isOne - square = GoldiExt.sqr - power = GoldiExt.pow - power_ = GoldiExt.pow_ - rndIO = randomIO - --------------------------------------------------------------------------------- - -instance QuadraticExt GoldiExt.FExt where - type BaseField GoldiExt.FExt = Goldi.F - - inject = GoldiExt.inj - project = GoldiExt.proj - scale = GoldiExt.scl - quadraticPack = GoldiExt.pack - quadraticUnpack = GoldiExt.unpack - --------------------------------------------------------------------------------- diff --git a/reference/src/Field/Goldilocks/Extension/BindC.hs b/reference/src/Field/Goldilocks/Extension/BindC.hs index 88f7ab8..08c07b8 100644 --- a/reference/src/Field/Goldilocks/Extension/BindC.hs +++ b/reference/src/Field/Goldilocks/Extension/BindC.hs @@ -5,7 +5,7 @@ -- with the fast Goldilocks base field operations, but the C versions should be useful -- for the vector operations, and this way we can test them easily. -{-# LANGUAGE ForeignFunctionInterface, BangPatterns, NumericUnderscores #-} +{-# LANGUAGE ForeignFunctionInterface, BangPatterns, NumericUnderscores, TypeFamilies #-} module Field.Goldilocks.Extension.BindC where -------------------------------------------------------------------------------- @@ -31,10 +31,11 @@ import Data.Binary.Put ( putWord64le ) import Text.Printf -import Field.Goldilocks ( F , Goldilocks(..) ) -import qualified Field.Goldilocks as Goldi +import Field.Goldilocks.Fast ( F , Goldilocks(..) ) +import qualified Field.Goldilocks.Fast as Goldi import Data.Flat +import Class.Field -------------------------------------------------------------------------------- @@ -103,6 +104,29 @@ instance Random F2 where in (F2 x y, g'') randomR = error "randomR/F2: doesn't make any sense" +instance Ring FExt where + zero = Field.Goldilocks.Extension.BindC.zero + one = Field.Goldilocks.Extension.BindC.one + isZero = Field.Goldilocks.Extension.BindC.isZero + isOne = Field.Goldilocks.Extension.BindC.isOne + square = Field.Goldilocks.Extension.BindC.sqr + power = Field.Goldilocks.Extension.BindC.pow + power_ = Field.Goldilocks.Extension.BindC.pow_ + +instance Field FExt + +instance FiniteField FExt where + fieldSize _ = (Goldi.goldilocksPrime ^ 2) + rndIO = randomIO + +instance QuadraticExt FExt where + type BaseField FExt = Goldi.F + inject = Field.Goldilocks.Extension.BindC.inj + project = Field.Goldilocks.Extension.BindC.proj + scale = Field.Goldilocks.Extension.BindC.scl + quadraticPack = Field.Goldilocks.Extension.BindC.pack + quadraticUnpack = Field.Goldilocks.Extension.BindC.unpack + -------------------------------------------------------------------------------- zero, one, two :: F2 diff --git a/reference/src/Field/Goldilocks/Extension/Haskell.hs b/reference/src/Field/Goldilocks/Extension/Haskell.hs index b182b18..99e0e6f 100644 --- a/reference/src/Field/Goldilocks/Extension/Haskell.hs +++ b/reference/src/Field/Goldilocks/Extension/Haskell.hs @@ -3,6 +3,7 @@ -- -- We use the irreducble polynomial @x^2 - 7@ to be compatible with Plonky3 +{-# LANGUAGE TypeFamilies #-} module Field.Goldilocks.Extension.Haskell where -------------------------------------------------------------------------------- @@ -20,10 +21,11 @@ import Foreign.Marshal import Data.Binary -import Data.Flat +import Field.Goldilocks.Slow ( F ) +import qualified Field.Goldilocks.Slow as Goldi -import Field.Goldilocks ( F ) -import qualified Field.Goldilocks as Goldi +import Data.Flat +import Class.Field -------------------------------------------------------------------------------- @@ -92,6 +94,29 @@ instance Random F2 where in (F2 x y, g'') randomR = error "randomR/F2: doesn't make any sense" +instance Ring FExt where + zero = Field.Goldilocks.Extension.Haskell.zero + one = Field.Goldilocks.Extension.Haskell.one + isZero = Field.Goldilocks.Extension.Haskell.isZero + isOne = Field.Goldilocks.Extension.Haskell.isOne + square = Field.Goldilocks.Extension.Haskell.sqr + power = Field.Goldilocks.Extension.Haskell.pow + power_ = Field.Goldilocks.Extension.Haskell.pow_ + +instance Field FExt + +instance FiniteField FExt where + fieldSize _ = (Goldi.goldilocksPrime ^ 2) + rndIO = randomIO + +instance QuadraticExt FExt where + type BaseField FExt = Goldi.F + inject = Field.Goldilocks.Extension.Haskell.inj + project = Field.Goldilocks.Extension.Haskell.proj + scale = Field.Goldilocks.Extension.Haskell.scl + quadraticPack = Field.Goldilocks.Extension.Haskell.pack + quadraticUnpack = Field.Goldilocks.Extension.Haskell.unpack + -------------------------------------------------------------------------------- zero, one, two :: F2 diff --git a/reference/src/Field/Goldilocks/Fast.hs b/reference/src/Field/Goldilocks/Fast.hs index e04b80a..7cb2578 100644 --- a/reference/src/Field/Goldilocks/Fast.hs +++ b/reference/src/Field/Goldilocks/Fast.hs @@ -27,6 +27,7 @@ import Data.Binary.Put ( putWord64le ) import Text.Printf import Data.Flat +import Class.Field -------------------------------------------------------------------------------- @@ -99,6 +100,21 @@ instance Random Goldilocks where random g = let (x,g') = randomR (0,goldilocksPrimeWord64-1) g in (MkGoldilocks x, g') randomR = error "randomR/Goldilocks: doesn't make much sense" +instance Ring F where + zero = Field.Goldilocks.Fast.zero + one = Field.Goldilocks.Fast.one + isZero = Field.Goldilocks.Fast.isZero + isOne = Field.Goldilocks.Fast.isOne + square = Field.Goldilocks.Fast.sqr + power = Field.Goldilocks.Fast.pow + power_ = Field.Goldilocks.Fast.pow_ + +instance Field F + +instance FiniteField F where + fieldSize _ = Field.Goldilocks.Fast.goldilocksPrime + rndIO = randomIO + -------------------------------------------------------------------------------- -- | @p = 2^64 - 2^32 + 1@ diff --git a/reference/src/Field/Goldilocks/Slow.hs b/reference/src/Field/Goldilocks/Slow.hs index ef400a6..556791d 100644 --- a/reference/src/Field/Goldilocks/Slow.hs +++ b/reference/src/Field/Goldilocks/Slow.hs @@ -26,6 +26,7 @@ import Data.Binary.Put ( putWord64le ) import Text.Printf import Data.Flat +import Class.Field -------------------------------------------------------------------------------- @@ -94,9 +95,24 @@ instance Fractional Goldilocks where instance Random Goldilocks where -- random :: RandomGen g => g -> (a, g) - random g = let (x,g') = randomR (0,goldilocksPrime-1) g in (MkGoldilocks x, g') + random g = let (x,g') = randomR (0,goldilocksPrime-1) g in (MkGoldilocks x, g') randomR = error "randomR/Goldilocks: doesn't make much sense" +instance Ring F where + zero = Field.Goldilocks.Slow.zero + one = Field.Goldilocks.Slow.one + isZero = Field.Goldilocks.Slow.isZero + isOne = Field.Goldilocks.Slow.isOne + square = Field.Goldilocks.Slow.sqr + power = Field.Goldilocks.Slow.pow + power_ = Field.Goldilocks.Slow.pow_ + +instance Field F + +instance FiniteField F where + fieldSize _ = Field.Goldilocks.Slow.goldilocksPrime + rndIO = randomIO + -------------------------------------------------------------------------------- -- | @p = 2^64 - 2^32 + 1@ @@ -153,4 +169,3 @@ pow x e _ -> go (acc*s) (sqr s) (shiftR expo 1) -------------------------------------------------------------------------------- - diff --git a/reference/src/Field/Properties.hs b/reference/src/Field/Properties.hs index 972b9e2..f7188e1 100644 --- a/reference/src/Field/Properties.hs +++ b/reference/src/Field/Properties.hs @@ -14,7 +14,7 @@ import Control.Monad import System.IO import System.Random -import Field.Class +import Class.Field -------------------------------------------------------------------------------- -- compatibility hacks diff --git a/reference/src/Field/Tests.hs b/reference/src/Field/Tests.hs index 423ab8a..9f3607d 100644 --- a/reference/src/Field/Tests.hs +++ b/reference/src/Field/Tests.hs @@ -8,7 +8,7 @@ import Control.Monad import Data.Proxy import Data.IORef -import Field.Class +import Class.Field import Field.Properties import Field.Goldilocks ( F ) diff --git a/reference/src/Misc.hs b/reference/src/Misc.hs index 5cf62dd..1f09e52 100644 --- a/reference/src/Misc.hs +++ b/reference/src/Misc.hs @@ -214,6 +214,24 @@ safeIndex def arr j where (a,b) = bounds arr +takeArray :: Int -> Array Int a -> Array Int a +takeArray k arr = listArray (0,k-1) $ take k (elems arr) + +dropArray :: Int -> Array Int a -> Array Int a +dropArray k arr = listArray (0,n1-k) [ arr!(k+j) | j<-[0..n1-k] ] where + (0,n1) = bounds arr + +appendArrays :: Array Int a -> Array Int a -> Array Int a +appendArrays arr1 arr2 = listArray (0,n1+n2+1) $ (elems arr1 ++ elems arr2) where + (0,n1) = bounds arr1 + (0,n2) = bounds arr2 + +concatArrays :: [Array Int a] -> Array Int a +concatArrays [] = error "concatArrays: empty list" +concatArrays arrs = listArray (0,nn-1) (concatMap elems arrs) where + sizes = map arraySize arrs + nn = sum sizes + interleaveArrays' :: Array Int (Array Int a) -> Array Int a interleaveArrays' arrs | nubOrd (elems sizes) == [n] = big diff --git a/reference/src/NTT/Class.hs b/reference/src/NTT/Class.hs deleted file mode 100644 index 6541863..0000000 --- a/reference/src/NTT/Class.hs +++ /dev/null @@ -1,22 +0,0 @@ - -module NTT.Class where - --------------------------------------------------------------------------------- - -import Data.Kind - -import NTT.Subgroup - --------------------------------------------------------------------------------- - -{- -class NTT ntt where - type Field ntt :: Type - type Poly ntt :: Type - - forwardNTT :: Subgroup F -> Poly F -> FlatArray F - inverseNTT :: Subgroup F -> FlatArray F -> Poly F - shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F - shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F - asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F --} diff --git a/reference/src/NTT/FFT/Fast.hs b/reference/src/NTT/FFT/Fast.hs index 19a13a4..07aee54 100644 --- a/reference/src/NTT/FFT/Fast.hs +++ b/reference/src/NTT/FFT/Fast.hs @@ -48,7 +48,7 @@ forwardNTT :: Subgroup F -> Poly F -> FlatArray F forwardNTT sg (MkPoly (MkFlatArray n fptr2)) | subgroupSize sg /= n = error "forwardNTT: subgroup size differs from the array size" | otherwise = unsafePerformIO $ do - fptr3 <- mallocForeignPtrArray n -- (n*4) + fptr3 <- mallocForeignPtrArray n let MkGoldilocks gen = subgroupGen sg withForeignPtr fptr2 $ \ptr2 -> do withForeignPtr fptr3 $ \ptr3 -> do @@ -60,7 +60,7 @@ inverseNTT :: Subgroup F -> FlatArray F -> Poly F inverseNTT sg (MkFlatArray n fptr2) | subgroupSize sg /= n = error "inverseNTT: subgroup size differs from the array size" | otherwise = unsafePerformIO $ do - fptr3 <- mallocForeignPtrArray n --(n*4) + fptr3 <- mallocForeignPtrArray n let MkGoldilocks gen = subgroupGen sg withForeignPtr fptr2 $ \ptr2 -> do withForeignPtr fptr3 $ \ptr3 -> do @@ -73,7 +73,7 @@ shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F shiftedForwardNTT sg (MkGoldilocks eta) (MkPoly (MkFlatArray n fptr2)) | subgroupSize sg /= n = error "shiftedForwardNTT: subgroup size differs from the array size" | otherwise = unsafePerformIO $ do - fptr3 <- mallocForeignPtrArray n -- (n*4) + fptr3 <- mallocForeignPtrArray n let MkGoldilocks gen = subgroupGen sg withForeignPtr fptr2 $ \ptr2 -> do withForeignPtr fptr3 $ \ptr3 -> do @@ -86,7 +86,7 @@ shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F shiftedInverseNTT sg (MkGoldilocks eta) (MkFlatArray n fptr2) | subgroupSize sg /= n = error "shiftedInverseNTT: subgroup size differs from the array size" | otherwise = unsafePerformIO $ do - fptr3 <- mallocForeignPtrArray n -- (n*4) + fptr3 <- mallocForeignPtrArray n withForeignPtr fptr2 $ \ptr2 -> do withForeignPtr fptr3 $ \ptr3 -> do c_ntt_inverse_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3 @@ -99,7 +99,7 @@ asymmForwardNTT sg_src (MkPoly (MkFlatArray n fptr2)) sg_tgt | subgroupSize sg_src /= n = error "asymmForwardNTT: subgroup size differs from the array size" | m < n = error "asymmForwardNTT: target subgroup size should be at least the source subgroup src" | otherwise = unsafePerformIO $ do - fptr3 <- mallocForeignPtrArray m -- (m*4) + fptr3 <- mallocForeignPtrArray m let MkGoldilocks sgen1 = subgroupGen sg_src let MkGoldilocks sgen2 = subgroupGen sg_tgt withForeignPtr fptr2 $ \ptr2 -> do diff --git a/reference/src/NTT/FFT/Slow.hs b/reference/src/NTT/FFT/Slow.hs index 0c41c76..3de633a 100644 --- a/reference/src/NTT/FFT/Slow.hs +++ b/reference/src/NTT/FFT/Slow.hs @@ -25,6 +25,15 @@ polyEvaluate = subgroupNTT polyInterpolate :: Subgroup F -> Array Int F -> Poly F polyInterpolate = subgroupINTT +-------------------------------------------------------------------------------- +-- compatible names + +forwardNTT :: Subgroup F -> Poly F -> Array Int F +forwardNTT = subgroupNTT + +inverseNTT :: Subgroup F -> Array Int F -> Poly F +inverseNTT = subgroupINTT + -------------------------------------------------------------------------------- -- | Evaluates a polynomial on a subgroup /of the same size/ diff --git a/reference/src/NTT/Poly/Flat.hs b/reference/src/NTT/Poly/Flat.hs index b1179ee..82e3c4c 100644 --- a/reference/src/NTT/Poly/Flat.hs +++ b/reference/src/NTT/Poly/Flat.hs @@ -20,8 +20,8 @@ import Foreign.ForeignPtr import System.Random import System.IO.Unsafe +import Class.Field import NTT.Subgroup -import Field.Class import Data.Flat as L -------------------------------------------------------------------------------- diff --git a/reference/src/NTT/Poly/Naive.hs b/reference/src/NTT/Poly/Naive.hs index 665d8c6..0e5ada5 100644 --- a/reference/src/NTT/Poly/Naive.hs +++ b/reference/src/NTT/Poly/Naive.hs @@ -32,6 +32,9 @@ newtype Poly a = Poly (Array Int a) deriving (Show,Functor) +fromPoly :: Poly a -> Array Int a +fromPoly (Poly arr) = arr + instance Binary a => Binary (Poly a) where put (Poly arr) = putSmallArray arr get = Poly <$> getSmallArray diff --git a/reference/src/NTT/Tests.hs b/reference/src/NTT/Tests.hs index de80263..6c75e19 100644 --- a/reference/src/NTT/Tests.hs +++ b/reference/src/NTT/Tests.hs @@ -7,7 +7,7 @@ import Data.Array import Control.Monad -import Field.Class +import Class.Field import Field.Goldilocks ( F ) import Misc