Further commit to full PoS chain

This commit is contained in:
Vitalik Buterin 2018-06-11 11:14:51 -04:00
parent 6250c90ed2
commit de85691047
4 changed files with 234 additions and 58 deletions

View File

@ -83,25 +83,3 @@ def aggregate_pubs(pubs):
for p in pubs:
o = add(o, decompress_G1(p))
return compress_G1(o)
for x in (1, 5, 124, 735, 127409812145, 90768492698215092512159, 0):
print('Testing with privkey %d' % x)
p1 = multiply(G1, x)
p2 = multiply(G2, x)
msg = str(x).encode('utf-8')
msghash = hash_to_G2(msg)
assert normalize(decompress_G1(compress_G1(p1))) == normalize(p1)
assert normalize(decompress_G2(compress_G2(p2))) == normalize(p2)
assert normalize(decompress_G2(compress_G2(msghash))) == normalize(msghash)
sig = sign(msg, x)
pub = privtopub(x)
assert verify(msg, pub, sig)
print('Testing signature aggregation')
msg = b'cow'
keys = [1, 5, 124, 735, 127409812145, 90768492698215092512159, 0]
sigs = [sign(msg, k) for k in keys]
pubs = [privtopub(k) for k in keys]
aggsig = aggregate_sigs(sigs)
aggpub = aggregate_pubs(pubs)
assert verify(msg, aggpub, aggsig)

View File

@ -1,68 +1,168 @@
from hashlib import blake2s
blake = lambda x: blake2s(x).digest()
from ethereum.utils import normalize_address, hash32, trie_root, \
big_endian_int, address, int256, encode_hex, decode_hex, encode_int, \
big_endian_to_int
from rlp.sedes import big_endian_int, Binary, binary, CountableList
from rlp.sedes import big_endian_int, Binary, binary, CountableList, BigEndianInt, Binary
int256 = BigEndianInt(256)
hash32 = Binary.fixed_length(32)
import rlp
import bls
import random
privkeys = [int.from_bytes(blake(str(i).encode('utf-8')), 'big') for i in range(30)]
keymap = {bls.privtopub(k): k for k in privkeys}
class AggregateVote(rlp.Serializable):
fields = [
('shard_id', int256),
('checkpoint', hash32),
('signer_bitmask', binary),
('aggregate_sig', int256)
]
def __init__(self, shard_id, checkpoint, signer_bitmask, aggregate_sig):
# at the beginning of a method, locals() is a dict of all arguments
fields = {k: v for k, v in locals().items() if k != 'self'}
super(BlockHeader, self).__init__(**fields)
class BeaconBlock(rlp.Serializable):
fields = [
# Hash of the parent block
('parent_hash', hash32),
# Number of skips (for the full PoS mechanism)
('skip_count', int256),
# Randao commitment reveal
('randao_reveal', hash32),
# Bitmask of who participated in the block notarization committee
('attestation_bitmask', binary),
# Their aggregate sig
('attestation_aggregate_sig', int256),
('ffg_signer_list', binary),
('ffg_aggregate_sig', int256),
# Shard aggregate votes
('shard_aggregate_votes', CountableList(AggregateVote)),
# Reference to main chain block
('main_chain_ref', hash32),
# Hash of the state
('state_hash', hash32),
# Block height
('height', int256),
# Signature from signer
('sig', int256)
]
def __init__(self,
parent_hash=b'\x00'*32, skip_count=0, randao_reveal=b'\x00'*32,
attestation_bitmask=b'', attestation_aggregate_sig=0,
ffg_signer_list=b'', ffg_aggregate_sig=0, main_chain_ref=b'\x00'*32,
shard_aggregate_votes=[], main_chain_ref=b'\x00'*32,
state_hash=b'\x00'*32, height=0, sig=0):
# at the beginning of a method, locals() is a dict of all arguments
fields = {k: v for k, v in locals().items() if k != 'self'}
super(BlockHeader, self).__init__(**fields)
def quick_sample(seed, validator_count, sample_count):
k = 0
while 256**k < n:
k += 1
o = []; source = seed; pos = 0
while len(o) < sample_count:
if pos + k > 32:
source = blake(source)
pos = 0
m = big_endian_to_int(source[pos:pos+k])
if n * (m // n + 1) <= 256**k:
o.append(m % n)
pos += k
def get_shuffling(seed, validator_count, sample=None):
assert validator_count <= 16777216
rand_max = 16777216 - 16777216 % validator_count
o = list(range(validator_count)); source = seed
i = 0
while i < sample if sample is not None else validator_count:
source = blake(source)
for pos in range(0, 30, 3):
m = int.from_bytes(source[pos:pos+3], 'big')
remaining = validator_count - i
if validator_count < rand_max:
replacement_pos = m % remaining + i
o[i], o[replacement_pos] = o[replacement_pos], o[i]
i += 1
return o
privkeys = [int.from_bytes(blake2s(str(i).encode('utf-8'))) for i in range(3000)]
class ValidatorRecord():
def __init__(self, **kwargs):
self.fields = {'pubkey': 'int256', 'return_shard': 'int16',
'return_address': 'address', 'randao_commitment': 'hash32',
'balance': 'int64', 'switch_dynasty': 'int64'}
defaults = {}
for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k]))
def mock_make_child(parent_state, skips, ):
attest
class CheckpointRecord():
def __init__(self, **kwargs):
self.fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'}
defaults = {}
for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k]))
fields = [
('parent_hash', hash32),
('skip_count', int256),
('randao_reveal', hash32),
('attestation_bitmask', binary),
('attestation_aggregate_sig', int256),
('ffg_signer_list', binary),
('ffg_aggregate_sig', int256),
('main_chain_ref', hash32),
('state_hash', hash32),
('height', int256),
('sig', int256)
]
class ActiveState():
def __init__(self, **kwargs):
self.fields = {'height': 'int64', 'randao': 'hash32',
'validator_ffg_voted': 'bytes', 'rewarded': ['int24'],
'penalized': ['int24'], 'checkpoints': [CheckpointRecord],
'total_skip_count': 'int64'}
defaults = {'height': 0, 'randao': b'\x00'*32,
'validator_ffg_voted': b'', 'rewarded': [],
'penalized': [], 'checkpoints': [], 'total_skip_count': 0}
for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k]))
class CrystallizedState():
def __init__(self, **kwargs):
self.fields = {'active_validators': [ValidatorRecord],
'queued_validators': [ValidatorRecord],
'exited_validators': [ValidatorRecord],
'last_justified_epoch': 'int64',
'last_finalized_epoch': 'int64',
'dynasty': 'int64',
'current_checkpoint': 'hash32',
'total_deposits': 'int256'}
self.defaults = {'active_validators': [],
'queued_validators': [],
'exited_validators': [],
'last_justified_epoch': 0,
'last_finalized_epoch': 0,
'dynasty': 0,
'current_checkpoint': b'\x00'*32,
'total_deposits': 0}
for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k]))
def compute_state_transition(parent_state, block, ):
pass
def mock_make_child(parent_state, parent_hash, skips, attester_share=0.8, checkpoint_shards=[]):
parent_attestation_hash = parent_hash + \
parent_state.checkpoint_hash + \
parent_state.epoch.to_bytes(32, 'big') + \
parent_state.source_epoch.to_bytes(32, 'big')
checkpoint_attestation_hash = b'\x00' * 32 + \
parent_state.checkpoint_hash + \
parent_state.epoch.to_bytes(32, 'big') + \
parent_state.source_epoch.to_bytes(32, 'big')
validator_count = len(parent_state.active_validators)
indices = get_shuffling(parent.randao_state, validator_count,
min(parent_state.active_validators, 128))
# Randomly pick indices to include
bitfield = [1 if random.random() < attester_share else 0 for i in indices]
# Attestations
sigs = [bls.sign(parent_attestation_hash, keymap[parent_state.active_validators[indices[i]].pubkey])
for i in range(len(indices)) if bitfield[i]]
attestation_aggregate_sig = bls.aggregate_sig(sigs)
attestation_bitmask = bytearray((len(bitfield)-1) // 8 + 1)
for i, b in enumerate(bitfield):
attestation_bitmask[i//8] ^= (128 >> (i % 8)) * b
# Randomly pick indices to include for checkpoints
shard_aggregate_votes = []
for shard, crosslinker_share in checkpoint_shards:
indices = parent_state.shuffling[(validator_count * shard) // 100: (validator_count * (shard + 1)) // 100]
bitfield = [1 if random.random() < crosslinker_share else 0 for i in indices]
bitmask = bytearray((len(bitfield)-1) // 8 + 1)
checkpoint = blake(bytes([shard]))
sigs = [bls.sign(checkpoint, keymap[parent_state.active_validators[indices[i]].pubkey])
for i in range(len(indices)) if bitfield[i]]
shard_aggregate_votes.append(AggregateVote(shard, checkpoint, bitmask, bls.aggregate_sig(sigs)))
# State calculations
o = BlockHeader(parent.hash, skips, blake(str(random.random()).encode('utf-8')),
attestation_bitmask, attestation_aggregate_sig, shard_aggregate_votes,
b'\x00'*32, state.hash, state.height)
# Main signature
o.sign(keymap[parent_state.active_validators[indices[-1]].pubkey])
return o

View File

@ -0,0 +1,50 @@
def serialize(val, typ):
if typ in ('hash32', 'address'):
assert len(val) == 20 if typ == 'address' else 32
return val
elif typ[:3] == 'int':
length = int(typ[3:])
assert length % 8 == 0
return val.to_bytes(length // 8, 'big')
elif typ == 'bytes':
return len(val).to_bytes(4, 'big') + val
elif isinstance(typ, list):
assert len(typ) == 1
sub = b''.join([serialize(x, typ[0]) for x in val])
return len(sub).to_bytes(4, 'big') + sub
elif isinstance(typ, type):
sub = b''.join([serialize(getattr(val, k), typ.fields[k]) for k in sorted(typ.fields.keys())])
return len(sub).to_bytes(4, 'big') + sub
def _deserialize(data, start, typ):
if typ in ('hash32', 'address'):
length = 20 if typ == 'address' else 32
assert len(data) + start >= length
return data[start: start+length], start+length
elif typ[:3] == 'int':
length = int(typ[3:])
assert length % 8 == 0
assert len(data) + start >= length // 8
return int.from_bytes(data[start: start+length//8], 'big'), start+length//8
elif typ == 'bytes':
length = int.from_bytes(data[start:start+4], 'big')
assert len(data) + start >= 4+length
return data[start+4: start+4+length], start+4+length
elif isinstance(typ, list):
assert len(typ) == 1
length = int.from_bytes(data[start:start+4])
pos, o = start + 4, []
while pos < start + 4 + length:
result, pos = _deserialize(data, pos, typ[0])
o.append(result)
assert pos == start + 4 + length
return o, pos
elif isinstance(typ, type):
values = {}
pos = start
for k in sorted(typ.fields.keys()):
values[k], pos = _deserialize(data, pos, typ.fields[k])
return typ(**values), pos
def deserialize(data, typ):
return _deserialize(data, 0, typ)[0]

48
beacon_chain_impl/test.py Normal file
View File

@ -0,0 +1,48 @@
from bls import G1, G2, hash_to_G2, compress_G1, compress_G2, \
decompress_G1, decompress_G2, normalize, multiply, \
sign, privtopub, aggregate_sigs, aggregate_pubs, verify
from simpleserialize import serialize, deserialize
from full_pos import ActiveState
for x in (1, 5, 124, 735, 127409812145, 90768492698215092512159, 0):
print('Testing with privkey %d' % x)
p1 = multiply(G1, x)
p2 = multiply(G2, x)
msg = str(x).encode('utf-8')
msghash = hash_to_G2(msg)
assert normalize(decompress_G1(compress_G1(p1))) == normalize(p1)
assert normalize(decompress_G2(compress_G2(p2))) == normalize(p2)
assert normalize(decompress_G2(compress_G2(msghash))) == normalize(msghash)
sig = sign(msg, x)
pub = privtopub(x)
assert verify(msg, pub, sig)
print('Testing signature aggregation')
msg = b'cow'
keys = [1, 5, 124, 735, 127409812145, 90768492698215092512159, 0]
sigs = [sign(msg, k) for k in keys]
pubs = [privtopub(k) for k in keys]
aggsig = aggregate_sigs(sigs)
aggpub = aggregate_pubs(pubs)
assert verify(msg, aggpub, aggsig)
print('Testing basic serialization')
assert serialize(5, 'int8') == b'\x05'
assert deserialize(b'\x05', 'int8') == 5
assert serialize(2**32-3, 'int40') == b'\x00\xff\xff\xff\xfd'
assert deserialize(b'\x00\xff\xff\xff\xfd', 'int40') == 2**32-3
assert serialize(b'\x35'*20, 'address') == b'\x35'*20
assert deserialize(b'\x35'*20, 'address') == b'\x35'*20
assert serialize(b'\x35'*32, 'hash32') == b'\x35'*32
assert deserialize(b'\x35'*32, 'hash32') == b'\x35'*32
assert serialize(b'cow', 'bytes') == b'\x00\x00\x00\x03cow'
assert deserialize(b'\x00\x00\x00\x03cow', 'bytes') == b'cow'
print('Testing advanced serialization')
s = ActiveState()
ds = deserialize(serialize(s))
for x in s.fields:
assert getattr(s, x) == getattr(ds, x)