fixed the (trivial) C NTT bug, and started a refactor towards a unified type-class based interface

This commit is contained in:
Balazs Komuves 2025-11-04 22:17:29 +01:00
parent a9cb0a96a6
commit bd888d5b57
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
20 changed files with 452 additions and 117 deletions

View File

@ -0,0 +1,40 @@
{-# LANGUAGE TypeFamilies #-}
module Class.Field where
--------------------------------------------------------------------------------
import Data.Kind
import Data.Proxy
import System.Random
--------------------------------------------------------------------------------
class (Show a, Eq a, Num a) => Ring a where
zero :: a
one :: a
isZero :: a -> Bool
isOne :: a -> Bool
square :: a -> a
power :: a -> Integer -> a
power_ :: a -> Int -> a
class (Ring a, Fractional a) => Field a where
inverse :: a -> a
inverse = recip
class Field a => FiniteField a where
rndIO :: IO a
fieldSize :: Proxy a -> Integer
-- | Quadratic extensions
class (Field (BaseField ext), Field ext) => QuadraticExt ext where
type BaseField ext :: Type
inject :: BaseField ext -> ext
project :: ext -> Maybe (BaseField ext)
scale :: BaseField ext -> ext -> ext
quadraticPack :: (BaseField ext, BaseField ext) -> ext
quadraticUnpack :: ext -> (BaseField ext, BaseField ext)
--------------------------------------------------------------------------------

View File

@ -0,0 +1,7 @@
module Class.Flat
( module Data.Flat.Class
)
where
import Data.Flat.Class

View File

@ -0,0 +1,40 @@
module NTT.Class where
--------------------------------------------------------------------------------
import Data.Kind
import Class.Field
import Class.Poly
import NTT.Subgroup
--------------------------------------------------------------------------------
{-
class NTT ntt where
type Field ntt :: Type
type Poly ntt :: Type
forwardNTT :: Subgroup F -> Poly F -> FlatArray F
inverseNTT :: Subgroup F -> FlatArray F -> Poly F
shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F
shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F
asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F
-}
{-
-- | Polynomials which over an FFT-friend field support NTT operations
class (Univariate p, FFTField (Coeff p)) => UnivariateFFT p where
-- | Number-theoretical transform (evaluate on a subgroup)
ntt :: FFTSubgroup (Coeff p) -> p -> FlatArray (Coeff p)
-- | Inverse number-theoretical transform (interpolate on a subgroup)
intt :: FFTSubgroup (Coeff p) -> FlatArray (Coeff p) -> p
-- | Shifts @f@ by @eta@, evaluating @f(eta*omega^k)@
shiftedNTT :: FFTSubgroup (Coeff p) -> Coeff p -> p -> FlatArray (Coeff p)
-- | Shifts @f@ by @eta^-1@, interpolating @f@ so that @f(eta^-1 * omega^k) = y_k@
shiftedINTT :: FFTSubgroup (Coeff p) -> Coeff p -> FlatArray (Coeff p) -> p
-- | Evaluate on a larger subgroup than the polynomial is defined on
asymmNTT :: FFTSubgroup (Coeff p) -> p -> FFTSubgroup (Coeff p) -> FlatArray (Coeff p)
-}

View File

@ -0,0 +1,99 @@
-- | Univariate polynomials
module Class.Poly where
--------------------------------------------------------------------------------
import Data.Kind
import Class.Field
--------------------------------------------------------------------------------
-- * Dense univariate polynomials over (finite) fields
class (Ring p, Field (Coeff p)) => Univariate p where
-- | the type of coefficients
type Coeff p :: Type
-- | Degree
degree :: p -> Int
-- | Size (can be larger than degree+1 if the top coefficients are zeor)
polySize :: p -> Int
-- | The k-th coefficient
kthCoeff :: Int -> p -> Coeff p
-- | Evaluation
evalAt :: Coeff p -> p -> Coeff p
-- | Scaling
scale :: Coeff p -> p -> p
-- | Create a polynomial from coefficiens
mkPoly :: [Coeff p] -> p
-- | Coefficients of the polynomial as a list
coeffs :: p -> [Coeff p]
-- | Coefficients as an Array
coeffsArr :: p -> Array Int (Coeff p)
{-
-- | Polynomial long division
polyLongDiv :: p -> p -> (p,p)
-- | Polynomial quotient
polyQuot :: p -> p -> p
-- | Polynomial remainder
polyRem :: p -> p -> p
-- | Divide by the coset vanishing polynomial @(x^n - eta)@
divByVanishing :: p -> (Int, Coeff p) -> (p,p)
-- | Quotient by the coset vanishing polynomial @(x^n - eta)@
quotByVanishing :: p -> (Int, Coeff p) -> Maybe p
-}
--------------------------------------------------------------------------------
-- * Some generic functions
-- | Checks whether the input is the constant one polynomial?
polyIsOne :: Univariate p => p -> Bool
polyIsOne p = case mbConst p of
Nothing -> False
Just x -> ZK.Algebra.Class.Field.isOne x
-- | The constant term of a polynomial
constTermOf :: Univariate p => p -> Coeff p
constTermOf = kthCoeff 0
-- | Is this a constant polynomial?
--
-- TODO: this is not efficient for high-degree polynomials, where we could exit early...
mbConst :: Univariate p => p -> Maybe (Coeff p)
mbConst p = if degree p <= 0 then Just (constTermOf p) else Nothing
-- | Create a constant polynomial
constPoly :: Univariate p => Coeff p -> p
constPoly y = mkPoly [y]
-- | The polynomial @p(x)=x@
idPoly :: Univariate p => p
idPoly = mkPoly [0,1]
-- | Create a linear polynomial.
--
-- > linearPoly a b == a*x + b@
--
linearPoly :: Univariate p => Coeff p -> Coeff p -> p
linearPoly a b = mkPoly [b,a]
showPoly :: Univariate p => p -> String
showPoly = showPoly' True
showPoly' :: Univariate p => Bool -> p -> String
showPoly' newlines_flag poly =
case newlines_flag of
False -> intercalate " +" terms
True -> intercalate " +\n" terms
where
pairs = filter (\kx -> snd kx /= 0)
$ zip [0..] (coeffs poly)
terms = case pairs of
[] -> [" 0"]
_ -> map f pairs
f (0,x) = ' ' : show x
f (1,x) = ' ' : show x ++ " * x"
f (k,x) = ' ' : show x ++ " * x^" ++ show k
--------------------------------------------------------------------------------

View File

@ -0,0 +1,112 @@
{-# LANGUAGE TypeFamilies #-}
module Class.Vector where
--------------------------------------------------------------------------------
import Data.Kind
import Data.Array.IArray
import Class.Field
import Misc
--------------------------------------------------------------------------------
class Vector v where
type VecElem v :: Type
vecLength :: v -> Int
vecPeek :: v -> Int -> VecElem v
vecToList :: v -> [VecElem v]
vecFromList :: [VecElem v] -> v
vecToArray :: v -> Array Int (VecElem v)
vecFromArray :: Array Int (VecElem v) -> v
vecTake :: Int -> v -> v
vecDrop :: Int -> v -> v
vecAppend :: v -> v -> v
vecConcat :: [v] -> v
vecMap :: (VecElem v -> VecElem v) -> v -> v
vecSize :: Vector v => v -> Int
vecSize = vecLength
--------------------------------------------------------------------------------
instance Vector [a] where
type VecElem [a] = a
vecLength = length
vecPeek = (!!)
vecToList = id
vecFromList = id
vecToArray = listToArray
vecFromArray = elems
vecTake = take
vecDrop = drop
vecAppend = (++)
vecConcat = concat
vecMap = map
instance Vector (Array Int a) where
type VecElem (Array Int a) = a
vecLength = \arr -> let (0,n1) = bounds arr in n1+1
vecPeek = (!)
vecToList = elems
vecFromList = listToArray
vecToArray = id
vecFromArray = id
vecTake = takeArray
vecDrop = dropArray
vecAppend = appendArrays
vecConcat = concatArrays
vecMap = amap
--------------------------------------------------------------------------------
-- * Pointwise operations
class PointwiseGroup a where
-- | Pointwise negation
pwNeg :: a -> a
-- | Pointwise addition
pwAdd :: a -> a -> a
-- | Pointwise subtraction
pwSub :: a -> a -> a
class PointwiseGroup a => PointwiseRing a where
-- | Pointwise squaring
pwSqr :: a -> a
-- | Pointwise multiplication
pwMul :: a -> a -> a
-- | Pointwise @a*b+c@
pwMulAdd :: a -> a -> a -> a
-- | Pointwise @a*b-c@
pwMulSub :: a -> a -> a -> a
class PointwiseRing a => PointwiseField a where
-- | Pointwise inversion
pwInv :: a -> a
-- | Pointwise division
pwDiv :: a -> a -> a
--------------------------------------------------------------------------------
-- | Finite dimensional vector spaces
class (Vector v, Field (VecElem v), PointwiseField v) => VectorSpace v where
-- | Scaling by an element
vecScale :: VecElem v -> v -> v
-- | Dot product
dotProd :: v -> v -> VecElem v
-- | The array @[ a*b^k | k<-[0..n-1] ]@
powers :: VecElem v -> VecElem v -> Int -> v
-- | Pointwise multiplication by the array @[ a*b^k | k<-[0..n-1] ]@
mulByPowers :: VecElem v -> VecElem v -> v -> v
-- | Linear combination @a*x + y@
linComb1 :: (VecElem v, v) -> v -> v
-- | Linear combination @a*x + b*y@
linComb2 :: (VecElem v, v) -> (VecElem v, v) -> v
--------------------------------------------------------------------------------

View File

@ -25,6 +25,7 @@ import System.IO
import System.IO.Unsafe
import Data.Flat.Class
import Class.Vector
import Misc
--------------------------------------------------------------------------------
@ -33,7 +34,7 @@ import Misc
-- foreign memory (not managed by the Haskell runtime).
--
-- Note: the @Int@ means the number of objects in the array.
data FlatArray a
data FlatArray (a :: Type)
= MkFlatArray !Int !(ForeignPtr Word64)
deriving Show
@ -128,6 +129,14 @@ dropFlatArrayIO k (MkFlatArray n fptr1) = do
when (m>0) $ copyBytes ptr2 src (8*sz*m)
return (MkFlatArray m fptr2)
{-# NOINLINE mapFlatArrayIO #-}
mapFlatArrayIO :: (a -> a) -> FlatArray a -> IO (FlatArray a)
mapFlatArrayIO = error "mapFlatArrayIO: not yet implemented"
{-# NOINLINE mapFlatArray #-}
mapFlatArray :: (a -> a) -> FlatArray a -> FlatArray a
mapFlatArray f arr = unsafePerformIO (mapFlatArrayIO f arr)
----------------------------------------
-- | Read a flat array from a raw binary file. The size of the file determines the length of the array.
@ -214,3 +223,20 @@ unpackFlatArrayToListIO (MkFlatArray len fptr) = do
makeFlat src
--------------------------------------------------------------------------------
instance Flat a => Vector (FlatArray a) where
type VecElem (FlatArray a) = a
vecLength = flatArrayLength
vecPeek = peekFlatArray
vecToList = unpackFlatArrayToList
vecFromList = packFlatArrayFromList
vecToArray = unpackFlatArray
vecFromArray = packFlatArray
vecAppend = error "Vector/vecAppend/FlatArray: not implemented"
vecConcat = error "Vector/vecAConcat/FlatArray: not implemented"
vecMap = mapFlatArray

View File

@ -1,77 +0,0 @@
{-# LANGUAGE TypeFamilies #-}
module Field.Class where
--------------------------------------------------------------------------------
import Data.Kind
import Data.Proxy
import System.Random
import qualified Field.Goldilocks as Goldi
import qualified Field.Goldilocks.Extension as GoldiExt
--------------------------------------------------------------------------------
class (Show a, Eq a, Num a, Fractional a) => Field a where
fieldSize :: Proxy a -> Integer
zero :: a
one :: a
isZero :: a -> Bool
isOne :: a -> Bool
square :: a -> a
power :: a -> Integer -> a
power_ :: a -> Int -> a
rndIO :: IO a
inverse :: Field a => a -> a
inverse = recip
-- | Quadratic extensions
class (Field (BaseField ext), Field ext) => QuadraticExt ext where
type BaseField ext :: Type
inject :: BaseField ext -> ext
project :: ext -> Maybe (BaseField ext)
scale :: BaseField ext -> ext -> ext
quadraticPack :: (BaseField ext, BaseField ext) -> ext
quadraticUnpack :: ext -> (BaseField ext, BaseField ext)
--------------------------------------------------------------------------------
instance Field Goldi.F where
fieldSize _ = Goldi.goldilocksPrime
zero = Goldi.zero
one = Goldi.one
isZero = Goldi.isZero
isOne = Goldi.isOne
square = Goldi.sqr
power = Goldi.pow
power_ = Goldi.pow_
rndIO = randomIO
--------------------------------------------------------------------------------
instance Field GoldiExt.FExt where
fieldSize _ = (Goldi.goldilocksPrime ^ 2)
zero = GoldiExt.zero
one = GoldiExt.one
isZero = GoldiExt.isZero
isOne = GoldiExt.isOne
square = GoldiExt.sqr
power = GoldiExt.pow
power_ = GoldiExt.pow_
rndIO = randomIO
--------------------------------------------------------------------------------
instance QuadraticExt GoldiExt.FExt where
type BaseField GoldiExt.FExt = Goldi.F
inject = GoldiExt.inj
project = GoldiExt.proj
scale = GoldiExt.scl
quadraticPack = GoldiExt.pack
quadraticUnpack = GoldiExt.unpack
--------------------------------------------------------------------------------

View File

@ -5,7 +5,7 @@
-- with the fast Goldilocks base field operations, but the C versions should be useful
-- for the vector operations, and this way we can test them easily.
{-# LANGUAGE ForeignFunctionInterface, BangPatterns, NumericUnderscores #-}
{-# LANGUAGE ForeignFunctionInterface, BangPatterns, NumericUnderscores, TypeFamilies #-}
module Field.Goldilocks.Extension.BindC where
--------------------------------------------------------------------------------
@ -31,10 +31,11 @@ import Data.Binary.Put ( putWord64le )
import Text.Printf
import Field.Goldilocks ( F , Goldilocks(..) )
import qualified Field.Goldilocks as Goldi
import Field.Goldilocks.Fast ( F , Goldilocks(..) )
import qualified Field.Goldilocks.Fast as Goldi
import Data.Flat
import Class.Field
--------------------------------------------------------------------------------
@ -103,6 +104,29 @@ instance Random F2 where
in (F2 x y, g'')
randomR = error "randomR/F2: doesn't make any sense"
instance Ring FExt where
zero = Field.Goldilocks.Extension.BindC.zero
one = Field.Goldilocks.Extension.BindC.one
isZero = Field.Goldilocks.Extension.BindC.isZero
isOne = Field.Goldilocks.Extension.BindC.isOne
square = Field.Goldilocks.Extension.BindC.sqr
power = Field.Goldilocks.Extension.BindC.pow
power_ = Field.Goldilocks.Extension.BindC.pow_
instance Field FExt
instance FiniteField FExt where
fieldSize _ = (Goldi.goldilocksPrime ^ 2)
rndIO = randomIO
instance QuadraticExt FExt where
type BaseField FExt = Goldi.F
inject = Field.Goldilocks.Extension.BindC.inj
project = Field.Goldilocks.Extension.BindC.proj
scale = Field.Goldilocks.Extension.BindC.scl
quadraticPack = Field.Goldilocks.Extension.BindC.pack
quadraticUnpack = Field.Goldilocks.Extension.BindC.unpack
--------------------------------------------------------------------------------
zero, one, two :: F2

View File

@ -3,6 +3,7 @@
--
-- We use the irreducble polynomial @x^2 - 7@ to be compatible with Plonky3
{-# LANGUAGE TypeFamilies #-}
module Field.Goldilocks.Extension.Haskell where
--------------------------------------------------------------------------------
@ -20,10 +21,11 @@ import Foreign.Marshal
import Data.Binary
import Data.Flat
import Field.Goldilocks.Slow ( F )
import qualified Field.Goldilocks.Slow as Goldi
import Field.Goldilocks ( F )
import qualified Field.Goldilocks as Goldi
import Data.Flat
import Class.Field
--------------------------------------------------------------------------------
@ -92,6 +94,29 @@ instance Random F2 where
in (F2 x y, g'')
randomR = error "randomR/F2: doesn't make any sense"
instance Ring FExt where
zero = Field.Goldilocks.Extension.Haskell.zero
one = Field.Goldilocks.Extension.Haskell.one
isZero = Field.Goldilocks.Extension.Haskell.isZero
isOne = Field.Goldilocks.Extension.Haskell.isOne
square = Field.Goldilocks.Extension.Haskell.sqr
power = Field.Goldilocks.Extension.Haskell.pow
power_ = Field.Goldilocks.Extension.Haskell.pow_
instance Field FExt
instance FiniteField FExt where
fieldSize _ = (Goldi.goldilocksPrime ^ 2)
rndIO = randomIO
instance QuadraticExt FExt where
type BaseField FExt = Goldi.F
inject = Field.Goldilocks.Extension.Haskell.inj
project = Field.Goldilocks.Extension.Haskell.proj
scale = Field.Goldilocks.Extension.Haskell.scl
quadraticPack = Field.Goldilocks.Extension.Haskell.pack
quadraticUnpack = Field.Goldilocks.Extension.Haskell.unpack
--------------------------------------------------------------------------------
zero, one, two :: F2

View File

@ -27,6 +27,7 @@ import Data.Binary.Put ( putWord64le )
import Text.Printf
import Data.Flat
import Class.Field
--------------------------------------------------------------------------------
@ -99,6 +100,21 @@ instance Random Goldilocks where
random g = let (x,g') = randomR (0,goldilocksPrimeWord64-1) g in (MkGoldilocks x, g')
randomR = error "randomR/Goldilocks: doesn't make much sense"
instance Ring F where
zero = Field.Goldilocks.Fast.zero
one = Field.Goldilocks.Fast.one
isZero = Field.Goldilocks.Fast.isZero
isOne = Field.Goldilocks.Fast.isOne
square = Field.Goldilocks.Fast.sqr
power = Field.Goldilocks.Fast.pow
power_ = Field.Goldilocks.Fast.pow_
instance Field F
instance FiniteField F where
fieldSize _ = Field.Goldilocks.Fast.goldilocksPrime
rndIO = randomIO
--------------------------------------------------------------------------------
-- | @p = 2^64 - 2^32 + 1@

View File

@ -26,6 +26,7 @@ import Data.Binary.Put ( putWord64le )
import Text.Printf
import Data.Flat
import Class.Field
--------------------------------------------------------------------------------
@ -94,9 +95,24 @@ instance Fractional Goldilocks where
instance Random Goldilocks where
-- random :: RandomGen g => g -> (a, g)
random g = let (x,g') = randomR (0,goldilocksPrime-1) g in (MkGoldilocks x, g')
random g = let (x,g') = randomR (0,goldilocksPrime-1) g in (MkGoldilocks x, g')
randomR = error "randomR/Goldilocks: doesn't make much sense"
instance Ring F where
zero = Field.Goldilocks.Slow.zero
one = Field.Goldilocks.Slow.one
isZero = Field.Goldilocks.Slow.isZero
isOne = Field.Goldilocks.Slow.isOne
square = Field.Goldilocks.Slow.sqr
power = Field.Goldilocks.Slow.pow
power_ = Field.Goldilocks.Slow.pow_
instance Field F
instance FiniteField F where
fieldSize _ = Field.Goldilocks.Slow.goldilocksPrime
rndIO = randomIO
--------------------------------------------------------------------------------
-- | @p = 2^64 - 2^32 + 1@
@ -153,4 +169,3 @@ pow x e
_ -> go (acc*s) (sqr s) (shiftR expo 1)
--------------------------------------------------------------------------------

View File

@ -14,7 +14,7 @@ import Control.Monad
import System.IO
import System.Random
import Field.Class
import Class.Field
--------------------------------------------------------------------------------
-- compatibility hacks

View File

@ -8,7 +8,7 @@ import Control.Monad
import Data.Proxy
import Data.IORef
import Field.Class
import Class.Field
import Field.Properties
import Field.Goldilocks ( F )

View File

@ -214,6 +214,24 @@ safeIndex def arr j
where
(a,b) = bounds arr
takeArray :: Int -> Array Int a -> Array Int a
takeArray k arr = listArray (0,k-1) $ take k (elems arr)
dropArray :: Int -> Array Int a -> Array Int a
dropArray k arr = listArray (0,n1-k) [ arr!(k+j) | j<-[0..n1-k] ] where
(0,n1) = bounds arr
appendArrays :: Array Int a -> Array Int a -> Array Int a
appendArrays arr1 arr2 = listArray (0,n1+n2+1) $ (elems arr1 ++ elems arr2) where
(0,n1) = bounds arr1
(0,n2) = bounds arr2
concatArrays :: [Array Int a] -> Array Int a
concatArrays [] = error "concatArrays: empty list"
concatArrays arrs = listArray (0,nn-1) (concatMap elems arrs) where
sizes = map arraySize arrs
nn = sum sizes
interleaveArrays' :: Array Int (Array Int a) -> Array Int a
interleaveArrays' arrs
| nubOrd (elems sizes) == [n] = big

View File

@ -1,22 +0,0 @@
module NTT.Class where
--------------------------------------------------------------------------------
import Data.Kind
import NTT.Subgroup
--------------------------------------------------------------------------------
{-
class NTT ntt where
type Field ntt :: Type
type Poly ntt :: Type
forwardNTT :: Subgroup F -> Poly F -> FlatArray F
inverseNTT :: Subgroup F -> FlatArray F -> Poly F
shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F
shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F
asymmForwardNTT :: Subgroup F -> Poly F -> Subgroup F -> FlatArray F
-}

View File

@ -48,7 +48,7 @@ forwardNTT :: Subgroup F -> Poly F -> FlatArray F
forwardNTT sg (MkPoly (MkFlatArray n fptr2))
| subgroupSize sg /= n = error "forwardNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n -- (n*4)
fptr3 <- mallocForeignPtrArray n
let MkGoldilocks gen = subgroupGen sg
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
@ -60,7 +60,7 @@ inverseNTT :: Subgroup F -> FlatArray F -> Poly F
inverseNTT sg (MkFlatArray n fptr2)
| subgroupSize sg /= n = error "inverseNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n --(n*4)
fptr3 <- mallocForeignPtrArray n
let MkGoldilocks gen = subgroupGen sg
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
@ -73,7 +73,7 @@ shiftedForwardNTT :: Subgroup F -> F -> Poly F -> FlatArray F
shiftedForwardNTT sg (MkGoldilocks eta) (MkPoly (MkFlatArray n fptr2))
| subgroupSize sg /= n = error "shiftedForwardNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n -- (n*4)
fptr3 <- mallocForeignPtrArray n
let MkGoldilocks gen = subgroupGen sg
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
@ -86,7 +86,7 @@ shiftedInverseNTT :: Subgroup F -> F -> FlatArray F -> Poly F
shiftedInverseNTT sg (MkGoldilocks eta) (MkFlatArray n fptr2)
| subgroupSize sg /= n = error "shiftedInverseNTT: subgroup size differs from the array size"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray n -- (n*4)
fptr3 <- mallocForeignPtrArray n
withForeignPtr fptr2 $ \ptr2 -> do
withForeignPtr fptr3 $ \ptr3 -> do
c_ntt_inverse_shifted eta (subgroupCLogSize sg) (subgroupGenAsWord64 sg) ptr2 ptr3
@ -99,7 +99,7 @@ asymmForwardNTT sg_src (MkPoly (MkFlatArray n fptr2)) sg_tgt
| subgroupSize sg_src /= n = error "asymmForwardNTT: subgroup size differs from the array size"
| m < n = error "asymmForwardNTT: target subgroup size should be at least the source subgroup src"
| otherwise = unsafePerformIO $ do
fptr3 <- mallocForeignPtrArray m -- (m*4)
fptr3 <- mallocForeignPtrArray m
let MkGoldilocks sgen1 = subgroupGen sg_src
let MkGoldilocks sgen2 = subgroupGen sg_tgt
withForeignPtr fptr2 $ \ptr2 -> do

View File

@ -25,6 +25,15 @@ polyEvaluate = subgroupNTT
polyInterpolate :: Subgroup F -> Array Int F -> Poly F
polyInterpolate = subgroupINTT
--------------------------------------------------------------------------------
-- compatible names
forwardNTT :: Subgroup F -> Poly F -> Array Int F
forwardNTT = subgroupNTT
inverseNTT :: Subgroup F -> Array Int F -> Poly F
inverseNTT = subgroupINTT
--------------------------------------------------------------------------------
-- | Evaluates a polynomial on a subgroup /of the same size/

View File

@ -20,8 +20,8 @@ import Foreign.ForeignPtr
import System.Random
import System.IO.Unsafe
import Class.Field
import NTT.Subgroup
import Field.Class
import Data.Flat as L
--------------------------------------------------------------------------------

View File

@ -32,6 +32,9 @@ newtype Poly a
= Poly (Array Int a)
deriving (Show,Functor)
fromPoly :: Poly a -> Array Int a
fromPoly (Poly arr) = arr
instance Binary a => Binary (Poly a) where
put (Poly arr) = putSmallArray arr
get = Poly <$> getSmallArray

View File

@ -7,7 +7,7 @@ import Data.Array
import Control.Monad
import Field.Class
import Class.Field
import Field.Goldilocks ( F )
import Misc