implement spectest decorator, update attestation tests

This commit is contained in:
protolambda 2019-05-06 00:31:57 +02:00
parent 4e179fb801
commit 8b24abde31
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
4 changed files with 104 additions and 64 deletions

View File

@ -8,8 +8,7 @@ from eth2spec.phase0.state_transition import (
) )
from eth2spec.phase0.spec import ( from eth2spec.phase0.spec import (
get_current_epoch, get_current_epoch,
process_attestation, process_attestation
slot_to_epoch,
) )
from tests.helpers import ( from tests.helpers import (
build_empty_block_for_next_slot, build_empty_block_for_next_slot,
@ -18,63 +17,75 @@ from tests.helpers import (
next_slot, next_slot,
) )
from tests.utils import spectest
# mark entire file as 'attestations' from tests.context import with_state
pytestmark = pytest.mark.attestations
def run_attestation_processing(state, attestation, valid=True): def run_attestation_processing(state, attestation, valid=True):
""" """
Run ``process_attestation`` returning the pre and post state. Run ``process_attestation``, yielding pre-state ('pre'), attestation ('attestation'), and post-state ('post').
If ``valid == False``, run expecting ``AssertionError`` If ``valid == False``, run expecting ``AssertionError``
""" """
post_state = deepcopy(state) # yield pre-state
yield 'pre', state
yield 'attestation', attestation
# If the attestation is invalid, processing is aborted, and there is no post-state.
if not valid: if not valid:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
process_attestation(post_state, attestation) process_attestation(state, attestation)
return state, None yield 'post', None
return
process_attestation(post_state, attestation) current_epoch_count = len(state.current_epoch_attestations)
previous_epoch_count = len(state.previous_epoch_attestations)
current_epoch = get_current_epoch(state) # process attestation
if attestation.data.target_epoch == current_epoch: process_attestation(state, attestation)
assert len(post_state.current_epoch_attestations) == len(state.current_epoch_attestations) + 1
# Make sure the attestation has been processed
if attestation.data.target_epoch == get_current_epoch(state):
assert len(state.current_epoch_attestations) == current_epoch_count + 1
else: else:
assert len(post_state.previous_epoch_attestations) == len(state.previous_epoch_attestations) + 1 assert len(state.previous_epoch_attestations) == previous_epoch_count + 1
return state, post_state # yield post-state
yield 'post', state
# shorthand for decorating @with_state @spectest()
def attestation_test(fn):
return with_state(spectest()(fn))
@attestation_test
def test_success(state): def test_success(state):
attestation = get_valid_attestation(state) attestation = get_valid_attestation(state)
state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY
pre_state, post_state = run_attestation_processing(state, attestation) yield from run_attestation_processing(state, attestation)
return pre_state, attestation, post_state
@attestation_test
def test_success_prevous_epoch(state): def test_success_prevous_epoch(state):
attestation = get_valid_attestation(state) attestation = get_valid_attestation(state)
block = build_empty_block_for_next_slot(state) block = build_empty_block_for_next_slot(state)
block.slot = state.slot + spec.SLOTS_PER_EPOCH block.slot = state.slot + spec.SLOTS_PER_EPOCH
state_transition(state, block) state_transition(state, block)
pre_state, post_state = run_attestation_processing(state, attestation) yield from run_attestation_processing(state, attestation)
return pre_state, attestation, post_state
@attestation_test
def test_before_inclusion_delay(state): def test_before_inclusion_delay(state):
attestation = get_valid_attestation(state) attestation = get_valid_attestation(state)
# do not increment slot to allow for inclusion delay # do not increment slot to allow for inclusion delay
pre_state, post_state = run_attestation_processing(state, attestation, False) yield from run_attestation_processing(state, attestation, False)
return pre_state, attestation, post_state
@attestation_test
def test_after_epoch_slots(state): def test_after_epoch_slots(state):
attestation = get_valid_attestation(state) attestation = get_valid_attestation(state)
block = build_empty_block_for_next_slot(state) block = build_empty_block_for_next_slot(state)
@ -82,44 +93,40 @@ def test_after_epoch_slots(state):
block.slot = state.slot + spec.SLOTS_PER_EPOCH + 1 block.slot = state.slot + spec.SLOTS_PER_EPOCH + 1
state_transition(state, block) state_transition(state, block)
pre_state, post_state = run_attestation_processing(state, attestation, False) yield from run_attestation_processing(state, attestation, False)
return pre_state, attestation, post_state
@attestation_test
def test_bad_source_epoch(state): def test_bad_source_epoch(state):
attestation = get_valid_attestation(state) attestation = get_valid_attestation(state)
state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY
attestation.data.source_epoch += 10 attestation.data.source_epoch += 10
pre_state, post_state = run_attestation_processing(state, attestation, False) yield from run_attestation_processing(state, attestation, False)
return pre_state, attestation, post_state
@attestation_test
def test_bad_source_root(state): def test_bad_source_root(state):
attestation = get_valid_attestation(state) attestation = get_valid_attestation(state)
state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY
attestation.data.source_root = b'\x42' * 32 attestation.data.source_root = b'\x42' * 32
pre_state, post_state = run_attestation_processing(state, attestation, False) yield from run_attestation_processing(state, attestation, False)
return pre_state, attestation, post_state
@attestation_test
def test_non_zero_crosslink_data_root(state): def test_non_zero_crosslink_data_root(state):
attestation = get_valid_attestation(state) attestation = get_valid_attestation(state)
state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY
attestation.data.crosslink_data_root = b'\x42' * 32 attestation.data.crosslink_data_root = b'\x42' * 32
pre_state, post_state = run_attestation_processing(state, attestation, False) yield from run_attestation_processing(state, attestation, False)
return pre_state, attestation, post_state
@attestation_test
def test_bad_previous_crosslink(state): def test_bad_previous_crosslink(state):
next_epoch(state) next_epoch(state)
attestation = get_valid_attestation(state) attestation = get_valid_attestation(state)
@ -128,28 +135,24 @@ def test_bad_previous_crosslink(state):
state.current_crosslinks[attestation.data.shard].epoch += 10 state.current_crosslinks[attestation.data.shard].epoch += 10
pre_state, post_state = run_attestation_processing(state, attestation, False) yield from run_attestation_processing(state, attestation, False)
return pre_state, attestation, post_state
@attestation_test
def test_non_empty_custody_bitfield(state): def test_non_empty_custody_bitfield(state):
attestation = get_valid_attestation(state) attestation = get_valid_attestation(state)
state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY
attestation.custody_bitfield = deepcopy(attestation.aggregation_bitfield) attestation.custody_bitfield = deepcopy(attestation.aggregation_bitfield)
pre_state, post_state = run_attestation_processing(state, attestation, False) yield from run_attestation_processing(state, attestation, False)
return pre_state, attestation, post_state
@attestation_test
def test_empty_aggregation_bitfield(state): def test_empty_aggregation_bitfield(state):
attestation = get_valid_attestation(state) attestation = get_valid_attestation(state)
state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY
attestation.aggregation_bitfield = b'\x00' * len(attestation.aggregation_bitfield) attestation.aggregation_bitfield = b'\x00' * len(attestation.aggregation_bitfield)
pre_state, post_state = run_attestation_processing(state, attestation, False) yield from run_attestation_processing(state, attestation, False)
return pre_state, attestation, post_state

View File

@ -3,10 +3,6 @@ import pytest
from eth2spec.phase0 import spec from eth2spec.phase0 import spec
from preset_loader import loader from preset_loader import loader
from .helpers import (
create_genesis_state,
)
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption( parser.addoption(
@ -19,18 +15,3 @@ def config(request):
config_name = request.config.getoption("--config") config_name = request.config.getoption("--config")
presets = loader.load_presets('../../configs/', config_name) presets = loader.load_presets('../../configs/', config_name)
spec.apply_constants_preset(presets) spec.apply_constants_preset(presets)
@pytest.fixture
def num_validators(config):
return spec.SLOTS_PER_EPOCH * 8
@pytest.fixture
def deposit_data_leaves():
return list()
@pytest.fixture
def state(num_validators, deposit_data_leaves):
return create_genesis_state(num_validators, deposit_data_leaves)

View File

@ -0,0 +1,10 @@
from eth2spec.phase0 import spec
from tests.utils import with_args
from .helpers import (
create_genesis_state,
)
# Provides a genesis state as first argument to the function decorated with this
with_state = with_args(lambda: [create_genesis_state(spec.SLOTS_PER_EPOCH * 8, list())])

View File

@ -0,0 +1,46 @@
from eth2spec.debug.encode import encode
def spectest(description: str = None):
def runner(fn):
# this wraps the function, to hide that the function actually yielding data.
def entry(*args, **kw):
# check generator mode, may be None/else.
# "pop" removes it, so it is not passed to the inner function.
if kw.pop('generator_mode', False) is True:
out = {}
if description is None:
# fall back on function name for test description
name = fn.__name__
if name.startswith('test_'):
name = name[5:]
out['description'] = name
else:
# description can be explicit
out['description'] = description
# put all generated data into a dict.
for data in fn(*args, **kw):
# If there is a type argument, encode it as that type.
if len(data) == 3:
(key, value, typ) = data
out[key] = encode(value, typ)
else:
# Otherwise, just put the raw value.
(key, value) = data
out[key] = value
return out
else:
# just complete the function, ignore all yielded data, we are not using it
for _ in fn(*args, **kw):
continue
return entry
return runner
def with_args(create_args):
def runner(fn):
# this wraps the function, to hide that the function actually yielding data.
def entry(*args, **kw):
return fn(*(create_args() + list(args)), **kw)
return entry
return runner