60% state replay speedup (#4434)

* 60% state replay speedup

* don't use HashList for epoch participation - in addition to the code
currently clearing the caches several times redundantly, clearing has to
be done each block nullifying the benefit (35%)
* introduce active balance cache - computing it is slow due to cache
unfriendliness in the random access pattern and bounds checking and we
do it for every block - this cache follows the same update pattern as
the active validator index cache (20%)
* avoid recomputing base reward several times per attestation (5%)

Applying 1024 blocks goes from 20s to ~8s on my laptop - these kinds of
requests happen on historical REST queries but also whenever there's a
reorg.

* fix test and diffs
This commit is contained in:
Jacek Sieka 2022-12-19 13:01:49 +01:00 committed by GitHub
parent 064d164a88
commit 7501f10587
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 67 additions and 39 deletions

View File

@ -637,8 +637,13 @@ func get_total_active_balance*(state: ForkyBeaconState, cache: var StateCache):
let epoch = state.get_current_epoch()
get_total_balance(
state, cache.get_shuffled_active_validator_indices(state, epoch))
cache.total_active_balance.withValue(epoch, tab) do:
return tab[]
do:
let tab = get_total_balance(
state, cache.get_shuffled_active_validator_indices(state, epoch))
cache.total_active_balance[epoch] = tab
return tab
# https://github.com/ethereum/consensus-specs/blob/v1.3.0-alpha.2/specs/altair/beacon-chain.md#get_base_reward_per_increment
func get_base_reward_per_increment_sqrt*(
@ -704,15 +709,15 @@ func get_proposer_reward*(state: ForkyBeaconState,
state, attestation.data, state.slot - attestation.data.slot)
for index in get_attesting_indices(
state, attestation.data, attestation.aggregation_bits, cache):
let
base_reward = get_base_reward(state, index, base_reward_per_increment)
for flag_index, weight in PARTICIPATION_FLAG_WEIGHTS:
if flag_index in participation_flag_indices and
not has_flag(epoch_participation.item(index), flag_index):
epoch_participation[index] =
asList(epoch_participation)[index] =
add_flag(epoch_participation.item(index), flag_index)
# these are all valid; TODO statically verify or do it type-safely
result += get_base_reward(
state, index, base_reward_per_increment) * weight.uint64
epoch_participation.asHashList.clearCache()
result += base_reward * weight.uint64
let proposer_reward_denominator =
(WEIGHT_DENOMINATOR.uint64 - PROPOSER_WEIGHT.uint64) *
@ -860,8 +865,7 @@ func upgrade_to_altair*(cfg: RuntimeConfig, pre: phase0.BeaconState):
empty_participation: EpochParticipationFlags
inactivity_scores = HashList[uint64, Limit VALIDATOR_REGISTRY_LIMIT]()
doAssert empty_participation.data.setLen(pre.validators.len)
empty_participation.asHashList.resetCache()
doAssert empty_participation.asList.setLen(pre.validators.len)
doAssert inactivity_scores.data.setLen(pre.validators.len)
inactivity_scores.resetCache()

View File

@ -78,7 +78,7 @@ type
ParticipationFlags* = uint8
EpochParticipationFlags* =
distinct HashList[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]
distinct List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]
# https://github.com/ethereum/consensus-specs/blob/v1.3.0-alpha.2/specs/altair/beacon-chain.md#syncaggregate
SyncAggregate* = object
@ -558,10 +558,8 @@ type
# Represent in full; for the next epoch, current_epoch_participation in
# epoch n is previous_epoch_participation in epoch n+1 but this doesn't
# generalize.
previous_epoch_participation*:
List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]
current_epoch_participation*:
List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]
previous_epoch_participation*: EpochParticipationFlags
current_epoch_participation*: EpochParticipationFlags
justification_bits*: JustificationBits
previous_justified_checkpoint*: Checkpoint
@ -589,26 +587,44 @@ template `[]`*(arr: array[SYNC_COMMITTEE_SIZE, auto] | seq;
makeLimitedU8(SyncSubcommitteeIndex, SYNC_COMMITTEE_SUBNET_COUNT)
makeLimitedU16(IndexInSyncCommittee, SYNC_COMMITTEE_SIZE)
template asHashList*(epochFlags: EpochParticipationFlags): untyped =
HashList[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT] epochFlags
template asList*(epochFlags: EpochParticipationFlags): untyped =
List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT] epochFlags
template asList*(epochFlags: var EpochParticipationFlags): untyped =
let tmp = cast[ptr List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]](addr epochFlags)
tmp[]
template asSeq*(epochFlags: EpochParticipationFlags): untyped =
seq[ParticipationFlags] asList(epochFlags)
template asSeq*(epochFlags: var EpochParticipationFlags): untyped =
let tmp = cast[ptr seq[ParticipationFlags]](addr epochFlags)
tmp[]
template item*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex): ParticipationFlags =
asHashList(epochFlags).item(idx)
asList(epochFlags)[idx]
template `[]`*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex|uint64): ParticipationFlags =
asHashList(epochFlags)[idx]
template `[]`*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex|uint64|int): ParticipationFlags =
asList(epochFlags)[idx]
template `[]=`*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex, flags: ParticipationFlags) =
asHashList(epochFlags)[idx] = flags
asList(epochFlags)[idx] = flags
template add*(epochFlags: var EpochParticipationFlags, flags: ParticipationFlags): bool =
asHashList(epochFlags).add flags
asList(epochFlags).add flags
template len*(epochFlags: EpochParticipationFlags): int =
asHashList(epochFlags).len
asList(epochFlags).len
template data*(epochFlags: EpochParticipationFlags): untyped =
asHashList(epochFlags).data
template low*(epochFlags: EpochParticipationFlags): int =
asSeq(epochFlags).low
template high*(epochFlags: EpochParticipationFlags): int =
asSeq(epochFlags).high
template assign*(v: var EpochParticipationFlags, src: EpochParticipationFlags) =
# TODO https://github.com/nim-lang/Nim/issues/21123
mixin assign
var tmp = cast[ptr seq[ParticipationFlags]](addr v)
assign(tmp[], distinctBase src)
func shortLog*(v: SomeBeaconBlock): auto =
(

View File

@ -406,6 +406,7 @@ type
# This doesn't know about forks or branches in the DAG. It's for straight,
# linear chunks of the chain.
StateCache* = object
total_active_balance*: Table[Epoch, Gwei]
shuffled_active_validator_indices*: Table[Epoch, seq[ValidatorIndex]]
beacon_proposer_indices*: Table[Slot, Option[ValidatorIndex]]
sync_committees*: Table[SyncCommitteePeriod, SyncCommitteeCache]
@ -923,6 +924,14 @@ func prune*(cache: var StateCache, epoch: Epoch) =
pruneEpoch = epoch - 2
var drops: seq[Slot]
block:
for k in cache.total_active_balance.keys:
if k < pruneEpoch:
drops.add pruneEpoch.start_slot
for drop in drops:
cache.total_active_balance.del drop.epoch
drops.setLen(0)
block:
for k in cache.shuffled_active_validator_indices.keys:
if k < pruneEpoch:
@ -948,6 +957,7 @@ func prune*(cache: var StateCache, epoch: Epoch) =
drops.setLen(0)
func clear*(cache: var StateCache) =
cache.total_active_balance.clear
cache.shuffled_active_validator_indices.clear
cache.beacon_proposer_indices.clear
cache.sync_committees.clear

View File

@ -612,15 +612,12 @@ proc readValue*(reader: var JsonReader[RestJson], value: var Epoch) {.
proc writeValue*(writer: var JsonWriter[RestJson],
epochFlags: EpochParticipationFlags)
{.raises: [IOError, Defect].} =
for e in writer.stepwiseArrayCreation(epochFlags.asHashList):
for e in writer.stepwiseArrayCreation(epochFlags.asList):
writer.writeValue $e
proc readValue*(reader: var JsonReader[RestJson],
epochFlags: var EpochParticipationFlags)
{.raises: [SerializationError, IOError, Defect].} =
# Please note that this function won't compute the cached hash tree roots
# immediately. They will be computed on the first HTR attempt.
for e in reader.readArray(string):
let parsed = try:
parseBiggestUInt(e)
@ -632,7 +629,7 @@ proc readValue*(reader: var JsonReader[RestJson],
reader.raiseUnexpectedValue(
"The usigned integer value should fit in 8 bits")
if not epochFlags.data.add(uint8(parsed)):
if not epochFlags.asList.add(uint8(parsed)):
reader.raiseUnexpectedValue(
"The participation flags list size exceeds limit")

View File

@ -17,7 +17,7 @@ import
./datatypes/base
from ./datatypes/altair import
ParticipationFlags, EpochParticipationFlags, asHashList
ParticipationFlags, EpochParticipationFlags
export codec, base, typetraits, EpochParticipationFlags
@ -28,7 +28,7 @@ template toSszType*(v: BlsCurveType): auto = toRaw(v)
template toSszType*(v: ForkDigest|GraffitiBytes): auto = distinctBase(v)
template toSszType*(v: Version): auto = distinctBase(v)
template toSszType*(v: JustificationBits): auto = distinctBase(v)
template toSszType*(epochFlags: EpochParticipationFlags): auto = asHashList epochFlags
template toSszType*(v: EpochParticipationFlags): auto = asList v
func fromSszBytes*(T: type GraffitiBytes, data: openArray[byte]): T {.raisesssz.} =
if data.len != sizeof(result):
@ -60,4 +60,6 @@ func fromSszBytes*(T: type JustificationBits, bytes: openArray[byte]): T {.raise
copyMem(result.addr, unsafeAddr bytes[0], sizeof(result))
func fromSszBytes*(T: type EpochParticipationFlags, bytes: openArray[byte]): T {.raisesssz.} =
readSszValue(bytes, HashList[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT] result)
# TODO https://github.com/nim-lang/Nim/issues/21123
let tmp = cast[ptr List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]](addr result)
readSszValue(bytes, tmp[])

View File

@ -120,6 +120,7 @@ func process_slot*(
hash_tree_root(state.latest_block_header)
func clear_epoch_from_cache(cache: var StateCache, epoch: Epoch) =
cache.total_active_balance.del epoch
cache.shuffled_active_validator_indices.del epoch
for slot in epoch.slots():

View File

@ -1005,13 +1005,11 @@ func process_participation_flag_updates*(
const zero = 0.ParticipationFlags
for i in 0 ..< state.current_epoch_participation.len:
state.current_epoch_participation.data[i] = zero
asList(state.current_epoch_participation)[i] = zero
# Shouldn't be wasted zeroing, because state.current_epoch_participation only
# grows. New elements are automatically initialized to 0, as required.
doAssert state.current_epoch_participation.data.setLen(state.validators.len)
state.current_epoch_participation.asHashList.resetCache()
doAssert state.current_epoch_participation.asList.setLen(state.validators.len)
# https://github.com/ethereum/consensus-specs/blob/v1.3.0-alpha.2/specs/altair/beacon-chain.md#sync-committee-updates
func process_sync_committee_updates*(

View File

@ -143,8 +143,8 @@ func diffStates*(state0, state1: bellatrix.BeaconState): BeaconStateDiff =
slashing: state1.slashings[state0.slot.epoch.uint64 mod
EPOCHS_PER_HISTORICAL_VECTOR.uint64],
previous_epoch_participation: state1.previous_epoch_participation.data,
current_epoch_participation: state1.current_epoch_participation.data,
previous_epoch_participation: state1.previous_epoch_participation,
current_epoch_participation: state1.current_epoch_participation,
justification_bits: state1.justification_bits,
previous_justified_checkpoint: state1.previous_justified_checkpoint,
@ -192,9 +192,9 @@ func applyDiff*(
assign(state.slashings.mitem(epochIndex), stateDiff.slashing)
assign(
state.previous_epoch_participation.data, stateDiff.previous_epoch_participation)
state.previous_epoch_participation, stateDiff.previous_epoch_participation)
assign(
state.current_epoch_participation.data, stateDiff.current_epoch_participation)
state.current_epoch_participation, stateDiff.current_epoch_participation)
state.justification_bits = stateDiff.justification_bits
assign(