363 lines
12 KiB
Haskell

{-| Merkle tree construction (using a T=12 hash)
Conventions:
* we use a "keyed compression function" to avoid collisions for different inputs
* when hashing the bottom-most layer, we use the key bit 0x01
* when hashing an odd layer, we pad with a single 0 hash and use the key bit 0x02
* when building a tree on a singleton input, we apply 1 round of compression
(with key 0x03, as it's both the bottom-most layer and odd)
-}
{-# LANGUAGE StrictData, RecordWildCards #-}
module Hash.Merkle where
--------------------------------------------------------------------------------
import Data.Array
import Data.Bits
import Control.Monad
import Data.Binary
import Text.Show.Pretty
import Field.Goldilocks
import Field.Goldilocks.Extension ( FExt , F2(..) )
import Field.Encode
import Hash.Permutations
import Hash.Common
import Hash.Sponge
import Misc
--------------------------------------------------------------------------------
type Key = Int
theHashFunction :: Hash
theHashFunction = Monolith
--------------------------------------------------------------------------------
type FRow = Array Int F
hashFRow :: FRow -> Digest
hashFRow farr = hashFieldElems theHashFunction (elems farr)
hashFExt :: FExt -> Digest
hashFExt (F2 x y) = hashFieldElems theHashFunction [x,y]
{-
data LeafData
= RowData FRow
| Singleton FExt
deriving (Eq,Show)
instance FieldEncode LeafData where
fieldEncode (RowData farr) = elems farr
fieldEncode (Singleton fext) = fieldEncode fext
hashLeafData :: LeafData -> Digest
hashLeafData leaf = case leaf of
RowData frow -> hashFRow frow
Singleton fext -> hashFExt fext
-}
hashAny :: FieldEncode a => a -> Digest
hashAny = hashFieldElems theHashFunction . fieldEncode
--------------------------------------------------------------------------------
newtype MerkleCap
= MkMerkleCap { fromMerkleCap :: Array Int Digest }
deriving (Eq,Show)
instance Binary MerkleCap where
put = putSmallArray . fromMerkleCap
get = MkMerkleCap <$> getSmallArray
instance FieldEncode MerkleCap where
fieldEncode (MkMerkleCap arr) = concatMap fieldEncode (elems arr)
merkleCapSize :: MerkleCap -> Int
merkleCapSize (MkMerkleCap ds) = (arrayLength ds)
merkleCapLogSize :: MerkleCap -> Log2
merkleCapLogSize (MkMerkleCap ds) = exactLog2__ (arrayLength ds)
-- | Computes the root of a Merkle cap
--
-- (we implicitly assume that the cap was not the bottom layer)
merkleCapRoot :: MerkleCap -> Digest
merkleCapRoot (MkMerkleCap hashArray) =
case elems hashArray of
[] -> error "merkleCapRoot: fatal: input is empty"
[z] -> keyedCompress theHashFunction (nodeKey BottomLayer OddNode) z zeroDigest
zs -> go zs
where
go :: [Digest] -> Digest
go [x] = x
go xs = go (map (evenOddCompressPair OtherLayer) $ eiPairs xs)
--------------------------------------------------------------------------------
-- | Note: index 0 is the bottom (widest) layer
data MerkleTree a = MkMerkleTree
{ _merkleTree :: Array Int (Array Int Digest)
, _merkleLeaves :: Array Int a -- LeafData
}
deriving Show
-- | @log2( number-of-leaves )@.
--
-- NOTE: this is one less than the actual number of layers!
-- However it equals to the length of a Merkle path
--
merkleTreeDepth :: MerkleTree a -> Log2
merkleTreeDepth = Log2 . merkleTreeDepth_
merkleTreeDepth_ :: MerkleTree a -> Int
merkleTreeDepth_ (MkMerkleTree outer _) = (b - a) where
(a,b) = bounds outer
extractMerkleCap :: Log2 -> MerkleTree a -> MerkleCap
extractMerkleCap (Log2 capdepth) (MkMerkleTree layers _) = cap where
(0,n) = bounds layers
cap = MkMerkleCap (layers ! (n-capdepth))
treeBottomLayer :: MerkleTree a -> Array Int Digest
treeBottomLayer (MkMerkleTree outer _) = outer!0
--------------------------------------------------------------------------------
-- | Only the Merkle path (siblings)
newtype RawMerklePath
= MkRawMerklePath [Digest]
deriving (Eq,Show)
fromRawMerklePath :: RawMerklePath -> [Digest]
fromRawMerklePath (MkRawMerklePath ds) = ds
rawMerklePathLength :: RawMerklePath -> Int
rawMerklePathLength (MkRawMerklePath ds) = length ds
rawMerklePathLength2 :: RawMerklePath -> Log2
rawMerklePathLength2 path = Log2 (rawMerklePathLength path)
instance Binary RawMerklePath where
put = putSmallList . fromRawMerklePath
get = MkRawMerklePath <$> getSmallList
instance FieldEncode RawMerklePath where
fieldEncode (MkRawMerklePath ds) = concatMap fieldEncode ds
data MerkleProof a = MkMerkleProof
{ _leafIndex :: Int -- ^ linear index of the leaf we prove, 0..dataSize-1
, _leafData :: a -- ^ the data on the leaf
, _merklePath :: RawMerklePath -- ^ the path up the root
, _dataSize :: Int -- ^ number of leaves in the tree
}
deriving (Eq,Show)
-- | Returns the leaf and Merkle path of the given leaf
extractMerkleProof :: MerkleTree a -> Int -> MerkleProof a
extractMerkleProof = extractMerkleProof' (Log2 0)
-- | Returns the leaf and Merkle path of the given leaf, up to a given Merkle cap depth
extractMerkleProof' :: Log2 -> MerkleTree a -> Int -> MerkleProof a
extractMerkleProof' (Log2 capDepth) tree@(MkMerkleTree outer leaves) idx = MkMerkleProof idx leaf path size where
leaf = leaves!idx
size = arrayLength (outer!0)
depth = merkleTreeDepth_ tree
path = MkRawMerklePath $ takePrefix (worker depth idx)
worker 0 0 = []
worker 0 _ = error "extractMerkleProof: this should not happen"
worker level j = this : worker (level-1) (shiftR j 1) where
this = outer ! (depth - level) ! (j `xor` 1)
takePrefix = take (depth - capDepth)
--------------------------------------------------------------------------------
checkMerkleRootProof :: FieldEncode a => Digest -> MerkleProof a -> Bool
checkMerkleRootProof root proof = reconstructMerkleRoot proof == root
checkMerkleCapProof :: FieldEncode a => MerkleCap -> MerkleProof a -> Bool
-- checkMerkleCapProof = error "checkMerkleCapProof: not yet implemented"
checkMerkleCapProof (MkMerkleCap cap) (MkMerkleProof{..}) = result where
result = go _dataSize _leafIndex (hashAny _leafData) (fromRawMerklePath _merklePath) layerFlags
capSize = arraySize cap
go :: Int -> Int -> Digest -> [Digest] -> [LayerFlag] -> Bool
go n j hash [] _layerFlags = if n /= capSize
then error "checkMerkleCapProof: fatal error: cap size doesn't match"
else (hash == cap!j)
go n j hash (sibling:siblings') (layerf:layerfs') = go n' j' hash' siblings' layerfs'
where
j' = shiftR j 1
n' = shiftR (n+1) 1
jparity = j .&. 1
oddflag = if (j+1 == n) && (jparity == 0) then OddNode else EvenNode
key = nodeKey layerf oddflag
hash' = case jparity of
0 -> keyedCompress theHashFunction key hash sibling
1 -> keyedCompress theHashFunction key sibling hash
--------------------------------------------------------------------------------
calcMerkleTree' :: [Digest] -> [Array Int Digest]
calcMerkleTree' input =
case input of
[] -> error "calcMerkleTree': input is empty"
[z] -> [ singletonArray $ keyedCompress theHashFunction (nodeKey BottomLayer OddNode) z zeroDigest ]
zs -> map listToArray (go layerFlags zs)
where
go :: [LayerFlag] -> [Digest] -> [[Digest]]
go _ [x] = [[x]]
go (f:fs) xs = xs : go fs (map (evenOddCompressPair f) $ eiPairs xs)
calcMerkleTree :: FieldEncode a => [a] -> MerkleTree a
calcMerkleTree input = MkMerkleTree tree leafData where
tree = listToArray (calcMerkleTree' $ map hashAny input)
leafData = listToArray input
calcArrayMerkleTree :: FieldEncode a => Array Int a -> MerkleTree a
calcArrayMerkleTree = calcMerkleTree . elems
-- | Applies a permutation of the rows.
--
-- We need the backward mapping (from Merkle tree indices to array indices)
calcArrayMerkleTree' :: FieldEncode a => (Int -> Int) -> Array Int a -> MerkleTree a
calcArrayMerkleTree' bwd arr = calcMerkleTree [ arr!(bwd i) | i<-[0..n-1] ] where
n = arraySize arr
--------------------------------------------------------------------------------
reconstructMerkleRoot :: FieldEncode a => MerkleProof a -> Digest
reconstructMerkleRoot (MkMerkleProof idx leaf (MkRawMerklePath path) size) = digest where
digest = go layerFlags size idx (hashAny leaf) path
go :: [LayerFlag] -> Int -> Int -> Digest -> [Digest] -> Digest
go _ !sz 0 !h [] = h
go (f:fs) !sz !j !h !(p:ps) = case (j.&.1, j==sz-1) of
(0, False) -> go fs sz' j' (evenOddCompressPair f $ Right (h,p)) ps
(0, True ) -> go fs sz' j' (evenOddCompressPair f $ Left h ) ps
(1, _ ) -> go fs sz' j' (evenOddCompressPair f $ Right (p,h)) ps
where
sz' = shiftR (sz+1) 1
j' = shiftR j 1
--------------------------------------------------------------------------------
compress :: Hash -> Digest -> Digest -> Digest
compress which (MkDigest a b c d) (MkDigest p q r s) = extractDigest output where
input = listArray (0,11) [ a,b,c,d , p,q,r,s , 0,0,0,0 ]
output = permute which input
keyedCompress :: Hash -> Key -> Digest -> Digest -> Digest
keyedCompress which key (MkDigest a b c d) (MkDigest p q r s) = extractDigest output where
k = fromIntegral key :: F
input = listArray (0,11) [ a,b,c,d , p,q,r,s , k,0,0,0 ]
output = permute which input
--------------------------------------------------------------------------------
-- | bit masks
keyBottom = 1 :: Key
keyOdd = 2 :: Key
--------------------------------------------------------------------------------
data LayerFlag
= BottomLayer -- ^ it's the bottom (initial, widest) layer
| OtherLayer -- ^ it's not the bottom layer
deriving (Eq,Show)
data NodeParity
= EvenNode -- ^ it has 2 children
| OddNode -- ^ it has 1 child
deriving (Eq,Show)
-- | Key based on the node type:
--
-- > bit0 := 1 if bottom layer, 0 otherwise
-- > bit1 := 1 if odd, 0 if even
--
nodeKey :: LayerFlag -> NodeParity -> Key
nodeKey OtherLayer EvenNode = 0x00
nodeKey BottomLayer EvenNode = 0x01
nodeKey OtherLayer OddNode = 0x02
nodeKey BottomLayer OddNode = 0x03
evenOddCompressPair :: LayerFlag -> Either Digest (Digest,Digest) -> Digest
evenOddCompressPair !lf (Right (x,y)) = keyedCompress theHashFunction (nodeKey lf EvenNode) x y
evenOddCompressPair !lf (Left x ) = keyedCompress theHashFunction (nodeKey lf OddNode ) x zeroDigest
eiPairs :: [a] -> [Either a (a,a)]
eiPairs [] = []
eiPairs [x] = Left x : []
eiPairs (x:y:rest) = Right (x,y) : eiPairs rest
layerFlags :: [LayerFlag]
layerFlags = BottomLayer : repeat OtherLayer
calcMerkleRoot' :: [Digest] -> Digest
calcMerkleRoot' input =
case input of
[] -> error "calcMerkleRoot: input is empty"
[z] -> keyedCompress theHashFunction (nodeKey BottomLayer OddNode) z zeroDigest
zs -> go layerFlags zs
where
go :: [LayerFlag] -> [Digest] -> Digest
go _ [x] = x
go (f:fs) xs = go fs (map (evenOddCompressPair f) $ eiPairs xs)
calcMerkleRoot :: FieldEncode a => [a] -> Digest
calcMerkleRoot = calcMerkleRoot' . map hashAny -- hashLeafData
--------------------------------------------------------------------------------
calcMerkleCap' :: Log2 -> [Digest] -> MerkleCap
calcMerkleCap' (Log2 capDepth) input =
case input of
[] -> error "calcMerkleRoot: input is empty"
[z] -> MkMerkleCap $ listToArray $ [ keyedCompress theHashFunction (nodeKey BottomLayer OddNode) z zeroDigest ]
zs -> MkMerkleCap $ listToArray $ select $ go layerFlags zs
where
go :: [LayerFlag] -> [Digest] -> [[Digest]]
go _ [x] = [[x]]
go (f:fs) xs = xs : go fs (map (evenOddCompressPair f) $ eiPairs xs)
select :: [[Digest]] -> [Digest]
select xs = xs !! (length xs - 1 - capDepth)
calcMerkleCap :: FieldEncode a => Log2 -> [a] -> MerkleCap
calcMerkleCap capDepth = calcMerkleCap' capDepth . map hashAny
--------------------------------------------------------------------------------
{-
exLeaves :: [[F]]
exLeaves = [ leaf (fromInteger i) | i<-[1..8] ] where
leaf :: F -> [F]
leaf i = [ 10*i , 10*i+1 , 10*i+2 ]
exTree = calcMerkleTree exLeaves
exCap = extractMerkleCap (Log2 1) exTree
exIdx = 2
exProof = extractMerkleProof' (Log2 1) exTree exIdx
exSanity = checkMerkleCapProof exCap exProof
-}