From 70bd73d2b5f4b8c3f19a9397fd983c23fd1eecc6 Mon Sep 17 00:00:00 2001
From: Hsiao-Wei Wang <hwwang156@gmail.com>
Date: Mon, 27 Apr 2020 20:47:36 +0800
Subject: [PATCH] Apply PR feedback from @djrtwo

Fix get_eth1_vote test cases
---
 .../test/validator/test_validator_unittest.py | 164 ++++++++++++++----
 1 file changed, 131 insertions(+), 33 deletions(-)

diff --git a/tests/core/pyspec/eth2spec/test/validator/test_validator_unittest.py b/tests/core/pyspec/eth2spec/test/validator/test_validator_unittest.py
index 4f6697f2f..a655cb486 100644
--- a/tests/core/pyspec/eth2spec/test/validator/test_validator_unittest.py
+++ b/tests/core/pyspec/eth2spec/test/validator/test_validator_unittest.py
@@ -8,12 +8,29 @@ from eth2spec.utils import bls
 from eth2spec.utils.ssz.ssz_typing import Bitlist
 
 
-def run_is_candidate_block(spec, eth1_block, period_start, success):
-    result = spec.is_candidate_block(eth1_block, period_start)
-    if success:
-        assert result
+def run_get_signature_test(spec, state, obj, domain, get_signature_fn, privkey, pubkey):
+    signature = get_signature_fn(state, obj, privkey)
+    signing_root = spec.compute_signing_root(obj, domain)
+    assert bls.Verify(pubkey, signing_root, signature)
+
+
+def run_get_committee_assignment(spec, state, epoch, validator_index, valid=True):
+    try:
+        assignment = spec.get_committee_assignment(state, epoch, validator_index)
+        committee, committee_index, slot = assignment
+        assert spec.compute_epoch_at_slot(slot) == epoch
+        assert committee == spec.get_beacon_committee(state, slot, committee_index)
+        assert committee_index < spec.get_committee_count_at_slot(state, slot)
+        assert validator_index in committee
+        assert valid
+    except AssertionError:
+        assert not valid
     else:
-        assert not result
+        assert valid
+
+
+def run_is_candidate_block(spec, eth1_block, period_start, success=True):
+    assert success == spec.is_candidate_block(eth1_block, period_start)
 
 
 def get_min_new_period_epochs(spec):
@@ -60,11 +77,25 @@ def test_check_if_validator_active(spec, state):
 def test_get_committee_assignment_current_epoch(spec, state):
     epoch = spec.get_current_epoch(state)
     validator_index = len(state.validators) - 1
-    assignment = spec.get_committee_assignment(state, epoch, validator_index)
-    committee, committee_index, slot = assignment
-    assert spec.compute_epoch_at_slot(slot) == epoch
-    assert committee == spec.get_beacon_committee(state, slot, committee_index)
-    assert committee_index < spec.get_committee_count_at_slot(state, slot)
+    run_get_committee_assignment(spec, state, epoch, validator_index, valid=True)
+
+
+@with_all_phases
+@spec_state_test
+@never_bls
+def test_get_committee_assignment_next_epoch(spec, state):
+    epoch = spec.get_current_epoch(state) + 1
+    validator_index = len(state.validators) - 1
+    run_get_committee_assignment(spec, state, epoch, validator_index, valid=True)
+
+
+@with_all_phases
+@spec_state_test
+@never_bls
+def test_get_committee_assignment_out_bound_epoch(spec, state):
+    epoch = spec.get_current_epoch(state) + 2
+    validator_index = len(state.validators) - 1
+    run_get_committee_assignment(spec, state, epoch, validator_index, valid=False)
 
 
 @with_all_phases
@@ -92,10 +123,16 @@ def test_get_epoch_signature(spec, state):
     block = spec.BeaconBlock()
     privkey = privkeys[0]
     pubkey = pubkeys[0]
-    signature = spec.get_epoch_signature(state, block, privkey)
     domain = spec.get_domain(state, spec.DOMAIN_RANDAO, spec.compute_epoch_at_slot(block.slot))
-    signing_root = spec.compute_signing_root(spec.compute_epoch_at_slot(block.slot), domain)
-    assert bls.Verify(pubkey, signing_root, signature)
+    run_get_signature_test(
+        spec=spec,
+        state=state,
+        obj=block,
+        domain=domain,
+        get_signature_fn=spec.get_epoch_signature,
+        privkey=privkey,
+        pubkey=pubkey,
+    )
 
 
 @with_all_phases
@@ -130,7 +167,7 @@ def test_is_candidate_block(spec, state):
 
 @with_all_phases
 @spec_state_test
-def test_get_eth1_data_default_vote(spec, state):
+def test_get_eth1_vote_default_vote(spec, state):
     min_new_period_epochs = get_min_new_period_epochs(spec)
     for _ in range(min_new_period_epochs):
         next_epoch(spec, state)
@@ -143,24 +180,61 @@ def test_get_eth1_data_default_vote(spec, state):
 
 @with_all_phases
 @spec_state_test
-def test_get_eth1_data_consensus_vote(spec, state):
+def test_get_eth1_vote_consensus_vote(spec, state):
     min_new_period_epochs = get_min_new_period_epochs(spec)
-    for _ in range(min_new_period_epochs):
+    for _ in range(min_new_period_epochs + 2):
         next_epoch(spec, state)
 
     period_start = spec.voting_period_start_time(state)
     votes_length = spec.get_current_epoch(state) % spec.EPOCHS_PER_ETH1_VOTING_PERIOD
+    assert votes_length >= 3  # We need to have the majority vote
     state.eth1_data_votes = ()
-    eth1_chain = []
+
+    block_1 = spec.Eth1Block(timestamp=period_start - spec.SECONDS_PER_ETH1_BLOCK * spec.ETH1_FOLLOW_DISTANCE - 1)
+    block_2 = spec.Eth1Block(timestamp=period_start - spec.SECONDS_PER_ETH1_BLOCK * spec.ETH1_FOLLOW_DISTANCE)
+    eth1_chain = [block_1, block_2]
     eth1_data_votes = []
-    block = spec.Eth1Block(timestamp=period_start - spec.SECONDS_PER_ETH1_BLOCK * spec.ETH1_FOLLOW_DISTANCE)
+
+    # Only the first vote is for block_1
+    eth1_data_votes.append(spec.get_eth1_data(block_1))
+    # Other votes are for block_2
+    for _ in range(votes_length - 1):
+        eth1_data_votes.append(spec.get_eth1_data(block_2))
+
+    state.eth1_data_votes = eth1_data_votes
+    eth1_data = spec.get_eth1_vote(state, eth1_chain)
+    assert eth1_data.block_hash == block_2.hash_tree_root()
+
+
+@with_all_phases
+@spec_state_test
+def test_get_eth1_vote_tie(spec, state):
+    min_new_period_epochs = get_min_new_period_epochs(spec)
+    for _ in range(min_new_period_epochs + 1):
+        next_epoch(spec, state)
+
+    period_start = spec.voting_period_start_time(state)
+    votes_length = spec.get_current_epoch(state) % spec.EPOCHS_PER_ETH1_VOTING_PERIOD
+    assert votes_length > 0 and votes_length % 2 == 0
+
+    state.eth1_data_votes = ()
+    block_1 = spec.Eth1Block(timestamp=period_start - spec.SECONDS_PER_ETH1_BLOCK * spec.ETH1_FOLLOW_DISTANCE - 1)
+    block_2 = spec.Eth1Block(timestamp=period_start - spec.SECONDS_PER_ETH1_BLOCK * spec.ETH1_FOLLOW_DISTANCE)
+    eth1_chain = [block_1, block_2]
+    eth1_data_votes = []
+    # Half votes are for block_1, another half votes are for block_2
     for i in range(votes_length):
-        eth1_chain.append(block)
+        if i % 2 == 0:
+            block = block_1
+        else:
+            block = block_2
         eth1_data_votes.append(spec.get_eth1_data(block))
 
     state.eth1_data_votes = eth1_data_votes
     eth1_data = spec.get_eth1_vote(state, eth1_chain)
-    assert eth1_data.block_hash == block.hash_tree_root()
+
+    # Tiebreak by smallest distance -> eth1_chain[0]
+    assert eth1_data.block_hash == eth1_chain[0].hash_tree_root()
 
 
 @with_all_phases
@@ -185,10 +259,16 @@ def test_get_block_signature(spec, state):
     privkey = privkeys[0]
     pubkey = pubkeys[0]
     block = build_empty_block(spec, state)
-    signature = spec.get_block_signature(state, block, privkey)
     domain = spec.get_domain(state, spec.DOMAIN_BEACON_PROPOSER, spec.compute_epoch_at_slot(block.slot))
-    signing_root = spec.compute_signing_root(block, domain)
-    assert bls.Verify(pubkey, signing_root, signature)
+    run_get_signature_test(
+        spec=spec,
+        state=state,
+        obj=block,
+        domain=domain,
+        get_signature_fn=spec.get_block_signature,
+        privkey=privkey,
+        pubkey=pubkey,
+    )
 
 
 # Attesting
@@ -200,10 +280,16 @@ def test_get_attestation_signature(spec, state):
     privkey = privkeys[0]
     pubkey = pubkeys[0]
     attestation_data = spec.AttestationData(slot=10)
-    signature = spec.get_attestation_signature(state, attestation_data, privkey)
     domain = spec.get_domain(state, spec.DOMAIN_BEACON_ATTESTER, attestation_data.target.epoch)
-    signing_root = spec.compute_signing_root(attestation_data, domain)
-    assert bls.Verify(pubkey, signing_root, signature)
+    run_get_signature_test(
+        spec=spec,
+        state=state,
+        obj=attestation_data,
+        domain=domain,
+        get_signature_fn=spec.get_attestation_signature,
+        privkey=privkey,
+        pubkey=pubkey,
+    )
 
 
 # Attestation aggregation
@@ -214,11 +300,17 @@ def test_get_attestation_signature(spec, state):
 def test_get_slot_signature(spec, state):
     privkey = privkeys[0]
     pubkey = pubkeys[0]
-    slot = 10
-    signature = spec.get_slot_signature(state, spec.Slot(slot), privkey)
+    slot = spec.Slot(10)
     domain = spec.get_domain(state, spec.DOMAIN_SELECTION_PROOF, spec.compute_epoch_at_slot(slot))
-    signing_root = spec.compute_signing_root(spec.Slot(slot), domain)
-    assert bls.Verify(pubkey, signing_root, signature)
+    run_get_signature_test(
+        spec=spec,
+        state=state,
+        obj=slot,
+        domain=domain,
+        get_signature_fn=spec.get_slot_signature,
+        privkey=privkey,
+        pubkey=pubkey,
+    )
 
 
 @with_all_phases
@@ -290,7 +382,13 @@ def test_get_aggregate_and_proof_signature(spec, state):
     pubkey = pubkeys[0]
     aggregate = get_mock_aggregate(spec)
     aggregate_and_proof = spec.get_aggregate_and_proof(state, spec.ValidatorIndex(1), aggregate, privkey)
-    signature = spec.get_aggregate_and_proof_signature(state, aggregate_and_proof, privkey)
     domain = spec.get_domain(state, spec.DOMAIN_AGGREGATE_AND_PROOF, spec.compute_epoch_at_slot(aggregate.data.slot))
-    signing_root = spec.compute_signing_root(aggregate_and_proof, domain)
-    assert bls.Verify(pubkey, signing_root, signature)
+    run_get_signature_test(
+        spec=spec,
+        state=state,
+        obj=aggregate_and_proof,
+        domain=domain,
+        get_signature_fn=spec.get_aggregate_and_proof_signature,
+        privkey=privkey,
+        pubkey=pubkey,
+    )