Improved efficiency a bunch and added boundary checks
This commit is contained in:
parent
770c0a2c78
commit
da1d723780
|
@ -16,32 +16,3 @@ def eval_poly_at(poly, x, modulus):
|
||||||
p = (p * x % modulus)
|
p = (p * x % modulus)
|
||||||
return o % modulus
|
return o % modulus
|
||||||
|
|
||||||
def lagrange_interp_4(pieces, xs, modulus):
|
|
||||||
x01, x02, x03, x12, x13, x23 = \
|
|
||||||
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
|
|
||||||
eq0 = [-x12 * xs[3] % modulus, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
|
|
||||||
eq1 = [-x02 * xs[3] % modulus, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
|
|
||||||
eq2 = [-x01 * xs[3] % modulus, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
|
|
||||||
eq3 = [-x01 * xs[2] % modulus, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
|
|
||||||
e0 = eval_poly_at(eq0, xs[0], modulus)
|
|
||||||
e1 = eval_poly_at(eq1, xs[1], modulus)
|
|
||||||
e2 = eval_poly_at(eq2, xs[2], modulus)
|
|
||||||
e3 = eval_poly_at(eq3, xs[3], modulus)
|
|
||||||
e01 = e0 * e1
|
|
||||||
e23 = e2 * e3
|
|
||||||
invall = inv(e01 * e23, modulus)
|
|
||||||
inv_y0 = pieces[0] * invall * e1 * e23 % modulus
|
|
||||||
inv_y1 = pieces[1] * invall * e0 * e23 % modulus
|
|
||||||
inv_y2 = pieces[2] * invall * e01 * e3 % modulus
|
|
||||||
inv_y3 = pieces[3] * invall * e01 * e2 % modulus
|
|
||||||
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % modulus for i in range(4)]
|
|
||||||
|
|
||||||
def lagrange_interp_2(pieces, xs, modulus):
|
|
||||||
eq0 = [-xs[1] % modulus, 1]
|
|
||||||
eq1 = [-xs[0] % modulus, 1]
|
|
||||||
e0 = eval_poly_at(eq0, xs[0], modulus)
|
|
||||||
e1 = eval_poly_at(eq1, xs[1], modulus)
|
|
||||||
invall = inv(e0 * e1, modulus)
|
|
||||||
inv_y0 = pieces[0] * invall * e1
|
|
||||||
inv_y1 = pieces[1] * invall * e0
|
|
||||||
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % modulus for i in range(2)]
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
from .poly_utils import PrimeField
|
|
|
@ -1,198 +0,0 @@
|
||||||
modulus_poly = [1, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 1, 0, 1, 0, 0, 1,
|
|
||||||
1]
|
|
||||||
modulus_poly_as_int = sum([(v << i) for i, v in enumerate(modulus_poly)])
|
|
||||||
degree = len(modulus_poly) - 1
|
|
||||||
|
|
||||||
two_to_the_degree = 2**degree
|
|
||||||
two_to_the_degree_m1 = 2**degree - 1
|
|
||||||
|
|
||||||
def galoistpl(a):
|
|
||||||
# 2 is not a primitive root, so we have to use 3 as our logarithm base
|
|
||||||
if a * 2 < two_to_the_degree:
|
|
||||||
return (a * 2) ^ a
|
|
||||||
else:
|
|
||||||
return (a * 2) ^ a ^ modulus_poly_as_int
|
|
||||||
|
|
||||||
# Precomputing a log table for increased speed of addition and multiplication
|
|
||||||
glogtable = [0] * (two_to_the_degree)
|
|
||||||
gexptable = []
|
|
||||||
v = 1
|
|
||||||
for i in range(two_to_the_degree_m1):
|
|
||||||
glogtable[v] = i
|
|
||||||
gexptable.append(v)
|
|
||||||
v = galoistpl(v)
|
|
||||||
|
|
||||||
gexptable += gexptable + gexptable
|
|
||||||
|
|
||||||
# Add two values in the Galois field
|
|
||||||
def galois_add(x, y):
|
|
||||||
return x ^ y
|
|
||||||
|
|
||||||
# In binary fields, addition and subtraction are the same thing
|
|
||||||
galois_sub = galois_add
|
|
||||||
|
|
||||||
# Multiply two values in the Galois field
|
|
||||||
def galois_mul(x, y):
|
|
||||||
return 0 if x*y == 0 else gexptable[glogtable[x] + glogtable[y]]
|
|
||||||
|
|
||||||
# Divide two values in the Galois field
|
|
||||||
def galois_div(x, y):
|
|
||||||
return 0 if x == 0 else gexptable[(glogtable[x] - glogtable[y]) % two_to_the_degree_m1]
|
|
||||||
|
|
||||||
# Evaluate a polynomial at a point
|
|
||||||
def eval_poly_at(p, x):
|
|
||||||
if x == 0:
|
|
||||||
return p[0]
|
|
||||||
y = 0
|
|
||||||
logx = glogtable[x]
|
|
||||||
for i, p_coeff in enumerate(p):
|
|
||||||
if p_coeff:
|
|
||||||
# Add x**i * coeff
|
|
||||||
y ^= gexptable[(logx * i + glogtable[p_coeff]) % two_to_the_degree_m1]
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
# Given p+1 y values and x values with no errors, recovers the original
|
|
||||||
# p+1 degree polynomial.
|
|
||||||
# Lagrange interpolation works roughly in the following way.
|
|
||||||
# 1. Suppose you have a set of points, eg. x = [1, 2, 3], y = [2, 5, 10]
|
|
||||||
# 2. For each x, generate a polynomial which equals its corresponding
|
|
||||||
# y coordinate at that point and 0 at all other points provided.
|
|
||||||
# 3. Add these polynomials together.
|
|
||||||
|
|
||||||
def lagrange_interp(pieces, xs):
|
|
||||||
# Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn)
|
|
||||||
root = mk_root_2(xs)
|
|
||||||
#print(root)
|
|
||||||
assert len(root) == len(pieces) + 1
|
|
||||||
# print(root)
|
|
||||||
# Generate the derivative
|
|
||||||
d = derivative(root)
|
|
||||||
# Generate denominators by evaluating numerator polys at each x
|
|
||||||
denoms = multi_eval_2(d, xs)
|
|
||||||
print(denoms)
|
|
||||||
# denoms = [eval_poly_at(d, xs[i]) for i in range(len(xs))]
|
|
||||||
# Generate output polynomial, which is the sum of the per-value numerator
|
|
||||||
# polynomials rescaled to have the right y values
|
|
||||||
factors = [galois_div(p, d) for p, d in zip(pieces, denoms)]
|
|
||||||
o = multi_root_derive(xs, factors)
|
|
||||||
# print(o)
|
|
||||||
return o
|
|
||||||
|
|
||||||
def multi_root_derive(xs, muls):
|
|
||||||
if len(xs) == 1:
|
|
||||||
return [muls[0]]
|
|
||||||
R1 = mk_root_2(xs[:len(xs) // 2])
|
|
||||||
R2 = mk_root_2(xs[len(xs) // 2:])
|
|
||||||
x1 = karatsuba_mul(R1, multi_root_derive(xs[len(xs) // 2:], muls[len(muls) // 2:]) + [0])
|
|
||||||
x2 = karatsuba_mul(R2, multi_root_derive(xs[:len(xs) // 2], muls[:len(muls) // 2]) + [0])
|
|
||||||
o = [v1 ^ v2 for v1, v2 in zip(x1, x2)][:len(xs)]
|
|
||||||
# print(len(R1), len(x1), len(xs), len(o))
|
|
||||||
return o
|
|
||||||
|
|
||||||
def multi_root_derive_1(xs, muls):
|
|
||||||
o = [0] * len(xs)
|
|
||||||
for i in range(len(xs)):
|
|
||||||
_xs = xs[:i] + xs[(i+1):]
|
|
||||||
root = mk_root_2(_xs)
|
|
||||||
for j in range(len(root)):
|
|
||||||
o[j] ^= galois_mul(root[j], muls[i])
|
|
||||||
return o
|
|
||||||
|
|
||||||
a = 124
|
|
||||||
b = 8932
|
|
||||||
c = 12415
|
|
||||||
|
|
||||||
assert galois_mul(galois_add(a, b), c) == galois_add(galois_mul(a, c), galois_mul(b, c))
|
|
||||||
|
|
||||||
def karatsuba_mul(p1, p2):
|
|
||||||
L = len(p1)
|
|
||||||
# assert L == len(p2)
|
|
||||||
if L <= 16:
|
|
||||||
o = [0] * (L * 2)
|
|
||||||
for i, v1 in enumerate(p1):
|
|
||||||
for j, v2 in enumerate(p2):
|
|
||||||
if v1 and v2:
|
|
||||||
o[i + j] ^= gexptable[glogtable[v1] + glogtable[v2]]
|
|
||||||
return o
|
|
||||||
if L % 2:
|
|
||||||
p1 = p1 + [0]
|
|
||||||
p2 = p2 + [0]
|
|
||||||
L += 1
|
|
||||||
halflen = L // 2
|
|
||||||
low1 = p1[:halflen]
|
|
||||||
high1 = p1[halflen:]
|
|
||||||
sum1 = [l ^ h for l, h in zip(low1, high1)]
|
|
||||||
low2 = p2[:halflen]
|
|
||||||
high2 = p2[halflen:]
|
|
||||||
sum2 = [l ^ h for l, h in zip(low2, high2)]
|
|
||||||
z2 = karatsuba_mul(high1, high2)
|
|
||||||
z0 = karatsuba_mul(low1, low2)
|
|
||||||
z1 = [m ^ _z0 ^ _z2 for m, _z0, _z2 in zip(karatsuba_mul(sum1, sum2), z0, z2)]
|
|
||||||
o = z0[:halflen] + \
|
|
||||||
[a ^ b for a, b in zip(z0[halflen:], z1[:halflen])] + \
|
|
||||||
[a ^ b for a, b in zip(z2[:halflen], z1[halflen:])] + \
|
|
||||||
z2[halflen:]
|
|
||||||
return o
|
|
||||||
|
|
||||||
def mk_root_1(xs):
|
|
||||||
root = [1]
|
|
||||||
for x in xs:
|
|
||||||
logx = glogtable[x]
|
|
||||||
root.insert(0, 0)
|
|
||||||
for j in range(len(root)-1):
|
|
||||||
if root[j+1] and x:
|
|
||||||
root[j] ^= gexptable[glogtable[root[j+1]] + logx]
|
|
||||||
return root
|
|
||||||
|
|
||||||
def mk_root_2(xs):
|
|
||||||
if len(xs) >= 128:
|
|
||||||
return karatsuba_mul(mk_root_2(xs[:len(xs) // 2]), mk_root_2(xs[len(xs) // 2:]))[:len(xs) + 1]
|
|
||||||
root = [1]
|
|
||||||
for x in xs:
|
|
||||||
logx = glogtable[x]
|
|
||||||
root.insert(0, 0)
|
|
||||||
for j in range(len(root)-1):
|
|
||||||
if root[j+1] and x:
|
|
||||||
root[j] ^= gexptable[glogtable[root[j+1]] + logx]
|
|
||||||
return root
|
|
||||||
|
|
||||||
def derivative(root):
|
|
||||||
return [0 if i % 2 else r for i, r in enumerate(root[1:])]
|
|
||||||
|
|
||||||
# Credit to http://people.csail.mit.edu/madhu/ST12/scribe/lect06.pdf for the algorithm
|
|
||||||
def xn_mod_poly(p):
|
|
||||||
if len(p) == 1:
|
|
||||||
return [galois_div(1, p[0])]
|
|
||||||
halflen = len(p) // 2
|
|
||||||
lowinv = xn_mod_poly(p[:halflen])
|
|
||||||
submod_high = karatsuba_mul(lowinv, p[:halflen])[halflen:]
|
|
||||||
med = karatsuba_mul(p[halflen:], lowinv)[:halflen]
|
|
||||||
med_plus_high = [x ^ y for x, y in zip(med, submod_high)]
|
|
||||||
highinv = karatsuba_mul(med_plus_high, lowinv)
|
|
||||||
o = (lowinv + highinv)[:len(p)]
|
|
||||||
print(halflen, lowinv, submod_high, med, highinv)
|
|
||||||
# assert karatsuba_mul(o, p)[:len(p)] == [1] + [0] * (len(p) - 1)
|
|
||||||
return o
|
|
||||||
|
|
||||||
def mod(a, b):
|
|
||||||
assert len(a) == 2 * (len(b) - 1)
|
|
||||||
L = len(b)
|
|
||||||
inv_rev_b = xn_mod_poly(b[::-1] + [0] * (len(a) - L))[:L]
|
|
||||||
quot = karatsuba_mul(inv_rev_b, a[::-1][:L])[:L-1][::-1]
|
|
||||||
subt = karatsuba_mul(b, quot + [0])[:-1]
|
|
||||||
o = [x ^ y for x, y in zip(a[:L-1], subt[:L-1])]
|
|
||||||
# assert [x^y for x, y in zip(karatsuba_mul(quot + [0], b), o)] == a
|
|
||||||
return o
|
|
||||||
|
|
||||||
def multi_eval_1(poly, xs):
|
|
||||||
return [eval_poly_at(poly, x) for x in xs]
|
|
||||||
|
|
||||||
def multi_eval_2(poly, xs):
|
|
||||||
if len(xs) <= 1024:
|
|
||||||
return [eval_poly_at(poly, x) for x in xs]
|
|
||||||
halflen = len(xs) // 2
|
|
||||||
return multi_eval_2(mod(poly, mk_root_2(xs[:halflen])), xs[:halflen]) + \
|
|
||||||
multi_eval_2(mod(poly, mk_root_2(xs[halflen:])), xs[halflen:])
|
|
||||||
# [eval_poly_at(poly, xs[-2]), eval_poly_at(poly, xs[-1])]
|
|
|
@ -1,24 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
from setuptools import setup, find_packages
|
|
||||||
|
|
||||||
|
|
||||||
with open('README.md') as f:
|
|
||||||
readme = f.read()
|
|
||||||
|
|
||||||
with open('LICENSE') as f:
|
|
||||||
license = f.read()
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name='ecpoly',
|
|
||||||
version='1.0.0',
|
|
||||||
description='Erasure code utilities for prime fields',
|
|
||||||
long_description=readme,
|
|
||||||
author='Vitalik Buterin',
|
|
||||||
author_email='',
|
|
||||||
url='https://github.com/ethereum/research/tree/master/erasure_code/ecpoly',
|
|
||||||
license=license,
|
|
||||||
packages=find_packages(exclude=('tests', 'docs')),
|
|
||||||
install_requires=[
|
|
||||||
],
|
|
||||||
)
|
|
|
@ -1,15 +0,0 @@
|
||||||
from ecpoly import PrimeField
|
|
||||||
|
|
||||||
f = PrimeField(65537)
|
|
||||||
|
|
||||||
k1 = list(range(10))
|
|
||||||
k2 = list(range(100, 200))
|
|
||||||
k3 = f.mul_polys(k1, k2)
|
|
||||||
assert f.div_polys(k3, k1) == k2
|
|
||||||
assert f.div_polys(k3, k2) == k1
|
|
||||||
assert (f.eval_poly_at(k1, 9999) * f.eval_poly_at(k2, 9999) -
|
|
||||||
f.eval_poly_at(k3, 9999)) % f.modulus == 0
|
|
||||||
k4 = f.compose_polys(k1, k2)
|
|
||||||
assert f.eval_poly_at(k4, 9998) == f.eval_poly_at(k1, f.eval_poly_at(k2, 9998))
|
|
||||||
|
|
||||||
print("All passed!")
|
|
|
@ -1,10 +1,17 @@
|
||||||
from better_lagrange import lagrange_interp_4, eval_poly_at
|
|
||||||
from merkle_tree import merkelize, mk_branch, verify_branch
|
from merkle_tree import merkelize, mk_branch, verify_branch
|
||||||
from utils import get_power_cycle, get_pseudorandom_indices
|
from utils import get_power_cycle, get_pseudorandom_indices
|
||||||
from fft import fft
|
from fft import fft
|
||||||
|
from poly_utils import PrimeField
|
||||||
|
|
||||||
# Generate an FRI proof
|
# Generate an FRI proof that the polynomial that has the specified
|
||||||
def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus):
|
# values at successive powers of the specified root of unity has a
|
||||||
|
# degree lower than maxdeg_plus_1
|
||||||
|
#
|
||||||
|
# We use maxdeg+1 instead of maxdeg because it's more mathematically
|
||||||
|
# convenient in this case.
|
||||||
|
|
||||||
|
def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus):
|
||||||
|
f = PrimeField(modulus)
|
||||||
print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1))
|
print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1))
|
||||||
|
|
||||||
# If the degree we are checking for is less than or equal to 32,
|
# If the degree we are checking for is less than or equal to 32,
|
||||||
|
@ -24,18 +31,17 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus):
|
||||||
# Select a pseudo-random x coordinate
|
# Select a pseudo-random x coordinate
|
||||||
special_x = int.from_bytes(m[1], 'big') % modulus
|
special_x = int.from_bytes(m[1], 'big') % modulus
|
||||||
|
|
||||||
# Calculate the "column" (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
|
# Calculate the "column" at that x coordinate
|
||||||
# at that x coordinate
|
# (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
|
||||||
# We calculate the column by Lagrange-interpolating the row, and not
|
# We calculate the column by Lagrange-interpolating the row, and not
|
||||||
# directly, as this is more efficient
|
# directly from the polynomial, as this is more efficient
|
||||||
column = []
|
column = []
|
||||||
for i in range(len(xs)//4):
|
for i in range(len(xs)//4):
|
||||||
x_poly = lagrange_interp_4(
|
x_poly = f.lagrange_interp_4(
|
||||||
[values[i+len(values)*j//4] for j in range(4)],
|
|
||||||
[xs[i+len(xs)*j//4] for j in range(4)],
|
[xs[i+len(xs)*j//4] for j in range(4)],
|
||||||
modulus
|
[values[i+len(values)*j//4] for j in range(4)],
|
||||||
)
|
)
|
||||||
column.append(eval_poly_at(x_poly, special_x, modulus))
|
column.append(f.eval_poly_at(x_poly, special_x))
|
||||||
m2 = merkelize(column)
|
m2 = merkelize(column)
|
||||||
|
|
||||||
# Pseudo-randomly select y indices to sample
|
# Pseudo-randomly select y indices to sample
|
||||||
|
@ -44,23 +50,24 @@ def prove_low_degree(poly, root_of_unity, values, maxdeg_plus_1, modulus):
|
||||||
# Compute the Merkle branches for the values in the polynomial and the column
|
# Compute the Merkle branches for the values in the polynomial and the column
|
||||||
branches = []
|
branches = []
|
||||||
for y in ys:
|
for y in ys:
|
||||||
branches.append([mk_branch(m2, y)] + [mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)])
|
branches.append([mk_branch(m2, y)] +
|
||||||
|
[mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)])
|
||||||
|
|
||||||
# This component of the proof
|
# This component of the proof
|
||||||
o = [m2[1], branches]
|
o = [m2[1], branches]
|
||||||
|
|
||||||
# In the next iteration of the proof, we'll work with smaller roots of unity
|
|
||||||
sub_xs = [xs[i] for i in range(0, len(xs), 4)]
|
|
||||||
|
|
||||||
# Interpolate the polynomial for the column
|
# Interpolate the polynomial for the column
|
||||||
ypoly = fft(column[:len(sub_xs)], modulus,
|
# sub_xs = [xs[i] for i in range(0, len(xs), 4)]
|
||||||
pow(root_of_unity, 4, modulus), inv=True)
|
# ypoly = fft(column[:len(sub_xs)], modulus,
|
||||||
|
# f.exp(root_of_unity, 4), inv=True)
|
||||||
|
|
||||||
# Recurse...
|
# Recurse...
|
||||||
return [o] + prove_low_degree(ypoly, pow(root_of_unity, 4, modulus), column, maxdeg_plus_1 // 4, modulus)
|
return [o] + prove_low_degree(column, f.exp(root_of_unity, 4),
|
||||||
|
maxdeg_plus_1 // 4, modulus)
|
||||||
|
|
||||||
# Verify an FRI proof
|
# Verify an FRI proof
|
||||||
def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, modulus):
|
def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, modulus):
|
||||||
|
f = PrimeField(modulus)
|
||||||
|
|
||||||
# Calculate which root of unity we're working with
|
# Calculate which root of unity we're working with
|
||||||
testval = root_of_unity
|
testval = root_of_unity
|
||||||
|
@ -69,10 +76,11 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo
|
||||||
roudeg *= 2
|
roudeg *= 2
|
||||||
testval = (testval * testval) % modulus
|
testval = (testval * testval) % modulus
|
||||||
|
|
||||||
|
# Powers of the given root of unity 1, p, p**2, p**3 such that p**4 = 1
|
||||||
quartic_roots_of_unity = [1,
|
quartic_roots_of_unity = [1,
|
||||||
pow(root_of_unity, roudeg // 4, modulus),
|
f.exp(root_of_unity, roudeg // 4),
|
||||||
pow(root_of_unity, roudeg // 2, modulus),
|
f.exp(root_of_unity, roudeg // 2),
|
||||||
pow(root_of_unity, roudeg * 3 // 4, modulus)]
|
f.exp(root_of_unity, roudeg * 3 // 4)]
|
||||||
|
|
||||||
# Verify the recursive components of the proof
|
# Verify the recursive components of the proof
|
||||||
for prf in proof[:-1]:
|
for prf in proof[:-1]:
|
||||||
|
@ -86,27 +94,28 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo
|
||||||
ys = get_pseudorandom_indices(root2, roudeg // 4, 40)
|
ys = get_pseudorandom_indices(root2, roudeg // 4, 40)
|
||||||
|
|
||||||
|
|
||||||
# Verify for each selected y coordinate that the four points from the polynomial
|
# Verify for each selected y coordinate that the four points from the
|
||||||
# and the one point from the column that are on that y coordinate are on the same
|
# polynomial and the one point from the column that are on that y
|
||||||
# deg < 4 polynomial
|
# coordinate are on the same deg < 4 polynomial
|
||||||
for i, y in enumerate(ys):
|
for i, y in enumerate(ys):
|
||||||
# The x coordinates from the polynomial
|
# The x coordinates from the polynomial
|
||||||
x1 = pow(root_of_unity, y, modulus)
|
x1 = f.exp(root_of_unity, y)
|
||||||
xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)]
|
xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)]
|
||||||
|
|
||||||
# The values from the polynomial
|
# The values from the polynomial
|
||||||
row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) for j, prf in zip(range(4), branches[i][1:])]
|
row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf)
|
||||||
|
for j, prf in zip(range(4), branches[i][1:])]
|
||||||
|
|
||||||
# Verify proof and recover the column value
|
# Verify proof and recover the column value
|
||||||
values = [verify_branch(root2, y, branches[i][0])] + row
|
values = [verify_branch(root2, y, branches[i][0])] + row
|
||||||
|
|
||||||
# Lagrange interpolate and check deg is < 4
|
# Lagrange interpolate and check deg is < 4
|
||||||
p = lagrange_interp_4(row, xcoords, modulus)
|
p = f.lagrange_interp_4(xcoords, row)
|
||||||
assert eval_poly_at(p, special_x, modulus) == verify_branch(root2, y, branches[i][0])
|
assert f.eval_poly_at(p, special_x) == verify_branch(root2, y, branches[i][0])
|
||||||
|
|
||||||
# Update constants to check the next proof
|
# Update constants to check the next proof
|
||||||
merkle_root = root2
|
merkle_root = root2
|
||||||
root_of_unity = pow(root_of_unity, 4, modulus)
|
root_of_unity = f.exp(root_of_unity, 4)
|
||||||
maxdeg_plus_1 //= 4
|
maxdeg_plus_1 //= 4
|
||||||
roudeg //= 4
|
roudeg //= 4
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from merkle_tree import merkelize, mk_branch, verify_branch, blake
|
from merkle_tree import merkelize, mk_branch, verify_branch, blake
|
||||||
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
|
from compression import compress_fri, decompress_fri, compress_branches, decompress_branches, bin_length
|
||||||
from ecpoly import PrimeField
|
from poly_utils import PrimeField
|
||||||
from better_lagrange import lagrange_interp_4, lagrange_interp_2
|
|
||||||
import time
|
import time
|
||||||
from fft import fft
|
from fft import fft
|
||||||
from fri import prove_low_degree, verify_low_degree_proof
|
from fri import prove_low_degree, verify_low_degree_proof
|
||||||
|
@ -13,15 +12,11 @@ nonresidue = 7
|
||||||
|
|
||||||
spot_check_security_factor = 240
|
spot_check_security_factor = 240
|
||||||
|
|
||||||
# Compute a MIMC permutation for 2**logsteps steps, using round constants
|
# Compute a MIMC permutation for 2**logsteps steps
|
||||||
# from the multiplicative subgroup of size 2**logprecision
|
def mimc(inp, logsteps):
|
||||||
def mimc(inp, logsteps, logprecision):
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
steps = 2**logsteps
|
steps = 2**logsteps
|
||||||
precision = 2**logprecision
|
# We use powers of 9 mod 2^256 XORed with 1 as the ith round constant for the moment
|
||||||
# Get (steps)th root of unity
|
|
||||||
subroot = pow(7, (modulus-1)//steps, modulus)
|
|
||||||
# We use powers of 9 mod 2^256 as the ith round constant for the moment
|
|
||||||
k = 1
|
k = 1
|
||||||
for i in range(steps-1):
|
for i in range(steps-1):
|
||||||
inp = (inp**3 + (k ^ 1)) % modulus
|
inp = (inp**3 + (k ^ 1)) % modulus
|
||||||
|
@ -29,34 +24,20 @@ def mimc(inp, logsteps, logprecision):
|
||||||
print("MIMC computed in %.4f sec" % (time.time() - start_time))
|
print("MIMC computed in %.4f sec" % (time.time() - start_time))
|
||||||
return inp
|
return inp
|
||||||
|
|
||||||
# Convert a polynomial P(x) into a polynomial Q(x) = P(fac * x)
|
|
||||||
def multiply_base(poly, fac):
|
|
||||||
o = []
|
|
||||||
r = 1
|
|
||||||
for p in poly:
|
|
||||||
o.append(p * r % modulus)
|
|
||||||
r = r * fac % modulus
|
|
||||||
return o
|
|
||||||
|
|
||||||
# Divides a polynomial by x^n-1
|
|
||||||
def divide_by_xnm1(poly, n):
|
|
||||||
if len(poly) <= n:
|
|
||||||
return []
|
|
||||||
return f.add_polys(poly[n:], divide_by_xnm1(poly[n:], n))
|
|
||||||
|
|
||||||
# Generate a STARK for a MIMC calculation
|
# Generate a STARK for a MIMC calculation
|
||||||
def mk_mimc_proof(inp, logsteps, logprecision):
|
def mk_mimc_proof(inp, logsteps):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
assert logsteps < logprecision <= 32
|
assert logsteps <= 29
|
||||||
|
logprecision = logsteps + 3
|
||||||
steps = 2**logsteps
|
steps = 2**logsteps
|
||||||
precision = 2**logprecision
|
precision = 2**logprecision
|
||||||
|
|
||||||
# Root of unity such that x^precision=1
|
# Root of unity such that x^precision=1
|
||||||
root_of_unity = pow(7, (modulus-1)//precision, modulus)
|
root_of_unity = f.exp(7, (modulus-1)//precision)
|
||||||
|
|
||||||
# Root of unity such that x^skips=1
|
# Root of unity such that x^skips=1
|
||||||
skips = precision // steps
|
skips = precision // steps
|
||||||
subroot = pow(root_of_unity, skips)
|
subroot = f.exp(root_of_unity, skips)
|
||||||
|
|
||||||
# Powers of the root of unity, our computational trace will be
|
# Powers of the root of unity, our computational trace will be
|
||||||
# along the sequence of roots of unity
|
# along the sequence of roots of unity
|
||||||
|
@ -82,30 +63,31 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
||||||
|
|
||||||
# Create the composed polynomial such that
|
# Create the composed polynomial such that
|
||||||
# C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x)
|
# C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x)
|
||||||
term1 = multiply_base(values_polynomial, subroot)
|
term1 = f.multiply_base(values_polynomial, subroot)
|
||||||
p_evaluations = fft(values_polynomial, modulus, root_of_unity)
|
p_evaluations = fft(values_polynomial, modulus, root_of_unity)
|
||||||
term2 = fft([pow(x, 3, modulus) for x in p_evaluations], modulus, root_of_unity, inv=True)[:len(values_polynomial) * 3 - 2]
|
term2 = fft([f.exp(x, 3) for x in p_evaluations], modulus, root_of_unity, inv=True)[:len(values_polynomial) * 3 - 2]
|
||||||
c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial)
|
c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial)
|
||||||
print('Computed C(P, K) polynomial')
|
print('Computed C(P, K) polynomial')
|
||||||
|
|
||||||
# Compute D(x) = C(P(x), P(rx), K(x)) / Z(x)
|
# Compute D(x) = C(P(x), P(rx), K(x)) / Z(x)
|
||||||
# Z(x) = (x^steps - 1) / (x - x_atlast_step)
|
# Z(x) = (x^steps - 1) / (x - x_atlast_step)
|
||||||
d = divide_by_xnm1(f.mul_polys(c_of_values,
|
d = f.divide_by_xnm1(f.mul_polys(c_of_values,
|
||||||
[-last_step_position, 1]),
|
[-last_step_position, 1]),
|
||||||
steps)
|
steps)
|
||||||
# Consistency check
|
# Consistency check
|
||||||
assert (f.eval_poly_at(d, 90833) *
|
# assert (f.eval_poly_at(d, 90833) *
|
||||||
(pow(90833, steps, modulus) - 1) *
|
# (f.exp(90833, steps) - 1) *
|
||||||
f.inv(f.eval_poly_at([-last_step_position, 1], 90833)) -
|
# f.inv(f.eval_poly_at([-last_step_position, 1], 90833)) -
|
||||||
f.eval_poly_at(c_of_values, 90833)) % modulus == 0
|
# f.eval_poly_at(c_of_values, 90833)) % modulus == 0
|
||||||
print('Computed D polynomial')
|
print('Computed D polynomial')
|
||||||
|
|
||||||
# Compute interpolant of ((1, input), (x_atlast_step, output))
|
# Compute interpolant of ((1, input), (x_atlast_step, output))
|
||||||
interpolant = lagrange_interp_2([inp, output], [1, last_step_position], modulus)
|
interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
|
||||||
quotient = f.mul_polys([-1, 1], [-last_step_position, 1])
|
quotient = f.mul_polys([-1, 1], [-last_step_position, 1])
|
||||||
b = f.div_polys(f.sub_polys(values_polynomial, interpolant), quotient)
|
b = f.div_polys(f.sub_polys(values_polynomial, interpolant), quotient)
|
||||||
# Consistency check
|
# Consistency check
|
||||||
assert f.eval_poly_at(f.add_polys(f.mul_polys(b, quotient), interpolant), 7045) == f.eval_poly_at(values_polynomial, 7045)
|
# assert f.eval_poly_at(f.add_polys(f.mul_polys(b, quotient), interpolant), 7045) == \
|
||||||
|
# f.eval_poly_at(values_polynomial, 7045)
|
||||||
print('Computed B polynomial')
|
print('Computed B polynomial')
|
||||||
|
|
||||||
# Evaluate B, D and K across the entire subgroup
|
# Evaluate B, D and K across the entire subgroup
|
||||||
|
@ -129,12 +111,18 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
||||||
k3 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x03'), 'big')
|
k3 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x03'), 'big')
|
||||||
k4 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x04'), 'big')
|
k4 = int.from_bytes(blake(p_mtree[1] + d_mtree[1] + b_mtree[1] + b'\x04'), 'big')
|
||||||
|
|
||||||
lincomb = f.add_polys(f.add_polys(d,
|
# Compute the linear combination. We don't even both calculating it in
|
||||||
f.mul_by_const(values_polynomial, k1) + f.mul_by_const(values_polynomial, k2)),
|
# coefficient form; we just compute the evaluations
|
||||||
f.mul_by_const(b, k3) + [0, 0] + f.mul_by_const(b, k4) + [0,0])
|
root_of_unity_to_the_steps = f.exp(root_of_unity, steps)
|
||||||
l_evaluations = fft(lincomb, modulus, root_of_unity)
|
powers = [1]
|
||||||
l_mtree = merkelize(l_evaluations)
|
for i in range(1, precision):
|
||||||
|
powers.append(powers[-1] * root_of_unity_to_the_steps % modulus)
|
||||||
|
|
||||||
|
l_evaluations = [(d_evaluations[i] +
|
||||||
|
p_evaluations[i] * k1 + p_evaluations[i] * k2 * powers[i] +
|
||||||
|
b_evaluations[i] * k3 + b_evaluations[i] * powers[i] * k4) % modulus
|
||||||
|
for i in range(precision)]
|
||||||
|
l_mtree = merkelize(l_evaluations)
|
||||||
print('Computed random linear combination')
|
print('Computed random linear combination')
|
||||||
|
|
||||||
# Do some spot checks of the Merkle tree at pseudo-random coordinates
|
# Do some spot checks of the Merkle tree at pseudo-random coordinates
|
||||||
|
@ -158,20 +146,21 @@ def mk_mimc_proof(inp, logsteps, logprecision):
|
||||||
b_mtree[1],
|
b_mtree[1],
|
||||||
l_mtree[1],
|
l_mtree[1],
|
||||||
branches,
|
branches,
|
||||||
prove_low_degree(lincomb, root_of_unity, l_evaluations, steps * 2, modulus)]
|
prove_low_degree(l_evaluations, root_of_unity, steps * 2, modulus)]
|
||||||
print("STARK computed in %.4f sec" % (time.time() - start_time))
|
print("STARK computed in %.4f sec" % (time.time() - start_time))
|
||||||
return o
|
return o
|
||||||
|
|
||||||
# Verifies a STARK
|
# Verifies a STARK
|
||||||
def verify_mimc_proof(inp, logsteps, logprecision, output, proof):
|
def verify_mimc_proof(inp, logsteps, output, proof):
|
||||||
p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof
|
p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
logprecision = logsteps + 3
|
||||||
steps = 2**logsteps
|
steps = 2**logsteps
|
||||||
precision = 2**logprecision
|
precision = 2**logprecision
|
||||||
|
|
||||||
# Get (steps)th root of unity
|
# Get (steps)th root of unity
|
||||||
root_of_unity = pow(7, (modulus-1)//precision, modulus)
|
root_of_unity = f.exp(7, (modulus-1)//precision)
|
||||||
skips = precision // steps
|
skips = precision // steps
|
||||||
|
|
||||||
# Verifies the low-degree proofs
|
# Verifies the low-degree proofs
|
||||||
|
@ -184,24 +173,29 @@ def verify_mimc_proof(inp, logsteps, logprecision, output, proof):
|
||||||
k4 = int.from_bytes(blake(p_root + d_root + b_root + b'\x04'), 'big')
|
k4 = int.from_bytes(blake(p_root + d_root + b_root + b'\x04'), 'big')
|
||||||
samples = spot_check_security_factor // (logprecision - logsteps)
|
samples = spot_check_security_factor // (logprecision - logsteps)
|
||||||
positions = get_pseudorandom_indices(l_root, precision - skips, samples)
|
positions = get_pseudorandom_indices(l_root, precision - skips, samples)
|
||||||
last_step_position = pow(root_of_unity, (steps - 1) * skips, modulus)
|
last_step_position = f.exp(root_of_unity, (steps - 1) * skips)
|
||||||
for i, pos in enumerate(positions):
|
for i, pos in enumerate(positions):
|
||||||
# Check C(P(x)) = Z(x) * D(x)
|
x = f.exp(root_of_unity, pos)
|
||||||
x = pow(root_of_unity, pos, modulus)
|
x_to_the_steps = f.exp(x, steps)
|
||||||
x_to_the_steps = pow(x, steps, modulus)
|
|
||||||
p_of_x = verify_branch(p_root, pos, branches[i*6])
|
p_of_x = verify_branch(p_root, pos, branches[i*6])
|
||||||
p_of_rx = verify_branch(p_root, pos+skips, branches[i*6 + 1])
|
p_of_rx = verify_branch(p_root, pos+skips, branches[i*6 + 1])
|
||||||
d_of_x = verify_branch(d_root, pos, branches[i*6 + 2])
|
d_of_x = verify_branch(d_root, pos, branches[i*6 + 2])
|
||||||
k_of_x = verify_branch(k_root, pos, branches[i*6 + 3])
|
k_of_x = verify_branch(k_root, pos, branches[i*6 + 3])
|
||||||
b_of_x = verify_branch(b_root, pos, branches[i*6 + 4])
|
b_of_x = verify_branch(b_root, pos, branches[i*6 + 4])
|
||||||
l_of_x = verify_branch(l_root, pos, branches[i*6 + 5])
|
l_of_x = verify_branch(l_root, pos, branches[i*6 + 5])
|
||||||
zvalue = f.div(pow(x, steps, modulus) - 1,
|
zvalue = f.div(f.exp(x, steps) - 1,
|
||||||
x - last_step_position)
|
x - last_step_position)
|
||||||
|
|
||||||
|
# Check transition constraints C(P(x)) = Z(x) * D(x)
|
||||||
assert (p_of_rx - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0
|
assert (p_of_rx - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0
|
||||||
interpolant = lagrange_interp_2([inp, output], [1, last_step_position], modulus)
|
interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
|
||||||
quotient = f.mul_polys([-1, 1], [-last_step_position, 1])
|
quotient = f.mul_polys([-1, 1], [-last_step_position, 1])
|
||||||
|
|
||||||
|
# Check boundary constraints B(x) * Q(x) + I(x) = P(x)
|
||||||
assert (p_of_x - b_of_x * f.eval_poly_at(quotient, x) -
|
assert (p_of_x - b_of_x * f.eval_poly_at(quotient, x) -
|
||||||
f.eval_poly_at(interpolant, x)) % modulus == 0
|
f.eval_poly_at(interpolant, x)) % modulus == 0
|
||||||
|
|
||||||
|
# Check correctness of the linear combination
|
||||||
assert (l_of_x - d_of_x -
|
assert (l_of_x - d_of_x -
|
||||||
k1 * p_of_x - k2 * p_of_x * x_to_the_steps -
|
k1 * p_of_x - k2 * p_of_x * x_to_the_steps -
|
||||||
k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0
|
k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Creates an object that includes convenience operations for numbers
|
||||||
|
# and polynomials in some prime field
|
||||||
class PrimeField():
|
class PrimeField():
|
||||||
def __init__(self, modulus):
|
def __init__(self, modulus):
|
||||||
assert pow(2, modulus, modulus) == 2
|
assert pow(2, modulus, modulus) == 2
|
||||||
|
@ -12,9 +14,13 @@ class PrimeField():
|
||||||
def mul(self, x, y):
|
def mul(self, x, y):
|
||||||
return (x*y) % self.modulus
|
return (x*y) % self.modulus
|
||||||
|
|
||||||
|
def exp(self, x, p):
|
||||||
|
return pow(x, p, self.modulus)
|
||||||
|
|
||||||
|
# Modular inverse using the extended Euclidean algorithm
|
||||||
def inv(self, a):
|
def inv(self, a):
|
||||||
if a == 0:
|
if a == 0:
|
||||||
return 0
|
raise Exception("Cannot invert 0")
|
||||||
lm, hm = 1, 0
|
lm, hm = 1, 0
|
||||||
low, high = a % self.modulus, self.modulus
|
low, high = a % self.modulus, self.modulus
|
||||||
while low > 1:
|
while low > 1:
|
||||||
|
@ -24,14 +30,10 @@ class PrimeField():
|
||||||
return lm % self.modulus
|
return lm % self.modulus
|
||||||
|
|
||||||
def div(self, x, y):
|
def div(self, x, y):
|
||||||
if x == 0 and y == 0:
|
|
||||||
return 1
|
|
||||||
return self.mul(x, self.inv(y))
|
return self.mul(x, self.inv(y))
|
||||||
|
|
||||||
# Evaluate a polynomial at a point
|
# Evaluate a polynomial at a point
|
||||||
def eval_poly_at(self, p, x):
|
def eval_poly_at(self, p, x):
|
||||||
if x == 0:
|
|
||||||
return p[0]
|
|
||||||
y = 0
|
y = 0
|
||||||
power_of_x = 1
|
power_of_x = 1
|
||||||
for i, p_coeff in enumerate(p):
|
for i, p_coeff in enumerate(p):
|
||||||
|
@ -39,7 +41,7 @@ class PrimeField():
|
||||||
power_of_x = (power_of_x * x) % self.modulus
|
power_of_x = (power_of_x * x) % self.modulus
|
||||||
return y % self.modulus
|
return y % self.modulus
|
||||||
|
|
||||||
# Build a polynomial that returns 0 at all xs
|
# Build a polynomial that returns 0 at all specified xs
|
||||||
def zpoly(self, xs):
|
def zpoly(self, xs):
|
||||||
root = [1]
|
root = [1]
|
||||||
for x in xs:
|
for x in xs:
|
||||||
|
@ -56,35 +58,62 @@ class PrimeField():
|
||||||
# y coordinate at that point and 0 at all other points provided.
|
# y coordinate at that point and 0 at all other points provided.
|
||||||
# 3. Add these polynomials together.
|
# 3. Add these polynomials together.
|
||||||
|
|
||||||
def lagrange_interp(self, pieces, xs):
|
def lagrange_interp(self, xs, ys):
|
||||||
# Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn)
|
# Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn)
|
||||||
root = self.zpoly(xs)
|
root = self.zpoly(xs)
|
||||||
#print(root)
|
assert len(root) == len(ys) + 1
|
||||||
assert len(root) == len(pieces) + 1
|
|
||||||
# print(root)
|
# print(root)
|
||||||
# Generate per-value numerator polynomials, eg. for x=x2,
|
# Generate per-value numerator polynomials, eg. for x=x2,
|
||||||
# (x - x1) * (x - x3) * ... * (x - xn), by dividing the master
|
# (x - x1) * (x - x3) * ... * (x - xn), by dividing the master
|
||||||
# polynomial back by each x coordinate
|
# polynomial back by each x coordinate
|
||||||
nums = []
|
nums = [self.div_polys(root, [-x, 1]) for x in xs]
|
||||||
for x in xs:
|
|
||||||
output = [0] * (len(root) - 2) + [1]
|
|
||||||
for j in range(len(root) - 2, 0, -1):
|
|
||||||
output[j-1] = root[j] + output[j] * x
|
|
||||||
assert len(output) == len(pieces)
|
|
||||||
nums.append(output)
|
|
||||||
#print(nums)
|
|
||||||
# Generate denominators by evaluating numerator polys at each x
|
# Generate denominators by evaluating numerator polys at each x
|
||||||
denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))]
|
denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))]
|
||||||
# Generate output polynomial, which is the sum of the per-value numerator
|
# Generate output polynomial, which is the sum of the per-value numerator
|
||||||
# polynomials rescaled to have the right y values
|
# polynomials rescaled to have the right y values
|
||||||
b = [0 for p in pieces]
|
b = [0 for y in ys]
|
||||||
for i in range(len(xs)):
|
for i in range(len(xs)):
|
||||||
yslice = self.div(pieces[i], denoms[i])
|
yslice = self.div(ys[i], denoms[i])
|
||||||
for j in range(len(pieces)):
|
for j in range(len(ys)):
|
||||||
if nums[i][j] and pieces[i]:
|
if nums[i][j] and ys[i]:
|
||||||
b[j] += nums[i][j] * yslice
|
b[j] += nums[i][j] * yslice
|
||||||
return [x % self.modulus for x in b]
|
return [x % self.modulus for x in b]
|
||||||
|
|
||||||
|
# Optimized version of the above restricted to deg-4 polynomials
|
||||||
|
def lagrange_interp_4(self, xs, ys):
|
||||||
|
x01, x02, x03, x12, x13, x23 = \
|
||||||
|
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
|
||||||
|
m = self.modulus
|
||||||
|
eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
|
||||||
|
eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
|
||||||
|
eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
|
||||||
|
eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
|
||||||
|
e0 = self.eval_poly_at(eq0, xs[0])
|
||||||
|
e1 = self.eval_poly_at(eq1, xs[1])
|
||||||
|
e2 = self.eval_poly_at(eq2, xs[2])
|
||||||
|
e3 = self.eval_poly_at(eq3, xs[3])
|
||||||
|
e01 = e0 * e1
|
||||||
|
e23 = e2 * e3
|
||||||
|
invall = self.inv(e01 * e23)
|
||||||
|
inv_y0 = ys[0] * invall * e1 * e23 % m
|
||||||
|
inv_y1 = ys[1] * invall * e0 * e23 % m
|
||||||
|
inv_y2 = ys[2] * invall * e01 * e3 % m
|
||||||
|
inv_y3 = ys[3] * invall * e01 * e2 % m
|
||||||
|
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)]
|
||||||
|
|
||||||
|
# Optimized version of the above restricted to deg-2 polynomials
|
||||||
|
def lagrange_interp_2(self, xs, ys):
|
||||||
|
m = self.modulus
|
||||||
|
eq0 = [-xs[1] % m, 1]
|
||||||
|
eq1 = [-xs[0] % m, 1]
|
||||||
|
e0 = self.eval_poly_at(eq0, xs[0])
|
||||||
|
e1 = self.eval_poly_at(eq1, xs[1])
|
||||||
|
invall = self.inv(e0 * e1)
|
||||||
|
inv_y0 = ys[0] * invall * e1
|
||||||
|
inv_y1 = ys[1] * invall * e0
|
||||||
|
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)]
|
||||||
|
|
||||||
|
# Arithmetic for polynomials
|
||||||
def add_polys(self, a, b):
|
def add_polys(self, a, b):
|
||||||
return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0))
|
return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0))
|
||||||
% self.modulus for i in range(max(len(a), len(b)))]
|
% self.modulus for i in range(max(len(a), len(b)))]
|
||||||
|
@ -119,6 +148,13 @@ class PrimeField():
|
||||||
diff -= 1
|
diff -= 1
|
||||||
return [x % self.modulus for x in o]
|
return [x % self.modulus for x in o]
|
||||||
|
|
||||||
|
# Divides a polynomial by x^n-1
|
||||||
|
def divide_by_xnm1(self, poly, n):
|
||||||
|
if len(poly) <= n:
|
||||||
|
return []
|
||||||
|
return self.add_polys(poly[n:], self.divide_by_xnm1(poly[n:], n))
|
||||||
|
|
||||||
|
# Returns P(x) = A(B(x))
|
||||||
def compose_polys(self, a, b):
|
def compose_polys(self, a, b):
|
||||||
o = []
|
o = []
|
||||||
p = [1]
|
p = [1]
|
||||||
|
@ -127,3 +163,12 @@ class PrimeField():
|
||||||
p = self.mul_polys(p, b)
|
p = self.mul_polys(p, b)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
# Convert a polynomial P(x) into a polynomial Q(x) = P(fac * x)
|
||||||
|
# Equivalent to compose_polys(poly, [0, fac])
|
||||||
|
def multiply_base(self, poly, fac):
|
||||||
|
o = []
|
||||||
|
r = 1
|
||||||
|
for p in poly:
|
||||||
|
o.append(p * r % self.modulus)
|
||||||
|
r = r * fac % self.modulus
|
||||||
|
return o
|
|
@ -12,25 +12,36 @@ def test_merkletree():
|
||||||
|
|
||||||
def test_fri():
|
def test_fri():
|
||||||
# Pure FRI tests
|
# Pure FRI tests
|
||||||
poly = list(range(512))
|
poly = list(range(4096))
|
||||||
root_of_unity = pow(7, (modulus-1)//1024, modulus)
|
root_of_unity = pow(7, (modulus-1)//16384, modulus)
|
||||||
evaluations = fft(poly, modulus, root_of_unity)
|
evaluations = fft(poly, modulus, root_of_unity)
|
||||||
proof = prove_low_degree(poly, root_of_unity, evaluations, 512, modulus)
|
proof = prove_low_degree(evaluations, root_of_unity, 4096, modulus)
|
||||||
print("Approx proof length: %d" % bin_length(compress_fri(proof)))
|
print("Approx proof length: %d" % bin_length(compress_fri(proof)))
|
||||||
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 512, modulus)
|
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 4096, modulus)
|
||||||
|
|
||||||
|
try:
|
||||||
|
fakedata = [x if pow(3, i, 4096) > 400 else 39 for x, i in enumerate(evaluations)]
|
||||||
|
proof2 = prove_low_degree(fakedata, root_of_unity, 4096, modulus)
|
||||||
|
assert verify_low_degree_proof(merkelize(fakedata)[1], root_of_unity, proof, 4096, modulus)
|
||||||
|
raise Exception("Fake data passed FRI")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
assert verify_low_degree_proof(merkelize(evaluations)[1], root_of_unity, proof, 2048, modulus)
|
||||||
|
raise Exception("Fake data passed FRI")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
def test_stark():
|
def test_stark():
|
||||||
INPUT = 3
|
INPUT = 3
|
||||||
LOGSTEPS = 13
|
LOGSTEPS = 13
|
||||||
LOGPRECISION = 16
|
|
||||||
|
|
||||||
# Full STARK test
|
# Full STARK test
|
||||||
proof = mk_mimc_proof(INPUT, LOGSTEPS, LOGPRECISION)
|
proof = mk_mimc_proof(INPUT, LOGSTEPS)
|
||||||
p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof
|
p_root, d_root, k_root, b_root, l_root, branches, fri_proof = proof
|
||||||
L1 = bin_length(compress_branches(branches))
|
L1 = bin_length(compress_branches(branches))
|
||||||
L2 = bin_length(compress_fri(fri_proof))
|
L2 = bin_length(compress_fri(fri_proof))
|
||||||
print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2))
|
print("Approx proof length: %d (branches), %d (FRI proof), %d (total)" % (L1, L2, L1 + L2))
|
||||||
root_of_unity = pow(7, (modulus-1)//2**LOGPRECISION, modulus)
|
assert verify_mimc_proof(3, LOGSTEPS, mimc(3, LOGSTEPS), proof)
|
||||||
subroot = pow(7, (modulus-1)//2**LOGSTEPS, modulus)
|
|
||||||
skips = 2**(LOGPRECISION - LOGSTEPS)
|
if __name__ == '__main__':
|
||||||
assert verify_mimc_proof(3, LOGSTEPS, LOGPRECISION, mimc(3, LOGSTEPS, LOGPRECISION), proof)
|
test_stark()
|
||||||
|
|
Loading…
Reference in New Issue