From 711bd9532b4534ef5ae6277bd7afe625195506d5 Mon Sep 17 00:00:00 2001 From: vub Date: Sat, 24 Dec 2016 09:51:55 -0500 Subject: [PATCH] Added field and curve preliminaries for pairing math --- zksnark/bn128_curve.py | 84 +++++++++++ zksnark/bn128_field_elements.py | 259 ++++++++++++++++++++++++++++++++ 2 files changed, 343 insertions(+) create mode 100644 zksnark/bn128_curve.py create mode 100644 zksnark/bn128_field_elements.py diff --git a/zksnark/bn128_curve.py b/zksnark/bn128_curve.py new file mode 100644 index 0000000..8827193 --- /dev/null +++ b/zksnark/bn128_curve.py @@ -0,0 +1,84 @@ +from bn128_field_elements import field_modulus, FQ, FQ2, FQ12 + +curve_order = 21888242871839275222246405745257275088548364400416034343698204186575808495617 + +# Curve is y**2 = x**3 + 3 +b = FQ(3) +b2 = FQ2([3, 0]) +b12 = FQ12([3] + [0] * 11) / FQ12([0] * 6 + [1] + [0] * 5) + +ate_loop_count = 29793968203157093288 + +G1 = (FQ(1), FQ(2)) +# Second element corresponds to modsqrt(67) * i in our quadratic field representation +G2 = (FQ2([4, 0]), FQ2([16893045765507297706785249332518927989146279141265438554111591828131739815230L, 16469166999615883226695964867118064280147127342783597836693979910667010785192])) + +# Check that a point is on the curve defined by y**2 == x**3 + b +def is_on_curve(pt, b): + x, y = pt + return y**2 - x**3 == b + +assert is_on_curve(G1, b) +assert is_on_curve(G2, b2) + +# Elliptic curve doubling +def double(pt): + x, y = pt + l = 3 * x**2 / (2 * y) + newx = l**2 - 2 * x + newy = -l * newx + l * x - y + return newx, newy + +# Elliptic curve addition +def add(p1, p2): + if p1 is None or p2 is None: + return p1 if p2 is None else p2 + x1, y1 = p1 + x2, y2 = p2 + if x2 == x1 and y2 == y1: + return double(p1) + elif x2 == x1: + return None + else: + l = (y2 - y1) / (x2 - x1) + newx = l**2 - x1 - x2 + newy = -l * newx + l * x1 - y1 + assert newy == (-l * newx + l * x2 - y2) + return (newx, newy) + +# Elliptic curve point multiplication +def multiply(pt, n): + if n == 0: + return None + elif n == 1: + return pt + elif not n % 2: + return multiply(double(pt), n / 2) + else: + return add(multiply(double(pt), int(n / 2)), pt) + +# Check that the G1 curve works fine +assert add(add(double(G1), G1), G1) == double(double(G1)) +assert double(G1) != G1 +assert add(multiply(G1, 9), multiply(G1, 5)) == add(multiply(G1, 12), multiply(G1, 2)) +assert multiply(G1, curve_order) is None + +# Check that the G2 curve works fine +assert add(add(double(G2), G2), G2) == double(double(G2)) +assert double(G2) != G2 +assert add(multiply(G2, 9), multiply(G2, 5)) == add(multiply(G2, 12), multiply(G2, 2)) +assert multiply(G2, 2 * field_modulus - curve_order) is not None + +# "Twist" a point in E(FQ2) into a point in E(FQ12) +w = FQ12([0, 1] + [0] * 10) + +def twist(pt): + if pt is None: + return None + x, y = pt + nx = FQ12([x.coeffs[0]] + [0] * 5 + [x.coeffs[1]] + [0] * 5) + ny = FQ12([y.coeffs[0]] + [0] * 5 + [y.coeffs[1]] + [0] * 5) + return (nx / w **2, ny / w**3) + +# Check that the twist creates a point that is on the curve +assert is_on_curve(twist(G2), b12) diff --git a/zksnark/bn128_field_elements.py b/zksnark/bn128_field_elements.py new file mode 100644 index 0000000..58277b0 --- /dev/null +++ b/zksnark/bn128_field_elements.py @@ -0,0 +1,259 @@ +import sys +sys.setrecursionlimit(10000) + +# The prime modulus of the field +field_modulus = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# See, it's prime! +assert pow(2, field_modulus, field_modulus) == 2 + +# The modulus of the polynomial in this representation of FQ2 +FQ2_modulus_coeffs = [82, 18] # Implied + [1] +# And in FQ12 +FQ12_modulus_coeffs = [82, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0] # Implied + [1] + +# Extended euclidean algorithm to find modular inverses for +# integers +def inv(a, n): + if a == 0: + return 0 + lm, hm = 1, 0 + low, high = a % n, n + while low > 1: + r = high//low + nm, new = hm-lm*r, high-low*r + lm, low, hm, high = nm, new, lm, low + return lm % n + +# A class for field elements in FQ. Wrap a number in this class, +# and it becomes a field element. +class FQ(): + def __init__(self, n): + if isinstance(n, self.__class__): + self.n = n.n + else: + self.n = n % field_modulus + assert isinstance(self.n, (int, long)) + + def __add__(self, other): + on = other.n if isinstance(other, FQ) else other + return FQ((self.n + on) % field_modulus) + + def __mul__(self, other): + on = other.n if isinstance(other, FQ) else other + return FQ((self.n * on) % field_modulus) + + def __rmul__(self, other): + return self * other + + def __radd__(self, other): + return self + other + + def __rsub__(self, other): + on = other.n if isinstance(other, FQ) else other + return FQ((on - self.n) % field_modulus) + + def __sub__(self, other): + on = other.n if isinstance(other, FQ) else other + return FQ((self.n - on) % field_modulus) + + def __div__(self, other): + on = other.n if isinstance(other, FQ) else other + assert isinstance(on, (int, long)) + return FQ(self.n * inv(on, field_modulus) % field_modulus) + + def __rdiv__(self, other): + on = other.n if isinstance(other, FQ) else other + assert isinstance(on, (int, long)), on + return FQ(inv(self.n, field_modulus) * on % field_modulus) + + def __pow__(self, other): + if other == 0: + return FQ(1) + elif other == 1: + return FQ(self.n) + elif other % 2 == 0: + return (self * self) ** (other / 2) + else: + return ((self * self) ** int(other / 2)) * self + + def __eq__(self, other): + if isinstance(other, FQ): + return self.n == other.n + else: + return self.n == other + + def __ne__(self, other): + return not self == other + + def __neg__(self): + return FQ(-self.n) + + def __repr__(self): + return repr(self.n) + + @classmethod + def one(cls): + return cls(1) + + @classmethod + def zero(cls): + return cls(0) + +# Check that the field works fine +assert FQ(2) * FQ(2) == FQ(4) +assert FQ(2) / FQ(7) + FQ(9) / FQ(7) == FQ(11) / FQ(7) +assert FQ(2) * FQ(7) + FQ(9) * FQ(7) == FQ(11) * FQ(7) +assert FQ(9) ** field_modulus == FQ(9) + +# Utility methods for polynomial math +def deg(p): + d = len(p) - 1 + while p[d] == 0 and d: + d -= 1 + return d + +def poly_rounded_div(a, b): + dega = deg(a) + degb = deg(b) + temp = [x for x in a] + o = [0 for x in a] + for i in range(dega - degb, -1, -1): + o[i] += temp[degb + i] / b[degb] + for c in range(degb + 1): + temp[c + i] -= o[c] + return o[:deg(o)+1] + +# A class for elements in polynomial extension fields +class FQP(): + def __init__(self, coeffs, modulus_coeffs): + assert len(coeffs) == len(modulus_coeffs) + self.coeffs = [FQ(c) for c in coeffs] + # The coefficients of the modulus, without the leading [1] + self.modulus_coeffs = modulus_coeffs + # The degree of the extension field + self.degree = len(self.modulus_coeffs) + + def __add__(self, other): + assert isinstance(other, self.__class__) + return self.__class__([x+y for x,y in zip(self.coeffs, other.coeffs)]) + + def __sub__(self, other): + assert isinstance(other, self.__class__) + return self.__class__([x-y for x,y in zip(self.coeffs, other.coeffs)]) + + def __mul__(self, other): + if isinstance(other, (FQ, int, long)): + return self.__class__([c * other for c in self.coeffs]) + else: + assert isinstance(other, self.__class__) + b = [FQ(0) for i in range(self.degree * 2 - 1)] + for i in range(self.degree): + for j in range(self.degree): + b[i + j] += self.coeffs[i] * other.coeffs[j] + while len(b) > self.degree: + exp, top = len(b) - self.degree - 1, b.pop() + for i in range(self.degree): + b[exp + i] -= top * FQ(self.modulus_coeffs[i]) + return self.__class__(b) + + def __rmul__(self, other): + return self * other + + def __div__(self, other): + if isinstance(other, (FQ, int, long)): + return self.__class__([c / other for c in self.coeffs]) + else: + assert isinstance(other, self.__class__) + return self * other.inv() + + def __pow__(self, other): + if other == 0: + return self.__class__([1] + [0] * (self.degree - 1)) + elif other == 1: + return self.__class__(self.coeffs) + elif other % 2 == 0: + return (self * self) ** (other / 2) + else: + return ((self * self) ** int(other / 2)) * self + + # Extended euclidean algorithm used to find the modular inverse + def inv(self): + lm, hm = [1] + [0] * self.degree, [0] * (self.degree + 1) + low, high = self.coeffs + [0], self.modulus_coeffs + [1] + while deg(low): + r = poly_rounded_div(high, low) + r += [0] * (self.degree + 1 - len(r)) + nm = [x for x in hm] + new = [x for x in high] + assert len(lm) == len(hm) == len(low) == len(high) == len(nm) == len(new) == self.degree + 1 + for i in range(self.degree + 1): + for j in range(self.degree + 1 - i): + nm[i+j] -= lm[i] * r[j] + new[i+j] -= low[i] * r[j] + lm, low, hm, high = nm, new, lm, low + return self.__class__(lm[:self.degree]) / low[0] + + def __repr__(self): + return repr(self.coeffs) + + def __eq__(self, other): + assert isinstance(other, self.__class__) + for c1, c2 in zip(self.coeffs, other.coeffs): + if c1 != c2: + return False + return True + + def __ne__(self, other): + return not self == other + + def __neg__(self): + return self.__class__([-c for c in self.coeffs]) + + @classmethod + def one(cls): + return cls([1] + [0] * (cls.degree - 1)) + + @classmethod + def zero(cls): + return cls([0] * cls.degree) + +# The quadratic extension field +class FQ2(FQP): + def __init__(self, coeffs): + self.coeffs = [FQ(c) for c in coeffs] + self.modulus_coeffs = FQ2_modulus_coeffs + self.degree = 2 + self.__class__.degree = 2 + +x = FQ2([1, 0]) +f = FQ2([1, 2]) +fpx = FQ2([2, 2]) +one = FQ2.one() + +# Check that the field works fine +assert x + f == fpx +assert f / f == one +assert one / f + x / f == (one + x) / f +assert one * f + x * f == (one + x) * f +assert x ** (field_modulus ** 2 - 1) == one + +# The 12th-degree extension field +class FQ12(FQP): + def __init__(self, coeffs): + self.coeffs = [FQ(c) for c in coeffs] + self.modulus_coeffs = FQ12_modulus_coeffs + self.degree = 12 + self.__class__.degree = 12 + +x = FQ12([1] + [0] * 11) +f = FQ12([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) +fpx = FQ12([2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) +one = FQ12.one() + +# Check that the field works fine +assert x + f == fpx +assert f / f == one +assert one / f + x / f == (one + x) / f +assert one * f + x * f == (one + x) * f +# This check takes too long +# assert x ** (field_modulus ** 12 - 1) == one