Consider fork transitions when using `get_sync_aggregate` helper func

When calling `get_sync_aggregate` with a `signature_slot - 1` in a
future fork, i.e., the first slot of the new fork is missed, it uses
a wrong fork version for the signature. Fix that by correctly applying
fork transitions to the `signature_state`, if a fork schedule is given.
This commit is contained in:
Etan Kissling 2024-01-09 16:03:38 +01:00
parent e52594634c
commit 240a127f9f
No known key found for this signature in database
GPG Key ID: B21DA824C5A3D03D
4 changed files with 76 additions and 32 deletions

View File

@ -16,16 +16,15 @@ from eth2spec.test.helpers.attestations import (
state_transition_with_full_block, state_transition_with_full_block,
) )
from eth2spec.test.helpers.constants import ( from eth2spec.test.helpers.constants import (
PHASE0, ALTAIR, BELLATRIX, CAPELLA, DENEB, ALTAIR, BELLATRIX, CAPELLA, DENEB,
MINIMAL, MINIMAL,
ALL_PHASES,
) )
from eth2spec.test.helpers.fork_transition import ( from eth2spec.test.helpers.fork_transition import (
do_fork, do_fork,
) )
from eth2spec.test.helpers.forks import ( from eth2spec.test.helpers.forks import (
get_spec_for_fork_version,
is_post_capella, is_post_deneb, is_post_capella, is_post_deneb,
is_post_fork,
) )
from eth2spec.test.helpers.light_client import ( from eth2spec.test.helpers.light_client import (
get_sync_aggregate, get_sync_aggregate,
@ -36,19 +35,6 @@ from eth2spec.test.helpers.state import (
) )
def get_spec_for_fork_version(spec, fork_version, phases):
if phases is None:
return spec
for fork in [fork for fork in ALL_PHASES if is_post_fork(spec.fork, fork)]:
if fork == PHASE0:
fork_version_field = 'GENESIS_FORK_VERSION'
else:
fork_version_field = fork.upper() + '_FORK_VERSION'
if fork_version == getattr(spec.config, fork_version_field):
return phases[fork]
raise ValueError("Unknown fork version %s" % fork_version)
def needs_upgrade_to_capella(d_spec, s_spec): def needs_upgrade_to_capella(d_spec, s_spec):
return is_post_capella(s_spec) and not is_post_capella(d_spec) return is_post_capella(s_spec) and not is_post_capella(d_spec)
@ -628,7 +614,7 @@ def run_test_single_fork(spec, phases, state, fork):
finalized_state = state.copy() finalized_state = state.copy()
attested_block = state_transition_with_full_block(spec, state, True, True) attested_block = state_transition_with_full_block(spec, state, True, True)
attested_state = state.copy() attested_state = state.copy()
sync_aggregate, _ = get_sync_aggregate(spec, state) sync_aggregate, _ = get_sync_aggregate(spec, state, phases=phases)
block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate) block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate)
yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases) yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases)
assert test.store.finalized_header.beacon.slot == finalized_state.slot assert test.store.finalized_header.beacon.slot == finalized_state.slot
@ -641,7 +627,7 @@ def run_test_single_fork(spec, phases, state, fork):
transition_to(spec, state, spec.compute_start_slot_at_epoch(fork_epoch) - 4) transition_to(spec, state, spec.compute_start_slot_at_epoch(fork_epoch) - 4)
attested_block = state_transition_with_full_block(spec, state, True, True) attested_block = state_transition_with_full_block(spec, state, True, True)
attested_state = state.copy() attested_state = state.copy()
sync_aggregate, _ = get_sync_aggregate(spec, state) sync_aggregate, _ = get_sync_aggregate(spec, state, phases=phases)
block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate) block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate)
update = yield from emit_update( update = yield from emit_update(
test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases) test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases)
@ -657,7 +643,7 @@ def run_test_single_fork(spec, phases, state, fork):
# Final slot before fork, check that importing the pre-fork format still works # Final slot before fork, check that importing the pre-fork format still works
attested_block = block.copy() attested_block = block.copy()
attested_state = state.copy() attested_state = state.copy()
sync_aggregate, _ = get_sync_aggregate(spec, state) sync_aggregate, _ = get_sync_aggregate(spec, state, phases=phases)
block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate) block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate)
yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases) yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases)
assert test.store.finalized_header.beacon.slot == finalized_state.slot assert test.store.finalized_header.beacon.slot == finalized_state.slot
@ -668,7 +654,7 @@ def run_test_single_fork(spec, phases, state, fork):
# Upgrade to post-fork spec, attested block is still before the fork # Upgrade to post-fork spec, attested block is still before the fork
attested_block = block.copy() attested_block = block.copy()
attested_state = state.copy() attested_state = state.copy()
sync_aggregate, _ = get_sync_aggregate(phases[fork], state) sync_aggregate, _ = get_sync_aggregate(phases[fork], state, phases=phases)
state, block = do_fork(state, spec, phases[fork], fork_epoch, sync_aggregate=sync_aggregate) state, block = do_fork(state, spec, phases[fork], fork_epoch, sync_aggregate=sync_aggregate)
spec = phases[fork] spec = phases[fork]
yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases) yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases)
@ -680,7 +666,7 @@ def run_test_single_fork(spec, phases, state, fork):
# Another block after the fork, this time attested block is after the fork # Another block after the fork, this time attested block is after the fork
attested_block = block.copy() attested_block = block.copy()
attested_state = state.copy() attested_state = state.copy()
sync_aggregate, _ = get_sync_aggregate(spec, state) sync_aggregate, _ = get_sync_aggregate(spec, state, phases=phases)
block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate) block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate)
yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases) yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases)
assert test.store.finalized_header.beacon.slot == finalized_state.slot assert test.store.finalized_header.beacon.slot == finalized_state.slot
@ -692,7 +678,7 @@ def run_test_single_fork(spec, phases, state, fork):
transition_to(spec, state, spec.compute_start_slot_at_epoch(fork_epoch + 1) - 2) transition_to(spec, state, spec.compute_start_slot_at_epoch(fork_epoch + 1) - 2)
attested_block = state_transition_with_full_block(spec, state, True, True) attested_block = state_transition_with_full_block(spec, state, True, True)
attested_state = state.copy() attested_state = state.copy()
sync_aggregate, _ = get_sync_aggregate(spec, state) sync_aggregate, _ = get_sync_aggregate(spec, state, phases=phases)
block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate) block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate)
yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases) yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases)
assert test.store.finalized_header.beacon.slot == finalized_state.slot assert test.store.finalized_header.beacon.slot == finalized_state.slot
@ -706,7 +692,7 @@ def run_test_single_fork(spec, phases, state, fork):
_, _, state = next_slots_with_attestations(spec, state, 2 * spec.SLOTS_PER_EPOCH - 1, True, True) _, _, state = next_slots_with_attestations(spec, state, 2 * spec.SLOTS_PER_EPOCH - 1, True, True)
attested_block = state_transition_with_full_block(spec, state, True, True) attested_block = state_transition_with_full_block(spec, state, True, True)
attested_state = state.copy() attested_state = state.copy()
sync_aggregate, _ = get_sync_aggregate(spec, state) sync_aggregate, _ = get_sync_aggregate(spec, state, phases=phases)
block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate) block = state_transition_with_full_block(spec, state, True, True, sync_aggregate=sync_aggregate)
yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases) yield from emit_update(test, spec, state, block, attested_state, attested_block, finalized_block, phases=phases)
assert test.store.finalized_header.beacon.slot == finalized_state.slot assert test.store.finalized_header.beacon.slot == finalized_state.slot

View File

@ -3,7 +3,10 @@ from enum import Enum, auto
from eth2spec.test.helpers.attester_slashings import ( from eth2spec.test.helpers.attester_slashings import (
get_valid_attester_slashing_by_indices, get_valid_attester_slashing_by_indices,
) )
from eth2spec.test.helpers.attestations import next_slots_with_attestations from eth2spec.test.helpers.attestations import (
next_slots_with_attestations,
state_transition_with_full_block,
)
from eth2spec.test.helpers.block import ( from eth2spec.test.helpers.block import (
build_empty_block_for_next_slot, build_empty_block_for_next_slot,
build_empty_block, build_empty_block,
@ -21,6 +24,9 @@ from eth2spec.test.helpers.deposits import (
from eth2spec.test.helpers.proposer_slashings import ( from eth2spec.test.helpers.proposer_slashings import (
get_valid_proposer_slashing, get_valid_proposer_slashing,
) )
from eth2spec.test.helpers.forks import (
get_next_fork_transition,
)
from eth2spec.test.helpers.state import ( from eth2spec.test.helpers.state import (
next_slot, next_slot,
state_transition_and_sign_block, state_transition_and_sign_block,
@ -196,6 +202,34 @@ def _transition_until_fork_minus_one(spec, state, fork_epoch):
transition_to(spec, state, to_slot) transition_to(spec, state, to_slot)
def transition_across_forks(spec, state, to_slot, phases=None, with_block=False, sync_aggregate=None):
assert to_slot > state.slot
state = state.copy()
block = None
to_epoch = spec.compute_epoch_at_slot(to_slot)
while state.slot < to_slot:
assert block is None
epoch = spec.compute_epoch_at_slot(state.slot)
post_spec, fork_epoch = get_next_fork_transition(spec, epoch, phases)
if fork_epoch is None or to_epoch < fork_epoch:
if with_block and (to_slot == state.slot + 1):
transition_to(spec, state, to_slot - 1)
block = state_transition_with_full_block(
spec, state, True, True,
sync_aggregate=sync_aggregate)
else:
transition_to(spec, state, to_slot)
else:
transition_until_fork(spec, state, fork_epoch)
state, block = do_fork(
state, spec, post_spec, fork_epoch,
with_block=with_block and (to_slot == state.slot + 1),
sync_aggregate=sync_aggregate,
)
spec = post_spec
return spec, state, block
def transition_to_next_epoch_and_append_blocks(spec, def transition_to_next_epoch_and_append_blocks(spec,
state, state,
post_tag, post_tag,

View File

@ -1,5 +1,5 @@
from .constants import ( from .constants import (
ALTAIR, BELLATRIX, CAPELLA, DENEB, PHASE0, ALTAIR, BELLATRIX, CAPELLA, DENEB,
EIP6110, EIP7002, WHISK, EIP6110, EIP7002, WHISK,
PREVIOUS_FORK_OF, PREVIOUS_FORK_OF,
) )
@ -47,3 +47,27 @@ def is_post_eip7002(spec):
def is_post_whisk(spec): def is_post_whisk(spec):
return is_post_fork(spec.fork, WHISK) return is_post_fork(spec.fork, WHISK)
def get_spec_for_fork_version(spec, fork_version, phases):
if phases is None:
return spec
for fork in [fork for fork in phases if is_post_fork(spec.fork, fork)]:
if fork == PHASE0:
fork_version_field = 'GENESIS_FORK_VERSION'
else:
fork_version_field = fork.upper() + '_FORK_VERSION'
if fork_version == getattr(spec.config, fork_version_field):
return phases[fork]
raise ValueError("Unknown fork version %s" % fork_version)
def get_next_fork_transition(spec, epoch, phases):
if phases is None:
return None, None
for fork in [fork for fork in phases if PREVIOUS_FORK_OF[fork] == spec.fork]:
assert fork != PHASE0 # PHASE0 does not have previous fork
fork_epoch = getattr(phases[fork].config, fork.upper() + '_FORK_EPOCH')
assert fork_epoch > epoch # Forks through given epoch already applied
return phases[fork], fork_epoch
return None, None # Already at latest fork

View File

@ -1,5 +1,5 @@
from eth2spec.test.helpers.state import ( from eth2spec.test.helpers.fork_transition import (
transition_to, transition_across_forks,
) )
from eth2spec.test.helpers.sync_committee import ( from eth2spec.test.helpers.sync_committee import (
compute_aggregate_sync_committee_signature, compute_aggregate_sync_committee_signature,
@ -8,14 +8,14 @@ from eth2spec.test.helpers.sync_committee import (
from math import floor from math import floor
def get_sync_aggregate(spec, state, num_participants=None, signature_slot=None): def get_sync_aggregate(spec, state, num_participants=None, signature_slot=None, phases=None):
# By default, the sync committee signs the previous slot # By default, the sync committee signs the previous slot
if signature_slot is None: if signature_slot is None:
signature_slot = state.slot + 1 signature_slot = state.slot + 1
assert signature_slot > state.slot
# Ensure correct sync committee and fork version are selected # Ensure correct sync committee and fork version are selected
signature_state = state.copy() signature_spec, signature_state, _ = transition_across_forks(spec, state, signature_slot, phases)
transition_to(spec, signature_state, signature_slot)
# Fetch sync committee # Fetch sync committee
committee_indices = compute_committee_indices(signature_state) committee_indices = compute_committee_indices(signature_state)
@ -29,12 +29,12 @@ def get_sync_aggregate(spec, state, num_participants=None, signature_slot=None):
# Compute sync aggregate # Compute sync aggregate
sync_committee_bits = [True] * num_participants + [False] * (committee_size - num_participants) sync_committee_bits = [True] * num_participants + [False] * (committee_size - num_participants)
sync_committee_signature = compute_aggregate_sync_committee_signature( sync_committee_signature = compute_aggregate_sync_committee_signature(
spec, signature_spec,
signature_state, signature_state,
max(signature_slot, 1) - 1, max(signature_slot, 1) - 1,
committee_indices[:num_participants], committee_indices[:num_participants],
) )
sync_aggregate = spec.SyncAggregate( sync_aggregate = signature_spec.SyncAggregate(
sync_committee_bits=sync_committee_bits, sync_committee_bits=sync_committee_bits,
sync_committee_signature=sync_committee_signature, sync_committee_signature=sync_committee_signature,
) )