From 15c2faf63b37e36d9efb27e854c478ffb93bf632 Mon Sep 17 00:00:00 2001 From: Vitalik Buterin Date: Mon, 11 Jun 2018 23:58:08 -0400 Subject: [PATCH] Completed active state transition function (I think) --- beacon_chain_impl/full_pos.py | 108 ++++++++++++++++++++++----- beacon_chain_impl/simpleserialize.py | 10 +++ beacon_chain_impl/test.py | 13 +--- 3 files changed, 101 insertions(+), 30 deletions(-) diff --git a/beacon_chain_impl/full_pos.py b/beacon_chain_impl/full_pos.py index 5ee7ddc..31d6a96 100644 --- a/beacon_chain_impl/full_pos.py +++ b/beacon_chain_impl/full_pos.py @@ -6,19 +6,22 @@ hash32 = Binary.fixed_length(32) import rlp import bls import random +from bls import decompress_G1, aggregate_pubs, verify privkeys = [int.from_bytes(blake(str(i).encode('utf-8')), 'big') for i in range(30)] keymap = {bls.privtopub(k): k for k in privkeys} +SHARD_COUNT = 100 + class AggregateVote(rlp.Serializable): fields = [ ('shard_id', int256), - ('checkpoint', hash32), + ('checkpoint_hash', hash32), ('signer_bitmask', binary), ('aggregate_sig', int256) ] - def __init__(self, shard_id, checkpoint, signer_bitmask, aggregate_sig): + def __init__(self, shard_id, checkpoint_hash, signer_bitmask, aggregate_sig): # at the beginning of a method, locals() is a dict of all arguments fields = {k: v for k, v in locals().items() if k != 'self'} super(BlockHeader, self).__init__(**fields) @@ -86,7 +89,7 @@ class ValidatorRecord(): class CheckpointRecord(): - fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'} + fields = {'checkpoint_hash': 'hash32', 'voters': 'int16'} defaults = {} def __init__(self, **kwargs): @@ -98,12 +101,12 @@ class CheckpointRecord(): class ActiveState(): fields = {'height': 'int64', 'randao': 'hash32', - 'validator_ffg_voted': 'bytes', 'rewarded': ['int24'], - 'penalized': ['int24'], 'checkpoints': [CheckpointRecord], + 'ffg_voter_bitmask': 'bytes', 'balance_deltas': ['int32'], + 'checkpoints': [CheckpointRecord], 'total_skip_count': 'int64'} defaults = {'height': 0, 'randao': b'\x00'*32, - 'validator_ffg_voted': b'', 'rewarded': [], - 'penalized': [], 'checkpoints': [], 'total_skip_count': 0} + 'ffg_voter_bitmask': b'', 'balance_deltas': [], + 'checkpoints': [], 'total_skip_count': 0} def __init__(self, **kwargs): for k in self.fields.keys(): @@ -115,6 +118,8 @@ class CrystallizedState(): self.fields = {'active_validators': [ValidatorRecord], 'queued_validators': [ValidatorRecord], 'exited_validators': [ValidatorRecord], + 'current_shuffling': ['int24'], + 'current_epoch': 'int64', 'last_justified_epoch': 'int64', 'last_finalized_epoch': 'int64', 'dynasty': 'int64', @@ -123,6 +128,8 @@ class CrystallizedState(): self.defaults = {'active_validators': [], 'queued_validators': [], 'exited_validators': [], + 'current_shuffling': ['int24'], + 'current_epoch': 0, 'last_justified_epoch': 0, 'last_finalized_epoch': 0, 'dynasty': 0, @@ -131,23 +138,82 @@ class CrystallizedState(): for k in self.fields.keys(): assert k in kwargs or k in self.defaults setattr(self, k, kwargs.get(k, self.defaults.get(k))) - -def compute_state_transition(parent_state, block, ): - pass +def get_checkpoint_aggvote_msg(aggvote, crystallized_state): + return aggvote.shard_id.to_bytes(2, 'big') + \ + aggvote.checkpoint_hash + \ + crystallized_state.current_checkpoint + \ + crystallized_state.current_epoch.to_bytes(8, 'big') + \ + crystallized_state.last_justified_epoch.to_bytes(8, 'big') + +def get_attesters_and_signer(crystallized_state, active_state, skip_count): + attestation_count = min(crystallized_state.active_validators, 128) + indices = get_shuffling(active_state.randao, len(crystallized_state.active_validators), + attestation_count + skip_count + 1) + return indices[:attestation_count], indices[-1] + +def get_shard_attesters(crystallized_state, shard_id): + vc = len(crystallized_state.active_validators) + return crystallized_state.current_shuffling[(vc * shard_id) // SHARD_COUNT: (vc * (shard_id + 1)) // SHARD_COUNT] + +def compute_state_transition(parent_state, parent_block, block): + crystallized_state, active_state = parent_state + # Possibly initialize a new epoch + # Process the block-by-block stuff + + # Verify the attestations of the parent + attestation_indices, main_signer = \ + get_attesters_and_signer(crystallized_state, active_state, block.skip_count) + pubs = [] + balance_deltas = [] + assert len(block.attestation_bitmask) == len(attestation_indices + 7) // 8 + for i, index in enumerate(attestation_indices): + if block.attestation_bitmask[i//8] & (1<<(i%8)): + pubs.append(crystallized_state.active_validators[index]) + balance_deltas.append((index << 8) + 1) + assert len(balance_deltas) <= 128 + balance_deltas.append((main_signer << 8) + len(balance_deltas)) + assert verify(parent_block.hash, aggregate_pubs(pubs), block.aggregate_sig) + + # Verify the attestations of checkpoint hashes + checkpoint_votes = {x.checkpoint_hash: x.votes for x in active_state.checkpoints} + new_ffg_bitmask = bytearray(active_state.ffg_voter_bitmask) + for vote in block.shard_aggregate_votes: + attestation = get_checkpoint_aggvote_msg(vote, crystallized_state) + indices = get_shard_attesters(crystallized_state, vote.shard_id) + assert len(vote.signer_bitmask) == len(indices + 7) // 8 + pubs = [] + voters = 0 + for i, index in enumerate(indices): + if (vote.signer_bitmask[i//8] >> (i%8)) % 2: + pubs.append(crystallized_state.active_validators[index]) + if new_ffg_bitmask[index//8] & (1<<(index%8)) == 0: + new_ffg_bitmask[index//8] ^= 1<<(index%8) + voters += 1 + assert verify(attestation, aggregate_pubs(pubs), vote.aggregate_sig) + balance_deltas.append((main_signer << 8) + (voters * 16 // len(indices))) + checkpoint_votes[vote.checkpoint_hash] = checkpoint_votes.get(vote.checkpoint_hash, 0) + voters + + + o = ActiveState(height=active_state.height + 1, + randao=(int.from_bytes(active_state.randao, 'big') ^ + int.from_bytes(block.randao_reveal)).to_bytes(32, 'big'), + total_skip_count=active_state.total_skip_count + block.skip_count, + checkpoints=[CheckpointRecord(checkpoint_hash=h, votes=checkpoint_votes[h]) + for h in sorted(checkpoint_votes.keys())], + ffg_voter_bitmask=new_ffg_bitmask, + balance_deltas=active_state.balance_deltas + balance_deltas) + + + def mock_make_child(parent_state, parent_hash, skips, attester_share=0.8, checkpoint_shards=[]): - parent_attestation_hash = parent_hash + \ - parent_state.checkpoint_hash + \ - parent_state.epoch.to_bytes(32, 'big') + \ - parent_state.source_epoch.to_bytes(32, 'big') - checkpoint_attestation_hash = b'\x00' * 32 + \ - parent_state.checkpoint_hash + \ - parent_state.epoch.to_bytes(32, 'big') + \ - parent_state.source_epoch.to_bytes(32, 'big') + parent_attestation_hash = parent_hash validator_count = len(parent_state.active_validators) + attestation_count = min(parent_state.active_validators, 128) indices = get_shuffling(parent.randao_state, validator_count, - min(parent_state.active_validators, 128)) + attestation_count + skip_count + 1) + main_signer = indices[-1] # Randomly pick indices to include bitfield = [1 if random.random() < attester_share else 0 for i in indices] # Attestations @@ -164,6 +230,10 @@ def mock_make_child(parent_state, parent_hash, skips, attester_share=0.8, checkp bitfield = [1 if random.random() < crosslinker_share else 0 for i in indices] bitmask = bytearray((len(bitfield)-1) // 8 + 1) checkpoint = blake(bytes([shard])) + checkpoint_attestation_hash = checkpoint + \ + parent_state.checkpoint_hash + \ + parent_state.epoch.to_bytes(32, 'big') + \ + parent_state.source_epoch.to_bytes(32, 'big') sigs = [bls.sign(checkpoint, keymap[parent_state.active_validators[indices[i]].pubkey]) for i in range(len(indices)) if bitfield[i]] shard_aggregate_votes.append(AggregateVote(shard, checkpoint, bitmask, bls.aggregate_sig(sigs))) diff --git a/beacon_chain_impl/simpleserialize.py b/beacon_chain_impl/simpleserialize.py index a99ad3d..36cd42d 100644 --- a/beacon_chain_impl/simpleserialize.py +++ b/beacon_chain_impl/simpleserialize.py @@ -52,3 +52,13 @@ def _deserialize(data, start, typ): def deserialize(data, typ): return _deserialize(data, 0, typ)[0] + +def eq(x, y): + if hasattr(x, 'fields') and hasattr(y, 'fields'): + for f in x.fields: + if not eq(getattr(x, f), getattr(y, f)): + print('Unequal:', x, y, f, getattr(x, f), getattr(y, f)) + return False + return True + else: + return x == y diff --git a/beacon_chain_impl/test.py b/beacon_chain_impl/test.py index a2dc078..2aac957 100644 --- a/beacon_chain_impl/test.py +++ b/beacon_chain_impl/test.py @@ -2,7 +2,7 @@ from bls import G1, G2, hash_to_G2, compress_G1, compress_G2, \ decompress_G1, decompress_G2, normalize, multiply, \ sign, privtopub, aggregate_sigs, aggregate_pubs, verify -from simpleserialize import serialize, deserialize +from simpleserialize import serialize, deserialize, eq from full_pos import ActiveState, CheckpointRecord @@ -43,20 +43,11 @@ assert deserialize(b'\x00\x00\x00\x03cow', 'bytes') == b'cow' print('Testing advanced serialization') -def eq(x, y): - if hasattr(x, 'fields') and hasattr(y, 'fields'): - for f in x.fields: - if not eq(getattr(x, f), getattr(y, f)): - print('Unequal:', x, y, f, getattr(x, f), getattr(y, f)) - return False - return True - else: - return x == y s = ActiveState() ds = deserialize(serialize(s, type(s)), type(s)) assert eq(s, ds) s = ActiveState(checkpoints=[CheckpointRecord(checkpoint_hash=b'\x55'*32, bitmask=b'31337dawg')], - height=555, randao=b'\x88'*32, rewarded=[5,7,9,579], penalized=[3]*333) + height=555, randao=b'\x88'*32, balance_deltas=[5,7,9,579] + [3] * 333) ds = deserialize(serialize(s, type(s)), type(s)) assert eq(s, ds)