Merge pull request #1970 from ethereum/shard-fork-choice-fix

Fix shard fork choice
This commit is contained in:
Hsiao-Wei Wang 2020-07-30 01:27:11 +08:00 committed by GitHub
commit a609320ad4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 415 additions and 187 deletions

View File

@ -1,3 +1,4 @@
from enum import Enum, auto
from setuptools import setup, find_packages, Command from setuptools import setup, find_packages, Command
from setuptools.command.build_py import build_py from setuptools.command.build_py import build_py
from distutils import dir_util from distutils import dir_util
@ -14,6 +15,13 @@ class SpecObject(NamedTuple):
custom_types: Dict[str, str] custom_types: Dict[str, str]
constants: Dict[str, str] constants: Dict[str, str]
ssz_objects: Dict[str, str] ssz_objects: Dict[str, str]
dataclasses: Dict[str, str]
class CodeBlockType(Enum):
SSZ = auto()
DATACLASS = auto()
FUNCTION = auto()
def get_spec(file_name: str) -> SpecObject: def get_spec(file_name: str) -> SpecObject:
@ -28,8 +36,9 @@ def get_spec(file_name: str) -> SpecObject:
functions: Dict[str, str] = {} functions: Dict[str, str] = {}
constants: Dict[str, str] = {} constants: Dict[str, str] = {}
ssz_objects: Dict[str, str] = {} ssz_objects: Dict[str, str] = {}
dataclasses: Dict[str, str] = {}
function_matcher = re.compile(FUNCTION_REGEX) function_matcher = re.compile(FUNCTION_REGEX)
is_ssz = False block_type = CodeBlockType.FUNCTION
custom_types: Dict[str, str] = {} custom_types: Dict[str, str] = {}
for linenum, line in enumerate(open(file_name).readlines()): for linenum, line in enumerate(open(file_name).readlines()):
line = line.rstrip() line = line.rstrip()
@ -43,20 +52,26 @@ def get_spec(file_name: str) -> SpecObject:
else: else:
# Handle function definitions & ssz_objects # Handle function definitions & ssz_objects
if pulling_from is not None: if pulling_from is not None:
# SSZ Object
if len(line) > 18 and line[:6] == 'class ' and line[-12:] == '(Container):': if len(line) > 18 and line[:6] == 'class ' and line[-12:] == '(Container):':
name = line[6:-12] name = line[6:-12]
# Check consistency with markdown header # Check consistency with markdown header
assert name == current_name assert name == current_name
is_ssz = True block_type = CodeBlockType.SSZ
# function definition elif line[:10] == '@dataclass':
block_type = CodeBlockType.DATACLASS
elif function_matcher.match(line) is not None: elif function_matcher.match(line) is not None:
current_name = function_matcher.match(line).group(0) current_name = function_matcher.match(line).group(0)
is_ssz = False block_type = CodeBlockType.FUNCTION
if is_ssz:
if block_type == CodeBlockType.SSZ:
ssz_objects[current_name] = ssz_objects.get(current_name, '') + line + '\n' ssz_objects[current_name] = ssz_objects.get(current_name, '') + line + '\n'
else: elif block_type == CodeBlockType.DATACLASS:
dataclasses[current_name] = dataclasses.get(current_name, '') + line + '\n'
elif block_type == CodeBlockType.FUNCTION:
functions[current_name] = functions.get(current_name, '') + line + '\n' functions[current_name] = functions.get(current_name, '') + line + '\n'
else:
pass
# Handle constant and custom types table entries # Handle constant and custom types table entries
elif pulling_from is None and len(line) > 0 and line[0] == '|': elif pulling_from is None and len(line) > 0 and line[0] == '|':
row = line[1:].split('|') row = line[1:].split('|')
@ -75,7 +90,7 @@ def get_spec(file_name: str) -> SpecObject:
constants[row[0]] = row[1].replace('**TBD**', '2**32') constants[row[0]] = row[1].replace('**TBD**', '2**32')
elif row[1].startswith('uint') or row[1].startswith('Bytes'): elif row[1].startswith('uint') or row[1].startswith('Bytes'):
custom_types[row[0]] = row[1] custom_types[row[0]] = row[1]
return SpecObject(functions, custom_types, constants, ssz_objects) return SpecObject(functions, custom_types, constants, ssz_objects, dataclasses)
CONFIG_LOADER = ''' CONFIG_LOADER = '''
@ -220,7 +235,7 @@ get_start_shard = cache_this(
_get_start_shard, lru_size=SLOTS_PER_EPOCH * 3)''' _get_start_shard, lru_size=SLOTS_PER_EPOCH * 3)'''
def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str: def objects_to_spec(spec_object: SpecObject, imports: str, fork: str, ordered_class_objects: Dict[str, str]) -> str:
""" """
Given all the objects that constitute a spec, combine them into a single pyfile. Given all the objects that constitute a spec, combine them into a single pyfile.
""" """
@ -240,7 +255,7 @@ def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str:
if k == "BLS12_381_Q": if k == "BLS12_381_Q":
spec_object.constants[k] += " # noqa: E501" spec_object.constants[k] += " # noqa: E501"
constants_spec = '\n'.join(map(lambda x: '%s = %s' % (x, spec_object.constants[x]), spec_object.constants)) constants_spec = '\n'.join(map(lambda x: '%s = %s' % (x, spec_object.constants[x]), spec_object.constants))
ssz_objects_instantiation_spec = '\n\n'.join(spec_object.ssz_objects.values()) ordered_class_objects_spec = '\n\n'.join(ordered_class_objects.values())
spec = ( spec = (
imports imports
+ '\n\n' + f"fork = \'{fork}\'\n" + '\n\n' + f"fork = \'{fork}\'\n"
@ -248,7 +263,7 @@ def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str:
+ '\n' + SUNDRY_CONSTANTS_FUNCTIONS + '\n' + SUNDRY_CONSTANTS_FUNCTIONS
+ '\n\n' + constants_spec + '\n\n' + constants_spec
+ '\n\n' + CONFIG_LOADER + '\n\n' + CONFIG_LOADER
+ '\n\n' + ssz_objects_instantiation_spec + '\n\n' + ordered_class_objects_spec
+ '\n\n' + functions_spec + '\n\n' + functions_spec
+ '\n' + PHASE0_SUNDRY_FUNCTIONS + '\n' + PHASE0_SUNDRY_FUNCTIONS
) )
@ -274,11 +289,12 @@ ignored_dependencies = [
'bit', 'boolean', 'Vector', 'List', 'Container', 'BLSPubkey', 'BLSSignature', 'bit', 'boolean', 'Vector', 'List', 'Container', 'BLSPubkey', 'BLSSignature',
'Bytes1', 'Bytes4', 'Bytes32', 'Bytes48', 'Bytes96', 'Bitlist', 'Bitvector', 'Bytes1', 'Bytes4', 'Bytes32', 'Bytes48', 'Bytes96', 'Bitlist', 'Bitvector',
'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256',
'bytes', 'byte', 'ByteList', 'ByteVector' 'bytes', 'byte', 'ByteList', 'ByteVector',
'Dict', 'dict', 'field',
] ]
def dependency_order_ssz_objects(objects: Dict[str, str], custom_types: Dict[str, str]) -> None: def dependency_order_class_objects(objects: Dict[str, str], custom_types: Dict[str, str]) -> None:
""" """
Determines which SSZ Object is dependent on which other and orders them appropriately Determines which SSZ Object is dependent on which other and orders them appropriately
""" """
@ -315,13 +331,14 @@ def combine_spec_objects(spec0: SpecObject, spec1: SpecObject) -> SpecObject:
""" """
Takes in two spec variants (as tuples of their objects) and combines them using the appropriate combiner function. Takes in two spec variants (as tuples of their objects) and combines them using the appropriate combiner function.
""" """
functions0, custom_types0, constants0, ssz_objects0 = spec0 functions0, custom_types0, constants0, ssz_objects0, dataclasses0 = spec0
functions1, custom_types1, constants1, ssz_objects1 = spec1 functions1, custom_types1, constants1, ssz_objects1, dataclasses1 = spec1
functions = combine_functions(functions0, functions1) functions = combine_functions(functions0, functions1)
custom_types = combine_constants(custom_types0, custom_types1) custom_types = combine_constants(custom_types0, custom_types1)
constants = combine_constants(constants0, constants1) constants = combine_constants(constants0, constants1)
ssz_objects = combine_ssz_objects(ssz_objects0, ssz_objects1, custom_types) ssz_objects = combine_ssz_objects(ssz_objects0, ssz_objects1, custom_types)
return SpecObject(functions, custom_types, constants, ssz_objects) dataclasses = combine_functions(dataclasses0, dataclasses1)
return SpecObject(functions, custom_types, constants, ssz_objects, dataclasses)
fork_imports = { fork_imports = {
@ -337,9 +354,10 @@ def build_spec(fork: str, source_files: List[str]) -> str:
for value in all_specs[1:]: for value in all_specs[1:]:
spec_object = combine_spec_objects(spec_object, value) spec_object = combine_spec_objects(spec_object, value)
dependency_order_ssz_objects(spec_object.ssz_objects, spec_object.custom_types) class_objects = {**spec_object.ssz_objects, **spec_object.dataclasses}
dependency_order_class_objects(class_objects, spec_object.custom_types)
return objects_to_spec(spec_object, fork_imports[fork], fork) return objects_to_spec(spec_object, fork_imports[fork], fork, class_objects)
class PySpecCommand(Command): class PySpecCommand(Command):

View File

@ -9,8 +9,13 @@
- [Introduction](#introduction) - [Introduction](#introduction)
- [Helpers](#helpers) - [Updated data structures](#updated-data-structures)
- [Extended `LatestMessage`](#extended-latestmessage) - [Extended `Store`](#extended-store)
- [New data structures](#new-data-structures)
- [`ShardLatestMessage`](#shardlatestmessage)
- [`ShardStore`](#shardstore)
- [Updated helpers](#updated-helpers)
- [Updated `get_forkchoice_store`](#updated-get_forkchoice_store)
- [Updated `update_latest_messages`](#updated-update_latest_messages) - [Updated `update_latest_messages`](#updated-update_latest_messages)
<!-- END doctoc generated TOC please keep comment here to allow auto update --> <!-- END doctoc generated TOC please keep comment here to allow auto update -->
@ -20,17 +25,74 @@
This document is the beacon chain fork choice spec for part of Ethereum 2.0 Phase 1. This document is the beacon chain fork choice spec for part of Ethereum 2.0 Phase 1.
### Helpers ### Updated data structures
#### Extended `LatestMessage` #### Extended `Store`
```python
@dataclass
class Store(object):
time: uint64
genesis_time: uint64
justified_checkpoint: Checkpoint
finalized_checkpoint: Checkpoint
best_justified_checkpoint: Checkpoint
blocks: Dict[Root, BeaconBlock] = field(default_factory=dict)
block_states: Dict[Root, BeaconState] = field(default_factory=dict)
checkpoint_states: Dict[Checkpoint, BeaconState] = field(default_factory=dict)
latest_messages: Dict[ValidatorIndex, LatestMessage] = field(default_factory=dict)
shard_stores: Dict[Shard, ShardStore] = field(default_factory=dict)
```
### New data structures
#### `ShardLatestMessage`
```python ```python
@dataclass(eq=True, frozen=True) @dataclass(eq=True, frozen=True)
class LatestMessage(object): class ShardLatestMessage(object):
epoch: Epoch epoch: Epoch
root: Root root: Root
```
#### `ShardStore`
```python
@dataclass
class ShardStore:
shard: Shard shard: Shard
shard_root: Root signed_blocks: Dict[Root, SignedShardBlock] = field(default_factory=dict)
block_states: Dict[Root, ShardState] = field(default_factory=dict)
latest_messages: Dict[ValidatorIndex, ShardLatestMessage] = field(default_factory=dict)
```
### Updated helpers
#### Updated `get_forkchoice_store`
```python
def get_forkchoice_store(anchor_state: BeaconState) -> Store:
anchor_block_header = anchor_state.latest_block_header.copy()
if anchor_block_header.state_root == Bytes32():
anchor_block_header.state_root = hash_tree_root(anchor_state)
anchor_root = hash_tree_root(anchor_block_header)
anchor_epoch = get_current_epoch(anchor_state)
justified_checkpoint = Checkpoint(epoch=anchor_epoch, root=anchor_root)
finalized_checkpoint = Checkpoint(epoch=anchor_epoch, root=anchor_root)
return Store(
time=anchor_state.genesis_time + SECONDS_PER_SLOT * anchor_state.slot,
genesis_time=anchor_state.genesis_time,
justified_checkpoint=justified_checkpoint,
finalized_checkpoint=finalized_checkpoint,
best_justified_checkpoint=justified_checkpoint,
blocks={anchor_root: anchor_block_header},
block_states={anchor_root: anchor_state.copy()},
checkpoint_states={justified_checkpoint: anchor_state.copy()},
shard_stores={
Shard(shard): get_forkchoice_shard_store(anchor_state, Shard(shard))
for shard in range(get_active_shard_count(anchor_state))
}
)
``` ```
#### Updated `update_latest_messages` #### Updated `update_latest_messages`
@ -43,7 +105,7 @@ def update_latest_messages(store: Store, attesting_indices: Sequence[ValidatorIn
shard = attestation.data.shard shard = attestation.data.shard
for i in attesting_indices: for i in attesting_indices:
if i not in store.latest_messages or target.epoch > store.latest_messages[i].epoch: if i not in store.latest_messages or target.epoch > store.latest_messages[i].epoch:
store.latest_messages[i] = LatestMessage( store.latest_messages[i] = LatestMessage(epoch=target.epoch, root=beacon_block_root)
epoch=target.epoch, root=beacon_block_root, shard=shard, shard_root=attestation.data.shard_head_root shard_latest_message = ShardLatestMessage(epoch=target.epoch, root=attestation.data.shard_head_root)
) store.shard_stores[shard].latest_messages[i] = shard_latest_message
``` ```

View File

@ -11,7 +11,6 @@
- [Introduction](#introduction) - [Introduction](#introduction)
- [Fork choice](#fork-choice) - [Fork choice](#fork-choice)
- [Helpers](#helpers) - [Helpers](#helpers)
- [`ShardStore`](#shardstore)
- [`get_forkchoice_shard_store`](#get_forkchoice_shard_store) - [`get_forkchoice_shard_store`](#get_forkchoice_shard_store)
- [`get_shard_latest_attesting_balance`](#get_shard_latest_attesting_balance) - [`get_shard_latest_attesting_balance`](#get_shard_latest_attesting_balance)
- [`get_shard_head`](#get_shard_head) - [`get_shard_head`](#get_shard_head)
@ -30,16 +29,6 @@ This document is the shard chain fork choice spec for part of Ethereum 2.0 Phase
### Helpers ### Helpers
#### `ShardStore`
```python
@dataclass
class ShardStore:
shard: Shard
signed_blocks: Dict[Root, SignedShardBlock] = field(default_factory=dict)
block_states: Dict[Root, ShardState] = field(default_factory=dict)
```
#### `get_forkchoice_shard_store` #### `get_forkchoice_shard_store`
```python ```python
@ -58,18 +47,21 @@ def get_forkchoice_shard_store(anchor_state: BeaconState, shard: Shard) -> Shard
#### `get_shard_latest_attesting_balance` #### `get_shard_latest_attesting_balance`
```python ```python
def get_shard_latest_attesting_balance(store: Store, shard_store: ShardStore, root: Root) -> Gwei: def get_shard_latest_attesting_balance(store: Store, shard: Shard, root: Root) -> Gwei:
shard_store = store.shard_stores[shard]
state = store.checkpoint_states[store.justified_checkpoint] state = store.checkpoint_states[store.justified_checkpoint]
active_indices = get_active_validator_indices(state, get_current_epoch(state)) active_indices = get_active_validator_indices(state, get_current_epoch(state))
return Gwei(sum( return Gwei(sum(
state.validators[i].effective_balance for i in active_indices state.validators[i].effective_balance for i in active_indices
if ( if (
i in store.latest_messages i in shard_store.latest_messages
# TODO: check the latest message logic: currently, validator's previous vote of another shard # TODO: check the latest message logic: currently, validator's previous vote of another shard
# would be ignored once their newer vote is accepted. Check if it makes sense. # would be ignored once their newer vote is accepted. Check if it makes sense.
and store.latest_messages[i].shard == shard_store.shard
and get_shard_ancestor( and get_shard_ancestor(
store, shard_store, store.latest_messages[i].shard_root, shard_store.signed_blocks[root].message.slot store,
shard,
shard_store.latest_messages[i].root,
shard_store.signed_blocks[root].message.slot,
) == root ) == root
) )
)) ))
@ -78,10 +70,14 @@ def get_shard_latest_attesting_balance(store: Store, shard_store: ShardStore, ro
#### `get_shard_head` #### `get_shard_head`
```python ```python
def get_shard_head(store: Store, shard_store: ShardStore) -> Root: def get_shard_head(store: Store, shard: Shard) -> Root:
# Execute the LMD-GHOST fork choice # Execute the LMD-GHOST fork choice
"""
Execute the LMD-GHOST fork choice.
"""
shard_store = store.shard_stores[shard]
beacon_head_root = get_head(store) beacon_head_root = get_head(store)
shard_head_state = store.block_states[beacon_head_root].shard_states[shard_store.shard] shard_head_state = store.block_states[beacon_head_root].shard_states[shard]
shard_head_root = shard_head_state.latest_block_root shard_head_root = shard_head_state.latest_block_root
shard_blocks = { shard_blocks = {
root: signed_shard_block.message for root, signed_shard_block in shard_store.signed_blocks.items() root: signed_shard_block.message for root, signed_shard_block in shard_store.signed_blocks.items()
@ -97,17 +93,18 @@ def get_shard_head(store: Store, shard_store: ShardStore) -> Root:
return shard_head_root return shard_head_root
# Sort by latest attesting balance with ties broken lexicographically # Sort by latest attesting balance with ties broken lexicographically
shard_head_root = max( shard_head_root = max(
children, key=lambda root: (get_shard_latest_attesting_balance(store, shard_store, root), root) children, key=lambda root: (get_shard_latest_attesting_balance(store, shard, root), root)
) )
``` ```
#### `get_shard_ancestor` #### `get_shard_ancestor`
```python ```python
def get_shard_ancestor(store: Store, shard_store: ShardStore, root: Root, slot: Slot) -> Root: def get_shard_ancestor(store: Store, shard: Shard, root: Root, slot: Slot) -> Root:
shard_store = store.shard_stores[shard]
block = shard_store.signed_blocks[root].message block = shard_store.signed_blocks[root].message
if block.slot > slot: if block.slot > slot:
return get_shard_ancestor(store, shard_store, block.shard_parent_root, slot) return get_shard_ancestor(store, shard, block.shard_parent_root, slot)
elif block.slot == slot: elif block.slot == slot:
return root return root
else: else:
@ -118,17 +115,17 @@ def get_shard_ancestor(store: Store, shard_store: ShardStore, root: Root, slot:
#### `get_pending_shard_blocks` #### `get_pending_shard_blocks`
```python ```python
def get_pending_shard_blocks(store: Store, shard_store: ShardStore) -> Sequence[SignedShardBlock]: def get_pending_shard_blocks(store: Store, shard: Shard) -> Sequence[SignedShardBlock]:
""" """
Return the canonical shard block branch that has not yet been crosslinked. Return the canonical shard block branch that has not yet been crosslinked.
""" """
shard = shard_store.shard shard_store = store.shard_stores[shard]
beacon_head_root = get_head(store) beacon_head_root = get_head(store)
beacon_head_state = store.block_states[beacon_head_root] beacon_head_state = store.block_states[beacon_head_root]
latest_shard_block_root = beacon_head_state.shard_states[shard].latest_block_root latest_shard_block_root = beacon_head_state.shard_states[shard].latest_block_root
shard_head_root = get_shard_head(store, shard_store) shard_head_root = get_shard_head(store, shard)
root = shard_head_root root = shard_head_root
signed_shard_blocks = [] signed_shard_blocks = []
while root != latest_shard_block_root: while root != latest_shard_block_root:
@ -145,13 +142,10 @@ def get_pending_shard_blocks(store: Store, shard_store: ShardStore) -> Sequence[
#### `on_shard_block` #### `on_shard_block`
```python ```python
def on_shard_block(store: Store, shard_store: ShardStore, signed_shard_block: SignedShardBlock) -> None: def on_shard_block(store: Store, signed_shard_block: SignedShardBlock) -> None:
shard_block = signed_shard_block.message shard_block = signed_shard_block.message
shard = shard_store.shard shard = shard_block.shard
shard_store = store.shard_stores[shard]
# Check shard
# TODO: check it in networking spec
assert shard_block.shard == shard
# Check shard parent exists # Check shard parent exists
assert shard_block.shard_parent_root in shard_store.block_states assert shard_block.shard_parent_root in shard_store.block_states

View File

@ -26,9 +26,12 @@ def run_on_attestation(spec, state, store, attestation, valid=True):
latest_message = spec.LatestMessage( latest_message = spec.LatestMessage(
epoch=attestation.data.target.epoch, epoch=attestation.data.target.epoch,
root=attestation.data.beacon_block_root, root=attestation.data.beacon_block_root,
shard=attestation.data.shard,
shard_root=attestation.data.shard_head_root,
) )
shard_latest_message = spec.ShardLatestMessage(
epoch=attestation.data.target.epoch,
root=attestation.data.shard_head_root,
)
assert store.shard_stores[attestation.data.shard].latest_messages[sample_index] == shard_latest_message
assert ( assert (
store.latest_messages[sample_index] == latest_message store.latest_messages[sample_index] == latest_message

View File

@ -0,0 +1,280 @@
from eth2spec.utils.ssz.ssz_impl import hash_tree_root
from eth2spec.test.context import PHASE0, spec_state_test, with_all_phases_except, never_bls
from eth2spec.test.helpers.attestations import get_valid_on_time_attestation
from eth2spec.test.helpers.shard_block import (
build_shard_block,
get_shard_transitions,
get_committee_index_of_shard,
)
from eth2spec.test.helpers.fork_choice import add_block_to_store, get_anchor_root
from eth2spec.test.helpers.shard_transitions import is_full_crosslink
from eth2spec.test.helpers.state import state_transition_and_sign_block
from eth2spec.test.helpers.block import build_empty_block
def run_on_shard_block(spec, store, signed_block, valid=True):
shard = signed_block.message.shard
if not valid:
try:
spec.on_shard_block(store, signed_block)
except AssertionError:
return
else:
assert False
spec.on_shard_block(store, signed_block)
shard_store = store.shard_stores[shard]
assert shard_store.signed_blocks[hash_tree_root(signed_block.message)] == signed_block
def initialize_store(spec, state, shards):
store = spec.get_forkchoice_store(state)
anchor_root = get_anchor_root(spec, state)
assert spec.get_head(store) == anchor_root
for shard in shards:
shard_head_root = spec.get_shard_head(store, shard)
assert shard_head_root == state.shard_states[shard].latest_block_root
shard_store = store.shard_stores[shard]
assert shard_store.block_states[shard_head_root].slot == 0
assert shard_store.block_states[shard_head_root] == state.shard_states[shard]
return store
def create_and_apply_shard_block(spec, store, shard, beacon_parent_state, shard_blocks_buffer):
body = b'\x56' * 4
shard_head_root = spec.get_shard_head(store, shard)
shard_store = store.shard_stores[shard]
shard_parent_state = shard_store.block_states[shard_head_root]
assert shard_parent_state.slot != beacon_parent_state.slot
shard_block = build_shard_block(
spec, beacon_parent_state, shard,
shard_parent_state=shard_parent_state, slot=beacon_parent_state.slot, body=body, signed=True
)
shard_blocks_buffer.append(shard_block)
run_on_shard_block(spec, store, shard_block)
assert spec.get_shard_head(store, shard) == shard_block.message.hash_tree_root()
def check_pending_shard_blocks(spec, store, shard, shard_blocks_buffer):
pending_shard_blocks = spec.get_pending_shard_blocks(store, shard)
assert pending_shard_blocks == shard_blocks_buffer
def is_in_offset_sets(spec, beacon_head_state, shard):
offset_slots = spec.compute_offset_slots(
beacon_head_state.shard_states[shard].slot, beacon_head_state.slot + 1
)
return beacon_head_state.slot in offset_slots
def create_attestation_for_shard_blocks(spec, beacon_parent_state, shard, committee_index, blocks,
filter_participant_set=None):
shard_transition = spec.get_shard_transition(beacon_parent_state, shard, blocks)
attestation = get_valid_on_time_attestation(
spec,
beacon_parent_state,
index=committee_index,
shard_transition=shard_transition,
signed=True,
)
return attestation
def create_beacon_block_with_shard_transition(
spec, state, store, shard, shard_blocks_buffer, is_checking_pending_shard_blocks=True):
beacon_block = build_empty_block(spec, state, slot=state.slot + 1)
committee_index = get_committee_index_of_shard(spec, state, state.slot, shard)
has_shard_committee = committee_index is not None # has committee of `shard` at this slot
beacon_block = build_empty_block(spec, state, slot=state.slot + 1)
# If next slot has committee of `shard`, add `shard_transtion` to the proposing beacon block
if has_shard_committee and len(shard_blocks_buffer) > 0:
# Sanity check `get_pending_shard_blocks`
# Assert that the pending shard blocks set in the store equal to shard_blocks_buffer
if is_checking_pending_shard_blocks:
check_pending_shard_blocks(spec, store, shard, shard_blocks_buffer)
# Use temporary next state to get ShardTransition of shard block
shard_transitions = get_shard_transitions(spec, state, shard_block_dict={shard: shard_blocks_buffer})
shard_transition = shard_transitions[shard]
attestation = get_valid_on_time_attestation(
spec,
state,
index=committee_index,
shard_transition=shard_transition,
signed=True,
)
assert attestation.data.shard == shard
beacon_block.body.attestations = [attestation]
beacon_block.body.shard_transitions = shard_transitions
# Clear buffer
shard_blocks_buffer.clear()
return beacon_block
def apply_all_attestation_to_store(spec, store, attestations):
for attestation in attestations:
spec.on_attestation(store, attestation)
def apply_beacon_block_to_store(spec, state, store, beacon_block):
signed_beacon_block = state_transition_and_sign_block(spec, state, beacon_block) # transition!
store.time = store.time + spec.SECONDS_PER_SLOT
add_block_to_store(spec, store, signed_beacon_block)
apply_all_attestation_to_store(spec, store, signed_beacon_block.message.body.attestations)
def create_and_apply_beacon_and_shard_blocks(spec, state, store, shard, shard_blocks_buffer,
is_checking_pending_shard_blocks=True):
beacon_block = create_beacon_block_with_shard_transition(
spec, state, store, shard, shard_blocks_buffer,
is_checking_pending_shard_blocks=is_checking_pending_shard_blocks
)
apply_beacon_block_to_store(spec, state, store, beacon_block)
# On shard block at the transitioned `state.slot`
if is_in_offset_sets(spec, state, shard):
# The created shard block would be appended to `shard_blocks_buffer`
create_and_apply_shard_block(spec, store, shard, state, shard_blocks_buffer)
has_shard_committee = get_committee_index_of_shard(spec, state, state.slot, shard) is not None
return has_shard_committee
@with_all_phases_except([PHASE0])
@spec_state_test
@never_bls # Set to never_bls for testing `check_pending_shard_blocks`
def test_basic(spec, state):
spec.PHASE_1_GENESIS_SLOT = 0 # NOTE: mock genesis slot here
state = spec.upgrade_to_phase1(state)
shard = spec.Shard(1)
# Initialization
store = initialize_store(spec, state, [shard])
# For mainnet config, it's possible that only one committee of `shard` per epoch.
# we set this counter to test more rounds.
shard_committee_counter = 2
shard_blocks_buffer = [] # the accumulated shard blocks that haven't been crosslinked yet
while shard_committee_counter > 0:
has_shard_committee = create_and_apply_beacon_and_shard_blocks(spec, state, store, shard, shard_blocks_buffer)
if has_shard_committee:
shard_committee_counter -= 1
def create_simple_fork(spec, state, store, shard):
# Beacon block
beacon_block = create_beacon_block_with_shard_transition(spec, state, store, shard, [])
apply_beacon_block_to_store(spec, state, store, beacon_block)
beacon_head_root = spec.get_head(store)
assert beacon_head_root == beacon_block.hash_tree_root()
beacon_parent_state = store.block_states[beacon_head_root]
shard_store = store.shard_stores[shard]
shard_parent_state = shard_store.block_states[spec.get_shard_head(store, shard)]
# Shard block A
body = b'\x56' * 4
forking_block_child = build_shard_block(
spec, beacon_parent_state, shard,
shard_parent_state=shard_parent_state, slot=beacon_parent_state.slot, body=body, signed=True
)
run_on_shard_block(spec, store, forking_block_child)
# Shard block B
body = b'\x78' * 4 # different body
shard_block_b = build_shard_block(
spec, beacon_parent_state, shard,
shard_parent_state=shard_parent_state, slot=beacon_parent_state.slot, body=body, signed=True
)
run_on_shard_block(spec, store, shard_block_b)
# Set forking_block
current_head = spec.get_shard_head(store, shard)
if current_head == forking_block_child.message.hash_tree_root():
head_block = forking_block_child
forking_block = shard_block_b
else:
assert current_head == shard_block_b.message.hash_tree_root()
head_block = shard_block_b
forking_block = forking_block_child
return head_block, forking_block
@with_all_phases_except([PHASE0])
@spec_state_test
def test_shard_simple_fork(spec, state):
if not is_full_crosslink(spec, state):
# skip
return
spec.PHASE_1_GENESIS_SLOT = 0 # NOTE: mock genesis slot here
state = spec.upgrade_to_phase1(state)
shard = spec.Shard(1)
# Initialization
store = initialize_store(spec, state, [shard])
# Create fork
_, forking_block = create_simple_fork(spec, state, store, shard)
# Vote for forking_block
state = store.block_states[spec.get_head(store)].copy()
beacon_block = create_beacon_block_with_shard_transition(spec, state, store, shard, [forking_block],
is_checking_pending_shard_blocks=False)
store.time = store.time + spec.SECONDS_PER_SLOT
apply_all_attestation_to_store(spec, store, beacon_block.body.attestations)
# Head block has been changed
assert spec.get_shard_head(store, shard) == forking_block.message.hash_tree_root()
@with_all_phases_except([PHASE0])
@spec_state_test
def test_shard_latest_messages_for_different_shards(spec, state):
if not is_full_crosslink(spec, state):
# skip
return
spec.PHASE_1_GENESIS_SLOT = 0 # NOTE: mock genesis slot here
state = spec.upgrade_to_phase1(state)
shard_0 = spec.Shard(0)
shard_1 = spec.Shard(1)
# Initialization
store = initialize_store(spec, state, [shard_0, shard_1])
# Shard 0 ----------------------------------
# Create fork on shard 0
_, forking_block = create_simple_fork(spec, state, store, shard_0)
# Vote for forking_block on shard 0
state = store.block_states[spec.get_head(store)].copy()
beacon_block = create_beacon_block_with_shard_transition(spec, state, store, shard_0, [forking_block],
is_checking_pending_shard_blocks=False)
store.time = store.time + spec.SECONDS_PER_SLOT
apply_all_attestation_to_store(spec, store, beacon_block.body.attestations)
# Head block of shard 0 has been changed due to the shard latest messages
assert spec.get_shard_head(store, shard_0) == forking_block.message.hash_tree_root()
# Shard 1 ----------------------------------
# Run shard 1 after 1~2 epochs
shard_committee_counter = 2
shard_blocks_buffer = [] # the accumulated shard blocks that haven't been crosslinked yet
while shard_committee_counter > 0:
has_shard_committee = create_and_apply_beacon_and_shard_blocks(
spec, state, store, shard_1, shard_blocks_buffer
)
if has_shard_committee:
shard_committee_counter -= 1
# Go back to see shard 0 ----------------------------------
# The head block of shard 0 should be unchanged.
assert spec.get_shard_head(store, shard_0) == forking_block.message.hash_tree_root()

View File

@ -1,129 +0,0 @@
from eth2spec.utils.ssz.ssz_impl import hash_tree_root
from eth2spec.test.context import PHASE0, spec_state_test, with_all_phases_except, never_bls
from eth2spec.test.helpers.attestations import get_valid_on_time_attestation
from eth2spec.test.helpers.shard_block import (
build_shard_block,
get_shard_transitions,
get_committee_index_of_shard,
)
from eth2spec.test.helpers.fork_choice import add_block_to_store, get_anchor_root
from eth2spec.test.helpers.state import (
state_transition_and_sign_block,
transition_to_valid_shard_slot,
)
from eth2spec.test.helpers.block import build_empty_block
def run_on_shard_block(spec, store, shard_store, signed_block, valid=True):
if not valid:
try:
spec.on_shard_block(store, shard_store, signed_block)
except AssertionError:
return
else:
assert False
spec.on_shard_block(store, shard_store, signed_block)
assert shard_store.signed_blocks[hash_tree_root(signed_block.message)] == signed_block
def apply_shard_block(spec, store, shard_store, beacon_parent_state, shard_blocks_buffer):
shard = shard_store.shard
body = b'\x56' * 4
shard_head_root = spec.get_shard_head(store, shard_store)
shard_parent_state = shard_store.block_states[shard_head_root]
assert shard_parent_state.slot != beacon_parent_state.slot
shard_block = build_shard_block(
spec, beacon_parent_state, shard,
shard_parent_state=shard_parent_state, slot=beacon_parent_state.slot, body=body, signed=True
)
shard_blocks_buffer.append(shard_block)
run_on_shard_block(spec, store, shard_store, shard_block)
assert spec.get_shard_head(store, shard_store) == shard_block.message.hash_tree_root()
def check_pending_shard_blocks(spec, store, shard_store, shard_blocks_buffer):
pending_shard_blocks = spec.get_pending_shard_blocks(store, shard_store)
assert pending_shard_blocks == shard_blocks_buffer
def is_in_offset_sets(spec, beacon_head_state, shard):
offset_slots = spec.compute_offset_slots(
beacon_head_state.shard_states[shard].slot, beacon_head_state.slot + 1
)
return beacon_head_state.slot in offset_slots
def apply_shard_and_beacon(spec, state, store, shard_store, shard_blocks_buffer):
store.time = store.time + spec.SECONDS_PER_SLOT * spec.SLOTS_PER_EPOCH
shard = shard_store.shard
committee_index = get_committee_index_of_shard(spec, state, state.slot, shard)
has_shard_committee = committee_index is not None # has committee of `shard` at this slot
beacon_block = build_empty_block(spec, state, slot=state.slot + 1)
# If next slot has committee of `shard`, add `shard_transtion` to the proposing beacon block
if has_shard_committee and len(shard_blocks_buffer) > 0:
# Sanity check `get_pending_shard_blocks` function
check_pending_shard_blocks(spec, store, shard_store, shard_blocks_buffer)
# Use temporary next state to get ShardTransition of shard block
shard_transitions = get_shard_transitions(
spec,
state,
shard_block_dict={shard: shard_blocks_buffer},
)
shard_transition = shard_transitions[shard]
attestation = get_valid_on_time_attestation(
spec,
state,
index=committee_index,
shard_transition=shard_transition,
signed=False,
)
assert attestation.data.shard == shard
beacon_block.body.attestations = [attestation]
beacon_block.body.shard_transitions = shard_transitions
# Clear buffer
shard_blocks_buffer.clear()
signed_beacon_block = state_transition_and_sign_block(spec, state, beacon_block) # transition!
add_block_to_store(spec, store, signed_beacon_block)
assert spec.get_head(store) == beacon_block.hash_tree_root()
# On shard block at transitioned `state.slot`
if is_in_offset_sets(spec, state, shard):
# The created shard block would be appended to `shard_blocks_buffer`
apply_shard_block(spec, store, shard_store, state, shard_blocks_buffer)
return has_shard_committee
@with_all_phases_except([PHASE0])
@spec_state_test
@never_bls # Set to never_bls for testing `check_pending_shard_blocks`
def test_basic(spec, state):
transition_to_valid_shard_slot(spec, state)
# Initialization
store = spec.get_forkchoice_store(state)
anchor_root = get_anchor_root(spec, state)
assert spec.get_head(store) == anchor_root
shard = spec.Shard(1)
shard_store = spec.get_forkchoice_shard_store(state, shard)
shard_head_root = spec.get_shard_head(store, shard_store)
assert shard_head_root == state.shard_states[shard].latest_block_root
assert shard_store.block_states[shard_head_root] == state.shard_states[shard]
# For mainnet config, it's possible that only one committee of `shard` per epoch.
# we set this counter to test more rounds.
shard_committee_counter = 2
shard_blocks_buffer = []
while shard_committee_counter > 0:
has_shard_committee = apply_shard_and_beacon(
spec, state, store, shard_store, shard_blocks_buffer
)
if has_shard_committee:
shard_committee_counter -= 1