198 lines
7.8 KiB
Python
198 lines
7.8 KiB
Python
# Creates an object that includes convenience operations for numbers
|
|
# and polynomials in some prime field
|
|
class PrimeField():
|
|
def __init__(self, modulus):
|
|
assert pow(2, modulus, modulus) == 2
|
|
self.modulus = modulus
|
|
|
|
def add(self, x, y):
|
|
return (x+y) % self.modulus
|
|
|
|
def sub(self, x, y):
|
|
return (x-y) % self.modulus
|
|
|
|
def mul(self, x, y):
|
|
return (x*y) % self.modulus
|
|
|
|
def exp(self, x, p):
|
|
return pow(x, p, self.modulus)
|
|
|
|
# Modular inverse using the extended Euclidean algorithm
|
|
def inv(self, a):
|
|
if a == 0:
|
|
return 0
|
|
lm, hm = 1, 0
|
|
low, high = a % self.modulus, self.modulus
|
|
while low > 1:
|
|
r = high//low
|
|
nm, new = hm-lm*r, high-low*r
|
|
lm, low, hm, high = nm, new, lm, low
|
|
return lm % self.modulus
|
|
|
|
def multi_inv(self, values):
|
|
partials = [1]
|
|
for i in range(len(values)):
|
|
partials.append(self.mul(partials[-1], values[i] or 1))
|
|
inv = self.inv(partials[-1])
|
|
outputs = [0] * len(values)
|
|
for i in range(len(values), 0, -1):
|
|
outputs[i-1] = self.mul(partials[i-1], inv) if values[i-1] else 0
|
|
inv = self.mul(inv, values[i-1] or 1)
|
|
return outputs
|
|
|
|
def div(self, x, y):
|
|
return self.mul(x, self.inv(y))
|
|
|
|
# Evaluate a polynomial at a point
|
|
def eval_poly_at(self, p, x):
|
|
y = 0
|
|
power_of_x = 1
|
|
for i, p_coeff in enumerate(p):
|
|
y += power_of_x * p_coeff
|
|
power_of_x = (power_of_x * x) % self.modulus
|
|
return y % self.modulus
|
|
|
|
# Arithmetic for polynomials
|
|
def add_polys(self, a, b):
|
|
return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0))
|
|
% self.modulus for i in range(max(len(a), len(b)))]
|
|
|
|
def sub_polys(self, a, b):
|
|
return [((a[i] if i < len(a) else 0) - (b[i] if i < len(b) else 0))
|
|
% self.modulus for i in range(max(len(a), len(b)))]
|
|
|
|
def mul_by_const(self, a, c):
|
|
return [(x*c) % self.modulus for x in a]
|
|
|
|
def mul_polys(self, a, b):
|
|
o = [0] * (len(a) + len(b) - 1)
|
|
for i, aval in enumerate(a):
|
|
for j, bval in enumerate(b):
|
|
o[i+j] += a[i] * b[j]
|
|
return [x % self.modulus for x in o]
|
|
|
|
def div_polys(self, a, b):
|
|
assert len(a) >= len(b)
|
|
a = [x for x in a]
|
|
o = []
|
|
apos = len(a) - 1
|
|
bpos = len(b) - 1
|
|
diff = apos - bpos
|
|
while diff >= 0:
|
|
quot = self.div(a[apos], b[bpos])
|
|
o.insert(0, quot)
|
|
for i in range(bpos, -1, -1):
|
|
a[diff+i] -= b[i] * quot
|
|
apos -= 1
|
|
diff -= 1
|
|
return [x % self.modulus for x in o]
|
|
|
|
# Build a polynomial that returns 0 at all specified xs
|
|
def zpoly(self, xs):
|
|
root = [1]
|
|
for x in xs:
|
|
root.insert(0, 0)
|
|
for j in range(len(root)-1):
|
|
root[j] -= root[j+1] * x
|
|
return [x % self.modulus for x in root]
|
|
|
|
# Given p+1 y values and x values with no errors, recovers the original
|
|
# p+1 degree polynomial.
|
|
# Lagrange interpolation works roughly in the following way.
|
|
# 1. Suppose you have a set of points, eg. x = [1, 2, 3], y = [2, 5, 10]
|
|
# 2. For each x, generate a polynomial which equals its corresponding
|
|
# y coordinate at that point and 0 at all other points provided.
|
|
# 3. Add these polynomials together.
|
|
|
|
def lagrange_interp(self, xs, ys):
|
|
# Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn)
|
|
root = self.zpoly(xs)
|
|
assert len(root) == len(ys) + 1
|
|
# print(root)
|
|
# Generate per-value numerator polynomials, eg. for x=x2,
|
|
# (x - x1) * (x - x3) * ... * (x - xn), by dividing the master
|
|
# polynomial back by each x coordinate
|
|
nums = [self.div_polys(root, [-x, 1]) for x in xs]
|
|
# Generate denominators by evaluating numerator polys at each x
|
|
denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))]
|
|
invdenoms = self.multi_inv(denoms)
|
|
# Generate output polynomial, which is the sum of the per-value numerator
|
|
# polynomials rescaled to have the right y values
|
|
b = [0 for y in ys]
|
|
for i in range(len(xs)):
|
|
yslice = self.mul(ys[i], invdenoms[i])
|
|
for j in range(len(ys)):
|
|
if nums[i][j] and ys[i]:
|
|
b[j] += nums[i][j] * yslice
|
|
return [x % self.modulus for x in b]
|
|
|
|
# Optimized poly evaluation for degree 4
|
|
def eval_quartic(self, p, x):
|
|
xsq = x * x % self.modulus
|
|
xcb = xsq * x
|
|
return (p[0] + p[1] * x + p[2] * xsq + p[3] * xcb) % self.modulus
|
|
|
|
# Optimized version of the above restricted to deg-4 polynomials
|
|
def lagrange_interp_4(self, xs, ys):
|
|
x01, x02, x03, x12, x13, x23 = \
|
|
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
|
|
m = self.modulus
|
|
eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
|
|
eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
|
|
eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
|
|
eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
|
|
e0 = self.eval_poly_at(eq0, xs[0])
|
|
e1 = self.eval_poly_at(eq1, xs[1])
|
|
e2 = self.eval_poly_at(eq2, xs[2])
|
|
e3 = self.eval_poly_at(eq3, xs[3])
|
|
e01 = e0 * e1
|
|
e23 = e2 * e3
|
|
invall = self.inv(e01 * e23)
|
|
inv_y0 = ys[0] * invall * e1 * e23 % m
|
|
inv_y1 = ys[1] * invall * e0 * e23 % m
|
|
inv_y2 = ys[2] * invall * e01 * e3 % m
|
|
inv_y3 = ys[3] * invall * e01 * e2 % m
|
|
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)]
|
|
|
|
# Optimized version of the above restricted to deg-2 polynomials
|
|
def lagrange_interp_2(self, xs, ys):
|
|
m = self.modulus
|
|
eq0 = [-xs[1] % m, 1]
|
|
eq1 = [-xs[0] % m, 1]
|
|
e0 = self.eval_poly_at(eq0, xs[0])
|
|
e1 = self.eval_poly_at(eq1, xs[1])
|
|
invall = self.inv(e0 * e1)
|
|
inv_y0 = ys[0] * invall * e1
|
|
inv_y1 = ys[1] * invall * e0
|
|
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)]
|
|
|
|
# Optimized version of the above restricted to deg-4 polynomials
|
|
def multi_interp_4(self, xsets, ysets):
|
|
data = []
|
|
invtargets = []
|
|
for xs, ys in zip(xsets, ysets):
|
|
x01, x02, x03, x12, x13, x23 = \
|
|
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
|
|
m = self.modulus
|
|
eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
|
|
eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
|
|
eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
|
|
eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
|
|
e0 = self.eval_quartic(eq0, xs[0])
|
|
e1 = self.eval_quartic(eq1, xs[1])
|
|
e2 = self.eval_quartic(eq2, xs[2])
|
|
e3 = self.eval_quartic(eq3, xs[3])
|
|
data.append([ys, eq0, eq1, eq2, eq3])
|
|
invtargets.extend([e0, e1, e2, e3])
|
|
invalls = self.multi_inv(invtargets)
|
|
o = []
|
|
for (i, (ys, eq0, eq1, eq2, eq3)) in enumerate(data):
|
|
invallz = invalls[i*4:i*4+4]
|
|
inv_y0 = ys[0] * invallz[0] % m
|
|
inv_y1 = ys[1] * invallz[1] % m
|
|
inv_y2 = ys[2] * invallz[2] % m
|
|
inv_y3 = ys[3] * invallz[3] % m
|
|
o.append([(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)])
|
|
# assert o == [self.lagrange_interp_4(xs, ys) for xs, ys in zip(xsets, ysets)]
|
|
return o
|