Further commit to full PoS chain
This commit is contained in:
parent
6250c90ed2
commit
de85691047
|
@ -83,25 +83,3 @@ def aggregate_pubs(pubs):
|
|||
for p in pubs:
|
||||
o = add(o, decompress_G1(p))
|
||||
return compress_G1(o)
|
||||
|
||||
for x in (1, 5, 124, 735, 127409812145, 90768492698215092512159, 0):
|
||||
print('Testing with privkey %d' % x)
|
||||
p1 = multiply(G1, x)
|
||||
p2 = multiply(G2, x)
|
||||
msg = str(x).encode('utf-8')
|
||||
msghash = hash_to_G2(msg)
|
||||
assert normalize(decompress_G1(compress_G1(p1))) == normalize(p1)
|
||||
assert normalize(decompress_G2(compress_G2(p2))) == normalize(p2)
|
||||
assert normalize(decompress_G2(compress_G2(msghash))) == normalize(msghash)
|
||||
sig = sign(msg, x)
|
||||
pub = privtopub(x)
|
||||
assert verify(msg, pub, sig)
|
||||
|
||||
print('Testing signature aggregation')
|
||||
msg = b'cow'
|
||||
keys = [1, 5, 124, 735, 127409812145, 90768492698215092512159, 0]
|
||||
sigs = [sign(msg, k) for k in keys]
|
||||
pubs = [privtopub(k) for k in keys]
|
||||
aggsig = aggregate_sigs(sigs)
|
||||
aggpub = aggregate_pubs(pubs)
|
||||
assert verify(msg, aggpub, aggsig)
|
||||
|
|
|
@ -1,68 +1,168 @@
|
|||
from hashlib import blake2s
|
||||
blake = lambda x: blake2s(x).digest()
|
||||
from ethereum.utils import normalize_address, hash32, trie_root, \
|
||||
big_endian_int, address, int256, encode_hex, decode_hex, encode_int, \
|
||||
big_endian_to_int
|
||||
from rlp.sedes import big_endian_int, Binary, binary, CountableList
|
||||
from rlp.sedes import big_endian_int, Binary, binary, CountableList, BigEndianInt, Binary
|
||||
int256 = BigEndianInt(256)
|
||||
hash32 = Binary.fixed_length(32)
|
||||
import rlp
|
||||
import bls
|
||||
import random
|
||||
|
||||
privkeys = [int.from_bytes(blake(str(i).encode('utf-8')), 'big') for i in range(30)]
|
||||
keymap = {bls.privtopub(k): k for k in privkeys}
|
||||
|
||||
class AggregateVote(rlp.Serializable):
|
||||
fields = [
|
||||
('shard_id', int256),
|
||||
('checkpoint', hash32),
|
||||
('signer_bitmask', binary),
|
||||
('aggregate_sig', int256)
|
||||
]
|
||||
|
||||
def __init__(self, shard_id, checkpoint, signer_bitmask, aggregate_sig):
|
||||
# at the beginning of a method, locals() is a dict of all arguments
|
||||
fields = {k: v for k, v in locals().items() if k != 'self'}
|
||||
super(BlockHeader, self).__init__(**fields)
|
||||
|
||||
class BeaconBlock(rlp.Serializable):
|
||||
|
||||
fields = [
|
||||
# Hash of the parent block
|
||||
('parent_hash', hash32),
|
||||
# Number of skips (for the full PoS mechanism)
|
||||
('skip_count', int256),
|
||||
# Randao commitment reveal
|
||||
('randao_reveal', hash32),
|
||||
# Bitmask of who participated in the block notarization committee
|
||||
('attestation_bitmask', binary),
|
||||
# Their aggregate sig
|
||||
('attestation_aggregate_sig', int256),
|
||||
('ffg_signer_list', binary),
|
||||
('ffg_aggregate_sig', int256),
|
||||
# Shard aggregate votes
|
||||
('shard_aggregate_votes', CountableList(AggregateVote)),
|
||||
# Reference to main chain block
|
||||
('main_chain_ref', hash32),
|
||||
# Hash of the state
|
||||
('state_hash', hash32),
|
||||
# Block height
|
||||
('height', int256),
|
||||
# Signature from signer
|
||||
('sig', int256)
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
parent_hash=b'\x00'*32, skip_count=0, randao_reveal=b'\x00'*32,
|
||||
attestation_bitmask=b'', attestation_aggregate_sig=0,
|
||||
ffg_signer_list=b'', ffg_aggregate_sig=0, main_chain_ref=b'\x00'*32,
|
||||
shard_aggregate_votes=[], main_chain_ref=b'\x00'*32,
|
||||
state_hash=b'\x00'*32, height=0, sig=0):
|
||||
# at the beginning of a method, locals() is a dict of all arguments
|
||||
fields = {k: v for k, v in locals().items() if k != 'self'}
|
||||
super(BlockHeader, self).__init__(**fields)
|
||||
|
||||
def quick_sample(seed, validator_count, sample_count):
|
||||
k = 0
|
||||
while 256**k < n:
|
||||
k += 1
|
||||
o = []; source = seed; pos = 0
|
||||
while len(o) < sample_count:
|
||||
if pos + k > 32:
|
||||
source = blake(source)
|
||||
pos = 0
|
||||
m = big_endian_to_int(source[pos:pos+k])
|
||||
if n * (m // n + 1) <= 256**k:
|
||||
o.append(m % n)
|
||||
pos += k
|
||||
def get_shuffling(seed, validator_count, sample=None):
|
||||
assert validator_count <= 16777216
|
||||
rand_max = 16777216 - 16777216 % validator_count
|
||||
o = list(range(validator_count)); source = seed
|
||||
i = 0
|
||||
while i < sample if sample is not None else validator_count:
|
||||
source = blake(source)
|
||||
for pos in range(0, 30, 3):
|
||||
m = int.from_bytes(source[pos:pos+3], 'big')
|
||||
remaining = validator_count - i
|
||||
if validator_count < rand_max:
|
||||
replacement_pos = m % remaining + i
|
||||
o[i], o[replacement_pos] = o[replacement_pos], o[i]
|
||||
i += 1
|
||||
return o
|
||||
|
||||
privkeys = [int.from_bytes(blake2s(str(i).encode('utf-8'))) for i in range(3000)]
|
||||
class ValidatorRecord():
|
||||
def __init__(self, **kwargs):
|
||||
self.fields = {'pubkey': 'int256', 'return_shard': 'int16',
|
||||
'return_address': 'address', 'randao_commitment': 'hash32',
|
||||
'balance': 'int64', 'switch_dynasty': 'int64'}
|
||||
defaults = {}
|
||||
for k in self.fields.keys():
|
||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
||||
|
||||
def mock_make_child(parent_state, skips, ):
|
||||
attest
|
||||
class CheckpointRecord():
|
||||
def __init__(self, **kwargs):
|
||||
self.fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'}
|
||||
defaults = {}
|
||||
for k in self.fields.keys():
|
||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
||||
|
||||
fields = [
|
||||
('parent_hash', hash32),
|
||||
('skip_count', int256),
|
||||
('randao_reveal', hash32),
|
||||
('attestation_bitmask', binary),
|
||||
('attestation_aggregate_sig', int256),
|
||||
('ffg_signer_list', binary),
|
||||
('ffg_aggregate_sig', int256),
|
||||
('main_chain_ref', hash32),
|
||||
('state_hash', hash32),
|
||||
('height', int256),
|
||||
('sig', int256)
|
||||
]
|
||||
|
||||
class ActiveState():
|
||||
def __init__(self, **kwargs):
|
||||
self.fields = {'height': 'int64', 'randao': 'hash32',
|
||||
'validator_ffg_voted': 'bytes', 'rewarded': ['int24'],
|
||||
'penalized': ['int24'], 'checkpoints': [CheckpointRecord],
|
||||
'total_skip_count': 'int64'}
|
||||
defaults = {'height': 0, 'randao': b'\x00'*32,
|
||||
'validator_ffg_voted': b'', 'rewarded': [],
|
||||
'penalized': [], 'checkpoints': [], 'total_skip_count': 0}
|
||||
for k in self.fields.keys():
|
||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
||||
|
||||
class CrystallizedState():
|
||||
def __init__(self, **kwargs):
|
||||
self.fields = {'active_validators': [ValidatorRecord],
|
||||
'queued_validators': [ValidatorRecord],
|
||||
'exited_validators': [ValidatorRecord],
|
||||
'last_justified_epoch': 'int64',
|
||||
'last_finalized_epoch': 'int64',
|
||||
'dynasty': 'int64',
|
||||
'current_checkpoint': 'hash32',
|
||||
'total_deposits': 'int256'}
|
||||
self.defaults = {'active_validators': [],
|
||||
'queued_validators': [],
|
||||
'exited_validators': [],
|
||||
'last_justified_epoch': 0,
|
||||
'last_finalized_epoch': 0,
|
||||
'dynasty': 0,
|
||||
'current_checkpoint': b'\x00'*32,
|
||||
'total_deposits': 0}
|
||||
for k in self.fields.keys():
|
||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
||||
|
||||
|
||||
def compute_state_transition(parent_state, block, ):
|
||||
pass
|
||||
|
||||
def mock_make_child(parent_state, parent_hash, skips, attester_share=0.8, checkpoint_shards=[]):
|
||||
parent_attestation_hash = parent_hash + \
|
||||
parent_state.checkpoint_hash + \
|
||||
parent_state.epoch.to_bytes(32, 'big') + \
|
||||
parent_state.source_epoch.to_bytes(32, 'big')
|
||||
checkpoint_attestation_hash = b'\x00' * 32 + \
|
||||
parent_state.checkpoint_hash + \
|
||||
parent_state.epoch.to_bytes(32, 'big') + \
|
||||
parent_state.source_epoch.to_bytes(32, 'big')
|
||||
validator_count = len(parent_state.active_validators)
|
||||
indices = get_shuffling(parent.randao_state, validator_count,
|
||||
min(parent_state.active_validators, 128))
|
||||
# Randomly pick indices to include
|
||||
bitfield = [1 if random.random() < attester_share else 0 for i in indices]
|
||||
# Attestations
|
||||
sigs = [bls.sign(parent_attestation_hash, keymap[parent_state.active_validators[indices[i]].pubkey])
|
||||
for i in range(len(indices)) if bitfield[i]]
|
||||
attestation_aggregate_sig = bls.aggregate_sig(sigs)
|
||||
attestation_bitmask = bytearray((len(bitfield)-1) // 8 + 1)
|
||||
for i, b in enumerate(bitfield):
|
||||
attestation_bitmask[i//8] ^= (128 >> (i % 8)) * b
|
||||
# Randomly pick indices to include for checkpoints
|
||||
shard_aggregate_votes = []
|
||||
for shard, crosslinker_share in checkpoint_shards:
|
||||
indices = parent_state.shuffling[(validator_count * shard) // 100: (validator_count * (shard + 1)) // 100]
|
||||
bitfield = [1 if random.random() < crosslinker_share else 0 for i in indices]
|
||||
bitmask = bytearray((len(bitfield)-1) // 8 + 1)
|
||||
checkpoint = blake(bytes([shard]))
|
||||
sigs = [bls.sign(checkpoint, keymap[parent_state.active_validators[indices[i]].pubkey])
|
||||
for i in range(len(indices)) if bitfield[i]]
|
||||
shard_aggregate_votes.append(AggregateVote(shard, checkpoint, bitmask, bls.aggregate_sig(sigs)))
|
||||
# State calculations
|
||||
|
||||
o = BlockHeader(parent.hash, skips, blake(str(random.random()).encode('utf-8')),
|
||||
attestation_bitmask, attestation_aggregate_sig, shard_aggregate_votes,
|
||||
b'\x00'*32, state.hash, state.height)
|
||||
# Main signature
|
||||
o.sign(keymap[parent_state.active_validators[indices[-1]].pubkey])
|
||||
return o
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
def serialize(val, typ):
|
||||
if typ in ('hash32', 'address'):
|
||||
assert len(val) == 20 if typ == 'address' else 32
|
||||
return val
|
||||
elif typ[:3] == 'int':
|
||||
length = int(typ[3:])
|
||||
assert length % 8 == 0
|
||||
return val.to_bytes(length // 8, 'big')
|
||||
elif typ == 'bytes':
|
||||
return len(val).to_bytes(4, 'big') + val
|
||||
elif isinstance(typ, list):
|
||||
assert len(typ) == 1
|
||||
sub = b''.join([serialize(x, typ[0]) for x in val])
|
||||
return len(sub).to_bytes(4, 'big') + sub
|
||||
elif isinstance(typ, type):
|
||||
sub = b''.join([serialize(getattr(val, k), typ.fields[k]) for k in sorted(typ.fields.keys())])
|
||||
return len(sub).to_bytes(4, 'big') + sub
|
||||
|
||||
def _deserialize(data, start, typ):
|
||||
if typ in ('hash32', 'address'):
|
||||
length = 20 if typ == 'address' else 32
|
||||
assert len(data) + start >= length
|
||||
return data[start: start+length], start+length
|
||||
elif typ[:3] == 'int':
|
||||
length = int(typ[3:])
|
||||
assert length % 8 == 0
|
||||
assert len(data) + start >= length // 8
|
||||
return int.from_bytes(data[start: start+length//8], 'big'), start+length//8
|
||||
elif typ == 'bytes':
|
||||
length = int.from_bytes(data[start:start+4], 'big')
|
||||
assert len(data) + start >= 4+length
|
||||
return data[start+4: start+4+length], start+4+length
|
||||
elif isinstance(typ, list):
|
||||
assert len(typ) == 1
|
||||
length = int.from_bytes(data[start:start+4])
|
||||
pos, o = start + 4, []
|
||||
while pos < start + 4 + length:
|
||||
result, pos = _deserialize(data, pos, typ[0])
|
||||
o.append(result)
|
||||
assert pos == start + 4 + length
|
||||
return o, pos
|
||||
elif isinstance(typ, type):
|
||||
values = {}
|
||||
pos = start
|
||||
for k in sorted(typ.fields.keys()):
|
||||
values[k], pos = _deserialize(data, pos, typ.fields[k])
|
||||
return typ(**values), pos
|
||||
|
||||
def deserialize(data, typ):
|
||||
return _deserialize(data, 0, typ)[0]
|
|
@ -0,0 +1,48 @@
|
|||
from bls import G1, G2, hash_to_G2, compress_G1, compress_G2, \
|
||||
decompress_G1, decompress_G2, normalize, multiply, \
|
||||
sign, privtopub, aggregate_sigs, aggregate_pubs, verify
|
||||
|
||||
from simpleserialize import serialize, deserialize
|
||||
|
||||
from full_pos import ActiveState
|
||||
|
||||
for x in (1, 5, 124, 735, 127409812145, 90768492698215092512159, 0):
|
||||
print('Testing with privkey %d' % x)
|
||||
p1 = multiply(G1, x)
|
||||
p2 = multiply(G2, x)
|
||||
msg = str(x).encode('utf-8')
|
||||
msghash = hash_to_G2(msg)
|
||||
assert normalize(decompress_G1(compress_G1(p1))) == normalize(p1)
|
||||
assert normalize(decompress_G2(compress_G2(p2))) == normalize(p2)
|
||||
assert normalize(decompress_G2(compress_G2(msghash))) == normalize(msghash)
|
||||
sig = sign(msg, x)
|
||||
pub = privtopub(x)
|
||||
assert verify(msg, pub, sig)
|
||||
|
||||
print('Testing signature aggregation')
|
||||
msg = b'cow'
|
||||
keys = [1, 5, 124, 735, 127409812145, 90768492698215092512159, 0]
|
||||
sigs = [sign(msg, k) for k in keys]
|
||||
pubs = [privtopub(k) for k in keys]
|
||||
aggsig = aggregate_sigs(sigs)
|
||||
aggpub = aggregate_pubs(pubs)
|
||||
assert verify(msg, aggpub, aggsig)
|
||||
|
||||
print('Testing basic serialization')
|
||||
|
||||
assert serialize(5, 'int8') == b'\x05'
|
||||
assert deserialize(b'\x05', 'int8') == 5
|
||||
assert serialize(2**32-3, 'int40') == b'\x00\xff\xff\xff\xfd'
|
||||
assert deserialize(b'\x00\xff\xff\xff\xfd', 'int40') == 2**32-3
|
||||
assert serialize(b'\x35'*20, 'address') == b'\x35'*20
|
||||
assert deserialize(b'\x35'*20, 'address') == b'\x35'*20
|
||||
assert serialize(b'\x35'*32, 'hash32') == b'\x35'*32
|
||||
assert deserialize(b'\x35'*32, 'hash32') == b'\x35'*32
|
||||
assert serialize(b'cow', 'bytes') == b'\x00\x00\x00\x03cow'
|
||||
assert deserialize(b'\x00\x00\x00\x03cow', 'bytes') == b'cow'
|
||||
|
||||
print('Testing advanced serialization')
|
||||
s = ActiveState()
|
||||
ds = deserialize(serialize(s))
|
||||
for x in s.fields:
|
||||
assert getattr(s, x) == getattr(ds, x)
|
Loading…
Reference in New Issue