4x speedup on epoch processing for 64 validators

This commit is contained in:
Dustin Brody 2019-03-27 14:10:15 -07:00 committed by zah
parent 3ad6053298
commit ced4dbe859
3 changed files with 106 additions and 26 deletions

View File

@ -8,7 +8,8 @@
import
chronicles, math, options, sequtils,
../extras, ../ssz,
./bitfield, ./crypto, ./datatypes, ./digest, ./helpers, ./validator
./bitfield, ./crypto, ./datatypes, ./digest, ./helpers, ./validator,
tables
# https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#get_effective_balance
func get_effective_balance*(state: BeaconState, index: ValidatorIndex): Gwei =
@ -379,6 +380,44 @@ func get_attestation_participants*(state: BeaconState,
if aggregation_bit:
result.add(validator_index)
func get_attestation_participants_cached*(state: BeaconState,
attestation_data: AttestationData,
bitfield: BitField,
crosslink_committees_cached: var auto): seq[ValidatorIndex] =
## Return the participant indices at for the ``attestation_data`` and
## ``bitfield``.
## Attestation participants in the attestation data are called out in a
## bit field that corresponds to the committee of the shard at the time;
## this function converts it to list of indices in to BeaconState.validators
##
## Returns empty list if the shard is not found
## Return the participant indices at for the ``attestation_data`` and ``bitfield``.
##
# TODO Linear search through shard list? borderline ok, it's a small list
# TODO iterator candidate
# Find the committee in the list with the desired shard
let crosslink_committees = get_crosslink_committees_at_slot_cached(
state, attestation_data.slot, false, crosslink_committees_cached)
doAssert anyIt(
crosslink_committees,
it[1] == attestation_data.shard)
let crosslink_committee = mapIt(
filterIt(crosslink_committees, it.shard == attestation_data.shard),
it.committee)[0]
# TODO this and other attestation-based fields need validation so we don't
# crash on a malicious attestation!
doAssert verify_bitfield(bitfield, len(crosslink_committee))
# Find the participating attesters in the committee
result = @[]
for i, validator_index in crosslink_committee:
let aggregation_bit = get_bitfield_bit(bitfield, i)
if aggregation_bit:
result.add(validator_index)
# https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#ejections
func process_ejections*(state: var BeaconState) =
## Iterate through the validator registry and eject active validators with

View File

@ -229,6 +229,17 @@ func get_crosslink_committees_at_slot*(state: BeaconState, slot: Slot|uint64,
(slot_start_shard + i.uint64) mod SHARD_COUNT
)
func get_crosslink_committees_at_slot_cached*(
state: BeaconState, slot: Slot|uint64,
registry_change: bool = false, cache: var auto):
seq[CrosslinkCommittee] =
let key = (slot.uint64, registry_change)
if key in cache:
return cache[key]
debugEcho "get_crosslink_committees_at_slot_cached: MISS"
result = get_crosslink_committees_at_slot(state, slot, registry_change)
cache[key] = result
# https://github.com/ethereum/eth2.0-specs/blob/v0.5.0/specs/core/0_beacon-chain.md#get_beacon_proposer_index
func get_beacon_proposer_index*(state: BeaconState, slot: Slot): ValidatorIndex =
## From Casper RPJ mini-spec:

View File

@ -536,10 +536,38 @@ func get_attesting_indices(
deduplicate().
sorted(system.cmp)
func get_attesting_indices_cached(
state: BeaconState,
attestations: openArray[PendingAttestation],
crosslink_committee_cache: var auto): seq[ValidatorIndex] =
# Union of attesters that participated in some attestations
attestations.
mapIt(
get_attestation_participants_cached(state, it.data, it.aggregation_bitfield, crosslink_committee_cache)).
flatten().
deduplicate().
sorted(system.cmp)
func get_attesting_balance(state: BeaconState,
attestations: seq[PendingAttestation]): Gwei =
get_total_balance(state, get_attesting_indices(state, attestations))
func get_attesting_balance_cached(
state: BeaconState,
attestations: seq[PendingAttestation],
cache: var auto,
crosslink_committees_cache: var auto): Gwei =
return get_total_balance(state, get_attesting_indices_cached(state, attestations, crosslink_committees_cache))
#hash_tree_root's overhead's too big; could do custom hashing
#when false:
# let key = hash_tree_root(attestations)
# # caches only held within epoch processing stage, or shorter; ignore state
# if key in cache:
# return cache[key]
#
# result = get_total_balance(state, get_attesting_indices_cached(state, attestations, crosslink_committees_cache))
# cache[key] = result
func get_current_epoch_boundary_attestations(state: BeaconState):
seq[PendingAttestation] =
filterIt(
@ -568,7 +596,7 @@ func lowerThan(candidate, current: Eth2Digest): bool =
if v > candidate.data[i]: return true
false
func get_winning_root_and_participants(state: BeaconState, shard: Shard):
func get_winning_root_and_participants(state: BeaconState, shard: Shard, cache: var auto, crosslink_committees_cache: var auto):
tuple[a: Eth2Digest, b: seq[ValidatorIndex]] =
let
all_attestations =
@ -584,10 +612,15 @@ func get_winning_root_and_participants(state: BeaconState, shard: Shard):
if len(all_roots) == 0:
return (ZERO_HASH, @[])
func get_attestations_for(root: Eth2Digest): seq[PendingAttestation] =
filterIt(
valid_attestations,
it.data.crosslink_data_root == root)
# 0.5.1 spec has less-than-ideal get_attestations_for nested function.
var attestations_for = initTable[Eth2Digest, seq[PendingAttestation]]()
for valid_attestation in valid_attestations:
if valid_attestation.data.crosslink_data_root in attestations_for:
attestations_for[valid_attestation.data.crosslink_data_root].add(
valid_attestation)
else:
attestations_for[valid_attestation.data.crosslink_data_root] =
@[valid_attestation]
## Winning crosslink root is the root with the most votes for it, ties broken
## in favor of lexicographically higher hash
@ -597,7 +630,7 @@ func get_winning_root_and_participants(state: BeaconState, shard: Shard):
for r in all_roots:
let root_balance =
get_attesting_balance(state, get_attestations_for(r))
get_attesting_balance_cached(state, attestations_for.getOrDefault(r), cache, crosslink_committees_cache)
if (root_balance > winning_root_balance or
(root_balance == winning_root_balance and
lowerThan(winning_root, r))):
@ -605,20 +638,14 @@ func get_winning_root_and_participants(state: BeaconState, shard: Shard):
winning_root_balance = root_balance
(winning_root,
get_attesting_indices(state, get_attestations_for(winning_root)))
get_attesting_indices(state, attestations_for.getOrDefault(winning_root)))
# Combination of earliest_attestation and inclusion_slot avoiding O(n^2)
# TODO merge/refactor these two functions, which differ only very slightly.
func inclusion_slots(state: BeaconState): auto =
result = initTable[ValidatorIndex, Slot]()
let previous_epoch_attestations =
state.previous_epoch_attestations.filterIt(
get_previous_epoch(state) == slot_to_epoch(it.data.slot))
## TODO switch previous_epoch_attestations to state.foo,
## when implemented finish_epoch_update
for a in sorted(previous_epoch_attestations,
for a in sorted(state.previous_epoch_attestations,
func (x, y: PendingAttestation): auto =
system.cmp(x.inclusion_slot, y.inclusion_slot)):
for v in get_attestation_participants(
@ -630,13 +657,7 @@ func inclusion_slots(state: BeaconState): auto =
func inclusion_distances(state: BeaconState): auto =
result = initTable[ValidatorIndex, Slot]()
let previous_epoch_attestations =
state.previous_epoch_attestations.filterIt(
get_previous_epoch(state) == slot_to_epoch(it.data.slot))
## TODO switch previous_epoch_attestations to state.foo,
## when implemented finish_epoch_update
for a in sorted(previous_epoch_attestations,
for a in sorted(state.previous_epoch_attestations,
func (x, y: PendingAttestation): auto =
system.cmp(x.inclusion_slot, y.inclusion_slot)):
for v in get_attestation_participants(
@ -719,6 +740,10 @@ func process_crosslinks(state: var BeaconState) =
current_epoch = get_current_epoch(state)
previous_epoch = current_epoch - 1
next_epoch = current_epoch + 1
var
attester_balance_cache = initTable[Eth2Digest, uint64]()
crosslink_committee_cache = initTable[tuple[a: uint64, b: bool], seq[CrosslinkCommittee]]()
## TODO is it actually correct to be setting state.latest_crosslinks[shard]
## to something pre-GENESIS_EPOCH, ever? I guess the intent is if there are
## a quorum of participants for get_epoch_start_slot(previous_epoch), when
@ -728,11 +753,12 @@ func process_crosslinks(state: var BeaconState) =
for slot in max(
GENESIS_SLOT.uint64, get_epoch_start_slot(previous_epoch).uint64) ..<
get_epoch_start_slot(next_epoch).uint64:
for cas in get_crosslink_committees_at_slot(state, slot):
for cas in get_crosslink_committees_at_slot_cached(state, slot, false, crosslink_committee_cache):
let
(crosslink_committee, shard) = cas
(winning_root, participants) =
get_winning_root_and_participants(state, shard)
get_winning_root_and_participants(
state, shard, attester_balance_cache, crosslink_committee_cache)
participating_balance = get_total_balance(state, participants)
total_balance = get_total_balance(state, crosslink_committee)
@ -912,6 +938,9 @@ func get_crosslink_deltas(state: BeaconState):
repeat(0'u64, len(state.validator_registry)),
repeat(0'u64, len(state.validator_registry))
)
var
attester_balance_cache = initTable[Eth2Digest, uint64]()
crosslink_committee_cache = initTable[tuple[a: uint64, b: bool], seq[CrosslinkCommittee]]()
let
previous_epoch_start_slot =
get_epoch_start_slot(get_previous_epoch(state))
@ -919,11 +948,12 @@ func get_crosslink_deltas(state: BeaconState):
get_epoch_start_slot(get_current_epoch(state))
for slot in previous_epoch_start_slot.uint64 ..<
current_epoch_start_slot.uint64:
for cas in get_crosslink_committees_at_slot(state, slot):
for cas in get_crosslink_committees_at_slot_cached(state, slot, false, crosslink_committee_cache):
let
(crosslink_committee, shard) = cas
(winning_root, participants) =
get_winning_root_and_participants(state, shard)
get_winning_root_and_participants(
state, shard, attester_balance_cache, crosslink_committee_cache)
participating_balance = get_total_balance(state, participants)
total_balance = get_total_balance(state, crosslink_committee)
for index in crosslink_committee: