make the Merkle tree construction safe (no second preimages) using keyed compression at the nodes

This commit is contained in:
Balazs Komuves 2023-11-16 11:53:08 +01:00
parent cada45df18
commit d0cb4e1026
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
2 changed files with 130 additions and 14 deletions

View File

@ -7,7 +7,7 @@ module Poseidon2
, calcMerkleRoot , calcMerkleTree
, MerkleTree(..) , depthOf , merkleRootOf
, MerkleProof(..) , extractMerkleProof , extractMerkleProof_ , reconstructMerkleRoot
, compression
, compressPair, keyedCompressPair
, permutation
)
where

View File

@ -1,5 +1,14 @@
-- | Merkle tree built from Poseidon2 permutation
--
-- Note: to avoid second preimage attacks, we use a keyed permutations
-- with 2 bits of key:
--
-- * The lowest bit is set to 1 if it's the bottom layer and 0 otherwise
--
-- * The next bit is set to 1 if it's an odd node (1 child) and 0 if
-- if it's an even node (2 children)
--
{-# LANGUAGE BangPatterns #-}
module Poseidon2.Merkle where
@ -9,6 +18,8 @@ module Poseidon2.Merkle where
import Data.Array
import Data.Bits
import Control.Monad
import ZK.Algebra.Curves.BN128.Fr.Mont (Fr)
import Poseidon2.Permutation
@ -43,29 +54,49 @@ depthOf :: MerkleTree -> Int
depthOf (MkMerkleTree outer) = b-a where
(a,b) = bounds outer
{-
calcMerkleTree' :: [Fr] -> [[Fr]]
calcMerkleTree' = go where
go :: [Fr] -> [[Fr]]
go [] = error "calcMerkleTree': input is empty"
go [x] = [[x]]
go xs = xs : go (map compressPair $ pairs xs)
go :: [Fr] -> [[Fr]]
go [] = error "calcMerkleTree': input is empty"
go [x] = [[x]]
go xs = xs : go (map compressPair $ pairs xs)
-}
calcMerkleTree' :: [Fr] -> [[Fr]]
calcMerkleTree' input =
case input of
[] -> error "calcMerkleRoot: input is empty"
[z] -> [[keyedCompression (nodeKey BottomLayer OddNode) z 0]]
zs -> go layerFlags zs
where
go :: [LayerFlag] -> [Fr] -> [[Fr]]
go _ [x] = [[x]]
go (f:fs) xs = xs : go fs (map (evenOddCompressPair f) $ eiPairs xs)
calcMerkleTree :: [Fr] -> MerkleTree
calcMerkleTree = MkMerkleTree . go1 . calcMerkleTree' where
go1 outer = listArray (0, length outer-1) (map go2 outer)
go2 inner = listArray (0, length inner-1) inner
--------------------------------------------------------------------------------
data MerkleProof = MkMerkleProof
{ _leafIndex :: Int
, _leafHash :: Fr
, _merklePath :: [Fr]
{ _leafIndex :: Int -- ^ linear index of the leaf we prove, 0..dataSize-1
, _leafData :: Fr -- ^ the data on the leaf
, _merklePath :: [Fr] -- ^ the path up the root
, _dataSize :: Int -- ^ number of leaves in the tree
}
deriving (Eq,Show)
arrayLength :: Array Int t -> Int
arrayLength arr = (b - a + 1) where (a,b) = bounds arr
-- | Returns the leaf and Merkle path of the given leaf
extractMerkleProof :: MerkleTree -> Int -> MerkleProof
extractMerkleProof tree@(MkMerkleTree outer) idx = MkMerkleProof idx leaf path where
extractMerkleProof tree@(MkMerkleTree outer) idx = MkMerkleProof idx leaf path size where
leaf = (outer!0)!idx
size = arrayLength (outer!0)
depth = depthOf tree
path = worker depth idx
@ -77,20 +108,103 @@ extractMerkleProof tree@(MkMerkleTree outer) idx = MkMerkleProof idx leaf path w
extractMerkleProof_ :: MerkleTree -> Int -> [Fr]
extractMerkleProof_ tree idx = _merklePath (extractMerkleProof tree idx)
reconstructMerkleRoot :: MerkleProof -> Fr
reconstructMerkleRoot (MkMerkleProof idx leaf path size) = go layerFlags size idx leaf path where
go _ !sz 0 !h [] = h
go (f:fs) !sz !j !h !(p:ps) = case (j.&.1, j==sz-1) of
(0, False) -> go fs sz' j' (evenOddCompressPair f $ Right (h,p)) ps
(0, True ) -> go fs sz' j' (evenOddCompressPair f $ Left h ) ps
(1, _ ) -> go fs sz' j' (evenOddCompressPair f $ Right (p,h)) ps
where
sz' = shiftR (sz+1) 1
j' = shiftR j 1
{-
reconstructMerkleRoot :: MerkleProof -> Fr
reconstructMerkleRoot (MkMerkleProof idx leaf path) = go idx leaf path where
go 0 !h [] = h
go !j !h !(p:ps) = case j .&. 1 of
0 -> go (shiftR j 1) (compression h p) ps
1 -> go (shiftR j 1) (compression p h) ps
-}
--------------------------------------------------------------------------------
testAllMerkleProofs :: Int -> IO ()
testAllMerkleProofs nn = forM_ [1..nn] $ \k -> do
let ok = if testMerkleProofs k then "OK." else "FAILED!!"
putStrLn $ "testing Merkle proofs for a tree with " ++ show k ++ " leaves: " ++ ok
testMerkleProofs :: Int -> Bool
testMerkleProofs = and . testMerkleProofs'
testMerkleProofs' :: Int -> [Bool]
testMerkleProofs' n = oks where
input = map fromIntegral [1001..1000+n] :: [Fr]
tree = calcMerkleTree input
root = merkleRootOf tree
oks = [ reconstructMerkleRoot prf == root
| j<-[0..n-1]
, let prf = extractMerkleProof tree j
]
--------------------------------------------------------------------------------
data LayerFlag
= BottomLayer -- ^ it's the bottom (initial, widest) layer
| OtherLayer -- ^ it's not the bottom layer
deriving (Eq,Show)
data NodeParity
= EvenNode -- ^ it has 2 children
| OddNode -- ^ it has 1 child
deriving (Eq,Show)
-- | Key based on the node type:
--
-- > bit0 := 1 if bottom layer, 0 otherwise
-- > bit1 := 1 if odd, 0 if even
--
nodeKey :: LayerFlag -> NodeParity -> Fr
nodeKey OtherLayer EvenNode = 0x00
nodeKey BottomLayer EvenNode = 0x01
nodeKey OtherLayer OddNode = 0x02
nodeKey BottomLayer OddNode = 0x03
evenOddCompressPair :: LayerFlag -> Either Fr (Fr,Fr) -> Fr
evenOddCompressPair !lf (Right (x,y)) = keyedCompression (nodeKey lf EvenNode) x y
evenOddCompressPair !lf (Left x ) = keyedCompression (nodeKey lf OddNode ) x 0
layerFlags :: [LayerFlag]
layerFlags = BottomLayer : repeat OtherLayer
calcMerkleRoot :: [Fr] -> Fr
calcMerkleRoot = go where
go [] = error "calcMerkleRoot: input is empty"
go [x] = x
go xs = go (map compressPair $ pairs xs)
calcMerkleRoot input =
case input of
[] -> error "calcMerkleRoot: input is empty"
[z] -> keyedCompression (nodeKey BottomLayer OddNode) z 0
zs -> go layerFlags zs
where
go :: [LayerFlag] -> [Fr] -> Fr
go _ [x] = x
go (f:fs) xs = go fs (map (evenOddCompressPair f) $ eiPairs xs)
--------------------------------------------------------------------------------
type Key = Fr
keyedCompressPair :: Key -> (Fr,Fr) -> Fr
keyedCompressPair !key (!x,!y) = keyedCompression key x y
keyedCompression :: Key -> Fr -> Fr -> Fr
keyedCompression !key !x !y = case permutation (x,y,key) of (z,_,_) -> z
eiPairs :: [Fr] -> [Either Fr (Fr,Fr)]
eiPairs [] = []
eiPairs [x] = Left x : []
eiPairs (x:y:rest) = Right (x,y) : eiPairs rest
--------------------------------------------------------------------------------
compressPair :: (Fr,Fr) -> Fr
compressPair (x,y) = compression x y
@ -98,10 +212,12 @@ compressPair (x,y) = compression x y
compression :: Fr -> Fr -> Fr
compression x y = case permutation (x,y,0) of (z,_,_) -> z
{-
pairs :: [Fr] -> [(Fr,Fr)]
pairs [] = []
pairs [x] = (x,x) : []
pairs [x] = (x,0) : []
pairs (x:y:rest) = (x,y) : pairs rest
-}
--------------------------------------------------------------------------------