Fix base-reward memoization bug, improve memoization with LRU, and improve misc rewards test

This commit is contained in:
protolambda 2020-03-20 20:38:36 +01:00
parent e429030ded
commit 33f8f4936d
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
4 changed files with 58 additions and 24 deletions

View File

@ -77,6 +77,10 @@ test: pyspec
. venv/bin/activate; cd $(PY_SPEC_DIR); \ . venv/bin/activate; cd $(PY_SPEC_DIR); \
python -m pytest -n 4 --cov=eth2spec.phase0.spec --cov=eth2spec.phase1.spec --cov-report="html:$(COV_HTML_OUT)" --cov-branch eth2spec python -m pytest -n 4 --cov=eth2spec.phase0.spec --cov=eth2spec.phase1.spec --cov-report="html:$(COV_HTML_OUT)" --cov-branch eth2spec
find_test: pyspec
. venv/bin/activate; cd $(PY_SPEC_DIR); \
python -m pytest -k=$(K) --cov=eth2spec.phase0.spec --cov=eth2spec.phase1.spec --cov-report="html:$(COV_HTML_OUT)" --cov-branch eth2spec
citest: pyspec citest: pyspec
mkdir -p tests/core/pyspec/test-reports/eth2spec; . venv/bin/activate; cd $(PY_SPEC_DIR); \ mkdir -p tests/core/pyspec/test-reports/eth2spec; . venv/bin/activate; cd $(PY_SPEC_DIR); \
python -m pytest -n 4 --junitxml=eth2spec/test_results.xml eth2spec python -m pytest -n 4 --junitxml=eth2spec/test_results.xml eth2spec

View File

@ -92,6 +92,8 @@ from dataclasses import (
field, field,
) )
from lru import LRU
from eth2spec.utils.ssz.ssz_impl import hash_tree_root from eth2spec.utils.ssz.ssz_impl import hash_tree_root
from eth2spec.utils.ssz.ssz_typing import ( from eth2spec.utils.ssz.ssz_typing import (
View, boolean, Container, List, Vector, uint64, View, boolean, Container, List, Vector, uint64,
@ -114,6 +116,8 @@ from dataclasses import (
field, field,
) )
from lru import LRU
from eth2spec.utils.ssz.ssz_impl import hash_tree_root from eth2spec.utils.ssz.ssz_impl import hash_tree_root
from eth2spec.utils.ssz.ssz_typing import ( from eth2spec.utils.ssz.ssz_typing import (
View, boolean, Container, List, Vector, uint64, uint8, bit, View, boolean, Container, List, Vector, uint64, uint8, bit,
@ -152,8 +156,8 @@ def hash(x: bytes) -> Bytes32: # type: ignore
return hash_cache[x] return hash_cache[x]
def cache_this(key_fn, value_fn): # type: ignore def cache_this(key_fn, value_fn, lru_size): # type: ignore
cache_dict = {} # type: ignore cache_dict = LRU(size=lru_size)
def wrapper(*args, **kw): # type: ignore def wrapper(*args, **kw): # type: ignore
key = key_fn(*args, **kw) key = key_fn(*args, **kw)
@ -164,35 +168,50 @@ def cache_this(key_fn, value_fn): # type: ignore
return wrapper return wrapper
_compute_shuffled_index = compute_shuffled_index
compute_shuffled_index = cache_this(
lambda index, index_count, seed: (index, index_count, seed),
_compute_shuffled_index, lru_size=SLOTS_PER_EPOCH * 3)
_get_total_active_balance = get_total_active_balance
get_total_active_balance = cache_this(
lambda state: (state.validators.hash_tree_root(), state.slot),
_get_total_active_balance, lru_size=10)
_get_base_reward = get_base_reward _get_base_reward = get_base_reward
get_base_reward = cache_this( get_base_reward = cache_this(
lambda state, index: (state.validators.hash_tree_root(), state.slot), lambda state, index: (state.validators.hash_tree_root(), state.slot, index),
_get_base_reward) _get_base_reward, lru_size=10)
_get_committee_count_at_slot = get_committee_count_at_slot _get_committee_count_at_slot = get_committee_count_at_slot
get_committee_count_at_slot = cache_this( get_committee_count_at_slot = cache_this(
lambda state, epoch: (state.validators.hash_tree_root(), epoch), lambda state, epoch: (state.validators.hash_tree_root(), epoch),
_get_committee_count_at_slot) _get_committee_count_at_slot, lru_size=SLOTS_PER_EPOCH * 3)
_get_active_validator_indices = get_active_validator_indices _get_active_validator_indices = get_active_validator_indices
get_active_validator_indices = cache_this( get_active_validator_indices = cache_this(
lambda state, epoch: (state.validators.hash_tree_root(), epoch), lambda state, epoch: (state.validators.hash_tree_root(), epoch),
_get_active_validator_indices) _get_active_validator_indices, lru_size=3)
_get_beacon_committee = get_beacon_committee _get_beacon_committee = get_beacon_committee
get_beacon_committee = cache_this( get_beacon_committee = cache_this(
lambda state, slot, index: (state.validators.hash_tree_root(), state.randao_mixes.hash_tree_root(), slot, index), lambda state, slot, index: (state.validators.hash_tree_root(), state.randao_mixes.hash_tree_root(), slot, index),
_get_beacon_committee) _get_beacon_committee, lru_size=SLOTS_PER_EPOCH * MAX_COMMITTEES_PER_SLOT * 3)
_get_matching_target_attestations = get_matching_target_attestations _get_matching_target_attestations = get_matching_target_attestations
get_matching_target_attestations = cache_this( get_matching_target_attestations = cache_this(
lambda state, epoch: (state.hash_tree_root(), epoch), lambda state, epoch: (state.hash_tree_root(), epoch),
_get_matching_target_attestations) _get_matching_target_attestations, lru_size=10)
_get_matching_head_attestations = get_matching_head_attestations _get_matching_head_attestations = get_matching_head_attestations
get_matching_head_attestations = cache_this( get_matching_head_attestations = cache_this(
lambda state, epoch: (state.hash_tree_root(), epoch), lambda state, epoch: (state.hash_tree_root(), epoch),
_get_matching_head_attestations)''' _get_matching_head_attestations, lru_size=10)
_get_attesting_indices = get_attesting_indices
get_attesting_indices = cache_this(lambda state, data, bits:
(state.validators.hash_tree_root(), data.hash_tree_root(), bits.hash_tree_root()),
_get_attesting_indices, lru_size=SLOTS_PER_EPOCH * MAX_COMMITTEES_PER_SLOT * 3)'''
def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str: def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str:
@ -481,6 +500,7 @@ setup(
"py_ecc==2.0.0", "py_ecc==2.0.0",
"dataclasses==0.6", "dataclasses==0.6",
"remerkleable==0.1.12", "remerkleable==0.1.12",
"ruamel.yaml==0.16.5" "ruamel.yaml==0.16.5",
"lru-dict==1.1.6"
] ]
) )

View File

@ -6,6 +6,7 @@ from .helpers.genesis import create_genesis_state
from .utils import vector_test, with_meta_tags from .utils import vector_test, with_meta_tags
from random import Random
from typing import Any, Callable, Sequence, TypedDict, Protocol from typing import Any, Callable, Sequence, TypedDict, Protocol
from importlib import reload from importlib import reload
@ -100,8 +101,10 @@ def misc_balances(spec):
Usage: `@with_custom_state(balances_fn=misc_balances, ...)` Usage: `@with_custom_state(balances_fn=misc_balances, ...)`
""" """
num_validators = spec.SLOTS_PER_EPOCH * 8 num_validators = spec.SLOTS_PER_EPOCH * 8
num_misc_validators = spec.SLOTS_PER_EPOCH balances = [spec.MAX_EFFECTIVE_BALANCE * 2 * i // num_validators for i in range(num_validators)]
return [spec.MAX_EFFECTIVE_BALANCE] * num_validators + [spec.MIN_DEPOSIT_AMOUNT] * num_misc_validators rng = Random(1234)
rng.shuffle(balances)
return balances
def single_phase(fn): def single_phase(fn):

View File

@ -1,8 +1,6 @@
from copy import deepcopy
from eth2spec.test.context import ( from eth2spec.test.context import (
spec_state_test, with_all_phases, spec_test, spec_state_test, with_all_phases, spec_test,
misc_balances, with_custom_state, default_activation_threshold, misc_balances, with_custom_state,
single_phase, single_phase,
) )
from eth2spec.test.helpers.state import ( from eth2spec.test.helpers.state import (
@ -24,7 +22,7 @@ def run_process_rewards_and_penalties(spec, state):
@with_all_phases @with_all_phases
@spec_state_test @spec_state_test
def test_genesis_epoch_no_attestations_no_penalties(spec, state): def test_genesis_epoch_no_attestations_no_penalties(spec, state):
pre_state = deepcopy(state) pre_state = state.copy()
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH
@ -52,7 +50,7 @@ def test_genesis_epoch_full_attestations_no_rewards(spec, state):
# ensure has not cross the epoch boundary # ensure has not cross the epoch boundary
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH
pre_state = deepcopy(state) pre_state = state.copy()
yield from run_process_rewards_and_penalties(spec, state) yield from run_process_rewards_and_penalties(spec, state)
@ -84,7 +82,7 @@ def prepare_state_with_full_attestations(spec, state):
def test_full_attestations(spec, state): def test_full_attestations(spec, state):
attestations = prepare_state_with_full_attestations(spec, state) attestations = prepare_state_with_full_attestations(spec, state)
pre_state = deepcopy(state) pre_state = state.copy()
yield from run_process_rewards_and_penalties(spec, state) yield from run_process_rewards_and_penalties(spec, state)
@ -122,18 +120,19 @@ def test_full_attestations_random_incorrect_fields(spec, state):
@with_all_phases @with_all_phases
@spec_test @spec_test
@with_custom_state(balances_fn=misc_balances, threshold_fn=default_activation_threshold) @with_custom_state(balances_fn=misc_balances, threshold_fn=lambda spec: spec.MAX_EFFECTIVE_BALANCE // 2)
@single_phase @single_phase
def test_full_attestations_misc_balances(spec, state): def test_full_attestations_misc_balances(spec, state):
attestations = prepare_state_with_full_attestations(spec, state) attestations = prepare_state_with_full_attestations(spec, state)
pre_state = deepcopy(state) pre_state = state.copy()
yield from run_process_rewards_and_penalties(spec, state) yield from run_process_rewards_and_penalties(spec, state)
attesting_indices = spec.get_unslashed_attesting_indices(state, attestations) attesting_indices = spec.get_unslashed_attesting_indices(state, attestations)
assert len(attesting_indices) > 0 assert len(attesting_indices) > 0
assert len(attesting_indices) != len(pre_state.validators) assert len(attesting_indices) != len(pre_state.validators)
assert any(v.effective_balance != spec.MAX_EFFECTIVE_BALANCE for v in state.validators)
for index in range(len(pre_state.validators)): for index in range(len(pre_state.validators)):
if index in attesting_indices: if index in attesting_indices:
assert state.balances[index] > pre_state.balances[index] assert state.balances[index] > pre_state.balances[index]
@ -141,13 +140,21 @@ def test_full_attestations_misc_balances(spec, state):
assert state.balances[index] < pre_state.balances[index] assert state.balances[index] < pre_state.balances[index]
else: else:
assert state.balances[index] == pre_state.balances[index] assert state.balances[index] == pre_state.balances[index]
# Check if base rewards are consistent with effective balance.
brs = {}
for index in attesting_indices:
br = spec.get_base_reward(state, index)
if br in brs:
assert brs[br] == state.validators[index].effective_balance
else:
brs[br] = state.validators[index].effective_balance
@with_all_phases @with_all_phases
@spec_state_test @spec_state_test
def test_no_attestations_all_penalties(spec, state): def test_no_attestations_all_penalties(spec, state):
next_epoch(spec, state) next_epoch(spec, state)
pre_state = deepcopy(state) pre_state = state.copy()
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH + 1 assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH + 1
@ -173,8 +180,8 @@ def test_duplicate_attestation(spec, state):
assert len(participants) > 0 assert len(participants) > 0
single_state = deepcopy(state) single_state = state.copy()
dup_state = deepcopy(state) dup_state = state.copy()
inclusion_slot = state.slot + spec.MIN_ATTESTATION_INCLUSION_DELAY inclusion_slot = state.slot + spec.MIN_ATTESTATION_INCLUSION_DELAY
add_attestations_to_state(spec, single_state, [attestation], inclusion_slot) add_attestations_to_state(spec, single_state, [attestation], inclusion_slot)
@ -220,7 +227,7 @@ def test_attestations_some_slashed(spec, state):
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH + 1 assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH + 1
assert len(state.previous_epoch_attestations) == spec.SLOTS_PER_EPOCH assert len(state.previous_epoch_attestations) == spec.SLOTS_PER_EPOCH
pre_state = deepcopy(state) pre_state = state.copy()
yield from run_process_rewards_and_penalties(spec, state) yield from run_process_rewards_and_penalties(spec, state)