fix up some PR feedback and testing for #1009
This commit is contained in:
parent
a40f37b9a2
commit
b3373a2d71
|
@ -46,23 +46,23 @@ Store = None
|
||||||
|
|
||||||
code_lines.append("""
|
code_lines.append("""
|
||||||
# Monkey patch validator get committee code
|
# Monkey patch validator get committee code
|
||||||
_compute_committee = compute_committee
|
_get_crosslink_committee = get_crosslink_committee
|
||||||
committee_cache = {}
|
committee_cache = {}
|
||||||
|
|
||||||
|
|
||||||
def compute_committee(validator_indices: List[ValidatorIndex],
|
def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> List[ValidatorIndex]:
|
||||||
seed: Bytes32,
|
active_indices = get_active_validator_indices(state, epoch)
|
||||||
index: int,
|
seed = generate_seed(state, epoch)
|
||||||
total_committees: int) -> List[ValidatorIndex]:
|
committee_count = get_epoch_committee_count(state, epoch)
|
||||||
|
committee_index = (shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT
|
||||||
param_hash = (hash_tree_root(validator_indices), seed, index, total_committees)
|
param_hash = (hash_tree_root(active_indices), seed, committee_count, committee_index)
|
||||||
|
|
||||||
if param_hash in committee_cache:
|
if param_hash in committee_cache:
|
||||||
# print("Cache hit, epoch={0}".format(epoch))
|
# print("Cache hit, epoch={0}".format(epoch))
|
||||||
return committee_cache[param_hash]
|
return committee_cache[param_hash]
|
||||||
else:
|
else:
|
||||||
# print("Cache miss, epoch={0}".format(epoch))
|
# print("Cache miss, epoch={0}".format(epoch))
|
||||||
ret = _compute_committee(validator_indices, seed, index, total_committees)
|
ret = _get_crosslink_committee(state, epoch, shard)
|
||||||
committee_cache[param_hash] = ret
|
committee_cache[param_hash] = ret
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
|
@ -66,6 +66,7 @@
|
||||||
- [`get_attestation_slot`](#get_attestation_slot)
|
- [`get_attestation_slot`](#get_attestation_slot)
|
||||||
- [`get_block_root_at_slot`](#get_block_root_at_slot)
|
- [`get_block_root_at_slot`](#get_block_root_at_slot)
|
||||||
- [`get_block_root`](#get_block_root)
|
- [`get_block_root`](#get_block_root)
|
||||||
|
- [`get_state_root`](#get_state_root)
|
||||||
- [`get_randao_mix`](#get_randao_mix)
|
- [`get_randao_mix`](#get_randao_mix)
|
||||||
- [`get_active_index_root`](#get_active_index_root)
|
- [`get_active_index_root`](#get_active_index_root)
|
||||||
- [`generate_seed`](#generate_seed)
|
- [`generate_seed`](#generate_seed)
|
||||||
|
@ -82,6 +83,8 @@
|
||||||
- [`verify_bitfield`](#verify_bitfield)
|
- [`verify_bitfield`](#verify_bitfield)
|
||||||
- [`convert_to_indexed`](#convert_to_indexed)
|
- [`convert_to_indexed`](#convert_to_indexed)
|
||||||
- [`verify_indexed_attestation`](#verify_indexed_attestation)
|
- [`verify_indexed_attestation`](#verify_indexed_attestation)
|
||||||
|
- [`is_double_vote`](#is_double_vote)
|
||||||
|
- [`is_surround_vote`](#is_surround_vote)
|
||||||
- [`integer_squareroot`](#integer_squareroot)
|
- [`integer_squareroot`](#integer_squareroot)
|
||||||
- [`get_delayed_activation_exit_epoch`](#get_delayed_activation_exit_epoch)
|
- [`get_delayed_activation_exit_epoch`](#get_delayed_activation_exit_epoch)
|
||||||
- [`get_churn_limit`](#get_churn_limit)
|
- [`get_churn_limit`](#get_churn_limit)
|
||||||
|
@ -763,7 +766,7 @@ def get_epoch_start_shard(state: BeaconState, epoch: Epoch) -> Shard:
|
||||||
def get_attestation_slot(state: BeaconState, attestation: Attestation) -> Slot:
|
def get_attestation_slot(state: BeaconState, attestation: Attestation) -> Slot:
|
||||||
epoch = attestation.data.target_epoch
|
epoch = attestation.data.target_epoch
|
||||||
committee_count = get_epoch_committee_count(state, epoch)
|
committee_count = get_epoch_committee_count(state, epoch)
|
||||||
offset = (attestation.data.shard - get_epoch_start_slot(epoch)) % SHARD_COUNT
|
offset = (attestation.data.shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT
|
||||||
return get_epoch_start_slot(epoch) + offset // (committee_count // SLOTS_PER_EPOCH)
|
return get_epoch_start_slot(epoch) + offset // (committee_count // SLOTS_PER_EPOCH)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -790,6 +793,18 @@ def get_block_root(state: BeaconState,
|
||||||
return get_block_root_at_slot(state, get_epoch_start_slot(epoch))
|
return get_block_root_at_slot(state, get_epoch_start_slot(epoch))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `get_state_root`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_state_root(state: BeaconState,
|
||||||
|
slot: Slot) -> Bytes32:
|
||||||
|
"""
|
||||||
|
Return the state root at a recent ``slot``.
|
||||||
|
"""
|
||||||
|
assert slot < state.slot <= slot + SLOTS_PER_HISTORICAL_ROOT
|
||||||
|
return state.latest_state_roots[slot % SLOTS_PER_HISTORICAL_ROOT]
|
||||||
|
```
|
||||||
|
|
||||||
### `get_randao_mix`
|
### `get_randao_mix`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -838,8 +853,9 @@ def get_beacon_proposer_index(state: BeaconState) -> ValidatorIndex:
|
||||||
"""
|
"""
|
||||||
epoch = get_current_epoch(state)
|
epoch = get_current_epoch(state)
|
||||||
committees_per_slot = get_epoch_committee_count(state, epoch) // SLOTS_PER_EPOCH
|
committees_per_slot = get_epoch_committee_count(state, epoch) // SLOTS_PER_EPOCH
|
||||||
offset = committees_per_slot * (state.slot % EPOCH_LENGTH)
|
offset = committees_per_slot * (state.slot % SLOTS_PER_EPOCH)
|
||||||
first_committee = get_crosslink_committee(state, epoch, offset)
|
shard = (get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT
|
||||||
|
first_committee = get_crosslink_committee(state, epoch, shard)
|
||||||
MAX_RANDOM_BYTE = 2**8 - 1
|
MAX_RANDOM_BYTE = 2**8 - 1
|
||||||
i = 0
|
i = 0
|
||||||
while True:
|
while True:
|
||||||
|
@ -905,8 +921,9 @@ def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> L
|
||||||
|
|
||||||
start_validator_index = (len(active_indices) * committee_index) // committee_count
|
start_validator_index = (len(active_indices) * committee_index) // committee_count
|
||||||
end_validator_index = (len(active_indices) * (committee_index + 1)) // committee_count
|
end_validator_index = (len(active_indices) * (committee_index + 1)) // committee_count
|
||||||
|
seed = generate_seed(state, epoch)
|
||||||
return [
|
return [
|
||||||
active_indices[get_shuffled_index(i, len(active_indices), generate_seed(state, epoch))]
|
active_indices[get_shuffled_index(i, len(active_indices), seed)]
|
||||||
for i in range(start_validator_index, end_validator_index)
|
for i in range(start_validator_index, end_validator_index)
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
@ -1046,6 +1063,29 @@ def verify_indexed_attestation(state: BeaconState, indexed_attestation: IndexedA
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `is_double_vote`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def is_double_vote(attestation_data_1: AttestationData, attestation_data_2: AttestationData) -> bool:
|
||||||
|
"""
|
||||||
|
Check if ``attestation_data_1`` and ``attestation_data_2`` violate Casper "double" slashing rule.
|
||||||
|
"""
|
||||||
|
return attestation_data_1.target_epoch == attestation_data_2.target_epoch
|
||||||
|
```
|
||||||
|
|
||||||
|
### `is_surround_vote`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def is_surround_vote(attestation_data_1: AttestationData, attestation_data_2: AttestationData) -> bool:
|
||||||
|
"""
|
||||||
|
Check if ``attestation_data_1`` and ``attestation_data_2`` violate Casper "surround" slashing rule.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
attestation_data_1.source_epoch < attestation_data_2.source_epoch and
|
||||||
|
attestation_data_2.target_epoch < attestation_data_1.target_epoch
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
### `integer_squareroot`
|
### `integer_squareroot`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -1359,7 +1399,7 @@ def process_crosslinks(state: BeaconState) -> None:
|
||||||
state.previous_crosslinks = [c for c in state.current_crosslinks]
|
state.previous_crosslinks = [c for c in state.current_crosslinks]
|
||||||
for epoch in (get_previous_epoch(state), get_current_epoch(state)):
|
for epoch in (get_previous_epoch(state), get_current_epoch(state)):
|
||||||
for offset in range(get_epoch_committee_count(state, epoch)):
|
for offset in range(get_epoch_committee_count(state, epoch)):
|
||||||
shard = (get_epoch_start_shard(epoch) + offset) % SHARD_COUNT
|
shard = (get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT
|
||||||
crosslink_committee = get_crosslink_committee(state, epoch, shard)
|
crosslink_committee = get_crosslink_committee(state, epoch, shard)
|
||||||
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch)
|
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch)
|
||||||
if 3 * get_total_balance(state, attesting_indices) >= 2 * get_total_balance(state, crosslink_committee):
|
if 3 * get_total_balance(state, attesting_indices) >= 2 * get_total_balance(state, crosslink_committee):
|
||||||
|
@ -1429,8 +1469,8 @@ def get_crosslink_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]:
|
||||||
penalties = [0 for index in range(len(state.validator_registry))]
|
penalties = [0 for index in range(len(state.validator_registry))]
|
||||||
epoch = get_previous_epoch(state)
|
epoch = get_previous_epoch(state)
|
||||||
for offset in range(get_epoch_committee_count(state, epoch)):
|
for offset in range(get_epoch_committee_count(state, epoch)):
|
||||||
shard = (get_epoch_start_shard(epoch) + offset) % SHARD_COUNT
|
shard = (get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT
|
||||||
crosslink_committee = get_crosslink_committee(state, epoch, offset)
|
crosslink_committee = get_crosslink_committee(state, epoch, shard)
|
||||||
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch)
|
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, shard, epoch)
|
||||||
attesting_balance = get_total_balance(state, attesting_indices)
|
attesting_balance = get_total_balance(state, attesting_indices)
|
||||||
committee_balance = get_total_balance(state, crosslink_committee)
|
committee_balance = get_total_balance(state, crosslink_committee)
|
||||||
|
@ -1653,16 +1693,9 @@ def process_attester_slashing(state: BeaconState,
|
||||||
attestation_2 = attester_slashing.attestation_2
|
attestation_2 = attester_slashing.attestation_2
|
||||||
# Check that the attestations are conflicting
|
# Check that the attestations are conflicting
|
||||||
assert attestation_1.data != attestation_2.data
|
assert attestation_1.data != attestation_2.data
|
||||||
|
|
||||||
source_1 = attestation_1.data.source_epoch
|
|
||||||
target_1 = attestation_1.data.target_epoch
|
|
||||||
source_2 = attestation_2.data.source_epoch
|
|
||||||
target_2 = attestation_2.data.target_epoch
|
|
||||||
assert (
|
assert (
|
||||||
# Double vote
|
is_double_vote(attestation_1.data, attestation_2.data) or
|
||||||
(target_1 == target_2) or
|
is_surround_vote(attestation_1.data, attestation_2.data)
|
||||||
# Surround vote (attestation 1 surrounds attestation 2)
|
|
||||||
(source_1 < source_2 and target_2 < target_1)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert verify_indexed_attestation(state, attestation_1)
|
assert verify_indexed_attestation(state, attestation_1)
|
||||||
|
@ -1715,7 +1748,7 @@ def process_attestation(state: BeaconState, attestation: Attestation) -> None:
|
||||||
inclusion_delay=state.slot - attestation_slot,
|
inclusion_delay=state.slot - attestation_slot,
|
||||||
proposer_index=get_beacon_proposer_index(state),
|
proposer_index=get_beacon_proposer_index(state),
|
||||||
)
|
)
|
||||||
if target_epoch == get_current_epoch(state):
|
if data.target_epoch == get_current_epoch(state):
|
||||||
state.current_epoch_attestations.append(pending_attestation)
|
state.current_epoch_attestations.append(pending_attestation)
|
||||||
else:
|
else:
|
||||||
state.previous_epoch_attestations.append(pending_attestation)
|
state.previous_epoch_attestations.append(pending_attestation)
|
||||||
|
|
|
@ -38,8 +38,7 @@ def run_attestation_processing(state, attestation, valid=True):
|
||||||
process_attestation(post_state, attestation)
|
process_attestation(post_state, attestation)
|
||||||
|
|
||||||
current_epoch = get_current_epoch(state)
|
current_epoch = get_current_epoch(state)
|
||||||
target_epoch = slot_to_epoch(attestation.data.slot)
|
if attestation.data.target_epoch == current_epoch:
|
||||||
if target_epoch == current_epoch:
|
|
||||||
assert len(post_state.current_epoch_attestations) == len(state.current_epoch_attestations) + 1
|
assert len(post_state.current_epoch_attestations) == len(state.current_epoch_attestations) + 1
|
||||||
else:
|
else:
|
||||||
assert len(post_state.previous_epoch_attestations) == len(state.previous_epoch_attestations) + 1
|
assert len(post_state.previous_epoch_attestations) == len(state.previous_epoch_attestations) + 1
|
||||||
|
|
|
@ -65,7 +65,7 @@ def test_success_surround(state):
|
||||||
|
|
||||||
# set attestion1 to surround attestation 2
|
# set attestion1 to surround attestation 2
|
||||||
attester_slashing.attestation_1.data.source_epoch = attester_slashing.attestation_2.data.source_epoch - 1
|
attester_slashing.attestation_1.data.source_epoch = attester_slashing.attestation_2.data.source_epoch - 1
|
||||||
attester_slashing.attestation_1.data.slot = attester_slashing.attestation_2.data.slot + spec.SLOTS_PER_EPOCH
|
attester_slashing.attestation_1.data.target_epoch = attester_slashing.attestation_2.data.target_epoch + 1
|
||||||
|
|
||||||
pre_state, post_state = run_attester_slashing_processing(state, attester_slashing)
|
pre_state, post_state = run_attester_slashing_processing(state, attester_slashing)
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ def test_same_data(state):
|
||||||
def test_no_double_or_surround(state):
|
def test_no_double_or_surround(state):
|
||||||
attester_slashing = get_valid_attester_slashing(state)
|
attester_slashing = get_valid_attester_slashing(state)
|
||||||
|
|
||||||
attester_slashing.attestation_1.data.slot += spec.SLOTS_PER_EPOCH
|
attester_slashing.attestation_1.data.target_epoch += 1
|
||||||
|
|
||||||
pre_state, post_state = run_attester_slashing_processing(state, attester_slashing, False)
|
pre_state, post_state = run_attester_slashing_processing(state, attester_slashing, False)
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from tests.helpers import (
|
||||||
add_attestation_to_state,
|
add_attestation_to_state,
|
||||||
build_empty_block_for_next_slot,
|
build_empty_block_for_next_slot,
|
||||||
fill_aggregate_attestation,
|
fill_aggregate_attestation,
|
||||||
get_crosslink_committee_for_attestation,
|
get_crosslink_committee,
|
||||||
get_valid_attestation,
|
get_valid_attestation,
|
||||||
next_epoch,
|
next_epoch,
|
||||||
next_slot,
|
next_slot,
|
||||||
|
@ -88,7 +88,7 @@ def test_single_crosslink_update_from_previous_epoch(state):
|
||||||
assert post_state.previous_crosslinks[shard] != post_state.current_crosslinks[shard]
|
assert post_state.previous_crosslinks[shard] != post_state.current_crosslinks[shard]
|
||||||
assert pre_state.current_crosslinks[shard] != post_state.current_crosslinks[shard]
|
assert pre_state.current_crosslinks[shard] != post_state.current_crosslinks[shard]
|
||||||
# ensure rewarded
|
# ensure rewarded
|
||||||
for index in get_crosslink_committee_for_attestation(state, attestation.data):
|
for index in get_crosslink_committee(state, attestation.data.target_epoch, attestation.data.shard):
|
||||||
assert crosslink_deltas[0][index] > 0
|
assert crosslink_deltas[0][index] > 0
|
||||||
assert crosslink_deltas[1][index] == 0
|
assert crosslink_deltas[1][index] == 0
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ def test_double_late_crosslink(state):
|
||||||
# ensure that the current crosslinks were not updated by the second attestation
|
# ensure that the current crosslinks were not updated by the second attestation
|
||||||
assert post_state.previous_crosslinks[shard] == post_state.current_crosslinks[shard]
|
assert post_state.previous_crosslinks[shard] == post_state.current_crosslinks[shard]
|
||||||
# ensure no reward, only penalties for the failed crosslink
|
# ensure no reward, only penalties for the failed crosslink
|
||||||
for index in get_crosslink_committee_for_attestation(state, attestation_2.data):
|
for index in get_crosslink_committee(state, attestation_2.data.target_epoch, attestation_2.data.shard):
|
||||||
assert crosslink_deltas[0][index] == 0
|
assert crosslink_deltas[0][index] == 0
|
||||||
assert crosslink_deltas[1][index] > 0
|
assert crosslink_deltas[1][index] > 0
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ from eth2spec.phase0.spec import (
|
||||||
get_attesting_indices,
|
get_attesting_indices,
|
||||||
get_block_root,
|
get_block_root,
|
||||||
get_block_root_at_slot,
|
get_block_root_at_slot,
|
||||||
get_crosslink_committees_at_slot,
|
get_crosslink_committee,
|
||||||
get_current_epoch,
|
get_current_epoch,
|
||||||
get_domain,
|
get_domain,
|
||||||
get_epoch_start_slot,
|
get_epoch_start_slot,
|
||||||
|
@ -174,11 +174,11 @@ def build_attestation_data(state, slot, shard):
|
||||||
|
|
||||||
crosslinks = state.current_crosslinks if slot_to_epoch(slot) == get_current_epoch(state) else state.previous_crosslinks
|
crosslinks = state.current_crosslinks if slot_to_epoch(slot) == get_current_epoch(state) else state.previous_crosslinks
|
||||||
return AttestationData(
|
return AttestationData(
|
||||||
slot=slot,
|
|
||||||
shard=shard,
|
shard=shard,
|
||||||
beacon_block_root=block_root,
|
beacon_block_root=block_root,
|
||||||
source_epoch=justified_epoch,
|
source_epoch=justified_epoch,
|
||||||
source_root=justified_block_root,
|
source_root=justified_block_root,
|
||||||
|
target_epoch=slot_to_epoch(slot),
|
||||||
target_root=epoch_boundary_root,
|
target_root=epoch_boundary_root,
|
||||||
crosslink_data_root=spec.ZERO_HASH,
|
crosslink_data_root=spec.ZERO_HASH,
|
||||||
previous_crosslink_root=hash_tree_root(crosslinks[shard]),
|
previous_crosslink_root=hash_tree_root(crosslinks[shard]),
|
||||||
|
@ -276,14 +276,6 @@ def get_valid_attester_slashing(state):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_crosslink_committee_for_attestation(state, attestation_data):
|
|
||||||
"""
|
|
||||||
Return the crosslink committee corresponding to ``attestation_data``.
|
|
||||||
"""
|
|
||||||
crosslink_committees = get_crosslink_committees_at_slot(state, attestation_data.slot)
|
|
||||||
return [committee for committee, shard in crosslink_committees if shard == attestation_data.shard][0]
|
|
||||||
|
|
||||||
|
|
||||||
def get_valid_attestation(state, slot=None):
|
def get_valid_attestation(state, slot=None):
|
||||||
if slot is None:
|
if slot is None:
|
||||||
slot = state.slot
|
slot = state.slot
|
||||||
|
@ -296,7 +288,7 @@ def get_valid_attestation(state, slot=None):
|
||||||
|
|
||||||
attestation_data = build_attestation_data(state, slot, shard)
|
attestation_data = build_attestation_data(state, slot, shard)
|
||||||
|
|
||||||
crosslink_committee = get_crosslink_committee_for_attestation(state, attestation_data)
|
crosslink_committee = get_crosslink_committee(state, attestation_data.target_epoch, attestation_data.shard)
|
||||||
|
|
||||||
committee_size = len(crosslink_committee)
|
committee_size = len(crosslink_committee)
|
||||||
bitfield_length = (committee_size + 7) // 8
|
bitfield_length = (committee_size + 7) // 8
|
||||||
|
@ -383,13 +375,13 @@ def get_attestation_signature(state, attestation_data, privkey, custody_bit=0b0)
|
||||||
domain=get_domain(
|
domain=get_domain(
|
||||||
state=state,
|
state=state,
|
||||||
domain_type=spec.DOMAIN_ATTESTATION,
|
domain_type=spec.DOMAIN_ATTESTATION,
|
||||||
message_epoch=slot_to_epoch(attestation_data.slot),
|
message_epoch=attestation_data.target_epoch,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def fill_aggregate_attestation(state, attestation):
|
def fill_aggregate_attestation(state, attestation):
|
||||||
crosslink_committee = get_crosslink_committee_for_attestation(state, attestation.data)
|
crosslink_committee = get_crosslink_committee(state, attestation.data.target_epoch, attestation.data.shard)
|
||||||
for i in range(len(crosslink_committee)):
|
for i in range(len(crosslink_committee)):
|
||||||
attestation.aggregation_bitfield = set_bitfield_bit(attestation.aggregation_bitfield, i)
|
attestation.aggregation_bitfield = set_bitfield_bit(attestation.aggregation_bitfield, i)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue