From f9df7eddb543e88fd3efb2442ed8b68f4e2b759b Mon Sep 17 00:00:00 2001 From: Vitalik Buterin Date: Sat, 30 Jun 2018 00:27:02 -0400 Subject: [PATCH] Added FFT-based efficiency improvements --- mimc_stark/fft.py | 40 +++++++++++++++++++++ mimc_stark/mimc_stark.py | 75 +++++++++++++++++++++++++--------------- 2 files changed, 87 insertions(+), 28 deletions(-) create mode 100644 mimc_stark/fft.py diff --git a/mimc_stark/fft.py b/mimc_stark/fft.py new file mode 100644 index 0000000..2a12473 --- /dev/null +++ b/mimc_stark/fft.py @@ -0,0 +1,40 @@ +def _fft(vals, modulus, roots_of_unity): + if len(vals) == 1: + return vals + L = _fft(vals[::2], modulus, roots_of_unity[::2]) + R = _fft(vals[1::2], modulus, roots_of_unity[::2]) + o = [0 for i in vals] + for i, (x, y) in enumerate(zip(L, R)): + y_times_root = y*roots_of_unity[i] + o[i] = (x+y_times_root) % modulus + o[i+len(L)] = (x-y_times_root) % modulus + # print(vals, root_of_unity, o) + return o + +def fft(vals, modulus, root_of_unity, inv=False): + # Build up roots of unity + rootz = [1, root_of_unity] + while rootz[-1] != 1: + rootz.append((rootz[-1] * root_of_unity) % modulus) + # Fill in vals with zeroes if needed + if len(rootz) > len(vals) + 1: + vals = vals + [0] * (len(rootz) - len(vals) - 1) + if inv: + # Inverse FFT + invlen = pow(len(vals), modulus-2, modulus) + return [(x*invlen) % modulus for x in _fft(vals, modulus, rootz[::-1])] + else: + # Regular FFT + return _fft(vals, modulus, rootz) + +def mul_polys(a, b, modulus, root_of_unity): + x1 = fft(a, modulus, root_of_unity) + x2 = fft(b, modulus, root_of_unity) + return fft([(v1*v2)%modulus for v1,v2 in zip(x1,x2)], + modulus, root_of_unity, inv=True) + +def div_polys(a, b, modulus, root_of_unity): + x1 = fft(a, modulus, root_of_unity) + x2 = fft(b, modulus, root_of_unity) + return fft([(v1*pow(v2,modulus-2,modulus))%modulus for v1,v2 in zip(x1,x2)], + modulus, root_of_unity, inv=True) diff --git a/mimc_stark/mimc_stark.py b/mimc_stark/mimc_stark.py index 88fd9c4..3021fce 100644 --- a/mimc_stark/mimc_stark.py +++ b/mimc_stark/mimc_stark.py @@ -1,6 +1,7 @@ from merkle_tree import merkelize, mk_branch, verify_branch, blake from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length from ecpoly import PrimeField +from fft import fft, mul_polys, div_polys import time modulus = 2**256 - 2**32 * 351 + 1 @@ -63,7 +64,15 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1): # Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html) # at that x coordinate - column = [eval_as_bivariate(poly, special_x, xs[i]) for i in range(0, len(xs), 4)] + # We calculate the column by Lagrange-interpolating the row, and not + # directly, as this is more efficient + column = [] + for i in range(len(xs)//4): + x_poly = f.lagrange_interp( + [values[i+len(values)*j//4] for j in range(4)], + [xs[i+len(xs)*j//4] for j in range(4)] + ) + column.append(f.eval_poly_at(x_poly, special_x)) m2 = merkelize(column) # Pseudo-randomly select y indices to sample @@ -85,7 +94,8 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1): sub_xs = [xs[i] for i in range(0, len(xs), 4)] # Interpolate the polynomial for the column - ypoly = f.lagrange_interp(column[:len(sub_xs)], sub_xs) + ypoly = fft(column[:len(sub_xs)], modulus, + pow(root_of_unity, 4, modulus), inv=True) # Recurse... return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4) @@ -145,11 +155,10 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1): mtree = merkelize(data) assert mtree[1] == merkle_root - # Check its degree - xs = get_power_cycle(root_of_unity) - poly = f.lagrange_interp(data[:maxdeg_plus_1], xs[:maxdeg_plus_1]) - for x, datum in zip(xs[maxdeg_plus_1:], data[maxdeg_plus_1:]): - assert f.eval_poly_at(poly, x) == datum + # Check the degree of the data + poly = fft(data, modulus, root_of_unity, inv=True) + for i in range(maxdeg_plus_1, len(poly)): + assert poly[i] == 0 print('FRI proof verified') return True @@ -157,7 +166,7 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1): # Pure FRI tests poly = list(range(512)) root_of_unity = pow(7, (modulus-1)//1024, modulus) -evaluations = [f.eval_poly_at(poly, pow(root_of_unity, i, modulus)) for i in range(1024)] +evaluations = fft(poly, modulus, root_of_unity) proof = prove_low_degree(poly, root_of_unity, evaluations, 512) print("Approx proof length: %d" % bin_length(compress_fri(proof))) assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512) @@ -189,14 +198,17 @@ def mk_mimc_proof(inp, logsteps, logprecision): # along the sequence of roots of unity xs = get_power_cycle(root) + skips = precision // steps + subroot = pow(root, skips) # Generate the computational trace values = [inp] for i in range(steps-1): - values.append((values[-1]**3 + xs[i]) % modulus) + values.append((values[-1]**3 + xs[i*skips]) % modulus) print('Done generating computational trace') # Interpolate the computational trace into a polynomial - values_polynomial = f.lagrange_interp(values, xs[:len(values)]) + # values_polynomial = f.lagrange_interp(values, [pow(subroot, i, modulus) for i in range(steps)]) + values_polynomial = fft(values, modulus, subroot, inv=True) print('Computed polynomial') #for x, v in zip(xs, values): @@ -204,27 +216,29 @@ def mk_mimc_proof(inp, logsteps, logprecision): # Create the composed polynomial such that # C(P(x), P(rx)) = P(rx) - P(x)**3 - x - term1 = f.compose_polys(values_polynomial, [0, root]) - term2 = f.mul_polys(f.mul_polys(values_polynomial, values_polynomial), - values_polynomial) + term1 = f.compose_polys(values_polynomial, [0, subroot]) + term2 = fft([pow(x, 3, modulus) for x in fft(values_polynomial, modulus, root)], modulus, root, inv=True)[:len(values_polynomial) * 3 - 2] c_of_values = f.sub_polys(f.sub_polys(term1, term2), [0, 1]) + print('Computed C(P) polynomial') #for i in range(steps-1): - # assert f.eval_poly_at(c_of_values, xs[i]) == 0 + # assert f.eval_poly_at(c_of_values, xs[i*skips]) == 0 #print('C(P(x)) check passed') # Compute the Z(x) polynomial that is 0 along the trace - z = f.zpoly(xs[:steps-1]) + z = fft([0] * (steps-1) + [1], + modulus, subroot, inv=True) + # z2 = f.zpoly(xs[:skips*(steps-1):skips]) print('Computed Z polynomial') # Compute D(x) = C(P(x)) / Z(x) d = f.div_polys(c_of_values, z) - assert f.mul_polys(d, z) == c_of_values - print('Computed and checked D polynomial') + # assert f.mul_polys(d, z) == c_of_values + print('Computed D polynomial') # Evaluate P and D across the entire subgroup - p_evaluations = [f.eval_poly_at(values_polynomial, x) for x in xs] - d_evaluations = [f.eval_poly_at(d, x) for x in xs] + p_evaluations = fft(values_polynomial, modulus, root) + d_evaluations = fft(d, modulus, root) print('Evaluated P and D') # Compute their Merkle roots @@ -235,10 +249,10 @@ def mk_mimc_proof(inp, logsteps, logprecision): # Do some spot checks of the Merkle tree at pseudo-random coordinates branches = [] samples = spot_check_security_factor // (logprecision - logsteps) - positions = get_indices(blake(p_mtree[1] + d_mtree[1]), len(xs) - 1, samples) + positions = get_indices(blake(p_mtree[1] + d_mtree[1]), len(xs) - skips, samples) for pos in positions: branches.append(mk_branch(p_mtree, pos)) - branches.append(mk_branch(p_mtree, pos + 1)) + branches.append(mk_branch(p_mtree, pos + skips)) branches.append(mk_branch(d_mtree, pos)) print('Computed %d spot checks' % samples) @@ -265,6 +279,7 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof): # Get (steps)th root of unity root_of_unity = pow(7, (modulus-1)//precision, modulus) + skips = precision // steps # Verifies the low-degree proofs assert verify_low_degree_proof(p_root, root_of_unity, p_proof, steps) @@ -272,13 +287,13 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof): # Performs the spot checks samples = spot_check_security_factor // (logprecision - logsteps) - positions = get_indices(blake(p_root + d_root), len(xs) - 1, samples) + positions = get_indices(blake(p_root + d_root), len(xs) - skips, samples) for i, pos in enumerate(positions): # Check C(P(x)) = Z(x) * D(x) x = pow(root_of_unity, pos, modulus) p_of_x = verify_branch(p_root, pos, branches[i*3]) - p_of_rx = verify_branch(p_root, pos+1, branches[i*3 + 1]) + p_of_rx = verify_branch(p_root, pos+skips, branches[i*3 + 1]) d_of_x = verify_branch(d_root, pos, branches[i*3 + 2]) assert (p_of_rx - p_of_x ** 3 - x - zvalues[pos] * d_of_x) % modulus == 0 @@ -287,8 +302,8 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof): return True INPUT = 3 -LOGSTEPS = 8 -LOGPRECISION = 11 +LOGSTEPS = 15 +LOGPRECISION = 18 # Full STARK test proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION) @@ -296,7 +311,11 @@ L1 = bin_length(compress_branches(proof[2])) L2 = bin_length(compress_fri(proof[3])) L3 = bin_length(compress_fri(proof[4])) print("Approx proof length: %d (branches), %d (FRI proof 1), %d (FRI proof 2), %d (total)" % (L1, L2, L3, L1 + L2 + L3)) -xs = get_power_cycle(pow(7, (modulus-1)//2**LOGPRECISION, modulus)) -zpoly = f.zpoly(xs[:2**LOGSTEPS-1]) -zpoly_vals = [f.eval_poly_at(zpoly, x) for x in xs] +root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus) +subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus) +xs = get_power_cycle(root_of_unity) +skips = 2**(LOGPRECISION - LOGSTEPS) +zpoly = fft([0] * (2**LOGSTEPS-1) + [1], + modulus, subroot, inv=True) +zpoly_vals = fft(zpoly, modulus, root_of_unity) assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), zpoly_vals, proof)