Attestation committee refactor

* Remove `get_crosslink_committees_at_slot` (that function's ugly man...)
* Make the "base" that everything works off instead be `get_crosslink_committee`
* Attestations store epoch, start shard and shard, no longer slot (slot can be calculated from the other three)
* Retaining start shard in attestations allows `get_attesting_indices` to peek much further back into the past, making it useful for slashings (Phase 1)
* Some two-layer-deep nested loops become one-layer-deep loops
This commit is contained in:
vbuterin 2019-04-29 11:02:39 -05:00 committed by GitHub
parent 2787fea5fe
commit 77d7aa7630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 73 additions and 62 deletions

View File

@ -65,7 +65,8 @@
- [`get_epoch_committee_count`](#get_epoch_committee_count)
- [`get_shard_delta`](#get_shard_delta)
- [`compute_committee`](#compute_committee)
- [`get_crosslink_committees_at_slot`](#get_crosslink_committees_at_slot)
- [`get_epoch_start_shard`](#get_epoch_start_shard)
- [`committee_shard_to_slot`](#committee_shard_to_slot)
- [`get_block_root_at_slot`](#get_block_root_at_slot)
- [`get_block_root`](#get_block_root)
- [`get_state_root`](#get_state_root)
@ -74,6 +75,7 @@
- [`generate_seed`](#generate_seed)
- [`get_beacon_proposer_index`](#get_beacon_proposer_index)
- [`verify_merkle_branch`](#verify_merkle_branch)
- [`get_crosslink_committee`](#get_crosslink_committee)
- [`get_attesting_indices`](#get_attesting_indices)
- [`int_to_bytes1`, `int_to_bytes2`, ...](#int_to_bytes1-int_to_bytes2-)
- [`bytes_to_int`](#bytes_to_int)
@ -307,7 +309,7 @@ The types are defined topologically to aid in facilitating an executable version
```python
{
# LMD GHOST vote
'slot': 'uint64',
'epoch': 'uint64',
'beacon_block_root': 'bytes32',
# FFG vote
@ -316,6 +318,7 @@ The types are defined topologically to aid in facilitating an executable version
'target_root': 'bytes32',
# Crosslink vote
'epoch_start_shard': 'uint64',
'shard': 'uint64',
'previous_crosslink_root': 'bytes32',
'crosslink_data_root': 'bytes32',
@ -805,44 +808,30 @@ def compute_committee(validator_indices: List[ValidatorIndex],
Note: this definition and the next few definitions are highly inefficient as algorithms, as they re-calculate many sub-expressions. Production implementations are expected to appropriately use caching/memoization to avoid redoing work.
### `get_crosslink_committees_at_slot`
### `get_epoch_start_shard`
```python
def get_crosslink_committees_at_slot(state: BeaconState,
slot: Slot) -> List[Tuple[List[ValidatorIndex], Shard]]:
"""
Return the list of ``(committee, shard)`` tuples for the ``slot``.
"""
epoch = slot_to_epoch(slot)
current_epoch = get_current_epoch(state)
previous_epoch = get_previous_epoch(state)
next_epoch = current_epoch + 1
def get_epoch_start_shard(state: BeaconState, epoch: Epoch) -> Shard:
if epoch == get_current_epoch(state):
return state.latest_start_shard
elif epoch == get_previous_epoch(state):
previous_shard_delta = get_shard_delta(state, epoch)
return (state.latest_start_shard - previous_shard_delta) % SHARD_COUNT
elif epoch == get_current_epoch(state) + 1:
current_shard_delta = get_shard_delta(state, get_current_epoch(state))
return (state.latest_start_shard + current_shard_delta) % SHARD_COUNT
else:
raise Exception("Not supported")
```
assert previous_epoch <= epoch <= next_epoch
indices = get_active_validator_indices(state, epoch)
### `committee_shard_to_slot`
if epoch == current_epoch:
start_shard = state.latest_start_shard
elif epoch == previous_epoch:
previous_shard_delta = get_shard_delta(state, previous_epoch)
start_shard = (state.latest_start_shard - previous_shard_delta) % SHARD_COUNT
elif epoch == next_epoch:
current_shard_delta = get_shard_delta(state, current_epoch)
start_shard = (state.latest_start_shard + current_shard_delta) % SHARD_COUNT
committees_per_epoch = get_epoch_committee_count(state, epoch)
committees_per_slot = committees_per_epoch // SLOTS_PER_EPOCH
offset = slot % SLOTS_PER_EPOCH
slot_start_shard = (start_shard + committees_per_slot * offset) % SHARD_COUNT
seed = generate_seed(state, epoch)
return [
(
compute_committee(indices, seed, committees_per_slot * offset + i, committees_per_epoch),
(slot_start_shard + i) % SHARD_COUNT,
)
for i in range(committees_per_slot)
]
```python
def committee_shard_to_slot(state: BeaconState, epoch: Epoch, shard: Shard) -> Slot:
start_shard = get_epoch_start_shard(state, epoch)
committees_per_slot = get_epoch_committee_count(state, epoch) // SLOTS_PER_EPOCH
offset = (shard - get_epoch_start_slot(epoch)) % SHARD_COUNT
return get_epoch_start_slot(epoch) + offset // committees_per_slot
```
### `get_block_root_at_slot`
@ -927,7 +916,9 @@ def get_beacon_proposer_index(state: BeaconState) -> ValidatorIndex:
Return the beacon proposer index at ``state.slot``.
"""
current_epoch = get_current_epoch(state)
first_committee, _ = get_crosslink_committees_at_slot(state, state.slot)[0]
committees_per_slot = get_epoch_committee_count(state, current_epoch) // SLOTS_PER_EPOCH
offset = committees_per_slot * (state.slot % EPOCH_LENGTH)
first_committee = get_crosslink_committee(state, epoch, offset)
MAX_RANDOM_BYTE = 2**8 - 1
i = 0
while True:
@ -956,6 +947,18 @@ def verify_merkle_branch(leaf: Bytes32, proof: List[Bytes32], depth: int, index:
return value == root
```
### `get_crosslink_committee`
```python
def get_crosslink_committee(state: BeaconState, epoch: Epoch, offset: int):
return compute_committee(
validator_indices=get_active_validator_indices(state, epoch),
seed=generate_seed(state, epoch),
index=offset,
total_committees=get_epoch_committee_count(state, epoch)
)
```
### `get_attesting_indices`
```python
@ -965,10 +968,10 @@ def get_attesting_indices(state: BeaconState,
"""
Return the sorted attesting indices corresponding to ``attestation_data`` and ``bitfield``.
"""
crosslink_committees = get_crosslink_committees_at_slot(state, attestation_data.slot)
crosslink_committee = [committee for committee, shard in crosslink_committees if shard == attestation_data.shard][0]
assert verify_bitfield(bitfield, len(crosslink_committee))
return sorted([index for i, index in enumerate(crosslink_committee) if get_bitfield_bit(bitfield, i) == 0b1])
offset = (attestation_data.shard - attestation_data.epoch_start_shard) % SHARD_COUNT
committee = get_crosslink_committee(state, attestation_data.epoch, offset)
assert verify_bitfield(bitfield, len(committee))
return sorted([index for i, index in enumerate(committee) if get_bitfield_bit(bitfield, i) == 0b1])
```
### `int_to_bytes1`, `int_to_bytes2`, ...
@ -1088,7 +1091,7 @@ def verify_indexed_attestation(state: BeaconState, indexed_attestation: IndexedA
hash_tree_root(AttestationDataAndCustodyBit(data=indexed_attestation.data, custody_bit=0b1)),
],
signature=indexed_attestation.signature,
domain=get_domain(state, DOMAIN_ATTESTATION, slot_to_epoch(indexed_attestation.data.slot)),
domain=get_domain(state, DOMAIN_ATTESTATION, indexed_attestation.data.epoch),
)
```
@ -1331,7 +1334,7 @@ def get_matching_target_attestations(state: BeaconState, epoch: Epoch) -> List[P
def get_matching_head_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]:
return [
a for a in get_matching_source_attestations(state, epoch)
if a.data.beacon_block_root == get_block_root_at_slot(state, a.data.slot)
if a.data.beacon_block_root == get_block_root_at_slot(state, committee_shard_to_slot(state, epoch, a.data.shard))
]
```
@ -1351,7 +1354,7 @@ def get_attesting_balance(state: BeaconState, attestations: List[PendingAttestat
```python
def get_crosslink_from_attestation_data(state: BeaconState, data: AttestationData) -> Crosslink:
return Crosslink(
epoch=min(slot_to_epoch(data.slot), state.current_crosslinks[data.shard].epoch + MAX_CROSSLINK_EPOCHS),
epoch=min(data.epoch, state.current_crosslinks[data.shard].epoch + MAX_CROSSLINK_EPOCHS),
previous_crosslink_root=data.previous_crosslink_root,
crosslink_data_root=data.crosslink_data_root,
)
@ -1444,8 +1447,10 @@ def process_crosslinks(state: BeaconState) -> None:
previous_epoch = get_previous_epoch(state)
next_epoch = get_current_epoch(state) + 1
for slot in range(get_epoch_start_slot(previous_epoch), get_epoch_start_slot(next_epoch)):
epoch = slot_to_epoch(slot)
for crosslink_committee, shard in get_crosslink_committees_at_slot(state, slot):
for epoch in (get_previous_epoch(state), get_current_epoch(state)):
for offset in range(get_epoch_committee_count(state, epoch)):
shard = (get_epoch_start_shard(epoch) + offset) % SHARD_COUNT
crosslink_committee = get_crosslink_committee(state, epoch, offset)
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch)
if 3 * get_total_balance(state, attesting_indices) >= 2 * get_total_balance(state, crosslink_committee):
state.current_crosslinks[shard] = winning_crosslink
@ -1492,7 +1497,8 @@ def get_attestation_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]:
for index in get_unslashed_attesting_indices(state, matching_source_attestations):
earliest_attestation = get_earliest_attestation(state, matching_source_attestations, index)
rewards[earliest_attestation.proposer_index] += get_base_reward(state, index) // PROPOSER_REWARD_QUOTIENT
inclusion_delay = earliest_attestation.inclusion_slot - earliest_attestation.data.slot
attestation_slot = committee_shard_to_slot(state, earliest_attestation.data.epoch, earliest_attestation.data.shard)
inclusion_delay = earliest_attestation.inclusion_slot - attestation_slot
rewards[index] += get_base_reward(state, index) * MIN_ATTESTATION_INCLUSION_DELAY // inclusion_delay
# Inactivity penalty
@ -1511,18 +1517,19 @@ def get_attestation_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]:
def get_crosslink_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]:
rewards = [0 for index in range(len(state.validator_registry))]
penalties = [0 for index in range(len(state.validator_registry))]
for slot in range(get_epoch_start_slot(get_previous_epoch(state)), get_epoch_start_slot(get_current_epoch(state))):
epoch = slot_to_epoch(slot)
for crosslink_committee, shard in get_crosslink_committees_at_slot(state, slot):
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch)
attesting_balance = get_total_balance(state, attesting_indices)
committee_balance = get_total_balance(state, crosslink_committee)
for index in crosslink_committee:
base_reward = get_base_reward(state, index)
if index in attesting_indices:
rewards[index] += base_reward * attesting_balance // committee_balance
else:
penalties[index] += base_reward
epoch = get_previous_epoch(state)
for offset in range(get_epoch_committee_count(state, epoch)):
shard = (get_epoch_start_shard(epoch) + offset) % SHARD_COUNT
crosslink_committee = get_crosslink_committee(state, epoch, offset)
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch)
attesting_balance = get_total_balance(state, attesting_indices)
committee_balance = get_total_balance(state, crosslink_committee)
for index in crosslink_committee:
base_reward = get_base_reward(state, index)
if index in attesting_indices:
rewards[index] += base_reward * attesting_balance // committee_balance
else:
penalties[index] += base_reward
return [rewards, penalties]
```
@ -1770,15 +1777,19 @@ def process_attestation(state: BeaconState, attestation: Attestation) -> None:
Note that this function mutates ``state``.
"""
data = attestation.data
attestation_slot = committee_shard_to_slot(state, data.epoch, attestation.shard)
min_slot = state.slot - SLOTS_PER_EPOCH if get_current_epoch(state) > GENESIS_EPOCH else GENESIS_SLOT
assert min_slot <= data.slot <= state.slot - MIN_ATTESTATION_INCLUSION_DELAY
assert min_slot <= attestation_slot <= state.slot - MIN_ATTESTATION_INCLUSION_DELAY
# Check target epoch, source epoch, source root, and source crosslink
target_epoch = slot_to_epoch(data.slot)
assert (target_epoch, data.source_epoch, data.source_root, data.previous_crosslink_root) in {
assert (data.epoch, data.source_epoch, data.source_root, data.previous_crosslink_root) in {
(get_current_epoch(state), state.current_justified_epoch, state.current_justified_root, hash_tree_root(state.current_crosslinks[data.shard])),
(get_previous_epoch(state), state.previous_justified_epoch, state.previous_justified_root, hash_tree_root(state.previous_crosslinks[data.shard])),
}
# Check shard and epoch start shard
assert data.epoch_start_shard == get_epoch_start_shard(state, data.epoch)
assert (data.shard - data.epoch_start_shard) % SHARD_COUNT < get_epoch_committee_count(state, data.epoch)
# Check crosslink data root
assert data.crosslink_data_root == ZERO_HASH # [to be removed in phase 1]