Completed active state transition function (I think)

This commit is contained in:
Vitalik Buterin 2018-06-11 23:58:08 -04:00
parent 058d45f0b5
commit 15c2faf63b
3 changed files with 101 additions and 30 deletions

View File

@ -6,19 +6,22 @@ hash32 = Binary.fixed_length(32)
import rlp
import bls
import random
from bls import decompress_G1, aggregate_pubs, verify
privkeys = [int.from_bytes(blake(str(i).encode('utf-8')), 'big') for i in range(30)]
keymap = {bls.privtopub(k): k for k in privkeys}
SHARD_COUNT = 100
class AggregateVote(rlp.Serializable):
fields = [
('shard_id', int256),
('checkpoint', hash32),
('checkpoint_hash', hash32),
('signer_bitmask', binary),
('aggregate_sig', int256)
]
def __init__(self, shard_id, checkpoint, signer_bitmask, aggregate_sig):
def __init__(self, shard_id, checkpoint_hash, signer_bitmask, aggregate_sig):
# at the beginning of a method, locals() is a dict of all arguments
fields = {k: v for k, v in locals().items() if k != 'self'}
super(BlockHeader, self).__init__(**fields)
@ -86,7 +89,7 @@ class ValidatorRecord():
class CheckpointRecord():
fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'}
fields = {'checkpoint_hash': 'hash32', 'voters': 'int16'}
defaults = {}
def __init__(self, **kwargs):
@ -98,12 +101,12 @@ class CheckpointRecord():
class ActiveState():
fields = {'height': 'int64', 'randao': 'hash32',
'validator_ffg_voted': 'bytes', 'rewarded': ['int24'],
'penalized': ['int24'], 'checkpoints': [CheckpointRecord],
'ffg_voter_bitmask': 'bytes', 'balance_deltas': ['int32'],
'checkpoints': [CheckpointRecord],
'total_skip_count': 'int64'}
defaults = {'height': 0, 'randao': b'\x00'*32,
'validator_ffg_voted': b'', 'rewarded': [],
'penalized': [], 'checkpoints': [], 'total_skip_count': 0}
'ffg_voter_bitmask': b'', 'balance_deltas': [],
'checkpoints': [], 'total_skip_count': 0}
def __init__(self, **kwargs):
for k in self.fields.keys():
@ -115,6 +118,8 @@ class CrystallizedState():
self.fields = {'active_validators': [ValidatorRecord],
'queued_validators': [ValidatorRecord],
'exited_validators': [ValidatorRecord],
'current_shuffling': ['int24'],
'current_epoch': 'int64',
'last_justified_epoch': 'int64',
'last_finalized_epoch': 'int64',
'dynasty': 'int64',
@ -123,6 +128,8 @@ class CrystallizedState():
self.defaults = {'active_validators': [],
'queued_validators': [],
'exited_validators': [],
'current_shuffling': ['int24'],
'current_epoch': 0,
'last_justified_epoch': 0,
'last_finalized_epoch': 0,
'dynasty': 0,
@ -131,23 +138,82 @@ class CrystallizedState():
for k in self.fields.keys():
assert k in kwargs or k in self.defaults
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
def compute_state_transition(parent_state, block, ):
pass
def get_checkpoint_aggvote_msg(aggvote, crystallized_state):
return aggvote.shard_id.to_bytes(2, 'big') + \
aggvote.checkpoint_hash + \
crystallized_state.current_checkpoint + \
crystallized_state.current_epoch.to_bytes(8, 'big') + \
crystallized_state.last_justified_epoch.to_bytes(8, 'big')
def get_attesters_and_signer(crystallized_state, active_state, skip_count):
attestation_count = min(crystallized_state.active_validators, 128)
indices = get_shuffling(active_state.randao, len(crystallized_state.active_validators),
attestation_count + skip_count + 1)
return indices[:attestation_count], indices[-1]
def get_shard_attesters(crystallized_state, shard_id):
vc = len(crystallized_state.active_validators)
return crystallized_state.current_shuffling[(vc * shard_id) // SHARD_COUNT: (vc * (shard_id + 1)) // SHARD_COUNT]
def compute_state_transition(parent_state, parent_block, block):
crystallized_state, active_state = parent_state
# Possibly initialize a new epoch
# Process the block-by-block stuff
# Verify the attestations of the parent
attestation_indices, main_signer = \
get_attesters_and_signer(crystallized_state, active_state, block.skip_count)
pubs = []
balance_deltas = []
assert len(block.attestation_bitmask) == len(attestation_indices + 7) // 8
for i, index in enumerate(attestation_indices):
if block.attestation_bitmask[i//8] & (1<<(i%8)):
pubs.append(crystallized_state.active_validators[index])
balance_deltas.append((index << 8) + 1)
assert len(balance_deltas) <= 128
balance_deltas.append((main_signer << 8) + len(balance_deltas))
assert verify(parent_block.hash, aggregate_pubs(pubs), block.aggregate_sig)
# Verify the attestations of checkpoint hashes
checkpoint_votes = {x.checkpoint_hash: x.votes for x in active_state.checkpoints}
new_ffg_bitmask = bytearray(active_state.ffg_voter_bitmask)
for vote in block.shard_aggregate_votes:
attestation = get_checkpoint_aggvote_msg(vote, crystallized_state)
indices = get_shard_attesters(crystallized_state, vote.shard_id)
assert len(vote.signer_bitmask) == len(indices + 7) // 8
pubs = []
voters = 0
for i, index in enumerate(indices):
if (vote.signer_bitmask[i//8] >> (i%8)) % 2:
pubs.append(crystallized_state.active_validators[index])
if new_ffg_bitmask[index//8] & (1<<(index%8)) == 0:
new_ffg_bitmask[index//8] ^= 1<<(index%8)
voters += 1
assert verify(attestation, aggregate_pubs(pubs), vote.aggregate_sig)
balance_deltas.append((main_signer << 8) + (voters * 16 // len(indices)))
checkpoint_votes[vote.checkpoint_hash] = checkpoint_votes.get(vote.checkpoint_hash, 0) + voters
o = ActiveState(height=active_state.height + 1,
randao=(int.from_bytes(active_state.randao, 'big') ^
int.from_bytes(block.randao_reveal)).to_bytes(32, 'big'),
total_skip_count=active_state.total_skip_count + block.skip_count,
checkpoints=[CheckpointRecord(checkpoint_hash=h, votes=checkpoint_votes[h])
for h in sorted(checkpoint_votes.keys())],
ffg_voter_bitmask=new_ffg_bitmask,
balance_deltas=active_state.balance_deltas + balance_deltas)
def mock_make_child(parent_state, parent_hash, skips, attester_share=0.8, checkpoint_shards=[]):
parent_attestation_hash = parent_hash + \
parent_state.checkpoint_hash + \
parent_state.epoch.to_bytes(32, 'big') + \
parent_state.source_epoch.to_bytes(32, 'big')
checkpoint_attestation_hash = b'\x00' * 32 + \
parent_state.checkpoint_hash + \
parent_state.epoch.to_bytes(32, 'big') + \
parent_state.source_epoch.to_bytes(32, 'big')
parent_attestation_hash = parent_hash
validator_count = len(parent_state.active_validators)
attestation_count = min(parent_state.active_validators, 128)
indices = get_shuffling(parent.randao_state, validator_count,
min(parent_state.active_validators, 128))
attestation_count + skip_count + 1)
main_signer = indices[-1]
# Randomly pick indices to include
bitfield = [1 if random.random() < attester_share else 0 for i in indices]
# Attestations
@ -164,6 +230,10 @@ def mock_make_child(parent_state, parent_hash, skips, attester_share=0.8, checkp
bitfield = [1 if random.random() < crosslinker_share else 0 for i in indices]
bitmask = bytearray((len(bitfield)-1) // 8 + 1)
checkpoint = blake(bytes([shard]))
checkpoint_attestation_hash = checkpoint + \
parent_state.checkpoint_hash + \
parent_state.epoch.to_bytes(32, 'big') + \
parent_state.source_epoch.to_bytes(32, 'big')
sigs = [bls.sign(checkpoint, keymap[parent_state.active_validators[indices[i]].pubkey])
for i in range(len(indices)) if bitfield[i]]
shard_aggregate_votes.append(AggregateVote(shard, checkpoint, bitmask, bls.aggregate_sig(sigs)))

View File

@ -52,3 +52,13 @@ def _deserialize(data, start, typ):
def deserialize(data, typ):
return _deserialize(data, 0, typ)[0]
def eq(x, y):
if hasattr(x, 'fields') and hasattr(y, 'fields'):
for f in x.fields:
if not eq(getattr(x, f), getattr(y, f)):
print('Unequal:', x, y, f, getattr(x, f), getattr(y, f))
return False
return True
else:
return x == y

View File

@ -2,7 +2,7 @@ from bls import G1, G2, hash_to_G2, compress_G1, compress_G2, \
decompress_G1, decompress_G2, normalize, multiply, \
sign, privtopub, aggregate_sigs, aggregate_pubs, verify
from simpleserialize import serialize, deserialize
from simpleserialize import serialize, deserialize, eq
from full_pos import ActiveState, CheckpointRecord
@ -43,20 +43,11 @@ assert deserialize(b'\x00\x00\x00\x03cow', 'bytes') == b'cow'
print('Testing advanced serialization')
def eq(x, y):
if hasattr(x, 'fields') and hasattr(y, 'fields'):
for f in x.fields:
if not eq(getattr(x, f), getattr(y, f)):
print('Unequal:', x, y, f, getattr(x, f), getattr(y, f))
return False
return True
else:
return x == y
s = ActiveState()
ds = deserialize(serialize(s, type(s)), type(s))
assert eq(s, ds)
s = ActiveState(checkpoints=[CheckpointRecord(checkpoint_hash=b'\x55'*32, bitmask=b'31337dawg')],
height=555, randao=b'\x88'*32, rewarded=[5,7,9,579], penalized=[3]*333)
height=555, randao=b'\x88'*32, balance_deltas=[5,7,9,579] + [3] * 333)
ds = deserialize(serialize(s, type(s)), type(s))
assert eq(s, ds)