Added 2-of-3 and TRIBES low influence functions

This commit is contained in:
Vitalik Buterin 2018-07-18 20:43:34 -04:00
parent f49c584f83
commit 1830ae2091
6 changed files with 180 additions and 32 deletions

View File

@ -1,10 +1,12 @@
def compress_fri(prf): def compress_fri(prf):
o = [] o = []
oindex = {}
def add_obj(x): def add_obj(x):
if x in o: if x in oindex:
o.append(o.index(x).to_bytes(2, 'big')) o.append(oindex[x].to_bytes(2, 'big'))
else: else:
o.append(x) o.append(x)
oindex[x] = len(o)-1
for root, yproofs in prf[:-1]: for root, yproofs in prf[:-1]:
# print('Adding proof item, pos %d' % len(o)) # print('Adding proof item, pos %d' % len(o))
@ -54,11 +56,13 @@ def decompress_fri(proof):
def compress_branches(branches): def compress_branches(branches):
o = [] o = []
oindex = {}
def add_obj(x): def add_obj(x):
if x in o: if x in oindex:
o.append(o.index(x).to_bytes(2, 'big')) o.append(oindex[x].to_bytes(2, 'big'))
else: else:
o.append(x) o.append(x)
oindex[x] = len(o)-1
for branch in branches: for branch in branches:
for p in branch: for p in branch:

View File

@ -1,13 +1,15 @@
def _simple_ft(vals, modulus, roots_of_unity): def _simple_ft(vals, modulus, roots_of_unity):
L = len(roots_of_unity) L = len(roots_of_unity)
o = [0 for _ in range(L)] o = []
for i in range(L): for i in range(L):
last = 0
for j in range(L): for j in range(L):
o[i] += vals[j] * roots_of_unity[(i*j)%L] last += vals[j] * roots_of_unity[(i*j)%L]
return [x % modulus for x in o] o.append(last % modulus)
return o
def _fft(vals, modulus, roots_of_unity): def _fft(vals, modulus, roots_of_unity):
if len(vals) == 1: if len(vals) <= 1:
return vals return vals
# return _simple_ft(vals, modulus, roots_of_unity) # return _simple_ft(vals, modulus, roots_of_unity)
L = _fft(vals[::2], modulus, roots_of_unity[::2]) L = _fft(vals[::2], modulus, roots_of_unity[::2])

View File

@ -1,6 +1,5 @@
from merkle_tree import merkelize, mk_branch, verify_branch from merkle_tree import merkelize, mk_branch, verify_branch
from utils import get_power_cycle, get_pseudorandom_indices from utils import get_power_cycle, get_pseudorandom_indices
from fft import fft
from poly_utils import PrimeField from poly_utils import PrimeField
# Generate an FRI proof that the polynomial that has the specified # Generate an FRI proof that the polynomial that has the specified
@ -33,15 +32,14 @@ def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus, exclude_mult
# Calculate the "column" at that x coordinate # Calculate the "column" at that x coordinate
# (see https://vitalik.ca/general/2017/11/22/starks_part_2.html) # (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
# We calculate the column by Lagrange-interpolating the row, and not # We calculate the column by Lagrange-interpolating each row, and not
# directly from the polynomial, as this is more efficient # directly from the polynomial, as this is more efficient
column = [] quarter_len = len(xs)//4
for i in range(len(xs)//4): x_polys = f.multi_interp_4(
x_poly = f.lagrange_interp_4( [[xs[i+quarter_len*j] for j in range(4)] for i in range(quarter_len)],
[xs[i+len(xs)*j//4] for j in range(4)], [[values[i+quarter_len*j] for j in range(4)] for i in range(quarter_len)]
[values[i+len(values)*j//4] for j in range(4)], )
) column = [f.eval_quartic(p, special_x) for p in x_polys]
column.append(f.eval_poly_at(x_poly, special_x))
m2 = merkelize(column) m2 = merkelize(column)
# Pseudo-randomly select y indices to sample # Pseudo-randomly select y indices to sample
@ -56,11 +54,6 @@ def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus, exclude_mult
# This component of the proof # This component of the proof
o = [m2[1], branches] o = [m2[1], branches]
# Interpolate the polynomial for the column
# sub_xs = [xs[i] for i in range(0, len(xs), 4)]
# ypoly = fft(column[:len(sub_xs)], modulus,
# f.exp(root_of_unity, 4), inv=True)
# Recurse... # Recurse...
return [o] + prove_low_degree(column, f.exp(root_of_unity, 4), return [o] + prove_low_degree(column, f.exp(root_of_unity, 4),
maxdeg_plus_1 // 4, modulus, exclude_multiples_of=exclude_multiples_of) maxdeg_plus_1 // 4, modulus, exclude_multiples_of=exclude_multiples_of)
@ -94,24 +87,30 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo
ys = get_pseudorandom_indices(root2, roudeg // 4, 40, ys = get_pseudorandom_indices(root2, roudeg // 4, 40,
exclude_multiples_of=exclude_multiples_of) exclude_multiples_of=exclude_multiples_of)
# Verify for each selected y coordinate that the four points from the # For each y coordinate, get the x coordinates on the row, the values on
# polynomial and the one point from the column that are on that y # the row, and the value at that y from the column
# coordinate are on the same deg < 4 polynomial xcoords = []
rows = []
columnvals = []
for i, y in enumerate(ys): for i, y in enumerate(ys):
# The x coordinates from the polynomial # The x coordinates from the polynomial
x1 = f.exp(root_of_unity, y) x1 = f.exp(root_of_unity, y)
xcoords = [(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)] xcoords.append([(quartic_roots_of_unity[j] * x1) % modulus for j in range(4)])
# The values from the polynomial # The values from the original polynomial
row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf) row = [verify_branch(merkle_root, y + (roudeg // 4) * j, prf)
for j, prf in zip(range(4), branches[i][1:])] for j, prf in zip(range(4), branches[i][1:])]
rows.append(row)
# Verify proof and recover the column value columnvals.append(verify_branch(root2, y, branches[i][0]))
values = [verify_branch(root2, y, branches[i][0])] + row
# Lagrange interpolate and check deg is < 4 # Verify for each selected y coordinate that the four points from the
p = f.lagrange_interp_4(xcoords, row) # polynomial and the one point from the column that are on that y
assert f.eval_poly_at(p, special_x) == verify_branch(root2, y, branches[i][0]) # coordinate are on the same deg < 4 polynomial
polys = f.multi_interp_4(xcoords, rows)
for p, c in zip(polys, columnvals):
assert f.eval_quartic(p, special_x) == c
# Update constants to check the next proof # Update constants to check the next proof
merkle_root = root2 merkle_root = root2

View File

@ -126,6 +126,12 @@ class PrimeField():
b[j] += nums[i][j] * yslice b[j] += nums[i][j] * yslice
return [x % self.modulus for x in b] return [x % self.modulus for x in b]
# Optimized poly evaluation for degree 4
def eval_quartic(self, p, x):
xsq = x * x % self.modulus
xcb = xsq * x
return (p[0] + p[1] * x + p[2] * xsq + p[3] * xcb) % self.modulus
# Optimized version of the above restricted to deg-4 polynomials # Optimized version of the above restricted to deg-4 polynomials
def lagrange_interp_4(self, xs, ys): def lagrange_interp_4(self, xs, ys):
x01, x02, x03, x12, x13, x23 = \ x01, x02, x03, x12, x13, x23 = \
@ -159,3 +165,33 @@ class PrimeField():
inv_y0 = ys[0] * invall * e1 inv_y0 = ys[0] * invall * e1
inv_y1 = ys[1] * invall * e0 inv_y1 = ys[1] * invall * e0
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)] return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)]
# Optimized version of the above restricted to deg-4 polynomials
def multi_interp_4(self, xsets, ysets):
data = []
invtargets = []
for xs, ys in zip(xsets, ysets):
x01, x02, x03, x12, x13, x23 = \
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
m = self.modulus
eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
e0 = self.eval_quartic(eq0, xs[0])
e1 = self.eval_quartic(eq1, xs[1])
e2 = self.eval_quartic(eq2, xs[2])
e3 = self.eval_quartic(eq3, xs[3])
data.append([ys, eq0, eq1, eq2, eq3])
invtargets.extend([e0, e1, e2, e3])
invalls = self.multi_inv(invtargets)
o = []
for (i, (ys, eq0, eq1, eq2, eq3)) in enumerate(data):
invallz = invalls[i*4:i*4+4]
inv_y0 = ys[0] * invallz[0] % m
inv_y1 = ys[1] * invallz[1] % m
inv_y2 = ys[2] * invallz[2] % m
inv_y3 = ys[3] * invallz[3] % m
o.append([(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)])
# assert o == [self.lagrange_interp_4(xs, ys) for xs, ys in zip(xsets, ysets)]
return o

View File

@ -0,0 +1,61 @@
# Implements the "iterated 2-of-3 majority" low-influence function
# from https://arxiv.org/pdf/1406.5694.pdf and outputs the probability
# that any specific user will be able to influence the result
import random
def mkbits(depth):
return random.randrange(2**(3**depth))
def winner(val, depth):
if depth == 0:
return val & 1
subwinners = [
winner(val, depth-1),
winner(val >> (3**(depth-1)), depth-1),
winner(val >> (3**(depth-1) * 2), depth-1)
]
return 1 if sum(subwinners) >= 2 else 0
def is_marginal(val, depth):
if depth == 0:
return True
s1, s2, s3 = val, val >> (3**(depth-1)), val >> (3**(depth-1) * 2)
w1, w2, w3 = (
winner(s1, depth-1),
winner(s2, depth-1),
winner(s3, depth-1)
)
if w1 == w2 and w2 == w3:
return False
dominants = (s2, s3) if (w1 != w2 and w1 != w3) else \
(s1, s3) if (w2 != w1 and w2 != w3) else \
(s1, s2) if (w3 != w1 and w3 != w2) else \
False
return is_marginal(dominants[0], depth-1) or \
is_marginal(dominants[1], depth-1)
def is_marginal2(val, depth):
w = winner(val, depth)
for x in range(3**depth):
val2 = val ^ (1 << x)
w2 = winner(val2, depth)
if w2 != w:
return True
return False
def influence(val, depth):
tot = 0
w = winner(val, depth)
for i in range(3**depth):
val2 = val ^ (1 << i)
w2 = winner(val2, depth)
if w != w2:
tot += 1
return tot / (3**depth)
bitz = [mkbits(3) for i in range(1000)]
print(sum([influence(b, 3) for b in bitz]) / 1000)

View File

@ -0,0 +1,46 @@
# Implements a modified version of the TRIBES low-influence function
# mentioned in https://arxiv.org/pdf/1406.5694.pdf and outputs the
# probability that any specific user will be able to influence the result
import random, math
def mkbits(n):
return random.randrange(2**n)
def tribes_log(n):
w = 1
while w * 2**w * 693 < n * 1000:
w += 1
return w
def tribes(val, n):
split = tribes_log(n)
o = []
full_subset = (1 << split) - 1
for i in range(n):
vall = val ^ ((2*i+3)**n % 2**n)
t = 0
for _ in range(n // split):
if vall & full_subset == full_subset:
t = 1
break
vall >>= split
o.append(t)
if len(o) % 2 == 0 and o[-2] == 0 and o[-1] == 1:
return 0
if len(o) % 2 == 0 and o[-2] == 1 and o[-1] == 0:
return 1
return o[-1]
def influence(val, n):
tot = 0
w = tribes(val, n)
for i in range(n):
val2 = val ^ (1 << i)
w2 = tribes(val2, n)
if w != w2:
tot += 1
return tot / n
print(sum([influence(mkbits(50), 50) for i in range(1000)]) / 1000)
# print(sum([tribes(mkbits(50), 50) for i in range(1000)]) / 1000)