diff --git a/eth/ssz/merkleization.nim b/eth/ssz/merkleization.nim new file mode 100644 index 0000000..84f365d --- /dev/null +++ b/eth/ssz/merkleization.nim @@ -0,0 +1,609 @@ +# ssz_serialization +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +# This module contains the parts necessary to create a merkle hash from the core +# SSZ types outlined in the spec: +# https://github.com/ethereum/eth2.0-specs/blob/v1.0.1/ssz/simple-serialize.md#merkleization + +{.push raises: [Defect].} + +import + stew/[bitops2, endians2, ptrops], + stew/ranges/ptr_arith, nimcrypto/[hash, sha2], + serialization/testing/tracing, + "."/[bitseqs, types] + +export + types + +when hasSerializationTracing: + import stew/byteutils, typetraits + +const + zero64 = default array[64, byte] + bitsPerChunk = bytesPerChunk * 8 + +func binaryTreeHeight*(totalElements: Limit): int = + bitWidth nextPow2(uint64 totalElements) + +type + SszMerkleizerImpl = object + combinedChunks: ptr UncheckedArray[Digest] + totalChunks: uint64 + topIndex: int + + SszMerkleizer*[limit: static[Limit]] = object + combinedChunks: ref array[binaryTreeHeight limit, Digest] + impl: SszMerkleizerImpl + +template chunks*(m: SszMerkleizerImpl): openArray[Digest] = + m.combinedChunks.toOpenArray(0, m.topIndex) + +template getChunkCount*(m: SszMerkleizer): uint64 = + m.impl.totalChunks + +template getCombinedChunks*(m: SszMerkleizer): openArray[Digest] = + toOpenArray(m.impl.combinedChunks, 0, m.impl.topIndex) + +type DigestCtx* = sha2.sha256 + +template computeDigest*(body: untyped): Digest = + ## This little helper will init the hash function and return the sliced + ## hash: + ## let hashOfData = withHash: h.update(data) + when nimvm: + # In SSZ, computeZeroHashes require compile-time SHA256 + block: + var h {.inject.}: sha256 + init(h) + body + finish(h) + else: + block: + var h {.inject, noInit.}: DigestCtx + init(h) + body + finish(h) + +func digest(a, b: openArray[byte]): Digest = + result = computeDigest: + trs "DIGESTING ARRAYS ", toHex(a), " ", toHex(b) + trs toHex(a) + trs toHex(b) + + h.update a + h.update b + trs "HASH RESULT ", result + +func digest(a, b, c: openArray[byte]): Digest = + result = computeDigest: + trs "DIGESTING ARRAYS ", toHex(a), " ", toHex(b), " ", toHex(c) + + h.update a + h.update b + h.update c + trs "HASH RESULT ", result + +func mergeBranches(existing: Digest, newData: openArray[byte]): Digest = + trs "MERGING BRANCHES OPEN ARRAY" + + let paddingBytes = bytesPerChunk - newData.len + digest(existing.data, newData, zero64.toOpenArray(0, paddingBytes - 1)) + +template mergeBranches(existing: Digest, newData: array[32, byte]): Digest = + trs "MERGING BRANCHES ARRAY" + digest(existing.data, newData) + +template mergeBranches(a, b: Digest): Digest = + trs "MERGING BRANCHES DIGEST" + digest(a.data, b.data) + +func computeZeroHashes: array[sizeof(Limit) * 8, Digest] = + result[0] = Digest() + for i in 1 .. result.high: + result[i] = mergeBranches(result[i - 1], result[i - 1]) + +const zeroHashes* = computeZeroHashes() + +func addChunk*(merkleizer: var SszMerkleizerImpl, data: openArray[byte]) = + doAssert data.len > 0 and data.len <= bytesPerChunk + + if getBitLE(merkleizer.totalChunks, 0): + var hash = mergeBranches(merkleizer.combinedChunks[0], data) + + for i in 1 .. merkleizer.topIndex: + trs "ITERATING" + if getBitLE(merkleizer.totalChunks, i): + trs "CALLING MERGE BRANCHES" + hash = mergeBranches(merkleizer.combinedChunks[i], hash) + else: + trs "WRITING FRESH CHUNK AT ", i, " = ", hash + merkleizer.combinedChunks[i] = hash + break + else: + let paddingBytes = bytesPerChunk - data.len + + merkleizer.combinedChunks[0].data[0.. 0 and merkleizer.topIndex > 0 + + let proofHeight = merkleizer.topIndex + 1 + result = newSeq[Digest](chunks.len * proofHeight) + + if chunks.len == 1: + merkleizer.addChunkAndGenMerkleProof(chunks[0], result) + return + + let newTotalChunks = merkleizer.totalChunks + chunks.len.uint64 + + var + # A perfect binary tree will take either `chunks.len * 2` values if the + # number of elements in the base layer is odd and `chunks.len * 2 - 1` + # otherwise. Each row may also need a single extra element at most if + # it must be combined with the existing values in the Merkleizer: + merkleTree = newSeqOfCap[Digest](chunks.len + merkleizer.topIndex) + inRowIdx = merkleizer.totalChunks + postUpdateInRowIdx = newTotalChunks + zeroMixed = false + + template writeResult(chunkIdx, level: int, chunk: Digest) = + result[chunkIdx * proofHeight + level] = chunk + + # We'll start by generating the first row of the merkle tree. + var currPairEnd = if inRowIdx.isOdd: + # an odd chunk number means that we must combine the + # hash with the existing pending sibling hash in the + # merkleizer. + writeResult(0, 0, merkleizer.combinedChunks[0]) + merkleTree.add mergeBranches(merkleizer.combinedChunks[0], chunks[0]) + + # TODO: can we immediately write this out? + merkleizer.completeStartedChunk(merkleTree[^1], 1) + 2 + else: + 1 + + if postUpdateInRowIdx.isOdd: + merkleizer.combinedChunks[0] = chunks[^1] + + while currPairEnd < chunks.len: + writeResult(currPairEnd - 1, 0, chunks[currPairEnd]) + writeResult(currPairEnd, 0, chunks[currPairEnd - 1]) + merkleTree.add mergeBranches(chunks[currPairEnd - 1], + chunks[currPairEnd]) + currPairEnd += 2 + + if currPairEnd - 1 < chunks.len: + zeroMixed = true + writeResult(currPairEnd - 1, 0, zeroHashes[0]) + merkleTree.add mergeBranches(chunks[currPairEnd - 1], + zeroHashes[0]) + var + level = 0 + baseChunksPerElement = 1 + treeRowStart = 0 + rowLen = merkleTree.len + + template writeProofs(rowChunkIdx: int, hash: Digest) = + let + startAbsIdx = (inRowIdx.int + rowChunkIdx) * baseChunksPerElement + endAbsIdx = startAbsIdx + baseChunksPerElement + startResIdx = max(startAbsIdx - merkleizer.totalChunks.int, 0) + endResIdx = min(endAbsIdx - merkleizer.totalChunks.int, chunks.len) + + for resultPos in startResIdx ..< endResIdx: + writeResult(resultPos, level, hash) + + if rowLen > 1: + while level < merkleizer.topIndex: + inc level + baseChunksPerElement *= 2 + inRowIdx = inRowIdx div 2 + postUpdateInRowIdx = postUpdateInRowIdx div 2 + + var currPairEnd = if inRowIdx.isOdd: + # an odd chunk number means that we must combine the + # hash with the existing pending sibling hash in the + # merkleizer. + writeProofs(0, merkleizer.combinedChunks[level]) + merkleTree.add mergeBranches(merkleizer.combinedChunks[level], + merkleTree[treeRowStart]) + + # TODO: can we immediately write this out? + merkleizer.completeStartedChunk(merkleTree[^1], level + 1) + 2 + else: + 1 + + if postUpdateInRowIdx.isOdd: + merkleizer.combinedChunks[level] = merkleTree[treeRowStart + rowLen - + ord(zeroMixed) - 1] + while currPairEnd < rowLen: + writeProofs(currPairEnd - 1, merkleTree[treeRowStart + currPairEnd]) + writeProofs(currPairEnd, merkleTree[treeRowStart + currPairEnd - 1]) + merkleTree.add mergeBranches(merkleTree[treeRowStart + currPairEnd - 1], + merkleTree[treeRowStart + currPairEnd]) + currPairEnd += 2 + + if currPairEnd - 1 < rowLen: + zeroMixed = true + writeProofs(currPairEnd - 1, zeroHashes[level]) + merkleTree.add mergeBranches(merkleTree[treeRowStart + currPairEnd - 1], + zeroHashes[level]) + + treeRowStart += rowLen + rowLen = merkleTree.len - treeRowStart + + if rowLen == 1: + break + + doAssert rowLen == 1 + + if (inRowIdx and 2) != 0: + merkleizer.completeStartedChunk( + mergeBranches(merkleizer.combinedChunks[level + 1], merkleTree[^1]), + level + 2) + + if (not zeroMixed) and (postUpdateInRowIdx and 2) != 0: + merkleizer.combinedChunks[level + 1] = merkleTree[^1] + + while level < merkleizer.topIndex: + inc level + baseChunksPerElement *= 2 + inRowIdx = inRowIdx div 2 + + let hash = if getBitLE(merkleizer.totalChunks, level): + merkleizer.combinedChunks[level] + else: + zeroHashes[level] + + writeProofs(0, hash) + + merkleizer.totalChunks = newTotalChunks + +proc init*(S: type SszMerkleizer): S = + new result.combinedChunks + result.impl = SszMerkleizerImpl( + combinedChunks: cast[ptr UncheckedArray[Digest]]( + addr result.combinedChunks[][0]), + topIndex: binaryTreeHeight(result.limit) - 1, + totalChunks: 0) + +proc init*(S: type SszMerkleizer, + combinedChunks: openArray[Digest], + totalChunks: uint64): S = + new result.combinedChunks + result.combinedChunks[][0 ..< combinedChunks.len] = combinedChunks + result.impl = SszMerkleizerImpl( + combinedChunks: cast[ptr UncheckedArray[Digest]]( + addr result.combinedChunks[][0]), + topIndex: binaryTreeHeight(result.limit) - 1, + totalChunks: totalChunks) + +proc copy*[L: static[Limit]](cloned: SszMerkleizer[L]): SszMerkleizer[L] = + new result.combinedChunks + result.combinedChunks[] = cloned.combinedChunks[] + result.impl = SszMerkleizerImpl( + combinedChunks: cast[ptr UncheckedArray[Digest]]( + addr result.combinedChunks[][0]), + topIndex: binaryTreeHeight(L) - 1, + totalChunks: cloned.totalChunks) + +template addChunksAndGenMerkleProofs*( + merkleizer: var SszMerkleizer, + chunks: openArray[Digest]): seq[Digest] = + addChunksAndGenMerkleProofs(merkleizer.impl, chunks) + +template addChunk*(merkleizer: var SszMerkleizer, data: openArray[byte]) = + addChunk(merkleizer.impl, data) + +template totalChunks*(merkleizer: SszMerkleizer): uint64 = + merkleizer.impl.totalChunks + +template getFinalHash*(merkleizer: SszMerkleizer): Digest = + merkleizer.impl.getFinalHash + +template createMerkleizer*(totalElements: static Limit): SszMerkleizerImpl = + trs "CREATING A MERKLEIZER FOR ", totalElements + + const treeHeight = binaryTreeHeight totalElements + var combinedChunks {.noInit.}: array[treeHeight, Digest] + + SszMerkleizerImpl( + combinedChunks: cast[ptr UncheckedArray[Digest]](addr combinedChunks), + topIndex: treeHeight - 1, + totalChunks: 0) + +func getFinalHash*(merkleizer: SszMerkleizerImpl): Digest = + if merkleizer.totalChunks == 0: + return zeroHashes[merkleizer.topIndex] + + let + bottomHashIdx = firstOne(merkleizer.totalChunks) - 1 + submittedChunksHeight = bitWidth(merkleizer.totalChunks - 1) + topHashIdx = merkleizer.topIndex + + trs "BOTTOM HASH ", bottomHashIdx + trs "SUBMITTED HEIGHT ", submittedChunksHeight + trs "TOP HASH IDX ", topHashIdx + + if bottomHashIdx != submittedChunksHeight: + # Our tree is not finished. We must complete the work in progress + # branches and then extend the tree to the right height. + result = mergeBranches(merkleizer.combinedChunks[bottomHashIdx], + zeroHashes[bottomHashIdx]) + + for i in bottomHashIdx + 1 ..< topHashIdx: + if getBitLE(merkleizer.totalChunks, i): + result = mergeBranches(merkleizer.combinedChunks[i], result) + trs "COMBINED" + else: + result = mergeBranches(result, zeroHashes[i]) + trs "COMBINED WITH ZERO" + + elif bottomHashIdx == topHashIdx: + # We have a perfect tree (chunks == 2**n) at just the right height! + result = merkleizer.combinedChunks[bottomHashIdx] + else: + # We have a perfect tree of user chunks, but we have more work to + # do - we must extend it to reach the desired height + result = mergeBranches(merkleizer.combinedChunks[bottomHashIdx], + zeroHashes[bottomHashIdx]) + + for i in bottomHashIdx + 1 ..< topHashIdx: + result = mergeBranches(result, zeroHashes[i]) + +func mixInLength*(root: Digest, length: int): Digest = + var dataLen: array[32, byte] + dataLen[0..<8] = uint64(length).toBytesLE() + mergeBranches(root, dataLen) + +func hash_tree_root*(x: auto): Digest {.gcsafe, raises: [Defect].} + +template merkleizeFields(totalElements: static Limit, body: untyped): Digest = + var merkleizer {.inject.} = createMerkleizer(totalElements) + + template addField(field) = + let hash = hash_tree_root(field) + trs "MERKLEIZING FIELD ", astToStr(field), " = ", hash + addChunk(merkleizer, hash.data) + trs "CHUNK ADDED" + + body + + getFinalHash(merkleizer) + +template writeBytesLE(chunk: var array[bytesPerChunk, byte], atParam: int, + val: SomeUnsignedInt) = + let at = atParam + chunk[at ..< at + sizeof(val)] = toBytesLE(val) + +func chunkedHashTreeRootForBasicTypes[T](merkleizer: var SszMerkleizerImpl, + arr: openArray[T]): Digest = + static: + doAssert T is BasicType + + if arr.len == 0: + return getFinalHash(merkleizer) + + when T is byte: + var + remainingBytes = arr.len + pos = cast[ptr byte](unsafeAddr arr[0]) + + while remainingBytes >= bytesPerChunk: + merkleizer.addChunk(makeOpenArray(pos, bytesPerChunk)) + pos = offset(pos, bytesPerChunk) + remainingBytes -= bytesPerChunk + + if remainingBytes > 0: + merkleizer.addChunk(makeOpenArray(pos, remainingBytes)) + + elif T is bool or cpuEndian == littleEndian: + let + baseAddr = cast[ptr byte](unsafeAddr arr[0]) + len = arr.len * sizeof(T) + return chunkedHashTreeRootForBasicTypes(merkleizer, makeOpenArray(baseAddr, len)) + + else: + static: + doAssert T is UintN + doAssert bytesPerChunk mod sizeof(Т) == 0 + + const valuesPerChunk = bytesPerChunk div sizeof(Т) + + var writtenValues = 0 + + var chunk: array[bytesPerChunk, byte] + while writtenValues < arr.len - valuesPerChunk: + for i in 0 ..< valuesPerChunk: + chunk.writeBytesLE(i * sizeof(T), arr[writtenValues + i]) + merkleizer.addChunk chunk + inc writtenValues, valuesPerChunk + + let remainingValues = arr.len - writtenValues + if remainingValues > 0: + var lastChunk: array[bytesPerChunk, byte] + for i in 0 ..< remainingValues: + chunk.writeBytesLE(i * sizeof(T), arr[writtenValues + i]) + merkleizer.addChunk lastChunk + + getFinalHash(merkleizer) + +func bitListHashTreeRoot(merkleizer: var SszMerkleizerImpl, x: BitSeq): Digest = + # TODO: Switch to a simpler BitList representation and + # replace this with `chunkedHashTreeRoot` + var + totalBytes = bytes(x).len + lastCorrectedByte = bytes(x)[^1] + + if lastCorrectedByte == byte(1): + if totalBytes == 1: + # This is an empty bit list. + # It should be hashed as a tree containing all zeros: + return mergeBranches(zeroHashes[merkleizer.topIndex], + zeroHashes[0]) # this is the mixed length + + totalBytes -= 1 + lastCorrectedByte = bytes(x)[^2] + else: + let markerPos = log2trunc(lastCorrectedByte) + lastCorrectedByte.clearBit(markerPos) + + var + bytesInLastChunk = totalBytes mod bytesPerChunk + fullChunks = totalBytes div bytesPerChunk + + if bytesInLastChunk == 0: + fullChunks -= 1 + bytesInLastChunk = 32 + + for i in 0 ..< fullChunks: + let + chunkStartPos = i * bytesPerChunk + chunkEndPos = chunkStartPos + bytesPerChunk - 1 + + merkleizer.addChunk bytes(x).toOpenArray(chunkStartPos, chunkEndPos) + + var + lastChunk: array[bytesPerChunk, byte] + chunkStartPos = fullChunks * bytesPerChunk + + for i in 0 .. bytesInLastChunk - 2: + lastChunk[i] = bytes(x)[chunkStartPos + i] + + lastChunk[bytesInLastChunk - 1] = lastCorrectedByte + + merkleizer.addChunk lastChunk.toOpenArray(0, bytesInLastChunk - 1) + let contentsHash = merkleizer.getFinalHash + mixInLength contentsHash, x.len + +func maxChunksCount(T: type, maxLen: Limit): Limit = + when T is BitList|BitArray: + (maxLen + bitsPerChunk - 1) div bitsPerChunk + elif T is array|List: + maxChunkIdx(ElemType(T), maxLen) + else: + unsupported T # This should never happen + +func hashTreeRootAux[T](x: T): Digest = + when T is bool|char: + result.data[0] = byte(x) + elif T is SomeUnsignedInt: + when cpuEndian == bigEndian: + result.data[0.. 8: + v = v or v shr 8 + when bitsof(x) > 16: + v = v or v shr 16 + when bitsof(x) > 32: + v = v or v shr 32 + + v + 1 + +template dataPerChunk(T: type): int = + # How many data items fit in a chunk + when T is BasicType: + bytesPerChunk div sizeof(T) + else: + 1 + +template chunkIdx*(T: type, dataIdx: int64): int64 = + # Given a data index, which chunk does it belong to? + dataIdx div dataPerChunk(T) + +template maxChunkIdx*(T: type, maxLen: Limit): int64 = + # Given a number of data items, how many chunks are needed? + # TODO compiler bug: + # beacon_chain/ssz/types.nim(75, 53) Error: cannot generate code for: maxLen + # nextPow2(chunkIdx(T, maxLen + dataPerChunk(T) - 1).uint64).int64 + nextPow2Int64(chunkIdx(T, maxLen.int64 + dataPerChunk(T) - 1)) + template asSeq*(x: List): auto = distinctBase(x) template init*[T](L: type List, x: seq[T], N: static Limit): auto =