diff --git a/reference/haskell/Poseidon2.hs b/reference/haskell/Poseidon2.hs index 0c37cd9..ca0d399 100644 --- a/reference/haskell/Poseidon2.hs +++ b/reference/haskell/Poseidon2.hs @@ -7,7 +7,7 @@ module Poseidon2 , calcMerkleRoot , calcMerkleTree , MerkleTree(..) , depthOf , merkleRootOf , MerkleProof(..) , extractMerkleProof , extractMerkleProof_ , reconstructMerkleRoot - , compression + , compressPair, keyedCompressPair , permutation ) where diff --git a/reference/haskell/Poseidon2/Merkle.hs b/reference/haskell/Poseidon2/Merkle.hs index f78b02b..aa71c8d 100644 --- a/reference/haskell/Poseidon2/Merkle.hs +++ b/reference/haskell/Poseidon2/Merkle.hs @@ -1,5 +1,14 @@ -- | Merkle tree built from Poseidon2 permutation +-- +-- Note: to avoid second preimage attacks, we use a keyed permutations +-- with 2 bits of key: +-- +-- * The lowest bit is set to 1 if it's the bottom layer and 0 otherwise +-- +-- * The next bit is set to 1 if it's an odd node (1 child) and 0 if +-- if it's an even node (2 children) +-- {-# LANGUAGE BangPatterns #-} module Poseidon2.Merkle where @@ -9,6 +18,8 @@ module Poseidon2.Merkle where import Data.Array import Data.Bits +import Control.Monad + import ZK.Algebra.Curves.BN128.Fr.Mont (Fr) import Poseidon2.Permutation @@ -43,29 +54,49 @@ depthOf :: MerkleTree -> Int depthOf (MkMerkleTree outer) = b-a where (a,b) = bounds outer +{- calcMerkleTree' :: [Fr] -> [[Fr]] calcMerkleTree' = go where - go :: [Fr] -> [[Fr]] - go [] = error "calcMerkleTree': input is empty" - go [x] = [[x]] - go xs = xs : go (map compressPair $ pairs xs) + go :: [Fr] -> [[Fr]] + go [] = error "calcMerkleTree': input is empty" + go [x] = [[x]] + go xs = xs : go (map compressPair $ pairs xs) +-} + +calcMerkleTree' :: [Fr] -> [[Fr]] +calcMerkleTree' input = + case input of + [] -> error "calcMerkleRoot: input is empty" + [z] -> [[keyedCompression (nodeKey BottomLayer OddNode) z 0]] + zs -> go layerFlags zs + where + go :: [LayerFlag] -> [Fr] -> [[Fr]] + go _ [x] = [[x]] + go (f:fs) xs = xs : go fs (map (evenOddCompressPair f) $ eiPairs xs) calcMerkleTree :: [Fr] -> MerkleTree calcMerkleTree = MkMerkleTree . go1 . calcMerkleTree' where go1 outer = listArray (0, length outer-1) (map go2 outer) go2 inner = listArray (0, length inner-1) inner +-------------------------------------------------------------------------------- + data MerkleProof = MkMerkleProof - { _leafIndex :: Int - , _leafHash :: Fr - , _merklePath :: [Fr] + { _leafIndex :: Int -- ^ linear index of the leaf we prove, 0..dataSize-1 + , _leafData :: Fr -- ^ the data on the leaf + , _merklePath :: [Fr] -- ^ the path up the root + , _dataSize :: Int -- ^ number of leaves in the tree } deriving (Eq,Show) +arrayLength :: Array Int t -> Int +arrayLength arr = (b - a + 1) where (a,b) = bounds arr + -- | Returns the leaf and Merkle path of the given leaf extractMerkleProof :: MerkleTree -> Int -> MerkleProof -extractMerkleProof tree@(MkMerkleTree outer) idx = MkMerkleProof idx leaf path where +extractMerkleProof tree@(MkMerkleTree outer) idx = MkMerkleProof idx leaf path size where leaf = (outer!0)!idx + size = arrayLength (outer!0) depth = depthOf tree path = worker depth idx @@ -77,20 +108,103 @@ extractMerkleProof tree@(MkMerkleTree outer) idx = MkMerkleProof idx leaf path w extractMerkleProof_ :: MerkleTree -> Int -> [Fr] extractMerkleProof_ tree idx = _merklePath (extractMerkleProof tree idx) +reconstructMerkleRoot :: MerkleProof -> Fr +reconstructMerkleRoot (MkMerkleProof idx leaf path size) = go layerFlags size idx leaf path where + go _ !sz 0 !h [] = h + go (f:fs) !sz !j !h !(p:ps) = case (j.&.1, j==sz-1) of + (0, False) -> go fs sz' j' (evenOddCompressPair f $ Right (h,p)) ps + (0, True ) -> go fs sz' j' (evenOddCompressPair f $ Left h ) ps + (1, _ ) -> go fs sz' j' (evenOddCompressPair f $ Right (p,h)) ps + where + sz' = shiftR (sz+1) 1 + j' = shiftR j 1 + +{- reconstructMerkleRoot :: MerkleProof -> Fr reconstructMerkleRoot (MkMerkleProof idx leaf path) = go idx leaf path where go 0 !h [] = h go !j !h !(p:ps) = case j .&. 1 of 0 -> go (shiftR j 1) (compression h p) ps 1 -> go (shiftR j 1) (compression p h) ps +-} -------------------------------------------------------------------------------- +testAllMerkleProofs :: Int -> IO () +testAllMerkleProofs nn = forM_ [1..nn] $ \k -> do + let ok = if testMerkleProofs k then "OK." else "FAILED!!" + putStrLn $ "testing Merkle proofs for a tree with " ++ show k ++ " leaves: " ++ ok + +testMerkleProofs :: Int -> Bool +testMerkleProofs = and . testMerkleProofs' + +testMerkleProofs' :: Int -> [Bool] +testMerkleProofs' n = oks where + input = map fromIntegral [1001..1000+n] :: [Fr] + tree = calcMerkleTree input + root = merkleRootOf tree + oks = [ reconstructMerkleRoot prf == root + | j<-[0..n-1] + , let prf = extractMerkleProof tree j + ] + +-------------------------------------------------------------------------------- + +data LayerFlag + = BottomLayer -- ^ it's the bottom (initial, widest) layer + | OtherLayer -- ^ it's not the bottom layer + deriving (Eq,Show) + +data NodeParity + = EvenNode -- ^ it has 2 children + | OddNode -- ^ it has 1 child + deriving (Eq,Show) + +-- | Key based on the node type: +-- +-- > bit0 := 1 if bottom layer, 0 otherwise +-- > bit1 := 1 if odd, 0 if even +-- +nodeKey :: LayerFlag -> NodeParity -> Fr +nodeKey OtherLayer EvenNode = 0x00 +nodeKey BottomLayer EvenNode = 0x01 +nodeKey OtherLayer OddNode = 0x02 +nodeKey BottomLayer OddNode = 0x03 + +evenOddCompressPair :: LayerFlag -> Either Fr (Fr,Fr) -> Fr +evenOddCompressPair !lf (Right (x,y)) = keyedCompression (nodeKey lf EvenNode) x y +evenOddCompressPair !lf (Left x ) = keyedCompression (nodeKey lf OddNode ) x 0 + +layerFlags :: [LayerFlag] +layerFlags = BottomLayer : repeat OtherLayer + calcMerkleRoot :: [Fr] -> Fr -calcMerkleRoot = go where - go [] = error "calcMerkleRoot: input is empty" - go [x] = x - go xs = go (map compressPair $ pairs xs) +calcMerkleRoot input = + case input of + [] -> error "calcMerkleRoot: input is empty" + [z] -> keyedCompression (nodeKey BottomLayer OddNode) z 0 + zs -> go layerFlags zs + where + go :: [LayerFlag] -> [Fr] -> Fr + go _ [x] = x + go (f:fs) xs = go fs (map (evenOddCompressPair f) $ eiPairs xs) + +-------------------------------------------------------------------------------- + +type Key = Fr + +keyedCompressPair :: Key -> (Fr,Fr) -> Fr +keyedCompressPair !key (!x,!y) = keyedCompression key x y + +keyedCompression :: Key -> Fr -> Fr -> Fr +keyedCompression !key !x !y = case permutation (x,y,key) of (z,_,_) -> z + +eiPairs :: [Fr] -> [Either Fr (Fr,Fr)] +eiPairs [] = [] +eiPairs [x] = Left x : [] +eiPairs (x:y:rest) = Right (x,y) : eiPairs rest + +-------------------------------------------------------------------------------- compressPair :: (Fr,Fr) -> Fr compressPair (x,y) = compression x y @@ -98,10 +212,12 @@ compressPair (x,y) = compression x y compression :: Fr -> Fr -> Fr compression x y = case permutation (x,y,0) of (z,_,_) -> z +{- pairs :: [Fr] -> [(Fr,Fr)] pairs [] = [] -pairs [x] = (x,x) : [] +pairs [x] = (x,0) : [] pairs (x:y:rest) = (x,y) : pairs rest +-} --------------------------------------------------------------------------------