refactor finalization test helper func

This commit is contained in:
protolambda 2019-06-29 01:22:29 +02:00
parent 384fa8854a
commit 3a60f64b92
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
1 changed files with 27 additions and 17 deletions

View File

@ -1,3 +1,4 @@
import math
from eth2spec.test.context import spec_state_test, with_all_phases from eth2spec.test.context import spec_state_test, with_all_phases
from eth2spec.test.phase_0.epoch_processing.run_epoch_process_base import ( from eth2spec.test.phase_0.epoch_processing.run_epoch_process_base import (
run_epoch_processing_with run_epoch_processing_with
@ -8,41 +9,46 @@ def run_process_just_and_fin(spec, state):
yield from run_epoch_processing_with(spec, state, 'process_justification_and_finalization') yield from run_epoch_processing_with(spec, state, 'process_justification_and_finalization')
def get_committee_size(spec, state, slot): def get_shards_for_slot(spec, state, slot):
epoch = spec.slot_to_epoch(slot) epoch = spec.slot_to_epoch(slot)
epoch_start_shard = spec.get_epoch_start_shard(state, epoch) epoch_start_shard = spec.get_epoch_start_shard(state, epoch)
committees_per_slot = spec.get_epoch_committee_count(state, epoch) // spec.SLOTS_PER_EPOCH committees_per_slot = spec.get_epoch_committee_count(state, epoch) // spec.SLOTS_PER_EPOCH
shard = (epoch_start_shard + committees_per_slot * (slot % spec.SLOTS_PER_EPOCH)) % spec.SHARD_COUNT shard = (epoch_start_shard + committees_per_slot * (slot % spec.SLOTS_PER_EPOCH)) % spec.SHARD_COUNT
committee_index = (shard + spec.SHARD_COUNT - spec.get_epoch_start_shard(state, epoch)) % spec.SHARD_COUNT return [shard + i for i in range(committees_per_slot)]
committee_count = spec.get_epoch_committee_count(state, epoch)
indices = spec.get_active_validator_indices(state, epoch)
def get_committee_size(spec, epoch_start_shard, shard, committee_count, indices):
committee_index = (shard + spec.SHARD_COUNT - epoch_start_shard) % spec.SHARD_COUNT
start = (len(indices) * committee_index) // committee_count start = (len(indices) * committee_index) // committee_count
end = (len(indices) * (committee_index + 1)) // committee_count end = (len(indices) * (committee_index + 1)) // committee_count
size = end - start size = end - start
return size return size,
def add_mock_attestations(spec, state, target_epoch, att_count, att_ratio): def add_mock_attestations(spec, state, epoch, att_count, att_ratio):
# we must be at the end of the epoch # we must be at the end of the epoch
assert (state.slot + 1) % spec.SLOTS_PER_EPOCH == 0 assert (state.slot + 1) % spec.SLOTS_PER_EPOCH == 0
previous_epoch = spec.get_previous_epoch(state) previous_epoch = spec.get_previous_epoch(state)
current_epoch = spec.get_current_epoch(state) current_epoch = spec.get_current_epoch(state)
if current_epoch == target_epoch: if current_epoch == epoch:
attestations = state.current_epoch_attestations attestations = state.current_epoch_attestations
elif previous_epoch == target_epoch: elif previous_epoch == epoch:
attestations = state.previous_epoch_attestations attestations = state.previous_epoch_attestations
else: else:
raise Exception(f"cannot target epoch ${target_epoch} from epoch ${current_epoch}") raise Exception(f"cannot include attestations in epoch ${epoch} from epoch ${current_epoch}")
committee_count = spec.get_epoch_committee_count(state, epoch)
indices = spec.get_active_validator_indices(state, epoch)
epoch_start_shard = spec.get_epoch_start_shard(state, epoch)
total = 0 total = 0
while total < att_count: for i in range(spec.SLOTS_PER_EPOCH):
for i in range(spec.SLOTS_PER_EPOCH): for shard in get_shards_for_slot(spec, state, state.slot + i):
size = get_committee_size(spec, state, state.slot + i) size = get_committee_size(spec, epoch_start_shard, shard, committee_count, indices)
# Create a bitfield filled with the given count per attestation, # Create a bitfield filled with the given count per attestation,
# exactly on the right-most part of the committee field. # exactly on the right-most part of the committee field.
attesting_count = int(size * att_ratio) attesting_count = math.ceil(size * att_ratio)
aggregation_bitfield = ((1 << attesting_count) - 1).to_bytes(length=((size + 7) // 8), byteorder='big') aggregation_bitfield = ((1 << attesting_count) - 1).to_bytes(length=((size + 7) // 8), byteorder='big')
attestations.append(spec.PendingAttestation( attestations.append(spec.PendingAttestation(
@ -54,16 +60,20 @@ def add_mock_attestations(spec, state, target_epoch, att_count, att_ratio):
target_root=b'\xbb' * 32, target_root=b'\xbb' * 32,
crosslink=spec.Crosslink() crosslink=spec.Crosslink()
), ),
inclusion_delay=0, inclusion_delay=1,
)) ))
total += 1 total += 1
if total >= att_count:
return
raise Exception(f"could not fill state with {att_count} attestations for epoch {epoch}")
@with_all_phases @with_all_phases
@spec_state_test @spec_state_test
def test_rule_1(spec, state): def test_rule_1(spec, state):
previous_epoch = spec.get_previous_epoch(state) # previous_epoch = spec.get_previous_epoch(state)
current_epoch = spec.get_current_epoch(state) # current_epoch = spec.get_current_epoch(state)
# TODO # TODO
# add_mock_attestations(spec, state, ...) # add_mock_attestations(spec, state, ...)
@ -71,4 +81,4 @@ def test_rule_1(spec, state):
# set their balances # set their balances
# yield from run_process_just_and_fin(spec, state) # yield from run_process_just_and_fin(spec, state)
# check finalization # check finalization
pass