From 88ea67aabe546034abc9e1b2598abb597eb76c99 Mon Sep 17 00:00:00 2001 From: vub Date: Mon, 20 Feb 2017 11:32:38 -0500 Subject: [PATCH] Added optimizations --- zksnark/bn128_pairing_test.py | 5 +- zksnark/optimized_curve.py | 138 ++++++++++++++++++++++++++++ zksnark/optimized_field_elements.py | 67 ++++++++++---- zksnark/optimized_pairing.py | 107 +++++++++++++++++++++ 4 files changed, 299 insertions(+), 18 deletions(-) create mode 100644 zksnark/optimized_curve.py create mode 100644 zksnark/optimized_pairing.py diff --git a/zksnark/bn128_pairing_test.py b/zksnark/bn128_pairing_test.py index 2e38019..3edfbc4 100644 --- a/zksnark/bn128_pairing_test.py +++ b/zksnark/bn128_pairing_test.py @@ -1,5 +1,7 @@ -from bn128_pairing import pairing, neg, G2, G1, multiply, FQ12, curve_order +from optimized_pairing import pairing, neg, G2, G1, multiply, FQ12, curve_order +import time +a = time.time() print('Starting tests') p1 = pairing(G2, G1) pn1 = pairing(G2, neg(G1)) @@ -23,3 +25,4 @@ p3 = pairing(multiply(G2, 27), multiply(G1, 37)) po3 = pairing(G2, multiply(G1, 999)) assert p3 == po3 print('Composite check passed') +print('Total time: %.3f' % (time.time() - a)) diff --git a/zksnark/optimized_curve.py b/zksnark/optimized_curve.py new file mode 100644 index 0000000..27d38f9 --- /dev/null +++ b/zksnark/optimized_curve.py @@ -0,0 +1,138 @@ +from bn128_field_elements import field_modulus, FQ +from optimized_field_elements import FQ2, FQ12 +# from bn128_field_elements import FQ2, FQ12 + +curve_order = 21888242871839275222246405745257275088548364400416034343698204186575808495617 + +# Curve order should be prime +assert pow(2, curve_order, curve_order) == 2 +# Curve order should be a factor of field_modulus**12 - 1 +assert (field_modulus ** 12 - 1) % curve_order == 0 + +# Curve is y**2 = x**3 + 3 +b = FQ(3) +# Twisted curve over FQ**2 +b2 = FQ2([3, 0]) / FQ2([0, 1]) +# Extension curve over FQ**12; same b value as over FQ +b12 = FQ12([3] + [0] * 11) + +# Generator for curve over FQ +G1 = (FQ(1), FQ(2), FQ(1)) +# Generator for twisted curve over FQ2 +G2 = (FQ2([16260673061341949275257563295988632869519996389676903622179081103440260644990, 11559732032986387107991004021392285783925812861821192530917403151452391805634]), + FQ2([15530828784031078730107954109694902500959150953518636601196686752670329677317, 4082367875863433681332203403145435568316851327593401208105741076214120093531]), FQ2.one()) + +# Check that a point is on the curve defined by y**2 == x**3 + b +def is_on_curve(pt, b): + if pt is None: + return True + x, y, z = pt + return y**2 * z - x**3 == b * z**3 + +assert is_on_curve(G1, b) +assert is_on_curve(G2, b2) + +# Elliptic curve doubling +def double(pt): + x, y, z = pt + W = 3 * x * x + S = y * z + B = x * y * S + H = W * W - 8 * B + S_squared = S * S + newx = 2 * H * S + newy = W * (4 * B - H) - 8 * y * y * S_squared + newz = 8 * S * S_squared + return newx, newy, newz + +# Elliptic curve addition +def add(p1, p2): + one, zero = p1[0].__class__.one(), p1[0].__class__.zero() + if p1[2] == zero or p2[2] == zero: + return p1 if zero else p2 + x1, y1, z1 = p1 + x2, y2, z2 = p2 + U1 = y2 * z1 + U2 = y1 * z2 + V1 = x2 * z1 + V2 = x1 * z2 + if V1 == V2 and U1 == U2: + return double(p1) + elif V1 == V2: + return (one, one, zero) + U = U1 - U2 + V = V1 - V2 + V_squared = V * V + V_squared_times_V2 = V_squared * V2 + V_cubed = V * V_squared + W = z1 * z2 + A = U * U * W - V_cubed - 2 * V_squared_times_V2 + newx = V * A + newy = U * (V_squared_times_V2 - A) - V_cubed * U2 + newz = V_cubed * W + return (newx, newy, newz) + +# Elliptic curve point multiplication +def multiply(pt, n): + if n == 0: + return None + elif n == 1: + return pt + elif not n % 2: + return multiply(double(pt), n // 2) + else: + return add(multiply(double(pt), int(n // 2)), pt) + +def eq(p1, p2): + x1, y1, z1 = p1 + x2, y2, z2 = p2 + return x1 * z2 == x2 * z1 and y1 * z2 == y2 * z1 + +def normalize(pt): + x, y, z = pt + return (x / z, y / z) + +# Check that the G1 curve works fine +assert eq(add(add(double(G1), G1), G1), double(double(G1))) +assert not eq(double(G1), G1) +assert eq(add(multiply(G1, 9), multiply(G1, 5)), add(multiply(G1, 12), multiply(G1, 2))) +assert eq(multiply(G1, curve_order), (1, 1, 0)) + +# Check that the G2 curve works fine +assert eq(add(add(double(G2), G2), G2), double(double(G2))) +assert not eq(double(G2), G2) +assert eq(add(multiply(G2, 9), multiply(G2, 5)), add(multiply(G2, 12), multiply(G2, 2))) +assert eq(multiply(G2, curve_order), (1, 1, 0)) +assert not eq(multiply(G2, 2 * field_modulus - curve_order), (1, 1, 0)) +assert is_on_curve(multiply(G2, 9), b2) + +# "Twist" a point in E(FQ2) into a point in E(FQ12) +w = FQ12([0, 1] + [0] * 10) + +# Convert P => -P +def neg(pt): + if pt is None: + return None + x, y, z = pt + return (x, -y, z) + +def twist(pt): + if pt is None: + return None + x, y, z = pt + nx = FQ12([x.coeffs[0]] + [0] * 5 + [x.coeffs[1]] + [0] * 5) + ny = FQ12([y.coeffs[0]] + [0] * 5 + [y.coeffs[1]] + [0] * 5) + nz = FQ12([z.coeffs[0]] + [0] * 5 + [z.coeffs[1]] + [0] * 5) + return (nx * w **2, ny * w**3, nz) + +# Check that the twist creates a point that is on the curve +assert is_on_curve(twist(G2), b12) + +# Check that the G12 curve works fine + +G12 = twist(G2) +assert eq(add(add(double(G12), G12), G12), double(double(G12))) +assert not eq(double(G12), G12) +assert eq(add(multiply(G12, 9), multiply(G12, 5)), add(multiply(G12, 12), multiply(G12, 2))) +assert is_on_curve(multiply(G12, 9), b12) +assert eq(multiply(G12, curve_order), (1, 1, 0)) diff --git a/zksnark/optimized_field_elements.py b/zksnark/optimized_field_elements.py index 0383555..34fd3e8 100644 --- a/zksnark/optimized_field_elements.py +++ b/zksnark/optimized_field_elements.py @@ -1,6 +1,8 @@ field_modulus = 21888242871839275222246405745257275088696311157297823662689037894645226208583 FQ2_modulus_coeffs = [82, -18] # Implied + [1] +FQ2_mc_tuples = [(0, 82), (1, -18)] FQ12_modulus_coeffs = [82, 0, 0, 0, 0, 0, -18, 0, 0, 0, 0, 0] # Implied + [1] +FQ12_mc_tuples = [(i, c) for i, c in enumerate(FQ12_modulus_coeffs) if c] # python3 compatibility try: @@ -39,6 +41,31 @@ def poly_rounded_div(a, b): temp[c + i] = (temp[c + i] - o[c]) return [x % field_modulus for x in o[:deg(o)+1]] +def karatsuba(a, b, c, d): + L = len(a) + EXTENDED_LEN = L * 2 - 1 + # phi = (a+b)(c+d) + # psi = (a-b)(c-d) + phi, psi, bd2 = [0] * EXTENDED_LEN, [0] * EXTENDED_LEN, [0] * EXTENDED_LEN + for i in range(L): + for j in range(L): + phi[i + j] += (a[i] + b[i]) * (c[j] + d[j]) + psi[i + j] += (a[i] - b[i]) * (c[j] - d[j]) + bd2[i + j] += b[i] * d[j] * 2 + o = [0] * (L * 4 - 1) + # L = (phi + psi - bd2) / 2 + # M = (phi - psi) / 2 + # H = bd2 / 2 + for i in range(L * 2 - 1): + o[i] += phi[i] + psi[i] - bd2[i] + o[i + L] += phi[i] - psi[i] + o[i + L * 2] += bd2[i] + inv_2 = (field_modulus + 1) // 2 + return [a * inv_2 if a % 2 else a // 2 for a in o] + +o = karatsuba([1, 3], [3, 1], [1, 3], [3, 1]) +assert [x % field_modulus for x in o] == [1, 6, 15, 20, 15, 6, 1] + # A class for elements in polynomial extension fields class FQP(): def __init__(self, coeffs, modulus_coeffs): @@ -61,15 +88,18 @@ class FQP(): if isinstance(other, (int, long)): return self.__class__([c * other % field_modulus for c in self.coeffs]) else: - assert isinstance(other, self.__class__) - b = [0 for i in range(self.degree * 2 - 1)] - for i in range(self.degree): - for j in range(self.degree): - b[i + j] += self.coeffs[i] * other.coeffs[j] - while len(b) > self.degree: - exp, top = len(b) - self.degree - 1, b.pop() - for i in range(self.degree): - b[exp + i] -= top * self.modulus_coeffs[i] + # assert isinstance(other, self.__class__) + b = [0] * (self.degree * 2 - 1) + inner_enumerate = list(enumerate(other.coeffs)) + for i, eli in enumerate(self.coeffs): + for j, elj in inner_enumerate: + b[i + j] += eli * elj + # MID = len(self.coeffs) // 2 + # b = karatsuba(self.coeffs[:MID], self.coeffs[MID:], other.coeffs[:MID], other.coeffs[MID:]) + for exp in range(self.degree - 2, -1, -1): + top = b.pop() + for i, c in self.mc_tuples: + b[exp + i] -= top * c return self.__class__([x % field_modulus for x in b]) def __rmul__(self, other): @@ -86,14 +116,14 @@ class FQP(): return self.__div__(other) def __pow__(self, other): - if other == 0: - return self.__class__([1] + [0] * (self.degree - 1)) - elif other == 1: - return self.__class__(self.coeffs) - elif other % 2 == 0: - return (self * self) ** (other // 2) - else: - return ((self * self) ** int(other // 2)) * self + o = self.__class__([1] + [0] * (self.degree - 1)) + t = self + while other > 0: + if other & 1: + o = o * t + other >>= 1 + t = t * t + return o # Extended euclidean algorithm used to find the modular inverse def inv(self): @@ -143,6 +173,7 @@ class FQ2(FQP): def __init__(self, coeffs): self.coeffs = coeffs self.modulus_coeffs = FQ2_modulus_coeffs + self.mc_tuples = FQ2_mc_tuples self.degree = 2 self.__class__.degree = 2 @@ -164,6 +195,7 @@ class FQcomplex(FQP): def __init__(self, coeffs): self.coeffs = coeffs self.modulus_coeffs = [1, 0] + self.mc_tuples = [(0, 1)] self.degree = 2 self.__class__.degree = 2 @@ -172,5 +204,6 @@ class FQ12(FQP): def __init__(self, coeffs): self.coeffs = coeffs self.modulus_coeffs = FQ12_modulus_coeffs + self.mc_tuples = FQ12_mc_tuples self.degree = 12 self.__class__.degree = 12 diff --git a/zksnark/optimized_pairing.py b/zksnark/optimized_pairing.py new file mode 100644 index 0000000..73aee02 --- /dev/null +++ b/zksnark/optimized_pairing.py @@ -0,0 +1,107 @@ +from optimized_curve import double, add, multiply, is_on_curve, neg, twist, b, b2, b12, curve_order, G1, G2, G12, normalize +from bn128_field_elements import field_modulus, FQ +from optimized_field_elements import FQ2, FQ12, FQcomplex +# from bn128_field_elements import FQ2, FQ12, FQcomplex + +ate_loop_count = 29793968203157093288 +log_ate_loop_count = 63 +pseudo_binary_encoding = [0, 0, 0, 1, 0, 1, 0, -1, 0, 0, 1, -1, 0, 0, 1, 0, + 0, 1, 1, 0, -1, 0, 0, 1, 0, -1, 0, 0, 0, 0, 1, 1, + 1, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, -1, 0, 0, 1, + 1, 0, 0, -1, 0, 0, 0, 1, 1, 0, -1, 0, 0, 1, 0, 1, 1] + +assert sum([e * 2**i for i, e in enumerate(pseudo_binary_encoding)]) == ate_loop_count + + +def normalize1(p): + x, y = normalize(p) + return x, y, x.__class__.one() + +# Create a function representing the line between P1 and P2, +# and evaluate it at T. Returns a numerator and a denominator +# to avoid unneeded divisions +def linefunc(P1, P2, T): + zero = P1[0].__class__.zero() + x1, y1, z1 = P1 + x2, y2, z2 = P2 + xt, yt, zt = T + # points in projective coords: (x / z, y / z) + # hence, m = (y2/z2 - y1/z1) / (x2/z2 - x1/z1) + # multiply numerator and denominator by z1z2 to get values below + m_numerator = y2 * z1 - y1 * z2 + m_denominator = x2 * z1 - x1 * z2 + if m_denominator != zero: + # m * ((xt/zt) - (x1/z1)) - ((yt/zt) - (y1/z1)) + return m_numerator * (xt * z1 - x1 * zt) - m_denominator * (yt * z1 - y1 * zt), \ + m_denominator * zt * z1 + elif m_numerator == zero: + # m = 3(x/z)^2 / 2(y/z), multiply num and den by z**2 + m_numerator = 3 * x1 * x1 + m_denominator = 2 * y1 * z1 + return m_numerator * (xt * z1 - x1 * zt) - m_denominator * (yt * z1 - y1 * zt), \ + m_denominator * zt * z1 + else: + return xt * z1 - x1 * zt, z1 * zt + +def cast_point_to_fq12(pt): + if pt is None: + return None + x, y, z = pt + return (FQ12([x.n] + [0] * 11), FQ12([y.n] + [0] * 11), FQ12([z.n] + [0] * 11)) + +# Check consistency of the "line function" +one, two, three = G1, double(G1), multiply(G1, 3) +negone, negtwo, negthree = multiply(G1, curve_order - 1), multiply(G1, curve_order - 2), multiply(G1, curve_order - 3) + +assert linefunc(one, two, one)[0] == FQ(0) +assert linefunc(one, two, two)[0] == FQ(0) +assert linefunc(one, two, three)[0] != FQ(0) +assert linefunc(one, two, negthree)[0] == FQ(0) +assert linefunc(one, negone, one)[0] == FQ(0) +assert linefunc(one, negone, negone)[0] == FQ(0) +assert linefunc(one, negone, two)[0] != FQ(0) +assert linefunc(one, one, one)[0] == FQ(0) +assert linefunc(one, one, two)[0] != FQ(0) +assert linefunc(one, one, negtwo)[0] == FQ(0) + +# Main miller loop +def miller_loop(Q, P): + if Q is None or P is None: + return FQ12.one() + R = Q + f_num, f_den = FQ12.one(), FQ12.one() + for b in pseudo_binary_encoding[63::-1]: + #for i in range(log_ate_loop_count, -1, -1): + _n, _d = linefunc(R, R, P) + f_num = f_num * f_num * _n + f_den = f_den * f_den * _d + R = double(R) + #if ate_loop_count & (2**i): + if b == 1: + _n, _d = linefunc(R, Q, P) + f_num = f_num * _n + f_den = f_den * _d + R = add(R, Q) + elif b == -1: + nQ = neg(Q) + _n, _d = linefunc(R, nQ, P) + f_num = f_num * _n + f_den = f_den * _d + R = add(R, nQ) + # assert R == multiply(Q, ate_loop_count) + Q1 = (Q[0] ** field_modulus, Q[1] ** field_modulus, Q[2] ** field_modulus) + # assert is_on_curve(Q1, b12) + nQ2 = (Q1[0] ** field_modulus, -Q1[1] ** field_modulus, Q1[2] ** field_modulus) + # assert is_on_curve(nQ2, b12) + _n1, _d1 = linefunc(R, Q1, P) + R = add(R, Q1) + _n2, _d2 = linefunc(R, nQ2, P) + f = f_num * _n1 * _n2 / (f_den * _d1 * _d2) + # R = add(R, nQ2) This line is in many specifications but it technically does nothing + return f ** ((field_modulus ** 12 - 1) // curve_order) + +# Pairing computation +def pairing(Q, P): + assert is_on_curve(Q, b2) + assert is_on_curve(P, b) + return miller_loop(twist(Q), cast_point_to_fq12(P))