This commit is contained in:
Jacek Sieka 2020-06-02 17:17:07 +02:00
parent 872d7ff493
commit 7e881a4c09
No known key found for this signature in database
GPG Key ID: A1B09461ABB656B8
2 changed files with 69 additions and 66 deletions

View File

@ -295,7 +295,7 @@ proc readValue*[T](r: var SszReader, val: var T) {.raises: [Defect, MalformedSsz
readSszValue(r.stream.read(r.stream.len.get), val) readSszValue(r.stream.read(r.stream.len.get), val)
const const
zeroChunk = default array[32, byte] zero64 = default array[64, byte]
func hash(a, b: openArray[byte]): Eth2Digest = func hash(a, b: openArray[byte]): Eth2Digest =
result = withEth2Hash: result = withEth2Hash:
@ -319,14 +319,14 @@ func mergeBranches(existing: Eth2Digest, newData: openarray[byte]): Eth2Digest =
let paddingBytes = bytesPerChunk - newData.len let paddingBytes = bytesPerChunk - newData.len
if paddingBytes > 0: if paddingBytes > 0:
trs "USING ", paddingBytes, " PADDING BYTES" trs "USING ", paddingBytes, " PADDING BYTES"
h.update zeroChunk.toOpenArray(0, paddingBytes - 1) h.update zero64.toOpenArray(0, paddingBytes - 1)
trs "HASH RESULT ", result trs "HASH RESULT ", result
template mergeBranches(a, b: Eth2Digest): Eth2Digest = template mergeBranches(a, b: Eth2Digest): Eth2Digest =
hash(a.data, b.data) hash(a.data, b.data)
func computeZeroHashes: array[sizeof(Limit) * 8, Eth2Digest] = func computeZeroHashes: array[sizeof(Limit) * 8, Eth2Digest] =
result[0] = Eth2Digest(data: zeroChunk) result[0] = Eth2Digest()
for i in 1 .. result.high: for i in 1 .. result.high:
result[i] = mergeBranches(result[i - 1], result[i - 1]) result[i] = mergeBranches(result[i - 1], result[i - 1])
@ -533,11 +533,7 @@ func maxChunksCount(T: type, maxLen: int64): int64 =
when T is BitList|BitArray: when T is BitList|BitArray:
(maxLen + bitsPerChunk - 1) div bitsPerChunk (maxLen + bitsPerChunk - 1) div bitsPerChunk
elif T is array|List: elif T is array|List:
type E = ElemType(T) maxChunkIdx(ElemType(T), maxLen)
when E is BasicType:
(maxLen * sizeof(E) + bytesPerChunk - 1) div bytesPerChunk
else:
maxLen
else: else:
unsupported T # This should never happen unsupported T # This should never happen
@ -579,8 +575,28 @@ func hashTreeRootAux[T](x: T): Eth2Digest =
else: else:
unsupported T unsupported T
func hashTreeRootList(x: List|BitList): Eth2Digest =
const maxLen = static(x.maxLen)
type T = type(x)
const limit = maxChunksCount(T, maxLen)
var merkleizer = createMerkleizer(limit)
when x is BitList:
merkleizer.bitListHashTreeRoot(BitSeq x)
else:
type E = ElemType(T)
let contentsHash = when E is BasicType:
chunkedHashTreeRootForBasicTypes(merkleizer, asSeq x)
else:
for elem in x:
let elemHash = hash_tree_root(elem)
merkleizer.addChunk(elemHash.data)
merkleizer.getFinalHash()
mixInLength(contentsHash, x.len)
func mergedDataHash(x: HashList|HashArray, chunkIdx: int64): Eth2Digest = func mergedDataHash(x: HashList|HashArray, chunkIdx: int64): Eth2Digest =
# The hash of the two cached # The merged hash of the data at `chunkIdx` and `chunkIdx + 1`
trs "DATA HASH ", chunkIdx, " ", x.data.len trs "DATA HASH ", chunkIdx, " ", x.data.len
when x.T is BasicType: when x.T is BasicType:
@ -592,8 +608,6 @@ func mergedDataHash(x: HashList|HashArray, chunkIdx: int64): Eth2Digest =
byteIdx = chunkIdx * bytesPerChunk byteIdx = chunkIdx * bytesPerChunk
byteLen = x.data.len * sizeof(x.T) byteLen = x.data.len * sizeof(x.T)
const zero64 = default(array[64, byte])
if byteIdx >= byteLen: if byteIdx >= byteLen:
zeroHashes[1] zeroHashes[1]
else: else:
@ -616,8 +630,20 @@ func mergedDataHash(x: HashList|HashArray, chunkIdx: int64): Eth2Digest =
hash_tree_root(x.data[chunkIdx]), hash_tree_root(x.data[chunkIdx]),
hash_tree_root(x.data[chunkIdx + 1])) hash_tree_root(x.data[chunkIdx + 1]))
func cachedHash*(x: HashList, vIdx: int64): Eth2Digest = template mergedHash(x: HashList|HashArray, vIdxParam: int64): Eth2Digest =
doAssert vIdx >= 1 # The merged hash of the data at `vIdx` and `vIdx + 1`
let vIdx = vIdxParam
if vIdx >= x.maxChunks:
let dataIdx = vIdx - x.maxChunks
mergedDataHash(x, dataIdx)
else:
mergeBranches(
hashTreeRootCached(x, vIdx),
hashTreeRootCached(x, vIdx + 1))
func hashTreeRootCached*(x: HashList, vIdx: int64): Eth2Digest =
doAssert vIdx >= 1, "Only valid for flat merkle tree indices"
let let
layer = layer(vIdx) layer = layer(vIdx)
@ -636,73 +662,48 @@ func cachedHash*(x: HashList, vIdx: int64): Eth2Digest =
trs "REFRESHING ", vIdx, " ", layerIdx, " ", layer trs "REFRESHING ", vIdx, " ", layerIdx, " ", layer
px[].hashes[layerIdx] = px[].hashes[layerIdx] = mergedHash(x, vIdx * 2)
if layer == x.maxDepth - 1:
let dataIdx = vIdx * 2 - 1'i64 shl (x.maxDepth)
mergedDataHash(x, dataIdx)
else:
mergeBranches(
cachedHash(x, vIdx * 2),
cachedHash(x, vIdx * 2 + 1))
else: else:
trs "CACHED ", layerIdx trs "CACHED ", layerIdx
x.hashes[layerIdx] x.hashes[layerIdx]
func cachedHash*(x: HashArray, i: int): Eth2Digest = func hashTreeRootCached*(x: HashArray, vIdx: int): Eth2Digest =
doAssert i > 0, "Only valid for flat merkle tree indices" doAssert vIdx >= 1, "Only valid for flat merkle tree indices"
if not isCached(x.hashes[i]): if not isCached(x.hashes[vIdx]):
# TODO oops. so much for maintaining non-mutability. # TODO oops. so much for maintaining non-mutability.
let px = unsafeAddr x let px = unsafeAddr x
px[].hashes[i] = px[].hashes[vIdx] = mergedHash(x, vIdx * 2)
if i * 2 >= x.hashes.len():
let dataIdx = i * 2 - x.hashes.len()
mergedDataHash(x, dataIdx)
else:
mergeBranches(
cachedHash(x, i * 2),
cachedHash(x, i * 2 + 1))
return x.hashes[i] return x.hashes[vIdx]
func hashTreeRootCached*(x: HashArray): Eth2Digest =
hashTreeRootCached(x, 1) # Array does not use idx 0
func hashTreeRootCached*(x: HashList): Eth2Digest =
if x.data.len == 0:
mixInLength(zeroHashes[x.maxDepth], x.data.len())
else:
if not isCached(x.hashes[0]):
# TODO oops. so much for maintaining non-mutability.
let px = unsafeAddr x
px[].hashes[0] = mixInLength(hashTreeRootCached(x, 1), x.data.len)
x.hashes[0]
func hash_tree_root*(x: auto): Eth2Digest {.raises: [Defect], nbench.} = func hash_tree_root*(x: auto): Eth2Digest {.raises: [Defect], nbench.} =
trs "STARTING HASH TREE ROOT FOR TYPE ", name(type(x)) trs "STARTING HASH TREE ROOT FOR TYPE ", name(type(x))
mixin toSszType mixin toSszType
result = when x is HashArray: result =
cachedHash(x, 1) when x is HashArray|HashList:
elif x is HashList: hashTreeRootCached(x)
if x.hashes.len < 2: elif x is List|BitList:
mixInLength(zeroHashes[x.maxDepth], x.data.len()) hashTreeRootList(x)
else:
if not isCached(x.hashes[0]):
# TODO oops. so much for maintaining non-mutability.
let px = unsafeAddr x
px[].hashes[0] = mixInLength(cachedHash(x, 1), x.data.len)
x.hashes[0]
elif x is List|BitList:
const maxLen = static(x.maxLen)
type T = type(x)
const limit = maxChunksCount(T, maxLen)
var merkleizer = createMerkleizer(limit)
when x is BitList:
merkleizer.bitListHashTreeRoot(BitSeq x)
else: else:
type E = ElemType(T) hashTreeRootAux toSszType(x)
let contentsHash = when E is BasicType:
chunkedHashTreeRootForBasicTypes(merkleizer, asSeq x)
else:
for elem in x:
let elemHash = hash_tree_root(elem)
merkleizer.addChunk(elemHash.data)
merkleizer.getFinalHash()
mixInLength(contentsHash, x.len)
else:
hashTreeRootAux toSszType(x)
trs "HASH TREE ROOT FOR ", name(type x), " = ", "0x", $result trs "HASH TREE ROOT FOR ", name(type x), " = ", "0x", $result

View File

@ -8,6 +8,8 @@ import
const const
offsetSize* = 4 offsetSize* = 4
bytesPerChunk* = 32
# A few index types from here onwards: # A few index types from here onwards:
# * dataIdx - leaf index starting from 0 to maximum length of collection # * dataIdx - leaf index starting from 0 to maximum length of collection
@ -18,7 +20,7 @@ const
proc dataPerChunk(T: type): int = proc dataPerChunk(T: type): int =
# How many data items fit in a chunk # How many data items fit in a chunk
when T is bool|SomeUnsignedInt: # BasicType when T is bool|SomeUnsignedInt: # BasicType
32 div sizeof(T) bytesPerChunk div sizeof(T)
else: else:
1 1
@ -26,7 +28,7 @@ template chunkIdx*(T: type, dataIdx: int64): int64 =
# Given a data index, which chunk does it belong to? # Given a data index, which chunk does it belong to?
dataIdx div dataPerChunk(T) dataIdx div dataPerChunk(T)
template maxChunkIdx(T: type, maxLen: int64): int64 = template maxChunkIdx*(T: type, maxLen: int64): int64 =
# Given a number of data items, how many chunks are needed? # Given a number of data items, how many chunks are needed?
chunkIdx(T, maxLen + dataPerChunk(T) - 1) chunkIdx(T, maxLen + dataPerChunk(T) - 1)