mirror of
https://github.com/logos-storage/rust-poseidon-bn254-pure.git
synced 2026-05-22 12:29:35 +00:00
146 lines
3.8 KiB
Haskell
146 lines
3.8 KiB
Haskell
|
|
module Reference.Poseidon2 where
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
import Data.List
|
|
import Control.Monad
|
|
|
|
import Field
|
|
|
|
import HorizenLabs
|
|
import qualified HorizenLabs.Old as Old
|
|
import qualified HorizenLabs.New as New
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
data ParamSet
|
|
= OldParams
|
|
| NewParams
|
|
deriving (Eq,Show)
|
|
|
|
data Width
|
|
= T2 | T3 | T4
|
|
deriving (Eq,Show)
|
|
|
|
fromWidth :: Width -> Integer
|
|
fromWidth T2 = 2
|
|
fromWidth T3 = 3
|
|
fromWidth T4 = 4
|
|
|
|
data Instance = MkInstance
|
|
{ width :: !Width
|
|
, paramSet :: !ParamSet
|
|
}
|
|
deriving (Eq,Show)
|
|
|
|
allParamSets = [ OldParams, NewParams ]
|
|
allWidths = [ T2, T3, T4 ]
|
|
allInstances = [ MkInstance w s | w<-allWidths, s<-allParamSets ]
|
|
|
|
checkPoseidon2 :: IO ()
|
|
checkPoseidon2 = do
|
|
forM_ allInstances $ \which@(MkInstance width paramSet) -> do
|
|
let t = fromWidth width
|
|
let input = map toF [0..t-1]
|
|
let output = permute which input
|
|
putStrLn "---------------------------------------"
|
|
print which
|
|
mapM_ print output
|
|
putStrLn $ "matches the KAT = " ++ show (output == kat which)
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
externalMDS :: Instance -> [[F]]
|
|
externalMDS (MkInstance width _) = (map . map) toF $ case width of
|
|
T2 -> mds_matrix_M2_row_major
|
|
T3 -> mds_matrix_M3_row_major
|
|
T4 -> mds_matrix_M4_row_major
|
|
|
|
internalDiag :: Instance -> [F]
|
|
internalDiag (MkInstance width paramSet) = map toF $ case paramSet of
|
|
OldParams -> case width of
|
|
T2 -> Old.diag_T2
|
|
T3 -> Old.diag_T3
|
|
T4 -> Old.diag_T4
|
|
NewParams -> case width of
|
|
T2 -> New.diag_T2
|
|
T3 -> New.diag_T3
|
|
T4 -> New.diag_T4
|
|
|
|
kat :: Instance -> [F]
|
|
kat (MkInstance width paramSet) = map toF $ case paramSet of
|
|
OldParams -> case width of
|
|
T2 -> Old.kat_T2
|
|
T3 -> Old.kat_T3
|
|
T4 -> Old.kat_T4
|
|
NewParams -> case width of
|
|
T2 -> New.kat_T2
|
|
T3 -> New.kat_T3
|
|
T4 -> New.kat_T4
|
|
|
|
roundConsts :: Instance -> [[F]]
|
|
roundConsts (MkInstance width paramSet) = (map . map) toF $ case paramSet of
|
|
OldParams -> case width of
|
|
T2 -> Old.roundConst_T2
|
|
T3 -> Old.roundConst_T3
|
|
T4 -> Old.roundConst_T4
|
|
NewParams -> case width of
|
|
T2 -> New.roundConst_T2
|
|
T3 -> New.roundConst_T3
|
|
T4 -> New.roundConst_T4
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
type State = [F]
|
|
type Mtx = [[F]]
|
|
type Diag = [F]
|
|
type RCs = [F]
|
|
type RC = F
|
|
|
|
sbox :: F -> F
|
|
sbox x = x * x4 where
|
|
x2 = x * x
|
|
x4 = x2 * x2
|
|
|
|
mdsMul :: Mtx -> State -> State
|
|
mdsMul matrix vector = [ sum (zipWith (*) column vector) | column <- matrix ]
|
|
|
|
diagMul :: Diag -> State -> State
|
|
diagMul diag vector = [ s + d * x | (d,x) <- zip diag vector ] where
|
|
s = sum vector
|
|
|
|
splitRoundConsts :: [RCs] -> ( [RCs] , [RC] , [RCs] )
|
|
splitRoundConsts rows = ( initial , internal , final ) where
|
|
n = length rows
|
|
initial = take 4 rows
|
|
final = drop (n-4) rows
|
|
internal = map head $ take (n-8) $ drop 4 rows
|
|
|
|
internalRound :: Diag -> State -> RC -> State
|
|
internalRound diag old rc = new where
|
|
(x:xs) = old
|
|
x' = sbox (x + rc)
|
|
new = diagMul diag (x':xs)
|
|
|
|
externalRound :: Mtx -> State -> RCs -> State
|
|
externalRound mds old rcs = new where
|
|
xs' = zipWith (\x rc -> sbox (x + rc)) old rcs
|
|
new = mdsMul mds xs'
|
|
|
|
flipFoldl :: (b -> a -> b) -> [a] -> b -> b
|
|
flipFoldl f ys x = foldl' f x ys
|
|
|
|
permute :: Instance -> State -> State
|
|
permute which input = output where
|
|
extMDS = externalMDS which
|
|
intDiag = internalDiag which
|
|
( rcIni , rcMiddle , rcFinal ) = splitRoundConsts (roundConsts which)
|
|
output = flipFoldl (externalRound extMDS ) rcFinal
|
|
$ flipFoldl (internalRound intDiag) rcMiddle
|
|
$ flipFoldl (externalRound extMDS ) rcIni
|
|
$ mdsMul extMDS
|
|
$ input
|
|
|
|
--------------------------------------------------------------------------------
|