Merge pull request #1672 from ethereum/fix-memoization
Fix memoization (base reward cache bug + add LRU)
This commit is contained in:
commit
9e137a6404
4
Makefile
4
Makefile
|
@ -77,6 +77,10 @@ test: pyspec
|
||||||
. venv/bin/activate; cd $(PY_SPEC_DIR); \
|
. venv/bin/activate; cd $(PY_SPEC_DIR); \
|
||||||
python -m pytest -n 4 --cov=eth2spec.phase0.spec --cov=eth2spec.phase1.spec --cov-report="html:$(COV_HTML_OUT)" --cov-branch eth2spec
|
python -m pytest -n 4 --cov=eth2spec.phase0.spec --cov=eth2spec.phase1.spec --cov-report="html:$(COV_HTML_OUT)" --cov-branch eth2spec
|
||||||
|
|
||||||
|
find_test: pyspec
|
||||||
|
. venv/bin/activate; cd $(PY_SPEC_DIR); \
|
||||||
|
python -m pytest -k=$(K) --cov=eth2spec.phase0.spec --cov=eth2spec.phase1.spec --cov-report="html:$(COV_HTML_OUT)" --cov-branch eth2spec
|
||||||
|
|
||||||
citest: pyspec
|
citest: pyspec
|
||||||
mkdir -p tests/core/pyspec/test-reports/eth2spec; . venv/bin/activate; cd $(PY_SPEC_DIR); \
|
mkdir -p tests/core/pyspec/test-reports/eth2spec; . venv/bin/activate; cd $(PY_SPEC_DIR); \
|
||||||
python -m pytest -n 4 --junitxml=eth2spec/test_results.xml eth2spec
|
python -m pytest -n 4 --junitxml=eth2spec/test_results.xml eth2spec
|
||||||
|
|
40
setup.py
40
setup.py
|
@ -92,6 +92,8 @@ from dataclasses import (
|
||||||
field,
|
field,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from lru import LRU
|
||||||
|
|
||||||
from eth2spec.utils.ssz.ssz_impl import hash_tree_root
|
from eth2spec.utils.ssz.ssz_impl import hash_tree_root
|
||||||
from eth2spec.utils.ssz.ssz_typing import (
|
from eth2spec.utils.ssz.ssz_typing import (
|
||||||
View, boolean, Container, List, Vector, uint64,
|
View, boolean, Container, List, Vector, uint64,
|
||||||
|
@ -114,6 +116,8 @@ from dataclasses import (
|
||||||
field,
|
field,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from lru import LRU
|
||||||
|
|
||||||
from eth2spec.utils.ssz.ssz_impl import hash_tree_root
|
from eth2spec.utils.ssz.ssz_impl import hash_tree_root
|
||||||
from eth2spec.utils.ssz.ssz_typing import (
|
from eth2spec.utils.ssz.ssz_typing import (
|
||||||
View, boolean, Container, List, Vector, uint64, uint8, bit,
|
View, boolean, Container, List, Vector, uint64, uint8, bit,
|
||||||
|
@ -152,8 +156,8 @@ def hash(x: bytes) -> Bytes32: # type: ignore
|
||||||
return hash_cache[x]
|
return hash_cache[x]
|
||||||
|
|
||||||
|
|
||||||
def cache_this(key_fn, value_fn): # type: ignore
|
def cache_this(key_fn, value_fn, lru_size): # type: ignore
|
||||||
cache_dict = {} # type: ignore
|
cache_dict = LRU(size=lru_size)
|
||||||
|
|
||||||
def wrapper(*args, **kw): # type: ignore
|
def wrapper(*args, **kw): # type: ignore
|
||||||
key = key_fn(*args, **kw)
|
key = key_fn(*args, **kw)
|
||||||
|
@ -164,35 +168,50 @@ def cache_this(key_fn, value_fn): # type: ignore
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
_compute_shuffled_index = compute_shuffled_index
|
||||||
|
compute_shuffled_index = cache_this(
|
||||||
|
lambda index, index_count, seed: (index, index_count, seed),
|
||||||
|
_compute_shuffled_index, lru_size=SLOTS_PER_EPOCH * 3)
|
||||||
|
|
||||||
|
_get_total_active_balance = get_total_active_balance
|
||||||
|
get_total_active_balance = cache_this(
|
||||||
|
lambda state: (state.validators.hash_tree_root(), compute_epoch_at_slot(state.slot)),
|
||||||
|
_get_total_active_balance, lru_size=10)
|
||||||
|
|
||||||
_get_base_reward = get_base_reward
|
_get_base_reward = get_base_reward
|
||||||
get_base_reward = cache_this(
|
get_base_reward = cache_this(
|
||||||
lambda state, index: (state.validators.hash_tree_root(), state.slot),
|
lambda state, index: (state.validators.hash_tree_root(), state.slot, index),
|
||||||
_get_base_reward)
|
_get_base_reward, lru_size=2048)
|
||||||
|
|
||||||
_get_committee_count_at_slot = get_committee_count_at_slot
|
_get_committee_count_at_slot = get_committee_count_at_slot
|
||||||
get_committee_count_at_slot = cache_this(
|
get_committee_count_at_slot = cache_this(
|
||||||
lambda state, epoch: (state.validators.hash_tree_root(), epoch),
|
lambda state, epoch: (state.validators.hash_tree_root(), epoch),
|
||||||
_get_committee_count_at_slot)
|
_get_committee_count_at_slot, lru_size=SLOTS_PER_EPOCH * 3)
|
||||||
|
|
||||||
_get_active_validator_indices = get_active_validator_indices
|
_get_active_validator_indices = get_active_validator_indices
|
||||||
get_active_validator_indices = cache_this(
|
get_active_validator_indices = cache_this(
|
||||||
lambda state, epoch: (state.validators.hash_tree_root(), epoch),
|
lambda state, epoch: (state.validators.hash_tree_root(), epoch),
|
||||||
_get_active_validator_indices)
|
_get_active_validator_indices, lru_size=3)
|
||||||
|
|
||||||
_get_beacon_committee = get_beacon_committee
|
_get_beacon_committee = get_beacon_committee
|
||||||
get_beacon_committee = cache_this(
|
get_beacon_committee = cache_this(
|
||||||
lambda state, slot, index: (state.validators.hash_tree_root(), state.randao_mixes.hash_tree_root(), slot, index),
|
lambda state, slot, index: (state.validators.hash_tree_root(), state.randao_mixes.hash_tree_root(), slot, index),
|
||||||
_get_beacon_committee)
|
_get_beacon_committee, lru_size=SLOTS_PER_EPOCH * MAX_COMMITTEES_PER_SLOT * 3)
|
||||||
|
|
||||||
_get_matching_target_attestations = get_matching_target_attestations
|
_get_matching_target_attestations = get_matching_target_attestations
|
||||||
get_matching_target_attestations = cache_this(
|
get_matching_target_attestations = cache_this(
|
||||||
lambda state, epoch: (state.hash_tree_root(), epoch),
|
lambda state, epoch: (state.hash_tree_root(), epoch),
|
||||||
_get_matching_target_attestations)
|
_get_matching_target_attestations, lru_size=10)
|
||||||
|
|
||||||
_get_matching_head_attestations = get_matching_head_attestations
|
_get_matching_head_attestations = get_matching_head_attestations
|
||||||
get_matching_head_attestations = cache_this(
|
get_matching_head_attestations = cache_this(
|
||||||
lambda state, epoch: (state.hash_tree_root(), epoch),
|
lambda state, epoch: (state.hash_tree_root(), epoch),
|
||||||
_get_matching_head_attestations)'''
|
_get_matching_head_attestations, lru_size=10)
|
||||||
|
|
||||||
|
_get_attesting_indices = get_attesting_indices
|
||||||
|
get_attesting_indices = cache_this(
|
||||||
|
lambda state, data, bits: (state.validators.hash_tree_root(), data.hash_tree_root(), bits.hash_tree_root()),
|
||||||
|
_get_attesting_indices, lru_size=SLOTS_PER_EPOCH * MAX_COMMITTEES_PER_SLOT * 3)'''
|
||||||
|
|
||||||
|
|
||||||
def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str:
|
def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str:
|
||||||
|
@ -481,6 +500,7 @@ setup(
|
||||||
"py_ecc==2.0.0",
|
"py_ecc==2.0.0",
|
||||||
"dataclasses==0.6",
|
"dataclasses==0.6",
|
||||||
"remerkleable==0.1.12",
|
"remerkleable==0.1.12",
|
||||||
"ruamel.yaml==0.16.5"
|
"ruamel.yaml==0.16.5",
|
||||||
|
"lru-dict==1.1.6"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,6 +6,7 @@ from .helpers.genesis import create_genesis_state
|
||||||
|
|
||||||
from .utils import vector_test, with_meta_tags
|
from .utils import vector_test, with_meta_tags
|
||||||
|
|
||||||
|
from random import Random
|
||||||
from typing import Any, Callable, Sequence, TypedDict, Protocol
|
from typing import Any, Callable, Sequence, TypedDict, Protocol
|
||||||
|
|
||||||
from importlib import reload
|
from importlib import reload
|
||||||
|
@ -108,8 +109,10 @@ def misc_balances(spec):
|
||||||
Usage: `@with_custom_state(balances_fn=misc_balances, ...)`
|
Usage: `@with_custom_state(balances_fn=misc_balances, ...)`
|
||||||
"""
|
"""
|
||||||
num_validators = spec.SLOTS_PER_EPOCH * 8
|
num_validators = spec.SLOTS_PER_EPOCH * 8
|
||||||
num_misc_validators = spec.SLOTS_PER_EPOCH
|
balances = [spec.MAX_EFFECTIVE_BALANCE * 2 * i // num_validators for i in range(num_validators)]
|
||||||
return [spec.MAX_EFFECTIVE_BALANCE] * num_validators + [spec.MIN_DEPOSIT_AMOUNT] * num_misc_validators
|
rng = Random(1234)
|
||||||
|
rng.shuffle(balances)
|
||||||
|
return balances
|
||||||
|
|
||||||
|
|
||||||
def low_single_balance(spec):
|
def low_single_balance(spec):
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
from eth2spec.test.context import (
|
from eth2spec.test.context import (
|
||||||
spec_state_test, with_all_phases, spec_test,
|
spec_state_test, with_all_phases, spec_test,
|
||||||
misc_balances, low_single_balance,
|
misc_balances, with_custom_state,
|
||||||
with_custom_state,
|
low_single_balance, zero_activation_threshold,
|
||||||
default_activation_threshold, zero_activation_threshold,
|
|
||||||
single_phase,
|
single_phase,
|
||||||
)
|
)
|
||||||
from eth2spec.test.helpers.state import (
|
from eth2spec.test.helpers.state import (
|
||||||
|
@ -26,7 +23,7 @@ def run_process_rewards_and_penalties(spec, state):
|
||||||
@with_all_phases
|
@with_all_phases
|
||||||
@spec_state_test
|
@spec_state_test
|
||||||
def test_genesis_epoch_no_attestations_no_penalties(spec, state):
|
def test_genesis_epoch_no_attestations_no_penalties(spec, state):
|
||||||
pre_state = deepcopy(state)
|
pre_state = state.copy()
|
||||||
|
|
||||||
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH
|
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH
|
||||||
|
|
||||||
|
@ -54,7 +51,7 @@ def test_genesis_epoch_full_attestations_no_rewards(spec, state):
|
||||||
# ensure has not cross the epoch boundary
|
# ensure has not cross the epoch boundary
|
||||||
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH
|
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH
|
||||||
|
|
||||||
pre_state = deepcopy(state)
|
pre_state = state.copy()
|
||||||
|
|
||||||
yield from run_process_rewards_and_penalties(spec, state)
|
yield from run_process_rewards_and_penalties(spec, state)
|
||||||
|
|
||||||
|
@ -86,7 +83,7 @@ def prepare_state_with_full_attestations(spec, state, empty=False):
|
||||||
def test_full_attestations(spec, state):
|
def test_full_attestations(spec, state):
|
||||||
attestations = prepare_state_with_full_attestations(spec, state)
|
attestations = prepare_state_with_full_attestations(spec, state)
|
||||||
|
|
||||||
pre_state = deepcopy(state)
|
pre_state = state.copy()
|
||||||
|
|
||||||
yield from run_process_rewards_and_penalties(spec, state)
|
yield from run_process_rewards_and_penalties(spec, state)
|
||||||
|
|
||||||
|
@ -124,18 +121,19 @@ def test_full_attestations_random_incorrect_fields(spec, state):
|
||||||
|
|
||||||
@with_all_phases
|
@with_all_phases
|
||||||
@spec_test
|
@spec_test
|
||||||
@with_custom_state(balances_fn=misc_balances, threshold_fn=default_activation_threshold)
|
@with_custom_state(balances_fn=misc_balances, threshold_fn=lambda spec: spec.MAX_EFFECTIVE_BALANCE // 2)
|
||||||
@single_phase
|
@single_phase
|
||||||
def test_full_attestations_misc_balances(spec, state):
|
def test_full_attestations_misc_balances(spec, state):
|
||||||
attestations = prepare_state_with_full_attestations(spec, state)
|
attestations = prepare_state_with_full_attestations(spec, state)
|
||||||
|
|
||||||
pre_state = deepcopy(state)
|
pre_state = state.copy()
|
||||||
|
|
||||||
yield from run_process_rewards_and_penalties(spec, state)
|
yield from run_process_rewards_and_penalties(spec, state)
|
||||||
|
|
||||||
attesting_indices = spec.get_unslashed_attesting_indices(state, attestations)
|
attesting_indices = spec.get_unslashed_attesting_indices(state, attestations)
|
||||||
assert len(attesting_indices) > 0
|
assert len(attesting_indices) > 0
|
||||||
assert len(attesting_indices) != len(pre_state.validators)
|
assert len(attesting_indices) != len(pre_state.validators)
|
||||||
|
assert any(v.effective_balance != spec.MAX_EFFECTIVE_BALANCE for v in state.validators)
|
||||||
for index in range(len(pre_state.validators)):
|
for index in range(len(pre_state.validators)):
|
||||||
if index in attesting_indices:
|
if index in attesting_indices:
|
||||||
assert state.balances[index] > pre_state.balances[index]
|
assert state.balances[index] > pre_state.balances[index]
|
||||||
|
@ -143,6 +141,14 @@ def test_full_attestations_misc_balances(spec, state):
|
||||||
assert state.balances[index] < pre_state.balances[index]
|
assert state.balances[index] < pre_state.balances[index]
|
||||||
else:
|
else:
|
||||||
assert state.balances[index] == pre_state.balances[index]
|
assert state.balances[index] == pre_state.balances[index]
|
||||||
|
# Check if base rewards are consistent with effective balance.
|
||||||
|
brs = {}
|
||||||
|
for index in attesting_indices:
|
||||||
|
br = spec.get_base_reward(state, index)
|
||||||
|
if br in brs:
|
||||||
|
assert brs[br] == state.validators[index].effective_balance
|
||||||
|
else:
|
||||||
|
brs[br] = state.validators[index].effective_balance
|
||||||
|
|
||||||
|
|
||||||
@with_all_phases
|
@with_all_phases
|
||||||
|
@ -163,7 +169,7 @@ def test_full_attestations_one_validaor_one_gwei(spec, state):
|
||||||
@spec_state_test
|
@spec_state_test
|
||||||
def test_no_attestations_all_penalties(spec, state):
|
def test_no_attestations_all_penalties(spec, state):
|
||||||
next_epoch(spec, state)
|
next_epoch(spec, state)
|
||||||
pre_state = deepcopy(state)
|
pre_state = state.copy()
|
||||||
|
|
||||||
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH + 1
|
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH + 1
|
||||||
|
|
||||||
|
@ -178,7 +184,7 @@ def test_no_attestations_all_penalties(spec, state):
|
||||||
def test_empty_attestations(spec, state):
|
def test_empty_attestations(spec, state):
|
||||||
attestations = prepare_state_with_full_attestations(spec, state, empty=True)
|
attestations = prepare_state_with_full_attestations(spec, state, empty=True)
|
||||||
|
|
||||||
pre_state = deepcopy(state)
|
pre_state = state.copy()
|
||||||
|
|
||||||
yield from run_process_rewards_and_penalties(spec, state)
|
yield from run_process_rewards_and_penalties(spec, state)
|
||||||
|
|
||||||
|
@ -205,8 +211,8 @@ def test_duplicate_attestation(spec, state):
|
||||||
|
|
||||||
assert len(participants) > 0
|
assert len(participants) > 0
|
||||||
|
|
||||||
single_state = deepcopy(state)
|
single_state = state.copy()
|
||||||
dup_state = deepcopy(state)
|
dup_state = state.copy()
|
||||||
|
|
||||||
inclusion_slot = state.slot + spec.MIN_ATTESTATION_INCLUSION_DELAY
|
inclusion_slot = state.slot + spec.MIN_ATTESTATION_INCLUSION_DELAY
|
||||||
add_attestations_to_state(spec, single_state, [attestation], inclusion_slot)
|
add_attestations_to_state(spec, single_state, [attestation], inclusion_slot)
|
||||||
|
@ -252,7 +258,7 @@ def test_attestations_some_slashed(spec, state):
|
||||||
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH + 1
|
assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH + 1
|
||||||
assert len(state.previous_epoch_attestations) == spec.SLOTS_PER_EPOCH
|
assert len(state.previous_epoch_attestations) == spec.SLOTS_PER_EPOCH
|
||||||
|
|
||||||
pre_state = deepcopy(state)
|
pre_state = state.copy()
|
||||||
|
|
||||||
yield from run_process_rewards_and_penalties(spec, state)
|
yield from run_process_rewards_and_penalties(spec, state)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue