research/mimc_stark/test.py

52 lines
2.0 KiB
Python
Raw Normal View History

2018-07-10 12:49:25 +00:00
from fft import fft
from mimc_stark import mk_mimc_proof, modulus, mimc, verify_mimc_proof
from compression import compress_fri, compress_branches, bin_length
from merkle_tree import merkelize, mk_branch, verify_branch
from fri import prove_low_degree, verify_low_degree_proof
def test_merkletree():
t = merkelize(range(128))
b = mk_branch(t, 59)
assert verify_branch(t[1], 59, b) == 59
print('Merkle tree works')
def test_fri():
# Pure FRI tests
poly = list(range(4096))
root_of_unity = pow(7, (modulus-1)//16384, modulus)
2018-07-10 12:49:25 +00:00
evaluations = fft(poly, modulus, root_of_unity)
proof = prove_low_degree(evaluations, root_of_unity, 4096, modulus)
2018-07-10 12:49:25 +00:00
print("Approx proof length: %d" % bin_length(compress_fri(proof)))
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 4096, modulus)
try:
fakedata = [x if pow(3, i, 4096) > 400 else 39 for x, i in enumerate(evaluations)]
proof2 = prove_low_degree(fakedata, root_of_unity, 4096, modulus)
assert verify_low_degree_proof(merkelize(fakedata)[1], root_of_unity, proof, 4096, modulus)
raise Exception("Fake data passed FRI")
except:
pass
try:
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 2048, modulus)
raise Exception("Fake data passed FRI")
except:
pass
2018-07-10 12:49:25 +00:00
def test_stark():
INPUT = 3
2018-07-11 15:46:21 +00:00
import sys
LOGSTEPS = int(sys.argv[1]) if len(sys.argv) > 1 else 13
2018-07-10 12:49:25 +00:00
# Full STARK test
import random
2018-07-11 15:46:21 +00:00
#constants = [random.randrange(modulus) for i in range(64)]
constants = [(i**7) ^ 42 for i in range(64)]
proof = mk_mimc_proof(INPUT, LOGSTEPS, constants)
p_root, d_root, b_root, l_root, branches, fri_proof = proof
2018-07-10 12:49:25 +00:00
L1 = bin_length(compress_branches(branches))
L2 = bin_length(compress_fri(fri_proof))
print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2))
assert verify_mimc_proof(3, LOGSTEPS, constants, mimc(3, LOGSTEPS, constants), proof)
if __name__ == '__main__':
test_stark()