Added FFT-based efficiency improvements

This commit is contained in:
Vitalik Buterin 2018-06-30 00:27:02 -04:00
parent 67b8079689
commit f9df7eddb5
2 changed files with 87 additions and 28 deletions

40
mimc_stark/fft.py Normal file
View File

@ -0,0 +1,40 @@
def _fft(vals, modulus, roots_of_unity):
if len(vals) == 1:
return vals
L = _fft(vals[::2], modulus, roots_of_unity[::2])
R = _fft(vals[1::2], modulus, roots_of_unity[::2])
o = [0 for i in vals]
for i, (x, y) in enumerate(zip(L, R)):
y_times_root = y*roots_of_unity[i]
o[i] = (x+y_times_root) % modulus
o[i+len(L)] = (x-y_times_root) % modulus
# print(vals, root_of_unity, o)
return o
def fft(vals, modulus, root_of_unity, inv=False):
# Build up roots of unity
rootz = [1, root_of_unity]
while rootz[-1] != 1:
rootz.append((rootz[-1] * root_of_unity) % modulus)
# Fill in vals with zeroes if needed
if len(rootz) > len(vals) + 1:
vals = vals + [0] * (len(rootz) - len(vals) - 1)
if inv:
# Inverse FFT
invlen = pow(len(vals), modulus-2, modulus)
return [(x*invlen) % modulus for x in _fft(vals, modulus, rootz[::-1])]
else:
# Regular FFT
return _fft(vals, modulus, rootz)
def mul_polys(a, b, modulus, root_of_unity):
x1 = fft(a, modulus, root_of_unity)
x2 = fft(b, modulus, root_of_unity)
return fft([(v1*v2)%modulus for v1,v2 in zip(x1,x2)],
modulus, root_of_unity, inv=True)
def div_polys(a, b, modulus, root_of_unity):
x1 = fft(a, modulus, root_of_unity)
x2 = fft(b, modulus, root_of_unity)
return fft([(v1*pow(v2,modulus-2,modulus))%modulus for v1,v2 in zip(x1,x2)],
modulus, root_of_unity, inv=True)

View File

@ -1,6 +1,7 @@
from merkle_tree import merkelize, mk_branch, verify_branch, blake
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
from ecpoly import PrimeField
from fft import fft, mul_polys, div_polys
import time
modulus = 2**256 - 2**32 * 351 + 1
@ -63,7 +64,15 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1):
# Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
# at that x coordinate
column = [eval_as_bivariate(poly, special_x, xs[i]) for i in range(0, len(xs), 4)]
# We calculate the column by Lagrange-interpolating the row, and not
# directly, as this is more efficient
column = []
for i in range(len(xs)//4):
x_poly = f.lagrange_interp(
[values[i+len(values)*j//4] for j in range(4)],
[xs[i+len(xs)*j//4] for j in range(4)]
)
column.append(f.eval_poly_at(x_poly, special_x))
m2 = merkelize(column)
# Pseudo-randomly select y indices to sample
@ -85,7 +94,8 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1):
sub_xs = [xs[i] for i in range(0, len(xs), 4)]
# Interpolate the polynomial for the column
ypoly = f.lagrange_interp(column[:len(sub_xs)], sub_xs)
ypoly = fft(column[:len(sub_xs)], modulus,
pow(root_of_unity, 4, modulus), inv=True)
# Recurse...
return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4)
@ -145,11 +155,10 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1):
mtree = merkelize(data)
assert mtree[1] == merkle_root
# Check its degree
xs = get_power_cycle(root_of_unity)
poly = f.lagrange_interp(data[:maxdeg_plus_1], xs[:maxdeg_plus_1])
for x, datum in zip(xs[maxdeg_plus_1:], data[maxdeg_plus_1:]):
assert f.eval_poly_at(poly, x) == datum
# Check the degree of the data
poly = fft(data, modulus, root_of_unity, inv=True)
for i in range(maxdeg_plus_1, len(poly)):
assert poly[i] == 0
print('FRI proof verified')
return True
@ -157,7 +166,7 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1):
# Pure FRI tests
poly = list(range(512))
root_of_unity = pow(7, (modulus-1)//1024, modulus)
evaluations = [f.eval_poly_at(poly, pow(root_of_unity, i, modulus)) for i in range(1024)]
evaluations = fft(poly, modulus, root_of_unity)
proof = prove_low_degree(poly, root_of_unity, evaluations, 512)
print("Approx proof length: %d" % bin_length(compress_fri(proof)))
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512)
@ -189,14 +198,17 @@ def mk_mimc_proof(inp, logsteps, logprecision):
# along the sequence of roots of unity
xs = get_power_cycle(root)
skips = precision // steps
subroot = pow(root, skips)
# Generate the computational trace
values = [inp]
for i in range(steps-1):
values.append((values[-1]**3 + xs[i]) % modulus)
values.append((values[-1]**3 + xs[i*skips]) % modulus)
print('Done generating computational trace')
# Interpolate the computational trace into a polynomial
values_polynomial = f.lagrange_interp(values, xs[:len(values)])
# values_polynomial = f.lagrange_interp(values, [pow(subroot, i, modulus) for i in range(steps)])
values_polynomial = fft(values, modulus, subroot, inv=True)
print('Computed polynomial')
#for x, v in zip(xs, values):
@ -204,27 +216,29 @@ def mk_mimc_proof(inp, logsteps, logprecision):
# Create the composed polynomial such that
# C(P(x), P(rx)) = P(rx) - P(x)**3 - x
term1 = f.compose_polys(values_polynomial, [0, root])
term2 = f.mul_polys(f.mul_polys(values_polynomial, values_polynomial),
values_polynomial)
term1 = f.compose_polys(values_polynomial, [0, subroot])
term2 = fft([pow(x, 3, modulus) for x in fft(values_polynomial, modulus, root)], modulus, root, inv=True)[:len(values_polynomial) * 3 - 2]
c_of_values = f.sub_polys(f.sub_polys(term1, term2), [0, 1])
print('Computed C(P) polynomial')
#for i in range(steps-1):
# assert f.eval_poly_at(c_of_values, xs[i]) == 0
# assert f.eval_poly_at(c_of_values, xs[i*skips]) == 0
#print('C(P(x)) check passed')
# Compute the Z(x) polynomial that is 0 along the trace
z = f.zpoly(xs[:steps-1])
z = fft([0] * (steps-1) + [1],
modulus, subroot, inv=True)
# z2 = f.zpoly(xs[:skips*(steps-1):skips])
print('Computed Z polynomial')
# Compute D(x) = C(P(x)) / Z(x)
d = f.div_polys(c_of_values, z)
assert f.mul_polys(d, z) == c_of_values
print('Computed and checked D polynomial')
# assert f.mul_polys(d, z) == c_of_values
print('Computed D polynomial')
# Evaluate P and D across the entire subgroup
p_evaluations = [f.eval_poly_at(values_polynomial, x) for x in xs]
d_evaluations = [f.eval_poly_at(d, x) for x in xs]
p_evaluations = fft(values_polynomial, modulus, root)
d_evaluations = fft(d, modulus, root)
print('Evaluated P and D')
# Compute their Merkle roots
@ -235,10 +249,10 @@ def mk_mimc_proof(inp, logsteps, logprecision):
# Do some spot checks of the Merkle tree at pseudo-random coordinates
branches = []
samples = spot_check_security_factor // (logprecision - logsteps)
positions = get_indices(blake(p_mtree[1] + d_mtree[1]), len(xs) - 1, samples)
positions = get_indices(blake(p_mtree[1] + d_mtree[1]), len(xs) - skips, samples)
for pos in positions:
branches.append(mk_branch(p_mtree, pos))
branches.append(mk_branch(p_mtree, pos + 1))
branches.append(mk_branch(p_mtree, pos + skips))
branches.append(mk_branch(d_mtree, pos))
print('Computed %d spot checks' % samples)
@ -265,6 +279,7 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof):
# Get (steps)th root of unity
root_of_unity = pow(7, (modulus-1)//precision, modulus)
skips = precision // steps
# Verifies the low-degree proofs
assert verify_low_degree_proof(p_root, root_of_unity, p_proof, steps)
@ -272,13 +287,13 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof):
# Performs the spot checks
samples = spot_check_security_factor // (logprecision - logsteps)
positions = get_indices(blake(p_root + d_root), len(xs) - 1, samples)
positions = get_indices(blake(p_root + d_root), len(xs) - skips, samples)
for i, pos in enumerate(positions):
# Check C(P(x)) = Z(x) * D(x)
x = pow(root_of_unity, pos, modulus)
p_of_x = verify_branch(p_root, pos, branches[i*3])
p_of_rx = verify_branch(p_root, pos+1, branches[i*3 + 1])
p_of_rx = verify_branch(p_root, pos+skips, branches[i*3 + 1])
d_of_x = verify_branch(d_root, pos, branches[i*3 + 2])
assert (p_of_rx - p_of_x ** 3 - x - zvalues[pos] * d_of_x) % modulus == 0
@ -287,8 +302,8 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof):
return True
INPUT = 3
LOGSTEPS = 8
LOGPRECISION = 11
LOGSTEPS = 15
LOGPRECISION = 18
# Full STARK test
proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION)
@ -296,7 +311,11 @@ L1 = bin_length(compress_branches(proof[2]))
L2 = bin_length(compress_fri(proof[3]))
L3 = bin_length(compress_fri(proof[4]))
print("Approx proof length: %d (branches), %d (FRI proof 1), %d (FRI proof 2), %d (total)" % (L1, L2, L3, L1 + L2 + L3))
xs = get_power_cycle(pow(7, (modulus-1)//2**LOGPRECISION, modulus))
zpoly = f.zpoly(xs[:2**LOGSTEPS-1])
zpoly_vals = [f.eval_poly_at(zpoly, x) for x in xs]
root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus)
subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus)
xs = get_power_cycle(root_of_unity)
skips = 2**(LOGPRECISION - LOGSTEPS)
zpoly = fft([0] * (2**LOGSTEPS-1) + [1],
modulus, subroot, inv=True)
zpoly_vals = fft(zpoly, modulus, root_of_unity)
assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), zpoly_vals, proof)