Serialization/deserialization seems to work

This commit is contained in:
Vitalik Buterin 2018-06-11 12:53:32 -04:00
parent de85691047
commit 058d45f0b5
3 changed files with 51 additions and 24 deletions

View File

@ -74,33 +74,41 @@ def get_shuffling(seed, validator_count, sample=None):
return o
class ValidatorRecord():
def __init__(self, **kwargs):
self.fields = {'pubkey': 'int256', 'return_shard': 'int16',
fields = {'pubkey': 'int256', 'return_shard': 'int16',
'return_address': 'address', 'randao_commitment': 'hash32',
'balance': 'int64', 'switch_dynasty': 'int64'}
defaults = {}
def __init__(self, **kwargs):
for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k]))
assert k in kwargs or k in self.defaults
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
class CheckpointRecord():
def __init__(self, **kwargs):
self.fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'}
fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'}
defaults = {}
def __init__(self, **kwargs):
for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k]))
assert k in kwargs or k in self.defaults
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
class ActiveState():
def __init__(self, **kwargs):
self.fields = {'height': 'int64', 'randao': 'hash32',
fields = {'height': 'int64', 'randao': 'hash32',
'validator_ffg_voted': 'bytes', 'rewarded': ['int24'],
'penalized': ['int24'], 'checkpoints': [CheckpointRecord],
'total_skip_count': 'int64'}
defaults = {'height': 0, 'randao': b'\x00'*32,
'validator_ffg_voted': b'', 'rewarded': [],
'penalized': [], 'checkpoints': [], 'total_skip_count': 0}
def __init__(self, **kwargs):
for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k]))
assert k in kwargs or k in self.defaults
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
class CrystallizedState():
def __init__(self, **kwargs):
@ -121,7 +129,8 @@ class CrystallizedState():
'current_checkpoint': b'\x00'*32,
'total_deposits': 0}
for k in self.fields.keys():
setattr(self, k, kwargs.get(k, defaults[k]))
assert k in kwargs or k in self.defaults
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
def compute_state_transition(parent_state, block, ):

View File

@ -1,8 +1,10 @@
def serialize(val, typ):
def serialize(val, typ=None):
if typ is None and hasattr(val, 'fields'):
typ = type(val)
if typ in ('hash32', 'address'):
assert len(val) == 20 if typ == 'address' else 32
return val
elif typ[:3] == 'int':
elif isinstance(typ, str) and typ[:3] == 'int':
length = int(typ[3:])
assert length % 8 == 0
return val.to_bytes(length // 8, 'big')
@ -21,7 +23,7 @@ def _deserialize(data, start, typ):
length = 20 if typ == 'address' else 32
assert len(data) + start >= length
return data[start: start+length], start+length
elif typ[:3] == 'int':
elif isinstance(typ, str) and typ[:3] == 'int':
length = int(typ[3:])
assert length % 8 == 0
assert len(data) + start >= length // 8
@ -32,7 +34,7 @@ def _deserialize(data, start, typ):
return data[start+4: start+4+length], start+4+length
elif isinstance(typ, list):
assert len(typ) == 1
length = int.from_bytes(data[start:start+4])
length = int.from_bytes(data[start:start+4], 'big')
pos, o = start + 4, []
while pos < start + 4 + length:
result, pos = _deserialize(data, pos, typ[0])
@ -40,10 +42,12 @@ def _deserialize(data, start, typ):
assert pos == start + 4 + length
return o, pos
elif isinstance(typ, type):
length = int.from_bytes(data[start:start+4], 'big')
values = {}
pos = start
pos = start + 4
for k in sorted(typ.fields.keys()):
values[k], pos = _deserialize(data, pos, typ.fields[k])
assert pos == start + 4 + length
return typ(**values), pos
def deserialize(data, typ):

View File

@ -4,7 +4,7 @@ from bls import G1, G2, hash_to_G2, compress_G1, compress_G2, \
from simpleserialize import serialize, deserialize
from full_pos import ActiveState
from full_pos import ActiveState, CheckpointRecord
for x in (1, 5, 124, 735, 127409812145, 90768492698215092512159, 0):
print('Testing with privkey %d' % x)
@ -42,7 +42,21 @@ assert serialize(b'cow', 'bytes') == b'\x00\x00\x00\x03cow'
assert deserialize(b'\x00\x00\x00\x03cow', 'bytes') == b'cow'
print('Testing advanced serialization')
def eq(x, y):
if hasattr(x, 'fields') and hasattr(y, 'fields'):
for f in x.fields:
if not eq(getattr(x, f), getattr(y, f)):
print('Unequal:', x, y, f, getattr(x, f), getattr(y, f))
return False
return True
else:
return x == y
s = ActiveState()
ds = deserialize(serialize(s))
for x in s.fields:
assert getattr(s, x) == getattr(ds, x)
ds = deserialize(serialize(s, type(s)), type(s))
assert eq(s, ds)
s = ActiveState(checkpoints=[CheckpointRecord(checkpoint_hash=b'\x55'*32, bitmask=b'31337dawg')],
height=555, randao=b'\x88'*32, rewarded=[5,7,9,579], penalized=[3]*333)
ds = deserialize(serialize(s, type(s)), type(s))
assert eq(s, ds)