Serialization/deserialization seems to work
This commit is contained in:
parent
de85691047
commit
058d45f0b5
|
@ -74,33 +74,41 @@ def get_shuffling(seed, validator_count, sample=None):
|
|||
return o
|
||||
|
||||
class ValidatorRecord():
|
||||
def __init__(self, **kwargs):
|
||||
self.fields = {'pubkey': 'int256', 'return_shard': 'int16',
|
||||
fields = {'pubkey': 'int256', 'return_shard': 'int16',
|
||||
'return_address': 'address', 'randao_commitment': 'hash32',
|
||||
'balance': 'int64', 'switch_dynasty': 'int64'}
|
||||
defaults = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k in self.fields.keys():
|
||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
||||
assert k in kwargs or k in self.defaults
|
||||
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
|
||||
|
||||
class CheckpointRecord():
|
||||
def __init__(self, **kwargs):
|
||||
self.fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'}
|
||||
|
||||
fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'}
|
||||
defaults = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k in self.fields.keys():
|
||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
||||
assert k in kwargs or k in self.defaults
|
||||
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
|
||||
|
||||
|
||||
class ActiveState():
|
||||
def __init__(self, **kwargs):
|
||||
self.fields = {'height': 'int64', 'randao': 'hash32',
|
||||
|
||||
fields = {'height': 'int64', 'randao': 'hash32',
|
||||
'validator_ffg_voted': 'bytes', 'rewarded': ['int24'],
|
||||
'penalized': ['int24'], 'checkpoints': [CheckpointRecord],
|
||||
'total_skip_count': 'int64'}
|
||||
defaults = {'height': 0, 'randao': b'\x00'*32,
|
||||
'validator_ffg_voted': b'', 'rewarded': [],
|
||||
'penalized': [], 'checkpoints': [], 'total_skip_count': 0}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k in self.fields.keys():
|
||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
||||
assert k in kwargs or k in self.defaults
|
||||
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
|
||||
|
||||
class CrystallizedState():
|
||||
def __init__(self, **kwargs):
|
||||
|
@ -121,7 +129,8 @@ class CrystallizedState():
|
|||
'current_checkpoint': b'\x00'*32,
|
||||
'total_deposits': 0}
|
||||
for k in self.fields.keys():
|
||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
||||
assert k in kwargs or k in self.defaults
|
||||
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
|
||||
|
||||
|
||||
def compute_state_transition(parent_state, block, ):
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
def serialize(val, typ):
|
||||
def serialize(val, typ=None):
|
||||
if typ is None and hasattr(val, 'fields'):
|
||||
typ = type(val)
|
||||
if typ in ('hash32', 'address'):
|
||||
assert len(val) == 20 if typ == 'address' else 32
|
||||
return val
|
||||
elif typ[:3] == 'int':
|
||||
elif isinstance(typ, str) and typ[:3] == 'int':
|
||||
length = int(typ[3:])
|
||||
assert length % 8 == 0
|
||||
return val.to_bytes(length // 8, 'big')
|
||||
|
@ -21,7 +23,7 @@ def _deserialize(data, start, typ):
|
|||
length = 20 if typ == 'address' else 32
|
||||
assert len(data) + start >= length
|
||||
return data[start: start+length], start+length
|
||||
elif typ[:3] == 'int':
|
||||
elif isinstance(typ, str) and typ[:3] == 'int':
|
||||
length = int(typ[3:])
|
||||
assert length % 8 == 0
|
||||
assert len(data) + start >= length // 8
|
||||
|
@ -32,7 +34,7 @@ def _deserialize(data, start, typ):
|
|||
return data[start+4: start+4+length], start+4+length
|
||||
elif isinstance(typ, list):
|
||||
assert len(typ) == 1
|
||||
length = int.from_bytes(data[start:start+4])
|
||||
length = int.from_bytes(data[start:start+4], 'big')
|
||||
pos, o = start + 4, []
|
||||
while pos < start + 4 + length:
|
||||
result, pos = _deserialize(data, pos, typ[0])
|
||||
|
@ -40,10 +42,12 @@ def _deserialize(data, start, typ):
|
|||
assert pos == start + 4 + length
|
||||
return o, pos
|
||||
elif isinstance(typ, type):
|
||||
length = int.from_bytes(data[start:start+4], 'big')
|
||||
values = {}
|
||||
pos = start
|
||||
pos = start + 4
|
||||
for k in sorted(typ.fields.keys()):
|
||||
values[k], pos = _deserialize(data, pos, typ.fields[k])
|
||||
assert pos == start + 4 + length
|
||||
return typ(**values), pos
|
||||
|
||||
def deserialize(data, typ):
|
||||
|
|
|
@ -4,7 +4,7 @@ from bls import G1, G2, hash_to_G2, compress_G1, compress_G2, \
|
|||
|
||||
from simpleserialize import serialize, deserialize
|
||||
|
||||
from full_pos import ActiveState
|
||||
from full_pos import ActiveState, CheckpointRecord
|
||||
|
||||
for x in (1, 5, 124, 735, 127409812145, 90768492698215092512159, 0):
|
||||
print('Testing with privkey %d' % x)
|
||||
|
@ -42,7 +42,21 @@ assert serialize(b'cow', 'bytes') == b'\x00\x00\x00\x03cow'
|
|||
assert deserialize(b'\x00\x00\x00\x03cow', 'bytes') == b'cow'
|
||||
|
||||
print('Testing advanced serialization')
|
||||
|
||||
def eq(x, y):
|
||||
if hasattr(x, 'fields') and hasattr(y, 'fields'):
|
||||
for f in x.fields:
|
||||
if not eq(getattr(x, f), getattr(y, f)):
|
||||
print('Unequal:', x, y, f, getattr(x, f), getattr(y, f))
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
s = ActiveState()
|
||||
ds = deserialize(serialize(s))
|
||||
for x in s.fields:
|
||||
assert getattr(s, x) == getattr(ds, x)
|
||||
ds = deserialize(serialize(s, type(s)), type(s))
|
||||
assert eq(s, ds)
|
||||
s = ActiveState(checkpoints=[CheckpointRecord(checkpoint_hash=b'\x55'*32, bitmask=b'31337dawg')],
|
||||
height=555, randao=b'\x88'*32, rewarded=[5,7,9,579], penalized=[3]*333)
|
||||
ds = deserialize(serialize(s, type(s)), type(s))
|
||||
assert eq(s, ds)
|
||||
|
|
Loading…
Reference in New Issue