Cleaned up a bit from poly_utils

This commit is contained in:
Vitalik Buterin 2018-07-11 17:33:39 -04:00
parent e4d2fc055a
commit f49c584f83
1 changed files with 35 additions and 60 deletions

View File

@ -52,6 +52,41 @@ class PrimeField():
power_of_x = (power_of_x * x) % self.modulus
return y % self.modulus
# Arithmetic for polynomials
def add_polys(self, a, b):
return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0))
% self.modulus for i in range(max(len(a), len(b)))]
def sub_polys(self, a, b):
return [((a[i] if i < len(a) else 0) - (b[i] if i < len(b) else 0))
% self.modulus for i in range(max(len(a), len(b)))]
def mul_by_const(self, a, c):
return [(x*c) % self.modulus for x in a]
def mul_polys(self, a, b):
o = [0] * (len(a) + len(b) - 1)
for i, aval in enumerate(a):
for j, bval in enumerate(b):
o[i+j] += a[i] * b[j]
return [x % self.modulus for x in o]
def div_polys(self, a, b):
assert len(a) >= len(b)
a = [x for x in a]
o = []
apos = len(a) - 1
bpos = len(b) - 1
diff = apos - bpos
while diff >= 0:
quot = self.div(a[apos], b[bpos])
o.insert(0, quot)
for i in range(bpos, -1, -1):
a[diff+i] -= b[i] * quot
apos -= 1
diff -= 1
return [x % self.modulus for x in o]
# Build a polynomial that returns 0 at all specified xs
def zpoly(self, xs):
root = [1]
@ -124,63 +159,3 @@ class PrimeField():
inv_y0 = ys[0] * invall * e1
inv_y1 = ys[1] * invall * e0
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)]
# Arithmetic for polynomials
def add_polys(self, a, b):
return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0))
% self.modulus for i in range(max(len(a), len(b)))]
def sub_polys(self, a, b):
return [((a[i] if i < len(a) else 0) - (b[i] if i < len(b) else 0))
% self.modulus for i in range(max(len(a), len(b)))]
def mul_by_const(self, a, c):
return [(x*c) % self.modulus for x in a]
def mul_polys(self, a, b):
o = [0] * (len(a) + len(b) - 1)
for i, aval in enumerate(a):
for j, bval in enumerate(b):
o[i+j] += a[i] * b[j]
return [x % self.modulus for x in o]
def div_polys(self, a, b):
assert len(a) >= len(b)
a = [x for x in a]
o = []
apos = len(a) - 1
bpos = len(b) - 1
diff = apos - bpos
while diff >= 0:
quot = self.div(a[apos], b[bpos])
o.insert(0, quot)
for i in range(bpos, -1, -1):
a[diff+i] -= b[i] * quot
apos -= 1
diff -= 1
return [x % self.modulus for x in o]
# Divides a polynomial by x^n-1
def divide_by_xnm1(self, poly, n):
if len(poly) <= n:
return []
return self.add_polys(poly[n:], self.divide_by_xnm1(poly[n:], n))
# Returns P(x) = A(B(x))
def compose_polys(self, a, b):
o = []
p = [1]
for c in a:
o = self.add_polys(o, self.mul_by_const(p, c))
p = self.mul_polys(p, b)
return o
# Convert a polynomial P(x) into a polynomial Q(x) = P(fac * x)
# Equivalent to compose_polys(poly, [0, fac])
def multiply_base(self, poly, fac):
o = []
r = 1
for p in poly:
o.append(p * r % self.modulus)
r = r * fac % self.modulus
return o