206 lines
9.1 KiB
Python
206 lines
9.1 KiB
Python
from merkle_tree import merkelize, mk_branch, verify_branch, blake
|
|
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
|
|
from poly_utils import PrimeField
|
|
import time
|
|
from fft import fft
|
|
from fri import prove_low_degree, verify_low_degree_proof
|
|
from utils import get_power_cycle, get_pseudorandom_indices, is_a_power_of_2
|
|
|
|
modulus = 2**256 - 2**32 * 351 + 1
|
|
f = PrimeField(modulus)
|
|
nonresidue = 7
|
|
|
|
spot_check_security_factor = 80
|
|
extension_factor = 8
|
|
|
|
# Compute a MIMC permutation for some number of steps
|
|
def mimc(inp, steps, round_constants):
|
|
start_time = time.time()
|
|
for i in range(steps-1):
|
|
inp = (inp**3 + round_constants[i % len(round_constants)]) % modulus
|
|
print("MIMC computed in %.4f sec" % (time.time() - start_time))
|
|
return inp
|
|
|
|
# Generate a STARK for a MIMC calculation
|
|
def mk_mimc_proof(inp, steps, round_constants):
|
|
start_time = time.time()
|
|
# Some constraints to make our job easier
|
|
assert steps <= 2**32 // extension_factor
|
|
assert is_a_power_of_2(steps) and is_a_power_of_2(len(round_constants))
|
|
assert len(round_constants) < steps
|
|
|
|
precision = steps * extension_factor
|
|
|
|
# Root of unity such that x^precision=1
|
|
G2 = f.exp(7, (modulus-1)//precision)
|
|
|
|
# Root of unity such that x^steps=1
|
|
skips = precision // steps
|
|
G1 = f.exp(G2, skips)
|
|
|
|
# Powers of the higher-order root of unity
|
|
xs = get_power_cycle(G2, modulus)
|
|
last_step_position = xs[(steps-1)*extension_factor]
|
|
|
|
# Generate the computational trace
|
|
computational_trace = [inp]
|
|
for i in range(steps-1):
|
|
computational_trace.append(
|
|
(computational_trace[-1]**3 + round_constants[i % len(round_constants)]) % modulus
|
|
)
|
|
output = computational_trace[-1]
|
|
print('Done generating computational trace')
|
|
|
|
# Interpolate the computational trace into a polynomial P, with each step
|
|
# along a successive power of G1
|
|
computational_trace_polynomial = fft(computational_trace, modulus, G1, inv=True)
|
|
p_evaluations = fft(computational_trace_polynomial, modulus, G2)
|
|
print('Converted computational steps into a polynomial and low-degree extended it')
|
|
|
|
skips2 = steps // len(round_constants)
|
|
constants_mini_polynomial = fft(round_constants, modulus, f.exp(G1, skips2), inv=True)
|
|
constants_polynomial = [0 if i % skips2 else constants_mini_polynomial[i//skips2] for i in range(steps)]
|
|
constants_mini_extension = fft(constants_mini_polynomial, modulus, f.exp(G2, skips2))
|
|
print('Converted round constants into a polynomial and low-degree extended it')
|
|
|
|
# Create the composed polynomial such that
|
|
# C(P(x), P(g1*x), K(x)) = P(g1*x) - P(x)**3 - K(x)
|
|
c_of_p_evaluations = [(p_evaluations[(i+extension_factor)%precision] -
|
|
f.exp(p_evaluations[i], 3) -
|
|
constants_mini_extension[i % len(constants_mini_extension)])
|
|
% modulus for i in range(precision)]
|
|
print('Computed C(P, K) polynomial')
|
|
|
|
# Compute D(x) = C(P(x), P(g1*x), K(x)) / Z(x)
|
|
# Z(x) = (x^steps - 1) / (x - x_atlast_step)
|
|
z_num_evaluations = [xs[(i * steps) % precision] - 1 for i in range(precision)]
|
|
z_num_inv = f.multi_inv(z_num_evaluations)
|
|
z_den_evaluations = [xs[i] - last_step_position for i in range(precision)]
|
|
d_evaluations = [cp * zd * zni % modulus for cp, zd, zni in zip(c_of_p_evaluations, z_den_evaluations, z_num_inv)]
|
|
print('Computed D polynomial')
|
|
|
|
# Compute interpolant of ((1, input), (x_atlast_step, output))
|
|
interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
|
|
i_evaluations = [f.eval_poly_at(interpolant, x) for x in xs]
|
|
|
|
zeropoly2 = f.mul_polys([-1, 1], [-last_step_position, 1])
|
|
inv_z2_evaluations = f.multi_inv([f.eval_poly_at(zeropoly2, x) for x in xs])
|
|
|
|
b_evaluations = [((p - i) * invq) % modulus for p, i, invq in zip(p_evaluations, i_evaluations, inv_z2_evaluations)]
|
|
print('Computed B polynomial')
|
|
|
|
# Compute their Merkle root
|
|
mtree = merkelize([pval.to_bytes(32, 'big') +
|
|
dval.to_bytes(32, 'big') +
|
|
bval.to_bytes(32, 'big') for
|
|
pval, dval, bval in zip(p_evaluations, d_evaluations, b_evaluations)])
|
|
print('Computed hash root')
|
|
|
|
# Based on the hashes of P, D and B, we select a random linear combination
|
|
# of P * x^steps, P, B * x^steps, B and D, and prove the low-degreeness of that,
|
|
# instead of proving the low-degreeness of P, B and D separately
|
|
k1 = int.from_bytes(blake(mtree[1] + b'\x01'), 'big')
|
|
k2 = int.from_bytes(blake(mtree[1] + b'\x02'), 'big')
|
|
k3 = int.from_bytes(blake(mtree[1] + b'\x03'), 'big')
|
|
k4 = int.from_bytes(blake(mtree[1] + b'\x04'), 'big')
|
|
|
|
# Compute the linear combination. We don't even both calculating it in
|
|
# coefficient form; we just compute the evaluations
|
|
G2_to_the_steps = f.exp(G2, steps)
|
|
powers = [1]
|
|
for i in range(1, precision):
|
|
powers.append(powers[-1] * G2_to_the_steps % modulus)
|
|
|
|
l_evaluations = [(d_evaluations[i] +
|
|
p_evaluations[i] * k1 + p_evaluations[i] * k2 * powers[i] +
|
|
b_evaluations[i] * k3 + b_evaluations[i] * powers[i] * k4) % modulus
|
|
for i in range(precision)]
|
|
|
|
l_mtree = merkelize(l_evaluations)
|
|
print('Computed random linear combination')
|
|
|
|
# Do some spot checks of the Merkle tree at pseudo-random coordinates, excluding
|
|
# multiples of `extension_factor`
|
|
branches = []
|
|
samples = spot_check_security_factor
|
|
positions = get_pseudorandom_indices(l_mtree[1], precision, samples,
|
|
exclude_multiples_of=extension_factor)
|
|
for pos in positions:
|
|
branches.append(mk_branch(mtree, pos))
|
|
branches.append(mk_branch(mtree, (pos + skips) % precision))
|
|
branches.append(mk_branch(l_mtree, pos))
|
|
print('Computed %d spot checks' % samples)
|
|
|
|
# Return the Merkle roots of P and D, the spot check Merkle proofs,
|
|
# and low-degree proofs of P and D
|
|
o = [mtree[1],
|
|
l_mtree[1],
|
|
branches,
|
|
prove_low_degree(l_evaluations, G2, steps * 2, modulus, exclude_multiples_of=extension_factor)]
|
|
print("STARK computed in %.4f sec" % (time.time() - start_time))
|
|
return o
|
|
|
|
# Verifies a STARK
|
|
def verify_mimc_proof(inp, steps, round_constants, output, proof):
|
|
m_root, l_root, branches, fri_proof = proof
|
|
start_time = time.time()
|
|
assert steps <= 2**32 // extension_factor
|
|
assert is_a_power_of_2(steps) and is_a_power_of_2(len(round_constants))
|
|
assert len(round_constants) < steps
|
|
|
|
precision = steps * extension_factor
|
|
|
|
# Get (steps)th root of unity
|
|
G2 = f.exp(7, (modulus-1)//precision)
|
|
skips = precision // steps
|
|
|
|
# Gets the polynomial representing the round constants
|
|
skips2 = steps // len(round_constants)
|
|
constants_mini_polynomial = fft(round_constants, modulus, f.exp(G2, extension_factor * skips2), inv=True)
|
|
|
|
# Verifies the low-degree proofs
|
|
assert verify_low_degree_proof(l_root, G2, fri_proof, steps * 2, modulus, exclude_multiples_of=extension_factor)
|
|
|
|
# Performs the spot checks
|
|
k1 = int.from_bytes(blake(m_root + b'\x01'), 'big')
|
|
k2 = int.from_bytes(blake(m_root + b'\x02'), 'big')
|
|
k3 = int.from_bytes(blake(m_root + b'\x03'), 'big')
|
|
k4 = int.from_bytes(blake(m_root + b'\x04'), 'big')
|
|
samples = spot_check_security_factor
|
|
positions = get_pseudorandom_indices(l_root, precision, samples,
|
|
exclude_multiples_of=extension_factor)
|
|
last_step_position = f.exp(G2, (steps - 1) * skips)
|
|
for i, pos in enumerate(positions):
|
|
x = f.exp(G2, pos)
|
|
x_to_the_steps = f.exp(x, steps)
|
|
mbranch1 = verify_branch(m_root, pos, branches[i*3])
|
|
mbranch2 = verify_branch(m_root, (pos+skips)%precision, branches[i*3+1])
|
|
l_of_x = verify_branch(l_root, pos, branches[i*3 + 2], output_as_int=True)
|
|
|
|
p_of_x = int.from_bytes(mbranch1[:32], 'big')
|
|
p_of_g1x = int.from_bytes(mbranch2[:32], 'big')
|
|
d_of_x = int.from_bytes(mbranch1[32:64], 'big')
|
|
b_of_x = int.from_bytes(mbranch1[64:], 'big')
|
|
|
|
zvalue = f.div(f.exp(x, steps) - 1,
|
|
x - last_step_position)
|
|
k_of_x = f.eval_poly_at(constants_mini_polynomial, f.exp(x, skips2))
|
|
|
|
# Check transition constraints C(P(x)) = Z(x) * D(x)
|
|
assert (p_of_g1x - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0
|
|
|
|
# Check boundary constraints B(x) * Q(x) + I(x) = P(x)
|
|
interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
|
|
zeropoly2 = f.mul_polys([-1, 1], [-last_step_position, 1])
|
|
assert (p_of_x - b_of_x * f.eval_poly_at(zeropoly2, x) -
|
|
f.eval_poly_at(interpolant, x)) % modulus == 0
|
|
|
|
# Check correctness of the linear combination
|
|
assert (l_of_x - d_of_x -
|
|
k1 * p_of_x - k2 * p_of_x * x_to_the_steps -
|
|
k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0
|
|
|
|
print('Verified %d consistency checks' % spot_check_security_factor)
|
|
print('Verified STARK in %.4f sec' % (time.time() - start_time))
|
|
return True
|