research/mimc_stark/compression.py

90 lines
2.5 KiB
Python
Raw Normal View History

2018-06-29 04:26:17 -04:00
def compress_fri(prf):
o = []
oindex = {}
2018-06-29 04:26:17 -04:00
def add_obj(x):
if x in oindex:
o.append(oindex[x].to_bytes(2, 'big'))
2018-06-29 04:26:17 -04:00
else:
o.append(x)
oindex[x] = len(o)-1
2018-06-29 04:26:17 -04:00
for root, yproofs in prf[:-1]:
2018-06-30 02:57:10 -04:00
# print('Adding proof item, pos %d' % len(o))
2018-06-29 04:26:17 -04:00
add_obj(b'----')
add_obj(root)
for yproof in yproofs:
for branch in yproof:
for p in branch:
add_obj(p)
add_obj(b'++++')
add_obj(b'====')
2018-06-30 02:57:10 -04:00
# print('Adding final proof, pos %d' % len(o))
2018-06-29 04:26:17 -04:00
add_obj(b'////')
for x in prf[-1]:
add_obj(x)
assert decompress_fri(o) == prf
return o
def decompress_fri(proof):
def get_obj(pos):
2018-07-10 08:49:25 -04:00
return proof[int.from_bytes(proof[pos], 'big')] if len(proof[pos]) == 2 else proof[pos]
2018-06-29 04:26:17 -04:00
o = []
pos = 0
while proof[pos] != b'////':
2018-06-30 02:57:10 -04:00
# print("Processing proof item", pos)
2018-06-29 04:26:17 -04:00
assert get_obj(pos) == b'----'
root = get_obj(pos + 1)
pos += 2
yproofs = []
while get_obj(pos) not in (b'----', b'////'):
yproof = []
while get_obj(pos) != b'====':
branch = []
while get_obj(pos) != b'++++':
branch.append(get_obj(pos))
pos += 1
yproof.append(branch)
pos += 1
yproofs.append(yproof)
pos += 1
o.append([root, yproofs])
2018-06-30 02:57:10 -04:00
# print('Processing final proof, pos %d' % pos)
2018-06-29 04:26:17 -04:00
pos += 1
o.append([get_obj(x) for x in range(pos, len(proof))])
return o
def compress_branches(branches):
o = []
oindex = {}
2018-06-29 04:26:17 -04:00
def add_obj(x):
if x in oindex:
o.append(oindex[x].to_bytes(2, 'big'))
2018-06-29 04:26:17 -04:00
else:
o.append(x)
oindex[x] = len(o)-1
2018-06-29 04:26:17 -04:00
for branch in branches:
for p in branch:
add_obj(p)
add_obj(b'----')
assert decompress_branches(o) == branches
return o
def decompress_branches(proof):
def get_obj(pos):
2018-07-10 08:49:25 -04:00
return proof[int.from_bytes(proof[pos], 'big')] if len(proof[pos]) == 2 else proof[pos]
2018-06-29 04:26:17 -04:00
o = []
pos = 0
while pos < len(proof):
branch = []
while pos < len(proof) and get_obj(pos) != b'----':
branch.append(get_obj(pos))
pos += 1
o.append(branch)
pos += 1
return o
def bin_length(c):
return len(b''.join([(b'\xff' if len(x) == 32 else b'') + x for x in c]))