incorporate C monolith hash implementation (TODO: fast hashing of flat arrays of field elements)

This commit is contained in:
Balazs Komuves 2025-10-14 20:21:19 +02:00
parent 58756dd824
commit 41b2f60357
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
17 changed files with 259 additions and 35 deletions

View File

@ -16,6 +16,19 @@ We could significantly improve the speed of the Haskell implementation by bindin
for some of the critical routines: Goldilocks field and extension, hashing, fast Fourier
transform.
### Implementation status
- [x] FRI prover
- [x] FRI verifier
- [ ] proof serialization
- [ ] serious testing of the FRI verifier
- [ ] full outsourcing protocol
- [ ] command line interface
- [x] faster Goldilocks field operations via C FFI
- [ ] quadratic field extension in C too
- [ ] faster hashing via C FFI
- [ ] faster NTT via C FFI
### References
- E. Ben-Sasson, L. Goldberg, S. Kopparty, and S. Saraf: _"DEEP-FRI: Sampling outside the box improves soundness"_ - https://eprint.iacr.org/2019/336

View File

@ -14,6 +14,9 @@ import Data.Ratio
import System.Random
import Foreign.Ptr
import Foreign.Storable
import Data.Binary
import Field.Goldilocks ( F )
@ -57,6 +60,17 @@ instance Random F2 where
in (F2 x y, g'')
randomR = error "randomR/F2: doesn't make any sense"
instance Storable F2 where
peek ptr = do
r <- peek (castPtr ptr)
i <- peek (castPtr ptr `plusPtr` 8)
return (F2 r i)
poke ptr (F2 r i) = do
poke (castPtr ptr) r
poke (castPtr ptr `plusPtr` 8) i
sizeOf _ = 16
alignment _ = 8
--------------------------------------------------------------------------------
zero, one, two :: F2

View File

@ -14,6 +14,8 @@ import Data.Word
import Data.Ratio
import Foreign.C
import Foreign.Ptr
import Foreign.Storable
import System.Random
@ -43,6 +45,12 @@ instance Binary F where
put x = putWord64le (fromF x)
get = toF <$> getWord64le
instance Storable F where
peek ptr = MkGoldilocks <$> peek (castPtr ptr)
poke ptr (MkGoldilocks x) = poke (castPtr ptr) x
sizeOf _ = 8
alignment _ = 8
--------------------------------------------------------------------------------
newtype Goldilocks

View File

@ -34,22 +34,6 @@ hashRate hash = case hash of
--------------------------------------------------------------------------------
type State = Array Int F
listToState' :: Int -> [F] -> State
listToState' n = listArray (0,n-1)
listToState :: Hash -> [F] -> State
listToState hash = listToState' (hashT hash)
zeroState' :: Int -> State
zeroState' n = listToState' n (replicate n 0)
zeroState :: Hash -> State
zeroState hash = zeroState' (hashT hash)
--------------------------------------------------------------------------------
data Digest
= MkDigest !F !F !F !F
deriving (Eq,Show)
@ -64,10 +48,6 @@ instance FieldEncode Digest where
zeroDigest :: Digest
zeroDigest = MkDigest 0 0 0 0
extractDigest :: State -> Digest
extractDigest state = case elems state of
(a:b:c:d:_) -> MkDigest a b c d
listToDigest :: [F] -> Digest
listToDigest [a,b,c,d] = MkDigest a b c d

View File

@ -16,7 +16,9 @@ import Control.Monad.IO.Class
import Text.Show.Pretty
import Field.Goldilocks
import Hash.Common
import Hash.State
import Hash.Duplex.Pure ( DuplexState, Squeeze, Absorb , theHashFunction )
import qualified Hash.Duplex.Pure as Pure

View File

@ -18,7 +18,9 @@ import Data.Array
import Field.Goldilocks ( F )
import Field.Goldilocks.Extension ( FExt , F2(..) )
import Hash.Permutations
import Hash.State
import Hash.Common
--------------------------------------------------------------------------------
@ -37,14 +39,11 @@ data DuplexState
duplexInitialState :: State -> DuplexState
duplexInitialState state = Absorbing state []
overwrite :: [F] -> State -> State
overwrite new old = listToState theHashFunction $ new ++ drop (length new) (elems old)
duplex :: [F] -> State -> State
duplex inp old = permute theHashFunction (overwrite inp old)
extract :: State -> [F]
extract state = reverse $ take rate (elems state) where
extract state = reverse $ take rate (stateToList state) where
rate = 8
freshSqueezing :: State -> DuplexState

View File

@ -33,6 +33,7 @@ import Field.Encode
import Hash.Permutations
import Hash.Common
import Hash.State
import Hash.Sponge
import Misc
@ -262,13 +263,13 @@ reconstructMerkleRoot (MkMerkleProof idx leaf (MkRawMerklePath path) size) = dig
compress :: Hash -> Digest -> Digest -> Digest
compress which (MkDigest a b c d) (MkDigest p q r s) = extractDigest output where
input = listArray (0,11) [ a,b,c,d , p,q,r,s , 0,0,0,0 ]
input = listToState' 12 [ a,b,c,d , p,q,r,s , 0,0,0,0 ]
output = permute which input
keyedCompress :: Hash -> Key -> Digest -> Digest -> Digest
keyedCompress which key (MkDigest a b c d) (MkDigest p q r s) = extractDigest output where
k = fromIntegral key :: F
input = listArray (0,11) [ a,b,c,d , p,q,r,s , k,0,0,0 ]
input = listToState' 12 [ a,b,c,d , p,q,r,s , k,0,0,0 ]
output = permute which input
--------------------------------------------------------------------------------

View File

@ -12,6 +12,8 @@ import Data.Bits
import Data.Word
import Field.Goldilocks
import Hash.State.Naive
import Hash.Monolith.Constants
import Hash.Common

View File

@ -1,16 +1,52 @@
{-# LANGUAGE CPP, ForeignFunctionInterface #-}
module Hash.Permutations where
--------------------------------------------------------------------------------
import Data.Word
import Foreign.Ptr
import Foreign.ForeignPtr
import System.IO.Unsafe
import qualified Hash.Monolith.Permutation as Monolith
import Hash.Common
import Hash.State
--------------------------------------------------------------------------------
#ifdef USE_NAIVE_HASKELL
permute :: Hash -> State -> State
permute hash = case hash of
Monolith -> Monolith.permutation
--------------------------------------------------------------------------------
#else
foreign import ccall unsafe "goldilocks_monolith_permutation" c_monolith_permutation :: Ptr Word64 -> IO ()
foreign import ccall unsafe "goldilocks_monolith_permutation_into" c_monolith_permutation_into :: Ptr Word64 -> Ptr Word64 -> IO ()
permuteInPlace :: State -> IO ()
permuteInPlace fptr = withForeignPtr fptr $ \ptr -> c_monolith_permutation (castPtr ptr)
{-# NOINLINE permuteIO #-}
permuteIO :: State -> IO State
permuteIO src = do
tgt <- mallocForeignPtrArray 12
withForeignPtr src $ \ptr1 ->
withForeignPtr tgt $ \ptr2 ->
c_monolith_permutation_into (castPtr ptr1) (castPtr ptr2)
return tgt
permute :: Hash -> State -> State
permute _ what = unsafePerformIO (permuteIO what)
--------------------------------------------------------------------------------
#endif

View File

@ -31,7 +31,9 @@ import Data.Word
import Data.List
import Field.Goldilocks
import Hash.Permutations
import Hash.State
import Hash.Common
--------------------------------------------------------------------------------
@ -69,7 +71,7 @@ hashFieldElems' which rate@(Rate r) fels
--
internalSponge :: Hash -> Int -> Rate -> [[F]] -> Digest
internalSponge which nbits (Rate r) blocks = extractDigest (loop blocks iv) where
iv = listArray (0,11) $ [ 0,0,0,0 , 0,0,0,0 , domSep,0,0,0 ] :: State
iv = listToState' 12 $ [ 0,0,0,0 , 0,0,0,0 , domSep,0,0,0 ] :: State
domSep = fromIntegral (65536*nbits + 256*t + r) :: F
t = 12
@ -81,9 +83,6 @@ internalSponge which nbits (Rate r) blocks = extractDigest (loop blocks iv) wher
(this:rest) -> loop rest (step this state)
[] -> state
addToState :: [F] -> State -> State
addToState xs arr = listArray (0,11) $ zipWith (+) (xs ++ repeat 0) (elems arr)
--------------------------------------------------------------------------------
hashBytes :: Hash -> [Word8] -> Digest

View File

@ -0,0 +1,14 @@
{-# LANGUAGE CPP #-}
#ifdef USE_NAIVE_HASKELL
module Hash.State ( module Hash.State.Naive ) where
import Hash.State.Naive
#else
module Hash.State ( module Hash.State.FastC ) where
import Hash.State.FastC
#endif

View File

@ -0,0 +1,97 @@
module Hash.State.FastC where
--------------------------------------------------------------------------------
import Data.Bits
import Data.Word
import Control.Monad
import System.IO.Unsafe
import Foreign.C
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.Marshal
import Foreign.Storable
import Field.Goldilocks
import Hash.Common
--------------------------------------------------------------------------------
type State = ForeignPtr F
{-# NOINLINE listToStateIO #-}
listToStateIO :: Int -> [F] -> IO State
listToStateIO n xs = do
fptr <- mallocForeignPtrArray n :: IO (ForeignPtr F)
withForeignPtr fptr $ \ptr -> pokeArray ptr xs
return fptr
listToState' :: Int -> [F] -> State
listToState' n xs = unsafePerformIO (listToStateIO n xs)
listToState :: Hash -> [F] -> State
listToState hash = listToState' (hashT hash)
zeroState' :: Int -> State
zeroState' n = listToState' n (replicate n 0)
zeroState :: Hash -> State
zeroState hash = zeroState' (hashT hash)
--------------------------------------------------------------------------------
{-# NOINLINE stateToListIO #-}
stateToListIO :: State -> IO [F]
stateToListIO fptr = do
withForeignPtr fptr $ \ptr -> do
peekArray 12 ptr
stateToList :: State -> [F]
stateToList state = unsafePerformIO (stateToListIO state)
{-# NOINLINE extractDigestIO #-}
extractDigestIO :: State -> IO Digest
extractDigestIO fptr =
withForeignPtr fptr $ \ptr -> do
a <- peek (ptr )
b <- peek (ptr `plusPtr` 8)
c <- peek (ptr `plusPtr` 16)
d <- peek (ptr `plusPtr` 24)
return (MkDigest a b c d)
extractDigest :: State -> Digest
extractDigest state = unsafePerformIO (extractDigestIO state)
{-# NOINLINE overwriteIO #-}
overwriteIO :: [F] -> State -> IO State
overwriteIO xs src = do
tgt <- mallocForeignPtrArray 12
withForeignPtr src $ \ptr1 -> do
withForeignPtr tgt $ \ptr2 -> do
copyArray ptr2 ptr1 12
pokeArray ptr2 xs
return tgt
overwrite :: [F] -> State -> State
overwrite new old = unsafePerformIO (overwriteIO new old)
{-# NOINLINE addToStateIO #-}
addToStateIO :: [F] -> State -> IO State
addToStateIO xs src = do
tgt <- mallocForeignPtrArray 12
withForeignPtr src $ \ptr1 -> do
withForeignPtr tgt $ \ptr2 -> do
copyArray ptr2 ptr1 12
forM_ (zip [0..] xs) $ \(i,x) -> do
a <- peekElemOff ptr1 i
pokeElemOff ptr2 i (a + x)
return tgt
addToState :: [F] -> State -> State
addToState new old = unsafePerformIO (addToStateIO new old)
--------------------------------------------------------------------------------

View File

@ -0,0 +1,49 @@
module Hash.State.Naive where
--------------------------------------------------------------------------------
import Data.Array
import Data.Bits
import Data.Word
import Data.Binary
import Field.Goldilocks
import Field.Encode
import Hash.Common
--------------------------------------------------------------------------------
type State = Array Int F
listToState' :: Int -> [F] -> State
listToState' n = listArray (0,n-1)
listToState :: Hash -> [F] -> State
listToState hash = listToState' (hashT hash)
zeroState' :: Int -> State
zeroState' n = listToState' n (replicate n 0)
zeroState :: Hash -> State
zeroState hash = zeroState' (hashT hash)
--------------------------------------------------------------------------------
stateToList :: State -> [F]
stateToList = elems
extractDigest :: State -> Digest
extractDigest state = case elems state of
(a:b:c:d:_) -> MkDigest a b c d
overwrite :: [F] -> State -> State
overwrite new old = listToState' 12 $ new ++ drop (length new) (elems old)
addToState :: [F] -> State -> State
addToState xs arr = listArray (0,11) $ zipWith (+) (xs ++ repeat 0) (elems arr)
--------------------------------------------------------------------------------

View File

@ -1,5 +1,6 @@
#include <assert.h>
#include <string.h>
#include "goldilocks.h"
#include "monolith.h"
@ -141,6 +142,11 @@ void goldilocks_monolith_permutation(uint64_t *state) {
goldilocks_monolith_concrete(state);
}
void goldilocks_monolith_permutation_into(uint64_t *src, uint64_t *tgt) {
memcpy( tgt , src , 12*8 );
goldilocks_monolith_permutation( tgt );
}
//------------------------------------------------------------------------------
// compression function: input is two 4-element vector of field elements,

View File

@ -3,10 +3,11 @@
//------------------------------------------------------------------------------
void goldilocks_monolith_permutation (uint64_t *state);
void goldilocks_monolith_keyed_compress(const uint64_t *x, const uint64_t *y, uint64_t key, uint64_t *out);
void goldilocks_monolith_compress (const uint64_t *x, const uint64_t *y, uint64_t *out);
void goldilocks_monolith_bytes_digest (int rate, int N, const uint8_t *input, uint64_t *hash);
void goldilocks_monolith_felts_digest (int rate, int N, const uint64_t *input, uint64_t *hash);
void goldilocks_monolith_permutation (uint64_t *state);
void goldilocks_monolith_permutation_into(uint64_t *src, uint64_t *tgt);
void goldilocks_monolith_keyed_compress (const uint64_t *x, const uint64_t *y, uint64_t key, uint64_t *out);
void goldilocks_monolith_compress (const uint64_t *x, const uint64_t *y, uint64_t *out);
void goldilocks_monolith_bytes_digest (int rate, int N, const uint8_t *input, uint64_t *hash);
void goldilocks_monolith_felts_digest (int rate, int N, const uint64_t *input, uint64_t *hash);
//------------------------------------------------------------------------------

Binary file not shown.

3
reference/src/runi.sh Executable file
View File

@ -0,0 +1,3 @@
#!/bin/bash
ghci testMain.hs cbits/goldilocks.o cbits/monolith.o