state_transition part of get_attesting_indices/get_attesting_indices_cached perf tweaks

This commit is contained in:
Dustin Brody 2019-03-28 14:27:50 -07:00 committed by zah
parent d74fa2027a
commit 6cbf355c5a
1 changed files with 16 additions and 22 deletions

View File

@ -35,10 +35,6 @@ import
./extras, ./ssz, ./extras, ./ssz,
./spec/[beaconstate, bitfield, crypto, datatypes, digest, helpers, validator] ./spec/[beaconstate, bitfield, crypto, datatypes, digest, helpers, validator]
func flatten[T](v: openArray[seq[T]]): seq[T] =
# TODO not in nim - doh.
for x in v: result.add x
# https://github.com/ethereum/eth2.0-specs/blob/v0.5.1/specs/core/0_beacon-chain.md#block-header # https://github.com/ethereum/eth2.0-specs/blob/v0.5.1/specs/core/0_beacon-chain.md#block-header
proc processBlockHeader( proc processBlockHeader(
state: var BeaconState, blck: BeaconBlock, flags: UpdateFlags): bool = state: var BeaconState, blck: BeaconBlock, flags: UpdateFlags): bool =
@ -529,24 +525,23 @@ func get_attesting_indices(
state: BeaconState, state: BeaconState,
attestations: openArray[PendingAttestation]): HashSet[ValidatorIndex] = attestations: openArray[PendingAttestation]): HashSet[ValidatorIndex] =
# Union of attesters that participated in some attestations # Union of attesters that participated in some attestations
attestations. result = initSet[ValidatorIndex]()
mapIt( for attestation in attestations:
get_attestation_participants(state, it.data, it.aggregation_bitfield)). for validator_index in get_attestation_participants(
flatten(). state, attestation.data, attestation.aggregation_bitfield):
toSet() result.incl validator_index
# sorted(system.cmp) unnecessary
func get_attesting_indices_cached( func get_attesting_indices_cached(
state: BeaconState, state: BeaconState,
attestations: openArray[PendingAttestation], attestations: openArray[PendingAttestation],
crosslink_committee_cache: var auto): seq[ValidatorIndex] = crosslink_committee_cache: var auto): HashSet[ValidatorIndex] =
# Union of attesters that participated in some attestations # Union of attesters that participated in some attestations
attestations. result = initSet[ValidatorIndex]()
mapIt( for attestation in attestations:
get_attestation_participants_cached(state, it.data, it.aggregation_bitfield, crosslink_committee_cache)). for validator_index in get_attestation_participants_cached(
flatten(). state, attestation.data, attestation.aggregation_bitfield,
deduplicate(). crosslink_committee_cache):
sorted(system.cmp) result.incl validator_index
func get_attesting_balance(state: BeaconState, func get_attesting_balance(state: BeaconState,
attestations: seq[PendingAttestation]): Gwei = attestations: seq[PendingAttestation]): Gwei =
@ -589,7 +584,7 @@ func lowerThan(candidate, current: Eth2Digest): bool =
func get_winning_root_and_participants( func get_winning_root_and_participants(
state: BeaconState, shard: Shard, crosslink_committees_cache: var auto): state: BeaconState, shard: Shard, crosslink_committees_cache: var auto):
tuple[a: Eth2Digest, b: seq[ValidatorIndex]] = tuple[a: Eth2Digest, b: HashSet[ValidatorIndex]] =
let let
all_attestations = all_attestations =
concat(state.current_epoch_attestations, concat(state.current_epoch_attestations,
@ -602,7 +597,7 @@ func get_winning_root_and_participants(
# handle when no attestations for shard available # handle when no attestations for shard available
if len(all_roots) == 0: if len(all_roots) == 0:
return (ZERO_HASH, @[]) return (ZERO_HASH, initSet[ValidatorIndex]())
# 0.5.1 spec has less-than-ideal get_attestations_for nested function. # 0.5.1 spec has less-than-ideal get_attestations_for nested function.
var attestations_for = initTable[Eth2Digest, seq[PendingAttestation]]() var attestations_for = initTable[Eth2Digest, seq[PendingAttestation]]()
@ -967,11 +962,10 @@ func get_crosslink_deltas(
state, shard, crosslink_committees_cache) state, shard, crosslink_committees_cache)
else: else:
(ZERO_HASH, winning_root_participants_cache[shard]) (ZERO_HASH, winning_root_participants_cache[shard])
nonquadraticParticipants = toSet(participants)
participating_balance = get_total_balance(state, participants) participating_balance = get_total_balance(state, participants)
total_balance = get_total_balance(state, crosslink_committee) total_balance = get_total_balance(state, crosslink_committee)
for index in crosslink_committee: for index in crosslink_committee:
if index in nonquadraticParticipants: if index in participants:
deltas[0][index] += deltas[0][index] +=
get_base_reward(state, index) * participating_balance div get_base_reward(state, index) * participating_balance div
total_balance total_balance
@ -1105,7 +1099,7 @@ func processEpoch(state: var BeaconState) =
crosslink_committee_cache = crosslink_committee_cache =
initTable[tuple[a: uint64, b: bool], seq[CrosslinkCommittee]]() initTable[tuple[a: uint64, b: bool], seq[CrosslinkCommittee]]()
winning_root_participants_cache = winning_root_participants_cache =
initTable[Shard, seq[ValidatorIndex]]() initTable[Shard, HashSet[ValidatorIndex]]()
# https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#crosslinks # https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#crosslinks
process_crosslinks( process_crosslinks(
state, crosslink_committee_cache, winning_root_participants_cache) state, crosslink_committee_cache, winning_root_participants_cache)