More efficiency

This commit is contained in:
Vitalik Buterin 2018-07-04 05:33:51 -04:00
parent 3180447866
commit 59b8020fec
4 changed files with 142 additions and 65 deletions

View File

@ -1,17 +1,21 @@
#### Disclaimer
DO NOT USE FOR ANYTHING IN REAL LIFE. DO NOT ASSUME THE PROTOCOL DESCRIBED HERE IS SOUND. TALK TO A SPECIALIST IF YOU'RE LOOKING TO USE STARKS IN YOUR APPLICATION.
#### What is this? #### What is this?
This is a very basic implementation of a STARK on a MIMC computation that is probably (ie. definitely) broken in a few places but is intended as a proof of concept to show the rough level of complexity that is involved in implementing a simple STARK. A STARK is a really cool proof-of-computation scheme that allows you to create an efficiently verifiable proof that some computation was executed correctly; the verification time only rises logarithmically with the computation time, and that relies only on hashes and information theory for security. This is a very basic implementation of a STARK on a MIMC computation that is probably (ie. definitely) broken in a few places but is intended as a proof of concept to show the rough level of complexity that is involved in implementing a simple STARK. A STARK is a really cool proof-of-computation scheme that allows you to create an efficiently verifiable proof that some computation was executed correctly; the verification time only rises logarithmically with the computation time, and that relies only on hashes and information theory for security.
The STARKs are done over a finite field chosen to have 2^32'th roots of unity (to facilitate STARKs), and NOT have 3rd roots of unity (to facilitate MIMC). The MIMC permutation in general takes the form: The STARKs are done over a finite field chosen to have `2**32`'th roots of unity (to facilitate STARKs), and NOT have 3rd roots of unity (to facilitate MIMC). The MIMC permutation in general takes the form:
k_1 k_2 k_1 k_2
| | | |
v v v v
x0 ---> (x->x^3) ---> xor ---> (x->x^3) ---> xor --- ... ---> output x0 ---> (x->x**3) ---> xor ---> (x->x**3) ---> xor --- ... ---> output
Where the `k_i` values are round constants. MIMC can be used as a building block in a hash function, or as a verifiable delay function; its simple arithmetic representation makes it ideal for use in STARKs, SNARKs, MPC and other "cryptography over general-purpose computation" schemes. Where the `k_i` values are round constants. MIMC can be used as a building block in a hash function, or as a verifiable delay function; its simple arithmetic representation makes it ideal for use in STARKs, SNARKs, MPC and other "cryptography over general-purpose computation" schemes.
The MIMC round constants used here are 2^k'th roots of unity for k <= 32. The computational trace is computed over successive powers of a 2^k'th root of unity. This allows the constraint checking polynomial to be `C(P(x), P(r*x), x) = P(r*x) - P(x)**3 - x`. The MIMC round constants used here are successive powers of 9 mod `2**256` xored with 1, though could be set to anything. The computational trace is computed over successive powers of a `2**k`'th root of unity. This allows the constraint checking polynomial to be `C(P(x), P(r*x), K(x)) = P(r*x) - P(x)**3 - K(x)`.
For a description of how STARKs work, see: For a description of how STARKs work, see:
@ -22,4 +26,18 @@ For more discussion on MIMC, see:
* [Zcash issue 2233](https://github.com/zcash/zcash/issues/2233) * [Zcash issue 2233](https://github.com/zcash/zcash/issues/2233)
The implementation is definitely suboptimal. For example, one optimization would be to reduce the two FRIs to one by low-degree proving a random linear combination k1 * D + k2 * P^2; this could cut the verification time by close to half. Another would be to remove redundancies in computing zero polynomials in multiple places; a more sophisticated optimization would be to move to subquadratic algorithms (Karatsuba or ideally FFT-based) for polynomial interpolation. The purpose is for the code to be deliberately simple to aid with understanding. ### How does the STARK scheme here work?
Here are the approximate steps in the code. Note that all arithmetic is done over the finite field of integers mod `2**256 - 2**32 * 351 + 1`.
1. Let `P[0]` equal the input, and `P[i+1] = P[i]**3 + K[i]`, up to `steps` (where `steps` must be a power of 2)
2. Construct a polynomial where `P(subroot ^ i) = P[i]` up to steps-1, where `subroot` is a steps'th root of unity (that is, `subroot**steps = 1`). Do the same with K.
3. Construct the polynomial `CP(x) = C(P(x), P(x * subroot), K(x))`. Note that since `P(x * subroot) = P(x) ^ 3 + K(x), CP(x) = 0` for any x inside the computation trace (that is, powers of subroot except the last).
4. Construct the polynomial `Z(x)`, which is the minimal polynomial that is 0 across the computation trace. Note that because Z is minimal, CP must be a multiple of Z.
5. Construct `D = CP / Z`. Because CP is a multiple of Z, D must itself be a low-degree polynomial.
6. Put D and P into Merkle trees.
7. Construct `L = D + k * x**steps * P`, where `k` is selected based on the Merkle roots of D and P.
8. Create an FRI proof that L is low-degree ((7) and (8) together are a clever way of making an aggregate low-degree proof of D and P)
9. Create probabilistic checks, using the Merkle root of L as source data, to spot-check that D(x) * Z(x) actually equal C(P(x)).
The probabilistic checks and the FRI proof are themselves restricted to `2**precision`'th roots of unity, where `precision = 8 * steps` (so we're FRI checking that the degree of L, which equals twice the degree of P, is at most 1/4 the theoretical maximum).

View File

@ -0,0 +1,37 @@
def inv(a, modulus):
if a == 0:
return 0
lm, hm = 1, 0
low, high = a % modulus, modulus
while low > 1:
r = high//low
nm, new = hm-lm*r, high-low*r
lm, low, hm, high = nm, new, lm, low
return lm % modulus
def eval_poly_at(poly, x, modulus):
o, p = 0, 1
for coeff in poly:
o += coeff * p
p *= x
return o % modulus
def lagrange_interp_4(pieces, xs, modulus):
x01, x02, x03, x12, x13, x23 = \
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
eq0 = [-x12 * xs[3] % modulus, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
eq1 = [-x02 * xs[3] % modulus, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
eq2 = [-x01 * xs[3] % modulus, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
eq3 = [-x01 * xs[2] % modulus, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
e0 = eval_poly_at(eq0, xs[0], modulus)
e1 = eval_poly_at(eq1, xs[1], modulus)
e2 = eval_poly_at(eq2, xs[2], modulus)
e3 = eval_poly_at(eq3, xs[3], modulus)
e01 = e0 * e1
e23 = e2 * e3
invall = inv(e01 * e23, modulus)
inv_y0 = pieces[0] * invall * e1 * e23 % modulus
inv_y1 = pieces[1] * invall * e0 * e23 % modulus
inv_y2 = pieces[2] * invall * e01 * e3 % modulus
inv_y3 = pieces[3] * invall * e01 * e2 % modulus
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % modulus for i in range(4)]

View File

@ -1,6 +1,15 @@
def _simple_ft(vals, modulus, roots_of_unity):
L = len(roots_of_unity)
o = [0 for _ in range(L)]
for i in range(L):
for j in range(L):
o[i] += vals[j] * roots_of_unity[(i*j)%L]
return [x % modulus for x in o]
def _fft(vals, modulus, roots_of_unity): def _fft(vals, modulus, roots_of_unity):
if len(vals) == 1: if len(vals) == 1:
return vals return vals
# return _simple_ft(vals, modulus, roots_of_unity)
L = _fft(vals[::2], modulus, roots_of_unity[::2]) L = _fft(vals[::2], modulus, roots_of_unity[::2])
R = _fft(vals[1::2], modulus, roots_of_unity[::2]) R = _fft(vals[1::2], modulus, roots_of_unity[::2])
o = [0 for i in vals] o = [0 for i in vals]
@ -8,7 +17,6 @@ def _fft(vals, modulus, roots_of_unity):
y_times_root = y*roots_of_unity[i] y_times_root = y*roots_of_unity[i]
o[i] = (x+y_times_root) % modulus o[i] = (x+y_times_root) % modulus
o[i+len(L)] = (x-y_times_root) % modulus o[i+len(L)] = (x-y_times_root) % modulus
# print(vals, root_of_unity, o)
return o return o
def fft(vals, modulus, root_of_unity, inv=False): def fft(vals, modulus, root_of_unity, inv=False):
@ -22,10 +30,11 @@ def fft(vals, modulus, root_of_unity, inv=False):
if inv: if inv:
# Inverse FFT # Inverse FFT
invlen = pow(len(vals), modulus-2, modulus) invlen = pow(len(vals), modulus-2, modulus)
return [(x*invlen) % modulus for x in _fft(vals, modulus, rootz[::-1])] return [(x*invlen) % modulus for x in
_fft(vals, modulus, rootz[:0:-1])]
else: else:
# Regular FFT # Regular FFT
return _fft(vals, modulus, rootz) return _fft(vals, modulus, rootz[:-1])
def mul_polys(a, b, modulus, root_of_unity): def mul_polys(a, b, modulus, root_of_unity):
x1 = fft(a, modulus, root_of_unity) x1 = fft(a, modulus, root_of_unity)

View File

@ -1,8 +1,9 @@
from merkle_tree import merkelize, mk_branch, verify_branch, blake from merkle_tree import merkelize, mk_branch, verify_branch, blake
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
from ecpoly import PrimeField from ecpoly import PrimeField
from fft import fft, mul_polys from better_lagrange import lagrange_interp_4
import time import time
from fft import fft
modulus = 2**256 - 2**32 * 351 + 1 modulus = 2**256 - 2**32 * 351 + 1
f = PrimeField(modulus) f = PrimeField(modulus)
@ -14,18 +15,6 @@ quartic_roots_of_unity = [1,
spot_check_security_factor = 240 spot_check_security_factor = 240
# Treat a polynomial as a bivariate polynomial g(x, y) and
# evaluate it as such. Invariant: eval_as_bivariate(p, x, x**4) = eval(p, x)
def eval_as_bivariate(p, x, y):
o = 0
ypow = 1
xpows = [pow(x, i, modulus) for i in range(4)]
for i in range(0, len(p), 4):
for j in range(4):
o += xpows[j] * ypow * p[i+j]
ypow = (ypow * y) % modulus
return o % modulus
# Get the set of powers of R, until but not including when the powers # Get the set of powers of R, until but not including when the powers
# loop back to 1 # loop back to 1
def get_power_cycle(r): def get_power_cycle(r):
@ -68,9 +57,10 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1):
# directly, as this is more efficient # directly, as this is more efficient
column = [] column = []
for i in range(len(xs)//4): for i in range(len(xs)//4):
x_poly = f.lagrange_interp( x_poly = lagrange_interp_4(
[values[i+len(values)*j//4] for j in range(4)], [values[i+len(values)*j//4] for j in range(4)],
[xs[i+len(xs)*j//4] for j in range(4)] [xs[i+len(xs)*j//4] for j in range(4)],
modulus
) )
column.append(f.eval_poly_at(x_poly, special_x)) column.append(f.eval_poly_at(x_poly, special_x))
m2 = merkelize(column) m2 = merkelize(column)
@ -119,12 +109,12 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1):
# Verify for each selected y coordinate that the four points from the polynomial # Verify for each selected y coordinate that the four points from the polynomial
# and the one point from the column that are on that y coordinate are on a # and the one point from the column that are on that y coordinate are on the same
# deg < 4 polynomial # deg < 4 polynomial
for i, y in enumerate(ys): for i, y in enumerate(ys):
# The five x coordinates we are checking # The x coordinates from the polynomial
x1 = pow(root_of_unity, y, modulus) x1 = pow(root_of_unity, y, modulus)
eckses = [special_x] + [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)] xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)]
# The values from the polynomial # The values from the polynomial
row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) for j, prf in zip(range(4), branches[i][1:])] row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) for j, prf in zip(range(4), branches[i][1:])]
@ -133,8 +123,8 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1):
values = [verify_branch(root2, y, branches[i][0])] + row values = [verify_branch(root2, y, branches[i][0])] + row
# Lagrange interpolate and check deg is < 4 # Lagrange interpolate and check deg is < 4
p = f.lagrange_interp(values, eckses) p = lagrange_interp_4(row, xcoords, modulus)
assert p[4] == 0 assert f.eval_poly_at(p, special_x) == verify_branch(root2, y, branches[i][0])
# Update constants to check the next proof # Update constants to check the next proof
merkle_root = root2 merkle_root = root2
@ -175,9 +165,11 @@ def mimc(inp, logsteps, logprecision):
precision = 2**logprecision precision = 2**logprecision
# Get (steps)th root of unity # Get (steps)th root of unity
subroot = pow(7, (modulus-1)//steps, modulus) subroot = pow(7, (modulus-1)//steps, modulus)
xs = get_power_cycle(subroot) # We use powers of 9 mod 2^256 as the ith round constant for the moment
k = 1
for i in range(steps-1): for i in range(steps-1):
inp = (inp**3 + xs[i]) % modulus inp = (inp**3 + (k ^ 1)) % modulus
k = (k * 9) & ((1 << 256) - 1)
print("MIMC computed in %.4f sec" % (time.time() - start_time)) print("MIMC computed in %.4f sec" % (time.time() - start_time))
return inp return inp
@ -215,27 +207,30 @@ def mk_mimc_proof(inp, logsteps, logprecision):
xs = get_power_cycle(subroot) xs = get_power_cycle(subroot)
# Generate the computational trace # Generate the computational trace
constants = []
values = [inp] values = [inp]
k = 1
for i in range(steps-1): for i in range(steps-1):
values.append((values[-1]**3 + xs[i]) % modulus) values.append((values[-1]**3 + (k ^ 1)) % modulus)
constants.append(k ^ 1)
k = (k * 9) & ((1 << 256) - 1)
constants.append(0)
print('Done generating computational trace') print('Done generating computational trace')
# Interpolate the computational trace into a polynomial # Interpolate the computational trace into a polynomial
# values_polynomial = f.lagrange_interp(values, [pow(subroot, i, modulus) for i in range(steps)])
values_polynomial = fft(values, modulus, subroot, inv=True) values_polynomial = fft(values, modulus, subroot, inv=True)
print('Computed polynomial') constants_polynomial = fft(constants, modulus, subroot, inv=True)
print('Converted computational steps and constants into a polynomial')
#for x, v in zip(xs, values):
# assert f.eval_poly_at(values_polynomial, x) == v
# Create the composed polynomial such that # Create the composed polynomial such that
# C(P(x), P(rx)) = P(rx) - P(x)**3 - x # C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x)
term1 = multiply_base(values_polynomial, subroot) term1 = multiply_base(values_polynomial, subroot)
term2 = fft([pow(x, 3, modulus) for x in fft(values_polynomial, modulus, root)], modulus, root, inv=True)[:len(values_polynomial) * 3 - 2] p_evaluations = fft(values_polynomial, modulus, root)
c_of_values = f.sub_polys(f.sub_polys(term1, term2), [0, 1]) term2 = fft([pow(x, 3, modulus) for x in p_evaluations], modulus, root, inv=True)[:len(values_polynomial) * 3 - 2]
print('Computed C(P) polynomial') c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial)
print('Computed C(P, K) polynomial')
# Compute D(x) = C(P(x)) / Z(x) # Compute D(x) = C(P(x), P(rx), K(x)) / Z(x)
# Z(x) = (x^steps - 1) / (x - x_atlast_step) # Z(x) = (x^steps - 1) / (x - x_atlast_step)
d = divide_by_xnm1(f.mul_polys(c_of_values, d = divide_by_xnm1(f.mul_polys(c_of_values,
[modulus-xs[steps-1], 1]), [modulus-xs[steps-1], 1]),
@ -243,42 +238,55 @@ def mk_mimc_proof(inp, logsteps, logprecision):
# assert f.mul_polys(d, z) == c_of_values # assert f.mul_polys(d, z) == c_of_values
print('Computed D polynomial') print('Computed D polynomial')
# Evaluate P and D across the entire subgroup # Evaluate D and K across the entire subgroup
p_evaluations = fft(values_polynomial, modulus, root)
d_evaluations = fft(d, modulus, root) d_evaluations = fft(d, modulus, root)
print('Evaluated P and D') k_evaluations = fft(constants_polynomial, modulus, root)
print('Evaluated P, D and K')
# Compute their Merkle roots # Compute their Merkle roots
p_mtree = merkelize(p_evaluations) p_mtree = merkelize(p_evaluations)
d_mtree = merkelize(d_evaluations) d_mtree = merkelize(d_evaluations)
k_mtree = merkelize(k_evaluations)
print('Computed hash root') print('Computed hash root')
# Based on the hashes of P and D, we select a random linear combination
# of P * x^steps and D, and prove the low-degreeness of that, instead of proving
# the low-degreeness of P and D separately
k = int.from_bytes(blake(p_mtree[1] + d_mtree[1]), 'big')
lincomb = f.add_polys(d, f.mul_by_const([0] * steps + values_polynomial, k))
l_evaluations = fft(lincomb, modulus, root)
l_mtree = merkelize(l_evaluations)
print('Computed random linear combination')
# Do some spot checks of the Merkle tree at pseudo-random coordinates # Do some spot checks of the Merkle tree at pseudo-random coordinates
branches = [] branches = []
samples = spot_check_security_factor // (logprecision - logsteps) samples = spot_check_security_factor // (logprecision - logsteps)
positions = get_indices(blake(p_mtree[1] + d_mtree[1]), precision - skips, samples) positions = get_indices(l_mtree[1], precision - skips, samples)
for pos in positions: for pos in positions:
branches.append(mk_branch(p_mtree, pos)) branches.append(mk_branch(p_mtree, pos))
branches.append(mk_branch(p_mtree, pos + skips)) branches.append(mk_branch(p_mtree, pos + skips))
branches.append(mk_branch(d_mtree, pos)) branches.append(mk_branch(d_mtree, pos))
branches.append(mk_branch(k_mtree, pos))
branches.append(mk_branch(l_mtree, pos))
print('Computed %d spot checks' % samples) print('Computed %d spot checks' % samples)
while len(d) < steps * 2:
d += [0]
# Return the Merkle roots of P and D, the spot check Merkle proofs, # Return the Merkle roots of P and D, the spot check Merkle proofs,
# and low-degree proofs of P and D # and low-degree proofs of P and D
o = [p_mtree[1], o = [p_mtree[1],
d_mtree[1], d_mtree[1],
k_mtree[1],
l_mtree[1],
branches, branches,
prove_low_degree(values_polynomial, root, p_evaluations, steps), prove_low_degree(lincomb, root, l_evaluations, steps * 2)]
prove_low_degree(d, root, d_evaluations, steps * 2)]
print("STARK computed in %.4f sec" % (time.time() - start_time)) print("STARK computed in %.4f sec" % (time.time() - start_time))
return o return o
# Verifies a STARK # Verifies a STARK
def verify_mimc_proof(inp, logsteps, logprecision, output, proof): def verify_mimc_proof(inp, logsteps, logprecision, output, proof):
p_root, d_root, branches, p_proof, d_proof = proof p_root, d_root, k_root, l_root, branches, fri_proof = proof
start_time = time.time() start_time = time.time()
steps = 2**logsteps steps = 2**logsteps
@ -289,37 +297,42 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, proof):
skips = precision // steps skips = precision // steps
# Verifies the low-degree proofs # Verifies the low-degree proofs
assert verify_low_degree_proof(p_root, root_of_unity, p_proof, steps) assert verify_low_degree_proof(l_root, root_of_unity, fri_proof, steps * 2)
assert verify_low_degree_proof(d_root, root_of_unity, d_proof, steps * 2)
# Performs the spot checks # Performs the spot checks
k = int.from_bytes(blake(p_root + d_root), 'big')
samples = spot_check_security_factor // (logprecision - logsteps) samples = spot_check_security_factor // (logprecision - logsteps)
positions = get_indices(blake(p_root + d_root), precision - skips, samples) positions = get_indices(l_root, precision - skips, samples)
for i, pos in enumerate(positions): for i, pos in enumerate(positions):
# Check C(P(x)) = Z(x) * D(x) # Check C(P(x)) = Z(x) * D(x)
x = pow(root_of_unity, pos, modulus) x = pow(root_of_unity, pos, modulus)
p_of_x = verify_branch(p_root, pos, branches[i*3]) p_of_x = verify_branch(p_root, pos, branches[i*5])
p_of_rx = verify_branch(p_root, pos+skips, branches[i*3 + 1]) p_of_rx = verify_branch(p_root, pos+skips, branches[i*5 + 1])
d_of_x = verify_branch(d_root, pos, branches[i*3 + 2]) d_of_x = verify_branch(d_root, pos, branches[i*5 + 2])
k_of_x = verify_branch(k_root, pos, branches[i*5 + 3])
l_of_x = verify_branch(l_root, pos, branches[i*5 + 4])
zvalue = f.div(pow(x, steps, modulus) - 1, zvalue = f.div(pow(x, steps, modulus) - 1,
x - pow(root_of_unity, (steps - 1) * skips, modulus)) x - pow(root_of_unity, (steps - 1) * skips, modulus))
assert (p_of_rx - p_of_x ** 3 - x - zvalue * d_of_x) % modulus == 0 assert (p_of_rx - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0
assert (l_of_x - d_of_x - k * p_of_x * pow(x, steps, modulus)) % modulus == 0
print('Verified %d consistency checks' % (spot_check_security_factor // (logprecision - logsteps))) print('Verified %d consistency checks' % (spot_check_security_factor // (logprecision - logsteps)))
print('Verified STARK in %.4f sec' % (time.time() - start_time)) print('Verified STARK in %.4f sec' % (time.time() - start_time))
print('Note: this does not include verifying the Merkle root of the constants tree')
print('This can be done by every client once as a precomputation')
return True return True
INPUT = 3 INPUT = 3
LOGSTEPS = 13 LOGSTEPS = 17
LOGPRECISION = 16 LOGPRECISION = 20
# Full STARK test # Full STARK test
proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION) proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION)
L1 = bin_length(compress_branches(proof[2])) p_root, d_root, k_root, l_root, branches, fri_proof = proof
L2 = bin_length(compress_fri(proof[3])) L1 = bin_length(compress_branches(branches))
L3 = bin_length(compress_fri(proof[4])) L2 = bin_length(compress_fri(fri_proof))
print("Approx proof length: %d (branches), %d (FRI proof 1), %d (FRI proof 2), %d (total)" % (L1, L2, L3, L1 + L2 + L3)) print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2))
root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus) root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus)
subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus) subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus)
skips = 2**(LOGPRECISION - LOGSTEPS) skips = 2**(LOGPRECISION - LOGSTEPS)