Added FFT-based efficiency improvements
This commit is contained in:
parent
67b8079689
commit
f9df7eddb5
|
@ -0,0 +1,40 @@
|
||||||
|
def _fft(vals, modulus, roots_of_unity):
|
||||||
|
if len(vals) == 1:
|
||||||
|
return vals
|
||||||
|
L = _fft(vals[::2], modulus, roots_of_unity[::2])
|
||||||
|
R = _fft(vals[1::2], modulus, roots_of_unity[::2])
|
||||||
|
o = [0 for i in vals]
|
||||||
|
for i, (x, y) in enumerate(zip(L, R)):
|
||||||
|
y_times_root = y*roots_of_unity[i]
|
||||||
|
o[i] = (x+y_times_root) % modulus
|
||||||
|
o[i+len(L)] = (x-y_times_root) % modulus
|
||||||
|
# print(vals, root_of_unity, o)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def fft(vals, modulus, root_of_unity, inv=False):
|
||||||
|
# Build up roots of unity
|
||||||
|
rootz = [1, root_of_unity]
|
||||||
|
while rootz[-1] != 1:
|
||||||
|
rootz.append((rootz[-1] * root_of_unity) % modulus)
|
||||||
|
# Fill in vals with zeroes if needed
|
||||||
|
if len(rootz) > len(vals) + 1:
|
||||||
|
vals = vals + [0] * (len(rootz) - len(vals) - 1)
|
||||||
|
if inv:
|
||||||
|
# Inverse FFT
|
||||||
|
invlen = pow(len(vals), modulus-2, modulus)
|
||||||
|
return [(x*invlen) % modulus for x in _fft(vals, modulus, rootz[::-1])]
|
||||||
|
else:
|
||||||
|
# Regular FFT
|
||||||
|
return _fft(vals, modulus, rootz)
|
||||||
|
|
||||||
|
def mul_polys(a, b, modulus, root_of_unity):
|
||||||
|
x1 = fft(a, modulus, root_of_unity)
|
||||||
|
x2 = fft(b, modulus, root_of_unity)
|
||||||
|
return fft([(v1*v2)%modulus for v1,v2 in zip(x1,x2)],
|
||||||
|
modulus, root_of_unity, inv=True)
|
||||||
|
|
||||||
|
def div_polys(a, b, modulus, root_of_unity):
|
||||||
|
x1 = fft(a, modulus, root_of_unity)
|
||||||
|
x2 = fft(b, modulus, root_of_unity)
|
||||||
|
return fft([(v1*pow(v2,modulus-2,modulus))%modulus for v1,v2 in zip(x1,x2)],
|
||||||
|
modulus, root_of_unity, inv=True)
|
|
@ -1,6 +1,7 @@
|
||||||
from merkle_tree import merkelize, mk_branch, verify_branch, blake
|
from merkle_tree import merkelize, mk_branch, verify_branch, blake
|
||||||
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
|
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
|
||||||
from ecpoly import PrimeField
|
from ecpoly import PrimeField
|
||||||
|
from fft import fft, mul_polys, div_polys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
modulus = 2**256 - 2**32 * 351 + 1
|
modulus = 2**256 - 2**32 * 351 + 1
|
||||||
|
@ -63,7 +64,15 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1):
|
||||||
|
|
||||||
# Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
|
# Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
|
||||||
# at that x coordinate
|
# at that x coordinate
|
||||||
column = [eval_as_bivariate(poly, special_x, xs[i]) for i in range(0, len(xs), 4)]
|
# We calculate the column by Lagrange-interpolating the row, and not
|
||||||
|
# directly, as this is more efficient
|
||||||
|
column = []
|
||||||
|
for i in range(len(xs)//4):
|
||||||
|
x_poly = f.lagrange_interp(
|
||||||
|
[values[i+len(values)*j//4] for j in range(4)],
|
||||||
|
[xs[i+len(xs)*j//4] for j in range(4)]
|
||||||
|
)
|
||||||
|
column.append(f.eval_poly_at(x_poly, special_x))
|
||||||
m2 = merkelize(column)
|
m2 = merkelize(column)
|
||||||
|
|
||||||
# Pseudo-randomly select y indices to sample
|
# Pseudo-randomly select y indices to sample
|
||||||
|
@ -85,7 +94,8 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1):
|
||||||
sub_xs = [xs[i] for i in range(0, len(xs), 4)]
|
sub_xs = [xs[i] for i in range(0, len(xs), 4)]
|
||||||
|
|
||||||
# Interpolate the polynomial for the column
|
# Interpolate the polynomial for the column
|
||||||
ypoly = f.lagrange_interp(column[:len(sub_xs)], sub_xs)
|
ypoly = fft(column[:len(sub_xs)], modulus,
|
||||||
|
pow(root_of_unity, 4, modulus), inv=True)
|
||||||
|
|
||||||
# Recurse...
|
# Recurse...
|
||||||
return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4)
|
return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4)
|
||||||
|
@ -145,11 +155,10 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1):
|
||||||
mtree = merkelize(data)
|
mtree = merkelize(data)
|
||||||
assert mtree[1] == merkle_root
|
assert mtree[1] == merkle_root
|
||||||
|
|
||||||
# Check its degree
|
# Check the degree of the data
|
||||||
xs = get_power_cycle(root_of_unity)
|
poly = fft(data, modulus, root_of_unity, inv=True)
|
||||||
poly = f.lagrange_interp(data[:maxdeg_plus_1], xs[:maxdeg_plus_1])
|
for i in range(maxdeg_plus_1, len(poly)):
|
||||||
for x, datum in zip(xs[maxdeg_plus_1:], data[maxdeg_plus_1:]):
|
assert poly[i] == 0
|
||||||
assert f.eval_poly_at(poly, x) == datum
|
|
||||||
|
|
||||||
print('FRI proof verified')
|
print('FRI proof verified')
|
||||||
return True
|
return True
|
||||||
|
@ -157,7 +166,7 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1):
|
||||||
# Pure FRI tests
|
# Pure FRI tests
|
||||||
poly = list(range(512))
|
poly = list(range(512))
|
||||||
root_of_unity = pow(7, (modulus-1)//1024, modulus)
|
root_of_unity = pow(7, (modulus-1)//1024, modulus)
|
||||||
evaluations = [f.eval_poly_at(poly, pow(root_of_unity, i, modulus)) for i in range(1024)]
|
evaluations = fft(poly, modulus, root_of_unity)
|
||||||
proof = prove_low_degree(poly, root_of_unity, evaluations, 512)
|
proof = prove_low_degree(poly, root_of_unity, evaluations, 512)
|
||||||
print("Approx proof length: %d" % bin_length(compress_fri(proof)))
|
print("Approx proof length: %d" % bin_length(compress_fri(proof)))
|
||||||
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512)
|
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512)
|
||||||
|
@ -189,14 +198,17 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
||||||
# along the sequence of roots of unity
|
# along the sequence of roots of unity
|
||||||
xs = get_power_cycle(root)
|
xs = get_power_cycle(root)
|
||||||
|
|
||||||
|
skips = precision // steps
|
||||||
|
subroot = pow(root, skips)
|
||||||
# Generate the computational trace
|
# Generate the computational trace
|
||||||
values = [inp]
|
values = [inp]
|
||||||
for i in range(steps-1):
|
for i in range(steps-1):
|
||||||
values.append((values[-1]**3 + xs[i]) % modulus)
|
values.append((values[-1]**3 + xs[i*skips]) % modulus)
|
||||||
print('Done generating computational trace')
|
print('Done generating computational trace')
|
||||||
|
|
||||||
# Interpolate the computational trace into a polynomial
|
# Interpolate the computational trace into a polynomial
|
||||||
values_polynomial = f.lagrange_interp(values, xs[:len(values)])
|
# values_polynomial = f.lagrange_interp(values, [pow(subroot, i, modulus) for i in range(steps)])
|
||||||
|
values_polynomial = fft(values, modulus, subroot, inv=True)
|
||||||
print('Computed polynomial')
|
print('Computed polynomial')
|
||||||
|
|
||||||
#for x, v in zip(xs, values):
|
#for x, v in zip(xs, values):
|
||||||
|
@ -204,27 +216,29 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
||||||
|
|
||||||
# Create the composed polynomial such that
|
# Create the composed polynomial such that
|
||||||
# C(P(x), P(rx)) = P(rx) - P(x)**3 - x
|
# C(P(x), P(rx)) = P(rx) - P(x)**3 - x
|
||||||
term1 = f.compose_polys(values_polynomial, [0, root])
|
term1 = f.compose_polys(values_polynomial, [0, subroot])
|
||||||
term2 = f.mul_polys(f.mul_polys(values_polynomial, values_polynomial),
|
term2 = fft([pow(x, 3, modulus) for x in fft(values_polynomial, modulus, root)], modulus, root, inv=True)[:len(values_polynomial) * 3 - 2]
|
||||||
values_polynomial)
|
|
||||||
c_of_values = f.sub_polys(f.sub_polys(term1, term2), [0, 1])
|
c_of_values = f.sub_polys(f.sub_polys(term1, term2), [0, 1])
|
||||||
|
print('Computed C(P) polynomial')
|
||||||
|
|
||||||
#for i in range(steps-1):
|
#for i in range(steps-1):
|
||||||
# assert f.eval_poly_at(c_of_values, xs[i]) == 0
|
# assert f.eval_poly_at(c_of_values, xs[i*skips]) == 0
|
||||||
#print('C(P(x)) check passed')
|
#print('C(P(x)) check passed')
|
||||||
|
|
||||||
# Compute the Z(x) polynomial that is 0 along the trace
|
# Compute the Z(x) polynomial that is 0 along the trace
|
||||||
z = f.zpoly(xs[:steps-1])
|
z = fft([0] * (steps-1) + [1],
|
||||||
|
modulus, subroot, inv=True)
|
||||||
|
# z2 = f.zpoly(xs[:skips*(steps-1):skips])
|
||||||
print('Computed Z polynomial')
|
print('Computed Z polynomial')
|
||||||
|
|
||||||
# Compute D(x) = C(P(x)) / Z(x)
|
# Compute D(x) = C(P(x)) / Z(x)
|
||||||
d = f.div_polys(c_of_values, z)
|
d = f.div_polys(c_of_values, z)
|
||||||
assert f.mul_polys(d, z) == c_of_values
|
# assert f.mul_polys(d, z) == c_of_values
|
||||||
print('Computed and checked D polynomial')
|
print('Computed D polynomial')
|
||||||
|
|
||||||
# Evaluate P and D across the entire subgroup
|
# Evaluate P and D across the entire subgroup
|
||||||
p_evaluations = [f.eval_poly_at(values_polynomial, x) for x in xs]
|
p_evaluations = fft(values_polynomial, modulus, root)
|
||||||
d_evaluations = [f.eval_poly_at(d, x) for x in xs]
|
d_evaluations = fft(d, modulus, root)
|
||||||
print('Evaluated P and D')
|
print('Evaluated P and D')
|
||||||
|
|
||||||
# Compute their Merkle roots
|
# Compute their Merkle roots
|
||||||
|
@ -235,10 +249,10 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
||||||
# Do some spot checks of the Merkle tree at pseudo-random coordinates
|
# Do some spot checks of the Merkle tree at pseudo-random coordinates
|
||||||
branches = []
|
branches = []
|
||||||
samples = spot_check_security_factor // (logprecision - logsteps)
|
samples = spot_check_security_factor // (logprecision - logsteps)
|
||||||
positions = get_indices(blake(p_mtree[1] + d_mtree[1]), len(xs) - 1, samples)
|
positions = get_indices(blake(p_mtree[1] + d_mtree[1]), len(xs) - skips, samples)
|
||||||
for pos in positions:
|
for pos in positions:
|
||||||
branches.append(mk_branch(p_mtree, pos))
|
branches.append(mk_branch(p_mtree, pos))
|
||||||
branches.append(mk_branch(p_mtree, pos + 1))
|
branches.append(mk_branch(p_mtree, pos + skips))
|
||||||
branches.append(mk_branch(d_mtree, pos))
|
branches.append(mk_branch(d_mtree, pos))
|
||||||
print('Computed %d spot checks' % samples)
|
print('Computed %d spot checks' % samples)
|
||||||
|
|
||||||
|
@ -265,6 +279,7 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof):
|
||||||
|
|
||||||
# Get (steps)th root of unity
|
# Get (steps)th root of unity
|
||||||
root_of_unity = pow(7, (modulus-1)//precision, modulus)
|
root_of_unity = pow(7, (modulus-1)//precision, modulus)
|
||||||
|
skips = precision // steps
|
||||||
|
|
||||||
# Verifies the low-degree proofs
|
# Verifies the low-degree proofs
|
||||||
assert verify_low_degree_proof(p_root, root_of_unity, p_proof, steps)
|
assert verify_low_degree_proof(p_root, root_of_unity, p_proof, steps)
|
||||||
|
@ -272,13 +287,13 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof):
|
||||||
|
|
||||||
# Performs the spot checks
|
# Performs the spot checks
|
||||||
samples = spot_check_security_factor // (logprecision - logsteps)
|
samples = spot_check_security_factor // (logprecision - logsteps)
|
||||||
positions = get_indices(blake(p_root + d_root), len(xs) - 1, samples)
|
positions = get_indices(blake(p_root + d_root), len(xs) - skips, samples)
|
||||||
for i, pos in enumerate(positions):
|
for i, pos in enumerate(positions):
|
||||||
|
|
||||||
# Check C(P(x)) = Z(x) * D(x)
|
# Check C(P(x)) = Z(x) * D(x)
|
||||||
x = pow(root_of_unity, pos, modulus)
|
x = pow(root_of_unity, pos, modulus)
|
||||||
p_of_x = verify_branch(p_root, pos, branches[i*3])
|
p_of_x = verify_branch(p_root, pos, branches[i*3])
|
||||||
p_of_rx = verify_branch(p_root, pos+1, branches[i*3 + 1])
|
p_of_rx = verify_branch(p_root, pos+skips, branches[i*3 + 1])
|
||||||
d_of_x = verify_branch(d_root, pos, branches[i*3 + 2])
|
d_of_x = verify_branch(d_root, pos, branches[i*3 + 2])
|
||||||
assert (p_of_rx - p_of_x ** 3 - x - zvalues[pos] * d_of_x) % modulus == 0
|
assert (p_of_rx - p_of_x ** 3 - x - zvalues[pos] * d_of_x) % modulus == 0
|
||||||
|
|
||||||
|
@ -287,8 +302,8 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
INPUT = 3
|
INPUT = 3
|
||||||
LOGSTEPS = 8
|
LOGSTEPS = 15
|
||||||
LOGPRECISION = 11
|
LOGPRECISION = 18
|
||||||
|
|
||||||
# Full STARK test
|
# Full STARK test
|
||||||
proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION)
|
proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION)
|
||||||
|
@ -296,7 +311,11 @@ L1 = bin_length(compress_branches(proof[2]))
|
||||||
L2 = bin_length(compress_fri(proof[3]))
|
L2 = bin_length(compress_fri(proof[3]))
|
||||||
L3 = bin_length(compress_fri(proof[4]))
|
L3 = bin_length(compress_fri(proof[4]))
|
||||||
print("Approx proof length: %d (branches), %d (FRI proof 1), %d (FRI proof 2), %d (total)" % (L1, L2, L3, L1 + L2 + L3))
|
print("Approx proof length: %d (branches), %d (FRI proof 1), %d (FRI proof 2), %d (total)" % (L1, L2, L3, L1 + L2 + L3))
|
||||||
xs = get_power_cycle(pow(7, (modulus-1)//2**LOGPRECISION, modulus))
|
root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus)
|
||||||
zpoly = f.zpoly(xs[:2**LOGSTEPS-1])
|
subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus)
|
||||||
zpoly_vals = [f.eval_poly_at(zpoly, x) for x in xs]
|
xs = get_power_cycle(root_of_unity)
|
||||||
|
skips = 2**(LOGPRECISION - LOGSTEPS)
|
||||||
|
zpoly = fft([0] * (2**LOGSTEPS-1) + [1],
|
||||||
|
modulus, subroot, inv=True)
|
||||||
|
zpoly_vals = fft(zpoly, modulus, root_of_unity)
|
||||||
assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), zpoly_vals, proof)
|
assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), zpoly_vals, proof)
|
||||||
|
|
Loading…
Reference in New Issue