switch state transition caching to match EpochRef (#1089)

* switch state transition caching usage to shuffled active validator indices to match EpochRef

* refactor the EpochRef -> StateCache transformation; elide pointless mapIt

* limit state passed between get_beacon_committee(...) and compute_committee(...)

* tweaks
This commit is contained in:
tersec 2020-06-01 07:44:50 +00:00 committed by GitHub
parent 0bcdabfcdf
commit a327e8581b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 47 additions and 66 deletions

View File

@ -139,7 +139,8 @@ type
epochsInfo*: seq[EpochRef]
## Could be multiple, since blocks could skip slots, but usually, not many
## Even if competing forks happen later during this epoch, potential empty
## slots beforehand must all be from this fork.
## slots beforehand must all be from this fork. getEpochInfo() is the only
## supported way of accesssing these.
BlockData* = object
## Body and graph in one

View File

@ -156,6 +156,13 @@ func getEpochInfo*(blck: BlockRef, state: BeaconState): EpochRef =
else:
raiseAssert "multiple EpochRefs per epoch per BlockRef invalid"
func getEpochCache*(blck: BlockRef, state: BeaconState): StateCache =
let epochInfo = getEpochInfo(blck, state)
result = get_empty_per_epoch_cache()
result.shuffled_active_validator_indices[
state.slot.compute_epoch_at_slot] =
epochInfo.shuffled_active_validator_indices
func init(T: type BlockRef, root: Eth2Digest, slot: Slot): BlockRef =
BlockRef(
root: root,
@ -447,7 +454,8 @@ proc skipAndUpdateState(
# save and reuse
# TODO possibly we should keep this in memory for the hot blocks
let nextStateRoot = dag.db.getStateRoot(blck.root, state.data.slot + 1)
advance_slot(state, nextStateRoot, dag.updateFlags)
var stateCache = getEpochCache(blck, state.data)
advance_slot(state, nextStateRoot, dag.updateFlags, stateCache)
if save:
dag.putState(state, blck)
@ -464,14 +472,7 @@ proc skipAndUpdateState(
doAssert (addr(statePtr.data) == addr v)
statePtr[] = dag.headState
# TODO it's probably not the right way to convey this, but for now, avoids
# death-by-dozens-of-pointless-changes in developing this
let epochInfo = getEpochInfo(blck.refs, state.data.data)
var stateCache = get_empty_per_epoch_cache()
stateCache.shuffled_active_validator_indices[
state.data.data.slot.compute_epoch_at_slot] =
epochInfo.shuffled_active_validator_indices
var stateCache = getEpochCache(blck.refs, state.data.data)
let ok = state_transition(
state.data, blck.data, stateCache, flags + dag.updateFlags, restore)

View File

@ -9,7 +9,7 @@ import
chronicles, tables,
metrics, stew/results,
../ssz, ../state_transition, ../extras,
../spec/[crypto, datatypes, digest, helpers, validator],
../spec/[crypto, datatypes, digest, helpers],
block_pools_types, candidate_chains
@ -177,17 +177,7 @@ proc add*(
doAssert v.addr == addr poolPtr.tmpState.data
poolPtr.tmpState = poolPtr.headState
# TODO it's probably not the right way to convey this, but for now, avoids
# death-by-dozens-of-pointless-changes in developing this
# TODO rename these, since now, the two "state cache"s are juxtaposed
# directly
let epochInfo = getEpochInfo(parent, dag.tmpState.data.data)
var stateCache = get_empty_per_epoch_cache()
stateCache.shuffled_active_validator_indices[
dag.tmpState.data.data.slot.compute_epoch_at_slot] =
epochInfo.shuffled_active_validator_indices
# End of section to refactor/combine
var stateCache = getEpochCache(parent, dag.tmpState.data.data)
if not state_transition(
dag.tmpState.data, signedBlock, stateCache, dag.updateFlags, restore):
# TODO find a better way to log all this block data

View File

@ -401,10 +401,6 @@ type
# TODO remove some of these, or otherwise coordinate with EpochRef
StateCache* = object
beacon_committee_cache*:
Table[tuple[a: int, b: Eth2Digest], seq[ValidatorIndex]]
active_validator_indices_cache*:
Table[Epoch, seq[ValidatorIndex]]
shuffled_active_validator_indices*:
Table[Epoch, seq[ValidatorIndex]]
committee_count_cache*: Table[Epoch, uint64]

View File

@ -423,14 +423,12 @@ func process_final_updates*(state: var BeaconState) {.nbench.}=
state.current_epoch_attestations = default(type state.current_epoch_attestations)
# https://github.com/ethereum/eth2.0-specs/blob/v0.11.3/specs/phase0/beacon-chain.md#epoch-processing
proc process_epoch*(state: var BeaconState, updateFlags: UpdateFlags)
{.nbench.} =
proc process_epoch*(state: var BeaconState, updateFlags: UpdateFlags,
per_epoch_cache: var StateCache) {.nbench.} =
let currentEpoch = get_current_epoch(state)
trace "process_epoch",
current_epoch = currentEpoch
var per_epoch_cache = get_empty_per_epoch_cache()
# https://github.com/ethereum/eth2.0-specs/blob/v0.11.3/specs/phase0/beacon-chain.md#justification-and-finalization
process_justification_and_finalization(state, per_epoch_cache, updateFlags)
@ -450,7 +448,8 @@ proc process_epoch*(state: var BeaconState, updateFlags: UpdateFlags)
## Caching here for get_beacon_committee(...) can break otherwise, since
## get_active_validator_indices(...) usually changes.
clear(per_epoch_cache.beacon_committee_cache)
per_epoch_cache.shuffled_active_validator_indices[currentEpoch] =
get_shuffled_active_validator_indices(state, currentEpoch)
# https://github.com/ethereum/eth2.0-specs/blob/v0.11.3/specs/phase0/beacon-chain.md#slashings
process_slashings(state)

View File

@ -9,7 +9,7 @@
{.push raises: [Defect].}
import
options, nimcrypto, sequtils, math, tables,
options, sequtils, math, tables,
./datatypes, ./digest, ./helpers
# https://github.com/ethereum/eth2.0-specs/blob/v0.11.3/specs/phase0/beacon-chain.md#compute_shuffled_index
@ -101,18 +101,16 @@ func get_previous_epoch*(state: BeaconState): Epoch =
# https://github.com/ethereum/eth2.0-specs/blob/v0.11.3/specs/phase0/beacon-chain.md#compute_committee
func compute_committee(indices: seq[ValidatorIndex], seed: Eth2Digest,
index: uint64, count: uint64, stateCache: var StateCache): seq[ValidatorIndex] =
index: uint64, count: uint64): seq[ValidatorIndex] =
## Return the committee corresponding to ``indices``, ``seed``, ``index``,
## and committee ``count``.
# indices only used here for its length, or for the shuffled version,
# so unlike spec, pass the shuffled version in directly.
try:
let
start = (len(indices).uint64 * index) div count
endIdx = (len(indices).uint64 * (index + 1)) div count
key = (indices.len, seed)
if key notin stateCache.beacon_committee_cache:
stateCache.beacon_committee_cache[key] =
get_shuffled_seq(seed, len(indices).uint64)
# These assertions from compute_shuffled_index(...)
let index_count = indices.len().uint64
@ -120,9 +118,8 @@ func compute_committee(indices: seq[ValidatorIndex], seed: Eth2Digest,
doAssert index_count <= 2'u64^40
# In spec, this calls get_shuffled_index() every time, but that's wasteful
mapIt(
start.int .. (endIdx.int-1),
indices[stateCache.beacon_committee_cache[key][it]])
# Here, get_beacon_committee() gets the shuffled version.
indices[start.int .. (endIdx.int-1)]
except KeyError:
raiseAssert("Cached entries are added before use")
@ -135,12 +132,14 @@ func get_beacon_committee*(
epoch = compute_epoch_at_slot(slot)
try:
## This is a somewhat more fragile, but high-ROI, caching setup --
## get_active_validator_indices() is slow to run in a loop and only
## changes once per epoch.
if epoch notin cache.active_validator_indices_cache:
cache.active_validator_indices_cache[epoch] =
get_active_validator_indices(state, epoch)
# This is a somewhat more fragile, but high-ROI, caching setup --
# get_active_validator_indices() is slow to run in a loop and only
# changes once per epoch. It is not, in the general case, possible
# to precompute these arbitrarily far out so still need to pick up
# missing cases here.
if epoch notin cache.shuffled_active_validator_indices:
cache.shuffled_active_validator_indices[epoch] =
get_shuffledactive_validator_indices(state, epoch)
# Constant throughout an epoch
if epoch notin cache.committee_count_cache:
@ -148,22 +147,17 @@ func get_beacon_committee*(
get_committee_count_at_slot(state, slot)
compute_committee(
cache.active_validator_indices_cache[epoch],
cache.shuffled_active_validator_indices[epoch],
get_seed(state, epoch, DOMAIN_BEACON_ATTESTER),
(slot mod SLOTS_PER_EPOCH) * cache.committee_count_cache[epoch] +
index.uint64,
cache.committee_count_cache[epoch] * SLOTS_PER_EPOCH,
cache
cache.committee_count_cache[epoch] * SLOTS_PER_EPOCH
)
except KeyError:
raiseAssert "values are added to cache before using them"
# Not from spec
func get_empty_per_epoch_cache*(): StateCache =
result.beacon_committee_cache =
initTable[tuple[a: int, b: Eth2Digest], seq[ValidatorIndex]]()
result.active_validator_indices_cache =
initTable[Epoch, seq[ValidatorIndex]]()
result.shuffled_active_validator_indices =
initTable[Epoch, seq[ValidatorIndex]]()
result.committee_count_cache = initTable[Epoch, uint64]()

View File

@ -125,7 +125,8 @@ func process_slot*(state: var HashedBeaconState) {.nbench.} =
# https://github.com/ethereum/eth2.0-specs/blob/v0.11.3/specs/phase0/beacon-chain.md#beacon-chain-state-transition-function
proc advance_slot*(state: var HashedBeaconState,
nextStateRoot: Opt[Eth2Digest], updateFlags: UpdateFlags) {.nbench.} =
nextStateRoot: Opt[Eth2Digest], updateFlags: UpdateFlags,
epochCache: var StateCache) {.nbench.} =
# Special case version of process_slots that moves one slot at a time - can
# run faster if the state root is known already (for example when replaying
# existing slots)
@ -134,7 +135,7 @@ proc advance_slot*(state: var HashedBeaconState,
if is_epoch_transition:
# Note: Genesis epoch = 0, no need to test if before Genesis
beacon_previous_validators.set(get_epoch_validator_count(state.data))
process_epoch(state.data, updateFlags)
process_epoch(state.data, updateFlags, epochCache)
state.data.slot += 1
if is_epoch_transition:
beacon_current_validators.set(get_epoch_validator_count(state.data))
@ -147,12 +148,6 @@ proc advance_slot*(state: var HashedBeaconState,
# https://github.com/ethereum/eth2.0-specs/blob/v0.11.3/specs/phase0/beacon-chain.md#beacon-chain-state-transition-function
proc process_slots*(state: var HashedBeaconState, slot: Slot,
updateFlags: UpdateFlags = {}): bool {.nbench.} =
# TODO: Eth specs strongly assert that state.data.slot <= slot
# This prevents receiving attestation in any order
# (see tests/test_attestation_pool)
# but it maybe an artifact of the test case
# as this was not triggered in the testnet1
# after a hour
# TODO this function is not _really_ necessary: when replaying states, we
# advance slots one by one before calling `state_transition` - this way,
# we avoid the state root calculation - as such, instead of advancing
@ -168,8 +163,9 @@ proc process_slots*(state: var HashedBeaconState, slot: Slot,
return false
# Catch up to the target slot
var cache = get_empty_per_epoch_cache()
while state.data.slot < slot:
advance_slot(state, err(Opt[Eth2Digest]), updateFlags)
advance_slot(state, err(Opt[Eth2Digest]), updateFlags, cache)
true
@ -212,7 +208,8 @@ proc state_transition*(
# the changes in case of failure (look out for `var BeaconState` and
# bool return values...)
doAssert not rollback.isNil, "use noRollback if it's ok to mess up state"
doAssert stateCache.shuffled_active_validator_indices.hasKey(state.data.slot.compute_epoch_at_slot)
doAssert stateCache.shuffled_active_validator_indices.hasKey(
state.data.slot.compute_epoch_at_slot)
if not process_slots(state, signedBlock.message.slot, flags):
rollback(state)
@ -238,6 +235,7 @@ proc state_transition*(
# TODO when creating a new block, state_root is not yet set.. comparing
# with zero hash here is a bit fragile however, but this whole thing
# should go away with proper hash caching
# TODO shouldn't ever have to recalculate; verifyStateRoot() does it
state.root =
if signedBlock.message.state_root == Eth2Digest(): hash_tree_root(state.data)
else: signedBlock.message.state_root
@ -256,6 +254,8 @@ proc state_transition*(
# and fuzzing code should always be coming from blockpool which should
# always be providing cache or equivalent
var cache = get_empty_per_epoch_cache()
# TODO not here, but in blockpool, should fill in as far ahead towards
# block's slot as protocol allows to be known already
cache.shuffled_active_validator_indices[state.data.slot.compute_epoch_at_slot] =
get_shuffled_active_validator_indices(
state.data, state.data.slot.compute_epoch_at_slot)

View File

@ -94,9 +94,9 @@ proc addTestBlock*(
graffiti = Eth2Digest(),
flags: set[UpdateFlag] = {}): SignedBeaconBlock =
# Create and add a block to state - state will advance by one slot!
advance_slot(state, err(Opt[Eth2Digest]), flags)
var cache = get_empty_per_epoch_cache()
advance_slot(state, err(Opt[Eth2Digest]), flags, cache)
let
proposer_index = get_beacon_proposer_index(state.data, cache)
privKey = hackPrivKey(state.data.validators[proposer_index.get])