Merge pull request #1855 from ethereum/hwwhww/phase1_refactor_part2

Some phase1 refactoring - part2
This commit is contained in:
Hsiao-Wei Wang 2020-06-01 17:54:35 +08:00 committed by GitHub
commit 09d8636e7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 49 deletions

View File

@ -40,6 +40,7 @@
- [`compute_shard_from_committee_index`](#compute_shard_from_committee_index) - [`compute_shard_from_committee_index`](#compute_shard_from_committee_index)
- [`compute_offset_slots`](#compute_offset_slots) - [`compute_offset_slots`](#compute_offset_slots)
- [`compute_updated_gasprice`](#compute_updated_gasprice) - [`compute_updated_gasprice`](#compute_updated_gasprice)
- [`compute_committee_source_epoch`](#compute_committee_source_epoch)
- [Beacon state accessors](#beacon-state-accessors) - [Beacon state accessors](#beacon-state-accessors)
- [`get_active_shard_count`](#get_active_shard_count) - [`get_active_shard_count`](#get_active_shard_count)
- [`get_online_validator_indices`](#get_online_validator_indices) - [`get_online_validator_indices`](#get_online_validator_indices)
@ -52,6 +53,7 @@
- [`get_latest_slot_for_shard`](#get_latest_slot_for_shard) - [`get_latest_slot_for_shard`](#get_latest_slot_for_shard)
- [`get_offset_slots`](#get_offset_slots) - [`get_offset_slots`](#get_offset_slots)
- [Predicates](#predicates) - [Predicates](#predicates)
- [`verify_attestation_custody`](#verify_attestation_custody)
- [Updated `is_valid_indexed_attestation`](#updated-is_valid_indexed_attestation) - [Updated `is_valid_indexed_attestation`](#updated-is_valid_indexed_attestation)
- [`is_on_time_attestation`](#is_on_time_attestation) - [`is_on_time_attestation`](#is_on_time_attestation)
- [`is_winning_attestation`](#is_winning_attestation) - [`is_winning_attestation`](#is_winning_attestation)
@ -155,6 +157,7 @@ class PendingAttestation(Container):
data: AttestationData data: AttestationData
inclusion_delay: Slot inclusion_delay: Slot
proposer_index: ValidatorIndex proposer_index: ValidatorIndex
# Phase 1
crosslink_success: boolean crosslink_success: boolean
``` ```
@ -417,7 +420,7 @@ def unpack_compact_validator(compact_validator: uint64) -> Tuple[ValidatorIndex,
```python ```python
def committee_to_compact_committee(state: BeaconState, committee: Sequence[ValidatorIndex]) -> CompactCommittee: def committee_to_compact_committee(state: BeaconState, committee: Sequence[ValidatorIndex]) -> CompactCommittee:
""" """
Given a state and a list of validator indices, outputs the CompactCommittee representing them. Given a state and a list of validator indices, outputs the ``CompactCommittee`` representing them.
""" """
validators = [state.validators[i] for i in committee] validators = [state.validators[i] for i in committee]
compact_validators = [ compact_validators = [
@ -449,17 +452,30 @@ def compute_offset_slots(start_slot: Slot, end_slot: Slot) -> Sequence[Slot]:
#### `compute_updated_gasprice` #### `compute_updated_gasprice`
```python ```python
def compute_updated_gasprice(prev_gasprice: Gwei, length: uint8) -> Gwei: def compute_updated_gasprice(prev_gasprice: Gwei, shard_block_length: uint8) -> Gwei:
if length > TARGET_SHARD_BLOCK_SIZE: if shard_block_length > TARGET_SHARD_BLOCK_SIZE:
delta = (prev_gasprice * (length - TARGET_SHARD_BLOCK_SIZE) delta = (prev_gasprice * (shard_block_length - TARGET_SHARD_BLOCK_SIZE)
// TARGET_SHARD_BLOCK_SIZE // GASPRICE_ADJUSTMENT_COEFFICIENT) // TARGET_SHARD_BLOCK_SIZE // GASPRICE_ADJUSTMENT_COEFFICIENT)
return min(prev_gasprice + delta, MAX_GASPRICE) return min(prev_gasprice + delta, MAX_GASPRICE)
else: else:
delta = (prev_gasprice * (TARGET_SHARD_BLOCK_SIZE - length) delta = (prev_gasprice * (TARGET_SHARD_BLOCK_SIZE - shard_block_length)
// TARGET_SHARD_BLOCK_SIZE // GASPRICE_ADJUSTMENT_COEFFICIENT) // TARGET_SHARD_BLOCK_SIZE // GASPRICE_ADJUSTMENT_COEFFICIENT)
return max(prev_gasprice, MIN_GASPRICE + delta) - delta return max(prev_gasprice, MIN_GASPRICE + delta) - delta
``` ```
#### `compute_committee_source_epoch`
```python
def compute_committee_source_epoch(epoch: Epoch, period: uint64) -> Epoch:
"""
Return the source epoch for computing the committee.
"""
source_epoch = epoch - epoch % period
if source_epoch >= period:
source_epoch -= period # `period` epochs lookahead
return source_epoch
```
### Beacon state accessors ### Beacon state accessors
#### `get_active_shard_count` #### `get_active_shard_count`
@ -481,9 +497,10 @@ def get_online_validator_indices(state: BeaconState) -> Set[ValidatorIndex]:
```python ```python
def get_shard_committee(beacon_state: BeaconState, epoch: Epoch, shard: Shard) -> Sequence[ValidatorIndex]: def get_shard_committee(beacon_state: BeaconState, epoch: Epoch, shard: Shard) -> Sequence[ValidatorIndex]:
source_epoch = epoch - epoch % SHARD_COMMITTEE_PERIOD """
if source_epoch >= SHARD_COMMITTEE_PERIOD: Return the shard committee of the given ``epoch`` of the given ``shard``.
source_epoch -= SHARD_COMMITTEE_PERIOD """
source_epoch = compute_committee_source_epoch(epoch, SHARD_COMMITTEE_PERIOD)
active_validator_indices = get_active_validator_indices(beacon_state, source_epoch) active_validator_indices = get_active_validator_indices(beacon_state, source_epoch)
seed = get_seed(beacon_state, source_epoch, DOMAIN_SHARD_COMMITTEE) seed = get_seed(beacon_state, source_epoch, DOMAIN_SHARD_COMMITTEE)
active_shard_count = get_active_shard_count(beacon_state) active_shard_count = get_active_shard_count(beacon_state)
@ -499,9 +516,10 @@ def get_shard_committee(beacon_state: BeaconState, epoch: Epoch, shard: Shard) -
```python ```python
def get_light_client_committee(beacon_state: BeaconState, epoch: Epoch) -> Sequence[ValidatorIndex]: def get_light_client_committee(beacon_state: BeaconState, epoch: Epoch) -> Sequence[ValidatorIndex]:
source_epoch = epoch - epoch % LIGHT_CLIENT_COMMITTEE_PERIOD """
if source_epoch >= LIGHT_CLIENT_COMMITTEE_PERIOD: Return the light client committee of no more than ``TARGET_COMMITTEE_SIZE`` validators.
source_epoch -= LIGHT_CLIENT_COMMITTEE_PERIOD """
source_epoch = compute_committee_source_epoch(epoch, LIGHT_CLIENT_COMMITTEE_PERIOD)
active_validator_indices = get_active_validator_indices(beacon_state, source_epoch) active_validator_indices = get_active_validator_indices(beacon_state, source_epoch)
seed = get_seed(beacon_state, source_epoch, DOMAIN_LIGHT_CLIENT) seed = get_seed(beacon_state, source_epoch, DOMAIN_LIGHT_CLIENT)
return compute_committee( return compute_committee(
@ -558,11 +576,45 @@ def get_latest_slot_for_shard(state: BeaconState, shard: Shard) -> Slot:
```python ```python
def get_offset_slots(state: BeaconState, shard: Shard) -> Sequence[Slot]: def get_offset_slots(state: BeaconState, shard: Shard) -> Sequence[Slot]:
return compute_offset_slots(state.shard_states[shard].slot, state.slot) """
Return the offset slots of the given ``shard`` between that latest included slot and current slot.
"""
return compute_offset_slots(get_latest_slot_for_shard(state, shard), state.slot)
``` ```
### Predicates ### Predicates
#### `verify_attestation_custody`
```python
def verify_attestation_custody(state: BeaconState, indexed_attestation: IndexedAttestation) -> bool:
"""
Check if ``indexed_attestation`` has valid signature against non-empty custody bits.
"""
attestation = indexed_attestation.attestation
aggregation_bits = attestation.aggregation_bits
domain = get_domain(state, DOMAIN_BEACON_ATTESTER, attestation.data.target.epoch)
all_pubkeys = []
all_signing_roots = []
for block_index, custody_bits in enumerate(attestation.custody_bits_blocks):
assert len(custody_bits) == len(indexed_attestation.committee)
for participant, aggregation_bit, custody_bit in zip(
indexed_attestation.committee, aggregation_bits, custody_bits
):
if aggregation_bit:
all_pubkeys.append(state.validators[participant].pubkey)
# Note: only 2N distinct message hashes
attestation_wrapper = AttestationCustodyBitWrapper(
attestation_data_root=hash_tree_root(attestation.data),
block_index=block_index,
bit=custody_bit,
)
all_signing_roots.append(compute_signing_root(attestation_wrapper, domain))
else:
assert not custody_bit
return bls.AggregateVerify(all_pubkeys, all_signing_roots, signature=attestation.signature)
```
#### Updated `is_valid_indexed_attestation` #### Updated `is_valid_indexed_attestation`
Note that this replaces the Phase 0 `is_valid_indexed_attestation`. Note that this replaces the Phase 0 `is_valid_indexed_attestation`.
@ -573,37 +625,22 @@ def is_valid_indexed_attestation(state: BeaconState, indexed_attestation: Indexe
Check if ``indexed_attestation`` has valid indices and signature. Check if ``indexed_attestation`` has valid indices and signature.
""" """
# Verify aggregate signature # Verify aggregate signature
all_pubkeys = []
all_signing_roots = []
attestation = indexed_attestation.attestation attestation = indexed_attestation.attestation
domain = get_domain(state, DOMAIN_BEACON_ATTESTER, attestation.data.target.epoch)
aggregation_bits = attestation.aggregation_bits aggregation_bits = attestation.aggregation_bits
if not any(aggregation_bits) or len(aggregation_bits) != len(indexed_attestation.committee): if not any(aggregation_bits) or len(aggregation_bits) != len(indexed_attestation.committee):
return False return False
if len(attestation.custody_bits_blocks) == 0: if len(attestation.custody_bits_blocks) == 0:
# fall back on phase0 behavior if there is no shard data. # fall back on phase0 behavior if there is no shard data.
for participant, abit in zip(indexed_attestation.committee, aggregation_bits): domain = get_domain(state, DOMAIN_BEACON_ATTESTER, attestation.data.target.epoch)
if abit: all_pubkeys = []
for participant, aggregation_bit in zip(indexed_attestation.committee, aggregation_bits):
if aggregation_bit:
all_pubkeys.append(state.validators[participant].pubkey) all_pubkeys.append(state.validators[participant].pubkey)
signing_root = compute_signing_root(indexed_attestation.attestation.data, domain) signing_root = compute_signing_root(indexed_attestation.attestation.data, domain)
return bls.FastAggregateVerify(all_pubkeys, signing_root, signature=attestation.signature) return bls.FastAggregateVerify(all_pubkeys, signing_root, signature=attestation.signature)
else: else:
for i, custody_bits in enumerate(attestation.custody_bits_blocks): return verify_attestation_custody(state, indexed_attestation)
assert len(custody_bits) == len(indexed_attestation.committee)
for participant, abit, cbit in zip(indexed_attestation.committee, aggregation_bits, custody_bits):
if abit:
all_pubkeys.append(state.validators[participant].pubkey)
# Note: only 2N distinct message hashes
attestation_wrapper = AttestationCustodyBitWrapper(
attestation_data_root=hash_tree_root(attestation.data),
block_index=i,
bit=cbit
)
all_signing_roots.append(compute_signing_root(attestation_wrapper, domain))
else:
assert not cbit
return bls.AggregateVerify(all_pubkeys, all_signing_roots, signature=attestation.signature)
``` ```
#### `is_on_time_attestation` #### `is_on_time_attestation`
@ -787,21 +824,21 @@ def apply_shard_transition(state: BeaconState, shard: Shard, transition: ShardTr
proposers = [] proposers = []
prev_gasprice = state.shard_states[shard].gasprice prev_gasprice = state.shard_states[shard].gasprice
shard_parent_root = state.shard_states[shard].latest_block_root shard_parent_root = state.shard_states[shard].latest_block_root
for i in range(len(offset_slots)): for i, offset_slot in enumerate(offset_slots):
shard_block_length = transition.shard_block_lengths[i] shard_block_length = transition.shard_block_lengths[i]
shard_state = transition.shard_states[i] shard_state = transition.shard_states[i]
# Verify correct calculation of gas prices and slots # Verify correct calculation of gas prices and slots
assert shard_state.gasprice == compute_updated_gasprice(prev_gasprice, shard_block_length) assert shard_state.gasprice == compute_updated_gasprice(prev_gasprice, shard_block_length)
assert shard_state.slot == offset_slots[i] assert shard_state.slot == offset_slot
# Collect the non-empty proposals result # Collect the non-empty proposals result
is_empty_proposal = shard_block_length == 0 is_empty_proposal = shard_block_length == 0
if not is_empty_proposal: if not is_empty_proposal:
proposal_index = get_shard_proposer_index(state, offset_slots[i], shard) proposal_index = get_shard_proposer_index(state, offset_slot, shard)
# Reconstruct shard headers # Reconstruct shard headers
header = ShardBlockHeader( header = ShardBlockHeader(
shard_parent_root=shard_parent_root, shard_parent_root=shard_parent_root,
beacon_parent_root=get_block_root_at_slot(state, offset_slots[i]), beacon_parent_root=get_block_root_at_slot(state, offset_slot),
slot=offset_slots[i], slot=offset_slot,
shard=shard, shard=shard,
proposer_index=proposal_index, proposer_index=proposal_index,
body_root=transition.shard_data_roots[i] body_root=transition.shard_data_roots[i]
@ -893,13 +930,12 @@ def process_crosslinks(state: BeaconState,
for committee_index in map(CommitteeIndex, range(committee_count)): for committee_index in map(CommitteeIndex, range(committee_count)):
shard = compute_shard_from_committee_index(state, committee_index, state.slot) shard = compute_shard_from_committee_index(state, committee_index, state.slot)
# All attestations in the block for this committee/shard and current slot # All attestations in the block for this committee/shard and current slot
shard_transition = shard_transitions[shard]
shard_attestations = [ shard_attestations = [
attestation for attestation in attestations attestation for attestation in attestations
if is_on_time_attestation(state, attestation) and attestation.data.index == committee_index if is_on_time_attestation(state, attestation) and attestation.data.index == committee_index
] ]
winning_root = process_crosslink_for_shard(state, committee_index, shard_transition, shard_attestations) winning_root = process_crosslink_for_shard(state, committee_index, shard_transitions[shard], shard_attestations)
if winning_root != Root(): if winning_root != Root():
# Mark relevant pending attestations as creating a successful crosslink # Mark relevant pending attestations as creating a successful crosslink
for pending_attestation in state.current_epoch_attestations: for pending_attestation in state.current_epoch_attestations:
@ -938,11 +974,9 @@ def process_shard_transitions(state: BeaconState,
```python ```python
def get_indices_from_committee( def get_indices_from_committee(
committee: List[ValidatorIndex, MAX_VALIDATORS_PER_COMMITTEE], committee: List[ValidatorIndex, MAX_VALIDATORS_PER_COMMITTEE],
bits: Bitlist[MAX_VALIDATORS_PER_COMMITTEE]) -> List[ValidatorIndex, MAX_VALIDATORS_PER_COMMITTEE]: bits: Bitlist[MAX_VALIDATORS_PER_COMMITTEE]) -> Sequence[ValidatorIndex]:
assert len(bits) == len(committee) assert len(bits) == len(committee)
return List[ValidatorIndex, MAX_VALIDATORS_PER_COMMITTEE]( return [validator_index for i, validator_index in enumerate(committee) if bits[i]]
[validator_index for i, validator_index in enumerate(committee) if bits[i]]
)
``` ```
```python ```python
@ -1036,7 +1070,9 @@ def process_online_tracking(state: BeaconState) -> None:
```python ```python
def process_light_client_committee_updates(state: BeaconState) -> None: def process_light_client_committee_updates(state: BeaconState) -> None:
# Update light client committees """
Update light client committees.
"""
if get_current_epoch(state) % LIGHT_CLIENT_COMMITTEE_PERIOD == 0: if get_current_epoch(state) % LIGHT_CLIENT_COMMITTEE_PERIOD == 0:
state.current_light_committee = state.next_light_committee state.current_light_committee = state.next_light_committee
new_committee = get_light_client_committee(state, get_current_epoch(state) + LIGHT_CLIENT_COMMITTEE_PERIOD) new_committee = get_light_client_committee(state, get_current_epoch(state) + LIGHT_CLIENT_COMMITTEE_PERIOD)

View File

@ -269,12 +269,8 @@ def get_shard_transition(beacon_state: BeaconState,
shard: Shard, shard: Shard,
shard_blocks: Sequence[SignedShardBlock]) -> ShardTransition: shard_blocks: Sequence[SignedShardBlock]) -> ShardTransition:
offset_slots = get_offset_slots(beacon_state, shard) offset_slots = get_offset_slots(beacon_state, shard)
start_slot = offset_slots[0]
proposals, shard_states, shard_data_roots = get_shard_state_transition_result(beacon_state, shard, shard_blocks) proposals, shard_states, shard_data_roots = get_shard_state_transition_result(beacon_state, shard, shard_blocks)
assert len(proposals) > 0
assert len(shard_data_roots) > 0
shard_block_lengths = [] shard_block_lengths = []
proposer_signatures = [] proposer_signatures = []
for proposal in proposals: for proposal in proposals:
@ -288,7 +284,7 @@ def get_shard_transition(beacon_state: BeaconState,
proposer_signature_aggregate = NO_SIGNATURE proposer_signature_aggregate = NO_SIGNATURE
return ShardTransition( return ShardTransition(
start_slot=start_slot, start_slot=offset_slots[0],
shard_block_lengths=shard_block_lengths, shard_block_lengths=shard_block_lengths,
shard_data_roots=shard_data_roots, shard_data_roots=shard_data_roots,
shard_states=shard_states, shard_states=shard_states,