update the Haskell reference to improved domain separation (TODO: do the same for the Nim implementation)

This commit is contained in:
Balazs Komuves 2026-04-22 22:42:00 +02:00
parent 84b23e7ba7
commit 33edaaf9e7
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
7 changed files with 215 additions and 69 deletions

View File

@ -130,7 +130,21 @@ elements, not bytes!), and an initialization vector `(0,0,domSep)` where `domSep
(short for "domain separation"), the initial value for the "capacity" part of the
sponge, is defined as
domSep := 2^64 + 256*t + rate
domSep (old) := 2^64 + 256*t + rate
domSep (new) := 2^64 + 2^24*padding + 2^16*inputType + 2^8*t + rate
Here `inputType` can be:
- 1: input is a sequence of bits
- 8: input is a sequence of bytes
- 254: input is a sequence of BN254 field elements
And `padding` (the padding strategy can be):
- 1: the `10*` padding strategy, applied to a sequence of field elements
- 16: the `10*` padding strategy, applied to a sequence of bytes (or bits)
- 17: the `10*` padding strategy applied first to bytes (to a multiple of 31 bytes), then also to the resulting field element sequence
- 255: no padding
Parameters
@ -264,6 +278,7 @@ In case of SHA256, we could use a compression functions of the form
byte. Since SHA256 already does some padding internally, this has the same
cost as computing just `SHA256(x|y)`.
Network blocks vs. cells
------------------------

View File

@ -4,8 +4,9 @@
module Poseidon2
( Fr
, Flavour(..)
, sponge1 , sponge2
, spongeFelts , spongeBytes
, calcMerkleRoot , calcMerkleTree
, calcMerkleTreeFeltSeqs , calcMerkleTreeByteStrings
, MerkleTree(..) , depthOf , merkleRootOf , treeBottomLayer
, MerkleProof(..) , extractMerkleProof , extractMerkleProof_ , reconstructMerkleRoot
, compressPair, keyedCompressPair

View File

@ -18,11 +18,14 @@ module Poseidon2.Merkle where
import Data.Array
import Data.Bits
import Data.ByteString (ByteString)
import Control.Monad
import ZK.Algebra.Curves.BN128.Fr.Mont (Fr)
import Poseidon2.Permutation
import Poseidon2.Sponge
-- import Debug.Trace
-- debug s x y = trace (s ++ " ~> " ++ show x) y
@ -66,6 +69,7 @@ calcMerkleTree' = go where
go xs = xs : go (map compressPair $ pairs xs)
-}
calcMerkleTree' :: Flavour -> [Fr] -> [[Fr]]
calcMerkleTree' flavour input =
case input of
@ -78,12 +82,20 @@ calcMerkleTree' flavour input =
go (f:fs) xs = xs : go fs (map (evenOddCompressPair flavour f) $ eiPairs xs)
calcMerkleTree :: Flavour -> [Fr] -> MerkleTree
calcMerkleTree flavour = MkMerkleTree flavour . go1 . calcMerkleTree' flavour where
calcMerkleTree flavour leaves = MkMerkleTree flavour $ go1 (calcMerkleTree' flavour leaves) where
go1 outer = listArray (0, length outer-1) (map go2 outer)
go2 inner = listArray (0, length inner-1) inner
--------------------------------------------------------------------------------
calcMerkleTreeFeltSeqs :: Flavour -> [[Fr]] -> MerkleTree
calcMerkleTreeFeltSeqs flavour xss = calcMerkleTree flavour (map (spongeFelts SpongeRate2 flavour) xss)
calcMerkleTreeByteStrings :: Flavour -> [ByteString] -> MerkleTree
calcMerkleTreeByteStrings flavour bss = calcMerkleTree flavour (map (spongeBytes SpongeRate2 flavour) bss)
--------------------------------------------------------------------------------
data MerkleProof = MkMerkleProof
{ _flavour :: !Flavour -- ^ which hash function
, _leafIndex :: !Int -- ^ linear index of the leaf we prove, 0..dataSize-1

View File

@ -2,25 +2,97 @@
{-# LANGUAGE BangPatterns #-}
module Poseidon2.Sponge
( Flavour(..)
, sponge1
, sponge2
, SpongeRate(..)
, InputFormat(..)
, PaddingStrategy(..)
, computeDomainSeparator
, spongeFelts , spongeBytes
, spongeFelts1 , spongeFelts2
, sponge1' , sponge2'
, byteStringToFieldElements
)
where
--------------------------------------------------------------------------------
import Data.Bits
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import ZK.Algebra.Curves.BN128.Fr.Mont (Fr)
import Poseidon2.Permutation
--------------------------------------------------------------------------------
-- | Sponge construction with rate=1 (capacity=2), zero IV and 10* padding
sponge1 :: Flavour -> [Fr] -> Fr
sponge1 !flavour input = go (0,0,civ) (pad input) where
data SpongeRate
= SpongeRate1
| SpongeRate2
deriving (Eq,Show)
-- domain separation: capacity IV = 2^64 + 256*t + rate
civ = fromInteger (2^64 + 0x0301)
data InputFormat
= BitSequence -- ^ sequence of bits
| ByteSequence -- ^ sequence of bytes
| FeltSequenceBN254 -- ^ sequence of BN254 field elements
deriving (Eq,Show)
data PaddingStrategy
= NoPadding -- ^ no padding
| Padding_Felts_10Star -- ^ padding field elements with @10*@ (to a multiple of rate)
| Padding_Bytes_10Star -- ^ padding bytes with @10*@ (so that the result length is divisible by @(31*rate)@, eg. 62)
| Padding_Felts_Bytes_10Star -- ^ padding bytes with @10*@ to be divisible by 31, and then padding the resulting field element sequence too
deriving (Eq,Show)
newtype DomSep = DomSep Fr
-- | domain separation:
--
-- > capacity IV = 2^64 + 2^24*padding + 2^16*inputfmt + 256*t + rate
--
computeDomainSeparator :: SpongeRate -> InputFormat -> PaddingStrategy -> DomSep
computeDomainSeparator spongRate inputFormat paddingStrategy = DomSep (fromInteger domsep) where
domsep :: Integer
domsep = (2^64 + 2^24*padding + 2^16*inputfmt + 2^8*width + rate)
width :: Integer
width = 3
rate = case spongRate of
SpongeRate1 -> 1
SpongeRate2 -> 2
inputfmt = case inputFormat of
BitSequence -> 1
ByteSequence -> 8
FeltSequenceBN254 -> 254
padding = case paddingStrategy of
NoPadding -> 255
Padding_Felts_10Star -> 1
Padding_Bytes_10Star -> 16
Padding_Felts_Bytes_10Star -> 17
--------------------------------------------------------------------------------
spongeFelts :: SpongeRate -> Flavour -> [Fr] -> Fr
spongeFelts rate = case rate of
SpongeRate1 -> spongeFelts1
SpongeRate2 -> spongeFelts2
spongeBytes :: SpongeRate -> Flavour -> ByteString -> Fr
spongeBytes rate flavour bytes = case rate of
SpongeRate1 -> sponge1' flavour (computeDomainSeparator rate ByteSequence Padding_Felts_Bytes_10Star) (byteStringToFieldElements bytes)
SpongeRate2 -> sponge2' flavour (computeDomainSeparator rate ByteSequence Padding_Felts_Bytes_10Star) (byteStringToFieldElements bytes)
--------------------------------------------------------------------------------
-- | Sponge construction with rate=1 (capacity=2), and 10* padding
spongeFelts1 :: Flavour -> [Fr] -> Fr
spongeFelts1 flavour = sponge1' flavour (computeDomainSeparator SpongeRate1 FeltSequenceBN254 Padding_Felts_10Star)
sponge1' :: Flavour -> DomSep -> [Fr] -> Fr
sponge1' !flavour (DomSep civ) input = go (0,0,civ) (pad input) where
pad :: [Fr] -> [Fr]
pad (x:xs) = x : pad xs
@ -32,12 +104,12 @@ sponge1 !flavour input = go (0,0,civ) (pad input) where
--------------------------------------------------------------------------------
-- | Sponge construction with rate=2 (capacity=1), zero IV and 10* padding
sponge2 :: Flavour -> [Fr] -> Fr
sponge2 !flavour input = go (0,0,civ) (pad input) where
-- | Sponge construction with rate=2 (capacity=1), and 10* padding
spongeFelts2 :: Flavour -> [Fr] -> Fr
spongeFelts2 flavour = sponge2' flavour (computeDomainSeparator SpongeRate2 FeltSequenceBN254 Padding_Felts_10Star)
-- domain separation: capacity IV = 2^64 + 256*t + rate
civ = fromInteger (2^64 + 0x0302)
sponge2' :: Flavour -> DomSep -> [Fr] -> Fr
sponge2' !flavour (DomSep civ) input = go (0,0,civ) (pad input) where
pad :: [Fr] -> [Fr]
pad (x:y:rest) = x : y : pad rest
@ -49,4 +121,47 @@ sponge2 !flavour input = go (0,0,civ) (pad input) where
state' = permutation flavour (sx+a, sy+b, sz)
--------------------------------------------------------------------------------
-- * dealing with bytes
-- | A 31-byte long chunk
newtype Chunk
= Chunk ByteString
deriving Show
-- | Split bytestring into samller pieces, applying the @10*@ padding strategy.
--
-- That is, always add a single @0x01@ byte, and then add the necessary
-- number (in the interval @[0..k-1]@) of @0x00@ bytes to be a multiple of the
-- given chunk length
--
padAndSplitByteString :: Int -> ByteString -> [Chunk]
padAndSplitByteString k orig = go (B.snoc orig 0x01) where
go bs
| m == 0 = []
| m < k = [Chunk $ B.append bs (B.replicate (k-m) 0x00)]
| otherwise = (Chunk $ B.take k bs) : go (B.drop k bs)
where
m = B.length bs
-- | Chunk a ByteString into a sequence of field elements
byteStringToFieldElements :: ByteString -> [Fr]
byteStringToFieldElements rawdata = map chunkToField pieces where
chunkSize = 31
pieces = padAndSplitByteString chunkSize rawdata
chunkToField :: Chunk -> Fr
chunkToField chunk@(Chunk bs)
| l == 31 = fromInteger (chunkToIntegerLE chunk)
| l < 31 = error "chunkToField: chunk is too small (expecting exactly 31 bytes)"
| l > 31 = error "chunkToField: chunk is too big (expecting exactly 31 bytes)"
where
l = B.length bs
-- | Interpret a ByteString as an integer (little-endian)
chunkToIntegerLE :: Chunk -> Integer
chunkToIntegerLE (Chunk chunk) = go (B.unpack chunk) where
go [] = 0
go (w:ws) = fromIntegral w + shiftL (go ws) 8
--------------------------------------------------------------------------------

View File

@ -12,6 +12,7 @@ import qualified Data.ByteString as B
import Slot as Slot
import DataSet as DataSet
import Poseidon2
import Poseidon2.Sponge
import qualified ZK.Algebra.Curves.BN128.Fr.Mont as Fr
@ -30,7 +31,7 @@ type Entropy = Fr
-- cell index to sample
sampleCellIndex :: SlotConfig -> Entropy -> Hash -> Int -> CellIdx
sampleCellIndex cfg entropy slotRoot counter = CellIdx (fromInteger idx) where
u = sponge2 (Slot._hashFlavour cfg) [entropy , slotRoot , fromIntegral counter] :: Fr
u = spongeFelts2 (Slot._hashFlavour cfg) [entropy , slotRoot , fromIntegral counter] :: Fr
idx = (Fr.from u) `mod` n :: Integer
n = (fromIntegral $ Slot._nCells cfg) :: Integer

View File

@ -16,6 +16,7 @@ import Control.Monad
import System.IO
import Poseidon2
import Poseidon2.Sponge
import Misc
--------------------------------------------------------------------------------
@ -231,48 +232,11 @@ hashCell cfg (CellData rawdata)
flavour = _hashFlavour cfg
hashCell_ :: Flavour -> ByteString -> Hash
hashCell_ flavour rawdata = sponge2 flavour (cellDataToFieldElements $ CellData rawdata)
hashCell_ flavour rawdata = spongeBytes SpongeRate2 flavour rawdata
-- sponge2 flavour (cellDataToFieldElements $ CellData rawdata)
--------------------------------------------------------------------------------
-- | A 31-byte long chunk
newtype Chunk
= Chunk ByteString
deriving Show
-- | Split bytestring into samller pieces, applying the @10*@ padding strategy.
--
-- That is, always add a single @0x01@ byte, and then add the necessary
-- number (in the interval @[0..k-1]@) of @0x00@ bytes to be a multiple of the
-- given chunk length
--
padAndSplitByteString :: Int -> ByteString -> [Chunk]
padAndSplitByteString k orig = go (B.snoc orig 0x01) where
go bs
| m == 0 = []
| m < k = [Chunk $ B.append bs (B.replicate (k-m) 0x00)]
| otherwise = (Chunk $ B.take k bs) : go (B.drop k bs)
where
m = B.length bs
-- | Chunk a ByteString into a sequence of field elements
cellDataToFieldElements :: CellData -> [Fr]
cellDataToFieldElements (CellData rawdata) = map chunkToField pieces where
chunkSize = 31
pieces = padAndSplitByteString chunkSize rawdata
chunkToField :: Chunk -> Fr
chunkToField chunk@(Chunk bs)
| l == 31 = fromInteger (chunkToIntegerLE chunk)
| l < 31 = error "chunkToField: chunk is too small (expecting exactly 31 bytes)"
| l > 31 = error "chunkToField: chunk is too big (expecting exactly 31 bytes)"
where
l = B.length bs
-- | Interpret a ByteString as an integer (little-endian)
chunkToIntegerLE :: Chunk -> Integer
chunkToIntegerLE (Chunk chunk) = go (B.unpack chunk) where
go [] = 0
go (w:ws) = fromIntegral w + shiftL (go ws) 8
--------------------------------------------------------------------------------
cellDataToFieldElements (CellData rawdata) = byteStringToFieldElements rawdata

View File

@ -8,6 +8,7 @@ module TestVectors where
import Control.Monad
import Data.Word
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Poseidon2.Merkle
@ -29,9 +30,20 @@ allTestVectors = do
allTestVectors' :: Flavour -> IO ()
allTestVectors' flavour = do
testVectorsSponge flavour
testVectorsHash flavour
testVectorsMerkle flavour
testVectorsSponge flavour
testVectorsHash flavour
testVectorsMerkleAsHash flavour
testVectorsMerkleFull flavour
--------------------------------------------------------------------------------
showFelt :: Fr -> String
showFelt x
| n0 > 77 = error "showFelt: should not happen"
| otherwise = replicate (77-n0) '0' ++ s0
where
s0 = show x
n0 = length s0
--------------------------------------------------------------------------------
@ -42,14 +54,14 @@ testVectorsSponge flavour = do
putStrLn "-------------------------------------------------------------------"
forM_ [0..8] $ \n -> do
let input = map fromIntegral [1..n] :: [Fr]
putStrLn $ "hash of [1.." ++ show n ++ "] :: [Fr] = " ++ show (sponge1 flavour input)
putStrLn $ "hash of [1.." ++ show n ++ "] :: [Fr] = " ++ showFelt (spongeFelts SpongeRate1 flavour input)
putStrLn ""
putStrLn $ "test vectors for sponge of field elements with rate=2 | " ++ show flavour
putStrLn "-------------------------------------------------------------------"
forM_ [0..8] $ \n -> do
let input = map fromIntegral [1..n] :: [Fr]
putStrLn $ "hash of [1.." ++ show n ++ "] :: [Fr] = " ++ show (sponge2 flavour input)
putStrLn $ "hash of [1.." ++ show n ++ "] :: [Fr] = " ++ showFelt (spongeFelts SpongeRate2 flavour input)
--------------------------------------------------------------------------------
@ -62,26 +74,52 @@ testVectorsHash flavour = do
forM_ [0..80] $ \n -> do
let input = map fromIntegral [1..n] :: [Word8]
let bs = B.pack input
putStrLn $ "hash of [1.." ++ show n ++ "] :: [Byte] = " ++ show (hashCell_ flavour bs)
putStrLn $ "hash of [1.." ++ show n ++ "] :: [Byte] = " ++ showFelt (spongeBytes SpongeRate2 flavour bs)
--------------------------------------------------------------------------------
testVectorsMerkle :: Flavour -> IO ()
testVectorsMerkle flavour = do
testVectorsMerkleAsHash :: Flavour -> IO ()
testVectorsMerkleAsHash flavour = do
putStrLn ""
putStrLn $ "test vectors for Merkle roots of field elements | " ++ show flavour
putStrLn $ "test vectors for Merkle roots (used as a hash function) of sequences of field elements | " ++ show flavour
putStrLn "-----------------------------------------------"
forM_ [1..40] $ \n -> do
let input = map fromIntegral [1..n] :: [Fr]
putStrLn $ "Merkle root of [1.." ++ show n ++ "] :: [Fr] = " ++ show (calcMerkleRoot flavour input)
let root = calcMerkleRoot flavour input
let root' = merkleRootOf $ calcMerkleTree flavour input
if root == root'
then putStrLn $ "Merkle root of [1.." ++ show n ++ "] :: [Fr] = " ++ showFelt root
else fail "testVectorsMerkleAsHash: FATAL"
putStrLn ""
putStrLn $ "test vectors for Merkle roots of sequence of bytes | " ++ show flavour
putStrLn $ "test vectors for Merkle roots (used as a hash function) of sequence of bytes | " ++ show flavour
putStrLn "--------------------------------------------------"
forM_ [0..80] $ \n -> do
let input = map fromIntegral [1..n] :: [Word8]
let bs = B.pack input
let flds = cellDataToFieldElements (CellData bs)
putStrLn $ "Merkle root of [1.." ++ show n ++ "] :: [Byte] = " ++ show (calcMerkleRoot flavour flds)
let flds = byteStringToFieldElements bs
let root = calcMerkleRoot flavour flds
let root' = merkleRootOf $ calcMerkleTree flavour flds
if root == root'
then putStrLn $ "Merkle root of [1.." ++ show n ++ "] :: [Byte] = " ++ showFelt root
else fail "testVectorsMerkleAsHash: FATAL"
--------------------------------------------------------------------------------
testVectorsMerkleFull :: Flavour -> IO ()
testVectorsMerkleFull flavour = do
putStrLn ""
putStrLn $ "test vectors for Merkle roots, where the leaves are sequences of bytes | " ++ show flavour
putStrLn "--------------------------------------------------"
forM_ [1..81] $ \n -> do
let inputs = makeInputs n
putStrLn $ "Merkle root of [ [0..j-1] | j<-[1.." ++ show n ++ "] :: [[Byte]] = " ++ showFelt (merkleRootOf $ calcMerkleTreeByteStrings flavour inputs)
where
makeInput :: Int -> Int -> ByteString
makeInput n j = B.pack $ map (fromIntegral :: Int -> Word8) [0..j-1]
makeInputs :: Int -> [ByteString]
makeInputs n = [ makeInput n j | j<-[1..n] ]
--------------------------------------------------------------------------------