Some speedups (#226)
* ssz: avoid memory allocations * a bit fishy with the 32-item stack.. this should be a smallvector * digest: avoid another burnmem * avoid a few allocations by using iterator
This commit is contained in:
parent
ad133a0222
commit
605dd0a0e9
|
@ -380,10 +380,10 @@ func get_attestation_participants*(state: BeaconState,
|
|||
if aggregation_bit:
|
||||
result.add(validator_index)
|
||||
|
||||
func get_attestation_participants_cached*(state: BeaconState,
|
||||
iterator get_attestation_participants_cached*(state: BeaconState,
|
||||
attestation_data: AttestationData,
|
||||
bitfield: BitField,
|
||||
crosslink_committees_cached: var auto): seq[ValidatorIndex] =
|
||||
crosslink_committees_cached: var auto): ValidatorIndex =
|
||||
## Return the participant indices at for the ``attestation_data`` and
|
||||
## ``bitfield``.
|
||||
## Attestation participants in the attestation data are called out in a
|
||||
|
@ -397,26 +397,25 @@ func get_attestation_participants_cached*(state: BeaconState,
|
|||
# TODO iterator candidate
|
||||
|
||||
# Find the committee in the list with the desired shard
|
||||
let crosslink_committees = get_crosslink_committees_at_slot_cached(
|
||||
state, attestation_data.slot, false, crosslink_committees_cached)
|
||||
# let crosslink_committees = get_crosslink_committees_at_slot_cached(
|
||||
# state, attestation_data.slot, false, crosslink_committees_cached)
|
||||
|
||||
doAssert anyIt(
|
||||
crosslink_committees,
|
||||
it[1] == attestation_data.shard)
|
||||
let crosslink_committee = mapIt(
|
||||
filterIt(crosslink_committees, it.shard == attestation_data.shard),
|
||||
it.committee)[0]
|
||||
var found = false
|
||||
for crosslink_committee in get_crosslink_committees_at_slot_cached(
|
||||
state, attestation_data.slot, false, crosslink_committees_cached):
|
||||
if crosslink_committee.shard == attestation_data.shard:
|
||||
# TODO this and other attestation-based fields need validation so we don't
|
||||
# crash on a malicious attestation!
|
||||
doAssert verify_bitfield(bitfield, len(crosslink_committee.committee))
|
||||
|
||||
# TODO this and other attestation-based fields need validation so we don't
|
||||
# crash on a malicious attestation!
|
||||
doAssert verify_bitfield(bitfield, len(crosslink_committee))
|
||||
|
||||
# Find the participating attesters in the committee
|
||||
result = @[]
|
||||
for i, validator_index in crosslink_committee:
|
||||
let aggregation_bit = get_bitfield_bit(bitfield, i)
|
||||
if aggregation_bit:
|
||||
result.add(validator_index)
|
||||
# Find the participating attesters in the committee
|
||||
for i, validator_index in crosslink_committee.committee:
|
||||
let aggregation_bit = get_bitfield_bit(bitfield, i)
|
||||
if aggregation_bit:
|
||||
yield validator_index
|
||||
found = true
|
||||
break
|
||||
doAssert found, "Couldn't find crosslink committee"
|
||||
|
||||
# https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#ejections
|
||||
func process_ejections*(state: var BeaconState) =
|
||||
|
|
|
@ -36,7 +36,7 @@ func shortLog*(x: Eth2Digest): string =
|
|||
result = ($x)[0..7]
|
||||
|
||||
func eth2hash*(v: openArray[byte]): Eth2Digest =
|
||||
var ctx: Eth2Hash
|
||||
var ctx: keccak256 # use explicit type so we can rely on init being useless
|
||||
# We can avoid this step for Keccak/SHA3 digests because `ctx` is already
|
||||
# empty, but if digest will be changed next line must be enabled.
|
||||
# ctx.init()
|
||||
|
@ -47,8 +47,8 @@ template withEth2Hash*(body: untyped): Eth2Digest =
|
|||
## This little helper will init the hash function and return the sliced
|
||||
## hash:
|
||||
## let hashOfData = withHash: h.update(data)
|
||||
var h {.inject.}: Eth2Hash
|
||||
h.init()
|
||||
var h {.inject.}: keccak256
|
||||
# TODO no need, as long as using keccak256: h.init()
|
||||
body
|
||||
var res = h.finish()
|
||||
res
|
||||
|
|
|
@ -229,16 +229,17 @@ func get_crosslink_committees_at_slot*(state: BeaconState, slot: Slot|uint64,
|
|||
(slot_start_shard + i.uint64) mod SHARD_COUNT
|
||||
)
|
||||
|
||||
func get_crosslink_committees_at_slot_cached*(
|
||||
iterator get_crosslink_committees_at_slot_cached*(
|
||||
state: BeaconState, slot: Slot|uint64,
|
||||
registry_change: bool = false, cache: var auto):
|
||||
seq[CrosslinkCommittee] =
|
||||
CrosslinkCommittee =
|
||||
let key = (slot.uint64, registry_change)
|
||||
if key in cache:
|
||||
return cache[key]
|
||||
for v in cache[key]: yield v
|
||||
#debugEcho "get_crosslink_committees_at_slot_cached: MISS"
|
||||
result = get_crosslink_committees_at_slot(state, slot, registry_change)
|
||||
let result = get_crosslink_committees_at_slot(state, slot, registry_change)
|
||||
cache[key] = result
|
||||
for v in result: yield v
|
||||
|
||||
# https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#get_beacon_proposer_index
|
||||
func get_beacon_proposer_index*(state: BeaconState, slot: Slot): ValidatorIndex =
|
||||
|
|
|
@ -297,71 +297,99 @@ func mix_in_length(root: Chunk, length: int): Chunk =
|
|||
|
||||
hash(root, dataLen)
|
||||
|
||||
proc pack(values: seq|array): iterator(): Chunk =
|
||||
result = iterator (): Chunk =
|
||||
# TODO should be trivial to avoid this seq also..
|
||||
# TODO I get a feeling a copy of the array is taken to the closure, which
|
||||
# also needs fixing
|
||||
# TODO avoid closure iterators that involve GC
|
||||
var tmp =
|
||||
newSeqOfCap[byte](values.len() * sizeof(toBytesSSZ(values[0].toSSZType())))
|
||||
template padEmptyChunks(chunks: int) =
|
||||
for i in chunks..<nextPowerOfTwo(chunks):
|
||||
yield emptyChunk
|
||||
|
||||
iterator packAndPad(values: seq|array): Chunk =
|
||||
## Produce a stream of chunks that are packed and padded such that they number
|
||||
## a power of two
|
||||
|
||||
when sizeof(values[0].toSSZType().toBytesSSZ()) == sizeof(Chunk):
|
||||
# When chunks and value lengths coincide, do the simple thing
|
||||
for v in values:
|
||||
tmp.add toBytesSSZ(v.toSSZType)
|
||||
yield v.toSSZType().toBytesSSZ()
|
||||
padEmptyChunks(values.len)
|
||||
|
||||
for v in 0..<tmp.len div sizeof(Chunk):
|
||||
var c: Chunk
|
||||
copyMem(addr c, addr tmp[v * sizeof(Chunk)], sizeof(Chunk))
|
||||
yield c
|
||||
else:
|
||||
var
|
||||
chunks: int
|
||||
tmp: Chunk
|
||||
tmpPos: int # how many bytes of tmp we've filled with ssz values
|
||||
|
||||
let remains = tmp.len mod sizeof(Chunk)
|
||||
if remains != 0:
|
||||
var c: Chunk
|
||||
copyMem(addr c, addr tmp[tmp.len - remains], remains)
|
||||
yield c
|
||||
for v in values:
|
||||
var
|
||||
vssz = toBytesSSZ(v.toSSZType)
|
||||
vPos = 0 # how many bytes of vssz that we've consumed
|
||||
|
||||
proc pad(iter: iterator(): Chunk): iterator(): Chunk =
|
||||
# Pad a list of chunks to the next power-of-two length with empty chunks -
|
||||
# this includes ensuring there's at least one chunk return
|
||||
result = iterator(): Chunk =
|
||||
var count = 0
|
||||
while vPos < vssz.len:
|
||||
# there are still bytes of vssz left to consume - looping happens when
|
||||
# vssz.len > sizeof(Chunk)
|
||||
|
||||
while true:
|
||||
let item = iter()
|
||||
if finished(iter): break
|
||||
count += 1
|
||||
yield item
|
||||
let left = min(tmp.len - tmpPos, vssz.len - vPos)
|
||||
copyMem(addr tmp[tmpPos], addr vssz[vPos], left)
|
||||
vPos += left
|
||||
tmpPos += left
|
||||
|
||||
doAssert nextPowerOfTwo(0) == 1,
|
||||
"Usefully, empty lists will be padded to one empty block"
|
||||
if tmpPos == tmp.len:
|
||||
# When vssz.len < sizeof(Chunk), multiple values will fit in a chunk
|
||||
yield tmp
|
||||
tmpPos = 0
|
||||
chunks += 1
|
||||
|
||||
for _ in count..<nextPowerOfTwo(count):
|
||||
yield emptyChunk
|
||||
if tmpPos > 0:
|
||||
# If vssz.len is not a multiple of Chunk, we might need to pad the last
|
||||
# chunk with zeroes and return it
|
||||
for i in tmpPos..<tmp.len:
|
||||
tmp[i] = 0'u8
|
||||
yield tmp
|
||||
tmpPos = 0
|
||||
chunks += 1
|
||||
|
||||
func merkleize(chunker: iterator(): Chunk): Chunk =
|
||||
padEmptyChunks(chunks)
|
||||
|
||||
iterator hash_tree_collection(value: array|seq): Chunk =
|
||||
mixin hash_tree_root
|
||||
var chunks = 0
|
||||
for v in value:
|
||||
yield hash_tree_root(v).data
|
||||
chunks += 1
|
||||
padEmptyChunks(chunks)
|
||||
|
||||
iterator hash_tree_fields(value: object): Chunk =
|
||||
mixin hash_tree_root
|
||||
var chunks = 0
|
||||
for v in value.fields:
|
||||
yield hash_tree_root(v).data
|
||||
chunks += 1
|
||||
padEmptyChunks(chunks)
|
||||
|
||||
template merkleize(chunker: untyped): Chunk =
|
||||
var
|
||||
stack: seq[tuple[height: int, chunk: Chunk]]
|
||||
paddedChunker = pad(chunker)
|
||||
|
||||
while true:
|
||||
let chunk = paddedChunker()
|
||||
if finished(paddedChunker): break
|
||||
# a depth of 32 here should give us capability to handle 2^32 chunks,
|
||||
# more than enough
|
||||
# TODO replace with SmallVector-like thing..
|
||||
stack: array[32, tuple[height: int, chunk: Chunk]]
|
||||
stackPos = 0
|
||||
|
||||
for chunk in chunker:
|
||||
# Leaves start at height 0 - every time they move up, height is increased
|
||||
# allowing us to detect two chunks at the same height ready for
|
||||
# consolidation
|
||||
# See also: http://szydlo.com/logspacetime03.pdf
|
||||
stack.add (0, chunk)
|
||||
stack[stackPos] = (0, chunk)
|
||||
inc stackPos
|
||||
|
||||
# Consolidate items of the same height - this keeps stack size at log N
|
||||
while stack.len > 1 and stack[^1].height == stack[^2].height:
|
||||
while stackPos > 1 and stack[stackPos - 1].height == stack[stackPos - 2].height:
|
||||
# As tradition dictates - one feature, at least one nim bug:
|
||||
# https://github.com/nim-lang/Nim/issues/9684
|
||||
let tmp = hash(stack[^2].chunk, stack[^1].chunk)
|
||||
stack[^2].height += 1
|
||||
stack[^2].chunk = tmp
|
||||
discard stack.pop
|
||||
let tmp = hash(stack[stackPos - 2].chunk, stack[stackPos - 1].chunk)
|
||||
stack[stackPos - 2].height += 1
|
||||
stack[stackPos - 2].chunk = tmp
|
||||
stackPos -= 1
|
||||
|
||||
doAssert stack.len == 1,
|
||||
doAssert stackPos == 1,
|
||||
"With power-of-two leaves, we should end up with a single root"
|
||||
|
||||
stack[0].chunk
|
||||
|
@ -373,40 +401,35 @@ func hash_tree_root*[T](value: T): Eth2Digest =
|
|||
# Merkle tree
|
||||
Eth2Digest(data:
|
||||
when T is BasicType:
|
||||
merkleize(pack([value]))
|
||||
merkleize(packAndPad([value]))
|
||||
elif T is array|seq:
|
||||
when T.elementType() is BasicType:
|
||||
mix_in_length(merkleize(pack(value)), len(value))
|
||||
mix_in_length(merkleize(packAndPad(value)), len(value))
|
||||
else:
|
||||
var roots = iterator(): Chunk =
|
||||
for v in value:
|
||||
yield hash_tree_root(v).data
|
||||
mix_in_length(merkleize(roots), len(value))
|
||||
mix_in_length(merkleize(hash_tree_collection(value)), len(value))
|
||||
elif T is object:
|
||||
var roots = iterator(): Chunk =
|
||||
for v in value.fields:
|
||||
yield hash_tree_root(v).data
|
||||
merkleize(roots)
|
||||
merkleize(hash_tree_fields(value))
|
||||
else:
|
||||
static: doAssert false, "Unexpected type: " & T.name
|
||||
)
|
||||
|
||||
iterator hash_tree_most(v: object): Chunk =
|
||||
var found_field_name = false
|
||||
|
||||
for name, field in v.fieldPairs:
|
||||
# TODO we should truncate the last field, regardless of its name.. this
|
||||
# hack works for now - how to skip the last fieldPair though??
|
||||
if name == "signature":
|
||||
found_field_name = true
|
||||
break
|
||||
yield hash_tree_root(field).data
|
||||
|
||||
doAssert found_field_name
|
||||
|
||||
# https://github.com/ethereum/eth2.0-specs/blob/0.4.0/specs/simple-serialize.md#signed-roots
|
||||
func signed_root*[T: object](x: T): Eth2Digest =
|
||||
# TODO write tests for this (check vs hash_tree_root)
|
||||
|
||||
var found_field_name = false
|
||||
var roots = iterator(): Chunk =
|
||||
for name, field in x.fieldPairs:
|
||||
# TODO we should truncate the last field, regardless of its name.. this
|
||||
# hack works for now - how to skip the last fieldPair though??
|
||||
if name == "signature":
|
||||
found_field_name = true
|
||||
break
|
||||
yield hash_tree_root(field).data
|
||||
|
||||
let root = merkleize(roots)
|
||||
|
||||
doAssert found_field_name
|
||||
let root = merkleize(hash_tree_most(x))
|
||||
|
||||
Eth2Digest(data: root)
|
||||
|
|
|
@ -79,12 +79,17 @@ suite "Simple serialization":
|
|||
SSZ.roundripTest BeaconState(slot: 42.Slot)
|
||||
|
||||
suite "Tree hashing":
|
||||
# TODO Nothing but smoke tests for now..
|
||||
# TODO The test values are taken from an earlier version of SSZ and have
|
||||
# nothing to do with upstream - needs verification and proper test suite
|
||||
|
||||
test "Hash BeaconBlock":
|
||||
let vr = BeaconBlock()
|
||||
check: hash_tree_root(vr) != Eth2Digest()
|
||||
check:
|
||||
$hash_tree_root(vr) ==
|
||||
"1BD5D8577A7806CC524C367808C53AE2480F35A3C4BB11A90D6E1AC304E27201"
|
||||
|
||||
test "Hash BeaconState":
|
||||
let vr = BeaconBlock()
|
||||
check: hash_tree_root(vr) != Eth2Digest()
|
||||
let vr = BeaconState()
|
||||
check:
|
||||
$hash_tree_root(vr) ==
|
||||
"DC751EF09987283D52483C75690234DDD75FFDAF1A844CD56FE1173465B5597A"
|
||||
|
|
Loading…
Reference in New Issue