diff --git a/mimc_stark/poly_utils.py b/mimc_stark/poly_utils.py index 2199b4c..61cd853 100644 --- a/mimc_stark/poly_utils.py +++ b/mimc_stark/poly_utils.py @@ -51,6 +51,41 @@ class PrimeField(): y += power_of_x * p_coeff power_of_x = (power_of_x * x) % self.modulus return y % self.modulus + + # Arithmetic for polynomials + def add_polys(self, a, b): + return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0)) + % self.modulus for i in range(max(len(a), len(b)))] + + def sub_polys(self, a, b): + return [((a[i] if i < len(a) else 0) - (b[i] if i < len(b) else 0)) + % self.modulus for i in range(max(len(a), len(b)))] + + def mul_by_const(self, a, c): + return [(x*c) % self.modulus for x in a] + + def mul_polys(self, a, b): + o = [0] * (len(a) + len(b) - 1) + for i, aval in enumerate(a): + for j, bval in enumerate(b): + o[i+j] += a[i] * b[j] + return [x % self.modulus for x in o] + + def div_polys(self, a, b): + assert len(a) >= len(b) + a = [x for x in a] + o = [] + apos = len(a) - 1 + bpos = len(b) - 1 + diff = apos - bpos + while diff >= 0: + quot = self.div(a[apos], b[bpos]) + o.insert(0, quot) + for i in range(bpos, -1, -1): + a[diff+i] -= b[i] * quot + apos -= 1 + diff -= 1 + return [x % self.modulus for x in o] # Build a polynomial that returns 0 at all specified xs def zpoly(self, xs): @@ -124,63 +159,3 @@ class PrimeField(): inv_y0 = ys[0] * invall * e1 inv_y1 = ys[1] * invall * e0 return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)] - - # Arithmetic for polynomials - def add_polys(self, a, b): - return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0)) - % self.modulus for i in range(max(len(a), len(b)))] - - def sub_polys(self, a, b): - return [((a[i] if i < len(a) else 0) - (b[i] if i < len(b) else 0)) - % self.modulus for i in range(max(len(a), len(b)))] - - def mul_by_const(self, a, c): - return [(x*c) % self.modulus for x in a] - - def mul_polys(self, a, b): - o = [0] * (len(a) + len(b) - 1) - for i, aval in enumerate(a): - for j, bval in enumerate(b): - o[i+j] += a[i] * b[j] - return [x % self.modulus for x in o] - - def div_polys(self, a, b): - assert len(a) >= len(b) - a = [x for x in a] - o = [] - apos = len(a) - 1 - bpos = len(b) - 1 - diff = apos - bpos - while diff >= 0: - quot = self.div(a[apos], b[bpos]) - o.insert(0, quot) - for i in range(bpos, -1, -1): - a[diff+i] -= b[i] * quot - apos -= 1 - diff -= 1 - return [x % self.modulus for x in o] - - # Divides a polynomial by x^n-1 - def divide_by_xnm1(self, poly, n): - if len(poly) <= n: - return [] - return self.add_polys(poly[n:], self.divide_by_xnm1(poly[n:], n)) - - # Returns P(x) = A(B(x)) - def compose_polys(self, a, b): - o = [] - p = [1] - for c in a: - o = self.add_polys(o, self.mul_by_const(p, c)) - p = self.mul_polys(p, b) - return o - - # Convert a polynomial P(x) into a polynomial Q(x) = P(fac * x) - # Equivalent to compose_polys(poly, [0, fac]) - def multiply_base(self, poly, fac): - o = [] - r = 1 - for p in poly: - o.append(p * r % self.modulus) - r = r * fac % self.modulus - return o