diff --git a/zksnark/bn128_curve.py b/zksnark/bn128_curve.py index 9b8940f..c79638b 100644 --- a/zksnark/bn128_curve.py +++ b/zksnark/bn128_curve.py @@ -14,6 +14,8 @@ G2 = (FQ2([4, 0]), FQ2([16893045765507297706785249332518927989146279141265438554 # Check that a point is on the curve defined by y**2 == x**3 + b def is_on_curve(pt, b): + if pt is None: + return True x, y = pt return y**2 - x**3 == b @@ -81,3 +83,10 @@ def twist(pt): # Check that the twist creates a point that is on the curve assert is_on_curve(twist(G2), b12) + +# Check that the G12 curve works fine + +G12 = twist(G2) +assert add(add(double(G12), G12), G12) == double(double(G12)) +assert double(G12) != G12 +assert add(multiply(G12, 9), multiply(G12, 5)) == add(multiply(G12, 12), multiply(G12, 2)) diff --git a/zksnark/bn128_pairing.py b/zksnark/bn128_pairing.py index 41febd3..0f309ce 100644 --- a/zksnark/bn128_pairing.py +++ b/zksnark/bn128_pairing.py @@ -4,11 +4,12 @@ from bn128_curve import double, add, multiply, is_on_curve, twist, b, b2, b12, c from bn128_field_elements import field_modulus, FQ, FQ2, FQ12 ate_loop_count = 29793968203157093288 -log_ate_loop_count = 64 +log_ate_loop_count = 63 # Create a function representing the line between P1 and P2, # and evaluate it at T def linefunc(P1, P2, T): + assert P1 and P2 and T # No points-at-infinity allowed, sorry x1, y1 = P1 x2, y2 = P2 xt, yt = T @@ -28,27 +29,33 @@ def cast_point_to_fq12(pt): return (FQ12([x.n] + [0] * 11), FQ12([y.n] + [0] * 11)) # Check consistency of the "line function" -one, two, three, negone = G1, double(G1), multiply(G1, 3), multiply(G1, curve_order - 1) +one, two, three = G1, double(G1), multiply(G1, 3) +negone, negtwo, negthree = multiply(G1, curve_order - 1), multiply(G1, curve_order - 2), multiply(G1, curve_order - 3) assert linefunc(one, two, one) == FQ(0) assert linefunc(one, two, two) == FQ(0) assert linefunc(one, two, three) != FQ(0) +assert linefunc(one, two, negthree) == FQ(0) assert linefunc(one, negone, one) == FQ(0) assert linefunc(one, negone, negone) == FQ(0) assert linefunc(one, negone, two) != FQ(0) assert linefunc(one, one, one) == FQ(0) assert linefunc(one, one, two) != FQ(0) +assert linefunc(one, one, negtwo) == FQ(0) # Main miller loop def miller_loop(Q, P): + if Q is None or P is None: + return FQ12.one() R = Q f = FQ12.one() for i in range(log_ate_loop_count, -1, -1): - f = f * f / linefunc(R, R, P) + f = f * f * linefunc(R, R, P) R = double(R) if ate_loop_count & (2**i): f = f * linefunc(R, Q, P) R = add(R, Q) + assert R == multiply(Q, ate_loop_count) Q1 = (Q[0] ** field_modulus, Q[1] ** field_modulus) nQ2 = (Q[0] ** (field_modulus ** 2), -Q[1] ** (field_modulus ** 2)) f = f * linefunc(R, Q1, P)