Some speedups (#226)

* ssz: avoid memory allocations

* a bit fishy with the 32-item stack.. this should be a smallvector

* digest: avoid another burnmem

* avoid a few allocations by using iterator
This commit is contained in:
Jacek Sieka 2019-04-03 09:46:22 -06:00 committed by Dustin Brody
parent ad133a0222
commit 605dd0a0e9
5 changed files with 128 additions and 100 deletions

View File

@ -380,10 +380,10 @@ func get_attestation_participants*(state: BeaconState,
if aggregation_bit:
result.add(validator_index)
func get_attestation_participants_cached*(state: BeaconState,
iterator get_attestation_participants_cached*(state: BeaconState,
attestation_data: AttestationData,
bitfield: BitField,
crosslink_committees_cached: var auto): seq[ValidatorIndex] =
crosslink_committees_cached: var auto): ValidatorIndex =
## Return the participant indices at for the ``attestation_data`` and
## ``bitfield``.
## Attestation participants in the attestation data are called out in a
@ -397,26 +397,25 @@ func get_attestation_participants_cached*(state: BeaconState,
# TODO iterator candidate
# Find the committee in the list with the desired shard
let crosslink_committees = get_crosslink_committees_at_slot_cached(
state, attestation_data.slot, false, crosslink_committees_cached)
doAssert anyIt(
crosslink_committees,
it[1] == attestation_data.shard)
let crosslink_committee = mapIt(
filterIt(crosslink_committees, it.shard == attestation_data.shard),
it.committee)[0]
# let crosslink_committees = get_crosslink_committees_at_slot_cached(
# state, attestation_data.slot, false, crosslink_committees_cached)
var found = false
for crosslink_committee in get_crosslink_committees_at_slot_cached(
state, attestation_data.slot, false, crosslink_committees_cached):
if crosslink_committee.shard == attestation_data.shard:
# TODO this and other attestation-based fields need validation so we don't
# crash on a malicious attestation!
doAssert verify_bitfield(bitfield, len(crosslink_committee))
doAssert verify_bitfield(bitfield, len(crosslink_committee.committee))
# Find the participating attesters in the committee
result = @[]
for i, validator_index in crosslink_committee:
for i, validator_index in crosslink_committee.committee:
let aggregation_bit = get_bitfield_bit(bitfield, i)
if aggregation_bit:
result.add(validator_index)
yield validator_index
found = true
break
doAssert found, "Couldn't find crosslink committee"
# https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#ejections
func process_ejections*(state: var BeaconState) =

View File

@ -36,7 +36,7 @@ func shortLog*(x: Eth2Digest): string =
result = ($x)[0..7]
func eth2hash*(v: openArray[byte]): Eth2Digest =
var ctx: Eth2Hash
var ctx: keccak256 # use explicit type so we can rely on init being useless
# We can avoid this step for Keccak/SHA3 digests because `ctx` is already
# empty, but if digest will be changed next line must be enabled.
# ctx.init()
@ -47,8 +47,8 @@ template withEth2Hash*(body: untyped): Eth2Digest =
## This little helper will init the hash function and return the sliced
## hash:
## let hashOfData = withHash: h.update(data)
var h {.inject.}: Eth2Hash
h.init()
var h {.inject.}: keccak256
# TODO no need, as long as using keccak256: h.init()
body
var res = h.finish()
res

View File

@ -229,16 +229,17 @@ func get_crosslink_committees_at_slot*(state: BeaconState, slot: Slot|uint64,
(slot_start_shard + i.uint64) mod SHARD_COUNT
)
func get_crosslink_committees_at_slot_cached*(
iterator get_crosslink_committees_at_slot_cached*(
state: BeaconState, slot: Slot|uint64,
registry_change: bool = false, cache: var auto):
seq[CrosslinkCommittee] =
CrosslinkCommittee =
let key = (slot.uint64, registry_change)
if key in cache:
return cache[key]
for v in cache[key]: yield v
#debugEcho "get_crosslink_committees_at_slot_cached: MISS"
result = get_crosslink_committees_at_slot(state, slot, registry_change)
let result = get_crosslink_committees_at_slot(state, slot, registry_change)
cache[key] = result
for v in result: yield v
# https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#get_beacon_proposer_index
func get_beacon_proposer_index*(state: BeaconState, slot: Slot): ValidatorIndex =

View File

@ -297,71 +297,99 @@ func mix_in_length(root: Chunk, length: int): Chunk =
hash(root, dataLen)
proc pack(values: seq|array): iterator(): Chunk =
result = iterator (): Chunk =
# TODO should be trivial to avoid this seq also..
# TODO I get a feeling a copy of the array is taken to the closure, which
# also needs fixing
# TODO avoid closure iterators that involve GC
var tmp =
newSeqOfCap[byte](values.len() * sizeof(toBytesSSZ(values[0].toSSZType())))
for v in values:
tmp.add toBytesSSZ(v.toSSZType)
for v in 0..<tmp.len div sizeof(Chunk):
var c: Chunk
copyMem(addr c, addr tmp[v * sizeof(Chunk)], sizeof(Chunk))
yield c
let remains = tmp.len mod sizeof(Chunk)
if remains != 0:
var c: Chunk
copyMem(addr c, addr tmp[tmp.len - remains], remains)
yield c
proc pad(iter: iterator(): Chunk): iterator(): Chunk =
# Pad a list of chunks to the next power-of-two length with empty chunks -
# this includes ensuring there's at least one chunk return
result = iterator(): Chunk =
var count = 0
while true:
let item = iter()
if finished(iter): break
count += 1
yield item
doAssert nextPowerOfTwo(0) == 1,
"Usefully, empty lists will be padded to one empty block"
for _ in count..<nextPowerOfTwo(count):
template padEmptyChunks(chunks: int) =
for i in chunks..<nextPowerOfTwo(chunks):
yield emptyChunk
func merkleize(chunker: iterator(): Chunk): Chunk =
iterator packAndPad(values: seq|array): Chunk =
## Produce a stream of chunks that are packed and padded such that they number
## a power of two
when sizeof(values[0].toSSZType().toBytesSSZ()) == sizeof(Chunk):
# When chunks and value lengths coincide, do the simple thing
for v in values:
yield v.toSSZType().toBytesSSZ()
padEmptyChunks(values.len)
else:
var
stack: seq[tuple[height: int, chunk: Chunk]]
paddedChunker = pad(chunker)
chunks: int
tmp: Chunk
tmpPos: int # how many bytes of tmp we've filled with ssz values
while true:
let chunk = paddedChunker()
if finished(paddedChunker): break
for v in values:
var
vssz = toBytesSSZ(v.toSSZType)
vPos = 0 # how many bytes of vssz that we've consumed
while vPos < vssz.len:
# there are still bytes of vssz left to consume - looping happens when
# vssz.len > sizeof(Chunk)
let left = min(tmp.len - tmpPos, vssz.len - vPos)
copyMem(addr tmp[tmpPos], addr vssz[vPos], left)
vPos += left
tmpPos += left
if tmpPos == tmp.len:
# When vssz.len < sizeof(Chunk), multiple values will fit in a chunk
yield tmp
tmpPos = 0
chunks += 1
if tmpPos > 0:
# If vssz.len is not a multiple of Chunk, we might need to pad the last
# chunk with zeroes and return it
for i in tmpPos..<tmp.len:
tmp[i] = 0'u8
yield tmp
tmpPos = 0
chunks += 1
padEmptyChunks(chunks)
iterator hash_tree_collection(value: array|seq): Chunk =
mixin hash_tree_root
var chunks = 0
for v in value:
yield hash_tree_root(v).data
chunks += 1
padEmptyChunks(chunks)
iterator hash_tree_fields(value: object): Chunk =
mixin hash_tree_root
var chunks = 0
for v in value.fields:
yield hash_tree_root(v).data
chunks += 1
padEmptyChunks(chunks)
template merkleize(chunker: untyped): Chunk =
var
# a depth of 32 here should give us capability to handle 2^32 chunks,
# more than enough
# TODO replace with SmallVector-like thing..
stack: array[32, tuple[height: int, chunk: Chunk]]
stackPos = 0
for chunk in chunker:
# Leaves start at height 0 - every time they move up, height is increased
# allowing us to detect two chunks at the same height ready for
# consolidation
# See also: http://szydlo.com/logspacetime03.pdf
stack.add (0, chunk)
stack[stackPos] = (0, chunk)
inc stackPos
# Consolidate items of the same height - this keeps stack size at log N
while stack.len > 1 and stack[^1].height == stack[^2].height:
while stackPos > 1 and stack[stackPos - 1].height == stack[stackPos - 2].height:
# As tradition dictates - one feature, at least one nim bug:
# https://github.com/nim-lang/Nim/issues/9684
let tmp = hash(stack[^2].chunk, stack[^1].chunk)
stack[^2].height += 1
stack[^2].chunk = tmp
discard stack.pop
let tmp = hash(stack[stackPos - 2].chunk, stack[stackPos - 1].chunk)
stack[stackPos - 2].height += 1
stack[stackPos - 2].chunk = tmp
stackPos -= 1
doAssert stack.len == 1,
doAssert stackPos == 1,
"With power-of-two leaves, we should end up with a single root"
stack[0].chunk
@ -373,31 +401,22 @@ func hash_tree_root*[T](value: T): Eth2Digest =
# Merkle tree
Eth2Digest(data:
when T is BasicType:
merkleize(pack([value]))
merkleize(packAndPad([value]))
elif T is array|seq:
when T.elementType() is BasicType:
mix_in_length(merkleize(pack(value)), len(value))
mix_in_length(merkleize(packAndPad(value)), len(value))
else:
var roots = iterator(): Chunk =
for v in value:
yield hash_tree_root(v).data
mix_in_length(merkleize(roots), len(value))
mix_in_length(merkleize(hash_tree_collection(value)), len(value))
elif T is object:
var roots = iterator(): Chunk =
for v in value.fields:
yield hash_tree_root(v).data
merkleize(roots)
merkleize(hash_tree_fields(value))
else:
static: doAssert false, "Unexpected type: " & T.name
)
# https://github.com/ethereum/eth2.0-specs/blob/0.4.0/specs/simple-serialize.md#signed-roots
func signed_root*[T: object](x: T): Eth2Digest =
# TODO write tests for this (check vs hash_tree_root)
iterator hash_tree_most(v: object): Chunk =
var found_field_name = false
var roots = iterator(): Chunk =
for name, field in x.fieldPairs:
for name, field in v.fieldPairs:
# TODO we should truncate the last field, regardless of its name.. this
# hack works for now - how to skip the last fieldPair though??
if name == "signature":
@ -405,8 +424,12 @@ func signed_root*[T: object](x: T): Eth2Digest =
break
yield hash_tree_root(field).data
let root = merkleize(roots)
doAssert found_field_name
# https://github.com/ethereum/eth2.0-specs/blob/0.4.0/specs/simple-serialize.md#signed-roots
func signed_root*[T: object](x: T): Eth2Digest =
# TODO write tests for this (check vs hash_tree_root)
let root = merkleize(hash_tree_most(x))
Eth2Digest(data: root)

View File

@ -79,12 +79,17 @@ suite "Simple serialization":
SSZ.roundripTest BeaconState(slot: 42.Slot)
suite "Tree hashing":
# TODO Nothing but smoke tests for now..
# TODO The test values are taken from an earlier version of SSZ and have
# nothing to do with upstream - needs verification and proper test suite
test "Hash BeaconBlock":
let vr = BeaconBlock()
check: hash_tree_root(vr) != Eth2Digest()
check:
$hash_tree_root(vr) ==
"1BD5D8577A7806CC524C367808C53AE2480F35A3C4BB11A90D6E1AC304E27201"
test "Hash BeaconState":
let vr = BeaconBlock()
check: hash_tree_root(vr) != Eth2Digest()
let vr = BeaconState()
check:
$hash_tree_root(vr) ==
"DC751EF09987283D52483C75690234DDD75FFDAF1A844CD56FE1173465B5597A"