From 770c0a2c781a49e52945e8bab2dd8526d362d56c Mon Sep 17 00:00:00 2001 From: Vitalik Buterin Date: Tue, 10 Jul 2018 08:49:25 -0400 Subject: [PATCH] Reorganized the code somewhat --- mimc_stark/better_lagrange.py | 12 +- mimc_stark/compression.py | 8 +- mimc_stark/fri.py | 128 +++++++++++++++++ mimc_stark/merkle_tree.py | 4 - mimc_stark/mimc_stark.py | 262 +++++++++------------------------- mimc_stark/test.py | 36 +++++ mimc_stark/utils.py | 17 +++ 7 files changed, 264 insertions(+), 203 deletions(-) create mode 100644 mimc_stark/fri.py create mode 100644 mimc_stark/test.py create mode 100644 mimc_stark/utils.py diff --git a/mimc_stark/better_lagrange.py b/mimc_stark/better_lagrange.py index 2d337ef..fa2dd16 100644 --- a/mimc_stark/better_lagrange.py +++ b/mimc_stark/better_lagrange.py @@ -13,7 +13,7 @@ def eval_poly_at(poly, x, modulus): o, p = 0, 1 for coeff in poly: o += coeff * p - p *= x + p = (p * x % modulus) return o % modulus def lagrange_interp_4(pieces, xs, modulus): @@ -35,3 +35,13 @@ def lagrange_interp_4(pieces, xs, modulus): inv_y2 = pieces[2] * invall * e01 * e3 % modulus inv_y3 = pieces[3] * invall * e01 * e2 % modulus return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % modulus for i in range(4)] + +def lagrange_interp_2(pieces, xs, modulus): + eq0 = [-xs[1] % modulus, 1] + eq1 = [-xs[0] % modulus, 1] + e0 = eval_poly_at(eq0, xs[0], modulus) + e1 = eval_poly_at(eq1, xs[1], modulus) + invall = inv(e0 * e1, modulus) + inv_y0 = pieces[0] * invall * e1 + inv_y1 = pieces[1] * invall * e0 + return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % modulus for i in range(2)] diff --git a/mimc_stark/compression.py b/mimc_stark/compression.py index ba50ea4..838a04f 100644 --- a/mimc_stark/compression.py +++ b/mimc_stark/compression.py @@ -2,7 +2,7 @@ def compress_fri(prf): o = [] def add_obj(x): if x in o: - o.append(o.index(x).to_bytes(3, 'big')) + o.append(o.index(x).to_bytes(2, 'big')) else: o.append(x) @@ -26,7 +26,7 @@ def compress_fri(prf): def decompress_fri(proof): def get_obj(pos): - return proof[int.from_bytes(proof[pos], 'big')] if len(proof[pos]) == 3 else proof[pos] + return proof[int.from_bytes(proof[pos], 'big')] if len(proof[pos]) == 2 else proof[pos] o = [] pos = 0 while proof[pos] != b'////': @@ -56,7 +56,7 @@ def compress_branches(branches): o = [] def add_obj(x): if x in o: - o.append(o.index(x).to_bytes(3, 'big')) + o.append(o.index(x).to_bytes(2, 'big')) else: o.append(x) @@ -69,7 +69,7 @@ def compress_branches(branches): def decompress_branches(proof): def get_obj(pos): - return proof[int.from_bytes(proof[pos], 'big')] if len(proof[pos]) == 3 else proof[pos] + return proof[int.from_bytes(proof[pos], 'big')] if len(proof[pos]) == 2 else proof[pos] o = [] pos = 0 while pos < len(proof): diff --git a/mimc_stark/fri.py b/mimc_stark/fri.py new file mode 100644 index 0000000..23b21dc --- /dev/null +++ b/mimc_stark/fri.py @@ -0,0 +1,128 @@ +from better_lagrange import lagrange_interp_4, eval_poly_at +from merkle_tree import merkelize, mk_branch, verify_branch +from utils import get_power_cycle, get_pseudorandom_indices +from fft import fft + +# Generate an FRI proof +def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus): + print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1)) + + # If the degree we are checking for is less than or equal to 32, + # use the polynomial directly as a proof + if maxdeg_plus_1 <= 32: + print('Produced FRI proof') + return [[x.to_bytes(32, 'big') for x in values]] + + # Calculate the set of x coordinates + xs = get_power_cycle(root_of_unity, modulus) + assert len(values) == len(xs) + + # Put the values into a Merkle tree. This is the root that the + # proof will be checked against + m = merkelize(values) + + # Select a pseudo-random x coordinate + special_x = int.from_bytes(m[1], 'big') % modulus + + # Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html) + # at that x coordinate + # We calculate the column by Lagrange-interpolating the row, and not + # directly, as this is more efficient + column = [] + for i in range(len(xs)//4): + x_poly = lagrange_interp_4( + [values[i+len(values)*j//4] for j in range(4)], + [xs[i+len(xs)*j//4] for j in range(4)], + modulus + ) + column.append(eval_poly_at(x_poly, special_x, modulus)) + m2 = merkelize(column) + + # Pseudo-randomly select y indices to sample + ys = get_pseudorandom_indices(m2[1], len(column), 40) + + # Compute the Merkle branches for the values in the polynomial and the column + branches = [] + for y in ys: + branches.append([mk_branch(m2, y)] + [mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)]) + + # This component of the proof + o = [m2[1], branches] + + # In the next iteration of the proof, we'll work with smaller roots of unity + sub_xs = [xs[i] for i in range(0, len(xs), 4)] + + # Interpolate the polynomial for the column + ypoly = fft(column[:len(sub_xs)], modulus, + pow(root_of_unity, 4, modulus), inv=True) + + # Recurse... + return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4, modulus) + +# Verify an FRI proof +def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, modulus): + + # Calculate which root of unity we're working with + testval = root_of_unity + roudeg = 1 + while testval != 1: + roudeg *= 2 + testval = (testval * testval) % modulus + + quartic_roots_of_unity = [1, + pow(root_of_unity, roudeg // 4, modulus), + pow(root_of_unity, roudeg // 2, modulus), + pow(root_of_unity, roudeg * 3 // 4, modulus)] + + # Verify the recursive components of the proof + for prf in proof[:-1]: + root2, branches = prf + print('Verifying degree <= %d' % maxdeg_plus_1) + + # Calculate the pseudo-random x coordinate + special_x = int.from_bytes(merkle_root, 'big') % modulus + + # Calculate the pseudo-randomly sampled y indices + ys = get_pseudorandom_indices(root2, roudeg // 4, 40) + + + # Verify for each selected y coordinate that the four points from the polynomial + # and the one point from the column that are on that y coordinate are on the same + # deg < 4 polynomial + for i, y in enumerate(ys): + # The x coordinates from the polynomial + x1 = pow(root_of_unity, y, modulus) + xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)] + + # The values from the polynomial + row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) for j, prf in zip(range(4), branches[i][1:])] + + # Verify proof and recover the column value + values = [verify_branch(root2, y, branches[i][0])] + row + + # Lagrange interpolate and check deg is < 4 + p = lagrange_interp_4(row, xcoords, modulus) + assert eval_poly_at(p, special_x, modulus) == verify_branch(root2, y, branches[i][0]) + + # Update constants to check the next proof + merkle_root = root2 + root_of_unity = pow(root_of_unity, 4, modulus) + maxdeg_plus_1 //= 4 + roudeg //= 4 + + # Verify the direct components of the proof + data = [int.from_bytes(x, 'big') for x in proof[-1]] + print('Verifying degree <= %d' % maxdeg_plus_1) + assert maxdeg_plus_1 <= 32 + + # Check the Merkle root matches up + mtree = merkelize(data) + assert mtree[1] == merkle_root + + # Check the degree of the data + poly = fft(data, modulus, root_of_unity, inv=True) + for i in range(maxdeg_plus_1, len(poly)): + assert poly[i] == 0 + + print('FRI proof verified') + return True diff --git a/mimc_stark/merkle_tree.py b/mimc_stark/merkle_tree.py index b0c8a35..3c67671 100644 --- a/mimc_stark/merkle_tree.py +++ b/mimc_stark/merkle_tree.py @@ -30,8 +30,4 @@ def verify_branch(root, index, proof): assert v == root return int.from_bytes(proof[0], 'big') -t = merkelize(range(128)) -b = mk_branch(t, 59) -assert verify_branch(t[1], 59, b) == 59 -print('Merkle tree works') diff --git a/mimc_stark/mimc_stark.py b/mimc_stark/mimc_stark.py index 54ba82f..551a01a 100644 --- a/mimc_stark/mimc_stark.py +++ b/mimc_stark/mimc_stark.py @@ -1,162 +1,18 @@ from merkle_tree import merkelize, mk_branch, verify_branch, blake from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length from ecpoly import PrimeField -from better_lagrange import lagrange_interp_4 +from better_lagrange import lagrange_interp_4, lagrange_interp_2 import time from fft import fft +from fri import prove_low_degree, verify_low_degree_proof +from utils import get_power_cycle, get_pseudorandom_indices modulus = 2**256 - 2**32 * 351 + 1 f = PrimeField(modulus) nonresidue = 7 -quartic_roots_of_unity = [1, - pow(7, (modulus-1)//4, modulus), - pow(7, (modulus-1)//2, modulus), - pow(7, (modulus-1)*3//4, modulus)] spot_check_security_factor = 240 -# Get the set of powers of R, until but not including when the powers -# loop back to 1 -def get_power_cycle(r): - o = [1, r] - while o[-1] != 1: - o.append((o[-1] * r) % modulus) - return o[:-1] - -# Extract pseudorandom indices from entropy -def get_indices(seed, modulus, count): - assert modulus < 2**24 - data = seed - while len(data) < 4 * count: - data += blake(data[-32:]) - return [int.from_bytes(data[i: i+4], 'big') % modulus for i in range(0, count * 4, 4)] - -# Generate an FRI proof -def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1): - print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1)) - - # If the degree we are checking for is less than or equal to 32, - # use the polynomial directly as a proof - if maxdeg_plus_1 <= 32: - print('Produced FRI proof') - return [[x.to_bytes(32, 'big') for x in values]] - - # Calculate the set of x coordinates - xs = get_power_cycle(root_of_unity) - - # Put the values into a Merkle tree. This is the root that the - # proof will be checked against - m = merkelize(values) - - # Select a pseudo-random x coordinate - special_x = int.from_bytes(m[1], 'big') % modulus - - # Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html) - # at that x coordinate - # We calculate the column by Lagrange-interpolating the row, and not - # directly, as this is more efficient - column = [] - for i in range(len(xs)//4): - x_poly = lagrange_interp_4( - [values[i+len(values)*j//4] for j in range(4)], - [xs[i+len(xs)*j//4] for j in range(4)], - modulus - ) - column.append(f.eval_poly_at(x_poly, special_x)) - m2 = merkelize(column) - - # Pseudo-randomly select y indices to sample - ys = get_indices(m2[1], len(column), 40) - - # Compute the Merkle branches for the values in the polynomial and the column - branches = [] - for y in ys: - branches.append([mk_branch(m2, y)] + [mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)]) - - # This component of the proof - o = [m2[1], branches] - - # In the next iteration of the proof, we'll work with smaller roots of unity - sub_xs = [xs[i] for i in range(0, len(xs), 4)] - - # Interpolate the polynomial for the column - ypoly = fft(column[:len(sub_xs)], modulus, - pow(root_of_unity, 4, modulus), inv=True) - - # Recurse... - return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4) - -# Verify an FRI proof -def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1): - - # Calculate which root of unity we're working with - testval = root_of_unity - roudeg = 1 - while testval != 1: - roudeg *= 2 - testval = (testval * testval) % modulus - - # Verify the recursive components of the proof - for prf in proof[:-1]: - root2, branches = prf - print('Verifying degree <= %d' % maxdeg_plus_1) - - # Calculate the pseudo-random x coordinate - special_x = int.from_bytes(merkle_root, 'big') % modulus - - # Calculate the pseudo-randomly sampled y indices - ys = get_indices(root2, roudeg // 4, 40) - - - # Verify for each selected y coordinate that the four points from the polynomial - # and the one point from the column that are on that y coordinate are on the same - # deg < 4 polynomial - for i, y in enumerate(ys): - # The x coordinates from the polynomial - x1 = pow(root_of_unity, y, modulus) - xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)] - - # The values from the polynomial - row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) for j, prf in zip(range(4), branches[i][1:])] - - # Verify proof and recover the column value - values = [verify_branch(root2, y, branches[i][0])] + row - - # Lagrange interpolate and check deg is < 4 - p = lagrange_interp_4(row, xcoords, modulus) - assert f.eval_poly_at(p, special_x) == verify_branch(root2, y, branches[i][0]) - - # Update constants to check the next proof - merkle_root = root2 - root_of_unity = pow(root_of_unity, 4, modulus) - maxdeg_plus_1 //= 4 - roudeg //= 4 - - # Verify the direct components of the proof - data = [int.from_bytes(x, 'big') for x in proof[-1]] - print('Verifying degree <= %d' % maxdeg_plus_1) - assert maxdeg_plus_1 <= 32 - - # Check the Merkle root matches up - mtree = merkelize(data) - assert mtree[1] == merkle_root - - # Check the degree of the data - poly = fft(data, modulus, root_of_unity, inv=True) - for i in range(maxdeg_plus_1, len(poly)): - assert poly[i] == 0 - - print('FRI proof verified') - return True - -# Pure FRI tests -poly = list(range(512)) -root_of_unity = pow(7, (modulus-1)//1024, modulus) -evaluations = fft(poly, modulus, root_of_unity) -proof = prove_low_degree(poly, root_of_unity, evaluations, 512) -print("Approx proof length: %d" % bin_length(compress_fri(proof))) -assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512) - # Compute a MIMC permutation for 2**logsteps steps, using round constants # from the multiplicative subgroup of size 2**logprecision def mimc(inp, logsteps, logprecision): @@ -196,15 +52,16 @@ def mk_mimc_proof(inp, logsteps, logprecision): precision = 2**logprecision # Root of unity such that x^precision=1 - root = pow(7, (modulus-1)//precision, modulus) + root_of_unity = pow(7, (modulus-1)//precision, modulus) # Root of unity such that x^skips=1 skips = precision // steps - subroot = pow(root, skips) + subroot = pow(root_of_unity, skips) # Powers of the root of unity, our computational trace will be # along the sequence of roots of unity - xs = get_power_cycle(subroot) + xs = get_power_cycle(subroot, modulus) + last_step_position = xs[steps-1] # Generate the computational trace constants = [] @@ -215,6 +72,7 @@ def mk_mimc_proof(inp, logsteps, logprecision): constants.append(k ^ 1) k = (k * 9) & ((1 << 256) - 1) constants.append(0) + output = values[-1] print('Done generating computational trace') # Interpolate the computational trace into a polynomial @@ -225,37 +83,56 @@ def mk_mimc_proof(inp, logsteps, logprecision): # Create the composed polynomial such that # C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x) term1 = multiply_base(values_polynomial, subroot) - p_evaluations = fft(values_polynomial, modulus, root) - term2 = fft([pow(x, 3, modulus) for x in p_evaluations], modulus, root, inv=True)[:len(values_polynomial) * 3 - 2] + p_evaluations = fft(values_polynomial, modulus, root_of_unity) + term2 = fft([pow(x, 3, modulus) for x in p_evaluations], modulus, root_of_unity, inv=True)[:len(values_polynomial) * 3 - 2] c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial) print('Computed C(P, K) polynomial') # Compute D(x) = C(P(x), P(rx), K(x)) / Z(x) # Z(x) = (x^steps - 1) / (x - x_atlast_step) d = divide_by_xnm1(f.mul_polys(c_of_values, - [modulus-xs[steps-1], 1]), + [-last_step_position, 1]), steps) - # assert f.mul_polys(d, z) == c_of_values + # Consistency check + assert (f.eval_poly_at(d, 90833) * + (pow(90833, steps, modulus) - 1) * + f.inv(f.eval_poly_at([-last_step_position, 1], 90833)) - + f.eval_poly_at(c_of_values, 90833)) % modulus == 0 print('Computed D polynomial') - # Evaluate D and K across the entire subgroup - d_evaluations = fft(d, modulus, root) - k_evaluations = fft(constants_polynomial, modulus, root) - print('Evaluated P, D and K') + # Compute interpolant of ((1, input), (x_atlast_step, output)) + interpolant = lagrange_interp_2([inp, output], [1, last_step_position], modulus) + quotient = f.mul_polys([-1, 1], [-last_step_position, 1]) + b = f.div_polys(f.sub_polys(values_polynomial, interpolant), quotient) + # Consistency check + assert f.eval_poly_at(f.add_polys(f.mul_polys(b, quotient), interpolant), 7045) == f.eval_poly_at(values_polynomial, 7045) + print('Computed B polynomial') + + # Evaluate B, D and K across the entire subgroup + d_evaluations = fft(d, modulus, root_of_unity) + k_evaluations = fft(constants_polynomial, modulus, root_of_unity) + b_evaluations = fft(b, modulus, root_of_unity) + print('Evaluated low-degree extension of B, D and K') # Compute their Merkle roots p_mtree = merkelize(p_evaluations) d_mtree = merkelize(d_evaluations) k_mtree = merkelize(k_evaluations) + b_mtree = merkelize(b_evaluations) print('Computed hash root') - # Based on the hashes of P and D, we select a random linear combination - # of P * x^steps and D, and prove the low-degreeness of that, instead of proving - # the low-degreeness of P and D separately - k = int.from_bytes(blake(p_mtree[1] + d_mtree[1]), 'big') + # Based on the hashes of P, D and B, we select a random linear combination + # of P * x^steps, P, B * x^steps, B and D, and prove the low-degreeness of that, + # instead of proving the low-degreeness of P, B and D separately + k1 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x01'), 'big') + k2 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x02'), 'big') + k3 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x03'), 'big') + k4 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x04'), 'big') - lincomb = f.add_polys(d, f.mul_by_const([0] * steps + values_polynomial, k)) - l_evaluations = fft(lincomb, modulus, root) + lincomb = f.add_polys(f.add_polys(d, + f.mul_by_const(values_polynomial, k1) + f.mul_by_const(values_polynomial, k2)), + f.mul_by_const(b, k3) + [0, 0] + f.mul_by_const(b, k4) + [0,0]) + l_evaluations = fft(lincomb, modulus, root_of_unity) l_mtree = merkelize(l_evaluations) print('Computed random linear combination') @@ -263,30 +140,31 @@ def mk_mimc_proof(inp, logsteps, logprecision): # Do some spot checks of the Merkle tree at pseudo-random coordinates branches = [] samples = spot_check_security_factor // (logprecision - logsteps) - positions = get_indices(l_mtree[1], precision - skips, samples) + positions = get_pseudorandom_indices(l_mtree[1], precision - skips, samples) for pos in positions: branches.append(mk_branch(p_mtree, pos)) branches.append(mk_branch(p_mtree, pos + skips)) branches.append(mk_branch(d_mtree, pos)) branches.append(mk_branch(k_mtree, pos)) + branches.append(mk_branch(b_mtree, pos)) branches.append(mk_branch(l_mtree, pos)) print('Computed %d spot checks' % samples) - # Return the Merkle roots of P and D, the spot check Merkle proofs, # and low-degree proofs of P and D o = [p_mtree[1], d_mtree[1], k_mtree[1], + b_mtree[1], l_mtree[1], branches, - prove_low_degree(lincomb, root, l_evaluations, steps * 2)] + prove_low_degree(lincomb, root_of_unity, l_evaluations, steps * 2, modulus)] print("STARK computed in %.4f sec" % (time.time() - start_time)) return o # Verifies a STARK def verify_mimc_proof(inp, logsteps, logprecision, output, proof): - p_root, d_root, k_root, l_root, branches, fri_proof = proof + p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof start_time = time.time() steps = 2**logsteps @@ -297,43 +175,39 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, proof): skips = precision // steps # Verifies the low-degree proofs - assert verify_low_degree_proof(l_root, root_of_unity, fri_proof, steps * 2) + assert verify_low_degree_proof(l_root, root_of_unity, fri_proof, steps * 2, modulus) # Performs the spot checks - k = int.from_bytes(blake(p_root + d_root), 'big') + k1 = int.from_bytes(blake(p_root + d_root + b_root + b'\x01'), 'big') + k2 = int.from_bytes(blake(p_root + d_root + b_root + b'\x02'), 'big') + k3 = int.from_bytes(blake(p_root + d_root + b_root + b'\x03'), 'big') + k4 = int.from_bytes(blake(p_root + d_root + b_root + b'\x04'), 'big') samples = spot_check_security_factor // (logprecision - logsteps) - positions = get_indices(l_root, precision - skips, samples) + positions = get_pseudorandom_indices(l_root, precision - skips, samples) + last_step_position = pow(root_of_unity, (steps - 1) * skips, modulus) for i, pos in enumerate(positions): - # Check C(P(x)) = Z(x) * D(x) x = pow(root_of_unity, pos, modulus) - p_of_x = verify_branch(p_root, pos, branches[i*5]) - p_of_rx = verify_branch(p_root, pos+skips, branches[i*5 + 1]) - d_of_x = verify_branch(d_root, pos, branches[i*5 + 2]) - k_of_x = verify_branch(k_root, pos, branches[i*5 + 3]) - l_of_x = verify_branch(l_root, pos, branches[i*5 + 4]) + x_to_the_steps = pow(x, steps, modulus) + p_of_x = verify_branch(p_root, pos, branches[i*6]) + p_of_rx = verify_branch(p_root, pos+skips, branches[i*6 + 1]) + d_of_x = verify_branch(d_root, pos, branches[i*6 + 2]) + k_of_x = verify_branch(k_root, pos, branches[i*6 + 3]) + b_of_x = verify_branch(b_root, pos, branches[i*6 + 4]) + l_of_x = verify_branch(l_root, pos, branches[i*6 + 5]) zvalue = f.div(pow(x, steps, modulus) - 1, - x - pow(root_of_unity, (steps - 1) * skips, modulus)) + x - last_step_position) assert (p_of_rx - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0 - assert (l_of_x - d_of_x - k * p_of_x * pow(x, steps, modulus)) % modulus == 0 + interpolant = lagrange_interp_2([inp, output], [1, last_step_position], modulus) + quotient = f.mul_polys([-1, 1], [-last_step_position, 1]) + assert (p_of_x - b_of_x * f.eval_poly_at(quotient, x) - + f.eval_poly_at(interpolant, x)) % modulus == 0 + assert (l_of_x - d_of_x - + k1 * p_of_x - k2 * p_of_x * x_to_the_steps - + k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0 print('Verified %d consistency checks' % (spot_check_security_factor // (logprecision - logsteps))) print('Verified STARK in %.4f sec' % (time.time() - start_time)) print('Note: this does not include verifying the Merkle root of the constants tree') print('This can be done by every client once as a precomputation') return True - -INPUT = 3 -LOGSTEPS = 17 -LOGPRECISION = 20 - -# Full STARK test -proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION) -p_root, d_root, k_root, l_root, branches, fri_proof = proof -L1 = bin_length(compress_branches(branches)) -L2 = bin_length(compress_fri(fri_proof)) -print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2)) -root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus) -subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus) -skips = 2**(LOGPRECISION - LOGSTEPS) -assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), proof) diff --git a/mimc_stark/test.py b/mimc_stark/test.py new file mode 100644 index 0000000..c306ff1 --- /dev/null +++ b/mimc_stark/test.py @@ -0,0 +1,36 @@ +from fft import fft +from mimc_stark import mk_mimc_proof, modulus, mimc, verify_mimc_proof +from compression import compress_fri, compress_branches, bin_length +from merkle_tree import merkelize, mk_branch, verify_branch +from fri import prove_low_degree, verify_low_degree_proof + +def test_merkletree(): + t = merkelize(range(128)) + b = mk_branch(t, 59) + assert verify_branch(t[1], 59, b) == 59 + print('Merkle tree works') + +def test_fri(): + # Pure FRI tests + poly = list(range(512)) + root_of_unity = pow(7, (modulus-1)//1024, modulus) + evaluations = fft(poly, modulus, root_of_unity) + proof = prove_low_degree(poly, root_of_unity, evaluations, 512, modulus) + print("Approx proof length: %d" % bin_length(compress_fri(proof))) + assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512, modulus) + +def test_stark(): + INPUT = 3 + LOGSTEPS = 13 + LOGPRECISION = 16 + + # Full STARK test + proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION) + p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof + L1 = bin_length(compress_branches(branches)) + L2 = bin_length(compress_fri(fri_proof)) + print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2)) + root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus) + subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus) + skips = 2**(LOGPRECISION - LOGSTEPS) + assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), proof) diff --git a/mimc_stark/utils.py b/mimc_stark/utils.py new file mode 100644 index 0000000..59b2b0d --- /dev/null +++ b/mimc_stark/utils.py @@ -0,0 +1,17 @@ +from merkle_tree import blake + +# Get the set of powers of R, until but not including when the powers +# loop back to 1 +def get_power_cycle(r, modulus): + o = [1, r] + while o[-1] != 1: + o.append((o[-1] * r) % modulus) + return o[:-1] + +# Extract pseudorandom indices from entropy +def get_pseudorandom_indices(seed, modulus, count): + assert modulus < 2**24 + data = seed + while len(data) < 4 * count: + data += blake(data[-32:]) + return [int.from_bytes(data[i: i+4], 'big') % modulus for i in range(0, count * 4, 4)]