From 605dd0a0e9969633acd169881fc1f408a4089643 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Wed, 3 Apr 2019 09:46:22 -0600 Subject: [PATCH] Some speedups (#226) * ssz: avoid memory allocations * a bit fishy with the 32-item stack.. this should be a smallvector * digest: avoid another burnmem * avoid a few allocations by using iterator --- beacon_chain/spec/beaconstate.nim | 39 ++++---- beacon_chain/spec/digest.nim | 6 +- beacon_chain/spec/validator.nim | 9 +- beacon_chain/ssz.nim | 161 +++++++++++++++++------------- tests/test_ssz.nim | 13 ++- 5 files changed, 128 insertions(+), 100 deletions(-) diff --git a/beacon_chain/spec/beaconstate.nim b/beacon_chain/spec/beaconstate.nim index ca9ea9b3c..fcf150df1 100644 --- a/beacon_chain/spec/beaconstate.nim +++ b/beacon_chain/spec/beaconstate.nim @@ -380,10 +380,10 @@ func get_attestation_participants*(state: BeaconState, if aggregation_bit: result.add(validator_index) -func get_attestation_participants_cached*(state: BeaconState, +iterator get_attestation_participants_cached*(state: BeaconState, attestation_data: AttestationData, bitfield: BitField, - crosslink_committees_cached: var auto): seq[ValidatorIndex] = + crosslink_committees_cached: var auto): ValidatorIndex = ## Return the participant indices at for the ``attestation_data`` and ## ``bitfield``. ## Attestation participants in the attestation data are called out in a @@ -397,26 +397,25 @@ func get_attestation_participants_cached*(state: BeaconState, # TODO iterator candidate # Find the committee in the list with the desired shard - let crosslink_committees = get_crosslink_committees_at_slot_cached( - state, attestation_data.slot, false, crosslink_committees_cached) + # let crosslink_committees = get_crosslink_committees_at_slot_cached( + # state, attestation_data.slot, false, crosslink_committees_cached) - doAssert anyIt( - crosslink_committees, - it[1] == attestation_data.shard) - let crosslink_committee = mapIt( - filterIt(crosslink_committees, it.shard == attestation_data.shard), - it.committee)[0] + var found = false + for crosslink_committee in get_crosslink_committees_at_slot_cached( + state, attestation_data.slot, false, crosslink_committees_cached): + if crosslink_committee.shard == attestation_data.shard: + # TODO this and other attestation-based fields need validation so we don't + # crash on a malicious attestation! + doAssert verify_bitfield(bitfield, len(crosslink_committee.committee)) - # TODO this and other attestation-based fields need validation so we don't - # crash on a malicious attestation! - doAssert verify_bitfield(bitfield, len(crosslink_committee)) - - # Find the participating attesters in the committee - result = @[] - for i, validator_index in crosslink_committee: - let aggregation_bit = get_bitfield_bit(bitfield, i) - if aggregation_bit: - result.add(validator_index) + # Find the participating attesters in the committee + for i, validator_index in crosslink_committee.committee: + let aggregation_bit = get_bitfield_bit(bitfield, i) + if aggregation_bit: + yield validator_index + found = true + break + doAssert found, "Couldn't find crosslink committee" # https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#ejections func process_ejections*(state: var BeaconState) = diff --git a/beacon_chain/spec/digest.nim b/beacon_chain/spec/digest.nim index 9bd098eaf..c42eb0f19 100644 --- a/beacon_chain/spec/digest.nim +++ b/beacon_chain/spec/digest.nim @@ -36,7 +36,7 @@ func shortLog*(x: Eth2Digest): string = result = ($x)[0..7] func eth2hash*(v: openArray[byte]): Eth2Digest = - var ctx: Eth2Hash + var ctx: keccak256 # use explicit type so we can rely on init being useless # We can avoid this step for Keccak/SHA3 digests because `ctx` is already # empty, but if digest will be changed next line must be enabled. # ctx.init() @@ -47,8 +47,8 @@ template withEth2Hash*(body: untyped): Eth2Digest = ## This little helper will init the hash function and return the sliced ## hash: ## let hashOfData = withHash: h.update(data) - var h {.inject.}: Eth2Hash - h.init() + var h {.inject.}: keccak256 + # TODO no need, as long as using keccak256: h.init() body var res = h.finish() res diff --git a/beacon_chain/spec/validator.nim b/beacon_chain/spec/validator.nim index 3e9a0dc47..bfbf3f62f 100644 --- a/beacon_chain/spec/validator.nim +++ b/beacon_chain/spec/validator.nim @@ -229,16 +229,17 @@ func get_crosslink_committees_at_slot*(state: BeaconState, slot: Slot|uint64, (slot_start_shard + i.uint64) mod SHARD_COUNT ) -func get_crosslink_committees_at_slot_cached*( +iterator get_crosslink_committees_at_slot_cached*( state: BeaconState, slot: Slot|uint64, registry_change: bool = false, cache: var auto): - seq[CrosslinkCommittee] = + CrosslinkCommittee = let key = (slot.uint64, registry_change) if key in cache: - return cache[key] + for v in cache[key]: yield v #debugEcho "get_crosslink_committees_at_slot_cached: MISS" - result = get_crosslink_committees_at_slot(state, slot, registry_change) + let result = get_crosslink_committees_at_slot(state, slot, registry_change) cache[key] = result + for v in result: yield v # https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#get_beacon_proposer_index func get_beacon_proposer_index*(state: BeaconState, slot: Slot): ValidatorIndex = diff --git a/beacon_chain/ssz.nim b/beacon_chain/ssz.nim index c77675bdc..082674ade 100644 --- a/beacon_chain/ssz.nim +++ b/beacon_chain/ssz.nim @@ -297,71 +297,99 @@ func mix_in_length(root: Chunk, length: int): Chunk = hash(root, dataLen) -proc pack(values: seq|array): iterator(): Chunk = - result = iterator (): Chunk = - # TODO should be trivial to avoid this seq also.. - # TODO I get a feeling a copy of the array is taken to the closure, which - # also needs fixing - # TODO avoid closure iterators that involve GC - var tmp = - newSeqOfCap[byte](values.len() * sizeof(toBytesSSZ(values[0].toSSZType()))) +template padEmptyChunks(chunks: int) = + for i in chunks.. sizeof(Chunk) - while true: - let item = iter() - if finished(iter): break - count += 1 - yield item + let left = min(tmp.len - tmpPos, vssz.len - vPos) + copyMem(addr tmp[tmpPos], addr vssz[vPos], left) + vPos += left + tmpPos += left - doAssert nextPowerOfTwo(0) == 1, - "Usefully, empty lists will be padded to one empty block" + if tmpPos == tmp.len: + # When vssz.len < sizeof(Chunk), multiple values will fit in a chunk + yield tmp + tmpPos = 0 + chunks += 1 - for _ in count.. 0: + # If vssz.len is not a multiple of Chunk, we might need to pad the last + # chunk with zeroes and return it + for i in tmpPos.. 1 and stack[^1].height == stack[^2].height: + while stackPos > 1 and stack[stackPos - 1].height == stack[stackPos - 2].height: # As tradition dictates - one feature, at least one nim bug: # https://github.com/nim-lang/Nim/issues/9684 - let tmp = hash(stack[^2].chunk, stack[^1].chunk) - stack[^2].height += 1 - stack[^2].chunk = tmp - discard stack.pop + let tmp = hash(stack[stackPos - 2].chunk, stack[stackPos - 1].chunk) + stack[stackPos - 2].height += 1 + stack[stackPos - 2].chunk = tmp + stackPos -= 1 - doAssert stack.len == 1, + doAssert stackPos == 1, "With power-of-two leaves, we should end up with a single root" stack[0].chunk @@ -373,40 +401,35 @@ func hash_tree_root*[T](value: T): Eth2Digest = # Merkle tree Eth2Digest(data: when T is BasicType: - merkleize(pack([value])) + merkleize(packAndPad([value])) elif T is array|seq: when T.elementType() is BasicType: - mix_in_length(merkleize(pack(value)), len(value)) + mix_in_length(merkleize(packAndPad(value)), len(value)) else: - var roots = iterator(): Chunk = - for v in value: - yield hash_tree_root(v).data - mix_in_length(merkleize(roots), len(value)) + mix_in_length(merkleize(hash_tree_collection(value)), len(value)) elif T is object: - var roots = iterator(): Chunk = - for v in value.fields: - yield hash_tree_root(v).data - merkleize(roots) + merkleize(hash_tree_fields(value)) else: static: doAssert false, "Unexpected type: " & T.name ) +iterator hash_tree_most(v: object): Chunk = + var found_field_name = false + + for name, field in v.fieldPairs: + # TODO we should truncate the last field, regardless of its name.. this + # hack works for now - how to skip the last fieldPair though?? + if name == "signature": + found_field_name = true + break + yield hash_tree_root(field).data + + doAssert found_field_name + # https://github.com/ethereum/eth2.0-specs/blob/0.4.0/specs/simple-serialize.md#signed-roots func signed_root*[T: object](x: T): Eth2Digest = # TODO write tests for this (check vs hash_tree_root) - var found_field_name = false - var roots = iterator(): Chunk = - for name, field in x.fieldPairs: - # TODO we should truncate the last field, regardless of its name.. this - # hack works for now - how to skip the last fieldPair though?? - if name == "signature": - found_field_name = true - break - yield hash_tree_root(field).data - - let root = merkleize(roots) - - doAssert found_field_name + let root = merkleize(hash_tree_most(x)) Eth2Digest(data: root) diff --git a/tests/test_ssz.nim b/tests/test_ssz.nim index 6cc2e4c74..5a76c1690 100644 --- a/tests/test_ssz.nim +++ b/tests/test_ssz.nim @@ -79,12 +79,17 @@ suite "Simple serialization": SSZ.roundripTest BeaconState(slot: 42.Slot) suite "Tree hashing": - # TODO Nothing but smoke tests for now.. + # TODO The test values are taken from an earlier version of SSZ and have + # nothing to do with upstream - needs verification and proper test suite test "Hash BeaconBlock": let vr = BeaconBlock() - check: hash_tree_root(vr) != Eth2Digest() + check: + $hash_tree_root(vr) == + "1BD5D8577A7806CC524C367808C53AE2480F35A3C4BB11A90D6E1AC304E27201" test "Hash BeaconState": - let vr = BeaconBlock() - check: hash_tree_root(vr) != Eth2Digest() + let vr = BeaconState() + check: + $hash_tree_root(vr) == + "DC751EF09987283D52483C75690234DDD75FFDAF1A844CD56FE1173465B5597A"