diff --git a/mimc_stark/README.md b/mimc_stark/README.md index 7f5500b..f742403 100644 --- a/mimc_stark/README.md +++ b/mimc_stark/README.md @@ -1,17 +1,21 @@ +#### Disclaimer + +DO NOT USE FOR ANYTHING IN REAL LIFE. DO NOT ASSUME THE PROTOCOL DESCRIBED HERE IS SOUND. TALK TO A SPECIALIST IF YOU'RE LOOKING TO USE STARKS IN YOUR APPLICATION. + #### What is this? This is a very basic implementation of a STARK on a MIMC computation that is probably (ie. definitely) broken in a few places but is intended as a proof of concept to show the rough level of complexity that is involved in implementing a simple STARK. A STARK is a really cool proof-of-computation scheme that allows you to create an efficiently verifiable proof that some computation was executed correctly; the verification time only rises logarithmically with the computation time, and that relies only on hashes and information theory for security. -The STARKs are done over a finite field chosen to have 2^32'th roots of unity (to facilitate STARKs), and NOT have 3rd roots of unity (to facilitate MIMC). The MIMC permutation in general takes the form: +The STARKs are done over a finite field chosen to have `2**32`'th roots of unity (to facilitate STARKs), and NOT have 3rd roots of unity (to facilitate MIMC). The MIMC permutation in general takes the form: - k_1 k_2 - | | - v v - x0 ---> (x->x^3) ---> xor ---> (x->x^3) ---> xor --- ... ---> output + k_1 k_2 + | | + v v + x0 ---> (x->x**3) ---> xor ---> (x->x**3) ---> xor --- ... ---> output Where the `k_i` values are round constants. MIMC can be used as a building block in a hash function, or as a verifiable delay function; its simple arithmetic representation makes it ideal for use in STARKs, SNARKs, MPC and other "cryptography over general-purpose computation" schemes. -The MIMC round constants used here are 2^k'th roots of unity for k <= 32. The computational trace is computed over successive powers of a 2^k'th root of unity. This allows the constraint checking polynomial to be `C(P(x), P(r*x), x) = P(r*x) - P(x)**3 - x`. +The MIMC round constants used here are successive powers of 9 mod `2**256` xored with 1, though could be set to anything. The computational trace is computed over successive powers of a `2**k`'th root of unity. This allows the constraint checking polynomial to be `C(P(x), P(r*x), K(x)) = P(r*x) - P(x)**3 - K(x)`. For a description of how STARKs work, see: @@ -22,4 +26,18 @@ For more discussion on MIMC, see: * [Zcash issue 2233](https://github.com/zcash/zcash/issues/2233) -The implementation is definitely suboptimal. For example, one optimization would be to reduce the two FRIs to one by low-degree proving a random linear combination k1 * D + k2 * P^2; this could cut the verification time by close to half. Another would be to remove redundancies in computing zero polynomials in multiple places; a more sophisticated optimization would be to move to subquadratic algorithms (Karatsuba or ideally FFT-based) for polynomial interpolation. The purpose is for the code to be deliberately simple to aid with understanding. +### How does the STARK scheme here work? + +Here are the approximate steps in the code. Note that all arithmetic is done over the finite field of integers mod `2**256 - 2**32 * 351 + 1`. + +1. Let `P[0]` equal the input, and `P[i+1] = P[i]**3 + K[i]`, up to `steps` (where `steps` must be a power of 2) +2. Construct a polynomial where `P(subroot ^ i) = P[i]` up to steps-1, where `subroot` is a steps'th root of unity (that is, `subroot**steps = 1`). Do the same with K. +3. Construct the polynomial `CP(x) = C(P(x), P(x * subroot), K(x))`. Note that since `P(x * subroot) = P(x) ^ 3 + K(x), CP(x) = 0` for any x inside the computation trace (that is, powers of subroot except the last). +4. Construct the polynomial `Z(x)`, which is the minimal polynomial that is 0 across the computation trace. Note that because Z is minimal, CP must be a multiple of Z. +5. Construct `D = CP / Z`. Because CP is a multiple of Z, D must itself be a low-degree polynomial. +6. Put D and P into Merkle trees. +7. Construct `L = D + k * x**steps * P`, where `k` is selected based on the Merkle roots of D and P. +8. Create an FRI proof that L is low-degree ((7) and (8) together are a clever way of making an aggregate low-degree proof of D and P) +9. Create probabilistic checks, using the Merkle root of L as source data, to spot-check that D(x) * Z(x) actually equal C(P(x)). + +The probabilistic checks and the FRI proof are themselves restricted to `2**precision`'th roots of unity, where `precision = 8 * steps` (so we're FRI checking that the degree of L, which equals twice the degree of P, is at most 1/4 the theoretical maximum). diff --git a/mimc_stark/better_lagrange.py b/mimc_stark/better_lagrange.py new file mode 100644 index 0000000..2d337ef --- /dev/null +++ b/mimc_stark/better_lagrange.py @@ -0,0 +1,37 @@ +def inv(a, modulus): + if a == 0: + return 0 + lm, hm = 1, 0 + low, high = a % modulus, modulus + while low > 1: + r = high//low + nm, new = hm-lm*r, high-low*r + lm, low, hm, high = nm, new, lm, low + return lm % modulus + +def eval_poly_at(poly, x, modulus): + o, p = 0, 1 + for coeff in poly: + o += coeff * p + p *= x + return o % modulus + +def lagrange_interp_4(pieces, xs, modulus): + x01, x02, x03, x12, x13, x23 = \ + xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3] + eq0 = [-x12 * xs[3] % modulus, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1] + eq1 = [-x02 * xs[3] % modulus, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1] + eq2 = [-x01 * xs[3] % modulus, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1] + eq3 = [-x01 * xs[2] % modulus, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1] + e0 = eval_poly_at(eq0, xs[0], modulus) + e1 = eval_poly_at(eq1, xs[1], modulus) + e2 = eval_poly_at(eq2, xs[2], modulus) + e3 = eval_poly_at(eq3, xs[3], modulus) + e01 = e0 * e1 + e23 = e2 * e3 + invall = inv(e01 * e23, modulus) + inv_y0 = pieces[0] * invall * e1 * e23 % modulus + inv_y1 = pieces[1] * invall * e0 * e23 % modulus + inv_y2 = pieces[2] * invall * e01 * e3 % modulus + inv_y3 = pieces[3] * invall * e01 * e2 % modulus + return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % modulus for i in range(4)] diff --git a/mimc_stark/fft.py b/mimc_stark/fft.py index 2b7f499..e2500c4 100644 --- a/mimc_stark/fft.py +++ b/mimc_stark/fft.py @@ -1,6 +1,15 @@ +def _simple_ft(vals, modulus, roots_of_unity): + L = len(roots_of_unity) + o = [0 for _ in range(L)] + for i in range(L): + for j in range(L): + o[i] += vals[j] * roots_of_unity[(i*j)%L] + return [x % modulus for x in o] + def _fft(vals, modulus, roots_of_unity): if len(vals) == 1: return vals + # return _simple_ft(vals, modulus, roots_of_unity) L = _fft(vals[::2], modulus, roots_of_unity[::2]) R = _fft(vals[1::2], modulus, roots_of_unity[::2]) o = [0 for i in vals] @@ -8,7 +17,6 @@ def _fft(vals, modulus, roots_of_unity): y_times_root = y*roots_of_unity[i] o[i] = (x+y_times_root) % modulus o[i+len(L)] = (x-y_times_root) % modulus - # print(vals, root_of_unity, o) return o def fft(vals, modulus, root_of_unity, inv=False): @@ -22,10 +30,11 @@ def fft(vals, modulus, root_of_unity, inv=False): if inv: # Inverse FFT invlen = pow(len(vals), modulus-2, modulus) - return [(x*invlen) % modulus for x in _fft(vals, modulus, rootz[::-1])] + return [(x*invlen) % modulus for x in + _fft(vals, modulus, rootz[:0:-1])] else: # Regular FFT - return _fft(vals, modulus, rootz) + return _fft(vals, modulus, rootz[:-1]) def mul_polys(a, b, modulus, root_of_unity): x1 = fft(a, modulus, root_of_unity) diff --git a/mimc_stark/mimc_stark.py b/mimc_stark/mimc_stark.py index 7f665d6..54ba82f 100644 --- a/mimc_stark/mimc_stark.py +++ b/mimc_stark/mimc_stark.py @@ -1,8 +1,9 @@ from merkle_tree import merkelize, mk_branch, verify_branch, blake from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length from ecpoly import PrimeField -from fft import fft, mul_polys +from better_lagrange import lagrange_interp_4 import time +from fft import fft modulus = 2**256 - 2**32 * 351 + 1 f = PrimeField(modulus) @@ -14,18 +15,6 @@ quartic_roots_of_unity = [1, spot_check_security_factor = 240 -# Treat a polynomial as a bivariate polynomial g(x, y) and -# evaluate it as such. Invariant: eval_as_bivariate(p, x, x**4) = eval(p, x) -def eval_as_bivariate(p, x, y): - o = 0 - ypow = 1 - xpows = [pow(x, i, modulus) for i in range(4)] - for i in range(0, len(p), 4): - for j in range(4): - o += xpows[j] * ypow * p[i+j] - ypow = (ypow * y) % modulus - return o % modulus - # Get the set of powers of R, until but not including when the powers # loop back to 1 def get_power_cycle(r): @@ -68,9 +57,10 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1): # directly, as this is more efficient column = [] for i in range(len(xs)//4): - x_poly = f.lagrange_interp( + x_poly = lagrange_interp_4( [values[i+len(values)*j//4] for j in range(4)], - [xs[i+len(xs)*j//4] for j in range(4)] + [xs[i+len(xs)*j//4] for j in range(4)], + modulus ) column.append(f.eval_poly_at(x_poly, special_x)) m2 = merkelize(column) @@ -119,12 +109,12 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1): # Verify for each selected y coordinate that the four points from the polynomial - # and the one point from the column that are on that y coordinate are on a + # and the one point from the column that are on that y coordinate are on the same # deg < 4 polynomial for i, y in enumerate(ys): - # The five x coordinates we are checking + # The x coordinates from the polynomial x1 = pow(root_of_unity, y, modulus) - eckses = [special_x] + [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)] + xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)] # The values from the polynomial row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) for j, prf in zip(range(4), branches[i][1:])] @@ -133,8 +123,8 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1): values = [verify_branch(root2, y, branches[i][0])] + row # Lagrange interpolate and check deg is < 4 - p = f.lagrange_interp(values, eckses) - assert p[4] == 0 + p = lagrange_interp_4(row, xcoords, modulus) + assert f.eval_poly_at(p, special_x) == verify_branch(root2, y, branches[i][0]) # Update constants to check the next proof merkle_root = root2 @@ -175,9 +165,11 @@ def mimc(inp, logsteps, logprecision): precision = 2**logprecision # Get (steps)th root of unity subroot = pow(7, (modulus-1)//steps, modulus) - xs = get_power_cycle(subroot) + # We use powers of 9 mod 2^256 as the ith round constant for the moment + k = 1 for i in range(steps-1): - inp = (inp**3 + xs[i]) % modulus + inp = (inp**3 + (k ^ 1)) % modulus + k = (k * 9) & ((1 << 256) - 1) print("MIMC computed in %.4f sec" % (time.time() - start_time)) return inp @@ -215,27 +207,30 @@ def mk_mimc_proof(inp, logsteps, logprecision): xs = get_power_cycle(subroot) # Generate the computational trace + constants = [] values = [inp] + k = 1 for i in range(steps-1): - values.append((values[-1]**3 + xs[i]) % modulus) + values.append((values[-1]**3 + (k ^ 1)) % modulus) + constants.append(k ^ 1) + k = (k * 9) & ((1 << 256) - 1) + constants.append(0) print('Done generating computational trace') # Interpolate the computational trace into a polynomial - # values_polynomial = f.lagrange_interp(values, [pow(subroot, i, modulus) for i in range(steps)]) values_polynomial = fft(values, modulus, subroot, inv=True) - print('Computed polynomial') - - #for x, v in zip(xs, values): - # assert f.eval_poly_at(values_polynomial, x) == v + constants_polynomial = fft(constants, modulus, subroot, inv=True) + print('Converted computational steps and constants into a polynomial') # Create the composed polynomial such that - # C(P(x), P(rx)) = P(rx) - P(x)**3 - x + # C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x) term1 = multiply_base(values_polynomial, subroot) - term2 = fft([pow(x, 3, modulus) for x in fft(values_polynomial, modulus, root)], modulus, root, inv=True)[:len(values_polynomial) * 3 - 2] - c_of_values = f.sub_polys(f.sub_polys(term1, term2), [0, 1]) - print('Computed C(P) polynomial') + p_evaluations = fft(values_polynomial, modulus, root) + term2 = fft([pow(x, 3, modulus) for x in p_evaluations], modulus, root, inv=True)[:len(values_polynomial) * 3 - 2] + c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial) + print('Computed C(P, K) polynomial') - # Compute D(x) = C(P(x)) / Z(x) + # Compute D(x) = C(P(x), P(rx), K(x)) / Z(x) # Z(x) = (x^steps - 1) / (x - x_atlast_step) d = divide_by_xnm1(f.mul_polys(c_of_values, [modulus-xs[steps-1], 1]), @@ -243,42 +238,55 @@ def mk_mimc_proof(inp, logsteps, logprecision): # assert f.mul_polys(d, z) == c_of_values print('Computed D polynomial') - # Evaluate P and D across the entire subgroup - p_evaluations = fft(values_polynomial, modulus, root) + # Evaluate D and K across the entire subgroup d_evaluations = fft(d, modulus, root) - print('Evaluated P and D') + k_evaluations = fft(constants_polynomial, modulus, root) + print('Evaluated P, D and K') # Compute their Merkle roots p_mtree = merkelize(p_evaluations) d_mtree = merkelize(d_evaluations) + k_mtree = merkelize(k_evaluations) print('Computed hash root') + # Based on the hashes of P and D, we select a random linear combination + # of P * x^steps and D, and prove the low-degreeness of that, instead of proving + # the low-degreeness of P and D separately + k = int.from_bytes(blake(p_mtree[1] + d_mtree[1]), 'big') + + lincomb = f.add_polys(d, f.mul_by_const([0] * steps + values_polynomial, k)) + l_evaluations = fft(lincomb, modulus, root) + l_mtree = merkelize(l_evaluations) + + print('Computed random linear combination') + # Do some spot checks of the Merkle tree at pseudo-random coordinates branches = [] samples = spot_check_security_factor // (logprecision - logsteps) - positions = get_indices(blake(p_mtree[1] + d_mtree[1]), precision - skips, samples) + positions = get_indices(l_mtree[1], precision - skips, samples) for pos in positions: branches.append(mk_branch(p_mtree, pos)) branches.append(mk_branch(p_mtree, pos + skips)) branches.append(mk_branch(d_mtree, pos)) + branches.append(mk_branch(k_mtree, pos)) + branches.append(mk_branch(l_mtree, pos)) print('Computed %d spot checks' % samples) - while len(d) < steps * 2: - d += [0] # Return the Merkle roots of P and D, the spot check Merkle proofs, # and low-degree proofs of P and D o = [p_mtree[1], d_mtree[1], + k_mtree[1], + l_mtree[1], branches, - prove_low_degree(values_polynomial, root, p_evaluations, steps), - prove_low_degree(d, root, d_evaluations, steps * 2)] + prove_low_degree(lincomb, root, l_evaluations, steps * 2)] print("STARK computed in %.4f sec" % (time.time() - start_time)) return o # Verifies a STARK def verify_mimc_proof(inp, logsteps, logprecision, output, proof): - p_root, d_root, branches, p_proof, d_proof = proof + p_root, d_root, k_root, l_root, branches, fri_proof = proof start_time = time.time() steps = 2**logsteps @@ -289,37 +297,42 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, proof): skips = precision // steps # Verifies the low-degree proofs - assert verify_low_degree_proof(p_root, root_of_unity, p_proof, steps) - assert verify_low_degree_proof(d_root, root_of_unity, d_proof, steps * 2) + assert verify_low_degree_proof(l_root, root_of_unity, fri_proof, steps * 2) # Performs the spot checks + k = int.from_bytes(blake(p_root + d_root), 'big') samples = spot_check_security_factor // (logprecision - logsteps) - positions = get_indices(blake(p_root + d_root), precision - skips, samples) + positions = get_indices(l_root, precision - skips, samples) for i, pos in enumerate(positions): # Check C(P(x)) = Z(x) * D(x) x = pow(root_of_unity, pos, modulus) - p_of_x = verify_branch(p_root, pos, branches[i*3]) - p_of_rx = verify_branch(p_root, pos+skips, branches[i*3 + 1]) - d_of_x = verify_branch(d_root, pos, branches[i*3 + 2]) + p_of_x = verify_branch(p_root, pos, branches[i*5]) + p_of_rx = verify_branch(p_root, pos+skips, branches[i*5 + 1]) + d_of_x = verify_branch(d_root, pos, branches[i*5 + 2]) + k_of_x = verify_branch(k_root, pos, branches[i*5 + 3]) + l_of_x = verify_branch(l_root, pos, branches[i*5 + 4]) zvalue = f.div(pow(x, steps, modulus) - 1, x - pow(root_of_unity, (steps - 1) * skips, modulus)) - assert (p_of_rx - p_of_x ** 3 - x - zvalue * d_of_x) % modulus == 0 + assert (p_of_rx - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0 + assert (l_of_x - d_of_x - k * p_of_x * pow(x, steps, modulus)) % modulus == 0 print('Verified %d consistency checks' % (spot_check_security_factor // (logprecision - logsteps))) print('Verified STARK in %.4f sec' % (time.time() - start_time)) + print('Note: this does not include verifying the Merkle root of the constants tree') + print('This can be done by every client once as a precomputation') return True INPUT = 3 -LOGSTEPS = 13 -LOGPRECISION = 16 +LOGSTEPS = 17 +LOGPRECISION = 20 # Full STARK test proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION) -L1 = bin_length(compress_branches(proof[2])) -L2 = bin_length(compress_fri(proof[3])) -L3 = bin_length(compress_fri(proof[4])) -print("Approx proof length: %d (branches), %d (FRI proof 1), %d (FRI proof 2), %d (total)" % (L1, L2, L3, L1 + L2 + L3)) +p_root, d_root, k_root, l_root, branches, fri_proof = proof +L1 = bin_length(compress_branches(branches)) +L2 = bin_length(compress_fri(fri_proof)) +print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2)) root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus) subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus) skips = 2**(LOGPRECISION - LOGSTEPS)