diff --git a/specs/phase1/shard-fork-choice.md b/specs/phase1/shard-fork-choice.md index e9026b991..4b3f42194 100644 --- a/specs/phase1/shard-fork-choice.md +++ b/specs/phase1/shard-fork-choice.md @@ -31,75 +31,38 @@ This document is the shard chain fork choice spec for part of Ethereum 2.0 Phase ### Helpers -#### Extended `Store` +#### `ShardStore` ```python @dataclass -class Store: - - @dataclass - class ShardStore: - blocks: Dict[Root, ShardBlock] = field(default_factory=dict) - block_states: Dict[Root, ShardState] = field(default_factory=dict) - - time: uint64 - genesis_time: uint64 - justified_checkpoint: Checkpoint - finalized_checkpoint: Checkpoint - best_justified_checkpoint: Checkpoint - blocks: Dict[Root, BeaconBlock] = field(default_factory=dict) - block_states: Dict[Root, BeaconState] = field(default_factory=dict) - checkpoint_states: Dict[Checkpoint, BeaconState] = field(default_factory=dict) - latest_messages: Dict[ValidatorIndex, LatestMessage] = field(default_factory=dict) - # shard chain - shards: Dict[Shard, ShardStore] = field(default_factory=dict) # noqa: F821 +class ShardStore: + shard: Shard + blocks: Dict[Root, ShardBlock] = field(default_factory=dict) + block_states: Dict[Root, ShardState] = field(default_factory=dict) ``` -#### Updated `get_forkchoice_store` +#### Updated `get_forkchoice_shard_store` ```python -def get_forkchoice_store(anchor_state: BeaconState) -> Store: - shard_count = len(anchor_state.shard_states) - anchor_block_header = anchor_state.latest_block_header.copy() - if anchor_block_header.state_root == Bytes32(): - anchor_block_header.state_root = hash_tree_root(anchor_state) - anchor_root = hash_tree_root(anchor_block_header) - anchor_epoch = get_current_epoch(anchor_state) - justified_checkpoint = Checkpoint(epoch=anchor_epoch, root=anchor_root) - finalized_checkpoint = Checkpoint(epoch=anchor_epoch, root=anchor_root) - - shard_stores = {} - for shard in map(Shard, range(shard_count)): - shard_stores[shard] = Store.ShardStore( - blocks={anchor_state.shard_states[shard].latest_block_root: ShardBlock(slot=anchor_state.slot)}, - block_states={anchor_state.shard_states[shard].latest_block_root: anchor_state.copy().shard_states[shard]}, - ) - - return Store( - time=anchor_state.genesis_time + SECONDS_PER_SLOT * anchor_state.slot, - genesis_time=anchor_state.genesis_time, - justified_checkpoint=justified_checkpoint, - finalized_checkpoint=finalized_checkpoint, - best_justified_checkpoint=justified_checkpoint, - blocks={anchor_root: anchor_block_header}, - block_states={anchor_root: anchor_state.copy()}, - checkpoint_states={justified_checkpoint: anchor_state.copy()}, - # shard chain - shards=shard_stores, +def get_forkchoice_shard_store(anchor_state: BeaconState, shard: Shard) -> ShardStore: + return ShardStore( + shard=shard, + blocks={anchor_state.shard_states[shard].latest_block_root: ShardBlock(slot=anchor_state.slot)}, + block_states={anchor_state.shard_states[shard].latest_block_root: anchor_state.copy().shard_states[shard]}, ) ``` #### `get_shard_latest_attesting_balance` ```python -def get_shard_latest_attesting_balance(store: Store, shard: Shard, root: Root) -> Gwei: +def get_shard_latest_attesting_balance(store: Store, shard_store: ShardStore, root: Root) -> Gwei: state = store.checkpoint_states[store.justified_checkpoint] active_indices = get_active_validator_indices(state, get_current_epoch(state)) return Gwei(sum( state.validators[i].effective_balance for i in active_indices if ( i in store.latest_messages and get_shard_ancestor( - store, shard, store.latest_messages[i].root, store.shards[shard].blocks[root].slot + store, shard_store, store.latest_messages[i].root, shard_store.blocks[root].slot ) == root ) )) @@ -108,13 +71,13 @@ def get_shard_latest_attesting_balance(store: Store, shard: Shard, root: Root) - #### `get_shard_head` ```python -def get_shard_head(store: Store, shard: Shard) -> Root: +def get_shard_head(store: Store, shard_store: ShardStore) -> Root: # Get filtered block tree that only includes viable branches - blocks = get_filtered_shard_block_tree(store, shard) + blocks = get_filtered_shard_block_tree(store, shard_store) # Execute the LMD-GHOST fork choice head_beacon_root = get_head(store) - head_shard_root = store.block_states[head_beacon_root].shard_states[shard].latest_block_root + head_shard_root = store.block_states[head_beacon_root].shard_states[shard_store.shard].latest_block_root while True: children = [ root for root in blocks.keys() @@ -123,16 +86,18 @@ def get_shard_head(store: Store, shard: Shard) -> Root: if len(children) == 0: return head_shard_root # Sort by latest attesting balance with ties broken lexicographically - head_shard_root = max(children, key=lambda root: (get_shard_latest_attesting_balance(store, shard, root), root)) + head_shard_root = max( + children, key=lambda root: (get_shard_latest_attesting_balance(store, shard_store, root), root) + ) ``` #### `get_shard_ancestor` ```python -def get_shard_ancestor(store: Store, shard: Shard, root: Root, slot: Slot) -> Root: - block = store.shards[shard].blocks[root] +def get_shard_ancestor(store: Store, shard_store: ShardStore, root: Root, slot: Slot) -> Root: + block = shard_store.blocks[root] if block.slot > slot: - return get_shard_ancestor(store, shard, block.shard_parent_root, slot) + return get_shard_ancestor(store, shard_store, block.shard_parent_root, slot) elif block.slot == slot: return root else: @@ -143,8 +108,10 @@ def get_shard_ancestor(store: Store, shard: Shard, root: Root, slot: Slot) -> Ro #### `filter_shard_block_tree` ```python -def filter_shard_block_tree(store: Store, shard: Shard, block_root: Root, blocks: Dict[Root, ShardBlock]) -> bool: - shard_store = store.shards[shard] +def filter_shard_block_tree(store: Store, + shard_store: ShardStore, + block_root: Root, + blocks: Dict[Root, ShardBlock]) -> bool: block = shard_store.blocks[block_root] children = [ root for root in shard_store.blocks.keys() @@ -152,7 +119,7 @@ def filter_shard_block_tree(store: Store, shard: Shard, block_root: Root, blocks ] if any(children): - filter_block_tree_result = [filter_shard_block_tree(store, shard, child, blocks) for child in children] + filter_block_tree_result = [filter_shard_block_tree(store, shard_store, child, blocks) for child in children] if any(filter_block_tree_result): blocks[block_root] = block return True @@ -164,11 +131,12 @@ def filter_shard_block_tree(store: Store, shard: Shard, block_root: Root, blocks #### `get_filtered_block_tree` ```python -def get_filtered_shard_block_tree(store: Store, shard: Shard) -> Dict[Root, ShardBlock]: +def get_filtered_shard_block_tree(store: Store, shard_store: ShardStore) -> Dict[Root, ShardBlock]: + shard = shard_store.shard base_beacon_block_root = get_head(store) base_shard_block_root = store.block_states[base_beacon_block_root].shard_states[shard].latest_block_root blocks: Dict[Root, ShardBlock] = {} - filter_shard_block_tree(store, shard, base_shard_block_root, blocks) + filter_shard_block_tree(store, shard_store, base_shard_block_root, blocks) return blocks ``` @@ -177,10 +145,9 @@ def get_filtered_shard_block_tree(store: Store, shard: Shard) -> Dict[Root, Shar #### `on_shard_block` ```python -def on_shard_block(store: Store, shard: Shard, signed_shard_block: SignedShardBlock) -> None: +def on_shard_block(store: Store, shard_store: ShardStore, signed_shard_block: SignedShardBlock) -> None: shard_block = signed_shard_block.message - shard_store = store.shards[shard] - + shard = shard_store.shard # 1. Check shard parent exists assert shard_block.shard_parent_root in shard_store.block_states pre_shard_state = shard_store.block_states[shard_block.shard_parent_root] diff --git a/tests/core/pyspec/eth2spec/test/fork_choice/test_on_shard_head.py b/tests/core/pyspec/eth2spec/test/fork_choice/test_on_shard_head.py index 220c510e7..f4b883f06 100644 --- a/tests/core/pyspec/eth2spec/test/fork_choice/test_on_shard_head.py +++ b/tests/core/pyspec/eth2spec/test/fork_choice/test_on_shard_head.py @@ -11,20 +11,21 @@ from eth2spec.test.helpers.state import next_slot, state_transition_and_sign_blo from eth2spec.test.helpers.block import build_empty_block -def run_on_shard_block(spec, store, shard, signed_block, valid=True): +def run_on_shard_block(spec, store, shard_store, signed_block, valid=True): if not valid: try: - spec.on_shard_block(store, shard, signed_block) + spec.on_shard_block(store, shard_store, signed_block) except AssertionError: return else: assert False - spec.on_shard_block(store, shard, signed_block) - assert store.shards[shard].blocks[hash_tree_root(signed_block.message)] == signed_block.message + spec.on_shard_block(store, shard_store, signed_block) + assert shard_store.blocks[hash_tree_root(signed_block.message)] == signed_block.message -def run_apply_shard_and_beacon(spec, state, store, shard, committee_index): +def run_apply_shard_and_beacon(spec, state, store, shard_store, committee_index): + shard = shard_store.shard store.time = store.time + spec.SECONDS_PER_SLOT * spec.SLOTS_PER_EPOCH # Create SignedShardBlock @@ -57,11 +58,11 @@ def run_apply_shard_and_beacon(spec, state, store, shard, committee_index): beacon_block.body.shard_transitions = shard_transitions signed_beacon_block = state_transition_and_sign_block(spec, state, beacon_block) - run_on_shard_block(spec, store, shard, shard_block) + run_on_shard_block(spec, store, shard_store, shard_block) add_block_to_store(spec, store, signed_beacon_block) assert spec.get_head(store) == beacon_block.hash_tree_root() - assert spec.get_shard_head(store, shard) == shard_block.message.hash_tree_root() + assert spec.get_shard_head(store, shard_store) == shard_block.message.hash_tree_root() @with_all_phases_except([PHASE0]) @@ -78,6 +79,7 @@ def test_basic(spec, state): committee_index = spec.CommitteeIndex(0) shard = spec.compute_shard_from_committee_index(state, committee_index, state.slot) + shard_store = spec.get_forkchoice_shard_store(state, shard) - run_apply_shard_and_beacon(spec, state, store, shard, committee_index) - run_apply_shard_and_beacon(spec, state, store, shard, committee_index) + run_apply_shard_and_beacon(spec, state, store, shard_store, committee_index) + run_apply_shard_and_beacon(spec, state, store, shard_store, committee_index)