Added optimizations

This commit is contained in:
vub 2017-02-20 11:32:38 -05:00
parent 78f5968333
commit 88ea67aabe
4 changed files with 299 additions and 18 deletions

View File

@ -1,5 +1,7 @@
from bn128_pairing import pairing, neg, G2, G1, multiply, FQ12, curve_order
from optimized_pairing import pairing, neg, G2, G1, multiply, FQ12, curve_order
import time
a = time.time()
print('Starting tests')
p1 = pairing(G2, G1)
pn1 = pairing(G2, neg(G1))
@ -23,3 +25,4 @@ p3 = pairing(multiply(G2, 27), multiply(G1, 37))
po3 = pairing(G2, multiply(G1, 999))
assert p3 == po3
print('Composite check passed')
print('Total time: %.3f' % (time.time() - a))

138
zksnark/optimized_curve.py Normal file
View File

@ -0,0 +1,138 @@
from bn128_field_elements import field_modulus, FQ
from optimized_field_elements import FQ2, FQ12
# from bn128_field_elements import FQ2, FQ12
curve_order = 21888242871839275222246405745257275088548364400416034343698204186575808495617
# Curve order should be prime
assert pow(2, curve_order, curve_order) == 2
# Curve order should be a factor of field_modulus**12 - 1
assert (field_modulus ** 12 - 1) % curve_order == 0
# Curve is y**2 = x**3 + 3
b = FQ(3)
# Twisted curve over FQ**2
b2 = FQ2([3, 0]) / FQ2([0, 1])
# Extension curve over FQ**12; same b value as over FQ
b12 = FQ12([3] + [0] * 11)
# Generator for curve over FQ
G1 = (FQ(1), FQ(2), FQ(1))
# Generator for twisted curve over FQ2
G2 = (FQ2([16260673061341949275257563295988632869519996389676903622179081103440260644990, 11559732032986387107991004021392285783925812861821192530917403151452391805634]),
FQ2([15530828784031078730107954109694902500959150953518636601196686752670329677317, 4082367875863433681332203403145435568316851327593401208105741076214120093531]), FQ2.one())
# Check that a point is on the curve defined by y**2 == x**3 + b
def is_on_curve(pt, b):
if pt is None:
return True
x, y, z = pt
return y**2 * z - x**3 == b * z**3
assert is_on_curve(G1, b)
assert is_on_curve(G2, b2)
# Elliptic curve doubling
def double(pt):
x, y, z = pt
W = 3 * x * x
S = y * z
B = x * y * S
H = W * W - 8 * B
S_squared = S * S
newx = 2 * H * S
newy = W * (4 * B - H) - 8 * y * y * S_squared
newz = 8 * S * S_squared
return newx, newy, newz
# Elliptic curve addition
def add(p1, p2):
one, zero = p1[0].__class__.one(), p1[0].__class__.zero()
if p1[2] == zero or p2[2] == zero:
return p1 if zero else p2
x1, y1, z1 = p1
x2, y2, z2 = p2
U1 = y2 * z1
U2 = y1 * z2
V1 = x2 * z1
V2 = x1 * z2
if V1 == V2 and U1 == U2:
return double(p1)
elif V1 == V2:
return (one, one, zero)
U = U1 - U2
V = V1 - V2
V_squared = V * V
V_squared_times_V2 = V_squared * V2
V_cubed = V * V_squared
W = z1 * z2
A = U * U * W - V_cubed - 2 * V_squared_times_V2
newx = V * A
newy = U * (V_squared_times_V2 - A) - V_cubed * U2
newz = V_cubed * W
return (newx, newy, newz)
# Elliptic curve point multiplication
def multiply(pt, n):
if n == 0:
return None
elif n == 1:
return pt
elif not n % 2:
return multiply(double(pt), n // 2)
else:
return add(multiply(double(pt), int(n // 2)), pt)
def eq(p1, p2):
x1, y1, z1 = p1
x2, y2, z2 = p2
return x1 * z2 == x2 * z1 and y1 * z2 == y2 * z1
def normalize(pt):
x, y, z = pt
return (x / z, y / z)
# Check that the G1 curve works fine
assert eq(add(add(double(G1), G1), G1), double(double(G1)))
assert not eq(double(G1), G1)
assert eq(add(multiply(G1, 9), multiply(G1, 5)), add(multiply(G1, 12), multiply(G1, 2)))
assert eq(multiply(G1, curve_order), (1, 1, 0))
# Check that the G2 curve works fine
assert eq(add(add(double(G2), G2), G2), double(double(G2)))
assert not eq(double(G2), G2)
assert eq(add(multiply(G2, 9), multiply(G2, 5)), add(multiply(G2, 12), multiply(G2, 2)))
assert eq(multiply(G2, curve_order), (1, 1, 0))
assert not eq(multiply(G2, 2 * field_modulus - curve_order), (1, 1, 0))
assert is_on_curve(multiply(G2, 9), b2)
# "Twist" a point in E(FQ2) into a point in E(FQ12)
w = FQ12([0, 1] + [0] * 10)
# Convert P => -P
def neg(pt):
if pt is None:
return None
x, y, z = pt
return (x, -y, z)
def twist(pt):
if pt is None:
return None
x, y, z = pt
nx = FQ12([x.coeffs[0]] + [0] * 5 + [x.coeffs[1]] + [0] * 5)
ny = FQ12([y.coeffs[0]] + [0] * 5 + [y.coeffs[1]] + [0] * 5)
nz = FQ12([z.coeffs[0]] + [0] * 5 + [z.coeffs[1]] + [0] * 5)
return (nx * w **2, ny * w**3, nz)
# Check that the twist creates a point that is on the curve
assert is_on_curve(twist(G2), b12)
# Check that the G12 curve works fine
G12 = twist(G2)
assert eq(add(add(double(G12), G12), G12), double(double(G12)))
assert not eq(double(G12), G12)
assert eq(add(multiply(G12, 9), multiply(G12, 5)), add(multiply(G12, 12), multiply(G12, 2)))
assert is_on_curve(multiply(G12, 9), b12)
assert eq(multiply(G12, curve_order), (1, 1, 0))

View File

@ -1,6 +1,8 @@
field_modulus = 21888242871839275222246405745257275088696311157297823662689037894645226208583
FQ2_modulus_coeffs = [82, -18] # Implied + [1]
FQ2_mc_tuples = [(0, 82), (1, -18)]
FQ12_modulus_coeffs = [82, 0, 0, 0, 0, 0, -18, 0, 0, 0, 0, 0] # Implied + [1]
FQ12_mc_tuples = [(i, c) for i, c in enumerate(FQ12_modulus_coeffs) if c]
# python3 compatibility
try:
@ -39,6 +41,31 @@ def poly_rounded_div(a, b):
temp[c + i] = (temp[c + i] - o[c])
return [x % field_modulus for x in o[:deg(o)+1]]
def karatsuba(a, b, c, d):
L = len(a)
EXTENDED_LEN = L * 2 - 1
# phi = (a+b)(c+d)
# psi = (a-b)(c-d)
phi, psi, bd2 = [0] * EXTENDED_LEN, [0] * EXTENDED_LEN, [0] * EXTENDED_LEN
for i in range(L):
for j in range(L):
phi[i + j] += (a[i] + b[i]) * (c[j] + d[j])
psi[i + j] += (a[i] - b[i]) * (c[j] - d[j])
bd2[i + j] += b[i] * d[j] * 2
o = [0] * (L * 4 - 1)
# L = (phi + psi - bd2) / 2
# M = (phi - psi) / 2
# H = bd2 / 2
for i in range(L * 2 - 1):
o[i] += phi[i] + psi[i] - bd2[i]
o[i + L] += phi[i] - psi[i]
o[i + L * 2] += bd2[i]
inv_2 = (field_modulus + 1) // 2
return [a * inv_2 if a % 2 else a // 2 for a in o]
o = karatsuba([1, 3], [3, 1], [1, 3], [3, 1])
assert [x % field_modulus for x in o] == [1, 6, 15, 20, 15, 6, 1]
# A class for elements in polynomial extension fields
class FQP():
def __init__(self, coeffs, modulus_coeffs):
@ -61,15 +88,18 @@ class FQP():
if isinstance(other, (int, long)):
return self.__class__([c * other % field_modulus for c in self.coeffs])
else:
assert isinstance(other, self.__class__)
b = [0 for i in range(self.degree * 2 - 1)]
for i in range(self.degree):
for j in range(self.degree):
b[i + j] += self.coeffs[i] * other.coeffs[j]
while len(b) > self.degree:
exp, top = len(b) - self.degree - 1, b.pop()
for i in range(self.degree):
b[exp + i] -= top * self.modulus_coeffs[i]
# assert isinstance(other, self.__class__)
b = [0] * (self.degree * 2 - 1)
inner_enumerate = list(enumerate(other.coeffs))
for i, eli in enumerate(self.coeffs):
for j, elj in inner_enumerate:
b[i + j] += eli * elj
# MID = len(self.coeffs) // 2
# b = karatsuba(self.coeffs[:MID], self.coeffs[MID:], other.coeffs[:MID], other.coeffs[MID:])
for exp in range(self.degree - 2, -1, -1):
top = b.pop()
for i, c in self.mc_tuples:
b[exp + i] -= top * c
return self.__class__([x % field_modulus for x in b])
def __rmul__(self, other):
@ -86,14 +116,14 @@ class FQP():
return self.__div__(other)
def __pow__(self, other):
if other == 0:
return self.__class__([1] + [0] * (self.degree - 1))
elif other == 1:
return self.__class__(self.coeffs)
elif other % 2 == 0:
return (self * self) ** (other // 2)
else:
return ((self * self) ** int(other // 2)) * self
o = self.__class__([1] + [0] * (self.degree - 1))
t = self
while other > 0:
if other & 1:
o = o * t
other >>= 1
t = t * t
return o
# Extended euclidean algorithm used to find the modular inverse
def inv(self):
@ -143,6 +173,7 @@ class FQ2(FQP):
def __init__(self, coeffs):
self.coeffs = coeffs
self.modulus_coeffs = FQ2_modulus_coeffs
self.mc_tuples = FQ2_mc_tuples
self.degree = 2
self.__class__.degree = 2
@ -164,6 +195,7 @@ class FQcomplex(FQP):
def __init__(self, coeffs):
self.coeffs = coeffs
self.modulus_coeffs = [1, 0]
self.mc_tuples = [(0, 1)]
self.degree = 2
self.__class__.degree = 2
@ -172,5 +204,6 @@ class FQ12(FQP):
def __init__(self, coeffs):
self.coeffs = coeffs
self.modulus_coeffs = FQ12_modulus_coeffs
self.mc_tuples = FQ12_mc_tuples
self.degree = 12
self.__class__.degree = 12

View File

@ -0,0 +1,107 @@
from optimized_curve import double, add, multiply, is_on_curve, neg, twist, b, b2, b12, curve_order, G1, G2, G12, normalize
from bn128_field_elements import field_modulus, FQ
from optimized_field_elements import FQ2, FQ12, FQcomplex
# from bn128_field_elements import FQ2, FQ12, FQcomplex
ate_loop_count = 29793968203157093288
log_ate_loop_count = 63
pseudo_binary_encoding = [0, 0, 0, 1, 0, 1, 0, -1, 0, 0, 1, -1, 0, 0, 1, 0,
0, 1, 1, 0, -1, 0, 0, 1, 0, -1, 0, 0, 0, 0, 1, 1,
1, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, -1, 0, 0, 1,
1, 0, 0, -1, 0, 0, 0, 1, 1, 0, -1, 0, 0, 1, 0, 1, 1]
assert sum([e * 2**i for i, e in enumerate(pseudo_binary_encoding)]) == ate_loop_count
def normalize1(p):
x, y = normalize(p)
return x, y, x.__class__.one()
# Create a function representing the line between P1 and P2,
# and evaluate it at T. Returns a numerator and a denominator
# to avoid unneeded divisions
def linefunc(P1, P2, T):
zero = P1[0].__class__.zero()
x1, y1, z1 = P1
x2, y2, z2 = P2
xt, yt, zt = T
# points in projective coords: (x / z, y / z)
# hence, m = (y2/z2 - y1/z1) / (x2/z2 - x1/z1)
# multiply numerator and denominator by z1z2 to get values below
m_numerator = y2 * z1 - y1 * z2
m_denominator = x2 * z1 - x1 * z2
if m_denominator != zero:
# m * ((xt/zt) - (x1/z1)) - ((yt/zt) - (y1/z1))
return m_numerator * (xt * z1 - x1 * zt) - m_denominator * (yt * z1 - y1 * zt), \
m_denominator * zt * z1
elif m_numerator == zero:
# m = 3(x/z)^2 / 2(y/z), multiply num and den by z**2
m_numerator = 3 * x1 * x1
m_denominator = 2 * y1 * z1
return m_numerator * (xt * z1 - x1 * zt) - m_denominator * (yt * z1 - y1 * zt), \
m_denominator * zt * z1
else:
return xt * z1 - x1 * zt, z1 * zt
def cast_point_to_fq12(pt):
if pt is None:
return None
x, y, z = pt
return (FQ12([x.n] + [0] * 11), FQ12([y.n] + [0] * 11), FQ12([z.n] + [0] * 11))
# Check consistency of the "line function"
one, two, three = G1, double(G1), multiply(G1, 3)
negone, negtwo, negthree = multiply(G1, curve_order - 1), multiply(G1, curve_order - 2), multiply(G1, curve_order - 3)
assert linefunc(one, two, one)[0] == FQ(0)
assert linefunc(one, two, two)[0] == FQ(0)
assert linefunc(one, two, three)[0] != FQ(0)
assert linefunc(one, two, negthree)[0] == FQ(0)
assert linefunc(one, negone, one)[0] == FQ(0)
assert linefunc(one, negone, negone)[0] == FQ(0)
assert linefunc(one, negone, two)[0] != FQ(0)
assert linefunc(one, one, one)[0] == FQ(0)
assert linefunc(one, one, two)[0] != FQ(0)
assert linefunc(one, one, negtwo)[0] == FQ(0)
# Main miller loop
def miller_loop(Q, P):
if Q is None or P is None:
return FQ12.one()
R = Q
f_num, f_den = FQ12.one(), FQ12.one()
for b in pseudo_binary_encoding[63::-1]:
#for i in range(log_ate_loop_count, -1, -1):
_n, _d = linefunc(R, R, P)
f_num = f_num * f_num * _n
f_den = f_den * f_den * _d
R = double(R)
#if ate_loop_count & (2**i):
if b == 1:
_n, _d = linefunc(R, Q, P)
f_num = f_num * _n
f_den = f_den * _d
R = add(R, Q)
elif b == -1:
nQ = neg(Q)
_n, _d = linefunc(R, nQ, P)
f_num = f_num * _n
f_den = f_den * _d
R = add(R, nQ)
# assert R == multiply(Q, ate_loop_count)
Q1 = (Q[0] ** field_modulus, Q[1] ** field_modulus, Q[2] ** field_modulus)
# assert is_on_curve(Q1, b12)
nQ2 = (Q1[0] ** field_modulus, -Q1[1] ** field_modulus, Q1[2] ** field_modulus)
# assert is_on_curve(nQ2, b12)
_n1, _d1 = linefunc(R, Q1, P)
R = add(R, Q1)
_n2, _d2 = linefunc(R, nQ2, P)
f = f_num * _n1 * _n2 / (f_den * _d1 * _d2)
# R = add(R, nQ2) This line is in many specifications but it technically does nothing
return f ** ((field_modulus ** 12 - 1) // curve_order)
# Pairing computation
def pairing(Q, P):
assert is_on_curve(Q, b2)
assert is_on_curve(P, b)
return miller_loop(twist(Q), cast_point_to_fq12(P))