Add phase1 type hinting checks and fix many bugs

This commit is contained in:
Hsiao-Wei Wang 2019-06-12 20:08:19 -04:00
parent 8a54203796
commit 48e8164e28
No known key found for this signature in database
GPG Key ID: 95B070122902DEA4
3 changed files with 50 additions and 45 deletions

View File

@ -43,6 +43,7 @@ PHASE1_IMPORTS = '''from typing import (
Callable, Callable,
Dict, Dict,
List, List,
Optional,
Set, Set,
Tuple, Tuple,
) )

View File

@ -184,7 +184,7 @@ class CustodyResponse(Container):
```python ```python
class CustodyKeyReveal(Container): class CustodyKeyReveal(Container):
# Index of the validator whose key is being revealed # Index of the validator whose key is being revealed
revealer_index: uint64 revealer_index: ValidatorIndex
# Reveal (masked signature) # Reveal (masked signature)
reveal: Bytes96 reveal: Bytes96
``` ```
@ -198,7 +198,7 @@ class EarlyDerivedSecretReveal(Container):
# Index of the validator whose key is being revealed # Index of the validator whose key is being revealed
revealed_index: uint64 revealed_index: uint64
# RANDAO epoch of the key that is being revealed # RANDAO epoch of the key that is being revealed
epoch: uint64 epoch: Epoch
# Reveal (masked signature) # Reveal (masked signature)
reveal: Bytes96 reveal: Bytes96
# Index of the validator who revealed (whistleblower) # Index of the validator who revealed (whistleblower)
@ -251,7 +251,7 @@ class BeaconBlockBody(Container):
### `ceillog2` ### `ceillog2`
```python ```python
def ceillog2(x): def ceillog2(x: int) -> int:
return x.bit_length() return x.bit_length()
``` ```
@ -269,7 +269,7 @@ def get_custody_chunk_count(crosslink: Crosslink) -> int:
```python ```python
def get_custody_chunk_bit(key: Bytes96, chunk: bytes) -> bool: def get_custody_chunk_bit(key: Bytes96, chunk: bytes) -> bool:
# TODO: Replace with something MPC-friendly, e.g. the Legendre symbol # TODO: Replace with something MPC-friendly, e.g. the Legendre symbol
return get_bitfield_bit(hash(key + chunk), 0) return bool(get_bitfield_bit(hash(key + chunk), 0))
``` ```
### `get_chunk_bits_root` ### `get_chunk_bits_root`
@ -288,7 +288,7 @@ def get_chunk_bits_root(chunk_bitfield: bytes) -> Bytes32:
```python ```python
def get_randao_epoch_for_custody_period(period: int, validator_index: ValidatorIndex) -> Epoch: def get_randao_epoch_for_custody_period(period: int, validator_index: ValidatorIndex) -> Epoch:
next_period_start = (period + 1) * EPOCHS_PER_CUSTODY_PERIOD - validator_index % EPOCHS_PER_CUSTODY_PERIOD next_period_start = (period + 1) * EPOCHS_PER_CUSTODY_PERIOD - validator_index % EPOCHS_PER_CUSTODY_PERIOD
return next_period_start + CUSTODY_PERIOD_TO_RANDAO_PADDING return Epoch(next_period_start + CUSTODY_PERIOD_TO_RANDAO_PADDING)
``` ```
### `get_validators_custody_reveal_period` ### `get_validators_custody_reveal_period`
@ -372,7 +372,11 @@ def process_custody_key_reveal(state: BeaconState,
# Reward Block Preposer # Reward Block Preposer
proposer_index = get_beacon_proposer_index(state) proposer_index = get_beacon_proposer_index(state)
increase_balance(state, proposer_index, get_base_reward(state, reveal.revealer_index) // MINOR_REWARD_QUOTIENT) increase_balance(
state,
proposer_index,
Gwei(get_base_reward(state, reveal.revealer_index) // MINOR_REWARD_QUOTIENT)
)
``` ```
#### Early derived secret reveals #### Early derived secret reveals
@ -433,7 +437,7 @@ def process_early_derived_secret_reveal(state: BeaconState,
// len(get_active_validator_indices(state, get_current_epoch(state))) // len(get_active_validator_indices(state, get_current_epoch(state)))
// PROPOSER_REWARD_QUOTIENT // PROPOSER_REWARD_QUOTIENT
) )
penalty = ( penalty = Gwei(
max_proposer_slot_reward max_proposer_slot_reward
* EARLY_DERIVED_SECRET_REVEAL_SLOT_REWARD_MULTIPLE * EARLY_DERIVED_SECRET_REVEAL_SLOT_REWARD_MULTIPLE
* (len(state.exposed_derived_secrets[derived_secret_location]) + 1) * (len(state.exposed_derived_secrets[derived_secret_location]) + 1)
@ -442,8 +446,8 @@ def process_early_derived_secret_reveal(state: BeaconState,
# Apply penalty # Apply penalty
proposer_index = get_beacon_proposer_index(state) proposer_index = get_beacon_proposer_index(state)
whistleblower_index = reveal.masker_index whistleblower_index = reveal.masker_index
whistleblowing_reward = penalty // WHISTLEBLOWING_REWARD_QUOTIENT whistleblowing_reward = Gwei(penalty // WHISTLEBLOWING_REWARD_QUOTIENT)
proposer_reward = whistleblowing_reward // PROPOSER_REWARD_QUOTIENT proposer_reward = Gwei(whistleblowing_reward // PROPOSER_REWARD_QUOTIENT)
increase_balance(state, proposer_index, proposer_reward) increase_balance(state, proposer_index, proposer_reward)
increase_balance(state, whistleblower_index, whistleblowing_reward - proposer_reward) increase_balance(state, whistleblower_index, whistleblowing_reward - proposer_reward)
decrease_balance(state, reveal.revealed_index, penalty) decrease_balance(state, reveal.revealed_index, penalty)
@ -512,7 +516,7 @@ def process_bit_challenge(state: BeaconState,
pubkey=challenger.pubkey, pubkey=challenger.pubkey,
message_hash=signing_root(challenge), message_hash=signing_root(challenge),
signature=challenge.signature, signature=challenge.signature,
domain=get_domain(state, get_current_epoch(state), DOMAIN_CUSTODY_BIT_CHALLENGE), domain=get_domain(state, DOMAIN_CUSTODY_BIT_CHALLENGE, get_current_epoch(state)),
) )
assert is_slashable_validator(challenger, get_current_epoch(state)) assert is_slashable_validator(challenger, get_current_epoch(state))
@ -535,8 +539,8 @@ def process_bit_challenge(state: BeaconState,
# Verify the responder is a valid custody key # Verify the responder is a valid custody key
epoch_to_sign = get_randao_epoch_for_custody_period( epoch_to_sign = get_randao_epoch_for_custody_period(
get_validators_custody_reveal_period( get_validators_custody_reveal_period(
state=state, state,
index=challenge.responder_index, challenge.responder_index,
epoch=slot_to_epoch(attestation.data.slot)), epoch=slot_to_epoch(attestation.data.slot)),
challenge.responder_index challenge.responder_index
) )
@ -610,7 +614,7 @@ def process_chunk_challenge_response(state: BeaconState,
# Verify the chunk matches the crosslink data root # Verify the chunk matches the crosslink data root
assert verify_merkle_branch( assert verify_merkle_branch(
leaf=hash_tree_root(response.chunk), leaf=hash_tree_root(response.chunk),
branch=response.data_branch, proof=response.data_branch,
depth=challenge.depth, depth=challenge.depth,
index=response.chunk_index, index=response.chunk_index,
root=challenge.data_root, root=challenge.data_root,
@ -620,7 +624,7 @@ def process_chunk_challenge_response(state: BeaconState,
records[records.index(challenge)] = CustodyChunkChallengeRecord() records[records.index(challenge)] = CustodyChunkChallengeRecord()
# Reward the proposer # Reward the proposer
proposer_index = get_beacon_proposer_index(state) proposer_index = get_beacon_proposer_index(state)
increase_balance(state, proposer_index, get_base_reward(state, proposer_index) // MINOR_REWARD_QUOTIENT) increase_balance(state, proposer_index, Gwei(get_base_reward(state, proposer_index) // MINOR_REWARD_QUOTIENT))
``` ```
```python ```python
@ -635,7 +639,7 @@ def process_bit_challenge_response(state: BeaconState,
# Verify the chunk matches the crosslink data root # Verify the chunk matches the crosslink data root
assert verify_merkle_branch( assert verify_merkle_branch(
leaf=hash_tree_root(response.chunk), leaf=hash_tree_root(response.chunk),
branch=response.data_branch, proof=response.data_branch,
depth=ceillog2(challenge.chunk_count), depth=ceillog2(challenge.chunk_count),
index=response.chunk_index, index=response.chunk_index,
root=challenge.data_root, root=challenge.data_root,
@ -643,7 +647,7 @@ def process_bit_challenge_response(state: BeaconState,
# Verify the chunk bit leaf matches the challenge data # Verify the chunk bit leaf matches the challenge data
assert verify_merkle_branch( assert verify_merkle_branch(
leaf=response.chunk_bits_leaf, leaf=response.chunk_bits_leaf,
branch=response.chunk_bits_branch, proof=response.chunk_bits_branch,
depth=ceillog2(challenge.chunk_count) >> 8, depth=ceillog2(challenge.chunk_count) >> 8,
index=response.chunk_index // 256, index=response.chunk_index // 256,
root=challenge.chunk_bits_merkle_root root=challenge.chunk_bits_merkle_root
@ -671,8 +675,8 @@ Run `process_reveal_deadlines(state)` immediately after `process_registry_update
def process_reveal_deadlines(state: BeaconState) -> None: def process_reveal_deadlines(state: BeaconState) -> None:
for index, validator in enumerate(state.validator_registry): for index, validator in enumerate(state.validator_registry):
deadline = validator.next_custody_reveal_period + (CUSTODY_RESPONSE_DEADLINE // EPOCHS_PER_CUSTODY_PERIOD) deadline = validator.next_custody_reveal_period + (CUSTODY_RESPONSE_DEADLINE // EPOCHS_PER_CUSTODY_PERIOD)
if get_validators_custody_reveal_period(state, index) > deadline: if get_validators_custody_reveal_period(state, ValidatorIndex(index)) > deadline:
slash_validator(state, index) slash_validator(state, ValidatorIndex(index))
``` ```
Run `process_challenge_deadlines(state)` immediately after `process_reveal_deadlines(state)`: Run `process_challenge_deadlines(state)` immediately after `process_reveal_deadlines(state)`:
@ -682,17 +686,17 @@ Run `process_challenge_deadlines(state)` immediately after `process_reveal_deadl
process_challenge_deadlines(state) process_challenge_deadlines(state)
# end insert @process_challenge_deadlines # end insert @process_challenge_deadlines
def process_challenge_deadlines(state: BeaconState) -> None: def process_challenge_deadlines(state: BeaconState) -> None:
for challenge in state.custody_chunk_challenge_records: for custody_chunk_challenge in state.custody_chunk_challenge_records:
if get_current_epoch(state) > challenge.inclusion_epoch + CUSTODY_RESPONSE_DEADLINE: if get_current_epoch(state) > custody_chunk_challenge.inclusion_epoch + CUSTODY_RESPONSE_DEADLINE:
slash_validator(state, challenge.responder_index, challenge.challenger_index) slash_validator(state, custody_chunk_challenge.responder_index, custody_chunk_challenge.challenger_index)
records = state.custody_chunk_challenge_records records = state.custody_chunk_challenge
records[records.index(challenge)] = CustodyChunkChallengeRecord() records[records.index(custody_chunk_challenge)] = CustodyChunkChallengeRecord()
for challenge in state.custody_bit_challenge_records: for custody_bit_challenge in state.custody_bit_challenge_records:
if get_current_epoch(state) > challenge.inclusion_epoch + CUSTODY_RESPONSE_DEADLINE: if get_current_epoch(state) > custody_bit_challenge.inclusion_epoch + CUSTODY_RESPONSE_DEADLINE:
slash_validator(state, challenge.responder_index, challenge.challenger_index) slash_validator(state, custody_bit_challenge.responder_index, custody_bit_challenge.challenger_index)
records = state.custody_bit_challenge_records records = state.custody_bit_challenge_records
records[records.index(challenge)] = CustodyBitChallengeRecord() records[records.index(custody_bit_challenge)] = CustodyBitChallengeRecord()
``` ```
Append this to `process_final_updates(state)`: Append this to `process_final_updates(state)`:
@ -713,5 +717,5 @@ def after_process_final_updates(state: BeaconState) -> None:
for index, validator in enumerate(state.validator_registry): for index, validator in enumerate(state.validator_registry):
if index not in validator_indices_in_records: if index not in validator_indices_in_records:
if validator.exit_epoch != FAR_FUTURE_EPOCH and validator.withdrawable_epoch == FAR_FUTURE_EPOCH: if validator.exit_epoch != FAR_FUTURE_EPOCH and validator.withdrawable_epoch == FAR_FUTURE_EPOCH:
validator.withdrawable_epoch = validator.exit_epoch + MIN_VALIDATOR_WITHDRAWABILITY_DELAY validator.withdrawable_epoch = Epoch(validator.exit_epoch + MIN_VALIDATOR_WITHDRAWABILITY_DELAY)
``` ```

View File

@ -79,8 +79,8 @@ class ShardBlockBody(Container):
```python ```python
class ShardAttestation(Container): class ShardAttestation(Container):
class data(Container): class data(Container):
slot: uint64 slot: Slot
shard: uint64 shard: Shard
shard_block_root: Bytes32 shard_block_root: Bytes32
aggregation_bitfield: bytes aggregation_bitfield: bytes
aggregate_signature: Bytes96 aggregate_signature: Bytes96
@ -90,8 +90,8 @@ class ShardAttestation(Container):
```python ```python
class ShardBlock(Container): class ShardBlock(Container):
slot: uint64 slot: Slot
shard: uint64 shard: Shard
beacon_chain_root: Bytes32 beacon_chain_root: Bytes32
parent_root: Bytes32 parent_root: Bytes32
data: ShardBlockBody data: ShardBlockBody
@ -104,8 +104,8 @@ class ShardBlock(Container):
```python ```python
class ShardBlockHeader(Container): class ShardBlockHeader(Container):
slot: uint64 slot: Slot
shard: uint64 shard: Shard
beacon_chain_root: Bytes32 beacon_chain_root: Bytes32
parent_root: Bytes32 parent_root: Bytes32
body_root: Bytes32 body_root: Bytes32
@ -138,8 +138,8 @@ def get_period_committee(state: BeaconState,
### `get_switchover_epoch` ### `get_switchover_epoch`
```python ```python
def get_switchover_epoch(state: BeaconState, epoch: Epoch, index: ValidatorIndex): def get_switchover_epoch(state: BeaconState, epoch: Epoch, index: ValidatorIndex) -> int:
earlier_start_epoch = epoch - (epoch % PERSISTENT_COMMITTEE_PERIOD) - PERSISTENT_COMMITTEE_PERIOD * 2 earlier_start_epoch = Epoch(epoch - (epoch % PERSISTENT_COMMITTEE_PERIOD) - PERSISTENT_COMMITTEE_PERIOD * 2)
return (bytes_to_int(hash(generate_seed(state, earlier_start_epoch) + int_to_bytes(index, length=3)[0:8])) return (bytes_to_int(hash(generate_seed(state, earlier_start_epoch) + int_to_bytes(index, length=3)[0:8]))
% PERSISTENT_COMMITTEE_PERIOD) % PERSISTENT_COMMITTEE_PERIOD)
``` ```
@ -154,19 +154,19 @@ def get_persistent_committee(state: BeaconState,
Return the persistent committee for the given ``shard`` at the given ``slot``. Return the persistent committee for the given ``shard`` at the given ``slot``.
""" """
epoch = slot_to_epoch(slot) epoch = slot_to_epoch(slot)
earlier_start_epoch = epoch - (epoch % PERSISTENT_COMMITTEE_PERIOD) - PERSISTENT_COMMITTEE_PERIOD * 2 earlier_start_epoch = Epoch(epoch - (epoch % PERSISTENT_COMMITTEE_PERIOD) - PERSISTENT_COMMITTEE_PERIOD * 2)
later_start_epoch = epoch - (epoch % PERSISTENT_COMMITTEE_PERIOD) - PERSISTENT_COMMITTEE_PERIOD later_start_epoch = Epoch(epoch - (epoch % PERSISTENT_COMMITTEE_PERIOD) - PERSISTENT_COMMITTEE_PERIOD)
committee_count = max( committee_count = max(
len(get_active_validator_indices(state.validator_registry, earlier_start_epoch)) // len(get_active_validator_indices(state, earlier_start_epoch)) //
(SHARD_COUNT * TARGET_COMMITTEE_SIZE), (SHARD_COUNT * TARGET_COMMITTEE_SIZE),
len(get_active_validator_indices(state.validator_registry, later_start_epoch)) // len(get_active_validator_indices(state, later_start_epoch)) //
(SHARD_COUNT * TARGET_COMMITTEE_SIZE), (SHARD_COUNT * TARGET_COMMITTEE_SIZE),
) + 1 ) + 1
index = slot % committee_count index = slot % committee_count
earlier_committee = get_period_committee(state, shard, earlier_start_epoch, index, committee_count) earlier_committee = get_period_committee(state, earlier_start_epoch, shard, index, committee_count)
later_committee = get_period_committee(state, shard, later_start_epoch, index, committee_count) later_committee = get_period_committee(state, later_start_epoch, shard, index, committee_count)
# Take not-yet-cycled-out validators from earlier committee and already-cycled-in validators from # Take not-yet-cycled-out validators from earlier committee and already-cycled-in validators from
# later committee; return a sorted list of the union of the two, deduplicated # later committee; return a sorted list of the union of the two, deduplicated
@ -181,7 +181,7 @@ def get_persistent_committee(state: BeaconState,
```python ```python
def get_shard_proposer_index(state: BeaconState, def get_shard_proposer_index(state: BeaconState,
shard: Shard, shard: Shard,
slot: Slot) -> ValidatorIndex: slot: Slot) -> Optional[ValidatorIndex]:
# Randomly shift persistent committee # Randomly shift persistent committee
persistent_committee = get_persistent_committee(state, shard, slot) persistent_committee = get_persistent_committee(state, shard, slot)
seed = hash(state.current_shuffling_seed + int_to_bytes(shard, length=8) + int_to_bytes(slot, length=8)) seed = hash(state.current_shuffling_seed + int_to_bytes(shard, length=8) + int_to_bytes(slot, length=8))
@ -231,7 +231,7 @@ def verify_shard_attestation_signature(state: BeaconState,
pubkey=bls_aggregate_pubkeys(pubkeys), pubkey=bls_aggregate_pubkeys(pubkeys),
message_hash=data.shard_block_root, message_hash=data.shard_block_root,
signature=attestation.aggregate_signature, signature=attestation.aggregate_signature,
domain=get_domain(state, slot_to_epoch(data.slot), DOMAIN_SHARD_ATTESTER) domain=get_domain(state, DOMAIN_SHARD_ATTESTER, slot_to_epoch(data.slot))
) )
``` ```
@ -328,7 +328,7 @@ def is_valid_shard_block(beacon_blocks: List[BeaconBlock],
pubkey=beacon_state.validator_registry[proposer_index].pubkey, pubkey=beacon_state.validator_registry[proposer_index].pubkey,
message_hash=signing_root(block), message_hash=signing_root(block),
signature=candidate.signature, signature=candidate.signature,
domain=get_domain(beacon_state, slot_to_epoch(candidate.slot), DOMAIN_SHARD_PROPOSER), domain=get_domain(beacon_state, DOMAIN_SHARD_PROPOSER, slot_to_epoch(candidate.slot)),
) )
return True return True