mirror of
https://github.com/status-im/research.git
synced 2025-01-27 07:15:14 +00:00
More efficiency
This commit is contained in:
parent
3180447866
commit
59b8020fec
@ -1,17 +1,21 @@
|
||||
#### Disclaimer
|
||||
|
||||
DO NOT USE FOR ANYTHING IN REAL LIFE. DO NOT ASSUME THE PROTOCOL DESCRIBED HERE IS SOUND. TALK TO A SPECIALIST IF YOU'RE LOOKING TO USE STARKS IN YOUR APPLICATION.
|
||||
|
||||
#### What is this?
|
||||
|
||||
This is a very basic implementation of a STARK on a MIMC computation that is probably (ie. definitely) broken in a few places but is intended as a proof of concept to show the rough level of complexity that is involved in implementing a simple STARK. A STARK is a really cool proof-of-computation scheme that allows you to create an efficiently verifiable proof that some computation was executed correctly; the verification time only rises logarithmically with the computation time, and that relies only on hashes and information theory for security.
|
||||
|
||||
The STARKs are done over a finite field chosen to have 2^32'th roots of unity (to facilitate STARKs), and NOT have 3rd roots of unity (to facilitate MIMC). The MIMC permutation in general takes the form:
|
||||
The STARKs are done over a finite field chosen to have `2**32`'th roots of unity (to facilitate STARKs), and NOT have 3rd roots of unity (to facilitate MIMC). The MIMC permutation in general takes the form:
|
||||
|
||||
k_1 k_2
|
||||
| |
|
||||
v v
|
||||
x0 ---> (x->x^3) ---> xor ---> (x->x^3) ---> xor --- ... ---> output
|
||||
k_1 k_2
|
||||
| |
|
||||
v v
|
||||
x0 ---> (x->x**3) ---> xor ---> (x->x**3) ---> xor --- ... ---> output
|
||||
|
||||
Where the `k_i` values are round constants. MIMC can be used as a building block in a hash function, or as a verifiable delay function; its simple arithmetic representation makes it ideal for use in STARKs, SNARKs, MPC and other "cryptography over general-purpose computation" schemes.
|
||||
|
||||
The MIMC round constants used here are 2^k'th roots of unity for k <= 32. The computational trace is computed over successive powers of a 2^k'th root of unity. This allows the constraint checking polynomial to be `C(P(x), P(r*x), x) = P(r*x) - P(x)**3 - x`.
|
||||
The MIMC round constants used here are successive powers of 9 mod `2**256` xored with 1, though could be set to anything. The computational trace is computed over successive powers of a `2**k`'th root of unity. This allows the constraint checking polynomial to be `C(P(x), P(r*x), K(x)) = P(r*x) - P(x)**3 - K(x)`.
|
||||
|
||||
For a description of how STARKs work, see:
|
||||
|
||||
@ -22,4 +26,18 @@ For more discussion on MIMC, see:
|
||||
|
||||
* [Zcash issue 2233](https://github.com/zcash/zcash/issues/2233)
|
||||
|
||||
The implementation is definitely suboptimal. For example, one optimization would be to reduce the two FRIs to one by low-degree proving a random linear combination k1 * D + k2 * P^2; this could cut the verification time by close to half. Another would be to remove redundancies in computing zero polynomials in multiple places; a more sophisticated optimization would be to move to subquadratic algorithms (Karatsuba or ideally FFT-based) for polynomial interpolation. The purpose is for the code to be deliberately simple to aid with understanding.
|
||||
### How does the STARK scheme here work?
|
||||
|
||||
Here are the approximate steps in the code. Note that all arithmetic is done over the finite field of integers mod `2**256 - 2**32 * 351 + 1`.
|
||||
|
||||
1. Let `P[0]` equal the input, and `P[i+1] = P[i]**3 + K[i]`, up to `steps` (where `steps` must be a power of 2)
|
||||
2. Construct a polynomial where `P(subroot ^ i) = P[i]` up to steps-1, where `subroot` is a steps'th root of unity (that is, `subroot**steps = 1`). Do the same with K.
|
||||
3. Construct the polynomial `CP(x) = C(P(x), P(x * subroot), K(x))`. Note that since `P(x * subroot) = P(x) ^ 3 + K(x), CP(x) = 0` for any x inside the computation trace (that is, powers of subroot except the last).
|
||||
4. Construct the polynomial `Z(x)`, which is the minimal polynomial that is 0 across the computation trace. Note that because Z is minimal, CP must be a multiple of Z.
|
||||
5. Construct `D = CP / Z`. Because CP is a multiple of Z, D must itself be a low-degree polynomial.
|
||||
6. Put D and P into Merkle trees.
|
||||
7. Construct `L = D + k * x**steps * P`, where `k` is selected based on the Merkle roots of D and P.
|
||||
8. Create an FRI proof that L is low-degree ((7) and (8) together are a clever way of making an aggregate low-degree proof of D and P)
|
||||
9. Create probabilistic checks, using the Merkle root of L as source data, to spot-check that D(x) * Z(x) actually equal C(P(x)).
|
||||
|
||||
The probabilistic checks and the FRI proof are themselves restricted to `2**precision`'th roots of unity, where `precision = 8 * steps` (so we're FRI checking that the degree of L, which equals twice the degree of P, is at most 1/4 the theoretical maximum).
|
||||
|
37
mimc_stark/better_lagrange.py
Normal file
37
mimc_stark/better_lagrange.py
Normal file
@ -0,0 +1,37 @@
|
||||
def inv(a, modulus):
|
||||
if a == 0:
|
||||
return 0
|
||||
lm, hm = 1, 0
|
||||
low, high = a % modulus, modulus
|
||||
while low > 1:
|
||||
r = high//low
|
||||
nm, new = hm-lm*r, high-low*r
|
||||
lm, low, hm, high = nm, new, lm, low
|
||||
return lm % modulus
|
||||
|
||||
def eval_poly_at(poly, x, modulus):
|
||||
o, p = 0, 1
|
||||
for coeff in poly:
|
||||
o += coeff * p
|
||||
p *= x
|
||||
return o % modulus
|
||||
|
||||
def lagrange_interp_4(pieces, xs, modulus):
|
||||
x01, x02, x03, x12, x13, x23 = \
|
||||
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
|
||||
eq0 = [-x12 * xs[3] % modulus, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
|
||||
eq1 = [-x02 * xs[3] % modulus, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
|
||||
eq2 = [-x01 * xs[3] % modulus, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
|
||||
eq3 = [-x01 * xs[2] % modulus, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
|
||||
e0 = eval_poly_at(eq0, xs[0], modulus)
|
||||
e1 = eval_poly_at(eq1, xs[1], modulus)
|
||||
e2 = eval_poly_at(eq2, xs[2], modulus)
|
||||
e3 = eval_poly_at(eq3, xs[3], modulus)
|
||||
e01 = e0 * e1
|
||||
e23 = e2 * e3
|
||||
invall = inv(e01 * e23, modulus)
|
||||
inv_y0 = pieces[0] * invall * e1 * e23 % modulus
|
||||
inv_y1 = pieces[1] * invall * e0 * e23 % modulus
|
||||
inv_y2 = pieces[2] * invall * e01 * e3 % modulus
|
||||
inv_y3 = pieces[3] * invall * e01 * e2 % modulus
|
||||
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % modulus for i in range(4)]
|
@ -1,6 +1,15 @@
|
||||
def _simple_ft(vals, modulus, roots_of_unity):
|
||||
L = len(roots_of_unity)
|
||||
o = [0 for _ in range(L)]
|
||||
for i in range(L):
|
||||
for j in range(L):
|
||||
o[i] += vals[j] * roots_of_unity[(i*j)%L]
|
||||
return [x % modulus for x in o]
|
||||
|
||||
def _fft(vals, modulus, roots_of_unity):
|
||||
if len(vals) == 1:
|
||||
return vals
|
||||
# return _simple_ft(vals, modulus, roots_of_unity)
|
||||
L = _fft(vals[::2], modulus, roots_of_unity[::2])
|
||||
R = _fft(vals[1::2], modulus, roots_of_unity[::2])
|
||||
o = [0 for i in vals]
|
||||
@ -8,7 +17,6 @@ def _fft(vals, modulus, roots_of_unity):
|
||||
y_times_root = y*roots_of_unity[i]
|
||||
o[i] = (x+y_times_root) % modulus
|
||||
o[i+len(L)] = (x-y_times_root) % modulus
|
||||
# print(vals, root_of_unity, o)
|
||||
return o
|
||||
|
||||
def fft(vals, modulus, root_of_unity, inv=False):
|
||||
@ -22,10 +30,11 @@ def fft(vals, modulus, root_of_unity, inv=False):
|
||||
if inv:
|
||||
# Inverse FFT
|
||||
invlen = pow(len(vals), modulus-2, modulus)
|
||||
return [(x*invlen) % modulus for x in _fft(vals, modulus, rootz[::-1])]
|
||||
return [(x*invlen) % modulus for x in
|
||||
_fft(vals, modulus, rootz[:0:-1])]
|
||||
else:
|
||||
# Regular FFT
|
||||
return _fft(vals, modulus, rootz)
|
||||
return _fft(vals, modulus, rootz[:-1])
|
||||
|
||||
def mul_polys(a, b, modulus, root_of_unity):
|
||||
x1 = fft(a, modulus, root_of_unity)
|
||||
|
@ -1,8 +1,9 @@
|
||||
from merkle_tree import merkelize, mk_branch, verify_branch, blake
|
||||
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
|
||||
from ecpoly import PrimeField
|
||||
from fft import fft, mul_polys
|
||||
from better_lagrange import lagrange_interp_4
|
||||
import time
|
||||
from fft import fft
|
||||
|
||||
modulus = 2**256 - 2**32 * 351 + 1
|
||||
f = PrimeField(modulus)
|
||||
@ -14,18 +15,6 @@ quartic_roots_of_unity = [1,
|
||||
|
||||
spot_check_security_factor = 240
|
||||
|
||||
# Treat a polynomial as a bivariate polynomial g(x, y) and
|
||||
# evaluate it as such. Invariant: eval_as_bivariate(p, x, x**4) = eval(p, x)
|
||||
def eval_as_bivariate(p, x, y):
|
||||
o = 0
|
||||
ypow = 1
|
||||
xpows = [pow(x, i, modulus) for i in range(4)]
|
||||
for i in range(0, len(p), 4):
|
||||
for j in range(4):
|
||||
o += xpows[j] * ypow * p[i+j]
|
||||
ypow = (ypow * y) % modulus
|
||||
return o % modulus
|
||||
|
||||
# Get the set of powers of R, until but not including when the powers
|
||||
# loop back to 1
|
||||
def get_power_cycle(r):
|
||||
@ -68,9 +57,10 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1):
|
||||
# directly, as this is more efficient
|
||||
column = []
|
||||
for i in range(len(xs)//4):
|
||||
x_poly = f.lagrange_interp(
|
||||
x_poly = lagrange_interp_4(
|
||||
[values[i+len(values)*j//4] for j in range(4)],
|
||||
[xs[i+len(xs)*j//4] for j in range(4)]
|
||||
[xs[i+len(xs)*j//4] for j in range(4)],
|
||||
modulus
|
||||
)
|
||||
column.append(f.eval_poly_at(x_poly, special_x))
|
||||
m2 = merkelize(column)
|
||||
@ -119,12 +109,12 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1):
|
||||
|
||||
|
||||
# Verify for each selected y coordinate that the four points from the polynomial
|
||||
# and the one point from the column that are on that y coordinate are on a
|
||||
# and the one point from the column that are on that y coordinate are on the same
|
||||
# deg < 4 polynomial
|
||||
for i, y in enumerate(ys):
|
||||
# The five x coordinates we are checking
|
||||
# The x coordinates from the polynomial
|
||||
x1 = pow(root_of_unity, y, modulus)
|
||||
eckses = [special_x] + [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)]
|
||||
xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)]
|
||||
|
||||
# The values from the polynomial
|
||||
row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) for j, prf in zip(range(4), branches[i][1:])]
|
||||
@ -133,8 +123,8 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1):
|
||||
values = [verify_branch(root2, y, branches[i][0])] + row
|
||||
|
||||
# Lagrange interpolate and check deg is < 4
|
||||
p = f.lagrange_interp(values, eckses)
|
||||
assert p[4] == 0
|
||||
p = lagrange_interp_4(row, xcoords, modulus)
|
||||
assert f.eval_poly_at(p, special_x) == verify_branch(root2, y, branches[i][0])
|
||||
|
||||
# Update constants to check the next proof
|
||||
merkle_root = root2
|
||||
@ -175,9 +165,11 @@ def mimc(inp, logsteps, logprecision):
|
||||
precision = 2**logprecision
|
||||
# Get (steps)th root of unity
|
||||
subroot = pow(7, (modulus-1)//steps, modulus)
|
||||
xs = get_power_cycle(subroot)
|
||||
# We use powers of 9 mod 2^256 as the ith round constant for the moment
|
||||
k = 1
|
||||
for i in range(steps-1):
|
||||
inp = (inp**3 + xs[i]) % modulus
|
||||
inp = (inp**3 + (k ^ 1)) % modulus
|
||||
k = (k * 9) & ((1 << 256) - 1)
|
||||
print("MIMC computed in %.4f sec" % (time.time() - start_time))
|
||||
return inp
|
||||
|
||||
@ -215,27 +207,30 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
||||
xs = get_power_cycle(subroot)
|
||||
|
||||
# Generate the computational trace
|
||||
constants = []
|
||||
values = [inp]
|
||||
k = 1
|
||||
for i in range(steps-1):
|
||||
values.append((values[-1]**3 + xs[i]) % modulus)
|
||||
values.append((values[-1]**3 + (k ^ 1)) % modulus)
|
||||
constants.append(k ^ 1)
|
||||
k = (k * 9) & ((1 << 256) - 1)
|
||||
constants.append(0)
|
||||
print('Done generating computational trace')
|
||||
|
||||
# Interpolate the computational trace into a polynomial
|
||||
# values_polynomial = f.lagrange_interp(values, [pow(subroot, i, modulus) for i in range(steps)])
|
||||
values_polynomial = fft(values, modulus, subroot, inv=True)
|
||||
print('Computed polynomial')
|
||||
|
||||
#for x, v in zip(xs, values):
|
||||
# assert f.eval_poly_at(values_polynomial, x) == v
|
||||
constants_polynomial = fft(constants, modulus, subroot, inv=True)
|
||||
print('Converted computational steps and constants into a polynomial')
|
||||
|
||||
# Create the composed polynomial such that
|
||||
# C(P(x), P(rx)) = P(rx) - P(x)**3 - x
|
||||
# C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x)
|
||||
term1 = multiply_base(values_polynomial, subroot)
|
||||
term2 = fft([pow(x, 3, modulus) for x in fft(values_polynomial, modulus, root)], modulus, root, inv=True)[:len(values_polynomial) * 3 - 2]
|
||||
c_of_values = f.sub_polys(f.sub_polys(term1, term2), [0, 1])
|
||||
print('Computed C(P) polynomial')
|
||||
p_evaluations = fft(values_polynomial, modulus, root)
|
||||
term2 = fft([pow(x, 3, modulus) for x in p_evaluations], modulus, root, inv=True)[:len(values_polynomial) * 3 - 2]
|
||||
c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial)
|
||||
print('Computed C(P, K) polynomial')
|
||||
|
||||
# Compute D(x) = C(P(x)) / Z(x)
|
||||
# Compute D(x) = C(P(x), P(rx), K(x)) / Z(x)
|
||||
# Z(x) = (x^steps - 1) / (x - x_atlast_step)
|
||||
d = divide_by_xnm1(f.mul_polys(c_of_values,
|
||||
[modulus-xs[steps-1], 1]),
|
||||
@ -243,42 +238,55 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
||||
# assert f.mul_polys(d, z) == c_of_values
|
||||
print('Computed D polynomial')
|
||||
|
||||
# Evaluate P and D across the entire subgroup
|
||||
p_evaluations = fft(values_polynomial, modulus, root)
|
||||
# Evaluate D and K across the entire subgroup
|
||||
d_evaluations = fft(d, modulus, root)
|
||||
print('Evaluated P and D')
|
||||
k_evaluations = fft(constants_polynomial, modulus, root)
|
||||
print('Evaluated P, D and K')
|
||||
|
||||
# Compute their Merkle roots
|
||||
p_mtree = merkelize(p_evaluations)
|
||||
d_mtree = merkelize(d_evaluations)
|
||||
k_mtree = merkelize(k_evaluations)
|
||||
print('Computed hash root')
|
||||
|
||||
# Based on the hashes of P and D, we select a random linear combination
|
||||
# of P * x^steps and D, and prove the low-degreeness of that, instead of proving
|
||||
# the low-degreeness of P and D separately
|
||||
k = int.from_bytes(blake(p_mtree[1] + d_mtree[1]), 'big')
|
||||
|
||||
lincomb = f.add_polys(d, f.mul_by_const([0] * steps + values_polynomial, k))
|
||||
l_evaluations = fft(lincomb, modulus, root)
|
||||
l_mtree = merkelize(l_evaluations)
|
||||
|
||||
print('Computed random linear combination')
|
||||
|
||||
# Do some spot checks of the Merkle tree at pseudo-random coordinates
|
||||
branches = []
|
||||
samples = spot_check_security_factor // (logprecision - logsteps)
|
||||
positions = get_indices(blake(p_mtree[1] + d_mtree[1]), precision - skips, samples)
|
||||
positions = get_indices(l_mtree[1], precision - skips, samples)
|
||||
for pos in positions:
|
||||
branches.append(mk_branch(p_mtree, pos))
|
||||
branches.append(mk_branch(p_mtree, pos + skips))
|
||||
branches.append(mk_branch(d_mtree, pos))
|
||||
branches.append(mk_branch(k_mtree, pos))
|
||||
branches.append(mk_branch(l_mtree, pos))
|
||||
print('Computed %d spot checks' % samples)
|
||||
|
||||
while len(d) < steps * 2:
|
||||
d += [0]
|
||||
|
||||
# Return the Merkle roots of P and D, the spot check Merkle proofs,
|
||||
# and low-degree proofs of P and D
|
||||
o = [p_mtree[1],
|
||||
d_mtree[1],
|
||||
k_mtree[1],
|
||||
l_mtree[1],
|
||||
branches,
|
||||
prove_low_degree(values_polynomial, root, p_evaluations, steps),
|
||||
prove_low_degree(d, root, d_evaluations, steps * 2)]
|
||||
prove_low_degree(lincomb, root, l_evaluations, steps * 2)]
|
||||
print("STARK computed in %.4f sec" % (time.time() - start_time))
|
||||
return o
|
||||
|
||||
# Verifies a STARK
|
||||
def verify_mimc_proof(inp, logsteps, logprecision, output, proof):
|
||||
p_root, d_root, branches, p_proof, d_proof = proof
|
||||
p_root, d_root, k_root, l_root, branches, fri_proof = proof
|
||||
start_time = time.time()
|
||||
|
||||
steps = 2**logsteps
|
||||
@ -289,37 +297,42 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, proof):
|
||||
skips = precision // steps
|
||||
|
||||
# Verifies the low-degree proofs
|
||||
assert verify_low_degree_proof(p_root, root_of_unity, p_proof, steps)
|
||||
assert verify_low_degree_proof(d_root, root_of_unity, d_proof, steps * 2)
|
||||
assert verify_low_degree_proof(l_root, root_of_unity, fri_proof, steps * 2)
|
||||
|
||||
# Performs the spot checks
|
||||
k = int.from_bytes(blake(p_root + d_root), 'big')
|
||||
samples = spot_check_security_factor // (logprecision - logsteps)
|
||||
positions = get_indices(blake(p_root + d_root), precision - skips, samples)
|
||||
positions = get_indices(l_root, precision - skips, samples)
|
||||
for i, pos in enumerate(positions):
|
||||
|
||||
# Check C(P(x)) = Z(x) * D(x)
|
||||
x = pow(root_of_unity, pos, modulus)
|
||||
p_of_x = verify_branch(p_root, pos, branches[i*3])
|
||||
p_of_rx = verify_branch(p_root, pos+skips, branches[i*3 + 1])
|
||||
d_of_x = verify_branch(d_root, pos, branches[i*3 + 2])
|
||||
p_of_x = verify_branch(p_root, pos, branches[i*5])
|
||||
p_of_rx = verify_branch(p_root, pos+skips, branches[i*5 + 1])
|
||||
d_of_x = verify_branch(d_root, pos, branches[i*5 + 2])
|
||||
k_of_x = verify_branch(k_root, pos, branches[i*5 + 3])
|
||||
l_of_x = verify_branch(l_root, pos, branches[i*5 + 4])
|
||||
zvalue = f.div(pow(x, steps, modulus) - 1,
|
||||
x - pow(root_of_unity, (steps - 1) * skips, modulus))
|
||||
assert (p_of_rx - p_of_x ** 3 - x - zvalue * d_of_x) % modulus == 0
|
||||
assert (p_of_rx - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0
|
||||
assert (l_of_x - d_of_x - k * p_of_x * pow(x, steps, modulus)) % modulus == 0
|
||||
|
||||
print('Verified %d consistency checks' % (spot_check_security_factor // (logprecision - logsteps)))
|
||||
print('Verified STARK in %.4f sec' % (time.time() - start_time))
|
||||
print('Note: this does not include verifying the Merkle root of the constants tree')
|
||||
print('This can be done by every client once as a precomputation')
|
||||
return True
|
||||
|
||||
INPUT = 3
|
||||
LOGSTEPS = 13
|
||||
LOGPRECISION = 16
|
||||
LOGSTEPS = 17
|
||||
LOGPRECISION = 20
|
||||
|
||||
# Full STARK test
|
||||
proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION)
|
||||
L1 = bin_length(compress_branches(proof[2]))
|
||||
L2 = bin_length(compress_fri(proof[3]))
|
||||
L3 = bin_length(compress_fri(proof[4]))
|
||||
print("Approx proof length: %d (branches), %d (FRI proof 1), %d (FRI proof 2), %d (total)" % (L1, L2, L3, L1 + L2 + L3))
|
||||
p_root, d_root, k_root, l_root, branches, fri_proof = proof
|
||||
L1 = bin_length(compress_branches(branches))
|
||||
L2 = bin_length(compress_fri(fri_proof))
|
||||
print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2))
|
||||
root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus)
|
||||
subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus)
|
||||
skips = 2**(LOGPRECISION - LOGSTEPS)
|
||||
|
Loading…
x
Reference in New Issue
Block a user