From e441352785552ff74b5ddc8713b8d277f8ddda11 Mon Sep 17 00:00:00 2001 From: Vitalik Buterin Date: Wed, 13 Jun 2018 08:11:11 -0400 Subject: [PATCH] Yay! Epoch transitions work! --- beacon_chain_impl/bls.py | 5 +- beacon_chain_impl/full_pos.py | 97 +++++++++++++++++++--------- beacon_chain_impl/simpleserialize.py | 2 + beacon_chain_impl/test_full_pos.py | 12 ++-- 4 files changed, 79 insertions(+), 37 deletions(-) diff --git a/beacon_chain_impl/bls.py b/beacon_chain_impl/bls.py index bbba1d0..b03c1ae 100644 --- a/beacon_chain_impl/bls.py +++ b/beacon_chain_impl/bls.py @@ -1,4 +1,7 @@ -from hashlib import blake2s +try: + from hashlib import blake2s +except: + from pyblake2 import blake2s blake = lambda x: blake2s(x).digest() from py_ecc.optimized_bn128 import G1, G2, add, multiply, FQ, FQ2, pairing, \ normalize, field_modulus, b, b2, is_on_curve, curve_order diff --git a/beacon_chain_impl/full_pos.py b/beacon_chain_impl/full_pos.py index 9762331..e7f174e 100644 --- a/beacon_chain_impl/full_pos.py +++ b/beacon_chain_impl/full_pos.py @@ -1,17 +1,16 @@ -from hashlib import blake2s +try: + from hashlib import blake2s +except: + from pyblake2 import blake2s blake = lambda x: blake2s(x).digest() -from rlp.sedes import big_endian_int, Binary, binary, CountableList, BigEndianInt, Binary -int256 = BigEndianInt(256) -hash32 = Binary.fixed_length(32) -import rlp import bls import random from bls import decompress_G1, aggregate_pubs, verify, sign, privtopub from simpleserialize import deepcopy, serialize, to_dict -SHARD_COUNT = 100 -ATTESTER_COUNT=32 +SHARD_COUNT = 20 +ATTESTER_COUNT = 32 DEFAULT_BALANCE = 20000 class AggregateVote(): @@ -143,6 +142,15 @@ class ActiveState(): for k in self.fields.keys(): assert k in kwargs or k in self.defaults setattr(self, k, kwargs.get(k, self.defaults.get(k))) + +class CrosslinkRecord(): + fields = {'epoch': 'int64', 'hash': 'hash32'} + defaults = {'epoch': 0, 'hash': b'\x00'*32} + + def __init__(self, **kwargs): + for k in self.fields.keys(): + assert k in kwargs or k in self.defaults + setattr(self, k, kwargs.get(k, self.defaults.get(k))) class CrystallizedState(): fields = {'active_validators': [ValidatorRecord], @@ -155,7 +163,7 @@ class CrystallizedState(): 'dynasty': 'int64', 'next_shard': 'int16', 'current_checkpoint': 'hash32', - 'checkpoint_last_crosslinked': ['int64'], + 'crosslink_records': [CrosslinkRecord], 'total_deposits': 'int256'} defaults = {'active_validators': [], 'queued_validators': [], @@ -167,7 +175,7 @@ class CrystallizedState(): 'dynasty': 0, 'next_shard': 0, 'current_checkpoint': b'\x00'*32, - 'checkpoint_last_crosslinked': [], + 'crosslink_records': [], 'total_deposits': 0} def __init__(self, **kwargs): @@ -194,29 +202,39 @@ def get_shard_attesters(crystallized_state, shard_id): def compute_state_transition(parent_state, parent_block, block, verify_sig=True): crystallized_state, active_state = parent_state - # Possibly initialize a new epoch + # Initialize a new epoch if needed if active_state.height % SHARD_COUNT == 0: + print('Processing epoch transition') # Process rewards from FFG/crosslink votes new_validator_records = deepcopy(crystallized_state.active_validators) # Who voted in the last epoch ffg_voter_bitmask = bytearray(active_state.ffg_voter_bitmask) # Total deposit size total_deposits = crystallized_state.total_deposits + # Old total deposit size + td = total_deposits # Number of epochs since last finality finality_distance = crystallized_state.current_epoch - crystallized_state.last_finalized_epoch - online_reward = 3 if finality_distance == 2 else 0 - offline_penalty = 2 * finality_distance + online_reward = 6 if finality_distance <= 2 else 0 + offline_penalty = 3 * finality_distance total_vote_count = 0 total_vote_deposits = 0 - for i in range(len(active_validators)): + total_validators = len(crystallized_state.active_validators) + for i in range(total_validators): if ffg_voter_bitmask[i // 8] & (128 >> (i % 8)): total_vote_deposits += new_validator_records[i].balance new_validator_records[i].balance += online_reward total_vote_count += 1 else: new_validator_records[i].balance -= offline_penalty + print('Total voted: %d of %d validators (%.2f%%), %d of %d deposits (%.2f%%)' % + (total_vote_count, total_validators, total_vote_count * 100 / total_validators, + total_vote_deposits, total_deposits, total_vote_deposits * 100 / total_deposits)) + print('FFG online reward: %d, offline penalty: %d' % (online_reward, offline_penalty)) total_deposits += total_vote_count * online_reward - \ - (len(active_validators) - total_vote_count) * online_penalty + (total_validators - total_vote_count) * offline_penalty + print('Total deposit change from FFG: %d' % (total_deposits - td)) + td = total_deposits # Find the most popular crosslink in each shard main_crosslink = {} for c in active_state.checkpoints: @@ -228,11 +246,13 @@ def compute_state_transition(parent_state, parent_block, block, verify_sig=True) if vote_count > main_crosslink.get(c.shard_id, (b'', 0, b''))[1]: main_crosslink[c.shard_id] = (c.checkpoint_hash, vote_count, mask) # Adjust crosslinks - new_checkpoint_last_crosslinked = deepcopy(crystallized_state.checkpoint_last_crosslinked) + new_crosslink_records = deepcopy(crystallized_state.crosslink_records) for shard in range(SHARD_COUNT): - h, votes, mask = main_crosslink.get(shard, (b'', 0)) - crosslink_distance = crystallized_state.epoch - crystallized_state.checkpoint_last_crosslinked[shard] + print('Processing crosslink data for shard %d' % shard) indices = get_shard_attesters(crystallized_state, shard) + h, votes, mask = main_crosslink.get(shard, (b'', 0, bytearray((len(indices)+7)//8))) + crosslink_distance = crystallized_state.current_epoch - crystallized_state.crosslink_records[shard].epoch + print('Last crosslink from this shard was from epoch %d' % crystallized_state.crosslink_records[shard].epoch) online_reward = 3 if crosslink_distance <= 2 else 0 offline_penalty = crosslink_distance * 2 for i, index in enumerate(indices): @@ -240,22 +260,33 @@ def compute_state_transition(parent_state, parent_block, block, verify_sig=True) new_validator_records[index].balance += online_reward else: new_validator_records[index].balance -= offline_penalty + total_deposits += votes * online_reward - (len(indices) - votes) * offline_penalty + print('Total voters: %d of %d (%.2f%%)' % (votes, len(indices), votes * 100 / len(indices))) + print('Crosslink online reward: %d, offline penalty: %d' % (online_reward, offline_penalty)) # New checkpoint last crosslinked record if votes * 3 >= len(indices) * 2: - new_checkpoint_last_crosslinked[shard] = crystallized_state.epoch - + new_crosslink_records[shard] = CrosslinkRecord(hash=h, epoch=crystallized_state.current_epoch) + print('Finalized checkpoint: %s' % hex(int.from_bytes(h, 'big'))) + print('Total deposit change from crosslinks: %d' % (total_deposits - td)) + td = total_deposits # Process other balance deltas for i in active_state.balance_deltas: if i % 256 <= 128: - new_validator_records[i >> 8] += i % 256 + new_validator_records[i >> 8].balance += i % 256 + total_deposits += i % 256 else: - new_validator_records[i >> 8] += (i % 256) - 256 + new_validator_records[i >> 8].balance += (i % 256) - 256 + total_deposits += (i % 256) - 256 + print('Total deposit change from deltas: %d' % (total_deposits - td)) + print('New total deposits: %d' % total_deposits) # Process finality and validator set changes - justify, finalize = False + justify, finalize = False, False if total_vote_deposits * 3 >= total_deposits * 2: justify = True + print('Justifying last epoch') if crystallized_state.last_justified_epoch == crystallized_state.current_epoch - 1: finalize = True + print('Finalizing last epoch') if finalize: new_active_validators = [v for v in crystallized_state.active_validators] new_exited_validators = [v for v in crystallized_state.exited_validators] @@ -263,12 +294,15 @@ def compute_state_transition(parent_state, parent_block, block, verify_sig=True) while i < len(new_active_validators): if new_validator_records[i].balance <= DEFAULT_BALANCE // 2: new_exited_validators.append(new_validator_records.pop(i)) - elif new_validator_records.switch_dynasty == crystallized_state.dynasty + 1: + elif new_validator_records[i].switch_dynasty == crystallized_state.dynasty + 1: new_exited_validators.append(new_validator_records.pop(i)) else: i += 1 - induct = min(len(crystallized_state.queued_validators), crystallized_state.active_validators // 30 + 1) + induct = min(len(crystallized_state.queued_validators), len(crystallized_state.active_validators) // 30 + 1) for i in range(induct): + if crystallized_state.queued_validators[i].switch_dynasty > crystallized_state.dynasty + 1: + induct = i + break new_active_validators.append(crystallized_state.queued_validators[i]) new_queued_validators = crystallized_state.queued_validators[induct:] else: @@ -280,12 +314,12 @@ def compute_state_transition(parent_state, parent_block, block, verify_sig=True) active_validators=new_active_validators, exited_validators=new_exited_validators, current_shuffling=get_shuffling(active_state.randao, len(new_active_validators)), - last_justified_epoch = crystallized_state.current_epoch if justified else crystallized_state.last_justified_epoch, - last_finalized_epoch = crystallized_state.current_epoch-1 if finalized else crystallized_state.last_finalized_epoch, - dynasty = crystallized_state.dynasty + (1 if finalized else 0), + last_justified_epoch = crystallized_state.current_epoch if justify else crystallized_state.last_justified_epoch, + last_finalized_epoch = crystallized_state.current_epoch-1 if finalize else crystallized_state.last_finalized_epoch, + dynasty = crystallized_state.dynasty + (1 if finalize else 0), next_shard = 0, - current_epoch = parent_block.hash, - checkpoint_last_crosslinked = new_checkpoint_last_crosslinked, + current_epoch = crystallized_state.current_epoch + 1, + crosslink_records = new_crosslink_records, total_deposits = total_deposits ) # Reset the active state @@ -313,7 +347,8 @@ def compute_state_transition(parent_state, parent_block, block, verify_sig=True) print('Verified aggregate sig') # Verify the attestations of checkpoint hashes - checkpoint_votes = {} + checkpoint_votes = {vote.checkpoint_hash + vote.shard_id.to_bytes(2, 'big'): + vote.voter_bitmask for vote in active_state.checkpoints} new_ffg_bitmask = bytearray(active_state.ffg_voter_bitmask) for vote in block.shard_aggregate_votes: attestation = get_checkpoint_aggvote_msg(vote.shard_id, vote.checkpoint_hash, crystallized_state) @@ -371,7 +406,7 @@ def mk_genesis_state_and_block(pubkeys): dynasty=1, next_shard=0, current_checkpoint=blake(b'insert EOS constitution here'), - checkpoint_last_crosslinked=[0] * SHARD_COUNT, + crosslink_records=[CrosslinkRecord(hash=b'\x00'*32, epoch=0) for i in range(SHARD_COUNT)], total_deposits=DEFAULT_BALANCE*len(pubkeys)) a = ActiveState(height=1, randao=b'\x45'*32, diff --git a/beacon_chain_impl/simpleserialize.py b/beacon_chain_impl/simpleserialize.py index 19eef0c..0788840 100644 --- a/beacon_chain_impl/simpleserialize.py +++ b/beacon_chain_impl/simpleserialize.py @@ -17,6 +17,7 @@ def serialize(val, typ=None): elif isinstance(typ, type): sub = b''.join([serialize(getattr(val, k), typ.fields[k]) for k in sorted(typ.fields.keys())]) return len(sub).to_bytes(4, 'big') + sub + raise Exception("Cannot serialize", val, typ) def _deserialize(data, start, typ): if typ in ('hash32', 'address'): @@ -49,6 +50,7 @@ def _deserialize(data, start, typ): values[k], pos = _deserialize(data, pos, typ.fields[k]) assert pos == start + 4 + length return typ(**values), pos + raise Exception("Cannot deserialize", typ) def deserialize(data, typ): return _deserialize(data, 0, typ)[0] diff --git a/beacon_chain_impl/test_full_pos.py b/beacon_chain_impl/test_full_pos.py index 3f2204c..e3e889a 100644 --- a/beacon_chain_impl/test_full_pos.py +++ b/beacon_chain_impl/test_full_pos.py @@ -1,6 +1,6 @@ from full_pos import blake, mk_genesis_state_and_block, compute_state_transition, \ get_attesters_and_signer, Block, get_checkpoint_aggvote_msg, AggregateVote, \ - SHARD_COUNT, ATTESTER_COUNT + SHARD_COUNT, ATTESTER_COUNT, get_shard_attesters import random import bls from simpleserialize import serialize, deserialize, eq, deepcopy @@ -34,8 +34,8 @@ def mock_make_child(parent_state, parent, skips, attester_share=0.8, checkpoint_ # Randomly pick indices to include for checkpoints shard_aggregate_votes = [] for shard, crosslinker_share in checkpoint_shards: - print('Making crosslink in shard %d') - indices = crystallized_state.current_shuffling[(validator_count * shard) // 100: (validator_count * (shard + 1)) // 100] + print('Making crosslink in shard %d' % shard) + indices = get_shard_attesters(crystallized_state, shard) print('Indices: %r' % indices) bitfield = [1 if random.random() < crosslinker_share else 0 for i in indices] bitmask = bytearray((len(bitfield)+7) // 8) @@ -82,8 +82,10 @@ print('Block size:', len(serialize(block))) block2, c2, a2 = mock_make_child((c, a), block, 0, 0.8, []) assert compute_state_transition((c, a), block, block2) print('Verified a block!') -block3, c3, a3 = mock_make_child((c2, a2), block2, 0, 0.8, [(0, 0.7)]) +block3, c3, a3 = mock_make_child((c2, a2), block2, 0, 0.8, [(0, 0.75)]) print('Verified a block with a committee!') while a3.height % SHARD_COUNT > 0: - block3, c3, a3 = mock_make_child((c3, a3), block3, 0, 0.8, [(a3.height, 0.7)]) + block3, c3, a3 = mock_make_child((c3, a3), block3, 0, 0.8, [(a3.height, 0.6 + 0.02 * a3.height)]) print('Height: %d' % a3.height) +print('FFG bitmask:', bin(int.from_bytes(a3.ffg_voter_bitmask, 'big'))) +block4, c4, a4 = mock_make_child((c3, a3), block3, 1, 0.55, [])