add some attestation tests. fix genesi crosslink bug

This commit is contained in:
Danny Ryan 2019-03-26 11:27:07 -06:00
parent cdbba3e181
commit a8410b8b84
No known key found for this signature in database
GPG Key ID: 2765A792E42CE07A
4 changed files with 124 additions and 54 deletions

View File

@ -905,21 +905,20 @@ def get_crosslink_committees_at_slot(state: BeaconState,
next_epoch = current_epoch + 1 next_epoch = current_epoch + 1
assert previous_epoch <= epoch <= next_epoch assert previous_epoch <= epoch <= next_epoch
active_validator_indices = get_active_validator_indices( indices = get_active_validator_indices(
state.validator_registry, state.validator_registry,
epoch, epoch,
) )
committees_per_epoch = get_epoch_committee_count(len(active_validator_indices)) committees_per_epoch = get_epoch_committee_count(len(indices))
if epoch == current_epoch: if epoch == current_epoch:
start_shard = state.latest_start_shard start_shard = state.latest_start_shard
elif epoch == previous_epoch: elif epoch == previous_epoch:
start_shard = (state.latest_start_shard - SLOTS_PER_EPOCH * committees_per_epoch) % SHARD_COUNT start_shard = (state.latest_start_shard - committees_per_epoch) % SHARD_COUNT
elif epoch == next_epoch: elif epoch == next_epoch:
current_epoch_committees = get_current_epoch_committee_count(state) current_epoch_committees = get_current_epoch_committee_count(state)
start_shard = (state.latest_start_shard + EPOCH_LENGTH * current_epoch_committees) % SHARD_COUNT start_shard = (state.latest_start_shard + current_epoch_committees) % SHARD_COUNT
indices = get_active_validator_indices(state.validator_registry, epoch)
committees_per_slot = committees_per_epoch // SLOTS_PER_EPOCH committees_per_slot = committees_per_epoch // SLOTS_PER_EPOCH
offset = slot % SLOTS_PER_EPOCH offset = slot % SLOTS_PER_EPOCH
slot_start_shard = (start_shard + committees_per_slot * offset) % SHARD_COUNT slot_start_shard = (start_shard + committees_per_slot * offset) % SHARD_COUNT
@ -1830,7 +1829,7 @@ Run the following function:
```python ```python
def process_crosslinks(state: BeaconState) -> None: def process_crosslinks(state: BeaconState) -> None:
current_epoch = get_current_epoch(state) current_epoch = get_current_epoch(state)
previous_epoch = current_epoch - 1 previous_epoch = max(current_epoch - 1, GENESIS_EPOCH)
next_epoch = current_epoch + 1 next_epoch = current_epoch + 1
for slot in range(get_epoch_start_slot(previous_epoch), get_epoch_start_slot(next_epoch)): for slot in range(get_epoch_start_slot(previous_epoch), get_epoch_start_slot(next_epoch)):
for crosslink_committee, shard in get_crosslink_committees_at_slot(state, slot): for crosslink_committee, shard in get_crosslink_committees_at_slot(state, slot):

View File

@ -0,0 +1,67 @@
from copy import deepcopy
import pytest
import build.phase0.spec as spec
from build.phase0.state_transition import (
state_transition,
)
from build.phase0.spec import (
ZERO_HASH,
get_current_epoch,
process_attestation,
slot_to_epoch,
)
from tests.phase0.helpers import (
build_empty_block_for_next_slot,
get_valid_attestation,
)
# mark entire file as 'attestations'
pytestmark = pytest.mark.attestations
def run_attestation_processing(state, attestation, valid=True):
"""
Run ``process_attestation`` returning the pre and post state.
If ``valid == False``, run expecting ``AssertionError``
"""
post_state = deepcopy(state)
if not valid:
with pytest.raises(AssertionError):
process_attestation(post_state, attestation)
return state, None
process_attestation(post_state, attestation)
current_epoch = get_current_epoch(state)
target_epoch = slot_to_epoch(attestation.data.slot)
if target_epoch == current_epoch:
assert len(post_state.current_epoch_attestations) == len(state.current_epoch_attestations) + 1
else:
assert len(post_state.previous_epoch_attestations) == len(state.previous_epoch_attestations) + 1
return state, post_state
def test_success(state):
attestation = get_valid_attestation(state)
state.slot += spec.MIN_ATTESTATION_INCLUSION_DELAY
pre_state, post_state = run_attestation_processing(state, attestation)
return pre_state, attestation, post_state
def test_success_prevous_epoch(state):
attestation = get_valid_attestation(state)
block = build_empty_block_for_next_slot(state)
block.slot = state.slot + spec.SLOTS_PER_EPOCH
state_transition(state, block)
pre_state, post_state = run_attestation_processing(state, attestation)
return pre_state, attestation, post_state

View File

@ -9,7 +9,9 @@ from build.phase0.spec import (
EMPTY_SIGNATURE, EMPTY_SIGNATURE,
ZERO_HASH, ZERO_HASH,
# SSZ # SSZ
Attestation,
AttestationData, AttestationData,
AttestationDataAndCustodyBit,
BeaconBlockHeader, BeaconBlockHeader,
Deposit, Deposit,
DepositData, DepositData,
@ -18,7 +20,9 @@ from build.phase0.spec import (
VoluntaryExit, VoluntaryExit,
# functions # functions
get_active_validator_indices, get_active_validator_indices,
get_attestation_participants,
get_block_root, get_block_root,
get_crosslink_committees_at_slot,
get_current_epoch, get_current_epoch,
get_domain, get_domain,
get_empty_block, get_empty_block,
@ -236,3 +240,49 @@ def get_valid_proposer_slashing(state):
header_1=header_1, header_1=header_1,
header_2=header_2, header_2=header_2,
) )
def get_valid_attestation(state, slot=None):
if slot is None:
slot = state.slot
shard = state.latest_start_shard
attestation_data = build_attestation_data(state, slot, shard)
crosslink_committees = get_crosslink_committees_at_slot(state, slot)
crosslink_committee = [committee for committee, _shard in crosslink_committees if _shard == attestation_data.shard][0]
committee_size = len(crosslink_committee)
bitfield_length = (committee_size + 7) // 8
aggregation_bitfield = b'\x01' + b'\x00' * (bitfield_length - 1)
custody_bitfield = b'\x00' * bitfield_length
attestation = Attestation(
aggregation_bitfield=aggregation_bitfield,
data=attestation_data,
custody_bitfield=custody_bitfield,
aggregate_signature=EMPTY_SIGNATURE,
)
participants = get_attestation_participants(
state,
attestation.data,
attestation.aggregation_bitfield,
)
assert len(participants) == 1
validator_index = participants[0]
privkey = privkeys[validator_index]
message_hash = AttestationDataAndCustodyBit(
data=attestation.data,
custody_bit=0b0,
).hash_tree_root()
attestation.aggregation_signature = bls.sign(
message_hash=message_hash,
privkey=privkey,
domain=get_domain(
fork=state.fork,
epoch=get_current_epoch(state),
domain_type=spec.DOMAIN_ATTESTATION,
)
)
return attestation

View File

@ -11,19 +11,13 @@ from build.phase0.spec import (
EMPTY_SIGNATURE, EMPTY_SIGNATURE,
ZERO_HASH, ZERO_HASH,
# SSZ # SSZ
Attestation,
AttestationDataAndCustodyBit,
BeaconBlockHeader,
Deposit, Deposit,
Transfer, Transfer,
ProposerSlashing,
VoluntaryExit, VoluntaryExit,
# functions # functions
get_active_validator_indices, get_active_validator_indices,
get_attestation_participants,
get_balance, get_balance,
get_block_root, get_block_root,
get_crosslink_committees_at_slot,
get_current_epoch, get_current_epoch,
get_domain, get_domain,
get_state_root, get_state_root,
@ -42,10 +36,10 @@ from build.phase0.utils.merkle_minimal import (
get_merkle_root, get_merkle_root,
) )
from tests.phase0.helpers import ( from tests.phase0.helpers import (
build_attestation_data,
build_deposit_data, build_deposit_data,
build_empty_block_for_next_slot, build_empty_block_for_next_slot,
force_registry_change_at_next_epoch, force_registry_change_at_next_epoch,
get_valid_attestation,
get_valid_proposer_slashing, get_valid_proposer_slashing,
privkeys, privkeys,
pubkeys, pubkeys,
@ -222,47 +216,7 @@ def test_deposit_top_up(state):
def test_attestation(state): def test_attestation(state):
test_state = deepcopy(state) test_state = deepcopy(state)
slot = state.slot attestation = get_valid_attestation(state)
shard = state.latest_start_shard
attestation_data = build_attestation_data(state, slot, shard)
crosslink_committees = get_crosslink_committees_at_slot(state, slot)
crosslink_committee = [committee for committee, _shard in crosslink_committees if _shard == attestation_data.shard][0]
committee_size = len(crosslink_committee)
bitfield_length = (committee_size + 7) // 8
aggregation_bitfield = b'\x01' + b'\x00' * (bitfield_length - 1)
custody_bitfield = b'\x00' * bitfield_length
attestation = Attestation(
aggregation_bitfield=aggregation_bitfield,
data=attestation_data,
custody_bitfield=custody_bitfield,
aggregate_signature=EMPTY_SIGNATURE,
)
participants = get_attestation_participants(
test_state,
attestation.data,
attestation.aggregation_bitfield,
)
assert len(participants) == 1
validator_index = participants[0]
privkey = privkeys[validator_index]
message_hash = AttestationDataAndCustodyBit(
data=attestation.data,
custody_bit=0b0,
).hash_tree_root()
attestation.aggregation_signature = bls.sign(
message_hash=message_hash,
privkey=privkey,
domain=get_domain(
fork=test_state.fork,
epoch=get_current_epoch(test_state),
domain_type=spec.DOMAIN_ATTESTATION,
)
)
# #
# Add to state via block transition # Add to state via block transition