add some attestation tests. fix genesi crosslink bug
This commit is contained in:
parent
cdbba3e181
commit
a8410b8b84
|
@ -905,21 +905,20 @@ def get_crosslink_committees_at_slot(state: BeaconState,
|
||||||
next_epoch = current_epoch + 1
|
next_epoch = current_epoch + 1
|
||||||
|
|
||||||
assert previous_epoch <= epoch <= next_epoch
|
assert previous_epoch <= epoch <= next_epoch
|
||||||
active_validator_indices = get_active_validator_indices(
|
indices = get_active_validator_indices(
|
||||||
state.validator_registry,
|
state.validator_registry,
|
||||||
epoch,
|
epoch,
|
||||||
)
|
)
|
||||||
committees_per_epoch = get_epoch_committee_count(len(active_validator_indices))
|
committees_per_epoch = get_epoch_committee_count(len(indices))
|
||||||
|
|
||||||
if epoch == current_epoch:
|
if epoch == current_epoch:
|
||||||
start_shard = state.latest_start_shard
|
start_shard = state.latest_start_shard
|
||||||
elif epoch == previous_epoch:
|
elif epoch == previous_epoch:
|
||||||
start_shard = (state.latest_start_shard - SLOTS_PER_EPOCH * committees_per_epoch) % SHARD_COUNT
|
start_shard = (state.latest_start_shard - committees_per_epoch) % SHARD_COUNT
|
||||||
elif epoch == next_epoch:
|
elif epoch == next_epoch:
|
||||||
current_epoch_committees = get_current_epoch_committee_count(state)
|
current_epoch_committees = get_current_epoch_committee_count(state)
|
||||||
start_shard = (state.latest_start_shard + EPOCH_LENGTH * current_epoch_committees) % SHARD_COUNT
|
start_shard = (state.latest_start_shard + current_epoch_committees) % SHARD_COUNT
|
||||||
|
|
||||||
indices = get_active_validator_indices(state.validator_registry, epoch)
|
|
||||||
committees_per_slot = committees_per_epoch // SLOTS_PER_EPOCH
|
committees_per_slot = committees_per_epoch // SLOTS_PER_EPOCH
|
||||||
offset = slot % SLOTS_PER_EPOCH
|
offset = slot % SLOTS_PER_EPOCH
|
||||||
slot_start_shard = (start_shard + committees_per_slot * offset) % SHARD_COUNT
|
slot_start_shard = (start_shard + committees_per_slot * offset) % SHARD_COUNT
|
||||||
|
@ -1830,7 +1829,7 @@ Run the following function:
|
||||||
```python
|
```python
|
||||||
def process_crosslinks(state: BeaconState) -> None:
|
def process_crosslinks(state: BeaconState) -> None:
|
||||||
current_epoch = get_current_epoch(state)
|
current_epoch = get_current_epoch(state)
|
||||||
previous_epoch = current_epoch - 1
|
previous_epoch = max(current_epoch - 1, GENESIS_EPOCH)
|
||||||
next_epoch = current_epoch + 1
|
next_epoch = current_epoch + 1
|
||||||
for slot in range(get_epoch_start_slot(previous_epoch), get_epoch_start_slot(next_epoch)):
|
for slot in range(get_epoch_start_slot(previous_epoch), get_epoch_start_slot(next_epoch)):
|
||||||
for crosslink_committee, shard in get_crosslink_committees_at_slot(state, slot):
|
for crosslink_committee, shard in get_crosslink_committees_at_slot(state, slot):
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import build.phase0.spec as spec
|
||||||
|
|
||||||
|
from build.phase0.state_transition import (
|
||||||
|
state_transition,
|
||||||
|
)
|
||||||
|
from build.phase0.spec import (
|
||||||
|
ZERO_HASH,
|
||||||
|
get_current_epoch,
|
||||||
|
process_attestation,
|
||||||
|
slot_to_epoch,
|
||||||
|
)
|
||||||
|
from tests.phase0.helpers import (
|
||||||
|
build_empty_block_for_next_slot,
|
||||||
|
get_valid_attestation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# mark entire file as 'attestations'
|
||||||
|
pytestmark = pytest.mark.attestations
|
||||||
|
|
||||||
|
|
||||||
|
def run_attestation_processing(state, attestation, valid=True):
|
||||||
|
"""
|
||||||
|
Run ``process_attestation`` returning the pre and post state.
|
||||||
|
If ``valid == False``, run expecting ``AssertionError``
|
||||||
|
"""
|
||||||
|
post_state = deepcopy(state)
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
process_attestation(post_state, attestation)
|
||||||
|
return state, None
|
||||||
|
|
||||||
|
process_attestation(post_state, attestation)
|
||||||
|
|
||||||
|
current_epoch = get_current_epoch(state)
|
||||||
|
target_epoch = slot_to_epoch(attestation.data.slot)
|
||||||
|
if target_epoch == current_epoch:
|
||||||
|
assert len(post_state.current_epoch_attestations) == len(state.current_epoch_attestations) + 1
|
||||||
|
else:
|
||||||
|
assert len(post_state.previous_epoch_attestations) == len(state.previous_epoch_attestations) + 1
|
||||||
|
|
||||||
|
|
||||||
|
return state, post_state
|
||||||
|
|
||||||
|
|
||||||
|
def test_success(state):
|
||||||
|
attestation = get_valid_attestation(state)
|
||||||
|
state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY
|
||||||
|
|
||||||
|
pre_state, post_state = run_attestation_processing(state, attestation)
|
||||||
|
|
||||||
|
return pre_state, attestation, post_state
|
||||||
|
|
||||||
|
|
||||||
|
def test_success_prevous_epoch(state):
|
||||||
|
attestation = get_valid_attestation(state)
|
||||||
|
block = build_empty_block_for_next_slot(state)
|
||||||
|
block.slot = state.slot + spec.SLOTS_PER_EPOCH
|
||||||
|
state_transition(state, block)
|
||||||
|
|
||||||
|
pre_state, post_state = run_attestation_processing(state, attestation)
|
||||||
|
|
||||||
|
return pre_state, attestation, post_state
|
|
@ -9,7 +9,9 @@ from build.phase0.spec import (
|
||||||
EMPTY_SIGNATURE,
|
EMPTY_SIGNATURE,
|
||||||
ZERO_HASH,
|
ZERO_HASH,
|
||||||
# SSZ
|
# SSZ
|
||||||
|
Attestation,
|
||||||
AttestationData,
|
AttestationData,
|
||||||
|
AttestationDataAndCustodyBit,
|
||||||
BeaconBlockHeader,
|
BeaconBlockHeader,
|
||||||
Deposit,
|
Deposit,
|
||||||
DepositData,
|
DepositData,
|
||||||
|
@ -18,7 +20,9 @@ from build.phase0.spec import (
|
||||||
VoluntaryExit,
|
VoluntaryExit,
|
||||||
# functions
|
# functions
|
||||||
get_active_validator_indices,
|
get_active_validator_indices,
|
||||||
|
get_attestation_participants,
|
||||||
get_block_root,
|
get_block_root,
|
||||||
|
get_crosslink_committees_at_slot,
|
||||||
get_current_epoch,
|
get_current_epoch,
|
||||||
get_domain,
|
get_domain,
|
||||||
get_empty_block,
|
get_empty_block,
|
||||||
|
@ -236,3 +240,49 @@ def get_valid_proposer_slashing(state):
|
||||||
header_1=header_1,
|
header_1=header_1,
|
||||||
header_2=header_2,
|
header_2=header_2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_valid_attestation(state, slot=None):
|
||||||
|
if slot is None:
|
||||||
|
slot = state.slot
|
||||||
|
shard = state.latest_start_shard
|
||||||
|
attestation_data = build_attestation_data(state, slot, shard)
|
||||||
|
|
||||||
|
crosslink_committees = get_crosslink_committees_at_slot(state, slot)
|
||||||
|
crosslink_committee = [committee for committee, _shard in crosslink_committees if _shard == attestation_data.shard][0]
|
||||||
|
|
||||||
|
committee_size = len(crosslink_committee)
|
||||||
|
bitfield_length = (committee_size + 7) // 8
|
||||||
|
aggregation_bitfield = b'\x01' + b'\x00' * (bitfield_length - 1)
|
||||||
|
custody_bitfield = b'\x00' * bitfield_length
|
||||||
|
attestation = Attestation(
|
||||||
|
aggregation_bitfield=aggregation_bitfield,
|
||||||
|
data=attestation_data,
|
||||||
|
custody_bitfield=custody_bitfield,
|
||||||
|
aggregate_signature=EMPTY_SIGNATURE,
|
||||||
|
)
|
||||||
|
participants = get_attestation_participants(
|
||||||
|
state,
|
||||||
|
attestation.data,
|
||||||
|
attestation.aggregation_bitfield,
|
||||||
|
)
|
||||||
|
assert len(participants) == 1
|
||||||
|
|
||||||
|
validator_index = participants[0]
|
||||||
|
privkey = privkeys[validator_index]
|
||||||
|
|
||||||
|
message_hash = AttestationDataAndCustodyBit(
|
||||||
|
data=attestation.data,
|
||||||
|
custody_bit=0b0,
|
||||||
|
).hash_tree_root()
|
||||||
|
|
||||||
|
attestation.aggregation_signature = bls.sign(
|
||||||
|
message_hash=message_hash,
|
||||||
|
privkey=privkey,
|
||||||
|
domain=get_domain(
|
||||||
|
fork=state.fork,
|
||||||
|
epoch=get_current_epoch(state),
|
||||||
|
domain_type=spec.DOMAIN_ATTESTATION,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return attestation
|
||||||
|
|
|
@ -11,19 +11,13 @@ from build.phase0.spec import (
|
||||||
EMPTY_SIGNATURE,
|
EMPTY_SIGNATURE,
|
||||||
ZERO_HASH,
|
ZERO_HASH,
|
||||||
# SSZ
|
# SSZ
|
||||||
Attestation,
|
|
||||||
AttestationDataAndCustodyBit,
|
|
||||||
BeaconBlockHeader,
|
|
||||||
Deposit,
|
Deposit,
|
||||||
Transfer,
|
Transfer,
|
||||||
ProposerSlashing,
|
|
||||||
VoluntaryExit,
|
VoluntaryExit,
|
||||||
# functions
|
# functions
|
||||||
get_active_validator_indices,
|
get_active_validator_indices,
|
||||||
get_attestation_participants,
|
|
||||||
get_balance,
|
get_balance,
|
||||||
get_block_root,
|
get_block_root,
|
||||||
get_crosslink_committees_at_slot,
|
|
||||||
get_current_epoch,
|
get_current_epoch,
|
||||||
get_domain,
|
get_domain,
|
||||||
get_state_root,
|
get_state_root,
|
||||||
|
@ -42,10 +36,10 @@ from build.phase0.utils.merkle_minimal import (
|
||||||
get_merkle_root,
|
get_merkle_root,
|
||||||
)
|
)
|
||||||
from tests.phase0.helpers import (
|
from tests.phase0.helpers import (
|
||||||
build_attestation_data,
|
|
||||||
build_deposit_data,
|
build_deposit_data,
|
||||||
build_empty_block_for_next_slot,
|
build_empty_block_for_next_slot,
|
||||||
force_registry_change_at_next_epoch,
|
force_registry_change_at_next_epoch,
|
||||||
|
get_valid_attestation,
|
||||||
get_valid_proposer_slashing,
|
get_valid_proposer_slashing,
|
||||||
privkeys,
|
privkeys,
|
||||||
pubkeys,
|
pubkeys,
|
||||||
|
@ -222,47 +216,7 @@ def test_deposit_top_up(state):
|
||||||
|
|
||||||
def test_attestation(state):
|
def test_attestation(state):
|
||||||
test_state = deepcopy(state)
|
test_state = deepcopy(state)
|
||||||
slot = state.slot
|
attestation = get_valid_attestation(state)
|
||||||
shard = state.latest_start_shard
|
|
||||||
attestation_data = build_attestation_data(state, slot, shard)
|
|
||||||
|
|
||||||
crosslink_committees = get_crosslink_committees_at_slot(state, slot)
|
|
||||||
crosslink_committee = [committee for committee, _shard in crosslink_committees if _shard == attestation_data.shard][0]
|
|
||||||
|
|
||||||
committee_size = len(crosslink_committee)
|
|
||||||
bitfield_length = (committee_size + 7) // 8
|
|
||||||
aggregation_bitfield = b'\x01' + b'\x00' * (bitfield_length - 1)
|
|
||||||
custody_bitfield = b'\x00' * bitfield_length
|
|
||||||
attestation = Attestation(
|
|
||||||
aggregation_bitfield=aggregation_bitfield,
|
|
||||||
data=attestation_data,
|
|
||||||
custody_bitfield=custody_bitfield,
|
|
||||||
aggregate_signature=EMPTY_SIGNATURE,
|
|
||||||
)
|
|
||||||
participants = get_attestation_participants(
|
|
||||||
test_state,
|
|
||||||
attestation.data,
|
|
||||||
attestation.aggregation_bitfield,
|
|
||||||
)
|
|
||||||
assert len(participants) == 1
|
|
||||||
|
|
||||||
validator_index = participants[0]
|
|
||||||
privkey = privkeys[validator_index]
|
|
||||||
|
|
||||||
message_hash = AttestationDataAndCustodyBit(
|
|
||||||
data=attestation.data,
|
|
||||||
custody_bit=0b0,
|
|
||||||
).hash_tree_root()
|
|
||||||
|
|
||||||
attestation.aggregation_signature = bls.sign(
|
|
||||||
message_hash=message_hash,
|
|
||||||
privkey=privkey,
|
|
||||||
domain=get_domain(
|
|
||||||
fork=test_state.fork,
|
|
||||||
epoch=get_current_epoch(test_state),
|
|
||||||
domain_type=spec.DOMAIN_ATTESTATION,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Add to state via block transition
|
# Add to state via block transition
|
||||||
|
|
Loading…
Reference in New Issue