303 lines
11 KiB
Python
303 lines
11 KiB
Python
from merkle_tree import merkelize, mk_branch, verify_branch, blake
|
|
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
|
|
from ecpoly import PrimeField
|
|
import time
|
|
|
|
modulus = 2**256 - 2**32 * 351 + 1
|
|
f = PrimeField(modulus)
|
|
nonresidue = 7
|
|
quartic_roots_of_unity = [1,
|
|
pow(7, (modulus-1)//4, modulus),
|
|
pow(7, (modulus-1)//2, modulus),
|
|
pow(7, (modulus-1)*3//4, modulus)]
|
|
|
|
spot_check_security_factor = 240
|
|
|
|
# Treat a polynomial as a bivariate polynomial g(x, y) and
|
|
# evaluate it as such. Invariant: eval_as_bivariate(p, x, x**4) = eval(p, x)
|
|
def eval_as_bivariate(p, x, y):
|
|
o = 0
|
|
ypow = 1
|
|
xpows = [pow(x, i, modulus) for i in range(4)]
|
|
for i in range(0, len(p), 4):
|
|
for j in range(4):
|
|
o += xpows[j] * ypow * p[i+j]
|
|
ypow = (ypow * y) % modulus
|
|
return o % modulus
|
|
|
|
# Get the set of powers of R, until but not including when the powers
|
|
# loop back to 1
|
|
def get_power_cycle(r):
|
|
o = [1, r]
|
|
while o[-1] != 1:
|
|
o.append((o[-1] * r) % modulus)
|
|
return o[:-1]
|
|
|
|
# Extract pseudorandom indices from entropy
|
|
def get_indices(seed, modulus, count):
|
|
assert modulus < 2**24
|
|
data = seed
|
|
while len(data) < 4 * count:
|
|
data += blake(data[-32:])
|
|
return [int.from_bytes(data[i: i+4], 'big') % modulus for i in range(0, count * 4, 4)]
|
|
|
|
# Generate an FRI proof
|
|
def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1):
|
|
print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1))
|
|
|
|
# If the degree we are checking for is less than or equal to 32,
|
|
# use the polynomial directly as a proof
|
|
if maxdeg_plus_1 <= 32:
|
|
print('Produced FRI proof')
|
|
return [[x.to_bytes(32, 'big') for x in values]]
|
|
|
|
# Calculate the set of x coordinates
|
|
xs = get_power_cycle(root_of_unity)
|
|
|
|
# Put the values into a Merkle tree. This is the root that the
|
|
# proof will be checked against
|
|
m = merkelize(values)
|
|
|
|
# Select a pseudo-random x coordinate
|
|
special_x = int.from_bytes(m[1], 'big') % modulus
|
|
|
|
# Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
|
|
# at that x coordinate
|
|
column = [eval_as_bivariate(poly, special_x, xs[i]) for i in range(0, len(xs), 4)]
|
|
m2 = merkelize(column)
|
|
|
|
# Pseudo-randomly select y indices to sample
|
|
ys = get_indices(m2[1], len(column), 40)
|
|
|
|
# Compute the Merkle branches for the values in the polynomial and the column
|
|
branches = []
|
|
for y in ys:
|
|
branches.append([mk_branch(m2, y)] + [mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)])
|
|
|
|
#for j in range(4):
|
|
# assert values[ys[0] + len(xs) // 4 * j] == eval_as_bivariate(poly, xs[ys[0] + len(xs) // 4 * j], xs[ys[0] * 4])
|
|
#assert column[ys[0]] == eval_as_bivariate(poly, special_x, xs[ys[0] * 4])
|
|
|
|
# This component of the proof
|
|
o = [m2[1], branches]
|
|
|
|
# In the next iteration of the proof, we'll work with smaller roots of unity
|
|
sub_xs = [xs[i] for i in range(0, len(xs), 4)]
|
|
|
|
# Interpolate the polynomial for the column
|
|
ypoly = f.lagrange_interp(column[:len(sub_xs)], sub_xs)
|
|
|
|
# Recurse...
|
|
return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4)
|
|
|
|
# Verify an FRI proof
|
|
def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1):
|
|
|
|
# Calculate which root of unity we're working with
|
|
testval = root_of_unity
|
|
roudeg = 1
|
|
while testval != 1:
|
|
roudeg *= 2
|
|
testval = (testval * testval) % modulus
|
|
|
|
# Verify the recursive components of the proof
|
|
for prf in proof[:-1]:
|
|
root2, branches = prf
|
|
print('Verifying degree <= %d' % maxdeg_plus_1)
|
|
|
|
# Calculate the pseudo-random x coordinate
|
|
special_x = int.from_bytes(merkle_root, 'big') % modulus
|
|
|
|
# Calculate the pseudo-randomly sampled y indices
|
|
ys = get_indices(root2, roudeg // 4, 40)
|
|
|
|
|
|
# Verify for each selected y coordinate that the four points from the polynomial
|
|
# and the one point from the column that are on that y coordinate are on a
|
|
# deg < 4 polynomial
|
|
for i, y in enumerate(ys):
|
|
# The five x coordinates we are checking
|
|
x1 = pow(root_of_unity, y, modulus)
|
|
eckses = [special_x] + [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)]
|
|
|
|
# The values from the polynomial
|
|
row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) for j, prf in zip(range(4), branches[i][1:])]
|
|
|
|
# Verify proof and recover the column value
|
|
values = [verify_branch(root2, y, branches[i][0])] + row
|
|
|
|
# Lagrange interpolate and check deg is < 4
|
|
p = f.lagrange_interp(values, eckses)
|
|
assert p[4] == 0
|
|
|
|
# Update constants to check the next proof
|
|
merkle_root = root2
|
|
root_of_unity = pow(root_of_unity, 4, modulus)
|
|
maxdeg_plus_1 //= 4
|
|
roudeg //= 4
|
|
|
|
# Verify the direct components of the proof
|
|
data = [int.from_bytes(x, 'big') for x in proof[-1]]
|
|
print('Verifying degree <= %d' % maxdeg_plus_1)
|
|
assert maxdeg_plus_1 <= 32
|
|
|
|
# Check the Merkle root matches up
|
|
mtree = merkelize(data)
|
|
assert mtree[1] == merkle_root
|
|
|
|
# Check its degree
|
|
xs = get_power_cycle(root_of_unity)
|
|
poly = f.lagrange_interp(data[:maxdeg_plus_1], xs[:maxdeg_plus_1])
|
|
for x, datum in zip(xs[maxdeg_plus_1:], data[maxdeg_plus_1:]):
|
|
assert f.eval_poly_at(poly, x) == datum
|
|
|
|
print('FRI proof verified')
|
|
return True
|
|
|
|
# Pure FRI tests
|
|
poly = list(range(512))
|
|
root_of_unity = pow(7, (modulus-1)//1024, modulus)
|
|
evaluations = [f.eval_poly_at(poly, pow(root_of_unity, i, modulus)) for i in range(1024)]
|
|
proof = prove_low_degree(poly, root_of_unity, evaluations, 512)
|
|
print("Approx proof length: %d" % bin_length(compress_fri(proof)))
|
|
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512)
|
|
|
|
# Compute a MIMC permutation for 2**logsteps steps, using round constants
|
|
# from the multiplicative subgroup of size 2**logprecision
|
|
def mimc(inp, logsteps, logprecision):
|
|
start_time = time.time()
|
|
steps = 2**logsteps
|
|
precision = 2**logprecision
|
|
# Get (steps)th root of unity
|
|
root = pow(7, (modulus-1)//precision, modulus)
|
|
xs = get_power_cycle(root)
|
|
for i in range(steps-1):
|
|
inp = (inp**3 + xs[i]) % modulus
|
|
print("MIMC computed in %.4f sec" % (time.time() - start_time))
|
|
return inp
|
|
|
|
# Generate a STARK for a MIMC calculation
|
|
def mk_mimc_proof(inp, logsteps, logprecision):
|
|
start_time = time.time()
|
|
assert logsteps < logprecision <= 32
|
|
steps = 2**logsteps
|
|
precision = 2**logprecision
|
|
|
|
# Get (steps)th root of unity
|
|
root = pow(7, (modulus-1)//precision, modulus)
|
|
# Powers of the root of unity, our computational trace will be
|
|
# along the sequence of roots of unity
|
|
xs = get_power_cycle(root)
|
|
|
|
# Generate the computational trace
|
|
values = [inp]
|
|
for i in range(steps-1):
|
|
values.append((values[-1]**3 + xs[i]) % modulus)
|
|
print('Done generating computational trace')
|
|
|
|
# Interpolate the computational trace into a polynomial
|
|
values_polynomial = f.lagrange_interp(values, xs[:len(values)])
|
|
print('Computed polynomial')
|
|
|
|
#for x, v in zip(xs, values):
|
|
# assert f.eval_poly_at(values_polynomial, x) == v
|
|
|
|
# Create the composed polynomial such that
|
|
# C(P(x), P(rx)) = P(rx) - P(x)**3 - x
|
|
term1 = f.compose_polys(values_polynomial, [0, root])
|
|
term2 = f.mul_polys(f.mul_polys(values_polynomial, values_polynomial),
|
|
values_polynomial)
|
|
c_of_values = f.sub_polys(f.sub_polys(term1, term2), [0, 1])
|
|
|
|
#for i in range(steps-1):
|
|
# assert f.eval_poly_at(c_of_values, xs[i]) == 0
|
|
#print('C(P(x)) check passed')
|
|
|
|
# Compute the Z(x) polynomial that is 0 along the trace
|
|
z = f.zpoly(xs[:steps-1])
|
|
print('Computed Z polynomial')
|
|
|
|
# Compute D(x) = C(P(x)) / Z(x)
|
|
d = f.div_polys(c_of_values, z)
|
|
assert f.mul_polys(d, z) == c_of_values
|
|
print('Computed and checked D polynomial')
|
|
|
|
# Evaluate P and D across the entire subgroup
|
|
p_evaluations = [f.eval_poly_at(values_polynomial, x) for x in xs]
|
|
d_evaluations = [f.eval_poly_at(d, x) for x in xs]
|
|
print('Evaluated P and D')
|
|
|
|
# Compute their Merkle roots
|
|
p_mtree = merkelize(p_evaluations)
|
|
d_mtree = merkelize(d_evaluations)
|
|
print('Computed hash root')
|
|
|
|
# Do some spot checks of the Merkle tree at pseudo-random coordinates
|
|
branches = []
|
|
samples = spot_check_security_factor // (logprecision - logsteps)
|
|
positions = get_indices(blake(p_mtree[1] + d_mtree[1]), len(xs) - 1, samples)
|
|
for pos in positions:
|
|
branches.append(mk_branch(p_mtree, pos))
|
|
branches.append(mk_branch(p_mtree, pos + 1))
|
|
branches.append(mk_branch(d_mtree, pos))
|
|
print('Computed %d spot checks' % samples)
|
|
|
|
while len(d) < steps * 2:
|
|
d += [0]
|
|
|
|
# Return the Merkle roots of P and D, the spot check Merkle proofs,
|
|
# and low-degree proofs of P and D
|
|
o = [p_mtree[1],
|
|
d_mtree[1],
|
|
branches,
|
|
prove_low_degree(values_polynomial, root, p_evaluations, steps),
|
|
prove_low_degree(d, root, d_evaluations, steps * 2)]
|
|
print("STARK computed in %.4f sec" % (time.time() - start_time))
|
|
return o
|
|
|
|
# Verifies a STARK
|
|
def verify_mimc_proof(inp, logsteps, logprecision, output, zvalues, proof):
|
|
p_root, d_root, branches, p_proof, d_proof = proof
|
|
start_time = time.time()
|
|
|
|
steps = 2**logsteps
|
|
precision = 2**logprecision
|
|
|
|
# Get (steps)th root of unity
|
|
root_of_unity = pow(7, (modulus-1)//precision, modulus)
|
|
|
|
# Verifies the low-degree proofs
|
|
assert verify_low_degree_proof(p_root, root_of_unity, p_proof, steps)
|
|
assert verify_low_degree_proof(d_root, root_of_unity, d_proof, steps * 2)
|
|
|
|
# Performs the spot checks
|
|
samples = spot_check_security_factor // (logprecision - logsteps)
|
|
positions = get_indices(blake(p_root + d_root), len(xs) - 1, samples)
|
|
for i, pos in enumerate(positions):
|
|
|
|
# Check C(P(x)) = Z(x) * D(x)
|
|
x = pow(root_of_unity, pos, modulus)
|
|
p_of_x = verify_branch(p_root, pos, branches[i*3])
|
|
p_of_rx = verify_branch(p_root, pos+1, branches[i*3 + 1])
|
|
d_of_x = verify_branch(d_root, pos, branches[i*3 + 2])
|
|
assert (p_of_rx - p_of_x ** 3 - x - zvalues[pos] * d_of_x) % modulus == 0
|
|
|
|
print('Verified %d consistency checks' % (spot_check_security_factor // (logprecision - logsteps)))
|
|
print('Verified STARK in %.4f sec' % (time.time() - start_time))
|
|
return True
|
|
|
|
INPUT = 3
|
|
LOGSTEPS = 8
|
|
LOGPRECISION = 11
|
|
|
|
# Full STARK test
|
|
proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION)
|
|
L1 = bin_length(compress_branches(proof[2]))
|
|
L2 = bin_length(compress_fri(proof[3]))
|
|
L3 = bin_length(compress_fri(proof[4]))
|
|
print("Approx proof length: %d (branches), %d (FRI proof 1), %d (FRI proof 2), %d (total)" % (L1, L2, L3, L1 + L2 + L3))
|
|
xs = get_power_cycle(pow(7, (modulus-1)//2**LOGPRECISION, modulus))
|
|
zpoly = f.zpoly(xs[:2**LOGSTEPS-1])
|
|
zpoly_vals = [f.eval_poly_at(zpoly, x) for x in xs]
|
|
assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), zpoly_vals, proof)
|