From 2dc041807af913f16e54cf44a46fef07f80dbe2e Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Fri, 29 May 2020 23:50:18 +0800 Subject: [PATCH] Implement `get_start_shard` --- setup.py | 15 ++++- specs/phase1/beacon-chain.md | 61 ++++++++++++++++++- specs/phase1/phase1-fork.md | 1 + .../test_process_crosslink.py | 2 +- .../test/phase_1/sanity/test_blocks.py | 4 +- 5 files changed, 74 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 4f3f2b872..d0f063b33 100644 --- a/setup.py +++ b/setup.py @@ -140,7 +140,7 @@ SUNDRY_CONSTANTS_FUNCTIONS = ''' def ceillog2(x: uint64) -> int: return (x - 1).bit_length() ''' -SUNDRY_FUNCTIONS = ''' +PHASE0_SUNDRY_FUNCTIONS = ''' # Monkey patch hash cache _hash = hash hash_cache: Dict[bytes, Bytes32] = {} @@ -220,6 +220,13 @@ get_attesting_indices = cache_this( _get_attesting_indices, lru_size=SLOTS_PER_EPOCH * MAX_COMMITTEES_PER_SLOT * 3)''' +PHASE1_SUNDRY_FUNCTIONS = ''' +_get_start_shard = get_start_shard +get_start_shard = cache_this( + lambda state, slot: (state.validators.hash_tree_root(), slot), + _get_start_shard, lru_size=SLOTS_PER_EPOCH * 3)''' + + def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str: """ Given all the objects that constitute a spec, combine them into a single pyfile. @@ -250,9 +257,11 @@ def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str: + '\n\n' + CONFIG_LOADER + '\n\n' + ssz_objects_instantiation_spec + '\n\n' + functions_spec - + '\n' + SUNDRY_FUNCTIONS - + '\n' + + '\n' + PHASE0_SUNDRY_FUNCTIONS ) + if fork == 'phase1': + spec += '\n' + PHASE1_SUNDRY_FUNCTIONS + spec += '\n' return spec diff --git a/specs/phase1/beacon-chain.md b/specs/phase1/beacon-chain.md index 39a73c0aa..124442b59 100644 --- a/specs/phase1/beacon-chain.md +++ b/specs/phase1/beacon-chain.md @@ -47,6 +47,7 @@ - [`get_light_client_committee`](#get_light_client_committee) - [`get_shard_proposer_index`](#get_shard_proposer_index) - [`get_indexed_attestation`](#get_indexed_attestation) + - [`get_committee_count_delta`](#get_committee_count_delta) - [`get_start_shard`](#get_start_shard) - [`get_shard`](#get_shard) - [`get_latest_slot_for_shard`](#get_latest_slot_for_shard) @@ -69,6 +70,7 @@ - [Shard transition false positives](#shard-transition-false-positives) - [Light client processing](#light-client-processing) - [Epoch transition](#epoch-transition) + - [Phase 1 final updates](#phase-1-final-updates) - [Custody game updates](#custody-game-updates) - [Online-tracking](#online-tracking) - [Light client committee updates](#light-client-committee-updates) @@ -280,6 +282,7 @@ class BeaconState(Container): current_justified_checkpoint: Checkpoint finalized_checkpoint: Checkpoint # Phase 1 + current_epoch_start_shard: Shard shard_states: List[ShardState, MAX_SHARDS] online_countdown: List[OnlineEpochs, VALIDATOR_REGISTRY_LIMIT] # not a raw byte array, considered its large size. current_light_committee: CompactCommittee @@ -530,18 +533,53 @@ def get_indexed_attestation(beacon_state: BeaconState, attestation: Attestation) ) ``` +#### `get_committee_count_delta` + +```python +def get_committee_count_delta(state: BeaconState, start_slot: Slot, stop_slot: Slot) -> uint64: + """ + Return the sum of committee counts between ``[start_slot, stop_slot)``. + """ + committee_sum = 0 + for slot in range(start_slot, stop_slot): + count = get_committee_count_at_slot(state, Slot(slot)) + committee_sum += count + return committee_sum +``` + #### `get_start_shard` ```python def get_start_shard(state: BeaconState, slot: Slot) -> Shard: - # TODO: implement start shard logic - return Shard(0) + """ + Return the start shard at ``slot``. + """ + current_epoch_start_slot = compute_start_slot_at_epoch(get_current_epoch(state)) + active_shard_count = get_active_shard_count(state) + if current_epoch_start_slot == slot: + return state.current_epoch_start_shard + elif current_epoch_start_slot > slot: + # Current epoch or the next epoch lookahead + shard_delta = get_committee_count_delta(state, start_slot=current_epoch_start_slot, stop_slot=slot) + return Shard((state.current_epoch_start_shard + shard_delta) % active_shard_count) + else: + # Previous epoch + shard_delta = get_committee_count_delta(state, start_slot=slot, stop_slot=current_epoch_start_slot) + max_committees_per_epoch = MAX_COMMITTEES_PER_SLOT * SLOTS_PER_EPOCH + return Shard( + # Ensure positive + (state.current_epoch_start_shard + max_committees_per_epoch * active_shard_count - shard_delta) + % active_shard_count + ) ``` #### `get_shard` ```python def get_shard(state: BeaconState, attestation: Attestation) -> Shard: + """ + Return the shard that the given attestation is attesting. + """ return compute_shard_from_committee_index(state, attestation.data.index, attestation.data.slot) ``` @@ -549,6 +587,9 @@ def get_shard(state: BeaconState, attestation: Attestation) -> Shard: ```python def get_latest_slot_for_shard(state: BeaconState, shard: Shard) -> Slot: + """ + Return the latest slot number of the given shard. + """ return state.shard_states[shard].slot ``` @@ -556,7 +597,11 @@ def get_latest_slot_for_shard(state: BeaconState, shard: Shard) -> Slot: ```python def get_offset_slots(state: BeaconState, shard: Shard) -> Sequence[Slot]: - return compute_offset_slots(state.shard_states[shard].slot, state.slot) + """ + Return the offset slots of the given shard. + The offset slot are after the latest slot and before current slot. + """ + return compute_offset_slots(get_latest_slot_for_shard(state, shard), state.slot) ``` ### Predicates @@ -993,9 +1038,19 @@ def process_epoch(state: BeaconState) -> None: process_reveal_deadlines(state) process_slashings(state) process_final_updates(state) + process_phase_1_final_updates(state) +``` + +#### Phase 1 final updates + +```python +def process_phase_1_final_updates(state: BeaconState) -> None: process_custody_final_updates(state) process_online_tracking(state) process_light_client_committee_updates(state) + + # Update current_epoch_start_shard + state.current_epoch_start_shard = get_start_shard(state, state.slot) ``` #### Custody game updates diff --git a/specs/phase1/phase1-fork.md b/specs/phase1/phase1-fork.md index cc7d8f33e..d362ed633 100644 --- a/specs/phase1/phase1-fork.md +++ b/specs/phase1/phase1-fork.md @@ -99,6 +99,7 @@ def upgrade_to_phase1(pre: phase0.BeaconState) -> BeaconState: current_justified_checkpoint=pre.current_justified_checkpoint, finalized_checkpoint=pre.finalized_checkpoint, # Phase 1 + current_epoch_start_shard=Shard(0), shard_states=List[ShardState, MAX_SHARDS]( ShardState( slot=pre.slot, diff --git a/tests/core/pyspec/eth2spec/test/phase_1/block_processing/test_process_crosslink.py b/tests/core/pyspec/eth2spec/test/phase_1/block_processing/test_process_crosslink.py index 1f066b344..af0ff1a90 100644 --- a/tests/core/pyspec/eth2spec/test/phase_1/block_processing/test_process_crosslink.py +++ b/tests/core/pyspec/eth2spec/test/phase_1/block_processing/test_process_crosslink.py @@ -18,7 +18,7 @@ def run_basic_crosslink_tests(spec, state, target_len_offset_slot, valid=True): # At the beginning, let `x = state.slot`, `state.shard_states[shard].slot == x - 1` slot_x = state.slot committee_index = spec.CommitteeIndex(0) - shard = spec.compute_shard_from_committee_index(state, committee_index, state.slot) + shard = spec.compute_shard_from_committee_index(state, committee_index, state.slot + target_len_offset_slot) assert state.shard_states[shard].slot == slot_x - 1 # Create SignedShardBlock diff --git a/tests/core/pyspec/eth2spec/test/phase_1/sanity/test_blocks.py b/tests/core/pyspec/eth2spec/test/phase_1/sanity/test_blocks.py index 60af35d45..1499d34cd 100644 --- a/tests/core/pyspec/eth2spec/test/phase_1/sanity/test_blocks.py +++ b/tests/core/pyspec/eth2spec/test/phase_1/sanity/test_blocks.py @@ -67,7 +67,7 @@ def test_process_beacon_block_with_normal_shard_transition(spec, state): target_len_offset_slot = 1 committee_index = spec.CommitteeIndex(0) - shard = spec.compute_shard_from_committee_index(state, committee_index, state.slot) + shard = spec.compute_shard_from_committee_index(state, committee_index, state.slot + target_len_offset_slot) assert state.shard_states[shard].slot == state.slot - 1 pre_gasprice = state.shard_states[shard].gasprice @@ -93,7 +93,7 @@ def test_process_beacon_block_with_empty_proposal_transition(spec, state): target_len_offset_slot = 1 committee_index = spec.CommitteeIndex(0) - shard = spec.compute_shard_from_committee_index(state, committee_index, state.slot) + shard = spec.compute_shard_from_committee_index(state, committee_index, state.slot + target_len_offset_slot) assert state.shard_states[shard].slot == state.slot - 1 # No new shard block