Refactor tests and avoiding passing `shart_store` to helper functions

This commit is contained in:
Hsiao-Wei Wang 2020-07-14 17:36:27 +08:00
parent 2da331a345
commit f6b1fe6172
No known key found for this signature in database
GPG Key ID: 95B070122902DEA4
2 changed files with 159 additions and 49 deletions

View File

@ -47,7 +47,8 @@ def get_forkchoice_shard_store(anchor_state: BeaconState, shard: Shard) -> Shard
#### `get_shard_latest_attesting_balance` #### `get_shard_latest_attesting_balance`
```python ```python
def get_shard_latest_attesting_balance(store: Store, shard_store: ShardStore, root: Root) -> Gwei: def get_shard_latest_attesting_balance(store: Store, shard: Shard, root: Root) -> Gwei:
shard_store = store.shard_stores[shard]
state = store.checkpoint_states[store.justified_checkpoint] state = store.checkpoint_states[store.justified_checkpoint]
active_indices = get_active_validator_indices(state, get_current_epoch(state)) active_indices = get_active_validator_indices(state, get_current_epoch(state))
return Gwei(sum( return Gwei(sum(
@ -58,7 +59,7 @@ def get_shard_latest_attesting_balance(store: Store, shard_store: ShardStore, ro
# would be ignored once their newer vote is accepted. Check if it makes sense. # would be ignored once their newer vote is accepted. Check if it makes sense.
and get_shard_ancestor( and get_shard_ancestor(
store, store,
shard_store, shard,
shard_store.latest_messages[i].root, shard_store.latest_messages[i].root,
shard_store.signed_blocks[root].message.slot, shard_store.signed_blocks[root].message.slot,
) == root ) == root
@ -69,10 +70,14 @@ def get_shard_latest_attesting_balance(store: Store, shard_store: ShardStore, ro
#### `get_shard_head` #### `get_shard_head`
```python ```python
def get_shard_head(store: Store, shard_store: ShardStore) -> Root: def get_shard_head(store: Store, shard: Shard) -> Root:
# Execute the LMD-GHOST fork choice # Execute the LMD-GHOST fork choice
"""
Execute the LMD-GHOST fork choice.
"""
shard_store = store.shard_stores[shard]
beacon_head_root = get_head(store) beacon_head_root = get_head(store)
shard_head_state = store.block_states[beacon_head_root].shard_states[shard_store.shard] shard_head_state = store.block_states[beacon_head_root].shard_states[shard]
shard_head_root = shard_head_state.latest_block_root shard_head_root = shard_head_state.latest_block_root
shard_blocks = { shard_blocks = {
root: signed_shard_block.message for root, signed_shard_block in shard_store.signed_blocks.items() root: signed_shard_block.message for root, signed_shard_block in shard_store.signed_blocks.items()
@ -88,17 +93,18 @@ def get_shard_head(store: Store, shard_store: ShardStore) -> Root:
return shard_head_root return shard_head_root
# Sort by latest attesting balance with ties broken lexicographically # Sort by latest attesting balance with ties broken lexicographically
shard_head_root = max( shard_head_root = max(
children, key=lambda root: (get_shard_latest_attesting_balance(store, shard_store, root), root) children, key=lambda root: (get_shard_latest_attesting_balance(store, shard, root), root)
) )
``` ```
#### `get_shard_ancestor` #### `get_shard_ancestor`
```python ```python
def get_shard_ancestor(store: Store, shard_store: ShardStore, root: Root, slot: Slot) -> Root: def get_shard_ancestor(store: Store, shard: Shard, root: Root, slot: Slot) -> Root:
shard_store = store.shard_stores[shard]
block = shard_store.signed_blocks[root].message block = shard_store.signed_blocks[root].message
if block.slot > slot: if block.slot > slot:
return get_shard_ancestor(store, shard_store, block.shard_parent_root, slot) return get_shard_ancestor(store, shard, block.shard_parent_root, slot)
elif block.slot == slot: elif block.slot == slot:
return root return root
else: else:
@ -109,17 +115,17 @@ def get_shard_ancestor(store: Store, shard_store: ShardStore, root: Root, slot:
#### `get_pending_shard_blocks` #### `get_pending_shard_blocks`
```python ```python
def get_pending_shard_blocks(store: Store, shard_store: ShardStore) -> Sequence[SignedShardBlock]: def get_pending_shard_blocks(store: Store, shard: Shard) -> Sequence[SignedShardBlock]:
""" """
Return the canonical shard block branch that has not yet been crosslinked. Return the canonical shard block branch that has not yet been crosslinked.
""" """
shard = shard_store.shard shard_store = store.shard_stores[shard]
beacon_head_root = get_head(store) beacon_head_root = get_head(store)
beacon_head_state = store.block_states[beacon_head_root] beacon_head_state = store.block_states[beacon_head_root]
latest_shard_block_root = beacon_head_state.shard_states[shard].latest_block_root latest_shard_block_root = beacon_head_state.shard_states[shard].latest_block_root
shard_head_root = get_shard_head(store, shard_store) shard_head_root = get_shard_head(store, shard)
root = shard_head_root root = shard_head_root
signed_shard_blocks = [] signed_shard_blocks = []
while root != latest_shard_block_root: while root != latest_shard_block_root:
@ -136,9 +142,9 @@ def get_pending_shard_blocks(store: Store, shard_store: ShardStore) -> Sequence[
#### `on_shard_block` #### `on_shard_block`
```python ```python
def on_shard_block(store: Store, shard_store: ShardStore, signed_shard_block: SignedShardBlock) -> None: def on_shard_block(store: Store, shard: Shard, signed_shard_block: SignedShardBlock) -> None:
shard_store = store.shard_stores[shard]
shard_block = signed_shard_block.message shard_block = signed_shard_block.message
shard = shard_store.shard
# Check shard # Check shard
# TODO: check it in networking spec # TODO: check it in networking spec

View File

@ -8,27 +8,43 @@ from eth2spec.test.helpers.shard_block import (
get_committee_index_of_shard, get_committee_index_of_shard,
) )
from eth2spec.test.helpers.fork_choice import add_block_to_store, get_anchor_root from eth2spec.test.helpers.fork_choice import add_block_to_store, get_anchor_root
from eth2spec.test.helpers.shard_transitions import is_full_crosslink
from eth2spec.test.helpers.state import state_transition_and_sign_block from eth2spec.test.helpers.state import state_transition_and_sign_block
from eth2spec.test.helpers.block import build_empty_block from eth2spec.test.helpers.block import build_empty_block
def run_on_shard_block(spec, store, shard_store, signed_block, valid=True): def run_on_shard_block(spec, store, shard, signed_block, valid=True):
if not valid: if not valid:
try: try:
spec.on_shard_block(store, shard_store, signed_block) spec.on_shard_block(store, shard, signed_block)
except AssertionError: except AssertionError:
return return
else: else:
assert False assert False
spec.on_shard_block(store, shard_store, signed_block) spec.on_shard_block(store, shard, signed_block)
shard_store = store.shard_stores[shard]
assert shard_store.signed_blocks[hash_tree_root(signed_block.message)] == signed_block assert shard_store.signed_blocks[hash_tree_root(signed_block.message)] == signed_block
def apply_shard_block(spec, store, shard_store, beacon_parent_state, shard_blocks_buffer): def initialize_store(spec, state, shard):
shard = shard_store.shard store = spec.get_forkchoice_store(state)
anchor_root = get_anchor_root(spec, state)
assert spec.get_head(store) == anchor_root
shard_head_root = spec.get_shard_head(store, shard)
assert shard_head_root == state.shard_states[shard].latest_block_root
shard_store = store.shard_stores[shard]
assert shard_store.block_states[shard_head_root].slot == 1
assert shard_store.block_states[shard_head_root] == state.shard_states[shard]
return store
def create_and_apply_shard_block(spec, store, shard, beacon_parent_state, shard_blocks_buffer):
body = b'\x56' * 4 body = b'\x56' * 4
shard_head_root = spec.get_shard_head(store, shard_store) shard_head_root = spec.get_shard_head(store, shard)
shard_store = store.shard_stores[shard]
shard_parent_state = shard_store.block_states[shard_head_root] shard_parent_state = shard_store.block_states[shard_head_root]
assert shard_parent_state.slot != beacon_parent_state.slot assert shard_parent_state.slot != beacon_parent_state.slot
shard_block = build_shard_block( shard_block = build_shard_block(
@ -36,12 +52,12 @@ def apply_shard_block(spec, store, shard_store, beacon_parent_state, shard_block
shard_parent_state=shard_parent_state, slot=beacon_parent_state.slot, body=body, signed=True shard_parent_state=shard_parent_state, slot=beacon_parent_state.slot, body=body, signed=True
) )
shard_blocks_buffer.append(shard_block) shard_blocks_buffer.append(shard_block)
run_on_shard_block(spec, store, shard_store, shard_block) run_on_shard_block(spec, store, shard, shard_block)
assert spec.get_shard_head(store, shard_store) == shard_block.message.hash_tree_root() assert spec.get_shard_head(store, shard) == shard_block.message.hash_tree_root()
def check_pending_shard_blocks(spec, store, shard_store, shard_blocks_buffer): def check_pending_shard_blocks(spec, store, shard, shard_blocks_buffer):
pending_shard_blocks = spec.get_pending_shard_blocks(store, shard_store) pending_shard_blocks = spec.get_pending_shard_blocks(store, shard)
assert pending_shard_blocks == shard_blocks_buffer assert pending_shard_blocks == shard_blocks_buffer
@ -52,10 +68,22 @@ def is_in_offset_sets(spec, beacon_head_state, shard):
return beacon_head_state.slot in offset_slots return beacon_head_state.slot in offset_slots
def apply_shard_and_beacon(spec, state, store, shard_store, shard_blocks_buffer): def create_attestation_for_shard_blocks(spec, beacon_parent_state, shard, committee_index, blocks,
store.time = store.time + spec.SECONDS_PER_SLOT * spec.SLOTS_PER_EPOCH filter_participant_set=None):
shard_transition = spec.get_shard_transition(beacon_parent_state, shard, blocks)
attestation = get_valid_on_time_attestation(
spec,
beacon_parent_state,
index=committee_index,
shard_transition=shard_transition,
signed=False,
)
return attestation
shard = shard_store.shard
def create_beacon_block_with_shard_transition(
spec, state, store, shard, shard_blocks_buffer, is_checking_pending_shard_blocks=True):
beacon_block = build_empty_block(spec, state, slot=state.slot + 1)
committee_index = get_committee_index_of_shard(spec, state, state.slot, shard) committee_index = get_committee_index_of_shard(spec, state, state.slot, shard)
has_shard_committee = committee_index is not None # has committee of `shard` at this slot has_shard_committee = committee_index is not None # has committee of `shard` at this slot
@ -63,14 +91,12 @@ def apply_shard_and_beacon(spec, state, store, shard_store, shard_blocks_buffer)
# If next slot has committee of `shard`, add `shard_transtion` to the proposing beacon block # If next slot has committee of `shard`, add `shard_transtion` to the proposing beacon block
if has_shard_committee and len(shard_blocks_buffer) > 0: if has_shard_committee and len(shard_blocks_buffer) > 0:
# Sanity check `get_pending_shard_blocks` function # Sanity check `get_pending_shard_blocks`
check_pending_shard_blocks(spec, store, shard_store, shard_blocks_buffer) # Assert that the pending shard blocks set in the store equal to shard_blocks_buffer
if is_checking_pending_shard_blocks:
check_pending_shard_blocks(spec, store, shard, shard_blocks_buffer)
# Use temporary next state to get ShardTransition of shard block # Use temporary next state to get ShardTransition of shard block
shard_transitions = get_shard_transitions( shard_transitions = get_shard_transitions(spec, state, shard_block_dict={shard: shard_blocks_buffer})
spec,
state,
shard_block_dict={shard: shard_blocks_buffer},
)
shard_transition = shard_transitions[shard] shard_transition = shard_transitions[shard]
attestation = get_valid_on_time_attestation( attestation = get_valid_on_time_attestation(
spec, spec,
@ -86,15 +112,31 @@ def apply_shard_and_beacon(spec, state, store, shard_store, shard_blocks_buffer)
# Clear buffer # Clear buffer
shard_blocks_buffer.clear() shard_blocks_buffer.clear()
signed_beacon_block = state_transition_and_sign_block(spec, state, beacon_block) # transition! return beacon_block
add_block_to_store(spec, store, signed_beacon_block)
assert spec.get_head(store) == beacon_block.hash_tree_root()
# On shard block at transitioned `state.slot`
def apply_all_attestation_to_store(spec, store, attestations):
for attestation in attestations:
spec.on_attestation(store, attestation)
def apply_beacon_block_to_store(spec, state, store, beacon_block):
signed_beacon_block = state_transition_and_sign_block(spec, state, beacon_block) # transition!
store.time = store.time + spec.SECONDS_PER_SLOT
add_block_to_store(spec, store, signed_beacon_block)
apply_all_attestation_to_store(spec, store, signed_beacon_block.message.body.attestations)
def create_and_apply_beacon_and_shard_blocks(spec, state, store, shard, shard_blocks_buffer):
beacon_block = create_beacon_block_with_shard_transition(spec, state, store, shard, shard_blocks_buffer)
apply_beacon_block_to_store(spec, state, store, beacon_block)
# On shard block at the transitioned `state.slot`
if is_in_offset_sets(spec, state, shard): if is_in_offset_sets(spec, state, shard):
# The created shard block would be appended to `shard_blocks_buffer` # The created shard block would be appended to `shard_blocks_buffer`
apply_shard_block(spec, store, shard_store, state, shard_blocks_buffer) create_and_apply_shard_block(spec, store, shard, state, shard_blocks_buffer)
has_shard_committee = get_committee_index_of_shard(spec, state, state.slot, shard) is not None
return has_shard_committee return has_shard_committee
@ -107,23 +149,85 @@ def test_basic(spec, state):
shard = spec.Shard(1) shard = spec.Shard(1)
# Initialization # Initialization
store = spec.get_forkchoice_store(state) store = initialize_store(spec, state, shard)
anchor_root = get_anchor_root(spec, state)
assert spec.get_head(store) == anchor_root
shard_store = store.shard_stores[shard]
shard_head_root = spec.get_shard_head(store, shard_store)
assert shard_head_root == state.shard_states[shard].latest_block_root
assert shard_store.block_states[shard_head_root].slot == 1
assert shard_store.block_states[shard_head_root] == state.shard_states[shard]
# For mainnet config, it's possible that only one committee of `shard` per epoch. # For mainnet config, it's possible that only one committee of `shard` per epoch.
# we set this counter to test more rounds. # we set this counter to test more rounds.
shard_committee_counter = 2 shard_committee_counter = 2
shard_blocks_buffer = [] shard_blocks_buffer = [] # the accumulated shard blocks that haven't been crosslinked yet
while shard_committee_counter > 0: while shard_committee_counter > 0:
has_shard_committee = apply_shard_and_beacon( has_shard_committee = create_and_apply_beacon_and_shard_blocks(
spec, state, store, shard_store, shard_blocks_buffer spec, state, store, shard, shard_blocks_buffer
) )
if has_shard_committee: if has_shard_committee:
shard_committee_counter -= 1 shard_committee_counter -= 1
def create_simple_fork(spec, state, store, shard):
# Beacon block
beacon_block = create_beacon_block_with_shard_transition(spec, state, store, shard, [])
apply_beacon_block_to_store(spec, state, store, beacon_block)
beacon_head_root = spec.get_head(store)
assert beacon_head_root == beacon_block.hash_tree_root()
beacon_parent_state = store.block_states[beacon_head_root]
shard_store = store.shard_stores[shard]
shard_parent_state = shard_store.block_states[spec.get_shard_head(store, shard)]
# Shard block A
body = b'\x56' * 4
forking_block_child = build_shard_block(
spec, beacon_parent_state, shard,
shard_parent_state=shard_parent_state, slot=beacon_parent_state.slot, body=body, signed=True
)
run_on_shard_block(spec, store, shard, forking_block_child)
# Shard block B
body = b'\x78' * 4 # different body
shard_block_b = build_shard_block(
spec, beacon_parent_state, shard,
shard_parent_state=shard_parent_state, slot=beacon_parent_state.slot, body=body, signed=True
)
run_on_shard_block(spec, store, shard, shard_block_b)
# Set forking_block
current_head = spec.get_shard_head(store, shard)
if current_head == forking_block_child.message.hash_tree_root():
head_block = forking_block_child
forking_block = shard_block_b
else:
assert current_head == shard_block_b.message.hash_tree_root()
head_block = shard_block_b
forking_block = forking_block_child
return head_block, forking_block
@with_all_phases_except([PHASE0])
@spec_state_test
@never_bls # Set to never_bls for testing `check_pending_shard_blocks`
def test_shard_simple_fork(spec, state):
if not is_full_crosslink(spec, state):
# skip
return
spec.PHASE_1_GENESIS_SLOT = 0 # NOTE: mock genesis slot here
state = spec.upgrade_to_phase1(state)
shard = spec.Shard(1)
# Initialization
store = initialize_store(spec, state, shard)
# Create fork
_, forking_block = create_simple_fork(spec, state, store, shard)
# Vote for forking_block
state = store.block_states[spec.get_head(store)].copy()
beacon_block = create_beacon_block_with_shard_transition(spec, state, store, shard, [forking_block],
is_checking_pending_shard_blocks=False)
# apply_beacon_block_to_store(spec, state, store, beacon_block)
store.time = store.time + spec.SECONDS_PER_SLOT
apply_all_attestation_to_store(spec, store, beacon_block.body.attestations)
# Head block has been changed
assert spec.get_shard_head(store, shard) == forking_block.message.hash_tree_root()