Improved efficiency a bunch and added boundary checks

This commit is contained in:
Vitalik Buterin 2018-07-10 15:45:12 -04:00
parent 770c0a2c78
commit da1d723780
11 changed files with 173 additions and 381 deletions

View File

@ -16,32 +16,3 @@ def eval_poly_at(poly, x, modulus):
p = (p * x % modulus) p = (p * x % modulus)
return o % modulus return o % modulus
def lagrange_interp_4(pieces, xs, modulus):
x01, x02, x03, x12, x13, x23 = \
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
eq0 = [-x12 * xs[3] % modulus, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
eq1 = [-x02 * xs[3] % modulus, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
eq2 = [-x01 * xs[3] % modulus, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
eq3 = [-x01 * xs[2] % modulus, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
e0 = eval_poly_at(eq0, xs[0], modulus)
e1 = eval_poly_at(eq1, xs[1], modulus)
e2 = eval_poly_at(eq2, xs[2], modulus)
e3 = eval_poly_at(eq3, xs[3], modulus)
e01 = e0 * e1
e23 = e2 * e3
invall = inv(e01 * e23, modulus)
inv_y0 = pieces[0] * invall * e1 * e23 % modulus
inv_y1 = pieces[1] * invall * e0 * e23 % modulus
inv_y2 = pieces[2] * invall * e01 * e3 % modulus
inv_y3 = pieces[3] * invall * e01 * e2 % modulus
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % modulus for i in range(4)]
def lagrange_interp_2(pieces, xs, modulus):
eq0 = [-xs[1] % modulus, 1]
eq1 = [-xs[0] % modulus, 1]
e0 = eval_poly_at(eq0, xs[0], modulus)
e1 = eval_poly_at(eq1, xs[1], modulus)
invall = inv(e0 * e1, modulus)
inv_y0 = pieces[0] * invall * e1
inv_y1 = pieces[1] * invall * e0
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % modulus for i in range(2)]

View File

@ -1 +0,0 @@
from .poly_utils import PrimeField

View File

@ -1,198 +0,0 @@
modulus_poly = [1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 1, 0, 0, 1,
1]
modulus_poly_as_int = sum([(v << i) for i, v in enumerate(modulus_poly)])
degree = len(modulus_poly) - 1
two_to_the_degree = 2**degree
two_to_the_degree_m1 = 2**degree - 1
def galoistpl(a):
# 2 is not a primitive root, so we have to use 3 as our logarithm base
if a * 2 < two_to_the_degree:
return (a * 2) ^ a
else:
return (a * 2) ^ a ^ modulus_poly_as_int
# Precomputing a log table for increased speed of addition and multiplication
glogtable = [0] * (two_to_the_degree)
gexptable = []
v = 1
for i in range(two_to_the_degree_m1):
glogtable[v] = i
gexptable.append(v)
v = galoistpl(v)
gexptable += gexptable + gexptable
# Add two values in the Galois field
def galois_add(x, y):
return x ^ y
# In binary fields, addition and subtraction are the same thing
galois_sub = galois_add
# Multiply two values in the Galois field
def galois_mul(x, y):
return 0 if x*y == 0 else gexptable[glogtable[x] + glogtable[y]]
# Divide two values in the Galois field
def galois_div(x, y):
return 0 if x == 0 else gexptable[(glogtable[x] - glogtable[y]) % two_to_the_degree_m1]
# Evaluate a polynomial at a point
def eval_poly_at(p, x):
if x == 0:
return p[0]
y = 0
logx = glogtable[x]
for i, p_coeff in enumerate(p):
if p_coeff:
# Add x**i * coeff
y ^= gexptable[(logx * i + glogtable[p_coeff]) % two_to_the_degree_m1]
return y
# Given p+1 y values and x values with no errors, recovers the original
# p+1 degree polynomial.
# Lagrange interpolation works roughly in the following way.
# 1. Suppose you have a set of points, eg. x = [1, 2, 3], y = [2, 5, 10]
# 2. For each x, generate a polynomial which equals its corresponding
# y coordinate at that point and 0 at all other points provided.
# 3. Add these polynomials together.
def lagrange_interp(pieces, xs):
# Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn)
root = mk_root_2(xs)
#print(root)
assert len(root) == len(pieces) + 1
# print(root)
# Generate the derivative
d = derivative(root)
# Generate denominators by evaluating numerator polys at each x
denoms = multi_eval_2(d, xs)
print(denoms)
# denoms = [eval_poly_at(d, xs[i]) for i in range(len(xs))]
# Generate output polynomial, which is the sum of the per-value numerator
# polynomials rescaled to have the right y values
factors = [galois_div(p, d) for p, d in zip(pieces, denoms)]
o = multi_root_derive(xs, factors)
# print(o)
return o
def multi_root_derive(xs, muls):
if len(xs) == 1:
return [muls[0]]
R1 = mk_root_2(xs[:len(xs) // 2])
R2 = mk_root_2(xs[len(xs) // 2:])
x1 = karatsuba_mul(R1, multi_root_derive(xs[len(xs) // 2:], muls[len(muls) // 2:]) + [0])
x2 = karatsuba_mul(R2, multi_root_derive(xs[:len(xs) // 2], muls[:len(muls) // 2]) + [0])
o = [v1 ^ v2 for v1, v2 in zip(x1, x2)][:len(xs)]
# print(len(R1), len(x1), len(xs), len(o))
return o
def multi_root_derive_1(xs, muls):
o = [0] * len(xs)
for i in range(len(xs)):
_xs = xs[:i] + xs[(i+1):]
root = mk_root_2(_xs)
for j in range(len(root)):
o[j] ^= galois_mul(root[j], muls[i])
return o
a = 124
b = 8932
c = 12415
assert galois_mul(galois_add(a, b), c) == galois_add(galois_mul(a, c), galois_mul(b, c))
def karatsuba_mul(p1, p2):
L = len(p1)
# assert L == len(p2)
if L <= 16:
o = [0] * (L * 2)
for i, v1 in enumerate(p1):
for j, v2 in enumerate(p2):
if v1 and v2:
o[i + j] ^= gexptable[glogtable[v1] + glogtable[v2]]
return o
if L % 2:
p1 = p1 + [0]
p2 = p2 + [0]
L += 1
halflen = L // 2
low1 = p1[:halflen]
high1 = p1[halflen:]
sum1 = [l ^ h for l, h in zip(low1, high1)]
low2 = p2[:halflen]
high2 = p2[halflen:]
sum2 = [l ^ h for l, h in zip(low2, high2)]
z2 = karatsuba_mul(high1, high2)
z0 = karatsuba_mul(low1, low2)
z1 = [m ^ _z0 ^ _z2 for m, _z0, _z2 in zip(karatsuba_mul(sum1, sum2), z0, z2)]
o = z0[:halflen] + \
[a ^ b for a, b in zip(z0[halflen:], z1[:halflen])] + \
[a ^ b for a, b in zip(z2[:halflen], z1[halflen:])] + \
z2[halflen:]
return o
def mk_root_1(xs):
root = [1]
for x in xs:
logx = glogtable[x]
root.insert(0, 0)
for j in range(len(root)-1):
if root[j+1] and x:
root[j] ^= gexptable[glogtable[root[j+1]] + logx]
return root
def mk_root_2(xs):
if len(xs) >= 128:
return karatsuba_mul(mk_root_2(xs[:len(xs) // 2]), mk_root_2(xs[len(xs) // 2:]))[:len(xs) + 1]
root = [1]
for x in xs:
logx = glogtable[x]
root.insert(0, 0)
for j in range(len(root)-1):
if root[j+1] and x:
root[j] ^= gexptable[glogtable[root[j+1]] + logx]
return root
def derivative(root):
return [0 if i % 2 else r for i, r in enumerate(root[1:])]
# Credit to http://people.csail.mit.edu/madhu/ST12/scribe/lect06.pdf for the algorithm
def xn_mod_poly(p):
if len(p) == 1:
return [galois_div(1, p[0])]
halflen = len(p) // 2
lowinv = xn_mod_poly(p[:halflen])
submod_high = karatsuba_mul(lowinv, p[:halflen])[halflen:]
med = karatsuba_mul(p[halflen:], lowinv)[:halflen]
med_plus_high = [x ^ y for x, y in zip(med, submod_high)]
highinv = karatsuba_mul(med_plus_high, lowinv)
o = (lowinv + highinv)[:len(p)]
print(halflen, lowinv, submod_high, med, highinv)
# assert karatsuba_mul(o, p)[:len(p)] == [1] + [0] * (len(p) - 1)
return o
def mod(a, b):
assert len(a) == 2 * (len(b) - 1)
L = len(b)
inv_rev_b = xn_mod_poly(b[::-1] + [0] * (len(a) - L))[:L]
quot = karatsuba_mul(inv_rev_b, a[::-1][:L])[:L-1][::-1]
subt = karatsuba_mul(b, quot + [0])[:-1]
o = [x ^ y for x, y in zip(a[:L-1], subt[:L-1])]
# assert [x^y for x, y in zip(karatsuba_mul(quot + [0], b), o)] == a
return o
def multi_eval_1(poly, xs):
return [eval_poly_at(poly, x) for x in xs]
def multi_eval_2(poly, xs):
if len(xs) <= 1024:
return [eval_poly_at(poly, x) for x in xs]
halflen = len(xs) // 2
return multi_eval_2(mod(poly, mk_root_2(xs[:halflen])), xs[:halflen]) + \
multi_eval_2(mod(poly, mk_root_2(xs[halflen:])), xs[halflen:])
# [eval_poly_at(poly, xs[-2]), eval_poly_at(poly, xs[-1])]

View File

@ -1,24 +0,0 @@
# -*- coding: utf-8 -*-
from setuptools import setup, find_packages
with open('README.md') as f:
readme = f.read()
with open('LICENSE') as f:
license = f.read()
setup(
name='ecpoly',
version='1.0.0',
description='Erasure code utilities for prime fields',
long_description=readme,
author='Vitalik Buterin',
author_email='',
url='https://github.com/ethereum/research/tree/master/erasure_code/ecpoly',
license=license,
packages=find_packages(exclude=('tests', 'docs')),
install_requires=[
],
)

View File

@ -1,15 +0,0 @@
from ecpoly import PrimeField
f = PrimeField(65537)
k1 = list(range(10))
k2 = list(range(100, 200))
k3 = f.mul_polys(k1, k2)
assert f.div_polys(k3, k1) == k2
assert f.div_polys(k3, k2) == k1
assert (f.eval_poly_at(k1, 9999) * f.eval_poly_at(k2, 9999) -
f.eval_poly_at(k3, 9999)) % f.modulus == 0
k4 = f.compose_polys(k1, k2)
assert f.eval_poly_at(k4, 9998) == f.eval_poly_at(k1, f.eval_poly_at(k2, 9998))
print("All passed!")

View File

@ -1,10 +1,17 @@
from better_lagrange import lagrange_interp_4, eval_poly_at
from merkle_tree import merkelize, mk_branch, verify_branch from merkle_tree import merkelize, mk_branch, verify_branch
from utils import get_power_cycle, get_pseudorandom_indices from utils import get_power_cycle, get_pseudorandom_indices
from fft import fft from fft import fft
from poly_utils import PrimeField
# Generate an FRI proof # Generate an FRI proof that the polynomial that has the specified
def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus): # values at successive powers of the specified root of unity has a
# degree lower than maxdeg_plus_1
#
# We use maxdeg+1 instead of maxdeg because it's more mathematically
# convenient in this case.
def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus):
f = PrimeField(modulus)
print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1)) print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1))
# If the degree we are checking for is less than or equal to 32, # If the degree we are checking for is less than or equal to 32,
@ -24,18 +31,17 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus):
# Select a pseudo-random x coordinate # Select a pseudo-random x coordinate
special_x = int.from_bytes(m[1], 'big') % modulus special_x = int.from_bytes(m[1], 'big') % modulus
# Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html) # Calculate the "column" at that x coordinate
# at that x coordinate # (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
# We calculate the column by Lagrange-interpolating the row, and not # We calculate the column by Lagrange-interpolating the row, and not
# directly, as this is more efficient # directly from the polynomial, as this is more efficient
column = [] column = []
for i in range(len(xs)//4): for i in range(len(xs)//4):
x_poly = lagrange_interp_4( x_poly = f.lagrange_interp_4(
[values[i+len(values)*j//4] for j in range(4)],
[xs[i+len(xs)*j//4] for j in range(4)], [xs[i+len(xs)*j//4] for j in range(4)],
modulus [values[i+len(values)*j//4] for j in range(4)],
) )
column.append(eval_poly_at(x_poly, special_x, modulus)) column.append(f.eval_poly_at(x_poly, special_x))
m2 = merkelize(column) m2 = merkelize(column)
# Pseudo-randomly select y indices to sample # Pseudo-randomly select y indices to sample
@ -44,23 +50,24 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus):
# Compute the Merkle branches for the values in the polynomial and the column # Compute the Merkle branches for the values in the polynomial and the column
branches = [] branches = []
for y in ys: for y in ys:
branches.append([mk_branch(m2, y)] + [mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)]) branches.append([mk_branch(m2, y)] +
[mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)])
# This component of the proof # This component of the proof
o = [m2[1], branches] o = [m2[1], branches]
# In the next iteration of the proof, we'll work with smaller roots of unity
sub_xs = [xs[i] for i in range(0, len(xs), 4)]
# Interpolate the polynomial for the column # Interpolate the polynomial for the column
ypoly = fft(column[:len(sub_xs)], modulus, # sub_xs = [xs[i] for i in range(0, len(xs), 4)]
pow(root_of_unity, 4, modulus), inv=True) # ypoly = fft(column[:len(sub_xs)], modulus,
# f.exp(root_of_unity, 4), inv=True)
# Recurse... # Recurse...
return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4, modulus) return [o] + prove_low_degree(column, f.exp(root_of_unity, 4),
maxdeg_plus_1 // 4, modulus)
# Verify an FRI proof # Verify an FRI proof
def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, modulus): def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, modulus):
f = PrimeField(modulus)
# Calculate which root of unity we're working with # Calculate which root of unity we're working with
testval = root_of_unity testval = root_of_unity
@ -69,10 +76,11 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo
roudeg *= 2 roudeg *= 2
testval = (testval * testval) % modulus testval = (testval * testval) % modulus
# Powers of the given root of unity 1, p, p**2, p**3 such that p**4 = 1
quartic_roots_of_unity = [1, quartic_roots_of_unity = [1,
pow(root_of_unity, roudeg // 4, modulus), f.exp(root_of_unity, roudeg // 4),
pow(root_of_unity, roudeg // 2, modulus), f.exp(root_of_unity, roudeg // 2),
pow(root_of_unity, roudeg * 3 // 4, modulus)] f.exp(root_of_unity, roudeg * 3 // 4)]
# Verify the recursive components of the proof # Verify the recursive components of the proof
for prf in proof[:-1]: for prf in proof[:-1]:
@ -86,27 +94,28 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo
ys = get_pseudorandom_indices(root2, roudeg // 4, 40) ys = get_pseudorandom_indices(root2, roudeg // 4, 40)
# Verify for each selected y coordinate that the four points from the polynomial # Verify for each selected y coordinate that the four points from the
# and the one point from the column that are on that y coordinate are on the same # polynomial and the one point from the column that are on that y
# deg < 4 polynomial # coordinate are on the same deg < 4 polynomial
for i, y in enumerate(ys): for i, y in enumerate(ys):
# The x coordinates from the polynomial # The x coordinates from the polynomial
x1 = pow(root_of_unity, y, modulus) x1 = f.exp(root_of_unity, y)
xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)] xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)]
# The values from the polynomial # The values from the polynomial
row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) for j, prf in zip(range(4), branches[i][1:])] row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf)
for j, prf in zip(range(4), branches[i][1:])]
# Verify proof and recover the column value # Verify proof and recover the column value
values = [verify_branch(root2, y, branches[i][0])] + row values = [verify_branch(root2, y, branches[i][0])] + row
# Lagrange interpolate and check deg is < 4 # Lagrange interpolate and check deg is < 4
p = lagrange_interp_4(row, xcoords, modulus) p = f.lagrange_interp_4(xcoords, row)
assert eval_poly_at(p, special_x, modulus) == verify_branch(root2, y, branches[i][0]) assert f.eval_poly_at(p, special_x) == verify_branch(root2, y, branches[i][0])
# Update constants to check the next proof # Update constants to check the next proof
merkle_root = root2 merkle_root = root2
root_of_unity = pow(root_of_unity, 4, modulus) root_of_unity = f.exp(root_of_unity, 4)
maxdeg_plus_1 //= 4 maxdeg_plus_1 //= 4
roudeg //= 4 roudeg //= 4

View File

@ -1,7 +1,6 @@
from merkle_tree import merkelize, mk_branch, verify_branch, blake from merkle_tree import merkelize, mk_branch, verify_branch, blake
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
from ecpoly import PrimeField from poly_utils import PrimeField
from better_lagrange import lagrange_interp_4, lagrange_interp_2
import time import time
from fft import fft from fft import fft
from fri import prove_low_degree, verify_low_degree_proof from fri import prove_low_degree, verify_low_degree_proof
@ -13,15 +12,11 @@ nonresidue = 7
spot_check_security_factor = 240 spot_check_security_factor = 240
# Compute a MIMC permutation for 2**logsteps steps, using round constants # Compute a MIMC permutation for 2**logsteps steps
# from the multiplicative subgroup of size 2**logprecision def mimc(inp, logsteps):
def mimc(inp, logsteps, logprecision):
start_time = time.time() start_time = time.time()
steps = 2**logsteps steps = 2**logsteps
precision = 2**logprecision # We use powers of 9 mod 2^256 XORed with 1 as the ith round constant for the moment
# Get (steps)th root of unity
subroot = pow(7, (modulus-1)//steps, modulus)
# We use powers of 9 mod 2^256 as the ith round constant for the moment
k = 1 k = 1
for i in range(steps-1): for i in range(steps-1):
inp = (inp**3 + (k ^ 1)) % modulus inp = (inp**3 + (k ^ 1)) % modulus
@ -29,34 +24,20 @@ def mimc(inp, logsteps, logprecision):
print("MIMC computed in %.4f sec" % (time.time() - start_time)) print("MIMC computed in %.4f sec" % (time.time() - start_time))
return inp return inp
# Convert a polynomial P(x) into a polynomial Q(x) = P(fac * x)
def multiply_base(poly, fac):
o = []
r = 1
for p in poly:
o.append(p * r % modulus)
r = r * fac % modulus
return o
# Divides a polynomial by x^n-1
def divide_by_xnm1(poly, n):
if len(poly) <= n:
return []
return f.add_polys(poly[n:], divide_by_xnm1(poly[n:], n))
# Generate a STARK for a MIMC calculation # Generate a STARK for a MIMC calculation
def mk_mimc_proof(inp, logsteps, logprecision): def mk_mimc_proof(inp, logsteps):
start_time = time.time() start_time = time.time()
assert logsteps < logprecision <= 32 assert logsteps <= 29
logprecision = logsteps + 3
steps = 2**logsteps steps = 2**logsteps
precision = 2**logprecision precision = 2**logprecision
# Root of unity such that x^precision=1 # Root of unity such that x^precision=1
root_of_unity = pow(7, (modulus-1)//precision, modulus) root_of_unity = f.exp(7, (modulus-1)//precision)
# Root of unity such that x^skips=1 # Root of unity such that x^skips=1
skips = precision // steps skips = precision // steps
subroot = pow(root_of_unity, skips) subroot = f.exp(root_of_unity, skips)
# Powers of the root of unity, our computational trace will be # Powers of the root of unity, our computational trace will be
# along the sequence of roots of unity # along the sequence of roots of unity
@ -82,30 +63,31 @@ def mk_mimc_proof(inp, logsteps, logprecision):
# Create the composed polynomial such that # Create the composed polynomial such that
# C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x) # C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x)
term1 = multiply_base(values_polynomial, subroot) term1 = f.multiply_base(values_polynomial, subroot)
p_evaluations = fft(values_polynomial, modulus, root_of_unity) p_evaluations = fft(values_polynomial, modulus, root_of_unity)
term2 = fft([pow(x, 3, modulus) for x in p_evaluations], modulus, root_of_unity, inv=True)[:len(values_polynomial) * 3 - 2] term2 = fft([f.exp(x, 3) for x in p_evaluations], modulus, root_of_unity, inv=True)[:len(values_polynomial) * 3 - 2]
c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial) c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial)
print('Computed C(P, K) polynomial') print('Computed C(P, K) polynomial')
# Compute D(x) = C(P(x), P(rx), K(x)) / Z(x) # Compute D(x) = C(P(x), P(rx), K(x)) / Z(x)
# Z(x) = (x^steps - 1) / (x - x_atlast_step) # Z(x) = (x^steps - 1) / (x - x_atlast_step)
d = divide_by_xnm1(f.mul_polys(c_of_values, d = f.divide_by_xnm1(f.mul_polys(c_of_values,
[-last_step_position, 1]), [-last_step_position, 1]),
steps) steps)
# Consistency check # Consistency check
assert (f.eval_poly_at(d, 90833) * # assert (f.eval_poly_at(d, 90833) *
(pow(90833, steps, modulus) - 1) * # (f.exp(90833, steps) - 1) *
f.inv(f.eval_poly_at([-last_step_position, 1], 90833)) - # f.inv(f.eval_poly_at([-last_step_position, 1], 90833)) -
f.eval_poly_at(c_of_values, 90833)) % modulus == 0 # f.eval_poly_at(c_of_values, 90833)) % modulus == 0
print('Computed D polynomial') print('Computed D polynomial')
# Compute interpolant of ((1, input), (x_atlast_step, output)) # Compute interpolant of ((1, input), (x_atlast_step, output))
interpolant = lagrange_interp_2([inp, output], [1, last_step_position], modulus) interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
quotient = f.mul_polys([-1, 1], [-last_step_position, 1]) quotient = f.mul_polys([-1, 1], [-last_step_position, 1])
b = f.div_polys(f.sub_polys(values_polynomial, interpolant), quotient) b = f.div_polys(f.sub_polys(values_polynomial, interpolant), quotient)
# Consistency check # Consistency check
assert f.eval_poly_at(f.add_polys(f.mul_polys(b, quotient), interpolant), 7045) == f.eval_poly_at(values_polynomial, 7045) # assert f.eval_poly_at(f.add_polys(f.mul_polys(b, quotient), interpolant), 7045) == \
# f.eval_poly_at(values_polynomial, 7045)
print('Computed B polynomial') print('Computed B polynomial')
# Evaluate B, D and K across the entire subgroup # Evaluate B, D and K across the entire subgroup
@ -129,12 +111,18 @@ def mk_mimc_proof(inp, logsteps, logprecision):
k3 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x03'), 'big') k3 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x03'), 'big')
k4 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x04'), 'big') k4 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x04'), 'big')
lincomb = f.add_polys(f.add_polys(d, # Compute the linear combination. We don't even both calculating it in
f.mul_by_const(values_polynomial, k1) + f.mul_by_const(values_polynomial, k2)), # coefficient form; we just compute the evaluations
f.mul_by_const(b, k3) + [0, 0] + f.mul_by_const(b, k4) + [0,0]) root_of_unity_to_the_steps = f.exp(root_of_unity, steps)
l_evaluations = fft(lincomb, modulus, root_of_unity) powers = [1]
l_mtree = merkelize(l_evaluations) for i in range(1, precision):
powers.append(powers[-1] * root_of_unity_to_the_steps % modulus)
l_evaluations = [(d_evaluations[i] +
p_evaluations[i] * k1 + p_evaluations[i] * k2 * powers[i] +
b_evaluations[i] * k3 + b_evaluations[i] * powers[i] * k4) % modulus
for i in range(precision)]
l_mtree = merkelize(l_evaluations)
print('Computed random linear combination') print('Computed random linear combination')
# Do some spot checks of the Merkle tree at pseudo-random coordinates # Do some spot checks of the Merkle tree at pseudo-random coordinates
@ -158,20 +146,21 @@ def mk_mimc_proof(inp, logsteps, logprecision):
b_mtree[1], b_mtree[1],
l_mtree[1], l_mtree[1],
branches, branches,
prove_low_degree(lincomb, root_of_unity, l_evaluations, steps * 2, modulus)] prove_low_degree(l_evaluations, root_of_unity, steps * 2, modulus)]
print("STARK computed in %.4f sec" % (time.time() - start_time)) print("STARK computed in %.4f sec" % (time.time() - start_time))
return o return o
# Verifies a STARK # Verifies a STARK
def verify_mimc_proof(inp, logsteps, logprecision, output, proof): def verify_mimc_proof(inp, logsteps, output, proof):
p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof
start_time = time.time() start_time = time.time()
logprecision = logsteps + 3
steps = 2**logsteps steps = 2**logsteps
precision = 2**logprecision precision = 2**logprecision
# Get (steps)th root of unity # Get (steps)th root of unity
root_of_unity = pow(7, (modulus-1)//precision, modulus) root_of_unity = f.exp(7, (modulus-1)//precision)
skips = precision // steps skips = precision // steps
# Verifies the low-degree proofs # Verifies the low-degree proofs
@ -184,24 +173,29 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, proof):
k4 = int.from_bytes(blake(p_root + d_root + b_root + b'\x04'), 'big') k4 = int.from_bytes(blake(p_root + d_root + b_root + b'\x04'), 'big')
samples = spot_check_security_factor // (logprecision - logsteps) samples = spot_check_security_factor // (logprecision - logsteps)
positions = get_pseudorandom_indices(l_root, precision - skips, samples) positions = get_pseudorandom_indices(l_root, precision - skips, samples)
last_step_position = pow(root_of_unity, (steps - 1) * skips, modulus) last_step_position = f.exp(root_of_unity, (steps - 1) * skips)
for i, pos in enumerate(positions): for i, pos in enumerate(positions):
# Check C(P(x)) = Z(x) * D(x) x = f.exp(root_of_unity, pos)
x = pow(root_of_unity, pos, modulus) x_to_the_steps = f.exp(x, steps)
x_to_the_steps = pow(x, steps, modulus)
p_of_x = verify_branch(p_root, pos, branches[i*6]) p_of_x = verify_branch(p_root, pos, branches[i*6])
p_of_rx = verify_branch(p_root, pos+skips, branches[i*6 + 1]) p_of_rx = verify_branch(p_root, pos+skips, branches[i*6 + 1])
d_of_x = verify_branch(d_root, pos, branches[i*6 + 2]) d_of_x = verify_branch(d_root, pos, branches[i*6 + 2])
k_of_x = verify_branch(k_root, pos, branches[i*6 + 3]) k_of_x = verify_branch(k_root, pos, branches[i*6 + 3])
b_of_x = verify_branch(b_root, pos, branches[i*6 + 4]) b_of_x = verify_branch(b_root, pos, branches[i*6 + 4])
l_of_x = verify_branch(l_root, pos, branches[i*6 + 5]) l_of_x = verify_branch(l_root, pos, branches[i*6 + 5])
zvalue = f.div(pow(x, steps, modulus) - 1, zvalue = f.div(f.exp(x, steps) - 1,
x - last_step_position) x - last_step_position)
# Check transition constraints C(P(x)) = Z(x) * D(x)
assert (p_of_rx - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0 assert (p_of_rx - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0
interpolant = lagrange_interp_2([inp, output], [1, last_step_position], modulus) interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
quotient = f.mul_polys([-1, 1], [-last_step_position, 1]) quotient = f.mul_polys([-1, 1], [-last_step_position, 1])
# Check boundary constraints B(x) * Q(x) + I(x) = P(x)
assert (p_of_x - b_of_x * f.eval_poly_at(quotient, x) - assert (p_of_x - b_of_x * f.eval_poly_at(quotient, x) -
f.eval_poly_at(interpolant, x)) % modulus == 0 f.eval_poly_at(interpolant, x)) % modulus == 0
# Check correctness of the linear combination
assert (l_of_x - d_of_x - assert (l_of_x - d_of_x -
k1 * p_of_x - k2 * p_of_x * x_to_the_steps - k1 * p_of_x - k2 * p_of_x * x_to_the_steps -
k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0 k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0

View File

@ -1,3 +1,5 @@
# Creates an object that includes convenience operations for numbers
# and polynomials in some prime field
class PrimeField(): class PrimeField():
def __init__(self, modulus): def __init__(self, modulus):
assert pow(2, modulus, modulus) == 2 assert pow(2, modulus, modulus) == 2
@ -12,9 +14,13 @@ class PrimeField():
def mul(self, x, y): def mul(self, x, y):
return (x*y) % self.modulus return (x*y) % self.modulus
def exp(self, x, p):
return pow(x, p, self.modulus)
# Modular inverse using the extended Euclidean algorithm
def inv(self, a): def inv(self, a):
if a == 0: if a == 0:
return 0 raise Exception("Cannot invert 0")
lm, hm = 1, 0 lm, hm = 1, 0
low, high = a % self.modulus, self.modulus low, high = a % self.modulus, self.modulus
while low > 1: while low > 1:
@ -24,14 +30,10 @@ class PrimeField():
return lm % self.modulus return lm % self.modulus
def div(self, x, y): def div(self, x, y):
if x == 0 and y == 0:
return 1
return self.mul(x, self.inv(y)) return self.mul(x, self.inv(y))
# Evaluate a polynomial at a point # Evaluate a polynomial at a point
def eval_poly_at(self, p, x): def eval_poly_at(self, p, x):
if x == 0:
return p[0]
y = 0 y = 0
power_of_x = 1 power_of_x = 1
for i, p_coeff in enumerate(p): for i, p_coeff in enumerate(p):
@ -39,7 +41,7 @@ class PrimeField():
power_of_x = (power_of_x * x) % self.modulus power_of_x = (power_of_x * x) % self.modulus
return y % self.modulus return y % self.modulus
# Build a polynomial that returns 0 at all xs # Build a polynomial that returns 0 at all specified xs
def zpoly(self, xs): def zpoly(self, xs):
root = [1] root = [1]
for x in xs: for x in xs:
@ -56,35 +58,62 @@ class PrimeField():
# y coordinate at that point and 0 at all other points provided. # y coordinate at that point and 0 at all other points provided.
# 3. Add these polynomials together. # 3. Add these polynomials together.
def lagrange_interp(self, pieces, xs): def lagrange_interp(self, xs, ys):
# Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn) # Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn)
root = self.zpoly(xs) root = self.zpoly(xs)
#print(root) assert len(root) == len(ys) + 1
assert len(root) == len(pieces) + 1
# print(root) # print(root)
# Generate per-value numerator polynomials, eg. for x=x2, # Generate per-value numerator polynomials, eg. for x=x2,
# (x - x1) * (x - x3) * ... * (x - xn), by dividing the master # (x - x1) * (x - x3) * ... * (x - xn), by dividing the master
# polynomial back by each x coordinate # polynomial back by each x coordinate
nums = [] nums = [self.div_polys(root, [-x, 1]) for x in xs]
for x in xs:
output = [0] * (len(root) - 2) + [1]
for j in range(len(root) - 2, 0, -1):
output[j-1] = root[j] + output[j] * x
assert len(output) == len(pieces)
nums.append(output)
#print(nums)
# Generate denominators by evaluating numerator polys at each x # Generate denominators by evaluating numerator polys at each x
denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))] denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))]
# Generate output polynomial, which is the sum of the per-value numerator # Generate output polynomial, which is the sum of the per-value numerator
# polynomials rescaled to have the right y values # polynomials rescaled to have the right y values
b = [0 for p in pieces] b = [0 for y in ys]
for i in range(len(xs)): for i in range(len(xs)):
yslice = self.div(pieces[i], denoms[i]) yslice = self.div(ys[i], denoms[i])
for j in range(len(pieces)): for j in range(len(ys)):
if nums[i][j] and pieces[i]: if nums[i][j] and ys[i]:
b[j] += nums[i][j] * yslice b[j] += nums[i][j] * yslice
return [x % self.modulus for x in b] return [x % self.modulus for x in b]
# Optimized version of the above restricted to deg-4 polynomials
def lagrange_interp_4(self, xs, ys):
x01, x02, x03, x12, x13, x23 = \
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
m = self.modulus
eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
e0 = self.eval_poly_at(eq0, xs[0])
e1 = self.eval_poly_at(eq1, xs[1])
e2 = self.eval_poly_at(eq2, xs[2])
e3 = self.eval_poly_at(eq3, xs[3])
e01 = e0 * e1
e23 = e2 * e3
invall = self.inv(e01 * e23)
inv_y0 = ys[0] * invall * e1 * e23 % m
inv_y1 = ys[1] * invall * e0 * e23 % m
inv_y2 = ys[2] * invall * e01 * e3 % m
inv_y3 = ys[3] * invall * e01 * e2 % m
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)]
# Optimized version of the above restricted to deg-2 polynomials
def lagrange_interp_2(self, xs, ys):
m = self.modulus
eq0 = [-xs[1] % m, 1]
eq1 = [-xs[0] % m, 1]
e0 = self.eval_poly_at(eq0, xs[0])
e1 = self.eval_poly_at(eq1, xs[1])
invall = self.inv(e0 * e1)
inv_y0 = ys[0] * invall * e1
inv_y1 = ys[1] * invall * e0
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)]
# Arithmetic for polynomials
def add_polys(self, a, b): def add_polys(self, a, b):
return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0)) return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0))
% self.modulus for i in range(max(len(a), len(b)))] % self.modulus for i in range(max(len(a), len(b)))]
@ -119,6 +148,13 @@ class PrimeField():
diff -= 1 diff -= 1
return [x % self.modulus for x in o] return [x % self.modulus for x in o]
# Divides a polynomial by x^n-1
def divide_by_xnm1(self, poly, n):
if len(poly) <= n:
return []
return self.add_polys(poly[n:], self.divide_by_xnm1(poly[n:], n))
# Returns P(x) = A(B(x))
def compose_polys(self, a, b): def compose_polys(self, a, b):
o = [] o = []
p = [1] p = [1]
@ -127,3 +163,12 @@ class PrimeField():
p = self.mul_polys(p, b) p = self.mul_polys(p, b)
return o return o
# Convert a polynomial P(x) into a polynomial Q(x) = P(fac * x)
# Equivalent to compose_polys(poly, [0, fac])
def multiply_base(self, poly, fac):
o = []
r = 1
for p in poly:
o.append(p * r % self.modulus)
r = r * fac % self.modulus
return o

View File

@ -12,25 +12,36 @@ def test_merkletree():
def test_fri(): def test_fri():
# Pure FRI tests # Pure FRI tests
poly = list(range(512)) poly = list(range(4096))
root_of_unity = pow(7, (modulus-1)//1024, modulus) root_of_unity = pow(7, (modulus-1)//16384, modulus)
evaluations = fft(poly, modulus, root_of_unity) evaluations = fft(poly, modulus, root_of_unity)
proof = prove_low_degree(poly, root_of_unity, evaluations, 512, modulus) proof = prove_low_degree(evaluations, root_of_unity, 4096, modulus)
print("Approx proof length: %d" % bin_length(compress_fri(proof))) print("Approx proof length: %d" % bin_length(compress_fri(proof)))
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512, modulus) assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 4096, modulus)
try:
fakedata = [x if pow(3, i, 4096) > 400 else 39 for x, i in enumerate(evaluations)]
proof2 = prove_low_degree(fakedata, root_of_unity, 4096, modulus)
assert verify_low_degree_proof(merkelize(fakedata)[1], root_of_unity, proof, 4096, modulus)
raise Exception("Fake data passed FRI")
except:
pass
try:
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 2048, modulus)
raise Exception("Fake data passed FRI")
except:
pass
def test_stark(): def test_stark():
INPUT = 3 INPUT = 3
LOGSTEPS = 13 LOGSTEPS = 13
LOGPRECISION = 16
# Full STARK test # Full STARK test
proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION) proof = mk_mimc_proof(INPUT, LOGSTEPS)
p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof
L1 = bin_length(compress_branches(branches)) L1 = bin_length(compress_branches(branches))
L2 = bin_length(compress_fri(fri_proof)) L2 = bin_length(compress_fri(fri_proof))
print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2)) print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2))
root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus) assert verify_mimc_proof(3, LOGSTEPS, mimc(3, LOGSTEPS), proof)
subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus)
skips = 2**(LOGPRECISION - LOGSTEPS) if __name__ == '__main__':
assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), proof) test_stark()