fix up some PR feedback and testing for #1009

This commit is contained in:
Danny Ryan 2019-04-30 12:55:14 -06:00
parent a40f37b9a2
commit b3373a2d71
No known key found for this signature in database
GPG Key ID: 2765A792E42CE07A
6 changed files with 69 additions and 45 deletions

View File

@ -46,23 +46,23 @@ Store = None
code_lines.append("""
# Monkey patch validator get committee code
_compute_committee = compute_committee
_get_crosslink_committee = get_crosslink_committee
committee_cache = {}
def compute_committee(validator_indices: List[ValidatorIndex],
seed: Bytes32,
index: int,
total_committees: int) -> List[ValidatorIndex]:
param_hash = (hash_tree_root(validator_indices), seed, index, total_committees)
def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> List[ValidatorIndex]:
active_indices = get_active_validator_indices(state, epoch)
seed = generate_seed(state, epoch)
committee_count = get_epoch_committee_count(state, epoch)
committee_index = (shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT
param_hash = (hash_tree_root(active_indices), seed, committee_count, committee_index)
if param_hash in committee_cache:
# print("Cache hit, epoch={0}".format(epoch))
return committee_cache[param_hash]
else:
# print("Cache miss, epoch={0}".format(epoch))
ret = _compute_committee(validator_indices, seed, index, total_committees)
ret = _get_crosslink_committee(state, epoch, shard)
committee_cache[param_hash] = ret
return ret

View File

@ -66,6 +66,7 @@
- [`get_attestation_slot`](#get_attestation_slot)
- [`get_block_root_at_slot`](#get_block_root_at_slot)
- [`get_block_root`](#get_block_root)
- [`get_state_root`](#get_state_root)
- [`get_randao_mix`](#get_randao_mix)
- [`get_active_index_root`](#get_active_index_root)
- [`generate_seed`](#generate_seed)
@ -82,6 +83,8 @@
- [`verify_bitfield`](#verify_bitfield)
- [`convert_to_indexed`](#convert_to_indexed)
- [`verify_indexed_attestation`](#verify_indexed_attestation)
- [`is_double_vote`](#is_double_vote)
- [`is_surround_vote`](#is_surround_vote)
- [`integer_squareroot`](#integer_squareroot)
- [`get_delayed_activation_exit_epoch`](#get_delayed_activation_exit_epoch)
- [`get_churn_limit`](#get_churn_limit)
@ -763,7 +766,7 @@ def get_epoch_start_shard(state: BeaconState, epoch: Epoch) -> Shard:
def get_attestation_slot(state: BeaconState, attestation: Attestation) -> Slot:
epoch = attestation.data.target_epoch
committee_count = get_epoch_committee_count(state, epoch)
offset = (attestation.data.shard - get_epoch_start_slot(epoch)) % SHARD_COUNT
offset = (attestation.data.shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT
return get_epoch_start_slot(epoch) + offset // (committee_count // SLOTS_PER_EPOCH)
```
@ -790,6 +793,18 @@ def get_block_root(state: BeaconState,
return get_block_root_at_slot(state, get_epoch_start_slot(epoch))
```
### `get_state_root`
```python
def get_state_root(state: BeaconState,
slot: Slot) -> Bytes32:
"""
Return the state root at a recent ``slot``.
"""
assert slot < state.slot <= slot + SLOTS_PER_HISTORICAL_ROOT
return state.latest_state_roots[slot % SLOTS_PER_HISTORICAL_ROOT]
```
### `get_randao_mix`
```python
@ -838,8 +853,9 @@ def get_beacon_proposer_index(state: BeaconState) -> ValidatorIndex:
"""
epoch = get_current_epoch(state)
committees_per_slot = get_epoch_committee_count(state, epoch) // SLOTS_PER_EPOCH
offset = committees_per_slot * (state.slot % EPOCH_LENGTH)
first_committee = get_crosslink_committee(state, epoch, offset)
offset = committees_per_slot * (state.slot % SLOTS_PER_EPOCH)
shard = (get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT
first_committee = get_crosslink_committee(state, epoch, shard)
MAX_RANDOM_BYTE = 2**8 - 1
i = 0
while True:
@ -905,8 +921,9 @@ def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> L
start_validator_index = (len(active_indices) * committee_index) // committee_count
end_validator_index = (len(active_indices) * (committee_index + 1)) // committee_count
seed = generate_seed(state, epoch)
return [
active_indices[get_shuffled_index(i, len(active_indices), generate_seed(state, epoch))]
active_indices[get_shuffled_index(i, len(active_indices), seed)]
for i in range(start_validator_index, end_validator_index)
]
```
@ -1046,6 +1063,29 @@ def verify_indexed_attestation(state: BeaconState, indexed_attestation: IndexedA
)
```
### `is_double_vote`
```python
def is_double_vote(attestation_data_1: AttestationData, attestation_data_2: AttestationData) -> bool:
"""
Check if ``attestation_data_1`` and ``attestation_data_2`` violate Casper "double" slashing rule.
"""
return attestation_data_1.target_epoch == attestation_data_2.target_epoch
```
### `is_surround_vote`
```python
def is_surround_vote(attestation_data_1: AttestationData, attestation_data_2: AttestationData) -> bool:
"""
Check if ``attestation_data_1`` and ``attestation_data_2`` violate Casper "surround" slashing rule.
"""
return (
attestation_data_1.source_epoch < attestation_data_2.source_epoch and
attestation_data_2.target_epoch < attestation_data_1.target_epoch
)
```
### `integer_squareroot`
```python
@ -1359,7 +1399,7 @@ def process_crosslinks(state: BeaconState) -> None:
state.previous_crosslinks = [c for c in state.current_crosslinks]
for epoch in (get_previous_epoch(state), get_current_epoch(state)):
for offset in range(get_epoch_committee_count(state, epoch)):
shard = (get_epoch_start_shard(epoch) + offset) % SHARD_COUNT
shard = (get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT
crosslink_committee = get_crosslink_committee(state, epoch, shard)
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch)
if 3 * get_total_balance(state, attesting_indices) >= 2 * get_total_balance(state, crosslink_committee):
@ -1429,8 +1469,8 @@ def get_crosslink_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]:
penalties = [0 for index in range(len(state.validator_registry))]
epoch = get_previous_epoch(state)
for offset in range(get_epoch_committee_count(state, epoch)):
shard = (get_epoch_start_shard(epoch) + offset) % SHARD_COUNT
crosslink_committee = get_crosslink_committee(state, epoch, offset)
shard = (get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT
crosslink_committee = get_crosslink_committee(state, epoch, shard)
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch)
attesting_balance = get_total_balance(state, attesting_indices)
committee_balance = get_total_balance(state, crosslink_committee)
@ -1653,16 +1693,9 @@ def process_attester_slashing(state: BeaconState,
attestation_2 = attester_slashing.attestation_2
# Check that the attestations are conflicting
assert attestation_1.data != attestation_2.data
source_1 = attestation_1.data.source_epoch
target_1 = attestation_1.data.target_epoch
source_2 = attestation_2.data.source_epoch
target_2 = attestation_2.data.target_epoch
assert (
# Double vote
(target_1 == target_2) or
# Surround vote (attestation 1 surrounds attestation 2)
(source_1 < source_2 and target_2 < target_1)
is_double_vote(attestation_1.data, attestation_2.data) or
is_surround_vote(attestation_1.data, attestation_2.data)
)
assert verify_indexed_attestation(state, attestation_1)
@ -1715,7 +1748,7 @@ def process_attestation(state: BeaconState, attestation: Attestation) -> None:
inclusion_delay=state.slot - attestation_slot,
proposer_index=get_beacon_proposer_index(state),
)
if target_epoch == get_current_epoch(state):
if data.target_epoch == get_current_epoch(state):
state.current_epoch_attestations.append(pending_attestation)
else:
state.previous_epoch_attestations.append(pending_attestation)

View File

@ -38,8 +38,7 @@ def run_attestation_processing(state, attestation, valid=True):
process_attestation(post_state, attestation)
current_epoch = get_current_epoch(state)
target_epoch = slot_to_epoch(attestation.data.slot)
if target_epoch == current_epoch:
if attestation.data.target_epoch == current_epoch:
assert len(post_state.current_epoch_attestations) == len(state.current_epoch_attestations) + 1
else:
assert len(post_state.previous_epoch_attestations) == len(state.previous_epoch_attestations) + 1

View File

@ -65,7 +65,7 @@ def test_success_surround(state):
# set attestion1 to surround attestation 2
attester_slashing.attestation_1.data.source_epoch = attester_slashing.attestation_2.data.source_epoch - 1
attester_slashing.attestation_1.data.slot = attester_slashing.attestation_2.data.slot + spec.SLOTS_PER_EPOCH
attester_slashing.attestation_1.data.target_epoch = attester_slashing.attestation_2.data.target_epoch + 1
pre_state, post_state = run_attester_slashing_processing(state, attester_slashing)
@ -85,7 +85,7 @@ def test_same_data(state):
def test_no_double_or_surround(state):
attester_slashing = get_valid_attester_slashing(state)
attester_slashing.attestation_1.data.slot += spec.SLOTS_PER_EPOCH
attester_slashing.attestation_1.data.target_epoch += 1
pre_state, post_state = run_attester_slashing_processing(state, attester_slashing, False)

View File

@ -15,7 +15,7 @@ from tests.helpers import (
add_attestation_to_state,
build_empty_block_for_next_slot,
fill_aggregate_attestation,
get_crosslink_committee_for_attestation,
get_crosslink_committee,
get_valid_attestation,
next_epoch,
next_slot,
@ -88,7 +88,7 @@ def test_single_crosslink_update_from_previous_epoch(state):
assert post_state.previous_crosslinks[shard] != post_state.current_crosslinks[shard]
assert pre_state.current_crosslinks[shard] != post_state.current_crosslinks[shard]
# ensure rewarded
for index in get_crosslink_committee_for_attestation(state, attestation.data):
for index in get_crosslink_committee(state, attestation.data.target_epoch, attestation.data.shard):
assert crosslink_deltas[0][index] > 0
assert crosslink_deltas[1][index] == 0
@ -129,7 +129,7 @@ def test_double_late_crosslink(state):
# ensure that the current crosslinks were not updated by the second attestation
assert post_state.previous_crosslinks[shard] == post_state.current_crosslinks[shard]
# ensure no reward, only penalties for the failed crosslink
for index in get_crosslink_committee_for_attestation(state, attestation_2.data):
for index in get_crosslink_committee(state, attestation_2.data.target_epoch, attestation_2.data.shard):
assert crosslink_deltas[0][index] == 0
assert crosslink_deltas[1][index] > 0

View File

@ -29,7 +29,7 @@ from eth2spec.phase0.spec import (
get_attesting_indices,
get_block_root,
get_block_root_at_slot,
get_crosslink_committees_at_slot,
get_crosslink_committee,
get_current_epoch,
get_domain,
get_epoch_start_slot,
@ -174,11 +174,11 @@ def build_attestation_data(state, slot, shard):
crosslinks = state.current_crosslinks if slot_to_epoch(slot) == get_current_epoch(state) else state.previous_crosslinks
return AttestationData(
slot=slot,
shard=shard,
beacon_block_root=block_root,
source_epoch=justified_epoch,
source_root=justified_block_root,
target_epoch=slot_to_epoch(slot),
target_root=epoch_boundary_root,
crosslink_data_root=spec.ZERO_HASH,
previous_crosslink_root=hash_tree_root(crosslinks[shard]),
@ -276,14 +276,6 @@ def get_valid_attester_slashing(state):
)
def get_crosslink_committee_for_attestation(state, attestation_data):
"""
Return the crosslink committee corresponding to ``attestation_data``.
"""
crosslink_committees = get_crosslink_committees_at_slot(state, attestation_data.slot)
return [committee for committee, shard in crosslink_committees if shard == attestation_data.shard][0]
def get_valid_attestation(state, slot=None):
if slot is None:
slot = state.slot
@ -296,7 +288,7 @@ def get_valid_attestation(state, slot=None):
attestation_data = build_attestation_data(state, slot, shard)
crosslink_committee = get_crosslink_committee_for_attestation(state, attestation_data)
crosslink_committee = get_crosslink_committee(state, attestation_data.target_epoch, attestation_data.shard)
committee_size = len(crosslink_committee)
bitfield_length = (committee_size + 7) // 8
@ -383,13 +375,13 @@ def get_attestation_signature(state, attestation_data, privkey, custody_bit=0b0)
domain=get_domain(
state=state,
domain_type=spec.DOMAIN_ATTESTATION,
message_epoch=slot_to_epoch(attestation_data.slot),
message_epoch=attestation_data.target_epoch,
)
)
def fill_aggregate_attestation(state, attestation):
crosslink_committee = get_crosslink_committee_for_attestation(state, attestation.data)
crosslink_committee = get_crosslink_committee(state, attestation.data.target_epoch, attestation.data.shard)
for i in range(len(crosslink_committee)):
attestation.aggregation_bitfield = set_bitfield_bit(attestation.aggregation_bitfield, i)