From da1d7237800b37bce054c0bdd310de2abb1dcc28 Mon Sep 17 00:00:00 2001 From: Vitalik Buterin Date: Tue, 10 Jul 2018 15:45:12 -0400 Subject: [PATCH] Improved efficiency a bunch and added boundary checks --- mimc_stark/better_lagrange.py | 29 --- mimc_stark/ecpoly/LICENSE | 0 mimc_stark/ecpoly/README.md | 0 mimc_stark/ecpoly/ecpoly/__init__.py | 1 - .../ecpoly/ecpoly/subquadratic_poly_utils.py | 198 ------------------ mimc_stark/ecpoly/setup.py | 24 --- mimc_stark/ecpoly/tests/test_basic_ops.py | 15 -- mimc_stark/fri.py | 65 +++--- mimc_stark/mimc_stark.py | 100 +++++---- mimc_stark/{ecpoly/ecpoly => }/poly_utils.py | 89 ++++++-- mimc_stark/test.py | 33 ++- 11 files changed, 173 insertions(+), 381 deletions(-) delete mode 100644 mimc_stark/ecpoly/LICENSE delete mode 100644 mimc_stark/ecpoly/README.md delete mode 100644 mimc_stark/ecpoly/ecpoly/__init__.py delete mode 100644 mimc_stark/ecpoly/ecpoly/subquadratic_poly_utils.py delete mode 100644 mimc_stark/ecpoly/setup.py delete mode 100644 mimc_stark/ecpoly/tests/test_basic_ops.py rename mimc_stark/{ecpoly/ecpoly => }/poly_utils.py (56%) diff --git a/mimc_stark/better_lagrange.py b/mimc_stark/better_lagrange.py index fa2dd16..bbd823e 100644 --- a/mimc_stark/better_lagrange.py +++ b/mimc_stark/better_lagrange.py @@ -16,32 +16,3 @@ def eval_poly_at(poly, x, modulus): p = (p * x % modulus) return o % modulus -def lagrange_interp_4(pieces, xs, modulus): - x01, x02, x03, x12, x13, x23 = \ - xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3] - eq0 = [-x12 * xs[3] % modulus, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1] - eq1 = [-x02 * xs[3] % modulus, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1] - eq2 = [-x01 * xs[3] % modulus, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1] - eq3 = [-x01 * xs[2] % modulus, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1] - e0 = eval_poly_at(eq0, xs[0], modulus) - e1 = eval_poly_at(eq1, xs[1], modulus) - e2 = eval_poly_at(eq2, xs[2], modulus) - e3 = eval_poly_at(eq3, xs[3], modulus) - e01 = e0 * e1 - e23 = e2 * e3 - invall = inv(e01 * e23, modulus) - inv_y0 = pieces[0] * invall * e1 * e23 % modulus - inv_y1 = pieces[1] * invall * e0 * e23 % modulus - inv_y2 = pieces[2] * invall * e01 * e3 % modulus - inv_y3 = pieces[3] * invall * e01 * e2 % modulus - return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % modulus for i in range(4)] - -def lagrange_interp_2(pieces, xs, modulus): - eq0 = [-xs[1] % modulus, 1] - eq1 = [-xs[0] % modulus, 1] - e0 = eval_poly_at(eq0, xs[0], modulus) - e1 = eval_poly_at(eq1, xs[1], modulus) - invall = inv(e0 * e1, modulus) - inv_y0 = pieces[0] * invall * e1 - inv_y1 = pieces[1] * invall * e0 - return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % modulus for i in range(2)] diff --git a/mimc_stark/ecpoly/LICENSE b/mimc_stark/ecpoly/LICENSE deleted file mode 100644 index e69de29..0000000 diff --git a/mimc_stark/ecpoly/README.md b/mimc_stark/ecpoly/README.md deleted file mode 100644 index e69de29..0000000 diff --git a/mimc_stark/ecpoly/ecpoly/__init__.py b/mimc_stark/ecpoly/ecpoly/__init__.py deleted file mode 100644 index 99729e5..0000000 --- a/mimc_stark/ecpoly/ecpoly/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .poly_utils import PrimeField diff --git a/mimc_stark/ecpoly/ecpoly/subquadratic_poly_utils.py b/mimc_stark/ecpoly/ecpoly/subquadratic_poly_utils.py deleted file mode 100644 index c069678..0000000 --- a/mimc_stark/ecpoly/ecpoly/subquadratic_poly_utils.py +++ /dev/null @@ -1,198 +0,0 @@ -modulus_poly = [1, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 1, 0, 1, 0, 0, 1, - 1] -modulus_poly_as_int = sum([(v << i) for i, v in enumerate(modulus_poly)]) -degree = len(modulus_poly) - 1 - -two_to_the_degree = 2**degree -two_to_the_degree_m1 = 2**degree - 1 - -def galoistpl(a): - # 2 is not a primitive root, so we have to use 3 as our logarithm base - if a * 2 < two_to_the_degree: - return (a * 2) ^ a - else: - return (a * 2) ^ a ^ modulus_poly_as_int - -# Precomputing a log table for increased speed of addition and multiplication -glogtable = [0] * (two_to_the_degree) -gexptable = [] -v = 1 -for i in range(two_to_the_degree_m1): - glogtable[v] = i - gexptable.append(v) - v = galoistpl(v) - -gexptable += gexptable + gexptable - -# Add two values in the Galois field -def galois_add(x, y): - return x ^ y - -# In binary fields, addition and subtraction are the same thing -galois_sub = galois_add - -# Multiply two values in the Galois field -def galois_mul(x, y): - return 0 if x*y == 0 else gexptable[glogtable[x] + glogtable[y]] - -# Divide two values in the Galois field -def galois_div(x, y): - return 0 if x == 0 else gexptable[(glogtable[x] - glogtable[y]) % two_to_the_degree_m1] - -# Evaluate a polynomial at a point -def eval_poly_at(p, x): - if x == 0: - return p[0] - y = 0 - logx = glogtable[x] - for i, p_coeff in enumerate(p): - if p_coeff: - # Add x**i * coeff - y ^= gexptable[(logx * i + glogtable[p_coeff]) % two_to_the_degree_m1] - return y - - -# Given p+1 y values and x values with no errors, recovers the original -# p+1 degree polynomial. -# Lagrange interpolation works roughly in the following way. -# 1. Suppose you have a set of points, eg. x = [1, 2, 3], y = [2, 5, 10] -# 2. For each x, generate a polynomial which equals its corresponding -# y coordinate at that point and 0 at all other points provided. -# 3. Add these polynomials together. - -def lagrange_interp(pieces, xs): - # Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn) - root = mk_root_2(xs) - #print(root) - assert len(root) == len(pieces) + 1 - # print(root) - # Generate the derivative - d = derivative(root) - # Generate denominators by evaluating numerator polys at each x - denoms = multi_eval_2(d, xs) - print(denoms) - # denoms = [eval_poly_at(d, xs[i]) for i in range(len(xs))] - # Generate output polynomial, which is the sum of the per-value numerator - # polynomials rescaled to have the right y values - factors = [galois_div(p, d) for p, d in zip(pieces, denoms)] - o = multi_root_derive(xs, factors) - # print(o) - return o - -def multi_root_derive(xs, muls): - if len(xs) == 1: - return [muls[0]] - R1 = mk_root_2(xs[:len(xs) // 2]) - R2 = mk_root_2(xs[len(xs) // 2:]) - x1 = karatsuba_mul(R1, multi_root_derive(xs[len(xs) // 2:], muls[len(muls) // 2:]) + [0]) - x2 = karatsuba_mul(R2, multi_root_derive(xs[:len(xs) // 2], muls[:len(muls) // 2]) + [0]) - o = [v1 ^ v2 for v1, v2 in zip(x1, x2)][:len(xs)] - # print(len(R1), len(x1), len(xs), len(o)) - return o - -def multi_root_derive_1(xs, muls): - o = [0] * len(xs) - for i in range(len(xs)): - _xs = xs[:i] + xs[(i+1):] - root = mk_root_2(_xs) - for j in range(len(root)): - o[j] ^= galois_mul(root[j], muls[i]) - return o - -a = 124 -b = 8932 -c = 12415 - -assert galois_mul(galois_add(a, b), c) == galois_add(galois_mul(a, c), galois_mul(b, c)) - -def karatsuba_mul(p1, p2): - L = len(p1) - # assert L == len(p2) - if L <= 16: - o = [0] * (L * 2) - for i, v1 in enumerate(p1): - for j, v2 in enumerate(p2): - if v1 and v2: - o[i + j] ^= gexptable[glogtable[v1] + glogtable[v2]] - return o - if L % 2: - p1 = p1 + [0] - p2 = p2 + [0] - L += 1 - halflen = L // 2 - low1 = p1[:halflen] - high1 = p1[halflen:] - sum1 = [l ^ h for l, h in zip(low1, high1)] - low2 = p2[:halflen] - high2 = p2[halflen:] - sum2 = [l ^ h for l, h in zip(low2, high2)] - z2 = karatsuba_mul(high1, high2) - z0 = karatsuba_mul(low1, low2) - z1 = [m ^ _z0 ^ _z2 for m, _z0, _z2 in zip(karatsuba_mul(sum1, sum2), z0, z2)] - o = z0[:halflen] + \ - [a ^ b for a, b in zip(z0[halflen:], z1[:halflen])] + \ - [a ^ b for a, b in zip(z2[:halflen], z1[halflen:])] + \ - z2[halflen:] - return o - -def mk_root_1(xs): - root = [1] - for x in xs: - logx = glogtable[x] - root.insert(0, 0) - for j in range(len(root)-1): - if root[j+1] and x: - root[j] ^= gexptable[glogtable[root[j+1]] + logx] - return root - -def mk_root_2(xs): - if len(xs) >= 128: - return karatsuba_mul(mk_root_2(xs[:len(xs) // 2]), mk_root_2(xs[len(xs) // 2:]))[:len(xs) + 1] - root = [1] - for x in xs: - logx = glogtable[x] - root.insert(0, 0) - for j in range(len(root)-1): - if root[j+1] and x: - root[j] ^= gexptable[glogtable[root[j+1]] + logx] - return root - -def derivative(root): - return [0 if i % 2 else r for i, r in enumerate(root[1:])] - -# Credit to http://people.csail.mit.edu/madhu/ST12/scribe/lect06.pdf for the algorithm -def xn_mod_poly(p): - if len(p) == 1: - return [galois_div(1, p[0])] - halflen = len(p) // 2 - lowinv = xn_mod_poly(p[:halflen]) - submod_high = karatsuba_mul(lowinv, p[:halflen])[halflen:] - med = karatsuba_mul(p[halflen:], lowinv)[:halflen] - med_plus_high = [x ^ y for x, y in zip(med, submod_high)] - highinv = karatsuba_mul(med_plus_high, lowinv) - o = (lowinv + highinv)[:len(p)] - print(halflen, lowinv, submod_high, med, highinv) - # assert karatsuba_mul(o, p)[:len(p)] == [1] + [0] * (len(p) - 1) - return o - -def mod(a, b): - assert len(a) == 2 * (len(b) - 1) - L = len(b) - inv_rev_b = xn_mod_poly(b[::-1] + [0] * (len(a) - L))[:L] - quot = karatsuba_mul(inv_rev_b, a[::-1][:L])[:L-1][::-1] - subt = karatsuba_mul(b, quot + [0])[:-1] - o = [x ^ y for x, y in zip(a[:L-1], subt[:L-1])] - # assert [x^y for x, y in zip(karatsuba_mul(quot + [0], b), o)] == a - return o - -def multi_eval_1(poly, xs): - return [eval_poly_at(poly, x) for x in xs] - -def multi_eval_2(poly, xs): - if len(xs) <= 1024: - return [eval_poly_at(poly, x) for x in xs] - halflen = len(xs) // 2 - return multi_eval_2(mod(poly, mk_root_2(xs[:halflen])), xs[:halflen]) + \ - multi_eval_2(mod(poly, mk_root_2(xs[halflen:])), xs[halflen:]) - # [eval_poly_at(poly, xs[-2]), eval_poly_at(poly, xs[-1])] diff --git a/mimc_stark/ecpoly/setup.py b/mimc_stark/ecpoly/setup.py deleted file mode 100644 index 97589ab..0000000 --- a/mimc_stark/ecpoly/setup.py +++ /dev/null @@ -1,24 +0,0 @@ -# -*- coding: utf-8 -*- - -from setuptools import setup, find_packages - - -with open('README.md') as f: - readme = f.read() - -with open('LICENSE') as f: - license = f.read() - -setup( - name='ecpoly', - version='1.0.0', - description='Erasure code utilities for prime fields', - long_description=readme, - author='Vitalik Buterin', - author_email='', - url='https://github.com/ethereum/research/tree/master/erasure_code/ecpoly', - license=license, - packages=find_packages(exclude=('tests', 'docs')), - install_requires=[ - ], -) diff --git a/mimc_stark/ecpoly/tests/test_basic_ops.py b/mimc_stark/ecpoly/tests/test_basic_ops.py deleted file mode 100644 index a9056b3..0000000 --- a/mimc_stark/ecpoly/tests/test_basic_ops.py +++ /dev/null @@ -1,15 +0,0 @@ -from ecpoly import PrimeField - -f = PrimeField(65537) - -k1 = list(range(10)) -k2 = list(range(100, 200)) -k3 = f.mul_polys(k1, k2) -assert f.div_polys(k3, k1) == k2 -assert f.div_polys(k3, k2) == k1 -assert (f.eval_poly_at(k1, 9999) * f.eval_poly_at(k2, 9999) - - f.eval_poly_at(k3, 9999)) % f.modulus == 0 -k4 = f.compose_polys(k1, k2) -assert f.eval_poly_at(k4, 9998) == f.eval_poly_at(k1, f.eval_poly_at(k2, 9998)) - -print("All passed!") diff --git a/mimc_stark/fri.py b/mimc_stark/fri.py index 23b21dc..5eaa0df 100644 --- a/mimc_stark/fri.py +++ b/mimc_stark/fri.py @@ -1,10 +1,17 @@ -from better_lagrange import lagrange_interp_4, eval_poly_at from merkle_tree import merkelize, mk_branch, verify_branch from utils import get_power_cycle, get_pseudorandom_indices from fft import fft +from poly_utils import PrimeField -# Generate an FRI proof -def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus): +# Generate an FRI proof that the polynomial that has the specified +# values at successive powers of the specified root of unity has a +# degree lower than maxdeg_plus_1 +# +# We use maxdeg+1 instead of maxdeg because it's more mathematically +# convenient in this case. + +def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus): + f = PrimeField(modulus) print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1)) # If the degree we are checking for is less than or equal to 32, @@ -24,18 +31,17 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus): # Select a pseudo-random x coordinate special_x = int.from_bytes(m[1], 'big') % modulus - # Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html) - # at that x coordinate + # Calculate the "column" at that x coordinate + # (see https://vitalik.ca/general/2017/11/22/starks_part_2.html) # We calculate the column by Lagrange-interpolating the row, and not - # directly, as this is more efficient + # directly from the polynomial, as this is more efficient column = [] for i in range(len(xs)//4): - x_poly = lagrange_interp_4( - [values[i+len(values)*j//4] for j in range(4)], + x_poly = f.lagrange_interp_4( [xs[i+len(xs)*j//4] for j in range(4)], - modulus + [values[i+len(values)*j//4] for j in range(4)], ) - column.append(eval_poly_at(x_poly, special_x, modulus)) + column.append(f.eval_poly_at(x_poly, special_x)) m2 = merkelize(column) # Pseudo-randomly select y indices to sample @@ -44,23 +50,24 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus): # Compute the Merkle branches for the values in the polynomial and the column branches = [] for y in ys: - branches.append([mk_branch(m2, y)] + [mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)]) + branches.append([mk_branch(m2, y)] + + [mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)]) # This component of the proof o = [m2[1], branches] - # In the next iteration of the proof, we'll work with smaller roots of unity - sub_xs = [xs[i] for i in range(0, len(xs), 4)] - # Interpolate the polynomial for the column - ypoly = fft(column[:len(sub_xs)], modulus, - pow(root_of_unity, 4, modulus), inv=True) + # sub_xs = [xs[i] for i in range(0, len(xs), 4)] + # ypoly = fft(column[:len(sub_xs)], modulus, + # f.exp(root_of_unity, 4), inv=True) # Recurse... - return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4, modulus) + return [o] + prove_low_degree(column, f.exp(root_of_unity, 4), + maxdeg_plus_1 // 4, modulus) # Verify an FRI proof def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, modulus): + f = PrimeField(modulus) # Calculate which root of unity we're working with testval = root_of_unity @@ -69,10 +76,11 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo roudeg *= 2 testval = (testval * testval) % modulus + # Powers of the given root of unity 1, p, p**2, p**3 such that p**4 = 1 quartic_roots_of_unity = [1, - pow(root_of_unity, roudeg // 4, modulus), - pow(root_of_unity, roudeg // 2, modulus), - pow(root_of_unity, roudeg * 3 // 4, modulus)] + f.exp(root_of_unity, roudeg // 4), + f.exp(root_of_unity, roudeg // 2), + f.exp(root_of_unity, roudeg * 3 // 4)] # Verify the recursive components of the proof for prf in proof[:-1]: @@ -86,27 +94,28 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo ys = get_pseudorandom_indices(root2, roudeg // 4, 40) - # Verify for each selected y coordinate that the four points from the polynomial - # and the one point from the column that are on that y coordinate are on the same - # deg < 4 polynomial + # Verify for each selected y coordinate that the four points from the + # polynomial and the one point from the column that are on that y + # coordinate are on the same deg < 4 polynomial for i, y in enumerate(ys): # The x coordinates from the polynomial - x1 = pow(root_of_unity, y, modulus) + x1 = f.exp(root_of_unity, y) xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)] # The values from the polynomial - row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) for j, prf in zip(range(4), branches[i][1:])] + row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) + for j, prf in zip(range(4), branches[i][1:])] # Verify proof and recover the column value values = [verify_branch(root2, y, branches[i][0])] + row # Lagrange interpolate and check deg is < 4 - p = lagrange_interp_4(row, xcoords, modulus) - assert eval_poly_at(p, special_x, modulus) == verify_branch(root2, y, branches[i][0]) + p = f.lagrange_interp_4(xcoords, row) + assert f.eval_poly_at(p, special_x) == verify_branch(root2, y, branches[i][0]) # Update constants to check the next proof merkle_root = root2 - root_of_unity = pow(root_of_unity, 4, modulus) + root_of_unity = f.exp(root_of_unity, 4) maxdeg_plus_1 //= 4 roudeg //= 4 diff --git a/mimc_stark/mimc_stark.py b/mimc_stark/mimc_stark.py index 551a01a..6345eb4 100644 --- a/mimc_stark/mimc_stark.py +++ b/mimc_stark/mimc_stark.py @@ -1,7 +1,6 @@ from merkle_tree import merkelize, mk_branch, verify_branch, blake from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length -from ecpoly import PrimeField -from better_lagrange import lagrange_interp_4, lagrange_interp_2 +from poly_utils import PrimeField import time from fft import fft from fri import prove_low_degree, verify_low_degree_proof @@ -13,15 +12,11 @@ nonresidue = 7 spot_check_security_factor = 240 -# Compute a MIMC permutation for 2**logsteps steps, using round constants -# from the multiplicative subgroup of size 2**logprecision -def mimc(inp, logsteps, logprecision): +# Compute a MIMC permutation for 2**logsteps steps +def mimc(inp, logsteps): start_time = time.time() steps = 2**logsteps - precision = 2**logprecision - # Get (steps)th root of unity - subroot = pow(7, (modulus-1)//steps, modulus) - # We use powers of 9 mod 2^256 as the ith round constant for the moment + # We use powers of 9 mod 2^256 XORed with 1 as the ith round constant for the moment k = 1 for i in range(steps-1): inp = (inp**3 + (k ^ 1)) % modulus @@ -29,34 +24,20 @@ def mimc(inp, logsteps, logprecision): print("MIMC computed in %.4f sec" % (time.time() - start_time)) return inp -# Convert a polynomial P(x) into a polynomial Q(x) = P(fac * x) -def multiply_base(poly, fac): - o = [] - r = 1 - for p in poly: - o.append(p * r % modulus) - r = r * fac % modulus - return o - -# Divides a polynomial by x^n-1 -def divide_by_xnm1(poly, n): - if len(poly) <= n: - return [] - return f.add_polys(poly[n:], divide_by_xnm1(poly[n:], n)) - # Generate a STARK for a MIMC calculation -def mk_mimc_proof(inp, logsteps, logprecision): +def mk_mimc_proof(inp, logsteps): start_time = time.time() - assert logsteps < logprecision <= 32 + assert logsteps <= 29 + logprecision = logsteps + 3 steps = 2**logsteps precision = 2**logprecision # Root of unity such that x^precision=1 - root_of_unity = pow(7, (modulus-1)//precision, modulus) + root_of_unity = f.exp(7, (modulus-1)//precision) # Root of unity such that x^skips=1 skips = precision // steps - subroot = pow(root_of_unity, skips) + subroot = f.exp(root_of_unity, skips) # Powers of the root of unity, our computational trace will be # along the sequence of roots of unity @@ -82,30 +63,31 @@ def mk_mimc_proof(inp, logsteps, logprecision): # Create the composed polynomial such that # C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x) - term1 = multiply_base(values_polynomial, subroot) + term1 = f.multiply_base(values_polynomial, subroot) p_evaluations = fft(values_polynomial, modulus, root_of_unity) - term2 = fft([pow(x, 3, modulus) for x in p_evaluations], modulus, root_of_unity, inv=True)[:len(values_polynomial) * 3 - 2] + term2 = fft([f.exp(x, 3) for x in p_evaluations], modulus, root_of_unity, inv=True)[:len(values_polynomial) * 3 - 2] c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial) print('Computed C(P, K) polynomial') # Compute D(x) = C(P(x), P(rx), K(x)) / Z(x) # Z(x) = (x^steps - 1) / (x - x_atlast_step) - d = divide_by_xnm1(f.mul_polys(c_of_values, - [-last_step_position, 1]), - steps) + d = f.divide_by_xnm1(f.mul_polys(c_of_values, + [-last_step_position, 1]), + steps) # Consistency check - assert (f.eval_poly_at(d, 90833) * - (pow(90833, steps, modulus) - 1) * - f.inv(f.eval_poly_at([-last_step_position, 1], 90833)) - - f.eval_poly_at(c_of_values, 90833)) % modulus == 0 + # assert (f.eval_poly_at(d, 90833) * + # (f.exp(90833, steps) - 1) * + # f.inv(f.eval_poly_at([-last_step_position, 1], 90833)) - + # f.eval_poly_at(c_of_values, 90833)) % modulus == 0 print('Computed D polynomial') # Compute interpolant of ((1, input), (x_atlast_step, output)) - interpolant = lagrange_interp_2([inp, output], [1, last_step_position], modulus) + interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output]) quotient = f.mul_polys([-1, 1], [-last_step_position, 1]) b = f.div_polys(f.sub_polys(values_polynomial, interpolant), quotient) # Consistency check - assert f.eval_poly_at(f.add_polys(f.mul_polys(b, quotient), interpolant), 7045) == f.eval_poly_at(values_polynomial, 7045) + # assert f.eval_poly_at(f.add_polys(f.mul_polys(b, quotient), interpolant), 7045) == \ + # f.eval_poly_at(values_polynomial, 7045) print('Computed B polynomial') # Evaluate B, D and K across the entire subgroup @@ -129,12 +111,18 @@ def mk_mimc_proof(inp, logsteps, logprecision): k3 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x03'), 'big') k4 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x04'), 'big') - lincomb = f.add_polys(f.add_polys(d, - f.mul_by_const(values_polynomial, k1) + f.mul_by_const(values_polynomial, k2)), - f.mul_by_const(b, k3) + [0, 0] + f.mul_by_const(b, k4) + [0,0]) - l_evaluations = fft(lincomb, modulus, root_of_unity) - l_mtree = merkelize(l_evaluations) + # Compute the linear combination. We don't even both calculating it in + # coefficient form; we just compute the evaluations + root_of_unity_to_the_steps = f.exp(root_of_unity, steps) + powers = [1] + for i in range(1, precision): + powers.append(powers[-1] * root_of_unity_to_the_steps % modulus) + l_evaluations = [(d_evaluations[i] + + p_evaluations[i] * k1 + p_evaluations[i] * k2 * powers[i] + + b_evaluations[i] * k3 + b_evaluations[i] * powers[i] * k4) % modulus + for i in range(precision)] + l_mtree = merkelize(l_evaluations) print('Computed random linear combination') # Do some spot checks of the Merkle tree at pseudo-random coordinates @@ -158,20 +146,21 @@ def mk_mimc_proof(inp, logsteps, logprecision): b_mtree[1], l_mtree[1], branches, - prove_low_degree(lincomb, root_of_unity, l_evaluations, steps * 2, modulus)] + prove_low_degree(l_evaluations, root_of_unity, steps * 2, modulus)] print("STARK computed in %.4f sec" % (time.time() - start_time)) return o # Verifies a STARK -def verify_mimc_proof(inp, logsteps, logprecision, output, proof): +def verify_mimc_proof(inp, logsteps, output, proof): p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof start_time = time.time() + logprecision = logsteps + 3 steps = 2**logsteps precision = 2**logprecision # Get (steps)th root of unity - root_of_unity = pow(7, (modulus-1)//precision, modulus) + root_of_unity = f.exp(7, (modulus-1)//precision) skips = precision // steps # Verifies the low-degree proofs @@ -184,24 +173,29 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, proof): k4 = int.from_bytes(blake(p_root + d_root + b_root + b'\x04'), 'big') samples = spot_check_security_factor // (logprecision - logsteps) positions = get_pseudorandom_indices(l_root, precision - skips, samples) - last_step_position = pow(root_of_unity, (steps - 1) * skips, modulus) + last_step_position = f.exp(root_of_unity, (steps - 1) * skips) for i, pos in enumerate(positions): - # Check C(P(x)) = Z(x) * D(x) - x = pow(root_of_unity, pos, modulus) - x_to_the_steps = pow(x, steps, modulus) + x = f.exp(root_of_unity, pos) + x_to_the_steps = f.exp(x, steps) p_of_x = verify_branch(p_root, pos, branches[i*6]) p_of_rx = verify_branch(p_root, pos+skips, branches[i*6 + 1]) d_of_x = verify_branch(d_root, pos, branches[i*6 + 2]) k_of_x = verify_branch(k_root, pos, branches[i*6 + 3]) b_of_x = verify_branch(b_root, pos, branches[i*6 + 4]) l_of_x = verify_branch(l_root, pos, branches[i*6 + 5]) - zvalue = f.div(pow(x, steps, modulus) - 1, + zvalue = f.div(f.exp(x, steps) - 1, x - last_step_position) + + # Check transition constraints C(P(x)) = Z(x) * D(x) assert (p_of_rx - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0 - interpolant = lagrange_interp_2([inp, output], [1, last_step_position], modulus) + interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output]) quotient = f.mul_polys([-1, 1], [-last_step_position, 1]) + + # Check boundary constraints B(x) * Q(x) + I(x) = P(x) assert (p_of_x - b_of_x * f.eval_poly_at(quotient, x) - f.eval_poly_at(interpolant, x)) % modulus == 0 + + # Check correctness of the linear combination assert (l_of_x - d_of_x - k1 * p_of_x - k2 * p_of_x * x_to_the_steps - k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0 diff --git a/mimc_stark/ecpoly/ecpoly/poly_utils.py b/mimc_stark/poly_utils.py similarity index 56% rename from mimc_stark/ecpoly/ecpoly/poly_utils.py rename to mimc_stark/poly_utils.py index 7ca98c1..804d27d 100644 --- a/mimc_stark/ecpoly/ecpoly/poly_utils.py +++ b/mimc_stark/poly_utils.py @@ -1,3 +1,5 @@ +# Creates an object that includes convenience operations for numbers +# and polynomials in some prime field class PrimeField(): def __init__(self, modulus): assert pow(2, modulus, modulus) == 2 @@ -12,9 +14,13 @@ class PrimeField(): def mul(self, x, y): return (x*y) % self.modulus + def exp(self, x, p): + return pow(x, p, self.modulus) + + # Modular inverse using the extended Euclidean algorithm def inv(self, a): if a == 0: - return 0 + raise Exception("Cannot invert 0") lm, hm = 1, 0 low, high = a % self.modulus, self.modulus while low > 1: @@ -24,14 +30,10 @@ class PrimeField(): return lm % self.modulus def div(self, x, y): - if x == 0 and y == 0: - return 1 return self.mul(x, self.inv(y)) # Evaluate a polynomial at a point def eval_poly_at(self, p, x): - if x == 0: - return p[0] y = 0 power_of_x = 1 for i, p_coeff in enumerate(p): @@ -39,7 +41,7 @@ class PrimeField(): power_of_x = (power_of_x * x) % self.modulus return y % self.modulus - # Build a polynomial that returns 0 at all xs + # Build a polynomial that returns 0 at all specified xs def zpoly(self, xs): root = [1] for x in xs: @@ -56,35 +58,62 @@ class PrimeField(): # y coordinate at that point and 0 at all other points provided. # 3. Add these polynomials together. - def lagrange_interp(self, pieces, xs): + def lagrange_interp(self, xs, ys): # Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn) root = self.zpoly(xs) - #print(root) - assert len(root) == len(pieces) + 1 + assert len(root) == len(ys) + 1 # print(root) # Generate per-value numerator polynomials, eg. for x=x2, # (x - x1) * (x - x3) * ... * (x - xn), by dividing the master # polynomial back by each x coordinate - nums = [] - for x in xs: - output = [0] * (len(root) - 2) + [1] - for j in range(len(root) - 2, 0, -1): - output[j-1] = root[j] + output[j] * x - assert len(output) == len(pieces) - nums.append(output) - #print(nums) + nums = [self.div_polys(root, [-x, 1]) for x in xs] # Generate denominators by evaluating numerator polys at each x denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))] # Generate output polynomial, which is the sum of the per-value numerator # polynomials rescaled to have the right y values - b = [0 for p in pieces] + b = [0 for y in ys] for i in range(len(xs)): - yslice = self.div(pieces[i], denoms[i]) - for j in range(len(pieces)): - if nums[i][j] and pieces[i]: + yslice = self.div(ys[i], denoms[i]) + for j in range(len(ys)): + if nums[i][j] and ys[i]: b[j] += nums[i][j] * yslice return [x % self.modulus for x in b] + + # Optimized version of the above restricted to deg-4 polynomials + def lagrange_interp_4(self, xs, ys): + x01, x02, x03, x12, x13, x23 = \ + xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3] + m = self.modulus + eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1] + eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1] + eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1] + eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1] + e0 = self.eval_poly_at(eq0, xs[0]) + e1 = self.eval_poly_at(eq1, xs[1]) + e2 = self.eval_poly_at(eq2, xs[2]) + e3 = self.eval_poly_at(eq3, xs[3]) + e01 = e0 * e1 + e23 = e2 * e3 + invall = self.inv(e01 * e23) + inv_y0 = ys[0] * invall * e1 * e23 % m + inv_y1 = ys[1] * invall * e0 * e23 % m + inv_y2 = ys[2] * invall * e01 * e3 % m + inv_y3 = ys[3] * invall * e01 * e2 % m + return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)] + # Optimized version of the above restricted to deg-2 polynomials + def lagrange_interp_2(self, xs, ys): + m = self.modulus + eq0 = [-xs[1] % m, 1] + eq1 = [-xs[0] % m, 1] + e0 = self.eval_poly_at(eq0, xs[0]) + e1 = self.eval_poly_at(eq1, xs[1]) + invall = self.inv(e0 * e1) + inv_y0 = ys[0] * invall * e1 + inv_y1 = ys[1] * invall * e0 + return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)] + + # Arithmetic for polynomials def add_polys(self, a, b): return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0)) % self.modulus for i in range(max(len(a), len(b)))] @@ -118,7 +147,14 @@ class PrimeField(): apos -= 1 diff -= 1 return [x % self.modulus for x in o] + + # Divides a polynomial by x^n-1 + def divide_by_xnm1(self, poly, n): + if len(poly) <= n: + return [] + return self.add_polys(poly[n:], self.divide_by_xnm1(poly[n:], n)) + # Returns P(x) = A(B(x)) def compose_polys(self, a, b): o = [] p = [1] @@ -126,4 +162,13 @@ class PrimeField(): o = self.add_polys(o, self.mul_by_const(p, c)) p = self.mul_polys(p, b) return o - + + # Convert a polynomial P(x) into a polynomial Q(x) = P(fac * x) + # Equivalent to compose_polys(poly, [0, fac]) + def multiply_base(self, poly, fac): + o = [] + r = 1 + for p in poly: + o.append(p * r % self.modulus) + r = r * fac % self.modulus + return o diff --git a/mimc_stark/test.py b/mimc_stark/test.py index c306ff1..814d9ac 100644 --- a/mimc_stark/test.py +++ b/mimc_stark/test.py @@ -12,25 +12,36 @@ def test_merkletree(): def test_fri(): # Pure FRI tests - poly = list(range(512)) - root_of_unity = pow(7, (modulus-1)//1024, modulus) + poly = list(range(4096)) + root_of_unity = pow(7, (modulus-1)//16384, modulus) evaluations = fft(poly, modulus, root_of_unity) - proof = prove_low_degree(poly, root_of_unity, evaluations, 512, modulus) + proof = prove_low_degree(evaluations, root_of_unity, 4096, modulus) print("Approx proof length: %d" % bin_length(compress_fri(proof))) - assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512, modulus) + assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 4096, modulus) + + try: + fakedata = [x if pow(3, i, 4096) > 400 else 39 for x, i in enumerate(evaluations)] + proof2 = prove_low_degree(fakedata, root_of_unity, 4096, modulus) + assert verify_low_degree_proof(merkelize(fakedata)[1], root_of_unity, proof, 4096, modulus) + raise Exception("Fake data passed FRI") + except: + pass + try: + assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 2048, modulus) + raise Exception("Fake data passed FRI") + except: + pass def test_stark(): INPUT = 3 LOGSTEPS = 13 - LOGPRECISION = 16 - # Full STARK test - proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION) + proof = mk_mimc_proof(INPUT, LOGSTEPS) p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof L1 = bin_length(compress_branches(branches)) L2 = bin_length(compress_fri(fri_proof)) print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2)) - root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus) - subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus) - skips = 2**(LOGPRECISION - LOGSTEPS) - assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), proof) + assert verify_mimc_proof(3, LOGSTEPS, mimc(3, LOGSTEPS), proof) + +if __name__ == '__main__': + test_stark()