Improved efficiency a bunch and added boundary checks
This commit is contained in:
parent
770c0a2c78
commit
da1d723780
|
@ -16,32 +16,3 @@ def eval_poly_at(poly, x, modulus):
|
|||
p = (p * x % modulus)
|
||||
return o % modulus
|
||||
|
||||
def lagrange_interp_4(pieces, xs, modulus):
|
||||
x01, x02, x03, x12, x13, x23 = \
|
||||
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
|
||||
eq0 = [-x12 * xs[3] % modulus, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
|
||||
eq1 = [-x02 * xs[3] % modulus, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
|
||||
eq2 = [-x01 * xs[3] % modulus, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
|
||||
eq3 = [-x01 * xs[2] % modulus, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
|
||||
e0 = eval_poly_at(eq0, xs[0], modulus)
|
||||
e1 = eval_poly_at(eq1, xs[1], modulus)
|
||||
e2 = eval_poly_at(eq2, xs[2], modulus)
|
||||
e3 = eval_poly_at(eq3, xs[3], modulus)
|
||||
e01 = e0 * e1
|
||||
e23 = e2 * e3
|
||||
invall = inv(e01 * e23, modulus)
|
||||
inv_y0 = pieces[0] * invall * e1 * e23 % modulus
|
||||
inv_y1 = pieces[1] * invall * e0 * e23 % modulus
|
||||
inv_y2 = pieces[2] * invall * e01 * e3 % modulus
|
||||
inv_y3 = pieces[3] * invall * e01 * e2 % modulus
|
||||
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % modulus for i in range(4)]
|
||||
|
||||
def lagrange_interp_2(pieces, xs, modulus):
|
||||
eq0 = [-xs[1] % modulus, 1]
|
||||
eq1 = [-xs[0] % modulus, 1]
|
||||
e0 = eval_poly_at(eq0, xs[0], modulus)
|
||||
e1 = eval_poly_at(eq1, xs[1], modulus)
|
||||
invall = inv(e0 * e1, modulus)
|
||||
inv_y0 = pieces[0] * invall * e1
|
||||
inv_y1 = pieces[1] * invall * e0
|
||||
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % modulus for i in range(2)]
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
from .poly_utils import PrimeField
|
|
@ -1,198 +0,0 @@
|
|||
modulus_poly = [1, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 1, 0, 1, 0, 0, 1,
|
||||
1]
|
||||
modulus_poly_as_int = sum([(v << i) for i, v in enumerate(modulus_poly)])
|
||||
degree = len(modulus_poly) - 1
|
||||
|
||||
two_to_the_degree = 2**degree
|
||||
two_to_the_degree_m1 = 2**degree - 1
|
||||
|
||||
def galoistpl(a):
|
||||
# 2 is not a primitive root, so we have to use 3 as our logarithm base
|
||||
if a * 2 < two_to_the_degree:
|
||||
return (a * 2) ^ a
|
||||
else:
|
||||
return (a * 2) ^ a ^ modulus_poly_as_int
|
||||
|
||||
# Precomputing a log table for increased speed of addition and multiplication
|
||||
glogtable = [0] * (two_to_the_degree)
|
||||
gexptable = []
|
||||
v = 1
|
||||
for i in range(two_to_the_degree_m1):
|
||||
glogtable[v] = i
|
||||
gexptable.append(v)
|
||||
v = galoistpl(v)
|
||||
|
||||
gexptable += gexptable + gexptable
|
||||
|
||||
# Add two values in the Galois field
|
||||
def galois_add(x, y):
|
||||
return x ^ y
|
||||
|
||||
# In binary fields, addition and subtraction are the same thing
|
||||
galois_sub = galois_add
|
||||
|
||||
# Multiply two values in the Galois field
|
||||
def galois_mul(x, y):
|
||||
return 0 if x*y == 0 else gexptable[glogtable[x] + glogtable[y]]
|
||||
|
||||
# Divide two values in the Galois field
|
||||
def galois_div(x, y):
|
||||
return 0 if x == 0 else gexptable[(glogtable[x] - glogtable[y]) % two_to_the_degree_m1]
|
||||
|
||||
# Evaluate a polynomial at a point
|
||||
def eval_poly_at(p, x):
|
||||
if x == 0:
|
||||
return p[0]
|
||||
y = 0
|
||||
logx = glogtable[x]
|
||||
for i, p_coeff in enumerate(p):
|
||||
if p_coeff:
|
||||
# Add x**i * coeff
|
||||
y ^= gexptable[(logx * i + glogtable[p_coeff]) % two_to_the_degree_m1]
|
||||
return y
|
||||
|
||||
|
||||
# Given p+1 y values and x values with no errors, recovers the original
|
||||
# p+1 degree polynomial.
|
||||
# Lagrange interpolation works roughly in the following way.
|
||||
# 1. Suppose you have a set of points, eg. x = [1, 2, 3], y = [2, 5, 10]
|
||||
# 2. For each x, generate a polynomial which equals its corresponding
|
||||
# y coordinate at that point and 0 at all other points provided.
|
||||
# 3. Add these polynomials together.
|
||||
|
||||
def lagrange_interp(pieces, xs):
|
||||
# Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn)
|
||||
root = mk_root_2(xs)
|
||||
#print(root)
|
||||
assert len(root) == len(pieces) + 1
|
||||
# print(root)
|
||||
# Generate the derivative
|
||||
d = derivative(root)
|
||||
# Generate denominators by evaluating numerator polys at each x
|
||||
denoms = multi_eval_2(d, xs)
|
||||
print(denoms)
|
||||
# denoms = [eval_poly_at(d, xs[i]) for i in range(len(xs))]
|
||||
# Generate output polynomial, which is the sum of the per-value numerator
|
||||
# polynomials rescaled to have the right y values
|
||||
factors = [galois_div(p, d) for p, d in zip(pieces, denoms)]
|
||||
o = multi_root_derive(xs, factors)
|
||||
# print(o)
|
||||
return o
|
||||
|
||||
def multi_root_derive(xs, muls):
|
||||
if len(xs) == 1:
|
||||
return [muls[0]]
|
||||
R1 = mk_root_2(xs[:len(xs) // 2])
|
||||
R2 = mk_root_2(xs[len(xs) // 2:])
|
||||
x1 = karatsuba_mul(R1, multi_root_derive(xs[len(xs) // 2:], muls[len(muls) // 2:]) + [0])
|
||||
x2 = karatsuba_mul(R2, multi_root_derive(xs[:len(xs) // 2], muls[:len(muls) // 2]) + [0])
|
||||
o = [v1 ^ v2 for v1, v2 in zip(x1, x2)][:len(xs)]
|
||||
# print(len(R1), len(x1), len(xs), len(o))
|
||||
return o
|
||||
|
||||
def multi_root_derive_1(xs, muls):
|
||||
o = [0] * len(xs)
|
||||
for i in range(len(xs)):
|
||||
_xs = xs[:i] + xs[(i+1):]
|
||||
root = mk_root_2(_xs)
|
||||
for j in range(len(root)):
|
||||
o[j] ^= galois_mul(root[j], muls[i])
|
||||
return o
|
||||
|
||||
a = 124
|
||||
b = 8932
|
||||
c = 12415
|
||||
|
||||
assert galois_mul(galois_add(a, b), c) == galois_add(galois_mul(a, c), galois_mul(b, c))
|
||||
|
||||
def karatsuba_mul(p1, p2):
|
||||
L = len(p1)
|
||||
# assert L == len(p2)
|
||||
if L <= 16:
|
||||
o = [0] * (L * 2)
|
||||
for i, v1 in enumerate(p1):
|
||||
for j, v2 in enumerate(p2):
|
||||
if v1 and v2:
|
||||
o[i + j] ^= gexptable[glogtable[v1] + glogtable[v2]]
|
||||
return o
|
||||
if L % 2:
|
||||
p1 = p1 + [0]
|
||||
p2 = p2 + [0]
|
||||
L += 1
|
||||
halflen = L // 2
|
||||
low1 = p1[:halflen]
|
||||
high1 = p1[halflen:]
|
||||
sum1 = [l ^ h for l, h in zip(low1, high1)]
|
||||
low2 = p2[:halflen]
|
||||
high2 = p2[halflen:]
|
||||
sum2 = [l ^ h for l, h in zip(low2, high2)]
|
||||
z2 = karatsuba_mul(high1, high2)
|
||||
z0 = karatsuba_mul(low1, low2)
|
||||
z1 = [m ^ _z0 ^ _z2 for m, _z0, _z2 in zip(karatsuba_mul(sum1, sum2), z0, z2)]
|
||||
o = z0[:halflen] + \
|
||||
[a ^ b for a, b in zip(z0[halflen:], z1[:halflen])] + \
|
||||
[a ^ b for a, b in zip(z2[:halflen], z1[halflen:])] + \
|
||||
z2[halflen:]
|
||||
return o
|
||||
|
||||
def mk_root_1(xs):
|
||||
root = [1]
|
||||
for x in xs:
|
||||
logx = glogtable[x]
|
||||
root.insert(0, 0)
|
||||
for j in range(len(root)-1):
|
||||
if root[j+1] and x:
|
||||
root[j] ^= gexptable[glogtable[root[j+1]] + logx]
|
||||
return root
|
||||
|
||||
def mk_root_2(xs):
|
||||
if len(xs) >= 128:
|
||||
return karatsuba_mul(mk_root_2(xs[:len(xs) // 2]), mk_root_2(xs[len(xs) // 2:]))[:len(xs) + 1]
|
||||
root = [1]
|
||||
for x in xs:
|
||||
logx = glogtable[x]
|
||||
root.insert(0, 0)
|
||||
for j in range(len(root)-1):
|
||||
if root[j+1] and x:
|
||||
root[j] ^= gexptable[glogtable[root[j+1]] + logx]
|
||||
return root
|
||||
|
||||
def derivative(root):
|
||||
return [0 if i % 2 else r for i, r in enumerate(root[1:])]
|
||||
|
||||
# Credit to http://people.csail.mit.edu/madhu/ST12/scribe/lect06.pdf for the algorithm
|
||||
def xn_mod_poly(p):
|
||||
if len(p) == 1:
|
||||
return [galois_div(1, p[0])]
|
||||
halflen = len(p) // 2
|
||||
lowinv = xn_mod_poly(p[:halflen])
|
||||
submod_high = karatsuba_mul(lowinv, p[:halflen])[halflen:]
|
||||
med = karatsuba_mul(p[halflen:], lowinv)[:halflen]
|
||||
med_plus_high = [x ^ y for x, y in zip(med, submod_high)]
|
||||
highinv = karatsuba_mul(med_plus_high, lowinv)
|
||||
o = (lowinv + highinv)[:len(p)]
|
||||
print(halflen, lowinv, submod_high, med, highinv)
|
||||
# assert karatsuba_mul(o, p)[:len(p)] == [1] + [0] * (len(p) - 1)
|
||||
return o
|
||||
|
||||
def mod(a, b):
|
||||
assert len(a) == 2 * (len(b) - 1)
|
||||
L = len(b)
|
||||
inv_rev_b = xn_mod_poly(b[::-1] + [0] * (len(a) - L))[:L]
|
||||
quot = karatsuba_mul(inv_rev_b, a[::-1][:L])[:L-1][::-1]
|
||||
subt = karatsuba_mul(b, quot + [0])[:-1]
|
||||
o = [x ^ y for x, y in zip(a[:L-1], subt[:L-1])]
|
||||
# assert [x^y for x, y in zip(karatsuba_mul(quot + [0], b), o)] == a
|
||||
return o
|
||||
|
||||
def multi_eval_1(poly, xs):
|
||||
return [eval_poly_at(poly, x) for x in xs]
|
||||
|
||||
def multi_eval_2(poly, xs):
|
||||
if len(xs) <= 1024:
|
||||
return [eval_poly_at(poly, x) for x in xs]
|
||||
halflen = len(xs) // 2
|
||||
return multi_eval_2(mod(poly, mk_root_2(xs[:halflen])), xs[:halflen]) + \
|
||||
multi_eval_2(mod(poly, mk_root_2(xs[halflen:])), xs[halflen:])
|
||||
# [eval_poly_at(poly, xs[-2]), eval_poly_at(poly, xs[-1])]
|
|
@ -1,24 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
|
||||
with open('README.md') as f:
|
||||
readme = f.read()
|
||||
|
||||
with open('LICENSE') as f:
|
||||
license = f.read()
|
||||
|
||||
setup(
|
||||
name='ecpoly',
|
||||
version='1.0.0',
|
||||
description='Erasure code utilities for prime fields',
|
||||
long_description=readme,
|
||||
author='Vitalik Buterin',
|
||||
author_email='',
|
||||
url='https://github.com/ethereum/research/tree/master/erasure_code/ecpoly',
|
||||
license=license,
|
||||
packages=find_packages(exclude=('tests', 'docs')),
|
||||
install_requires=[
|
||||
],
|
||||
)
|
|
@ -1,15 +0,0 @@
|
|||
from ecpoly import PrimeField
|
||||
|
||||
f = PrimeField(65537)
|
||||
|
||||
k1 = list(range(10))
|
||||
k2 = list(range(100, 200))
|
||||
k3 = f.mul_polys(k1, k2)
|
||||
assert f.div_polys(k3, k1) == k2
|
||||
assert f.div_polys(k3, k2) == k1
|
||||
assert (f.eval_poly_at(k1, 9999) * f.eval_poly_at(k2, 9999) -
|
||||
f.eval_poly_at(k3, 9999)) % f.modulus == 0
|
||||
k4 = f.compose_polys(k1, k2)
|
||||
assert f.eval_poly_at(k4, 9998) == f.eval_poly_at(k1, f.eval_poly_at(k2, 9998))
|
||||
|
||||
print("All passed!")
|
|
@ -1,10 +1,17 @@
|
|||
from better_lagrange import lagrange_interp_4, eval_poly_at
|
||||
from merkle_tree import merkelize, mk_branch, verify_branch
|
||||
from utils import get_power_cycle, get_pseudorandom_indices
|
||||
from fft import fft
|
||||
from poly_utils import PrimeField
|
||||
|
||||
# Generate an FRI proof
|
||||
def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus):
|
||||
# Generate an FRI proof that the polynomial that has the specified
|
||||
# values at successive powers of the specified root of unity has a
|
||||
# degree lower than maxdeg_plus_1
|
||||
#
|
||||
# We use maxdeg+1 instead of maxdeg because it's more mathematically
|
||||
# convenient in this case.
|
||||
|
||||
def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus):
|
||||
f = PrimeField(modulus)
|
||||
print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1))
|
||||
|
||||
# If the degree we are checking for is less than or equal to 32,
|
||||
|
@ -24,18 +31,17 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus):
|
|||
# Select a pseudo-random x coordinate
|
||||
special_x = int.from_bytes(m[1], 'big') % modulus
|
||||
|
||||
# Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
|
||||
# at that x coordinate
|
||||
# Calculate the "column" at that x coordinate
|
||||
# (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
|
||||
# We calculate the column by Lagrange-interpolating the row, and not
|
||||
# directly, as this is more efficient
|
||||
# directly from the polynomial, as this is more efficient
|
||||
column = []
|
||||
for i in range(len(xs)//4):
|
||||
x_poly = lagrange_interp_4(
|
||||
[values[i+len(values)*j//4] for j in range(4)],
|
||||
x_poly = f.lagrange_interp_4(
|
||||
[xs[i+len(xs)*j//4] for j in range(4)],
|
||||
modulus
|
||||
[values[i+len(values)*j//4] for j in range(4)],
|
||||
)
|
||||
column.append(eval_poly_at(x_poly, special_x, modulus))
|
||||
column.append(f.eval_poly_at(x_poly, special_x))
|
||||
m2 = merkelize(column)
|
||||
|
||||
# Pseudo-randomly select y indices to sample
|
||||
|
@ -44,23 +50,24 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus):
|
|||
# Compute the Merkle branches for the values in the polynomial and the column
|
||||
branches = []
|
||||
for y in ys:
|
||||
branches.append([mk_branch(m2, y)] + [mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)])
|
||||
branches.append([mk_branch(m2, y)] +
|
||||
[mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)])
|
||||
|
||||
# This component of the proof
|
||||
o = [m2[1], branches]
|
||||
|
||||
# In the next iteration of the proof, we'll work with smaller roots of unity
|
||||
sub_xs = [xs[i] for i in range(0, len(xs), 4)]
|
||||
|
||||
# Interpolate the polynomial for the column
|
||||
ypoly = fft(column[:len(sub_xs)], modulus,
|
||||
pow(root_of_unity, 4, modulus), inv=True)
|
||||
# sub_xs = [xs[i] for i in range(0, len(xs), 4)]
|
||||
# ypoly = fft(column[:len(sub_xs)], modulus,
|
||||
# f.exp(root_of_unity, 4), inv=True)
|
||||
|
||||
# Recurse...
|
||||
return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4, modulus)
|
||||
return [o] + prove_low_degree(column, f.exp(root_of_unity, 4),
|
||||
maxdeg_plus_1 // 4, modulus)
|
||||
|
||||
# Verify an FRI proof
|
||||
def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, modulus):
|
||||
f = PrimeField(modulus)
|
||||
|
||||
# Calculate which root of unity we're working with
|
||||
testval = root_of_unity
|
||||
|
@ -69,10 +76,11 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo
|
|||
roudeg *= 2
|
||||
testval = (testval * testval) % modulus
|
||||
|
||||
# Powers of the given root of unity 1, p, p**2, p**3 such that p**4 = 1
|
||||
quartic_roots_of_unity = [1,
|
||||
pow(root_of_unity, roudeg // 4, modulus),
|
||||
pow(root_of_unity, roudeg // 2, modulus),
|
||||
pow(root_of_unity, roudeg * 3 // 4, modulus)]
|
||||
f.exp(root_of_unity, roudeg // 4),
|
||||
f.exp(root_of_unity, roudeg // 2),
|
||||
f.exp(root_of_unity, roudeg * 3 // 4)]
|
||||
|
||||
# Verify the recursive components of the proof
|
||||
for prf in proof[:-1]:
|
||||
|
@ -86,27 +94,28 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo
|
|||
ys = get_pseudorandom_indices(root2, roudeg // 4, 40)
|
||||
|
||||
|
||||
# Verify for each selected y coordinate that the four points from the polynomial
|
||||
# and the one point from the column that are on that y coordinate are on the same
|
||||
# deg < 4 polynomial
|
||||
# Verify for each selected y coordinate that the four points from the
|
||||
# polynomial and the one point from the column that are on that y
|
||||
# coordinate are on the same deg < 4 polynomial
|
||||
for i, y in enumerate(ys):
|
||||
# The x coordinates from the polynomial
|
||||
x1 = pow(root_of_unity, y, modulus)
|
||||
x1 = f.exp(root_of_unity, y)
|
||||
xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)]
|
||||
|
||||
# The values from the polynomial
|
||||
row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) for j, prf in zip(range(4), branches[i][1:])]
|
||||
row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf)
|
||||
for j, prf in zip(range(4), branches[i][1:])]
|
||||
|
||||
# Verify proof and recover the column value
|
||||
values = [verify_branch(root2, y, branches[i][0])] + row
|
||||
|
||||
# Lagrange interpolate and check deg is < 4
|
||||
p = lagrange_interp_4(row, xcoords, modulus)
|
||||
assert eval_poly_at(p, special_x, modulus) == verify_branch(root2, y, branches[i][0])
|
||||
p = f.lagrange_interp_4(xcoords, row)
|
||||
assert f.eval_poly_at(p, special_x) == verify_branch(root2, y, branches[i][0])
|
||||
|
||||
# Update constants to check the next proof
|
||||
merkle_root = root2
|
||||
root_of_unity = pow(root_of_unity, 4, modulus)
|
||||
root_of_unity = f.exp(root_of_unity, 4)
|
||||
maxdeg_plus_1 //= 4
|
||||
roudeg //= 4
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from merkle_tree import merkelize, mk_branch, verify_branch, blake
|
||||
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
|
||||
from ecpoly import PrimeField
|
||||
from better_lagrange import lagrange_interp_4, lagrange_interp_2
|
||||
from poly_utils import PrimeField
|
||||
import time
|
||||
from fft import fft
|
||||
from fri import prove_low_degree, verify_low_degree_proof
|
||||
|
@ -13,15 +12,11 @@ nonresidue = 7
|
|||
|
||||
spot_check_security_factor = 240
|
||||
|
||||
# Compute a MIMC permutation for 2**logsteps steps, using round constants
|
||||
# from the multiplicative subgroup of size 2**logprecision
|
||||
def mimc(inp, logsteps, logprecision):
|
||||
# Compute a MIMC permutation for 2**logsteps steps
|
||||
def mimc(inp, logsteps):
|
||||
start_time = time.time()
|
||||
steps = 2**logsteps
|
||||
precision = 2**logprecision
|
||||
# Get (steps)th root of unity
|
||||
subroot = pow(7, (modulus-1)//steps, modulus)
|
||||
# We use powers of 9 mod 2^256 as the ith round constant for the moment
|
||||
# We use powers of 9 mod 2^256 XORed with 1 as the ith round constant for the moment
|
||||
k = 1
|
||||
for i in range(steps-1):
|
||||
inp = (inp**3 + (k ^ 1)) % modulus
|
||||
|
@ -29,34 +24,20 @@ def mimc(inp, logsteps, logprecision):
|
|||
print("MIMC computed in %.4f sec" % (time.time() - start_time))
|
||||
return inp
|
||||
|
||||
# Convert a polynomial P(x) into a polynomial Q(x) = P(fac * x)
|
||||
def multiply_base(poly, fac):
|
||||
o = []
|
||||
r = 1
|
||||
for p in poly:
|
||||
o.append(p * r % modulus)
|
||||
r = r * fac % modulus
|
||||
return o
|
||||
|
||||
# Divides a polynomial by x^n-1
|
||||
def divide_by_xnm1(poly, n):
|
||||
if len(poly) <= n:
|
||||
return []
|
||||
return f.add_polys(poly[n:], divide_by_xnm1(poly[n:], n))
|
||||
|
||||
# Generate a STARK for a MIMC calculation
|
||||
def mk_mimc_proof(inp, logsteps, logprecision):
|
||||
def mk_mimc_proof(inp, logsteps):
|
||||
start_time = time.time()
|
||||
assert logsteps < logprecision <= 32
|
||||
assert logsteps <= 29
|
||||
logprecision = logsteps + 3
|
||||
steps = 2**logsteps
|
||||
precision = 2**logprecision
|
||||
|
||||
# Root of unity such that x^precision=1
|
||||
root_of_unity = pow(7, (modulus-1)//precision, modulus)
|
||||
root_of_unity = f.exp(7, (modulus-1)//precision)
|
||||
|
||||
# Root of unity such that x^skips=1
|
||||
skips = precision // steps
|
||||
subroot = pow(root_of_unity, skips)
|
||||
subroot = f.exp(root_of_unity, skips)
|
||||
|
||||
# Powers of the root of unity, our computational trace will be
|
||||
# along the sequence of roots of unity
|
||||
|
@ -82,30 +63,31 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
|||
|
||||
# Create the composed polynomial such that
|
||||
# C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x)
|
||||
term1 = multiply_base(values_polynomial, subroot)
|
||||
term1 = f.multiply_base(values_polynomial, subroot)
|
||||
p_evaluations = fft(values_polynomial, modulus, root_of_unity)
|
||||
term2 = fft([pow(x, 3, modulus) for x in p_evaluations], modulus, root_of_unity, inv=True)[:len(values_polynomial) * 3 - 2]
|
||||
term2 = fft([f.exp(x, 3) for x in p_evaluations], modulus, root_of_unity, inv=True)[:len(values_polynomial) * 3 - 2]
|
||||
c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial)
|
||||
print('Computed C(P, K) polynomial')
|
||||
|
||||
# Compute D(x) = C(P(x), P(rx), K(x)) / Z(x)
|
||||
# Z(x) = (x^steps - 1) / (x - x_atlast_step)
|
||||
d = divide_by_xnm1(f.mul_polys(c_of_values,
|
||||
[-last_step_position, 1]),
|
||||
steps)
|
||||
d = f.divide_by_xnm1(f.mul_polys(c_of_values,
|
||||
[-last_step_position, 1]),
|
||||
steps)
|
||||
# Consistency check
|
||||
assert (f.eval_poly_at(d, 90833) *
|
||||
(pow(90833, steps, modulus) - 1) *
|
||||
f.inv(f.eval_poly_at([-last_step_position, 1], 90833)) -
|
||||
f.eval_poly_at(c_of_values, 90833)) % modulus == 0
|
||||
# assert (f.eval_poly_at(d, 90833) *
|
||||
# (f.exp(90833, steps) - 1) *
|
||||
# f.inv(f.eval_poly_at([-last_step_position, 1], 90833)) -
|
||||
# f.eval_poly_at(c_of_values, 90833)) % modulus == 0
|
||||
print('Computed D polynomial')
|
||||
|
||||
# Compute interpolant of ((1, input), (x_atlast_step, output))
|
||||
interpolant = lagrange_interp_2([inp, output], [1, last_step_position], modulus)
|
||||
interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
|
||||
quotient = f.mul_polys([-1, 1], [-last_step_position, 1])
|
||||
b = f.div_polys(f.sub_polys(values_polynomial, interpolant), quotient)
|
||||
# Consistency check
|
||||
assert f.eval_poly_at(f.add_polys(f.mul_polys(b, quotient), interpolant), 7045) == f.eval_poly_at(values_polynomial, 7045)
|
||||
# assert f.eval_poly_at(f.add_polys(f.mul_polys(b, quotient), interpolant), 7045) == \
|
||||
# f.eval_poly_at(values_polynomial, 7045)
|
||||
print('Computed B polynomial')
|
||||
|
||||
# Evaluate B, D and K across the entire subgroup
|
||||
|
@ -129,12 +111,18 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
|||
k3 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x03'), 'big')
|
||||
k4 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x04'), 'big')
|
||||
|
||||
lincomb = f.add_polys(f.add_polys(d,
|
||||
f.mul_by_const(values_polynomial, k1) + f.mul_by_const(values_polynomial, k2)),
|
||||
f.mul_by_const(b, k3) + [0, 0] + f.mul_by_const(b, k4) + [0,0])
|
||||
l_evaluations = fft(lincomb, modulus, root_of_unity)
|
||||
l_mtree = merkelize(l_evaluations)
|
||||
# Compute the linear combination. We don't even both calculating it in
|
||||
# coefficient form; we just compute the evaluations
|
||||
root_of_unity_to_the_steps = f.exp(root_of_unity, steps)
|
||||
powers = [1]
|
||||
for i in range(1, precision):
|
||||
powers.append(powers[-1] * root_of_unity_to_the_steps % modulus)
|
||||
|
||||
l_evaluations = [(d_evaluations[i] +
|
||||
p_evaluations[i] * k1 + p_evaluations[i] * k2 * powers[i] +
|
||||
b_evaluations[i] * k3 + b_evaluations[i] * powers[i] * k4) % modulus
|
||||
for i in range(precision)]
|
||||
l_mtree = merkelize(l_evaluations)
|
||||
print('Computed random linear combination')
|
||||
|
||||
# Do some spot checks of the Merkle tree at pseudo-random coordinates
|
||||
|
@ -158,20 +146,21 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
|||
b_mtree[1],
|
||||
l_mtree[1],
|
||||
branches,
|
||||
prove_low_degree(lincomb, root_of_unity, l_evaluations, steps * 2, modulus)]
|
||||
prove_low_degree(l_evaluations, root_of_unity, steps * 2, modulus)]
|
||||
print("STARK computed in %.4f sec" % (time.time() - start_time))
|
||||
return o
|
||||
|
||||
# Verifies a STARK
|
||||
def verify_mimc_proof(inp, logsteps, logprecision, output, proof):
|
||||
def verify_mimc_proof(inp, logsteps, output, proof):
|
||||
p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof
|
||||
start_time = time.time()
|
||||
|
||||
logprecision = logsteps + 3
|
||||
steps = 2**logsteps
|
||||
precision = 2**logprecision
|
||||
|
||||
# Get (steps)th root of unity
|
||||
root_of_unity = pow(7, (modulus-1)//precision, modulus)
|
||||
root_of_unity = f.exp(7, (modulus-1)//precision)
|
||||
skips = precision // steps
|
||||
|
||||
# Verifies the low-degree proofs
|
||||
|
@ -184,24 +173,29 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, proof):
|
|||
k4 = int.from_bytes(blake(p_root + d_root + b_root + b'\x04'), 'big')
|
||||
samples = spot_check_security_factor // (logprecision - logsteps)
|
||||
positions = get_pseudorandom_indices(l_root, precision - skips, samples)
|
||||
last_step_position = pow(root_of_unity, (steps - 1) * skips, modulus)
|
||||
last_step_position = f.exp(root_of_unity, (steps - 1) * skips)
|
||||
for i, pos in enumerate(positions):
|
||||
# Check C(P(x)) = Z(x) * D(x)
|
||||
x = pow(root_of_unity, pos, modulus)
|
||||
x_to_the_steps = pow(x, steps, modulus)
|
||||
x = f.exp(root_of_unity, pos)
|
||||
x_to_the_steps = f.exp(x, steps)
|
||||
p_of_x = verify_branch(p_root, pos, branches[i*6])
|
||||
p_of_rx = verify_branch(p_root, pos+skips, branches[i*6 + 1])
|
||||
d_of_x = verify_branch(d_root, pos, branches[i*6 + 2])
|
||||
k_of_x = verify_branch(k_root, pos, branches[i*6 + 3])
|
||||
b_of_x = verify_branch(b_root, pos, branches[i*6 + 4])
|
||||
l_of_x = verify_branch(l_root, pos, branches[i*6 + 5])
|
||||
zvalue = f.div(pow(x, steps, modulus) - 1,
|
||||
zvalue = f.div(f.exp(x, steps) - 1,
|
||||
x - last_step_position)
|
||||
|
||||
# Check transition constraints C(P(x)) = Z(x) * D(x)
|
||||
assert (p_of_rx - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0
|
||||
interpolant = lagrange_interp_2([inp, output], [1, last_step_position], modulus)
|
||||
interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
|
||||
quotient = f.mul_polys([-1, 1], [-last_step_position, 1])
|
||||
|
||||
# Check boundary constraints B(x) * Q(x) + I(x) = P(x)
|
||||
assert (p_of_x - b_of_x * f.eval_poly_at(quotient, x) -
|
||||
f.eval_poly_at(interpolant, x)) % modulus == 0
|
||||
|
||||
# Check correctness of the linear combination
|
||||
assert (l_of_x - d_of_x -
|
||||
k1 * p_of_x - k2 * p_of_x * x_to_the_steps -
|
||||
k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# Creates an object that includes convenience operations for numbers
|
||||
# and polynomials in some prime field
|
||||
class PrimeField():
|
||||
def __init__(self, modulus):
|
||||
assert pow(2, modulus, modulus) == 2
|
||||
|
@ -12,9 +14,13 @@ class PrimeField():
|
|||
def mul(self, x, y):
|
||||
return (x*y) % self.modulus
|
||||
|
||||
def exp(self, x, p):
|
||||
return pow(x, p, self.modulus)
|
||||
|
||||
# Modular inverse using the extended Euclidean algorithm
|
||||
def inv(self, a):
|
||||
if a == 0:
|
||||
return 0
|
||||
raise Exception("Cannot invert 0")
|
||||
lm, hm = 1, 0
|
||||
low, high = a % self.modulus, self.modulus
|
||||
while low > 1:
|
||||
|
@ -24,14 +30,10 @@ class PrimeField():
|
|||
return lm % self.modulus
|
||||
|
||||
def div(self, x, y):
|
||||
if x == 0 and y == 0:
|
||||
return 1
|
||||
return self.mul(x, self.inv(y))
|
||||
|
||||
# Evaluate a polynomial at a point
|
||||
def eval_poly_at(self, p, x):
|
||||
if x == 0:
|
||||
return p[0]
|
||||
y = 0
|
||||
power_of_x = 1
|
||||
for i, p_coeff in enumerate(p):
|
||||
|
@ -39,7 +41,7 @@ class PrimeField():
|
|||
power_of_x = (power_of_x * x) % self.modulus
|
||||
return y % self.modulus
|
||||
|
||||
# Build a polynomial that returns 0 at all xs
|
||||
# Build a polynomial that returns 0 at all specified xs
|
||||
def zpoly(self, xs):
|
||||
root = [1]
|
||||
for x in xs:
|
||||
|
@ -56,35 +58,62 @@ class PrimeField():
|
|||
# y coordinate at that point and 0 at all other points provided.
|
||||
# 3. Add these polynomials together.
|
||||
|
||||
def lagrange_interp(self, pieces, xs):
|
||||
def lagrange_interp(self, xs, ys):
|
||||
# Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn)
|
||||
root = self.zpoly(xs)
|
||||
#print(root)
|
||||
assert len(root) == len(pieces) + 1
|
||||
assert len(root) == len(ys) + 1
|
||||
# print(root)
|
||||
# Generate per-value numerator polynomials, eg. for x=x2,
|
||||
# (x - x1) * (x - x3) * ... * (x - xn), by dividing the master
|
||||
# polynomial back by each x coordinate
|
||||
nums = []
|
||||
for x in xs:
|
||||
output = [0] * (len(root) - 2) + [1]
|
||||
for j in range(len(root) - 2, 0, -1):
|
||||
output[j-1] = root[j] + output[j] * x
|
||||
assert len(output) == len(pieces)
|
||||
nums.append(output)
|
||||
#print(nums)
|
||||
nums = [self.div_polys(root, [-x, 1]) for x in xs]
|
||||
# Generate denominators by evaluating numerator polys at each x
|
||||
denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))]
|
||||
# Generate output polynomial, which is the sum of the per-value numerator
|
||||
# polynomials rescaled to have the right y values
|
||||
b = [0 for p in pieces]
|
||||
b = [0 for y in ys]
|
||||
for i in range(len(xs)):
|
||||
yslice = self.div(pieces[i], denoms[i])
|
||||
for j in range(len(pieces)):
|
||||
if nums[i][j] and pieces[i]:
|
||||
yslice = self.div(ys[i], denoms[i])
|
||||
for j in range(len(ys)):
|
||||
if nums[i][j] and ys[i]:
|
||||
b[j] += nums[i][j] * yslice
|
||||
return [x % self.modulus for x in b]
|
||||
|
||||
# Optimized version of the above restricted to deg-4 polynomials
|
||||
def lagrange_interp_4(self, xs, ys):
|
||||
x01, x02, x03, x12, x13, x23 = \
|
||||
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
|
||||
m = self.modulus
|
||||
eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
|
||||
eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
|
||||
eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
|
||||
eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
|
||||
e0 = self.eval_poly_at(eq0, xs[0])
|
||||
e1 = self.eval_poly_at(eq1, xs[1])
|
||||
e2 = self.eval_poly_at(eq2, xs[2])
|
||||
e3 = self.eval_poly_at(eq3, xs[3])
|
||||
e01 = e0 * e1
|
||||
e23 = e2 * e3
|
||||
invall = self.inv(e01 * e23)
|
||||
inv_y0 = ys[0] * invall * e1 * e23 % m
|
||||
inv_y1 = ys[1] * invall * e0 * e23 % m
|
||||
inv_y2 = ys[2] * invall * e01 * e3 % m
|
||||
inv_y3 = ys[3] * invall * e01 * e2 % m
|
||||
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)]
|
||||
|
||||
# Optimized version of the above restricted to deg-2 polynomials
|
||||
def lagrange_interp_2(self, xs, ys):
|
||||
m = self.modulus
|
||||
eq0 = [-xs[1] % m, 1]
|
||||
eq1 = [-xs[0] % m, 1]
|
||||
e0 = self.eval_poly_at(eq0, xs[0])
|
||||
e1 = self.eval_poly_at(eq1, xs[1])
|
||||
invall = self.inv(e0 * e1)
|
||||
inv_y0 = ys[0] * invall * e1
|
||||
inv_y1 = ys[1] * invall * e0
|
||||
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)]
|
||||
|
||||
# Arithmetic for polynomials
|
||||
def add_polys(self, a, b):
|
||||
return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0))
|
||||
% self.modulus for i in range(max(len(a), len(b)))]
|
||||
|
@ -118,7 +147,14 @@ class PrimeField():
|
|||
apos -= 1
|
||||
diff -= 1
|
||||
return [x % self.modulus for x in o]
|
||||
|
||||
# Divides a polynomial by x^n-1
|
||||
def divide_by_xnm1(self, poly, n):
|
||||
if len(poly) <= n:
|
||||
return []
|
||||
return self.add_polys(poly[n:], self.divide_by_xnm1(poly[n:], n))
|
||||
|
||||
# Returns P(x) = A(B(x))
|
||||
def compose_polys(self, a, b):
|
||||
o = []
|
||||
p = [1]
|
||||
|
@ -126,4 +162,13 @@ class PrimeField():
|
|||
o = self.add_polys(o, self.mul_by_const(p, c))
|
||||
p = self.mul_polys(p, b)
|
||||
return o
|
||||
|
||||
|
||||
# Convert a polynomial P(x) into a polynomial Q(x) = P(fac * x)
|
||||
# Equivalent to compose_polys(poly, [0, fac])
|
||||
def multiply_base(self, poly, fac):
|
||||
o = []
|
||||
r = 1
|
||||
for p in poly:
|
||||
o.append(p * r % self.modulus)
|
||||
r = r * fac % self.modulus
|
||||
return o
|
|
@ -12,25 +12,36 @@ def test_merkletree():
|
|||
|
||||
def test_fri():
|
||||
# Pure FRI tests
|
||||
poly = list(range(512))
|
||||
root_of_unity = pow(7, (modulus-1)//1024, modulus)
|
||||
poly = list(range(4096))
|
||||
root_of_unity = pow(7, (modulus-1)//16384, modulus)
|
||||
evaluations = fft(poly, modulus, root_of_unity)
|
||||
proof = prove_low_degree(poly, root_of_unity, evaluations, 512, modulus)
|
||||
proof = prove_low_degree(evaluations, root_of_unity, 4096, modulus)
|
||||
print("Approx proof length: %d" % bin_length(compress_fri(proof)))
|
||||
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512, modulus)
|
||||
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 4096, modulus)
|
||||
|
||||
try:
|
||||
fakedata = [x if pow(3, i, 4096) > 400 else 39 for x, i in enumerate(evaluations)]
|
||||
proof2 = prove_low_degree(fakedata, root_of_unity, 4096, modulus)
|
||||
assert verify_low_degree_proof(merkelize(fakedata)[1], root_of_unity, proof, 4096, modulus)
|
||||
raise Exception("Fake data passed FRI")
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 2048, modulus)
|
||||
raise Exception("Fake data passed FRI")
|
||||
except:
|
||||
pass
|
||||
|
||||
def test_stark():
|
||||
INPUT = 3
|
||||
LOGSTEPS = 13
|
||||
LOGPRECISION = 16
|
||||
|
||||
# Full STARK test
|
||||
proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION)
|
||||
proof = mk_mimc_proof(INPUT, LOGSTEPS)
|
||||
p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof
|
||||
L1 = bin_length(compress_branches(branches))
|
||||
L2 = bin_length(compress_fri(fri_proof))
|
||||
print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2))
|
||||
root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus)
|
||||
subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus)
|
||||
skips = 2**(LOGPRECISION - LOGSTEPS)
|
||||
assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), proof)
|
||||
assert verify_mimc_proof(3, LOGSTEPS, mimc(3, LOGSTEPS), proof)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_stark()
|
||||
|
|
Loading…
Reference in New Issue