Efficiency updates

This commit is contained in:
Vitalik Buterin 2018-08-21 19:36:01 -04:00
parent b3b78bd4e5
commit 131b934f2c
2 changed files with 29 additions and 16 deletions

View File

@ -9,9 +9,9 @@ def _simple_ft(vals, modulus, roots_of_unity):
return o return o
def _fft(vals, modulus, roots_of_unity): def _fft(vals, modulus, roots_of_unity):
if len(vals) <= 1: if len(vals) <= 4:
return vals #return vals
# return _simple_ft(vals, modulus, roots_of_unity) return _simple_ft(vals, modulus, roots_of_unity)
L = _fft(vals[::2], modulus, roots_of_unity[::2]) L = _fft(vals[::2], modulus, roots_of_unity[::2])
R = _fft(vals[1::2], modulus, roots_of_unity[::2]) R = _fft(vals[1::2], modulus, roots_of_unity[::2])
o = [0 for i in vals] o = [0 for i in vals]
@ -39,7 +39,14 @@ def fft(vals, modulus, root_of_unity, inv=False):
return _fft(vals, modulus, rootz[:-1]) return _fft(vals, modulus, rootz[:-1])
def mul_polys(a, b, modulus, root_of_unity): def mul_polys(a, b, modulus, root_of_unity):
x1 = fft(a, modulus, root_of_unity) rootz = [1, root_of_unity]
x2 = fft(b, modulus, root_of_unity) while rootz[-1] != 1:
return fft([(v1*v2)%modulus for v1,v2 in zip(x1,x2)], rootz.append((rootz[-1] * root_of_unity) % modulus)
modulus, root_of_unity, inv=True) if len(rootz) > len(a) + 1:
a = a + [0] * (len(rootz) - len(a) - 1)
if len(rootz) > len(b) + 1:
b = b + [0] * (len(rootz) - len(b) - 1)
x1 = _fft(a, modulus, rootz[:-1])
x2 = _fft(b, modulus, rootz[:-1])
return _fft([(v1*v2)%modulus for v1,v2 in zip(x1,x2)],
modulus, rootz[:0:-1])

View File

@ -23,31 +23,30 @@ def p_of_kx(poly, modulus, k):
# Return (x - root**positions[0]) * (x - root**positions[1]) * ... # Return (x - root**positions[0]) * (x - root**positions[1]) * ...
# possibly with a constant factor offset # possibly with a constant factor offset
def zpoly(positions, modulus, root_of_unity): def _zpoly(positions, modulus, roots_of_unity):
# If there are not more than 4 positions, use the naive # If there are not more than 4 positions, use the naive
# O(n^2) algorithm as it is faster # O(n^2) algorithm as it is faster
if len(positions) <= 4: if len(positions) <= 4:
root = [1] root = [1]
for pos in positions: for pos in positions:
x = pow(root_of_unity, pos, modulus) x = roots_of_unity[pos]
root.insert(0, 0) root.insert(0, 0)
for j in range(len(root)-1): for j in range(len(root)-1):
root[j] -= root[j+1] * x root[j] -= root[j+1] * x
return [x % modulus for x in root] return [x % modulus for x in root]
else: else:
half_order_root_of_unity = pow(root_of_unity, 2, modulus)
# Recursively find the zpoly for even indices and odd # Recursively find the zpoly for even indices and odd
# indices, operating over a half-size subgroup in each # indices, operating over a half-size subgroup in each
# case # case
left = zpoly([x//2 for x in positions if x%2 == 0], left = _zpoly([x//2 for x in positions if x%2 == 0],
modulus, half_order_root_of_unity) modulus, roots_of_unity[::2])
right = zpoly([x//2 for x in positions if x%2 == 1], right = _zpoly([x//2 for x in positions if x%2 == 1],
modulus, half_order_root_of_unity) modulus, roots_of_unity[::2])
invroot = pow(root_of_unity, modulus - 2, modulus) invroot = roots_of_unity[-1]
# Offset the result for the odd indices, and combine # Offset the result for the odd indices, and combine
# the two # the two
o = mul_polys(left, p_of_kx(right, modulus, invroot), o = mul_polys(left, p_of_kx(right, modulus, invroot),
modulus, root_of_unity) modulus, roots_of_unity[1])
# Deal with the special case where mul_polys returns zero # Deal with the special case where mul_polys returns zero
# when it should return x ^ (2 ** k) - 1 # when it should return x ^ (2 ** k) - 1
if o == [0] * len(o): if o == [0] * len(o):
@ -55,6 +54,13 @@ def zpoly(positions, modulus, root_of_unity):
else: else:
return o return o
def zpoly(positions, modulus, root_of_unity):
# Precompute roots of unity
rootz = [1, root_of_unity]
while rootz[-1] != 1:
rootz.append((rootz[-1] * root_of_unity) % modulus)
return _zpoly(positions, modulus, rootz[:-1])
def erasure_code_recover(vals, modulus, root_of_unity): def erasure_code_recover(vals, modulus, root_of_unity):
# Generate the polynomial that is zero at the roots of unity # Generate the polynomial that is zero at the roots of unity
# corresponding to the indices where vals[i] is None # corresponding to the indices where vals[i] is None