2025-10-06 00:57:05 +02:00

104 lines
3.1 KiB
Haskell

-- | Monadic interface to do Fiat-Shamir challenges
{-# LANGUAGE StrictData, GeneralizedNewtypeDeriving #-}
module Hash.Duplex.Monad where
--------------------------------------------------------------------------------
import Data.Array
import Control.Monad
import Control.Monad.Identity
import qualified Control.Monad.State.Strict as S
import Control.Monad.IO.Class
import Text.Show.Pretty
import Field.Goldilocks
import Hash.Common
import Hash.Duplex.Pure ( DuplexState, Squeeze, Absorb , theHashFunction )
import qualified Hash.Duplex.Pure as Pure
--------------------------------------------------------------------------------
-- * Monadic interface
newtype DuplexT m a
= DuplexT (S.StateT DuplexState m a)
deriving (Functor,Applicative,Monad)
type Duplex a = DuplexT Identity a
runDuplexT :: Monad m => DuplexT m a -> State -> m a
runDuplexT (DuplexT action) ini = S.evalStateT action (Pure.duplexInitialState ini)
runDuplex :: Duplex a -> State -> a
runDuplex action ini = runIdentity (runDuplexT action ini)
absorb :: (Monad m, Absorb a) => a -> DuplexT m ()
absorb x = DuplexT $ S.modify (Pure.absorb x)
squeeze :: (Monad m, Squeeze a) => DuplexT m a
squeeze = DuplexT $ S.state Pure.squeeze
squeezeN :: (Monad m, Squeeze a) => Int -> DuplexT m [a]
squeezeN n = DuplexT $ S.state (Pure.squeezeN n)
-- | For debugging only
inspectDuplexState :: Monad m => DuplexT m (DuplexState)
inspectDuplexState = DuplexT S.get
--------------------------------------------------------------------------------
-- * Access to the internal state (so that we can implement grinding)
unsafeGetInnerState :: Monad m => DuplexT m DuplexState
unsafeGetInnerState = DuplexT S.get
unsafeSetInnerState :: Monad m => DuplexState -> DuplexT m ()
unsafeSetInnerState s = DuplexT (S.put s)
--------------------------------------------------------------------------------
-- * Duplex in IO
type DuplexIO a = DuplexT IO a
instance MonadIO (DuplexT IO) where
liftIO action = DuplexT (liftIO action)
duplexPutStrLn :: String -> DuplexIO ()
duplexPutStrLn s = DuplexT (liftIO $ putStrLn s)
duplexPrint_ :: Show a => a -> DuplexIO ()
duplexPrint_ x = DuplexT (liftIO $ print x)
duplexPrint :: Show a => String -> a -> DuplexIO ()
duplexPrint n x = DuplexT (liftIO $ putStrLn $ n ++ " = " ++ show x)
duplexPPrint :: Show a => String -> a -> DuplexIO ()
duplexPPrint n x = DuplexT (liftIO $ putStrLn $ n ++ ":\n\n" ++ ppShow x ++ "\n")
printDuplexState :: DuplexIO ()
printDuplexState = duplexPrint "state" =<< inspectDuplexState
runDuplexIO :: DuplexIO a -> State -> IO a
runDuplexIO = runDuplexT
runDuplexIO_ :: DuplexIO a -> IO a
runDuplexIO_ action
= runDuplexIO action
$ zeroState theHashFunction
--------------------------------------------------------------------------------
duplexTest :: Int -> IO ()
duplexTest m = runDuplexT action (zeroState theHashFunction) where
action :: DuplexIO ()
action = do
forM_ [0..19] $ \(k :: Int) -> do
absorb (map intToF [1..k])
ys <- squeezeN k :: DuplexIO [F]
duplexPrint_ ys
--------------------------------------------------------------------------------