2024-09-24 13:19:16 +02:00

88 lines
2.7 KiB
Haskell

-- | Generate test cases for Nim
module TestGen where
--------------------------------------------------------------------------------
import System.IO
import Goldilocks
--------------------------------------------------------------------------------
centered :: Integer -> Integer -> [Integer]
centered center width = [center-width .. center+width]
mkTestFieldElems :: Integer -> [F]
mkTestFieldElems width = map fromInteger $ concat
[ centered (0 ) width
, centered (2^16) width
, centered (2^31) width
, centered (2^32) width
, centered (2^33) width
, centered (2^48) width
, centered (2^63) width
]
testFieldElems :: [F]
testFieldElems = mkTestFieldElems 7
testFieldPairs :: [(F,F)]
testFieldPairs = [ (x,y) | x<-list, y<-list ] where
list = mkTestFieldElems 3
--------------------------------------------------------------------------------
nimShow :: F -> String
nimShow x = show x ++ "'u64"
nimShowPair :: (F,F) -> String
nimShowPair (x,y) = "( " ++ nimShow x ++ " , " ++ nimShow y ++ " )"
nimShowTriple :: (F,F,F) -> String
nimShowTriple (x,y,z) = "( " ++ nimShow x ++ " , " ++ nimShow y ++ " , " ++ nimShow z ++ " )"
showPairs :: [(F,F)] -> [String]
showPairs xys = zipWith (++) prefix (map nimShowPair xys) where
prefix = " [ " : repeat " , "
showTriples :: [(F,F,F)] -> [String]
showTriples xyzs = zipWith (++) prefix (map nimShowTriple xyzs) where
prefix = " [ " : repeat " , "
----------------------------------------
unary :: String -> (F -> F) -> [F] -> String
unary varname f xs = unlines (header : stuff ++ footer) where
header = "const " ++ varname ++ "* : array[" ++ show (length xs) ++ ", tuple[x:uint64, y:uint64]] = "
footer = [" ]",""]
stuff = showPairs [ (x, f x) | x<-xs ]
binary :: String -> (F -> F -> F) -> [(F,F)] -> String
binary varname f xys = unlines (header : stuff ++ footer) where
header = "const " ++ varname ++ "* : array[" ++ show (length xys) ++ ", tuple[x:uint64, y:uint64, z:uint64]] = "
footer = [" ]",""]
stuff = showTriples [ (x, y, f x y) | (x,y)<-xys ]
--------------------------------------------------------------------------------
printTests :: IO ()
printTests = hPrintTests stdout
hPrintTests :: Handle -> IO ()
hPrintTests h = hPutStrLn h $ unlines
[ unary "testcases_neg" negate testFieldElems
, binary "testcases_add" (+) testFieldPairs
, binary "testcases_sub" (-) testFieldPairs
, binary "testcases_mul" (*) testFieldPairs
]
writeTests :: IO ()
writeTests = withFile "fieldTestCases.nim" WriteMode $ \h -> do
hPutStrLn h "# generated by TestGen.hs\n"
-- hPutStrLn h "import poseidon2/types\n"
hPrintTests h
--------------------------------------------------------------------------------