allow withXxx to access fork-specific fields (#2943)

So far, `withState` and `withBlck` templates could only be used to have
convenience access to fork-agnostic BeaconState and BeaconBlock fields.
This patch:
- injects an additional `stateFork` constant that allows to use
  `when` expressions to also access Altair and Merge-specific fields.
- introduces a `withStateAndBlck` template to support operating on both
  a `BeaconState` and `BeaconBlock` at a time.
- makes sync committee related functions Merge aware.
- changes a couple if-else trees for forks into case statements so that
  forgotten future forks are promoted to compile-time errors.
This commit is contained in:
Etan Kissling 2021-10-06 19:05:06 +02:00 committed by GitHub
parent a017e4a817
commit 9ee134324b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 128 additions and 118 deletions

View File

@ -1049,21 +1049,22 @@ func syncCommitteeParticipants*(dagParam: ChainDAGRef,
dag = dagParam
slot = slotParam
if dag.headState.data.beaconStateFork == forkAltair:
let
headSlot = dag.headState.data.hbsAltair.data.slot
headCommitteePeriod = syncCommitteePeriod(headSlot)
periodStart = syncCommitteePeriodStartSlot(headCommitteePeriod)
nextPeriodStart = periodStart + SLOTS_PER_SYNC_COMMITTEE_PERIOD
withState(dag.headState.data):
when stateFork >= forkAltair:
let
headSlot = state.data.slot
headCommitteePeriod = syncCommitteePeriod(headSlot)
periodStart = syncCommitteePeriodStartSlot(headCommitteePeriod)
nextPeriodStart = periodStart + SLOTS_PER_SYNC_COMMITTEE_PERIOD
if slot >= nextPeriodStart:
@(dag.headState.data.hbsAltair.data.next_sync_committee.pubkeys.data)
elif slot >= periodStart:
@(dag.headState.data.hbsAltair.data.current_sync_committee.pubkeys.data)
if slot >= nextPeriodStart:
@(state.data.next_sync_committee.pubkeys.data)
elif slot >= periodStart:
@(state.data.current_sync_committee.pubkeys.data)
else:
@[]
else:
@[]
else:
@[]
func getSubcommitteePositionsAux(
dag: ChainDAGRef,
@ -1086,24 +1087,25 @@ func getSubcommitteePositions*(dag: ChainDAGRef,
slot: Slot,
committeeIdx: SyncCommitteeIndex,
validatorIdx: uint64): seq[uint64] =
if dag.headState.data.beaconStateFork == forkPhase0:
return @[]
withState(dag.headState.data):
when stateFork >= forkAltair:
let
headSlot = state.data.slot
headCommitteePeriod = syncCommitteePeriod(headSlot)
periodStart = syncCommitteePeriodStartSlot(headCommitteePeriod)
nextPeriodStart = periodStart + SLOTS_PER_SYNC_COMMITTEE_PERIOD
let
headSlot = dag.headState.data.hbsAltair.data.slot
headCommitteePeriod = syncCommitteePeriod(headSlot)
periodStart = syncCommitteePeriodStartSlot(headCommitteePeriod)
nextPeriodStart = periodStart + SLOTS_PER_SYNC_COMMITTEE_PERIOD
template search(syncCommittee: openarray[ValidatorPubKey]): seq[uint64] =
dag.getSubcommitteePositionsAux(syncCommittee, committeeIdx, validatorIdx)
template search(syncCommittee: openarray[ValidatorPubKey]): seq[uint64] =
dag.getSubcommitteePositionsAux(syncCommittee, committeeIdx, validatorIdx)
if slot < periodStart:
return @[]
elif slot >= nextPeriodStart:
return search(dag.headState.data.hbsAltair.data.next_sync_committee.pubkeys.data)
else:
return search(dag.headState.data.hbsAltair.data.current_sync_committee.pubkeys.data)
if slot < periodStart:
@[]
elif slot >= nextPeriodStart:
search(state.data.next_sync_committee.pubkeys.data)
else:
search(state.data.current_sync_committee.pubkeys.data)
else:
@[]
template syncCommitteeParticipants*(
dag: ChainDAGRef,

View File

@ -1,8 +1,10 @@
# beacon_chain
# Copyright (c) 2018-2021 Status Research & Development GmbH
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
std/[typetraits, sequtils, strutils, sets],
stew/[results, base10],
@ -99,33 +101,21 @@ proc toString*(kind: ValidatorFilterKind): string =
func syncCommitteeParticipants*(forkedState: ForkedHashedBeaconState,
epoch: Epoch): Result[seq[ValidatorPubKey], cstring] =
case forkedState.beaconStateFork
of BeaconStateFork.forkPhase0:
err("State's fork do not support sync committees")
of BeaconStateFork.forkAltair:
let
headSlot = forkedState.hbsAltair.data.slot
epochPeriod = syncCommitteePeriod(epoch.compute_start_slot_at_epoch())
currentPeriod = syncCommitteePeriod(headSlot)
nextPeriod = currentPeriod + 1'u64
if epochPeriod == currentPeriod:
ok(@(forkedState.hbsAltair.data.current_sync_committee.pubkeys.data))
elif epochPeriod == nextPeriod:
ok(@(forkedState.hbsAltair.data.next_sync_committee.pubkeys.data))
withState(forkedState):
when stateFork >= forkAltair:
let
headSlot = state.data.slot
epochPeriod = syncCommitteePeriod(epoch.compute_start_slot_at_epoch())
currentPeriod = syncCommitteePeriod(headSlot)
nextPeriod = currentPeriod + 1'u64
if epochPeriod == currentPeriod:
ok(@(state.data.current_sync_committee.pubkeys.data))
elif epochPeriod == nextPeriod:
ok(@(state.data.next_sync_committee.pubkeys.data))
else:
err("Epoch is outside the sync committee period of the state")
else:
err("Epoch is outside the sync committee period of the state")
of BeaconStateFork.forkMerge:
let
headSlot = forkedState.hbsMerge.data.slot
epochPeriod = syncCommitteePeriod(epoch.compute_start_slot_at_epoch())
currentPeriod = syncCommitteePeriod(headSlot)
nextPeriod = currentPeriod + 1'u64
if epochPeriod == currentPeriod:
ok(@(forkedState.hbsMerge.data.current_sync_committee.pubkeys.data))
elif epochPeriod == nextPeriod:
ok(@(forkedState.hbsMerge.data.next_sync_committee.pubkeys.data))
else:
err("Epoch is outside the sync committee period of the state")
err("State's fork do not support sync committees")
proc installBeaconApiHandlers*(router: var RestRouter, node: BeaconNode) =
# https://ethereum.github.io/beacon-APIs/#/Beacon/getGenesis

View File

@ -1,3 +1,10 @@
# beacon_chain
# Copyright (c) 2021 Status Research & Development GmbH
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import std/sequtils
import chronicles
import ".."/[version, beacon_node_common],
@ -71,13 +78,8 @@ proc installDebugApiHandlers*(router: var RestRouter, node: BeaconNode) =
RestApiResponse.jsonResponsePlain(
ForkedBeaconState.init(stateData.data))
of "application/octet-stream":
case stateData.data.beaconStateFork
of BeaconStateFork.forkPhase0:
RestApiResponse.sszResponse(stateData.data.hbsPhase0.data)
of BeaconStateFork.forkAltair:
RestApiResponse.sszResponse(stateData.data.hbsAltair.data)
of BeaconStateFork.forkMerge:
RestApiResponse.sszResponse(stateData.data.hbsMerge.data)
withState(stateData.data):
RestApiResponse.sszResponse(state.data)
else:
RestApiResponse.jsonError(Http500, InvalidAcceptError)
return RestApiResponse.jsonError(Http404, StateNotFoundError)

View File

@ -137,27 +137,29 @@ template init*(T: type ForkedTrustedSignedBeaconBlock, blck: merge.TrustedSigned
template withState*(x: ForkedHashedBeaconState, body: untyped): untyped =
case x.beaconStateFork
of forkPhase0:
template state: untyped {.inject.} = x.hbsPhase0
of forkMerge:
const stateFork {.inject.} = forkMerge
template state: untyped {.inject.} = x.hbsMerge
body
of forkAltair:
const stateFork {.inject.} = forkAltair
template state: untyped {.inject.} = x.hbsAltair
body
of forkMerge:
template state: untyped {.inject.} = x.hbsMerge
of forkPhase0:
const stateFork {.inject.} = forkPhase0
template state: untyped {.inject.} = x.hbsPhase0
body
# Dispatch functions
func assign*(tgt: var ForkedHashedBeaconState, src: ForkedHashedBeaconState) =
if tgt.beaconStateFork == src.beaconStateFork:
if tgt.beaconStateFork == forkPhase0:
assign(tgt.hbsPhase0, src.hbsPhase0):
elif tgt.beaconStateFork == forkAltair:
assign(tgt.hbsAltair, src.hbsAltair):
elif tgt.beaconStateFork == forkMerge:
case tgt.beaconStateFork
of forkMerge:
assign(tgt.hbsMerge, src.hbsMerge):
else:
doAssert false
of forkAltair:
assign(tgt.hbsAltair, src.hbsAltair):
of forkPhase0:
assign(tgt.hbsPhase0, src.hbsPhase0):
else:
# Ensure case object and discriminator get updated simultaneously, even
# with nimOldCaseObjects. This is infrequent.
@ -170,9 +172,9 @@ template getStateField*(x: ForkedHashedBeaconState, y: untyped): untyped =
# ```
# Without `unsafeAddr`, the `validators` list would be copied to a temporary variable.
(case x.beaconStateFork
of forkPhase0: unsafeAddr x.hbsPhase0.data.y
of forkMerge: unsafeAddr x.hbsMerge.data.y
of forkAltair: unsafeAddr x.hbsAltair.data.y
of forkMerge: unsafeAddr x.hbsMerge.data.y)[]
of forkPhase0: unsafeAddr x.hbsPhase0.data.y)[]
func getStateRoot*(x: ForkedHashedBeaconState): Eth2Digest =
withState(x): state.root
@ -241,19 +243,9 @@ proc get_attesting_indices*(state: ForkedHashedBeaconState;
# iterator
var idxBuf: seq[ValidatorIndex]
if state.beaconStateFork == forkPhase0:
for vidx in state.hbsPhase0.data.get_attesting_indices(data, bits, cache):
withState(state):
for vidx in state.data.get_attesting_indices(data, bits, cache):
idxBuf.add vidx
elif state.beaconStateFork == forkAltair:
for vidx in state.hbsAltair.data.get_attesting_indices(data, bits, cache):
idxBuf.add vidx
elif state.beaconStateFork == forkMerge:
for vidx in state.hbsMerge.data.get_attesting_indices(data, bits, cache):
idxBuf.add vidx
else:
doAssert false
idxBuf
proc check_attester_slashing*(
@ -340,15 +332,21 @@ template asTrusted*(x: merge.SignedBeaconBlock or merge.SigVerifiedBeaconBlock):
template asTrusted*(x: ForkedSignedBeaconBlock): ForkedTrustedSignedBeaconBlock =
isomorphicCast[ForkedTrustedSignedBeaconBlock](x)
template withBlck*(x: ForkedBeaconBlock | ForkedSignedBeaconBlock | ForkedTrustedSignedBeaconBlock, body: untyped): untyped =
template withBlck*(
x: ForkedBeaconBlock | ForkedSignedBeaconBlock |
ForkedTrustedSignedBeaconBlock,
body: untyped): untyped =
case x.kind
of BeaconBlockFork.Phase0:
const stateFork {.inject.} = forkPhase0
template blck: untyped {.inject.} = x.phase0Block
body
of BeaconBlockFork.Altair:
const stateFork {.inject.} = forkAltair
template blck: untyped {.inject.} = x.altairBlock
body
of BeaconBlockFork.Merge:
const stateFork {.inject.} = forkMerge
template blck: untyped {.inject.} = x.mergeBlock
body
@ -387,6 +385,28 @@ chronicles.formatIt ForkedBeaconBlock: it.shortLog
chronicles.formatIt ForkedSignedBeaconBlock: it.shortLog
chronicles.formatIt ForkedTrustedSignedBeaconBlock: it.shortLog
template withStateAndBlck*(
s: ForkedHashedBeaconState,
b: ForkedBeaconBlock | ForkedSignedBeaconBlock |
ForkedTrustedSignedBeaconBlock,
body: untyped): untyped =
case s.beaconStateFork
of forkMerge:
const stateFork {.inject.} = forkMerge
template state: untyped {.inject.} = s.hbsMerge
template blck: untyped {.inject.} = b.mergeBlock
body
of forkAltair:
const stateFork {.inject.} = forkAltair
template state: untyped {.inject.} = s.hbsAltair
template blck: untyped {.inject.} = b.altairBlock
body
of forkPhase0:
const stateFork {.inject.} = forkPhase0
template state: untyped {.inject.} = s.hbsPhase0
template blck: untyped {.inject.} = b.phase0Block
body
proc forkAtEpoch*(cfg: RuntimeConfig, epoch: Epoch): Fork =
case cfg.stateForkAtEpoch(epoch)
of forkMerge: cfg.mergeFork

View File

@ -255,30 +255,30 @@ proc sendSyncCommitteeMessages*(node: BeaconNode,
return statuses.mapIt(it.get())
(resCur, resNxt)
template curParticipants(): untyped =
node.dag.headState.data.hbsAltair.data.current_sync_committee.pubkeys.data
template nxtParticipants(): untyped =
node.dag.headState.data.hbsAltair.data.next_sync_committee.pubkeys.data
let (pending, indices) =
block:
var resFutures: seq[Future[SendResult]]
var resIndices: seq[int]
for committeeIdx in allSyncCommittees():
for valKey in syncSubcommittee(curParticipants(), committeeIdx):
let index = keysCur.getOrDefault(valKey, -1)
if index >= 0:
resIndices.add(index)
resFutures.add(node.sendSyncCommitteeMessage(msgs[index],
committeeIdx, true))
for committeeIdx in allSyncCommittees():
for valKey in syncSubcommittee(nxtParticipants(), committeeIdx):
let index = keysNxt.getOrDefault(valKey, -1)
if index >= 0:
resIndices.add(index)
resFutures.add(node.sendSyncCommitteeMessage(msgs[index],
committeeIdx, true))
(resFutures, resIndices)
withState(node.dag.headState.data):
when stateFork >= forkAltair:
var resFutures: seq[Future[SendResult]]
var resIndices: seq[int]
for committeeIdx in allSyncCommittees():
for valKey in syncSubcommittee(
state.data.current_sync_committee.pubkeys.data, committeeIdx):
let index = keysCur.getOrDefault(valKey, -1)
if index >= 0:
resIndices.add(index)
resFutures.add(node.sendSyncCommitteeMessage(msgs[index],
committeeIdx, true))
for committeeIdx in allSyncCommittees():
for valKey in syncSubcommittee(
state.data.next_sync_committee.pubkeys.data, committeeIdx):
let index = keysNxt.getOrDefault(valKey, -1)
if index >= 0:
resIndices.add(index)
resFutures.add(node.sendSyncCommitteeMessage(msgs[index],
committeeIdx, true))
(resFutures, resIndices)
else:
raiseAssert "Sync committee not available in Phase0"
await allFutures(pending)
@ -444,8 +444,8 @@ proc makeBeaconBlockForHeadAndSlot*(node: BeaconNode,
node.exitPool[].getProposerSlashingsForBlock(),
node.exitPool[].getAttesterSlashingsForBlock(),
node.exitPool[].getVoluntaryExitsForBlock(),
if slot.epoch < node.dag.cfg.ALTAIR_FORK_EPOCH:
SyncAggregate(sync_committee_signature: ValidatorSig.infinity)
if slot.epoch < node.dag.cfg.ALTAIR_FORK_EPOCH:
SyncAggregate.init()
else:
node.sync_committee_msg_pool[].produceSyncAggregate(head.root),
default(merge.ExecutionPayload),

View File

@ -76,12 +76,8 @@ proc getTestStates*(
cfg, tmpState[], slot, cache, rewards, {})
if i mod 3 == 0:
if tmpState[].beaconStateFork == forkPhase0:
valid_deposit(tmpState[].hbsPhase0.data)
elif tmpState[].beaconStateFork == forkAltair:
valid_deposit(tmpState[].hbsAltair.data)
else:
valid_deposit(tmpState[].hbsMerge.data)
withState(tmpState[]):
valid_deposit(state.data)
doAssert getStateField(tmpState[], slot) == slot
if tmpState[].beaconStateFork == stateFork: