Added FFT-based efficiency improvements
This commit is contained in:
parent
67b8079689
commit
f9df7eddb5
|
@ -0,0 +1,40 @@
|
|||
def _fft(vals, modulus, roots_of_unity):
|
||||
if len(vals) == 1:
|
||||
return vals
|
||||
L = _fft(vals[::2], modulus, roots_of_unity[::2])
|
||||
R = _fft(vals[1::2], modulus, roots_of_unity[::2])
|
||||
o = [0 for i in vals]
|
||||
for i, (x, y) in enumerate(zip(L, R)):
|
||||
y_times_root = y*roots_of_unity[i]
|
||||
o[i] = (x+y_times_root) % modulus
|
||||
o[i+len(L)] = (x-y_times_root) % modulus
|
||||
# print(vals, root_of_unity, o)
|
||||
return o
|
||||
|
||||
def fft(vals, modulus, root_of_unity, inv=False):
|
||||
# Build up roots of unity
|
||||
rootz = [1, root_of_unity]
|
||||
while rootz[-1] != 1:
|
||||
rootz.append((rootz[-1] * root_of_unity) % modulus)
|
||||
# Fill in vals with zeroes if needed
|
||||
if len(rootz) > len(vals) + 1:
|
||||
vals = vals + [0] * (len(rootz) - len(vals) - 1)
|
||||
if inv:
|
||||
# Inverse FFT
|
||||
invlen = pow(len(vals), modulus-2, modulus)
|
||||
return [(x*invlen) % modulus for x in _fft(vals, modulus, rootz[::-1])]
|
||||
else:
|
||||
# Regular FFT
|
||||
return _fft(vals, modulus, rootz)
|
||||
|
||||
def mul_polys(a, b, modulus, root_of_unity):
|
||||
x1 = fft(a, modulus, root_of_unity)
|
||||
x2 = fft(b, modulus, root_of_unity)
|
||||
return fft([(v1*v2)%modulus for v1,v2 in zip(x1,x2)],
|
||||
modulus, root_of_unity, inv=True)
|
||||
|
||||
def div_polys(a, b, modulus, root_of_unity):
|
||||
x1 = fft(a, modulus, root_of_unity)
|
||||
x2 = fft(b, modulus, root_of_unity)
|
||||
return fft([(v1*pow(v2,modulus-2,modulus))%modulus for v1,v2 in zip(x1,x2)],
|
||||
modulus, root_of_unity, inv=True)
|
|
@ -1,6 +1,7 @@
|
|||
from merkle_tree import merkelize, mk_branch, verify_branch, blake
|
||||
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
|
||||
from ecpoly import PrimeField
|
||||
from fft import fft, mul_polys, div_polys
|
||||
import time
|
||||
|
||||
modulus = 2**256 - 2**32 * 351 + 1
|
||||
|
@ -63,7 +64,15 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1):
|
|||
|
||||
# Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
|
||||
# at that x coordinate
|
||||
column = [eval_as_bivariate(poly, special_x, xs[i]) for i in range(0, len(xs), 4)]
|
||||
# We calculate the column by Lagrange-interpolating the row, and not
|
||||
# directly, as this is more efficient
|
||||
column = []
|
||||
for i in range(len(xs)//4):
|
||||
x_poly = f.lagrange_interp(
|
||||
[values[i+len(values)*j//4] for j in range(4)],
|
||||
[xs[i+len(xs)*j//4] for j in range(4)]
|
||||
)
|
||||
column.append(f.eval_poly_at(x_poly, special_x))
|
||||
m2 = merkelize(column)
|
||||
|
||||
# Pseudo-randomly select y indices to sample
|
||||
|
@ -85,7 +94,8 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1):
|
|||
sub_xs = [xs[i] for i in range(0, len(xs), 4)]
|
||||
|
||||
# Interpolate the polynomial for the column
|
||||
ypoly = f.lagrange_interp(column[:len(sub_xs)], sub_xs)
|
||||
ypoly = fft(column[:len(sub_xs)], modulus,
|
||||
pow(root_of_unity, 4, modulus), inv=True)
|
||||
|
||||
# Recurse...
|
||||
return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4)
|
||||
|
@ -145,11 +155,10 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1):
|
|||
mtree = merkelize(data)
|
||||
assert mtree[1] == merkle_root
|
||||
|
||||
# Check its degree
|
||||
xs = get_power_cycle(root_of_unity)
|
||||
poly = f.lagrange_interp(data[:maxdeg_plus_1], xs[:maxdeg_plus_1])
|
||||
for x, datum in zip(xs[maxdeg_plus_1:], data[maxdeg_plus_1:]):
|
||||
assert f.eval_poly_at(poly, x) == datum
|
||||
# Check the degree of the data
|
||||
poly = fft(data, modulus, root_of_unity, inv=True)
|
||||
for i in range(maxdeg_plus_1, len(poly)):
|
||||
assert poly[i] == 0
|
||||
|
||||
print('FRI proof verified')
|
||||
return True
|
||||
|
@ -157,7 +166,7 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1):
|
|||
# Pure FRI tests
|
||||
poly = list(range(512))
|
||||
root_of_unity = pow(7, (modulus-1)//1024, modulus)
|
||||
evaluations = [f.eval_poly_at(poly, pow(root_of_unity, i, modulus)) for i in range(1024)]
|
||||
evaluations = fft(poly, modulus, root_of_unity)
|
||||
proof = prove_low_degree(poly, root_of_unity, evaluations, 512)
|
||||
print("Approx proof length: %d" % bin_length(compress_fri(proof)))
|
||||
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512)
|
||||
|
@ -189,14 +198,17 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
|||
# along the sequence of roots of unity
|
||||
xs = get_power_cycle(root)
|
||||
|
||||
skips = precision // steps
|
||||
subroot = pow(root, skips)
|
||||
# Generate the computational trace
|
||||
values = [inp]
|
||||
for i in range(steps-1):
|
||||
values.append((values[-1]**3 + xs[i]) % modulus)
|
||||
values.append((values[-1]**3 + xs[i*skips]) % modulus)
|
||||
print('Done generating computational trace')
|
||||
|
||||
# Interpolate the computational trace into a polynomial
|
||||
values_polynomial = f.lagrange_interp(values, xs[:len(values)])
|
||||
# values_polynomial = f.lagrange_interp(values, [pow(subroot, i, modulus) for i in range(steps)])
|
||||
values_polynomial = fft(values, modulus, subroot, inv=True)
|
||||
print('Computed polynomial')
|
||||
|
||||
#for x, v in zip(xs, values):
|
||||
|
@ -204,27 +216,29 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
|||
|
||||
# Create the composed polynomial such that
|
||||
# C(P(x), P(rx)) = P(rx) - P(x)**3 - x
|
||||
term1 = f.compose_polys(values_polynomial, [0, root])
|
||||
term2 = f.mul_polys(f.mul_polys(values_polynomial, values_polynomial),
|
||||
values_polynomial)
|
||||
term1 = f.compose_polys(values_polynomial, [0, subroot])
|
||||
term2 = fft([pow(x, 3, modulus) for x in fft(values_polynomial, modulus, root)], modulus, root, inv=True)[:len(values_polynomial) * 3 - 2]
|
||||
c_of_values = f.sub_polys(f.sub_polys(term1, term2), [0, 1])
|
||||
print('Computed C(P) polynomial')
|
||||
|
||||
#for i in range(steps-1):
|
||||
# assert f.eval_poly_at(c_of_values, xs[i]) == 0
|
||||
# assert f.eval_poly_at(c_of_values, xs[i*skips]) == 0
|
||||
#print('C(P(x)) check passed')
|
||||
|
||||
# Compute the Z(x) polynomial that is 0 along the trace
|
||||
z = f.zpoly(xs[:steps-1])
|
||||
z = fft([0] * (steps-1) + [1],
|
||||
modulus, subroot, inv=True)
|
||||
# z2 = f.zpoly(xs[:skips*(steps-1):skips])
|
||||
print('Computed Z polynomial')
|
||||
|
||||
# Compute D(x) = C(P(x)) / Z(x)
|
||||
d = f.div_polys(c_of_values, z)
|
||||
assert f.mul_polys(d, z) == c_of_values
|
||||
print('Computed and checked D polynomial')
|
||||
# assert f.mul_polys(d, z) == c_of_values
|
||||
print('Computed D polynomial')
|
||||
|
||||
# Evaluate P and D across the entire subgroup
|
||||
p_evaluations = [f.eval_poly_at(values_polynomial, x) for x in xs]
|
||||
d_evaluations = [f.eval_poly_at(d, x) for x in xs]
|
||||
p_evaluations = fft(values_polynomial, modulus, root)
|
||||
d_evaluations = fft(d, modulus, root)
|
||||
print('Evaluated P and D')
|
||||
|
||||
# Compute their Merkle roots
|
||||
|
@ -235,10 +249,10 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
|||
# Do some spot checks of the Merkle tree at pseudo-random coordinates
|
||||
branches = []
|
||||
samples = spot_check_security_factor // (logprecision - logsteps)
|
||||
positions = get_indices(blake(p_mtree[1] + d_mtree[1]), len(xs) - 1, samples)
|
||||
positions = get_indices(blake(p_mtree[1] + d_mtree[1]), len(xs) - skips, samples)
|
||||
for pos in positions:
|
||||
branches.append(mk_branch(p_mtree, pos))
|
||||
branches.append(mk_branch(p_mtree, pos + 1))
|
||||
branches.append(mk_branch(p_mtree, pos + skips))
|
||||
branches.append(mk_branch(d_mtree, pos))
|
||||
print('Computed %d spot checks' % samples)
|
||||
|
||||
|
@ -265,6 +279,7 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof):
|
|||
|
||||
# Get (steps)th root of unity
|
||||
root_of_unity = pow(7, (modulus-1)//precision, modulus)
|
||||
skips = precision // steps
|
||||
|
||||
# Verifies the low-degree proofs
|
||||
assert verify_low_degree_proof(p_root, root_of_unity, p_proof, steps)
|
||||
|
@ -272,13 +287,13 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof):
|
|||
|
||||
# Performs the spot checks
|
||||
samples = spot_check_security_factor // (logprecision - logsteps)
|
||||
positions = get_indices(blake(p_root + d_root), len(xs) - 1, samples)
|
||||
positions = get_indices(blake(p_root + d_root), len(xs) - skips, samples)
|
||||
for i, pos in enumerate(positions):
|
||||
|
||||
# Check C(P(x)) = Z(x) * D(x)
|
||||
x = pow(root_of_unity, pos, modulus)
|
||||
p_of_x = verify_branch(p_root, pos, branches[i*3])
|
||||
p_of_rx = verify_branch(p_root, pos+1, branches[i*3 + 1])
|
||||
p_of_rx = verify_branch(p_root, pos+skips, branches[i*3 + 1])
|
||||
d_of_x = verify_branch(d_root, pos, branches[i*3 + 2])
|
||||
assert (p_of_rx - p_of_x ** 3 - x - zvalues[pos] * d_of_x) % modulus == 0
|
||||
|
||||
|
@ -287,8 +302,8 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof):
|
|||
return True
|
||||
|
||||
INPUT = 3
|
||||
LOGSTEPS = 8
|
||||
LOGPRECISION = 11
|
||||
LOGSTEPS = 15
|
||||
LOGPRECISION = 18
|
||||
|
||||
# Full STARK test
|
||||
proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION)
|
||||
|
@ -296,7 +311,11 @@ L1 = bin_length(compress_branches(proof[2]))
|
|||
L2 = bin_length(compress_fri(proof[3]))
|
||||
L3 = bin_length(compress_fri(proof[4]))
|
||||
print("Approx proof length: %d (branches), %d (FRI proof 1), %d (FRI proof 2), %d (total)" % (L1, L2, L3, L1 + L2 + L3))
|
||||
xs = get_power_cycle(pow(7, (modulus-1)//2**LOGPRECISION, modulus))
|
||||
zpoly = f.zpoly(xs[:2**LOGSTEPS-1])
|
||||
zpoly_vals = [f.eval_poly_at(zpoly, x) for x in xs]
|
||||
root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus)
|
||||
subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus)
|
||||
xs = get_power_cycle(root_of_unity)
|
||||
skips = 2**(LOGPRECISION - LOGSTEPS)
|
||||
zpoly = fft([0] * (2**LOGSTEPS-1) + [1],
|
||||
modulus, subroot, inv=True)
|
||||
zpoly_vals = fft(zpoly, modulus, root_of_unity)
|
||||
assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), zpoly_vals, proof)
|
||||
|
|
Loading…
Reference in New Issue