Some speedups (#226)

* ssz: avoid memory allocations

* a bit fishy with the 32-item stack.. this should be a smallvector

* digest: avoid another burnmem

* avoid a few allocations by using iterator
This commit is contained in:
Jacek Sieka 2019-04-03 09:46:22 -06:00 committed by Dustin Brody
parent ad133a0222
commit 605dd0a0e9
5 changed files with 128 additions and 100 deletions

View File

@ -380,10 +380,10 @@ func get_attestation_participants*(state: BeaconState,
if aggregation_bit: if aggregation_bit:
result.add(validator_index) result.add(validator_index)
func get_attestation_participants_cached*(state: BeaconState, iterator get_attestation_participants_cached*(state: BeaconState,
attestation_data: AttestationData, attestation_data: AttestationData,
bitfield: BitField, bitfield: BitField,
crosslink_committees_cached: var auto): seq[ValidatorIndex] = crosslink_committees_cached: var auto): ValidatorIndex =
## Return the participant indices at for the ``attestation_data`` and ## Return the participant indices at for the ``attestation_data`` and
## ``bitfield``. ## ``bitfield``.
## Attestation participants in the attestation data are called out in a ## Attestation participants in the attestation data are called out in a
@ -397,26 +397,25 @@ func get_attestation_participants_cached*(state: BeaconState,
# TODO iterator candidate # TODO iterator candidate
# Find the committee in the list with the desired shard # Find the committee in the list with the desired shard
let crosslink_committees = get_crosslink_committees_at_slot_cached( # let crosslink_committees = get_crosslink_committees_at_slot_cached(
state, attestation_data.slot, false, crosslink_committees_cached) # state, attestation_data.slot, false, crosslink_committees_cached)
doAssert anyIt( var found = false
crosslink_committees, for crosslink_committee in get_crosslink_committees_at_slot_cached(
it[1] == attestation_data.shard) state, attestation_data.slot, false, crosslink_committees_cached):
let crosslink_committee = mapIt( if crosslink_committee.shard == attestation_data.shard:
filterIt(crosslink_committees, it.shard == attestation_data.shard), # TODO this and other attestation-based fields need validation so we don't
it.committee)[0] # crash on a malicious attestation!
doAssert verify_bitfield(bitfield, len(crosslink_committee.committee))
# TODO this and other attestation-based fields need validation so we don't # Find the participating attesters in the committee
# crash on a malicious attestation! for i, validator_index in crosslink_committee.committee:
doAssert verify_bitfield(bitfield, len(crosslink_committee)) let aggregation_bit = get_bitfield_bit(bitfield, i)
if aggregation_bit:
# Find the participating attesters in the committee yield validator_index
result = @[] found = true
for i, validator_index in crosslink_committee: break
let aggregation_bit = get_bitfield_bit(bitfield, i) doAssert found, "Couldn't find crosslink committee"
if aggregation_bit:
result.add(validator_index)
# https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#ejections # https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#ejections
func process_ejections*(state: var BeaconState) = func process_ejections*(state: var BeaconState) =

View File

@ -36,7 +36,7 @@ func shortLog*(x: Eth2Digest): string =
result = ($x)[0..7] result = ($x)[0..7]
func eth2hash*(v: openArray[byte]): Eth2Digest = func eth2hash*(v: openArray[byte]): Eth2Digest =
var ctx: Eth2Hash var ctx: keccak256 # use explicit type so we can rely on init being useless
# We can avoid this step for Keccak/SHA3 digests because `ctx` is already # We can avoid this step for Keccak/SHA3 digests because `ctx` is already
# empty, but if digest will be changed next line must be enabled. # empty, but if digest will be changed next line must be enabled.
# ctx.init() # ctx.init()
@ -47,8 +47,8 @@ template withEth2Hash*(body: untyped): Eth2Digest =
## This little helper will init the hash function and return the sliced ## This little helper will init the hash function and return the sliced
## hash: ## hash:
## let hashOfData = withHash: h.update(data) ## let hashOfData = withHash: h.update(data)
var h {.inject.}: Eth2Hash var h {.inject.}: keccak256
h.init() # TODO no need, as long as using keccak256: h.init()
body body
var res = h.finish() var res = h.finish()
res res

View File

@ -229,16 +229,17 @@ func get_crosslink_committees_at_slot*(state: BeaconState, slot: Slot|uint64,
(slot_start_shard + i.uint64) mod SHARD_COUNT (slot_start_shard + i.uint64) mod SHARD_COUNT
) )
func get_crosslink_committees_at_slot_cached*( iterator get_crosslink_committees_at_slot_cached*(
state: BeaconState, slot: Slot|uint64, state: BeaconState, slot: Slot|uint64,
registry_change: bool = false, cache: var auto): registry_change: bool = false, cache: var auto):
seq[CrosslinkCommittee] = CrosslinkCommittee =
let key = (slot.uint64, registry_change) let key = (slot.uint64, registry_change)
if key in cache: if key in cache:
return cache[key] for v in cache[key]: yield v
#debugEcho "get_crosslink_committees_at_slot_cached: MISS" #debugEcho "get_crosslink_committees_at_slot_cached: MISS"
result = get_crosslink_committees_at_slot(state, slot, registry_change) let result = get_crosslink_committees_at_slot(state, slot, registry_change)
cache[key] = result cache[key] = result
for v in result: yield v
# https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#get_beacon_proposer_index # https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#get_beacon_proposer_index
func get_beacon_proposer_index*(state: BeaconState, slot: Slot): ValidatorIndex = func get_beacon_proposer_index*(state: BeaconState, slot: Slot): ValidatorIndex =

View File

@ -297,71 +297,99 @@ func mix_in_length(root: Chunk, length: int): Chunk =
hash(root, dataLen) hash(root, dataLen)
proc pack(values: seq|array): iterator(): Chunk = template padEmptyChunks(chunks: int) =
result = iterator (): Chunk = for i in chunks..<nextPowerOfTwo(chunks):
# TODO should be trivial to avoid this seq also.. yield emptyChunk
# TODO I get a feeling a copy of the array is taken to the closure, which
# also needs fixing iterator packAndPad(values: seq|array): Chunk =
# TODO avoid closure iterators that involve GC ## Produce a stream of chunks that are packed and padded such that they number
var tmp = ## a power of two
newSeqOfCap[byte](values.len() * sizeof(toBytesSSZ(values[0].toSSZType())))
when sizeof(values[0].toSSZType().toBytesSSZ()) == sizeof(Chunk):
# When chunks and value lengths coincide, do the simple thing
for v in values: for v in values:
tmp.add toBytesSSZ(v.toSSZType) yield v.toSSZType().toBytesSSZ()
padEmptyChunks(values.len)
for v in 0..<tmp.len div sizeof(Chunk): else:
var c: Chunk var
copyMem(addr c, addr tmp[v * sizeof(Chunk)], sizeof(Chunk)) chunks: int
yield c tmp: Chunk
tmpPos: int # how many bytes of tmp we've filled with ssz values
let remains = tmp.len mod sizeof(Chunk) for v in values:
if remains != 0: var
var c: Chunk vssz = toBytesSSZ(v.toSSZType)
copyMem(addr c, addr tmp[tmp.len - remains], remains) vPos = 0 # how many bytes of vssz that we've consumed
yield c
proc pad(iter: iterator(): Chunk): iterator(): Chunk = while vPos < vssz.len:
# Pad a list of chunks to the next power-of-two length with empty chunks - # there are still bytes of vssz left to consume - looping happens when
# this includes ensuring there's at least one chunk return # vssz.len > sizeof(Chunk)
result = iterator(): Chunk =
var count = 0
while true: let left = min(tmp.len - tmpPos, vssz.len - vPos)
let item = iter() copyMem(addr tmp[tmpPos], addr vssz[vPos], left)
if finished(iter): break vPos += left
count += 1 tmpPos += left
yield item
doAssert nextPowerOfTwo(0) == 1, if tmpPos == tmp.len:
"Usefully, empty lists will be padded to one empty block" # When vssz.len < sizeof(Chunk), multiple values will fit in a chunk
yield tmp
tmpPos = 0
chunks += 1
for _ in count..<nextPowerOfTwo(count): if tmpPos > 0:
yield emptyChunk # If vssz.len is not a multiple of Chunk, we might need to pad the last
# chunk with zeroes and return it
for i in tmpPos..<tmp.len:
tmp[i] = 0'u8
yield tmp
tmpPos = 0
chunks += 1
func merkleize(chunker: iterator(): Chunk): Chunk = padEmptyChunks(chunks)
iterator hash_tree_collection(value: array|seq): Chunk =
mixin hash_tree_root
var chunks = 0
for v in value:
yield hash_tree_root(v).data
chunks += 1
padEmptyChunks(chunks)
iterator hash_tree_fields(value: object): Chunk =
mixin hash_tree_root
var chunks = 0
for v in value.fields:
yield hash_tree_root(v).data
chunks += 1
padEmptyChunks(chunks)
template merkleize(chunker: untyped): Chunk =
var var
stack: seq[tuple[height: int, chunk: Chunk]] # a depth of 32 here should give us capability to handle 2^32 chunks,
paddedChunker = pad(chunker) # more than enough
# TODO replace with SmallVector-like thing..
while true: stack: array[32, tuple[height: int, chunk: Chunk]]
let chunk = paddedChunker() stackPos = 0
if finished(paddedChunker): break
for chunk in chunker:
# Leaves start at height 0 - every time they move up, height is increased # Leaves start at height 0 - every time they move up, height is increased
# allowing us to detect two chunks at the same height ready for # allowing us to detect two chunks at the same height ready for
# consolidation # consolidation
# See also: http://szydlo.com/logspacetime03.pdf # See also: http://szydlo.com/logspacetime03.pdf
stack.add (0, chunk) stack[stackPos] = (0, chunk)
inc stackPos
# Consolidate items of the same height - this keeps stack size at log N # Consolidate items of the same height - this keeps stack size at log N
while stack.len > 1 and stack[^1].height == stack[^2].height: while stackPos > 1 and stack[stackPos - 1].height == stack[stackPos - 2].height:
# As tradition dictates - one feature, at least one nim bug: # As tradition dictates - one feature, at least one nim bug:
# https://github.com/nim-lang/Nim/issues/9684 # https://github.com/nim-lang/Nim/issues/9684
let tmp = hash(stack[^2].chunk, stack[^1].chunk) let tmp = hash(stack[stackPos - 2].chunk, stack[stackPos - 1].chunk)
stack[^2].height += 1 stack[stackPos - 2].height += 1
stack[^2].chunk = tmp stack[stackPos - 2].chunk = tmp
discard stack.pop stackPos -= 1
doAssert stack.len == 1, doAssert stackPos == 1,
"With power-of-two leaves, we should end up with a single root" "With power-of-two leaves, we should end up with a single root"
stack[0].chunk stack[0].chunk
@ -373,40 +401,35 @@ func hash_tree_root*[T](value: T): Eth2Digest =
# Merkle tree # Merkle tree
Eth2Digest(data: Eth2Digest(data:
when T is BasicType: when T is BasicType:
merkleize(pack([value])) merkleize(packAndPad([value]))
elif T is array|seq: elif T is array|seq:
when T.elementType() is BasicType: when T.elementType() is BasicType:
mix_in_length(merkleize(pack(value)), len(value)) mix_in_length(merkleize(packAndPad(value)), len(value))
else: else:
var roots = iterator(): Chunk = mix_in_length(merkleize(hash_tree_collection(value)), len(value))
for v in value:
yield hash_tree_root(v).data
mix_in_length(merkleize(roots), len(value))
elif T is object: elif T is object:
var roots = iterator(): Chunk = merkleize(hash_tree_fields(value))
for v in value.fields:
yield hash_tree_root(v).data
merkleize(roots)
else: else:
static: doAssert false, "Unexpected type: " & T.name static: doAssert false, "Unexpected type: " & T.name
) )
iterator hash_tree_most(v: object): Chunk =
var found_field_name = false
for name, field in v.fieldPairs:
# TODO we should truncate the last field, regardless of its name.. this
# hack works for now - how to skip the last fieldPair though??
if name == "signature":
found_field_name = true
break
yield hash_tree_root(field).data
doAssert found_field_name
# https://github.com/ethereum/eth2.0-specs/blob/0.4.0/specs/simple-serialize.md#signed-roots # https://github.com/ethereum/eth2.0-specs/blob/0.4.0/specs/simple-serialize.md#signed-roots
func signed_root*[T: object](x: T): Eth2Digest = func signed_root*[T: object](x: T): Eth2Digest =
# TODO write tests for this (check vs hash_tree_root) # TODO write tests for this (check vs hash_tree_root)
var found_field_name = false let root = merkleize(hash_tree_most(x))
var roots = iterator(): Chunk =
for name, field in x.fieldPairs:
# TODO we should truncate the last field, regardless of its name.. this
# hack works for now - how to skip the last fieldPair though??
if name == "signature":
found_field_name = true
break
yield hash_tree_root(field).data
let root = merkleize(roots)
doAssert found_field_name
Eth2Digest(data: root) Eth2Digest(data: root)

View File

@ -79,12 +79,17 @@ suite "Simple serialization":
SSZ.roundripTest BeaconState(slot: 42.Slot) SSZ.roundripTest BeaconState(slot: 42.Slot)
suite "Tree hashing": suite "Tree hashing":
# TODO Nothing but smoke tests for now.. # TODO The test values are taken from an earlier version of SSZ and have
# nothing to do with upstream - needs verification and proper test suite
test "Hash BeaconBlock": test "Hash BeaconBlock":
let vr = BeaconBlock() let vr = BeaconBlock()
check: hash_tree_root(vr) != Eth2Digest() check:
$hash_tree_root(vr) ==
"1BD5D8577A7806CC524C367808C53AE2480F35A3C4BB11A90D6E1AC304E27201"
test "Hash BeaconState": test "Hash BeaconState":
let vr = BeaconBlock() let vr = BeaconState()
check: hash_tree_root(vr) != Eth2Digest() check:
$hash_tree_root(vr) ==
"DC751EF09987283D52483C75690234DDD75FFDAF1A844CD56FE1173465B5597A"