Serialization/deserialization seems to work

This commit is contained in:
Vitalik Buterin 2018-06-11 12:53:32 -04:00
parent de85691047
commit 058d45f0b5
3 changed files with 51 additions and 24 deletions

View File

@ -74,33 +74,41 @@ def get_shuffling(seed, validator_count, sample=None):
return o return o
class ValidatorRecord(): class ValidatorRecord():
def __init__(self, **kwargs): fields = {'pubkey': 'int256', 'return_shard': 'int16',
self.fields = {'pubkey': 'int256', 'return_shard': 'int16',
'return_address': 'address', 'randao_commitment': 'hash32', 'return_address': 'address', 'randao_commitment': 'hash32',
'balance': 'int64', 'switch_dynasty': 'int64'} 'balance': 'int64', 'switch_dynasty': 'int64'}
defaults = {} defaults = {}
def __init__(self, **kwargs):
for k in self.fields.keys(): for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k])) assert k in kwargs or k in self.defaults
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
class CheckpointRecord(): class CheckpointRecord():
def __init__(self, **kwargs):
self.fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'} fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'}
defaults = {} defaults = {}
def __init__(self, **kwargs):
for k in self.fields.keys(): for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k])) assert k in kwargs or k in self.defaults
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
class ActiveState(): class ActiveState():
def __init__(self, **kwargs):
self.fields = {'height': 'int64', 'randao': 'hash32', fields = {'height': 'int64', 'randao': 'hash32',
'validator_ffg_voted': 'bytes', 'rewarded': ['int24'], 'validator_ffg_voted': 'bytes', 'rewarded': ['int24'],
'penalized': ['int24'], 'checkpoints': [CheckpointRecord], 'penalized': ['int24'], 'checkpoints': [CheckpointRecord],
'total_skip_count': 'int64'} 'total_skip_count': 'int64'}
defaults = {'height': 0, 'randao': b'\x00'*32, defaults = {'height': 0, 'randao': b'\x00'*32,
'validator_ffg_voted': b'', 'rewarded': [], 'validator_ffg_voted': b'', 'rewarded': [],
'penalized': [], 'checkpoints': [], 'total_skip_count': 0} 'penalized': [], 'checkpoints': [], 'total_skip_count': 0}
def __init__(self, **kwargs):
for k in self.fields.keys(): for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k])) assert k in kwargs or k in self.defaults
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
class CrystallizedState(): class CrystallizedState():
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -121,7 +129,8 @@ class CrystallizedState():
'current_checkpoint': b'\x00'*32, 'current_checkpoint': b'\x00'*32,
'total_deposits': 0} 'total_deposits': 0}
for k in self.fields.keys(): for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k])) assert k in kwargs or k in self.defaults
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
def compute_state_transition(parent_state, block, ): def compute_state_transition(parent_state, block, ):

View File

@ -1,8 +1,10 @@
def serialize(val, typ): def serialize(val, typ=None):
if typ is None and hasattr(val, 'fields'):
typ = type(val)
if typ in ('hash32', 'address'): if typ in ('hash32', 'address'):
assert len(val) == 20 if typ == 'address' else 32 assert len(val) == 20 if typ == 'address' else 32
return val return val
elif typ[:3] == 'int': elif isinstance(typ, str) and typ[:3] == 'int':
length = int(typ[3:]) length = int(typ[3:])
assert length % 8 == 0 assert length % 8 == 0
return val.to_bytes(length // 8, 'big') return val.to_bytes(length // 8, 'big')
@ -21,7 +23,7 @@ def _deserialize(data, start, typ):
length = 20 if typ == 'address' else 32 length = 20 if typ == 'address' else 32
assert len(data) + start >= length assert len(data) + start >= length
return data[start: start+length], start+length return data[start: start+length], start+length
elif typ[:3] == 'int': elif isinstance(typ, str) and typ[:3] == 'int':
length = int(typ[3:]) length = int(typ[3:])
assert length % 8 == 0 assert length % 8 == 0
assert len(data) + start >= length // 8 assert len(data) + start >= length // 8
@ -32,7 +34,7 @@ def _deserialize(data, start, typ):
return data[start+4: start+4+length], start+4+length return data[start+4: start+4+length], start+4+length
elif isinstance(typ, list): elif isinstance(typ, list):
assert len(typ) == 1 assert len(typ) == 1
length = int.from_bytes(data[start:start+4]) length = int.from_bytes(data[start:start+4], 'big')
pos, o = start + 4, [] pos, o = start + 4, []
while pos < start + 4 + length: while pos < start + 4 + length:
result, pos = _deserialize(data, pos, typ[0]) result, pos = _deserialize(data, pos, typ[0])
@ -40,10 +42,12 @@ def _deserialize(data, start, typ):
assert pos == start + 4 + length assert pos == start + 4 + length
return o, pos return o, pos
elif isinstance(typ, type): elif isinstance(typ, type):
length = int.from_bytes(data[start:start+4], 'big')
values = {} values = {}
pos = start pos = start + 4
for k in sorted(typ.fields.keys()): for k in sorted(typ.fields.keys()):
values[k], pos = _deserialize(data, pos, typ.fields[k]) values[k], pos = _deserialize(data, pos, typ.fields[k])
assert pos == start + 4 + length
return typ(**values), pos return typ(**values), pos
def deserialize(data, typ): def deserialize(data, typ):

View File

@ -4,7 +4,7 @@ from bls import G1, G2, hash_to_G2, compress_G1, compress_G2, \
from simpleserialize import serialize, deserialize from simpleserialize import serialize, deserialize
from full_pos import ActiveState from full_pos import ActiveState, CheckpointRecord
for x in (1, 5, 124, 735, 127409812145, 90768492698215092512159, 0): for x in (1, 5, 124, 735, 127409812145, 90768492698215092512159, 0):
print('Testing with privkey %d' % x) print('Testing with privkey %d' % x)
@ -42,7 +42,21 @@ assert serialize(b'cow', 'bytes') == b'\x00\x00\x00\x03cow'
assert deserialize(b'\x00\x00\x00\x03cow', 'bytes') == b'cow' assert deserialize(b'\x00\x00\x00\x03cow', 'bytes') == b'cow'
print('Testing advanced serialization') print('Testing advanced serialization')
def eq(x, y):
if hasattr(x, 'fields') and hasattr(y, 'fields'):
for f in x.fields:
if not eq(getattr(x, f), getattr(y, f)):
print('Unequal:', x, y, f, getattr(x, f), getattr(y, f))
return False
return True
else:
return x == y
s = ActiveState() s = ActiveState()
ds = deserialize(serialize(s)) ds = deserialize(serialize(s, type(s)), type(s))
for x in s.fields: assert eq(s, ds)
assert getattr(s, x) == getattr(ds, x) s = ActiveState(checkpoints=[CheckpointRecord(checkpoint_hash=b'\x55'*32, bitmask=b'31337dawg')],
height=555, randao=b'\x88'*32, rewarded=[5,7,9,579], penalized=[3]*333)
ds = deserialize(serialize(s, type(s)), type(s))
assert eq(s, ds)