reduce memory allocations during state transition (#5235)

This PR removes a few hundred thousand temporary seq allocations during
state transition - in particular, the flag seq was allocated per
validator while committees are computed per attestation.
This commit is contained in:
Jacek Sieka 2023-08-03 01:03:40 +02:00 committed by GitHub
parent 28194468c9
commit 5bc48acc36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 71 additions and 61 deletions

View File

@ -16,7 +16,7 @@ import
"."/[eth2_merkleization, forks, signatures, validator]
from std/algorithm import fill
from std/sequtils import anyIt, mapIt
from std/sequtils import anyIt, mapIt, toSeq
from ./datatypes/capella import BeaconState, ExecutionPayloadHeader, Withdrawal
@ -381,14 +381,15 @@ proc is_valid_indexed_attestation*(
ok()
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/phase0/beacon-chain.md#get_attesting_indices
func get_attesting_indices*(state: ForkyBeaconState,
data: AttestationData,
bits: CommitteeValidatorsBits,
cache: var StateCache): seq[ValidatorIndex] =
iterator get_attesting_indices_iter*(state: ForkyBeaconState,
data: AttestationData,
bits: CommitteeValidatorsBits,
cache: var StateCache): ValidatorIndex =
## Return the set of attesting indices corresponding to ``data`` and ``bits``
## or nothing if `data` is invalid
## This iterator must not be called in functions using a
## ForkedHashedBeaconState due to https://github.com/nim-lang/Nim/issues/18188
var res: seq[ValidatorIndex]
# Can't be an iterator due to https://github.com/nim-lang/Nim/issues/18188
let committee_index = CommitteeIndex.init(data.index)
if committee_index.isErr() or bits.lenu64 != get_beacon_committee_len(
@ -398,9 +399,17 @@ func get_attesting_indices*(state: ForkyBeaconState,
for index_in_committee, validator_index in get_beacon_committee(
state, data.slot, committee_index.get(), cache):
if bits[index_in_committee]:
res.add validator_index
yield validator_index
res
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/phase0/beacon-chain.md#get_attesting_indices
func get_attesting_indices*(state: ForkyBeaconState,
data: AttestationData,
bits: CommitteeValidatorsBits,
cache: var StateCache): seq[ValidatorIndex] =
## Return the set of attesting indices corresponding to ``data`` and ``bits``
## or nothing if `data` is invalid
toSeq(get_attesting_indices_iter(state, data, bits, cache))
func get_attesting_indices*(state: ForkedHashedBeaconState;
data: AttestationData;
@ -432,7 +441,7 @@ proc is_valid_indexed_attestation*(
if not (skipBlsValidation in flags or attestation.signature is TrustedSig):
var
pubkeys = newSeqOfCap[ValidatorPubKey](sigs)
for index in get_attesting_indices(
for index in get_attesting_indices_iter(
state, attestation.data, attestation.aggregation_bits, cache):
pubkeys.add(state.validators[index].pubkey)
@ -496,7 +505,7 @@ func check_attestation_index*(
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/altair/beacon-chain.md#get_attestation_participation_flag_indices
func get_attestation_participation_flag_indices(
state: altair.BeaconState | bellatrix.BeaconState | capella.BeaconState,
data: AttestationData, inclusion_delay: uint64): seq[int] =
data: AttestationData, inclusion_delay: uint64): set[TimelyFlag] =
## Return the flag indices that are satisfied by an attestation.
let justified_checkpoint =
if data.target.epoch == get_current_epoch(state):
@ -517,20 +526,20 @@ func get_attestation_participation_flag_indices(
# Checked by check_attestation()
doAssert is_matching_source
var participation_flag_indices: seq[int]
var participation_flag_indices: set[TimelyFlag]
if is_matching_source and inclusion_delay <= integer_squareroot(SLOTS_PER_EPOCH):
participation_flag_indices.add(TIMELY_SOURCE_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_SOURCE_FLAG_INDEX)
if is_matching_target and inclusion_delay <= SLOTS_PER_EPOCH:
participation_flag_indices.add(TIMELY_TARGET_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_TARGET_FLAG_INDEX)
if is_matching_head and inclusion_delay == MIN_ATTESTATION_INCLUSION_DELAY:
participation_flag_indices.add(TIMELY_HEAD_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_HEAD_FLAG_INDEX)
participation_flag_indices
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/deneb/beacon-chain.md#modified-get_attestation_participation_flag_indices
func get_attestation_participation_flag_indices(
state: deneb.BeaconState,
data: AttestationData, inclusion_delay: uint64): seq[int] =
data: AttestationData, inclusion_delay: uint64): set[TimelyFlag] =
## Return the flag indices that are satisfied by an attestation.
let justified_checkpoint =
if data.target.epoch == get_current_epoch(state):
@ -551,13 +560,13 @@ func get_attestation_participation_flag_indices(
# Checked by check_attestation
doAssert is_matching_source
var participation_flag_indices: seq[int]
var participation_flag_indices: set[TimelyFlag]
if is_matching_source and inclusion_delay <= integer_squareroot(SLOTS_PER_EPOCH):
participation_flag_indices.add(TIMELY_SOURCE_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_SOURCE_FLAG_INDEX)
if is_matching_target: # [Modified in Deneb:EIP7045]
participation_flag_indices.add(TIMELY_TARGET_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_TARGET_FLAG_INDEX)
if is_matching_head and inclusion_delay == MIN_ATTESTATION_INCLUSION_DELAY:
participation_flag_indices.add(TIMELY_HEAD_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_HEAD_FLAG_INDEX)
participation_flag_indices
@ -672,7 +681,7 @@ func get_proposer_reward*(state: ForkyBeaconState,
epoch_participation: var EpochParticipationFlags): uint64 =
let participation_flag_indices = get_attestation_participation_flag_indices(
state, attestation.data, state.slot - attestation.data.slot)
for index in get_attesting_indices(
for index in get_attesting_indices_iter(
state, attestation.data, attestation.aggregation_bits, cache):
let
base_reward = get_base_reward(state, index, base_reward_per_increment)
@ -1115,7 +1124,7 @@ func translate_participation(
get_attestation_participation_flag_indices(state, data, inclusion_delay)
# Apply flags to all attesting validators
for index in get_attesting_indices(
for index in get_attesting_indices_iter(
state, data, attestation.aggregation_bits, cache):
for flag_index in participation_flag_indices:
state.previous_epoch_participation[index] =

View File

@ -26,6 +26,19 @@ export base, sets
from ssz_serialization/proofs import GeneralizedIndex
export proofs.GeneralizedIndex
type
TimelyFlag* {.pure.} = enum
TIMELY_SOURCE_FLAG_INDEX
TIMELY_TARGET_FLAG_INDEX
TIMELY_HEAD_FLAG_INDEX
static:
# Verify that ordinals follow spec values (the spec uses these as shifts for bit flags)
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#participation-flag-indices
doAssert ord(TIMELY_SOURCE_FLAG_INDEX) == 0
doAssert ord(TIMELY_TARGET_FLAG_INDEX) == 1
doAssert ord(TIMELY_HEAD_FLAG_INDEX) == 2
const
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#incentivization-weights
TIMELY_SOURCE_WEIGHT* = 14
@ -35,8 +48,8 @@ const
PROPOSER_WEIGHT* = 8
WEIGHT_DENOMINATOR* = 64
PARTICIPATION_FLAG_WEIGHTS* =
[TIMELY_SOURCE_WEIGHT, TIMELY_TARGET_WEIGHT, TIMELY_HEAD_WEIGHT]
PARTICIPATION_FLAG_WEIGHTS*: array[TimelyFlag, uint64] =
[uint64 TIMELY_SOURCE_WEIGHT, TIMELY_TARGET_WEIGHT, TIMELY_HEAD_WEIGHT]
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/altair/validator.md#misc
TARGET_AGGREGATORS_PER_SYNC_SUBCOMMITTEE* = 16
@ -52,11 +65,6 @@ const
CURRENT_SYNC_COMMITTEE_INDEX* = 54.GeneralizedIndex # `current_sync_committee`
NEXT_SYNC_COMMITTEE_INDEX* = 55.GeneralizedIndex # `next_sync_committee`
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#participation-flag-indices
TIMELY_SOURCE_FLAG_INDEX* = 0
TIMELY_TARGET_FLAG_INDEX* = 1
TIMELY_HEAD_FLAG_INDEX* = 2
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#inactivity-penalties
INACTIVITY_SCORE_BIAS* = 4
INACTIVITY_SCORE_RECOVERY_RATE* = 16
@ -310,7 +318,7 @@ type
next_sync_committee*: SyncCommittee # [New in Altair]
UnslashedParticipatingBalances* = object
previous_epoch*: array[PARTICIPATION_FLAG_WEIGHTS.len, Gwei]
previous_epoch*: array[TimelyFlag, Gwei]
current_epoch_TIMELY_TARGET*: Gwei
current_epoch*: Gwei # aka total_active_balance

View File

@ -200,13 +200,13 @@ func get_seed*(state: ForkyBeaconState, epoch: Epoch, domain_type: DomainType):
state.get_seed(epoch, domain_type, mix)
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#add_flag
func add_flag*(flags: ParticipationFlags, flag_index: int): ParticipationFlags =
let flag = ParticipationFlags(1'u8 shl flag_index)
func add_flag*(flags: ParticipationFlags, flag_index: TimelyFlag): ParticipationFlags =
let flag = ParticipationFlags(1'u8 shl ord(flag_index))
flags or flag
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#has_flag
func has_flag*(flags: ParticipationFlags, flag_index: int): bool =
let flag = ParticipationFlags(1'u8 shl flag_index)
func has_flag*(flags: ParticipationFlags, flag_index: TimelyFlag): bool =
let flag = ParticipationFlags(1'u8 shl ord(flag_index))
(flags and flag) == flag
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/altair/light-client/sync-protocol.md#is_sync_committee_update

View File

@ -113,7 +113,7 @@ func process_attestation(
flags.incl RewardFlags.isPreviousEpochHeadAttester
# Update the cache for all participants
for validator_index in get_attesting_indices(
for validator_index in get_attesting_indices_iter(
state, a.data, a.aggregation_bits, cache):
template v(): untyped = info.validators[validator_index]
@ -205,7 +205,7 @@ func get_unslashed_participating_balances*(
state.previous_epoch_participation[validator_index]
if is_active_previous_epoch:
for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
if has_flag(previous_epoch_participation, flag_index):
res.previous_epoch[flag_index] += validator_effective_balance
@ -216,7 +216,7 @@ func get_unslashed_participating_balances*(
TIMELY_TARGET_FLAG_INDEX):
res.current_epoch_TIMELY_TARGET += validator_effective_balance
for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
res.previous_epoch[flag_index] =
max(EFFECTIVE_BALANCE_INCREMENT, res.previous_epoch[flag_index])
@ -230,7 +230,7 @@ func get_unslashed_participating_balances*(
func is_unslashed_participating_index(
state: altair.BeaconState | bellatrix.BeaconState | capella.BeaconState |
deneb.BeaconState,
flag_index: int, epoch: Epoch, validator_index: ValidatorIndex): bool =
flag_index: TimelyFlag, epoch: Epoch, validator_index: ValidatorIndex): bool =
doAssert epoch in [get_previous_epoch(state), get_current_epoch(state)]
# TODO hoist this conditional
let epoch_participation =
@ -658,7 +658,7 @@ func get_flag_index_reward*(
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/altair/beacon-chain.md#get_flag_index_deltas
func get_unslashed_participating_increment*(
info: altair.EpochInfo | bellatrix.BeaconState, flag_index: int): Gwei =
info: altair.EpochInfo | bellatrix.BeaconState, flag_index: TimelyFlag): Gwei =
info.balances.previous_epoch[flag_index] div EFFECTIVE_BALANCE_INCREMENT
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#get_flag_index_deltas
@ -670,14 +670,14 @@ func get_active_increments*(
iterator get_flag_index_deltas*(
state: altair.BeaconState | bellatrix.BeaconState | capella.BeaconState |
deneb.BeaconState,
flag_index: int, base_reward_per_increment: Gwei,
flag_index: TimelyFlag, base_reward_per_increment: Gwei,
info: var altair.EpochInfo, finality_delay: uint64):
(ValidatorIndex, RewardDelta) =
## Return the deltas for a given ``flag_index`` by scanning through the
## participation flags.
let
previous_epoch = get_previous_epoch(state)
weight = PARTICIPATION_FLAG_WEIGHTS[flag_index].uint64 # safe
weight = PARTICIPATION_FLAG_WEIGHTS[flag_index]
unslashed_participating_increments = get_unslashed_participating_increment(
info, flag_index)
active_increments = get_active_increments(info)
@ -695,7 +695,6 @@ iterator get_flag_index_deltas*(
of TIMELY_SOURCE_FLAG_INDEX: ParticipationFlag.timelySourceAttester
of TIMELY_TARGET_FLAG_INDEX: ParticipationFlag.timelyTargetAttester
of TIMELY_HEAD_FLAG_INDEX: ParticipationFlag.timelyHeadAttester
else: raiseAssert "Unknown flag index " & $flag_index
info.validators[vidx].flags.incl pflag
@ -796,7 +795,7 @@ func process_rewards_and_penalties*(
finality_delay = get_finality_delay(state)
doAssert state.validators.len() == info.validators.len()
for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
for validator_index, delta in get_flag_index_deltas(
state, flag_index, base_reward_per_increment, info, finality_delay):
info.validators[validator_index].delta.add(delta)

View File

@ -288,7 +288,7 @@ proc collectEpochRewardsAndPenalties*(
total_active_balance)
finality_delay = get_finality_delay(state)
for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
for validator_index, delta in get_flag_index_deltas(
state, flag_index, base_reward_per_increment, info, finality_delay):
template rp: untyped = rewardsAndPenalties[validator_index]
@ -302,7 +302,7 @@ proc collectEpochRewardsAndPenalties*(
max_flag_index_reward = get_flag_index_reward(
state, base_reward, active_increments,
unslashed_participating_increment,
PARTICIPATION_FLAG_WEIGHTS[flag_index].uint64,
PARTICIPATION_FLAG_WEIGHTS[flag_index],
finality_delay)
case flag_index
@ -315,8 +315,6 @@ proc collectEpochRewardsAndPenalties*(
of TIMELY_HEAD_FLAG_INDEX:
rp.head_outcome = delta.getOutcome
rp.max_head_reward = max_flag_index_reward
else:
raiseAssert(&"Unknown flag index {flag_index}.")
for validator_index, penalty in get_inactivity_penalty_deltas(
cfg, state, info):

View File

@ -1,5 +1,5 @@
# beacon_chain
# Copyright (c) 2020-2022 Status Research & Development GmbH
# Copyright (c) 2020-2023 Status Research & Development GmbH
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
@ -49,9 +49,8 @@ proc runTest(rewardsDir, identifier: string) =
total_balance = info.balances.current_epoch
base_reward_per_increment = get_base_reward_per_increment(total_balance)
static: doAssert PARTICIPATION_FLAG_WEIGHTS.len == 3
var
flagDeltas2 = [
flagDeltas2: array[TimelyFlag, Deltas] = [
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len)]
@ -59,7 +58,7 @@ proc runTest(rewardsDir, identifier: string) =
let finality_delay = get_finality_delay(state[])
for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
for validator_index, delta in get_flag_index_deltas(
state[], flag_index, base_reward_per_increment, info, finality_delay):
if not is_eligible_validator(info.validators[validator_index]):

View File

@ -1,5 +1,5 @@
# beacon_chain
# Copyright (c) 2020-2022 Status Research & Development GmbH
# Copyright (c) 2020-2023 Status Research & Development GmbH
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
@ -49,9 +49,8 @@ proc runTest(rewardsDir, identifier: string) =
total_balance = info.balances.current_epoch
base_reward_per_increment = get_base_reward_per_increment(total_balance)
static: doAssert PARTICIPATION_FLAG_WEIGHTS.len == 3
var
flagDeltas2 = [
flagDeltas2: array[TimelyFlag, Deltas] = [
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len)]
@ -59,7 +58,7 @@ proc runTest(rewardsDir, identifier: string) =
let finality_delay = get_finality_delay(state[])
for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
for validator_index, delta in get_flag_index_deltas(
state[], flag_index, base_reward_per_increment, info, finality_delay):
if not is_eligible_validator(info.validators[validator_index]):

View File

@ -1,5 +1,5 @@
# beacon_chain
# Copyright (c) 2020-2022 Status Research & Development GmbH
# Copyright (c) 2020-2023 Status Research & Development GmbH
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
@ -49,9 +49,8 @@ proc runTest(rewardsDir, identifier: string) =
total_balance = info.balances.current_epoch
base_reward_per_increment = get_base_reward_per_increment(total_balance)
static: doAssert PARTICIPATION_FLAG_WEIGHTS.len == 3
var
flagDeltas2 = [
flagDeltas2: array[TimelyFlag, Deltas] = [
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len)]
@ -59,7 +58,7 @@ proc runTest(rewardsDir, identifier: string) =
let finality_delay = get_finality_delay(state[])
for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
for validator_index, delta in get_flag_index_deltas(
state[], flag_index, base_reward_per_increment, info, finality_delay):
if not is_eligible_validator(info.validators[validator_index]):

View File

@ -49,9 +49,8 @@ proc runTest(rewardsDir, identifier: string) =
total_balance = info.balances.current_epoch
base_reward_per_increment = get_base_reward_per_increment(total_balance)
static: doAssert PARTICIPATION_FLAG_WEIGHTS.len == 3
var
flagDeltas2 = [
flagDeltas2: array[TimelyFlag, Deltas] = [
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len)]
@ -59,7 +58,7 @@ proc runTest(rewardsDir, identifier: string) =
let finality_delay = get_finality_delay(state[])
for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
for validator_index, delta in get_flag_index_deltas(
state[], flag_index, base_reward_per_increment, info, finality_delay):
if not is_eligible_validator(info.validators[validator_index]):