add missing attestation validations; cleanup process_attestations and fix minor bugs

This commit is contained in:
Danny Ryan 2019-11-22 18:09:31 -07:00
parent a1ac0d5a80
commit f533fef167
No known key found for this signature in database
GPG Key ID: 2765A792E42CE07A
2 changed files with 110 additions and 61 deletions

View File

@ -415,6 +415,14 @@ def get_shard(state: BeaconState, attestation: Attestation) -> Shard:
return Shard((attestation.data.index + get_start_shard(state, attestation.data.slot)) % ACTIVE_SHARDS)
```
#### `get_next_slot_for_shard`
```python
def get_next_slot_for_shard(state: BeaconState, shard: Shard) -> Slot:
return Slot(state.shard_transitions[shard].slot + 1)
```
#### `get_offset_slots`
```python
@ -422,7 +430,6 @@ def get_offset_slots(state: BeaconState, start_slot: Slot) -> Sequence[Slot]:
return [Slot(start_slot + x) for x in SHARD_BLOCK_OFFSETS if start_slot + x < state.slot]
```
### Predicates
#### Updated `is_valid_indexed_attestation`
@ -507,17 +514,30 @@ def process_operations(state: BeaconState, body: BeaconBlockBody) -> None:
```python
def validate_attestation(state: BeaconState, attestation: Attestation) -> None:
data = attestation.data
assert data.index < get_committee_count_at_slot(state, data.slot)
assert data.index < ACTIVE_SHARDS
assert data.target.epoch in (get_previous_epoch(state), get_current_epoch(state))
assert data.slot + MIN_ATTESTATION_INCLUSION_DELAY <= state.slot <= data.slot + SLOTS_PER_EPOCH
committee = get_beacon_committee(state, data.slot, data.index)
assert len(attestation.aggregation_bits) == len(committee)
if attestation.data.target.epoch == get_current_epoch(state):
assert attestation.data.source == state.current_justified_checkpoint
else:
assert attestation.data.source == state.previous_justified_checkpoint
shard = get_shard(state, attestation)
shard_start_slot = get_next_slot_for_shard(state, shard)
# Signature check
assert is_valid_indexed_attestation(state, get_indexed_attestation(state, attestation))
# Type 1: on-time attestations
if attestation.custody_bits != []:
# Correct slot
assert data.slot == state.slot
assert data.slot + MIN_ATTESTATION_INCLUSION_DELAY == state.slot
# Correct data root count
assert len(attestation.custody_bits) == len(get_offset_slots(state, state.shard_next_slots[shard]))
assert len(attestation.custody_bits) == len(get_offset_slots(state, shard_start_slot))
# Correct parent block root
assert data.beacon_block_root == get_block_root_at_slot(state, get_previous_slot(state))
# Type 2: delayed attestations
@ -531,7 +551,7 @@ def validate_attestation(state: BeaconState, attestation: Attestation) -> None:
```python
def apply_shard_transition(state: BeaconState, shard: Shard, transition: ShardTransition) -> None:
# Slot the attestation starts counting from
start_slot = state.shard_next_slots[shard]
start_slot = get_next_slot_for_shard(state, shard)
# Correct data root count
offset_slots = get_offset_slots(state, start_slot)
@ -543,7 +563,7 @@ def apply_shard_transition(state: BeaconState, shard: Shard, transition: ShardTr
)
assert transition.start_slot == start_slot
# Reonstruct shard headers
# Reconstruct shard headers
headers = []
proposers = []
shard_parent_root = state.shard_states[shard].latest_block_root
@ -582,6 +602,84 @@ def apply_shard_transition(state: BeaconState, shard: Shard, transition: ShardTr
state.shard_states[shard].slot = state.slot - 1
```
###### `process_crosslink_for_shard`
```python
def process_crosslink_for_shard(state: BeaconState,
shard: Shard,
shard_transition: ShardTransition,
attestations: Sequence[Attestation]) -> Hash:
committee = get_beacon_committee(state, get_current_epoch(state), shard)
online_indices = get_online_validator_indices(state)
# Loop over all shard transition roots
shard_transition_roots = set([a.data.shard_transition_root for a in attestations])
for shard_transition_root in sorted(shard_transition_roots):
transition_attestations = [a for a in attestations if a.data.shard_transition_root == shard_transition_root]
transition_participants: Set[ValidatorIndex] = set()
for attestation in transition_attestations:
participants = get_attesting_indices(state, attestation.data, attestation.aggregation_bits)
transition_participants = transition_participants.union(participants)
enough_online_stake = (
get_total_balance(state, online_indices.intersection(transition_participants)) * 3 >=
get_total_balance(state, online_indices.intersection(committee)) * 2
)
# If not enough stake, try next transition root
if not enough_online_stake:
continue
# Attestation <-> shard transition consistency
assert shard_transition_root == hash_tree_root(shard_transition)
assert (
attestation.data.head_shard_root
== chunks_to_body_root(shard_transition.shard_data_roots[-1])
)
# Apply transition
apply_shard_transition(state, shard, shard_transition)
# Apply proposer reward and cost
beacon_proposer_index = get_beacon_proposer_index(state)
estimated_attester_reward = sum([get_base_reward(state, attester) for attester in transition_participants])
proposer_reward = Gwei(estimated_attester_reward // PROPOSER_REWARD_QUOTIENT)
increase_balance(state, beacon_proposer_index, proposer_reward)
states_slots_lengths = zip(
shard_transition.shard_states,
get_offset_slots(state, get_next_slot_for_shard(state, shard)),
shard_transition.shard_block_lengths
)
for shard_state, slot, length in states_slots_lengths:
proposer_index = get_shard_proposer_index(state, shard, slot)
decrease_balance(state, proposer_index, shard_state.gasprice * length)
# Return winning transition root
return shard_transition_root
# No winning transition root, ensure empty and return empty root
assert shard_transition == ShardTransition()
return Hash()
```
###### `process_crosslinks`
```python
def process_crosslinks(state: BeaconState,
block_body: BeaconBlockBody,
attestations: Sequence[Attestation]) -> Set[Tuple[Shard, Hash]]:
winners: Set[Tuple[Shard, Hash]] = set()
for shard in map(Shard, range(ACTIVE_SHARDS)):
# All attestations in the block for this shard
shard_attestations = [
attestation for attestation in attestations
if get_shard(state, attestation) == shard and attestation.data.slot == state.slot
]
shard_transition = block_body.shard_transitions[shard]
winning_root = process_crosslink_for_shard(state, shard, shard_transition, shard_attestations)
if winning_root != Hash():
winners.add((shard, winning_root))
return winners
```
###### `process_attestations`
```python
@ -589,72 +687,23 @@ def process_attestations(state: BeaconState, block_body: BeaconBlockBody, attest
# Basic validation
for attestation in attestations:
validate_attestation(state, attestation)
# Process crosslinks
online_indices = get_online_validator_indices(state)
winners = set()
for shard in map(Shard, range(ACTIVE_SHARDS)):
success = False
# All attestations in the block for this shard
this_shard_attestations = [
attestation for attestation in attestations
if get_shard(state, attestation) == shard and attestation.data.slot == state.slot
]
# The committee for this shard
this_shard_committee = get_beacon_committee(state, get_current_epoch(state), shard)
# Loop over all shard transition roots
shard_transition_roots = set([a.data.shard_transition_root for a in this_shard_attestations])
for shard_transition_root in sorted(shard_transition_roots):
all_participants: Set[ValidatorIndex] = set()
participating_attestations = []
for attestation in this_shard_attestations:
participating_attestations.append(attestation)
if attestation.data.shard_transition_root == shard_transition_root:
participants = get_attesting_indices(state, attestation.data, attestation.aggregation_bits)
all_participants = all_participants.union(participants)
if (
get_total_balance(state, online_indices.intersection(all_participants)) * 3 >=
get_total_balance(state, online_indices.intersection(this_shard_committee)) * 2
and success is False
):
# Attestation <-> shard transition consistency
assert shard_transition_root == hash_tree_root(block_body.shard_transition)
assert (
attestation.data.head_shard_root
== chunks_to_body_root(block_body.shard_transition.shard_data_roots[-1])
)
# Apply transition
apply_shard_transition(state, shard, block_body.shard_transition)
# Apply proposer reward and cost
beacon_proposer_index = get_beacon_proposer_index(state)
estimated_attester_reward = sum([get_base_reward(state, attester) for attester in all_participants])
proposer_reward = Gwei(estimated_attester_reward // PROPOSER_REWARD_QUOTIENT)
increase_balance(state, beacon_proposer_index, proposer_reward)
states_slots_lengths = zip(
block_body.shard_transition.shard_states,
get_offset_slots(state, state.shard_next_slots[get_shard(state, attestation)]),
block_body.shard_transition.shard_block_lengths
)
for shard_state, slot, length in states_slots_lengths:
proposer_index = get_shard_proposer_index(state, shard, slot)
decrease_balance(state, proposer_index, shard_state.gasprice * length)
winners.add((shard, shard_transition_root))
success = True
if not success:
assert block_body.shard_transitions[shard] == ShardTransition()
winners = process_crosslinks(state, block_body, attestations)
# Store pending attestations for epoch processing
for attestation in attestations:
is_winning_transition = (get_shard(state, attestation), attestation.shard_transition_root) in winners
is_winning_transition = (get_shard(state, attestation), attestation.data.shard_transition_root) in winners
pending_attestation = PendingAttestation(
aggregation_bits=attestation.aggregation_bits,
data=attestation.data,
inclusion_delay=state.slot - attestation.data.slot,
crosslink_success=is_winning_transition and attestation.data.slot == state.slot,
proposer_index=proposer_index
proposer_index=get_beacon_proposer_index(state),
)
if attestation.data.target.epoch == get_current_epoch(state):
assert attestation.data.source == state.current_justified_checkpoint
state.current_epoch_attestations.append(pending_attestation)
else:
assert attestation.data.source == state.previous_justified_checkpoint
state.previous_epoch_attestations.append(pending_attestation)
```

View File

@ -86,7 +86,7 @@ def upgrade_to_phase1(pre: phase0.BeaconState) -> BeaconState:
# Phase 1
shard_states=List[ShardState, MAX_SHARDS](
ShardState(
slot=0,
slot=pre.slot,
gasprice=INITIAL_GASPRICE,
data=Root(),
latest_block_root=Hash(),