fix up some PR feedback and testing for #1009

This commit is contained in:
Danny Ryan 2019-04-30 12:55:14 -06:00
parent a40f37b9a2
commit b3373a2d71
No known key found for this signature in database
GPG Key ID: 2765A792E42CE07A
6 changed files with 69 additions and 45 deletions

View File

@ -46,23 +46,23 @@ Store = None
code_lines.append(""" code_lines.append("""
# Monkey patch validator get committee code # Monkey patch validator get committee code
_compute_committee = compute_committee _get_crosslink_committee = get_crosslink_committee
committee_cache = {} committee_cache = {}
def compute_committee(validator_indices: List[ValidatorIndex], def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> List[ValidatorIndex]:
seed: Bytes32, active_indices = get_active_validator_indices(state, epoch)
index: int, seed = generate_seed(state, epoch)
total_committees: int) -> List[ValidatorIndex]: committee_count = get_epoch_committee_count(state, epoch)
committee_index = (shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT
param_hash = (hash_tree_root(validator_indices), seed, index, total_committees) param_hash = (hash_tree_root(active_indices), seed, committee_count, committee_index)
if param_hash in committee_cache: if param_hash in committee_cache:
# print("Cache hit, epoch={0}".format(epoch)) # print("Cache hit, epoch={0}".format(epoch))
return committee_cache[param_hash] return committee_cache[param_hash]
else: else:
# print("Cache miss, epoch={0}".format(epoch)) # print("Cache miss, epoch={0}".format(epoch))
ret = _compute_committee(validator_indices, seed, index, total_committees) ret = _get_crosslink_committee(state, epoch, shard)
committee_cache[param_hash] = ret committee_cache[param_hash] = ret
return ret return ret

View File

@ -66,6 +66,7 @@
- [`get_attestation_slot`](#get_attestation_slot) - [`get_attestation_slot`](#get_attestation_slot)
- [`get_block_root_at_slot`](#get_block_root_at_slot) - [`get_block_root_at_slot`](#get_block_root_at_slot)
- [`get_block_root`](#get_block_root) - [`get_block_root`](#get_block_root)
- [`get_state_root`](#get_state_root)
- [`get_randao_mix`](#get_randao_mix) - [`get_randao_mix`](#get_randao_mix)
- [`get_active_index_root`](#get_active_index_root) - [`get_active_index_root`](#get_active_index_root)
- [`generate_seed`](#generate_seed) - [`generate_seed`](#generate_seed)
@ -82,6 +83,8 @@
- [`verify_bitfield`](#verify_bitfield) - [`verify_bitfield`](#verify_bitfield)
- [`convert_to_indexed`](#convert_to_indexed) - [`convert_to_indexed`](#convert_to_indexed)
- [`verify_indexed_attestation`](#verify_indexed_attestation) - [`verify_indexed_attestation`](#verify_indexed_attestation)
- [`is_double_vote`](#is_double_vote)
- [`is_surround_vote`](#is_surround_vote)
- [`integer_squareroot`](#integer_squareroot) - [`integer_squareroot`](#integer_squareroot)
- [`get_delayed_activation_exit_epoch`](#get_delayed_activation_exit_epoch) - [`get_delayed_activation_exit_epoch`](#get_delayed_activation_exit_epoch)
- [`get_churn_limit`](#get_churn_limit) - [`get_churn_limit`](#get_churn_limit)
@ -763,7 +766,7 @@ def get_epoch_start_shard(state: BeaconState, epoch: Epoch) -> Shard:
def get_attestation_slot(state: BeaconState, attestation: Attestation) -> Slot: def get_attestation_slot(state: BeaconState, attestation: Attestation) -> Slot:
epoch = attestation.data.target_epoch epoch = attestation.data.target_epoch
committee_count = get_epoch_committee_count(state, epoch) committee_count = get_epoch_committee_count(state, epoch)
offset = (attestation.data.shard - get_epoch_start_slot(epoch)) % SHARD_COUNT offset = (attestation.data.shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT
return get_epoch_start_slot(epoch) + offset // (committee_count // SLOTS_PER_EPOCH) return get_epoch_start_slot(epoch) + offset // (committee_count // SLOTS_PER_EPOCH)
``` ```
@ -790,6 +793,18 @@ def get_block_root(state: BeaconState,
return get_block_root_at_slot(state, get_epoch_start_slot(epoch)) return get_block_root_at_slot(state, get_epoch_start_slot(epoch))
``` ```
### `get_state_root`
```python
def get_state_root(state: BeaconState,
slot: Slot) -> Bytes32:
"""
Return the state root at a recent ``slot``.
"""
assert slot < state.slot <= slot + SLOTS_PER_HISTORICAL_ROOT
return state.latest_state_roots[slot % SLOTS_PER_HISTORICAL_ROOT]
```
### `get_randao_mix` ### `get_randao_mix`
```python ```python
@ -838,8 +853,9 @@ def get_beacon_proposer_index(state: BeaconState) -> ValidatorIndex:
""" """
epoch = get_current_epoch(state) epoch = get_current_epoch(state)
committees_per_slot = get_epoch_committee_count(state, epoch) // SLOTS_PER_EPOCH committees_per_slot = get_epoch_committee_count(state, epoch) // SLOTS_PER_EPOCH
offset = committees_per_slot * (state.slot % EPOCH_LENGTH) offset = committees_per_slot * (state.slot % SLOTS_PER_EPOCH)
first_committee = get_crosslink_committee(state, epoch, offset) shard = (get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT
first_committee = get_crosslink_committee(state, epoch, shard)
MAX_RANDOM_BYTE = 2**8 - 1 MAX_RANDOM_BYTE = 2**8 - 1
i = 0 i = 0
while True: while True:
@ -905,8 +921,9 @@ def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> L
start_validator_index = (len(active_indices) * committee_index) // committee_count start_validator_index = (len(active_indices) * committee_index) // committee_count
end_validator_index = (len(active_indices) * (committee_index + 1)) // committee_count end_validator_index = (len(active_indices) * (committee_index + 1)) // committee_count
seed = generate_seed(state, epoch)
return [ return [
active_indices[get_shuffled_index(i, len(active_indices), generate_seed(state, epoch))] active_indices[get_shuffled_index(i, len(active_indices), seed)]
for i in range(start_validator_index, end_validator_index) for i in range(start_validator_index, end_validator_index)
] ]
``` ```
@ -1046,6 +1063,29 @@ def verify_indexed_attestation(state: BeaconState, indexed_attestation: IndexedA
) )
``` ```
### `is_double_vote`
```python
def is_double_vote(attestation_data_1: AttestationData, attestation_data_2: AttestationData) -> bool:
"""
Check if ``attestation_data_1`` and ``attestation_data_2`` violate Casper "double" slashing rule.
"""
return attestation_data_1.target_epoch == attestation_data_2.target_epoch
```
### `is_surround_vote`
```python
def is_surround_vote(attestation_data_1: AttestationData, attestation_data_2: AttestationData) -> bool:
"""
Check if ``attestation_data_1`` and ``attestation_data_2`` violate Casper "surround" slashing rule.
"""
return (
attestation_data_1.source_epoch < attestation_data_2.source_epoch and
attestation_data_2.target_epoch < attestation_data_1.target_epoch
)
```
### `integer_squareroot` ### `integer_squareroot`
```python ```python
@ -1359,7 +1399,7 @@ def process_crosslinks(state: BeaconState) -> None:
state.previous_crosslinks = [c for c in state.current_crosslinks] state.previous_crosslinks = [c for c in state.current_crosslinks]
for epoch in (get_previous_epoch(state), get_current_epoch(state)): for epoch in (get_previous_epoch(state), get_current_epoch(state)):
for offset in range(get_epoch_committee_count(state, epoch)): for offset in range(get_epoch_committee_count(state, epoch)):
shard = (get_epoch_start_shard(epoch) + offset) % SHARD_COUNT shard = (get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT
crosslink_committee = get_crosslink_committee(state, epoch, shard) crosslink_committee = get_crosslink_committee(state, epoch, shard)
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch) winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch)
if 3 * get_total_balance(state, attesting_indices) >= 2 * get_total_balance(state, crosslink_committee): if 3 * get_total_balance(state, attesting_indices) >= 2 * get_total_balance(state, crosslink_committee):
@ -1429,8 +1469,8 @@ def get_crosslink_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]:
penalties = [0 for index in range(len(state.validator_registry))] penalties = [0 for index in range(len(state.validator_registry))]
epoch = get_previous_epoch(state) epoch = get_previous_epoch(state)
for offset in range(get_epoch_committee_count(state, epoch)): for offset in range(get_epoch_committee_count(state, epoch)):
shard = (get_epoch_start_shard(epoch) + offset) % SHARD_COUNT shard = (get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT
crosslink_committee = get_crosslink_committee(state, epoch, offset) crosslink_committee = get_crosslink_committee(state, epoch, shard)
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch) winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch)
attesting_balance = get_total_balance(state, attesting_indices) attesting_balance = get_total_balance(state, attesting_indices)
committee_balance = get_total_balance(state, crosslink_committee) committee_balance = get_total_balance(state, crosslink_committee)
@ -1653,16 +1693,9 @@ def process_attester_slashing(state: BeaconState,
attestation_2 = attester_slashing.attestation_2 attestation_2 = attester_slashing.attestation_2
# Check that the attestations are conflicting # Check that the attestations are conflicting
assert attestation_1.data != attestation_2.data assert attestation_1.data != attestation_2.data
source_1 = attestation_1.data.source_epoch
target_1 = attestation_1.data.target_epoch
source_2 = attestation_2.data.source_epoch
target_2 = attestation_2.data.target_epoch
assert ( assert (
# Double vote is_double_vote(attestation_1.data, attestation_2.data) or
(target_1 == target_2) or is_surround_vote(attestation_1.data, attestation_2.data)
# Surround vote (attestation 1 surrounds attestation 2)
(source_1 < source_2 and target_2 < target_1)
) )
assert verify_indexed_attestation(state, attestation_1) assert verify_indexed_attestation(state, attestation_1)
@ -1715,7 +1748,7 @@ def process_attestation(state: BeaconState, attestation: Attestation) -> None:
inclusion_delay=state.slot - attestation_slot, inclusion_delay=state.slot - attestation_slot,
proposer_index=get_beacon_proposer_index(state), proposer_index=get_beacon_proposer_index(state),
) )
if target_epoch == get_current_epoch(state): if data.target_epoch == get_current_epoch(state):
state.current_epoch_attestations.append(pending_attestation) state.current_epoch_attestations.append(pending_attestation)
else: else:
state.previous_epoch_attestations.append(pending_attestation) state.previous_epoch_attestations.append(pending_attestation)

View File

@ -38,8 +38,7 @@ def run_attestation_processing(state, attestation, valid=True):
process_attestation(post_state, attestation) process_attestation(post_state, attestation)
current_epoch = get_current_epoch(state) current_epoch = get_current_epoch(state)
target_epoch = slot_to_epoch(attestation.data.slot) if attestation.data.target_epoch == current_epoch:
if target_epoch == current_epoch:
assert len(post_state.current_epoch_attestations) == len(state.current_epoch_attestations) + 1 assert len(post_state.current_epoch_attestations) == len(state.current_epoch_attestations) + 1
else: else:
assert len(post_state.previous_epoch_attestations) == len(state.previous_epoch_attestations) + 1 assert len(post_state.previous_epoch_attestations) == len(state.previous_epoch_attestations) + 1

View File

@ -65,7 +65,7 @@ def test_success_surround(state):
# set attestion1 to surround attestation 2 # set attestion1 to surround attestation 2
attester_slashing.attestation_1.data.source_epoch = attester_slashing.attestation_2.data.source_epoch - 1 attester_slashing.attestation_1.data.source_epoch = attester_slashing.attestation_2.data.source_epoch - 1
attester_slashing.attestation_1.data.slot = attester_slashing.attestation_2.data.slot + spec.SLOTS_PER_EPOCH attester_slashing.attestation_1.data.target_epoch = attester_slashing.attestation_2.data.target_epoch + 1
pre_state, post_state = run_attester_slashing_processing(state, attester_slashing) pre_state, post_state = run_attester_slashing_processing(state, attester_slashing)
@ -85,7 +85,7 @@ def test_same_data(state):
def test_no_double_or_surround(state): def test_no_double_or_surround(state):
attester_slashing = get_valid_attester_slashing(state) attester_slashing = get_valid_attester_slashing(state)
attester_slashing.attestation_1.data.slot += spec.SLOTS_PER_EPOCH attester_slashing.attestation_1.data.target_epoch += 1
pre_state, post_state = run_attester_slashing_processing(state, attester_slashing, False) pre_state, post_state = run_attester_slashing_processing(state, attester_slashing, False)

View File

@ -15,7 +15,7 @@ from tests.helpers import (
add_attestation_to_state, add_attestation_to_state,
build_empty_block_for_next_slot, build_empty_block_for_next_slot,
fill_aggregate_attestation, fill_aggregate_attestation,
get_crosslink_committee_for_attestation, get_crosslink_committee,
get_valid_attestation, get_valid_attestation,
next_epoch, next_epoch,
next_slot, next_slot,
@ -88,7 +88,7 @@ def test_single_crosslink_update_from_previous_epoch(state):
assert post_state.previous_crosslinks[shard] != post_state.current_crosslinks[shard] assert post_state.previous_crosslinks[shard] != post_state.current_crosslinks[shard]
assert pre_state.current_crosslinks[shard] != post_state.current_crosslinks[shard] assert pre_state.current_crosslinks[shard] != post_state.current_crosslinks[shard]
# ensure rewarded # ensure rewarded
for index in get_crosslink_committee_for_attestation(state, attestation.data): for index in get_crosslink_committee(state, attestation.data.target_epoch, attestation.data.shard):
assert crosslink_deltas[0][index] > 0 assert crosslink_deltas[0][index] > 0
assert crosslink_deltas[1][index] == 0 assert crosslink_deltas[1][index] == 0
@ -129,7 +129,7 @@ def test_double_late_crosslink(state):
# ensure that the current crosslinks were not updated by the second attestation # ensure that the current crosslinks were not updated by the second attestation
assert post_state.previous_crosslinks[shard] == post_state.current_crosslinks[shard] assert post_state.previous_crosslinks[shard] == post_state.current_crosslinks[shard]
# ensure no reward, only penalties for the failed crosslink # ensure no reward, only penalties for the failed crosslink
for index in get_crosslink_committee_for_attestation(state, attestation_2.data): for index in get_crosslink_committee(state, attestation_2.data.target_epoch, attestation_2.data.shard):
assert crosslink_deltas[0][index] == 0 assert crosslink_deltas[0][index] == 0
assert crosslink_deltas[1][index] > 0 assert crosslink_deltas[1][index] > 0

View File

@ -29,7 +29,7 @@ from eth2spec.phase0.spec import (
get_attesting_indices, get_attesting_indices,
get_block_root, get_block_root,
get_block_root_at_slot, get_block_root_at_slot,
get_crosslink_committees_at_slot, get_crosslink_committee,
get_current_epoch, get_current_epoch,
get_domain, get_domain,
get_epoch_start_slot, get_epoch_start_slot,
@ -174,11 +174,11 @@ def build_attestation_data(state, slot, shard):
crosslinks = state.current_crosslinks if slot_to_epoch(slot) == get_current_epoch(state) else state.previous_crosslinks crosslinks = state.current_crosslinks if slot_to_epoch(slot) == get_current_epoch(state) else state.previous_crosslinks
return AttestationData( return AttestationData(
slot=slot,
shard=shard, shard=shard,
beacon_block_root=block_root, beacon_block_root=block_root,
source_epoch=justified_epoch, source_epoch=justified_epoch,
source_root=justified_block_root, source_root=justified_block_root,
target_epoch=slot_to_epoch(slot),
target_root=epoch_boundary_root, target_root=epoch_boundary_root,
crosslink_data_root=spec.ZERO_HASH, crosslink_data_root=spec.ZERO_HASH,
previous_crosslink_root=hash_tree_root(crosslinks[shard]), previous_crosslink_root=hash_tree_root(crosslinks[shard]),
@ -276,14 +276,6 @@ def get_valid_attester_slashing(state):
) )
def get_crosslink_committee_for_attestation(state, attestation_data):
"""
Return the crosslink committee corresponding to ``attestation_data``.
"""
crosslink_committees = get_crosslink_committees_at_slot(state, attestation_data.slot)
return [committee for committee, shard in crosslink_committees if shard == attestation_data.shard][0]
def get_valid_attestation(state, slot=None): def get_valid_attestation(state, slot=None):
if slot is None: if slot is None:
slot = state.slot slot = state.slot
@ -296,7 +288,7 @@ def get_valid_attestation(state, slot=None):
attestation_data = build_attestation_data(state, slot, shard) attestation_data = build_attestation_data(state, slot, shard)
crosslink_committee = get_crosslink_committee_for_attestation(state, attestation_data) crosslink_committee = get_crosslink_committee(state, attestation_data.target_epoch, attestation_data.shard)
committee_size = len(crosslink_committee) committee_size = len(crosslink_committee)
bitfield_length = (committee_size + 7) // 8 bitfield_length = (committee_size + 7) // 8
@ -383,13 +375,13 @@ def get_attestation_signature(state, attestation_data, privkey, custody_bit=0b0)
domain=get_domain( domain=get_domain(
state=state, state=state,
domain_type=spec.DOMAIN_ATTESTATION, domain_type=spec.DOMAIN_ATTESTATION,
message_epoch=slot_to_epoch(attestation_data.slot), message_epoch=attestation_data.target_epoch,
) )
) )
def fill_aggregate_attestation(state, attestation): def fill_aggregate_attestation(state, attestation):
crosslink_committee = get_crosslink_committee_for_attestation(state, attestation.data) crosslink_committee = get_crosslink_committee(state, attestation.data.target_epoch, attestation.data.shard)
for i in range(len(crosslink_committee)): for i in range(len(crosslink_committee)):
attestation.aggregation_bitfield = set_bitfield_bit(attestation.aggregation_bitfield, i) attestation.aggregation_bitfield = set_bitfield_bit(attestation.aggregation_bitfield, i)