From ced4dbe859c39068230e8f86c8c570322ef7343d Mon Sep 17 00:00:00 2001 From: Dustin Brody Date: Wed, 27 Mar 2019 14:10:15 -0700 Subject: [PATCH] 4x speedup on epoch processing for 64 validators --- beacon_chain/spec/beaconstate.nim | 41 +++++++++++++++- beacon_chain/spec/validator.nim | 11 +++++ beacon_chain/state_transition.nim | 80 +++++++++++++++++++++---------- 3 files changed, 106 insertions(+), 26 deletions(-) diff --git a/beacon_chain/spec/beaconstate.nim b/beacon_chain/spec/beaconstate.nim index 254f7c1b6..5fad3fa92 100644 --- a/beacon_chain/spec/beaconstate.nim +++ b/beacon_chain/spec/beaconstate.nim @@ -8,7 +8,8 @@ import chronicles, math, options, sequtils, ../extras, ../ssz, - ./bitfield, ./crypto, ./datatypes, ./digest, ./helpers, ./validator + ./bitfield, ./crypto, ./datatypes, ./digest, ./helpers, ./validator, + tables # https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#get_effective_balance func get_effective_balance*(state: BeaconState, index: ValidatorIndex): Gwei = @@ -379,6 +380,44 @@ func get_attestation_participants*(state: BeaconState, if aggregation_bit: result.add(validator_index) +func get_attestation_participants_cached*(state: BeaconState, + attestation_data: AttestationData, + bitfield: BitField, + crosslink_committees_cached: var auto): seq[ValidatorIndex] = + ## Return the participant indices at for the ``attestation_data`` and + ## ``bitfield``. + ## Attestation participants in the attestation data are called out in a + ## bit field that corresponds to the committee of the shard at the time; + ## this function converts it to list of indices in to BeaconState.validators + ## + ## Returns empty list if the shard is not found + ## Return the participant indices at for the ``attestation_data`` and ``bitfield``. + ## + # TODO Linear search through shard list? borderline ok, it's a small list + # TODO iterator candidate + + # Find the committee in the list with the desired shard + let crosslink_committees = get_crosslink_committees_at_slot_cached( + state, attestation_data.slot, false, crosslink_committees_cached) + + doAssert anyIt( + crosslink_committees, + it[1] == attestation_data.shard) + let crosslink_committee = mapIt( + filterIt(crosslink_committees, it.shard == attestation_data.shard), + it.committee)[0] + + # TODO this and other attestation-based fields need validation so we don't + # crash on a malicious attestation! + doAssert verify_bitfield(bitfield, len(crosslink_committee)) + + # Find the participating attesters in the committee + result = @[] + for i, validator_index in crosslink_committee: + let aggregation_bit = get_bitfield_bit(bitfield, i) + if aggregation_bit: + result.add(validator_index) + # https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#ejections func process_ejections*(state: var BeaconState) = ## Iterate through the validator registry and eject active validators with diff --git a/beacon_chain/spec/validator.nim b/beacon_chain/spec/validator.nim index 2550c2c2e..39e6621f5 100644 --- a/beacon_chain/spec/validator.nim +++ b/beacon_chain/spec/validator.nim @@ -229,6 +229,17 @@ func get_crosslink_committees_at_slot*(state: BeaconState, slot: Slot|uint64, (slot_start_shard + i.uint64) mod SHARD_COUNT ) +func get_crosslink_committees_at_slot_cached*( + state: BeaconState, slot: Slot|uint64, + registry_change: bool = false, cache: var auto): + seq[CrosslinkCommittee] = + let key = (slot.uint64, registry_change) + if key in cache: + return cache[key] + debugEcho "get_crosslink_committees_at_slot_cached: MISS" + result = get_crosslink_committees_at_slot(state, slot, registry_change) + cache[key] = result + # https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#get_beacon_proposer_index func get_beacon_proposer_index*(state: BeaconState, slot: Slot): ValidatorIndex = ## From Casper RPJ mini-spec: diff --git a/beacon_chain/state_transition.nim b/beacon_chain/state_transition.nim index d6f17af3c..fadca710d 100644 --- a/beacon_chain/state_transition.nim +++ b/beacon_chain/state_transition.nim @@ -536,10 +536,38 @@ func get_attesting_indices( deduplicate(). sorted(system.cmp) +func get_attesting_indices_cached( + state: BeaconState, + attestations: openArray[PendingAttestation], + crosslink_committee_cache: var auto): seq[ValidatorIndex] = + # Union of attesters that participated in some attestations + attestations. + mapIt( + get_attestation_participants_cached(state, it.data, it.aggregation_bitfield, crosslink_committee_cache)). + flatten(). + deduplicate(). + sorted(system.cmp) + func get_attesting_balance(state: BeaconState, attestations: seq[PendingAttestation]): Gwei = get_total_balance(state, get_attesting_indices(state, attestations)) +func get_attesting_balance_cached( + state: BeaconState, + attestations: seq[PendingAttestation], + cache: var auto, + crosslink_committees_cache: var auto): Gwei = + return get_total_balance(state, get_attesting_indices_cached(state, attestations, crosslink_committees_cache)) + #hash_tree_root's overhead's too big; could do custom hashing + #when false: + # let key = hash_tree_root(attestations) + # # caches only held within epoch processing stage, or shorter; ignore state + # if key in cache: + # return cache[key] + # + # result = get_total_balance(state, get_attesting_indices_cached(state, attestations, crosslink_committees_cache)) + # cache[key] = result + func get_current_epoch_boundary_attestations(state: BeaconState): seq[PendingAttestation] = filterIt( @@ -568,7 +596,7 @@ func lowerThan(candidate, current: Eth2Digest): bool = if v > candidate.data[i]: return true false -func get_winning_root_and_participants(state: BeaconState, shard: Shard): +func get_winning_root_and_participants(state: BeaconState, shard: Shard, cache: var auto, crosslink_committees_cache: var auto): tuple[a: Eth2Digest, b: seq[ValidatorIndex]] = let all_attestations = @@ -584,10 +612,15 @@ func get_winning_root_and_participants(state: BeaconState, shard: Shard): if len(all_roots) == 0: return (ZERO_HASH, @[]) - func get_attestations_for(root: Eth2Digest): seq[PendingAttestation] = - filterIt( - valid_attestations, - it.data.crosslink_data_root == root) + # 0.5.1 spec has less-than-ideal get_attestations_for nested function. + var attestations_for = initTable[Eth2Digest, seq[PendingAttestation]]() + for valid_attestation in valid_attestations: + if valid_attestation.data.crosslink_data_root in attestations_for: + attestations_for[valid_attestation.data.crosslink_data_root].add( + valid_attestation) + else: + attestations_for[valid_attestation.data.crosslink_data_root] = + @[valid_attestation] ## Winning crosslink root is the root with the most votes for it, ties broken ## in favor of lexicographically higher hash @@ -597,7 +630,7 @@ func get_winning_root_and_participants(state: BeaconState, shard: Shard): for r in all_roots: let root_balance = - get_attesting_balance(state, get_attestations_for(r)) + get_attesting_balance_cached(state, attestations_for.getOrDefault(r), cache, crosslink_committees_cache) if (root_balance > winning_root_balance or (root_balance == winning_root_balance and lowerThan(winning_root, r))): @@ -605,20 +638,14 @@ func get_winning_root_and_participants(state: BeaconState, shard: Shard): winning_root_balance = root_balance (winning_root, - get_attesting_indices(state, get_attestations_for(winning_root))) + get_attesting_indices(state, attestations_for.getOrDefault(winning_root))) # Combination of earliest_attestation and inclusion_slot avoiding O(n^2) # TODO merge/refactor these two functions, which differ only very slightly. func inclusion_slots(state: BeaconState): auto = result = initTable[ValidatorIndex, Slot]() - let previous_epoch_attestations = - state.previous_epoch_attestations.filterIt( - get_previous_epoch(state) == slot_to_epoch(it.data.slot)) - - ## TODO switch previous_epoch_attestations to state.foo, - ## when implemented finish_epoch_update - for a in sorted(previous_epoch_attestations, + for a in sorted(state.previous_epoch_attestations, func (x, y: PendingAttestation): auto = system.cmp(x.inclusion_slot, y.inclusion_slot)): for v in get_attestation_participants( @@ -630,13 +657,7 @@ func inclusion_slots(state: BeaconState): auto = func inclusion_distances(state: BeaconState): auto = result = initTable[ValidatorIndex, Slot]() - let previous_epoch_attestations = - state.previous_epoch_attestations.filterIt( - get_previous_epoch(state) == slot_to_epoch(it.data.slot)) - - ## TODO switch previous_epoch_attestations to state.foo, - ## when implemented finish_epoch_update - for a in sorted(previous_epoch_attestations, + for a in sorted(state.previous_epoch_attestations, func (x, y: PendingAttestation): auto = system.cmp(x.inclusion_slot, y.inclusion_slot)): for v in get_attestation_participants( @@ -719,6 +740,10 @@ func process_crosslinks(state: var BeaconState) = current_epoch = get_current_epoch(state) previous_epoch = current_epoch - 1 next_epoch = current_epoch + 1 + + var + attester_balance_cache = initTable[Eth2Digest, uint64]() + crosslink_committee_cache = initTable[tuple[a: uint64, b: bool], seq[CrosslinkCommittee]]() ## TODO is it actually correct to be setting state.latest_crosslinks[shard] ## to something pre-GENESIS_EPOCH, ever? I guess the intent is if there are ## a quorum of participants for get_epoch_start_slot(previous_epoch), when @@ -728,11 +753,12 @@ func process_crosslinks(state: var BeaconState) = for slot in max( GENESIS_SLOT.uint64, get_epoch_start_slot(previous_epoch).uint64) ..< get_epoch_start_slot(next_epoch).uint64: - for cas in get_crosslink_committees_at_slot(state, slot): + for cas in get_crosslink_committees_at_slot_cached(state, slot, false, crosslink_committee_cache): let (crosslink_committee, shard) = cas (winning_root, participants) = - get_winning_root_and_participants(state, shard) + get_winning_root_and_participants( + state, shard, attester_balance_cache, crosslink_committee_cache) participating_balance = get_total_balance(state, participants) total_balance = get_total_balance(state, crosslink_committee) @@ -912,6 +938,9 @@ func get_crosslink_deltas(state: BeaconState): repeat(0'u64, len(state.validator_registry)), repeat(0'u64, len(state.validator_registry)) ) + var + attester_balance_cache = initTable[Eth2Digest, uint64]() + crosslink_committee_cache = initTable[tuple[a: uint64, b: bool], seq[CrosslinkCommittee]]() let previous_epoch_start_slot = get_epoch_start_slot(get_previous_epoch(state)) @@ -919,11 +948,12 @@ func get_crosslink_deltas(state: BeaconState): get_epoch_start_slot(get_current_epoch(state)) for slot in previous_epoch_start_slot.uint64 ..< current_epoch_start_slot.uint64: - for cas in get_crosslink_committees_at_slot(state, slot): + for cas in get_crosslink_committees_at_slot_cached(state, slot, false, crosslink_committee_cache): let (crosslink_committee, shard) = cas (winning_root, participants) = - get_winning_root_and_participants(state, shard) + get_winning_root_and_participants( + state, shard, attester_balance_cache, crosslink_committee_cache) participating_balance = get_total_balance(state, participants) total_balance = get_total_balance(state, crosslink_committee) for index in crosslink_committee: