diff --git a/reference/src/FRI/Prover.hs b/reference/src/FRI/Prover.hs index ae60365..74bef8a 100644 --- a/reference/src/FRI/Prover.hs +++ b/reference/src/FRI/Prover.hs @@ -187,8 +187,8 @@ data CommitPhaseData = MkCommitPhaseData -- * fold the polynomial and also the domain -- -repeatedlyFoldPoly :: FriConfig -> [Arity] -> Coset F -> Poly FExt -> DuplexIO ( [CommitPhaseData] , Poly FExt ) -repeatedlyFoldPoly (MkFriConfig{..}) arities domain poly = go arities domain poly where +repeatedlyFoldPoly :: FriConfig -> ReductionStrategy -> Coset F -> Poly FExt -> DuplexIO ( [CommitPhaseData] , Poly FExt ) +repeatedlyFoldPoly (MkFriConfig{..}) (MkRedStrategy arities) domain poly = go arities domain poly where go [] domain poly = return ( [] , poly ) go (arity:rest) domain poly = do let intArity = exp2_ arity -- size of the folded coset diff --git a/reference/src/FRI/Types.hs b/reference/src/FRI/Types.hs index 7d4f205..93e80b1 100644 --- a/reference/src/FRI/Types.hs +++ b/reference/src/FRI/Types.hs @@ -4,6 +4,10 @@ module FRI.Types where -------------------------------------------------------------------------------- +import Control.Monad +import Data.Binary +import qualified Data.ByteString.Lazy as L + import Field.Goldilocks import Field.Goldilocks.Extension ( FExt , F2(..) ) import Field.Encode @@ -43,6 +47,13 @@ data RSConfig = MkRSConfig } deriving (Eq,Show) +instance Binary RSConfig where + put (MkRSConfig{..}) = do + put rsRateBits + put rsDataSize + put rsCosetShift + get = MkRSConfig <$> get <*> get <*> get + exampleRSConfig :: RSConfig exampleRSConfig = MkRSConfig 8 3 theMultiplicativeGenerator @@ -71,7 +82,17 @@ instance Print RSConfig where -- | Folding arity type Arity = Log2 -type ReductionStrategy = [Arity] +newtype ReductionStrategy = MkRedStrategy + { fromReductionStrategy :: [Arity] + } + deriving (Eq,Show) + +instance Binary ReductionStrategy where + put = putSmallList . fromReductionStrategy + get = MkRedStrategy <$> getSmallList + +instance FieldEncode ReductionStrategy where + fieldEncode = fieldEncode . fromReductionStrategy -- | FRI configuration data FriConfig = MkFriConfig @@ -84,12 +105,22 @@ data FriConfig = MkFriConfig } deriving (Eq,Show) +instance Binary FriConfig where + put (MkFriConfig{..}) = do + put friRSConfig + put friNColumns + put friMerkleCapSize + put friReductionStrategy + put friNQueryRounds + put friGrindingBits + get = MkFriConfig <$> get <*> get <*> get <*> get <*> get <*> get + instance Print FriConfig where showWithIndent indent (MkFriConfig{..}) = [ " - friRSConfig\n" ++ unlines1 (showWithIndent (indent+2) friRSConfig) , " - friNColumns = " ++ show friNColumns , " - friMerkleCapSize = " ++ show friMerkleCapSize - , " - friReductionStrategy = " ++ show (map fromLog2 friReductionStrategy) + , " - friReductionStrategy = " ++ show (map fromLog2 $ fromReductionStrategy friReductionStrategy) , " - friNQueryRounds = " ++ show friNQueryRounds , " - friGrindingBits = " ++ show friGrindingBits ] @@ -133,7 +164,7 @@ defaultReductionStrategyParams = MkRedStratPars } findReductionStrategy :: ReductionStrategyParams -> RSConfig -> ReductionStrategy -findReductionStrategy (MkRedStratPars{..}) (MkRSConfig{..}) = worker (rsDataSize + rsRateBits) where +findReductionStrategy (MkRedStratPars{..}) (MkRSConfig{..}) = MkRedStrategy $ worker (rsDataSize + rsRateBits) where worker k | k <= redStoppingDegree = [] | k >= redStoppingDegree + redFoldingArity = redFoldingArity : worker (k - redFoldingArity) @@ -161,6 +192,43 @@ data FriProof = MkFriProof , proofQueryRounds :: [FriQueryRound] -- ^ query rounds , proofPowWitness :: F -- ^ witness showing that the prover did PoW } - deriving Show + deriving (Eq,Show) + +---------------------------------------- + +friProofSizeInBytes :: FriProof -> Int +friProofSizeInBytes friProof = fromIntegral $ L.length (encode friProof) + +instance Binary FriQueryStep where + put (MkFriQueryStep{..}) = do + putSmallList queryEvals + put queryMerklePath + get = MkFriQueryStep + <$> getSmallList + <*> get + +instance Binary FriQueryRound where + put (MkFriQueryRound{..}) = do + putSmallArray queryRow + put queryInitialTreeProof + putSmallList querySteps + get = MkFriQueryRound + <$> getSmallArray + <*> get + <*> getSmallList + +instance Binary FriProof where + put (MkFriProof{..}) = do + put proofFriConfig + putSmallList proofCommitPhaseCaps + put proofFinalPoly + putSmallList proofQueryRounds + put proofPowWitness + get = MkFriProof + <$> get + <*> getSmallList + <*> get + <*> getSmallList + <*> get -------------------------------------------------------------------------------- diff --git a/reference/src/Field/Goldilocks/Extension.hs b/reference/src/Field/Goldilocks/Extension.hs index dbe4cfd..f4bc930 100644 --- a/reference/src/Field/Goldilocks/Extension.hs +++ b/reference/src/Field/Goldilocks/Extension.hs @@ -14,6 +14,8 @@ import Data.Ratio import System.Random +import Data.Binary + import Field.Goldilocks ( F ) -------------------------------------------------------------------------------- @@ -26,6 +28,10 @@ data F2 = F2 } deriving (Eq) +instance Binary F2 where + put (F2 x y) = put x >> put y + get = F2 <$> get <*> get + instance Show F2 where show (F2 r i) = "[ " ++ show r ++ " + j * " ++ show i ++ " ]" diff --git a/reference/src/Field/Goldilocks/Slow.hs b/reference/src/Field/Goldilocks/Slow.hs index d82a15a..97dbb97 100644 --- a/reference/src/Field/Goldilocks/Slow.hs +++ b/reference/src/Field/Goldilocks/Slow.hs @@ -15,6 +15,10 @@ import Data.Ratio import System.Random +import Data.Binary +import Data.Binary.Get ( getWord64le ) +import Data.Binary.Put ( putWord64le ) + import Text.Printf -------------------------------------------------------------------------------- @@ -30,6 +34,10 @@ toF = mkGoldilocks . fromIntegral intToF :: Int -> F intToF = mkGoldilocks . fromIntegral +instance Binary F where + put x = putWord64le (fromF x) + get = toF <$> getWord64le + -------------------------------------------------------------------------------- newtype Goldilocks diff --git a/reference/src/Hash/Common.hs b/reference/src/Hash/Common.hs index a7dfef8..1efde76 100644 --- a/reference/src/Hash/Common.hs +++ b/reference/src/Hash/Common.hs @@ -7,6 +7,8 @@ import Data.Array import Data.Bits import Data.Word +import Data.Binary + import Field.Goldilocks import Field.Encode @@ -52,6 +54,10 @@ data Digest = MkDigest !F !F !F !F deriving (Eq,Show) +instance Binary Digest where + put (MkDigest a b c d) = put a >> put b >> put c >> put d + get = MkDigest <$> get <*> get <*> get <*> get + instance FieldEncode Digest where fieldEncode (MkDigest a b c d) = [a,b,c,d] diff --git a/reference/src/Hash/Merkle.hs b/reference/src/Hash/Merkle.hs index 3126900..a392356 100644 --- a/reference/src/Hash/Merkle.hs +++ b/reference/src/Hash/Merkle.hs @@ -22,6 +22,9 @@ module Hash.Merkle where import Data.Array import Data.Bits +import Control.Monad +import Data.Binary + import Field.Goldilocks import Field.Goldilocks.Extension ( FExt , F2(..) ) import Field.Encode @@ -74,6 +77,10 @@ newtype MerkleCap = MkMerkleCap { fromMerkleCap :: Array Int Digest } deriving (Eq,Show) +instance Binary MerkleCap where + put = putSmallArray . fromMerkleCap + get = MkMerkleCap <$> getSmallArray + instance FieldEncode MerkleCap where fieldEncode (MkMerkleCap arr) = concatMap fieldEncode (elems arr) @@ -133,6 +140,13 @@ newtype RawMerklePath = MkRawMerklePath [Digest] deriving (Eq,Show) +fromRawMerklePath :: RawMerklePath -> [Digest] +fromRawMerklePath (MkRawMerklePath ds) = ds + +instance Binary RawMerklePath where + put = putSmallList . fromRawMerklePath + get = MkRawMerklePath <$> getSmallList + instance FieldEncode RawMerklePath where fieldEncode (MkRawMerklePath ds) = concatMap fieldEncode ds diff --git a/reference/src/Misc.hs b/reference/src/Misc.hs index 1a45253..630068e 100644 --- a/reference/src/Misc.hs +++ b/reference/src/Misc.hs @@ -9,6 +9,9 @@ import Data.List import qualified Data.Set as Set ; import Data.Set (Set) +import Control.Monad +import Data.Binary + import Debug.Trace -------------------------------------------------------------------------------- @@ -175,3 +178,38 @@ extractSubgroupArray :: Int -> Array Int a -> Array Int a extractSubgroupArray stride = extractCosetArray 0 stride -------------------------------------------------------------------------------- +-- * Binary encoding + +instance Binary Log2 where + put (Log2 k) = putWord8 (fromIntegral k) + get = (Log2 . fromIntegral) <$> getWord8 + +putSmallList :: Binary a => [a] -> Put +putSmallList list = do + let n = length list + if (n < 256) + then do + putWord8 (fromIntegral n) + mapM_ put list + else error "putSmallList: array length >= 256" + +getSmallList :: Binary a => Get [a] +getSmallList = do + len <- fromIntegral <$> getWord8 :: Get Int + replicateM len get + +putSmallArray :: Binary a => Array Int a -> Put +putSmallArray list = do + let n = arrayLength list + if (n < 256) + then do + putWord8 (fromIntegral n) + mapM_ put list + else error "putSmallArray: array length >= 256" + +getSmallArray :: Binary a => Get (Array Int a) +getSmallArray = do + len <- fromIntegral <$> getWord8 :: Get Int + listToArray <$> replicateM len get + +-------------------------------------------------------------------------------- diff --git a/reference/src/NTT/Poly.hs b/reference/src/NTT/Poly.hs index 8a8f45d..ceac044 100644 --- a/reference/src/NTT/Poly.hs +++ b/reference/src/NTT/Poly.hs @@ -16,6 +16,8 @@ import Control.Monad.ST.Strict import System.Random +import Data.Binary + import Field.Goldilocks import Field.Goldilocks.Extension ( FExt ) import Field.Encode @@ -30,6 +32,10 @@ newtype Poly a = Poly (Array Int a) deriving (Show,Functor) +instance Binary a => Binary (Poly a) where + put (Poly arr) = putSmallArray arr + get = Poly <$> getSmallArray + instance (Num a, Eq a) => Eq (Poly a) where p == q = polyIsZero (polySub p q) diff --git a/reference/src/testMain.hs b/reference/src/testMain.hs index 233408e..f816626 100644 --- a/reference/src/testMain.hs +++ b/reference/src/testMain.hs @@ -7,6 +7,9 @@ import Data.Array import Text.Show.Pretty import System.Random +import Data.Binary +import qualified Data.ByteString.Lazy as L + import Hash.Duplex.Monad import FRI import Misc @@ -89,7 +92,12 @@ main = do pPrint commits pPrint friProof - ok <- runDuplexIO_ (verifyFRI (_ldeCommitment commits) friProof) - putStrLn $ "verify FRI succeed = " ++ show ok + let lbs = encode friProof + let friProof' = decode lbs + putStrLn $ "size of the serialized proof = " ++ show (L.length lbs) + putStrLn $ "could serialize proof and then load back unchanged = " ++ show (friProof == friProof') + +-- ok <- runDuplexIO_ (verifyFRI (_ldeCommitment commits) friProof) +-- putStrLn $ "verify FRI succeed = " ++ show ok --------------------------------------------------------------------------------