mirror of
https://github.com/status-im/research.git
synced 2025-01-14 17:14:16 +00:00
Serialization/deserialization seems to work
This commit is contained in:
parent
de85691047
commit
058d45f0b5
@ -74,33 +74,41 @@ def get_shuffling(seed, validator_count, sample=None):
|
|||||||
return o
|
return o
|
||||||
|
|
||||||
class ValidatorRecord():
|
class ValidatorRecord():
|
||||||
def __init__(self, **kwargs):
|
fields = {'pubkey': 'int256', 'return_shard': 'int16',
|
||||||
self.fields = {'pubkey': 'int256', 'return_shard': 'int16',
|
|
||||||
'return_address': 'address', 'randao_commitment': 'hash32',
|
'return_address': 'address', 'randao_commitment': 'hash32',
|
||||||
'balance': 'int64', 'switch_dynasty': 'int64'}
|
'balance': 'int64', 'switch_dynasty': 'int64'}
|
||||||
defaults = {}
|
defaults = {}
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
for k in self.fields.keys():
|
for k in self.fields.keys():
|
||||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
assert k in kwargs or k in self.defaults
|
||||||
|
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
|
||||||
|
|
||||||
class CheckpointRecord():
|
class CheckpointRecord():
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'}
|
fields = {'checkpoint_hash': 'hash32', 'bitmask': 'bytes'}
|
||||||
defaults = {}
|
defaults = {}
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
for k in self.fields.keys():
|
for k in self.fields.keys():
|
||||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
assert k in kwargs or k in self.defaults
|
||||||
|
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
|
||||||
|
|
||||||
|
|
||||||
class ActiveState():
|
class ActiveState():
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.fields = {'height': 'int64', 'randao': 'hash32',
|
fields = {'height': 'int64', 'randao': 'hash32',
|
||||||
'validator_ffg_voted': 'bytes', 'rewarded': ['int24'],
|
'validator_ffg_voted': 'bytes', 'rewarded': ['int24'],
|
||||||
'penalized': ['int24'], 'checkpoints': [CheckpointRecord],
|
'penalized': ['int24'], 'checkpoints': [CheckpointRecord],
|
||||||
'total_skip_count': 'int64'}
|
'total_skip_count': 'int64'}
|
||||||
defaults = {'height': 0, 'randao': b'\x00'*32,
|
defaults = {'height': 0, 'randao': b'\x00'*32,
|
||||||
'validator_ffg_voted': b'', 'rewarded': [],
|
'validator_ffg_voted': b'', 'rewarded': [],
|
||||||
'penalized': [], 'checkpoints': [], 'total_skip_count': 0}
|
'penalized': [], 'checkpoints': [], 'total_skip_count': 0}
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
for k in self.fields.keys():
|
for k in self.fields.keys():
|
||||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
assert k in kwargs or k in self.defaults
|
||||||
|
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
|
||||||
|
|
||||||
class CrystallizedState():
|
class CrystallizedState():
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -121,7 +129,8 @@ class CrystallizedState():
|
|||||||
'current_checkpoint': b'\x00'*32,
|
'current_checkpoint': b'\x00'*32,
|
||||||
'total_deposits': 0}
|
'total_deposits': 0}
|
||||||
for k in self.fields.keys():
|
for k in self.fields.keys():
|
||||||
setattr(self, k, kwargs.get(k, defaults[k]))
|
assert k in kwargs or k in self.defaults
|
||||||
|
setattr(self, k, kwargs.get(k, self.defaults.get(k)))
|
||||||
|
|
||||||
|
|
||||||
def compute_state_transition(parent_state, block, ):
|
def compute_state_transition(parent_state, block, ):
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
def serialize(val, typ):
|
def serialize(val, typ=None):
|
||||||
|
if typ is None and hasattr(val, 'fields'):
|
||||||
|
typ = type(val)
|
||||||
if typ in ('hash32', 'address'):
|
if typ in ('hash32', 'address'):
|
||||||
assert len(val) == 20 if typ == 'address' else 32
|
assert len(val) == 20 if typ == 'address' else 32
|
||||||
return val
|
return val
|
||||||
elif typ[:3] == 'int':
|
elif isinstance(typ, str) and typ[:3] == 'int':
|
||||||
length = int(typ[3:])
|
length = int(typ[3:])
|
||||||
assert length % 8 == 0
|
assert length % 8 == 0
|
||||||
return val.to_bytes(length // 8, 'big')
|
return val.to_bytes(length // 8, 'big')
|
||||||
@ -21,7 +23,7 @@ def _deserialize(data, start, typ):
|
|||||||
length = 20 if typ == 'address' else 32
|
length = 20 if typ == 'address' else 32
|
||||||
assert len(data) + start >= length
|
assert len(data) + start >= length
|
||||||
return data[start: start+length], start+length
|
return data[start: start+length], start+length
|
||||||
elif typ[:3] == 'int':
|
elif isinstance(typ, str) and typ[:3] == 'int':
|
||||||
length = int(typ[3:])
|
length = int(typ[3:])
|
||||||
assert length % 8 == 0
|
assert length % 8 == 0
|
||||||
assert len(data) + start >= length // 8
|
assert len(data) + start >= length // 8
|
||||||
@ -32,7 +34,7 @@ def _deserialize(data, start, typ):
|
|||||||
return data[start+4: start+4+length], start+4+length
|
return data[start+4: start+4+length], start+4+length
|
||||||
elif isinstance(typ, list):
|
elif isinstance(typ, list):
|
||||||
assert len(typ) == 1
|
assert len(typ) == 1
|
||||||
length = int.from_bytes(data[start:start+4])
|
length = int.from_bytes(data[start:start+4], 'big')
|
||||||
pos, o = start + 4, []
|
pos, o = start + 4, []
|
||||||
while pos < start + 4 + length:
|
while pos < start + 4 + length:
|
||||||
result, pos = _deserialize(data, pos, typ[0])
|
result, pos = _deserialize(data, pos, typ[0])
|
||||||
@ -40,10 +42,12 @@ def _deserialize(data, start, typ):
|
|||||||
assert pos == start + 4 + length
|
assert pos == start + 4 + length
|
||||||
return o, pos
|
return o, pos
|
||||||
elif isinstance(typ, type):
|
elif isinstance(typ, type):
|
||||||
|
length = int.from_bytes(data[start:start+4], 'big')
|
||||||
values = {}
|
values = {}
|
||||||
pos = start
|
pos = start + 4
|
||||||
for k in sorted(typ.fields.keys()):
|
for k in sorted(typ.fields.keys()):
|
||||||
values[k], pos = _deserialize(data, pos, typ.fields[k])
|
values[k], pos = _deserialize(data, pos, typ.fields[k])
|
||||||
|
assert pos == start + 4 + length
|
||||||
return typ(**values), pos
|
return typ(**values), pos
|
||||||
|
|
||||||
def deserialize(data, typ):
|
def deserialize(data, typ):
|
||||||
|
@ -4,7 +4,7 @@ from bls import G1, G2, hash_to_G2, compress_G1, compress_G2, \
|
|||||||
|
|
||||||
from simpleserialize import serialize, deserialize
|
from simpleserialize import serialize, deserialize
|
||||||
|
|
||||||
from full_pos import ActiveState
|
from full_pos import ActiveState, CheckpointRecord
|
||||||
|
|
||||||
for x in (1, 5, 124, 735, 127409812145, 90768492698215092512159, 0):
|
for x in (1, 5, 124, 735, 127409812145, 90768492698215092512159, 0):
|
||||||
print('Testing with privkey %d' % x)
|
print('Testing with privkey %d' % x)
|
||||||
@ -42,7 +42,21 @@ assert serialize(b'cow', 'bytes') == b'\x00\x00\x00\x03cow'
|
|||||||
assert deserialize(b'\x00\x00\x00\x03cow', 'bytes') == b'cow'
|
assert deserialize(b'\x00\x00\x00\x03cow', 'bytes') == b'cow'
|
||||||
|
|
||||||
print('Testing advanced serialization')
|
print('Testing advanced serialization')
|
||||||
|
|
||||||
|
def eq(x, y):
|
||||||
|
if hasattr(x, 'fields') and hasattr(y, 'fields'):
|
||||||
|
for f in x.fields:
|
||||||
|
if not eq(getattr(x, f), getattr(y, f)):
|
||||||
|
print('Unequal:', x, y, f, getattr(x, f), getattr(y, f))
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return x == y
|
||||||
|
|
||||||
s = ActiveState()
|
s = ActiveState()
|
||||||
ds = deserialize(serialize(s))
|
ds = deserialize(serialize(s, type(s)), type(s))
|
||||||
for x in s.fields:
|
assert eq(s, ds)
|
||||||
assert getattr(s, x) == getattr(ds, x)
|
s = ActiveState(checkpoints=[CheckpointRecord(checkpoint_hash=b'\x55'*32, bitmask=b'31337dawg')],
|
||||||
|
height=555, randao=b'\x88'*32, rewarded=[5,7,9,579], penalized=[3]*333)
|
||||||
|
ds = deserialize(serialize(s, type(s)), type(s))
|
||||||
|
assert eq(s, ds)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user