research/mimc_stark/compression.py

86 lines
2.3 KiB
Python

def compress_fri(prf):
o = []
def add_obj(x):
if x in o:
o.append(o.index(x).to_bytes(2, 'big'))
else:
o.append(x)
for root, yproofs in prf[:-1]:
# print('Adding proof item, pos %d' % len(o))
add_obj(b'----')
add_obj(root)
for yproof in yproofs:
for branch in yproof:
for p in branch:
add_obj(p)
add_obj(b'++++')
add_obj(b'====')
# print('Adding final proof, pos %d' % len(o))
add_obj(b'////')
for x in prf[-1]:
add_obj(x)
assert decompress_fri(o) == prf
return o
def decompress_fri(proof):
def get_obj(pos):
return proof[int.from_bytes(proof[pos], 'big')] if len(proof[pos]) == 2 else proof[pos]
o = []
pos = 0
while proof[pos] != b'////':
# print("Processing proof item", pos)
assert get_obj(pos) == b'----'
root = get_obj(pos + 1)
pos += 2
yproofs = []
while get_obj(pos) not in (b'----', b'////'):
yproof = []
while get_obj(pos) != b'====':
branch = []
while get_obj(pos) != b'++++':
branch.append(get_obj(pos))
pos += 1
yproof.append(branch)
pos += 1
yproofs.append(yproof)
pos += 1
o.append([root, yproofs])
# print('Processing final proof, pos %d' % pos)
pos += 1
o.append([get_obj(x) for x in range(pos, len(proof))])
return o
def compress_branches(branches):
o = []
def add_obj(x):
if x in o:
o.append(o.index(x).to_bytes(2, 'big'))
else:
o.append(x)
for branch in branches:
for p in branch:
add_obj(p)
add_obj(b'----')
assert decompress_branches(o) == branches
return o
def decompress_branches(proof):
def get_obj(pos):
return proof[int.from_bytes(proof[pos], 'big')] if len(proof[pos]) == 2 else proof[pos]
o = []
pos = 0
while pos < len(proof):
branch = []
while pos < len(proof) and get_obj(pos) != b'----':
branch.append(get_obj(pos))
pos += 1
o.append(branch)
pos += 1
return o
def bin_length(c):
return len(b''.join([(b'\xff' if len(x) == 32 else b'') + x for x in c]))