constantine/sage/lattice_decomposition_bn254...

376 lines
11 KiB
Python
Raw Normal View History

# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
# ############################################################
#
# BN254-Snarks GLS Endomorphism
# Lattice Decomposition
#
# ############################################################
# Parameters
x = Integer('0x44E992B44A6909F1')
p = 36*x^4 + 36*x^3 + 24*x^2 + 6*x + 1
r = 36*x^4 + 36*x^3 + 18*x^2 + 6*x + 1
t = 6*x^2 + 1
print(' Prime modulus p: 0x' + p.hex())
print(' Curve order r: 0x' + r.hex())
print(' trace t: 0x' + t.hex())
# Finite fields
Fp = GF(p)
K2.<u> = PolynomialRing(Fp)
Fp2.<beta> = Fp.extension(u^2+1)
SNR = Fp2([9, 1]) # Sextic Non-Residue for Sextic Twist
# Curves
b = 3
G1 = EllipticCurve(Fp, [0, b])
G2 = EllipticCurve(Fp2, [0, b/SNR])
# https://crypto.stackexchange.com/questions/64064/order-of-twisted-curve-in-pairings
# https://math.stackexchange.com/questions/144194/how-to-find-the-order-of-elliptic-curve-over-finite-field-extension
cofactorG1 = G1.order() // r
cofactorG2 = G2.order() // r
print('')
print('cofactor G1: ' + cofactorG1.hex())
print('cofactor G2: ' + cofactorG2.hex())
print('')
# Frobenius constants (D type: use SNR, M type use 1/SNR)
FrobConst_psi = SNR^((p-1)/6)
FrobConst_psi_2 = FrobConst_psi * FrobConst_psi
FrobConst_psi_3 = FrobConst_psi_2 * FrobConst_psi
FrobConst_psi2_2 = FrobConst_psi_2 * FrobConst_psi_2^p
FrobConst_psi2_3 = FrobConst_psi_3 * FrobConst_psi_3^p
def psi(P):
(Px, Py, Pz) = P
return G2([
FrobConst_psi_2 * Px.frobenius(1),
FrobConst_psi_3 * Py.frobenius(1)
# Pz.frobenius() - Always 1 after extract
])
def psi2(P):
(Px, Py, Pz) = P
return G2([
FrobConst_psi2_2 * Px.frobenius(2),
FrobConst_psi2_3 * Py.frobenius(2)
# Pz - Always 1 after extract
])
def clearCofactorG2(P):
return cofactorG2 * P
# Test generator
set_random_seed(1337)
# Check
def checkEndo():
P = G2.random_point()
P = clearCofactorG2(P)
(Px, Py, Pz) = P
# Galbraith-Lin-Scott, 2008, Theorem 1
assert psi(psi(P)) - t*psi(P) + p*P == G2([0, 1, 0])
# Galbraith-Scott, 2008, Lemma 1
# k-th cyclotomic polynomial with k = 12
assert psi2(psi2(P)) - psi2(P) + P == G2([0, 1, 0])
assert p % r == (t-1) % r
# assert (p^4 - p^2 + 1) % r == 0
assert ((t-1)^4 - (t-1)^2 + 1) % r == 0
assert (t-1)*P == (p % r)*P
assert (t-1)*P == psi(P)
print('Endomorphism OK')
checkEndo()
# Decomposition generated by LLL-algorithm and Babai rounding
# to solve the Shortest (Basis) Vector Problem
# TODO: This Lattice from Guide to Pairing-Based Cryptography
# gives miniscalars bigger than (r/4)+1 = 65 bits
# for unknown reason
# The linear combination is correct though.
# Lattice from Guide to Pairing-Based Cryptography
# Lat = [
# [ - x , 2*x , 2*x+1, - x ],
# [ -2*x-1, x , x+1, x ],
# [ 2*x+1, 0, 2*x , 1],
# [ -1, 2*x+1, 1, 2*x ]
# ]
# ahat = [2*x+1, -(12*x^3+6*x^2+2*x+1), 2*x*(3*x^2+3*x+1), 6*x^2-1]
# Lattice from Galbraith-Scott 2008
Lat = Matrix([
[ x+1, x , x , -2*x ],
[ 2*x+1, -x , -x-1, -x ],
[ 2*x , 2*x+1, 2*x+1, 2*x+1],
[ x-1, 4*x+2, -2*x+1, x-1]
])
ahat = [2*x^2+3*x+1, 12*x^3+8*x^2+x, 6*x^3+4*x^2+x, -2*x^2-x]
n = int(r).bit_length()
n = int(((n + 64 - 1) // 64) * 64) # round to next multiple of 64
v = [Integer(a << n) // r for a in ahat]
def pretty_print_lattice(Lat):
latHex = [['0x' + x.hex() if x >= 0 else '-0x' + (-x).hex() for x in vec] for vec in Lat]
maxlen = max([len(cell) for row in latHex for cell in row])
for row in latHex:
row = ' '.join(cell.rjust(maxlen + 2) for cell in row)
print(row)
print('\nLattice')
pretty_print_lattice(Lat)
print('\nbasis:')
print(' 𝛼\u03050: 0x' + v[0].hex())
print(' 𝛼\u03051: 0x' + v[1].hex())
print(' 𝛼\u03052: 0x' + v[2].hex())
print(' 𝛼\u03053: 0x' + v[3].hex())
print('')
lambda1 = (t-1) % r
lambda2 = lambda1^2 % r
lambda3 = lambda1^3 % r
def getGLV2_decomp(scalar):
maxLen = (int(r).bit_length() + 3) // 4 + 1
a0 = (v[0] * scalar) >> n
a1 = (v[1] * scalar) >> n
a2 = (v[2] * scalar) >> n
a3 = (v[3] * scalar) >> n
print('𝛼0: ' + a0.hex())
print('𝛼1: ' + a1.hex())
print('𝛼2: ' + a2.hex())
print('𝛼3: ' + a3.hex())
print('𝛼3 unred: ' + (v[3] * scalar).hex())
print('')
print('Lat[3][0]: ' + Lat[3][0].hex())
print('a3 * Lat[3][0]: ' + (a3 * Lat[3][0]).hex())
print('')
k0 = scalar - a0 * Lat[0][0] - a1 * Lat[1][0] - a2 * Lat[2][0] - a3 * Lat[3][0]
k1 = 0 - a0 * Lat[0][1] - a1 * Lat[1][1] - a2 * Lat[2][1] - a3 * Lat[3][1]
k2 = 0 - a0 * Lat[0][2] - a1 * Lat[1][2] - a2 * Lat[2][2] - a3 * Lat[3][2]
k3 = 0 - a0 * Lat[0][3] - a1 * Lat[1][3] - a2 * Lat[2][3] - a3 * Lat[3][3]
k = [scalar, 0, 0, 0]
a = [a0, a1, a2, a3]
for i in range(4):
for j in range(4):
elem = a[j] * Lat[j][i]
print(f'a{j} * Lat[{j}][{i}] = {elem.hex()}')
k[i] -= elem
print(f' k{i} = {k[i].hex()}')
print('k: ' + str([ki.hex() for ki in k]))
print("k0.bitlength(): " + str(int(k0).bit_length()))
print("k1.bitlength(): " + str(int(k1).bit_length()))
print("k2.bitlength(): " + str(int(k2).bit_length()))
print("k3.bitlength(): " + str(int(k3).bit_length()))
print('k0: ' + k0.hex())
print('k1: ' + k1.hex())
print('k2: ' + k2.hex())
print('k3: ' + k3.hex())
assert scalar == (k0 + k1*lambda1 + k2*lambda2 + k3*lambda3) % r
assert int(k0).bit_length() <= maxLen
assert int(k1).bit_length() <= maxLen
assert int(k2).bit_length() <= maxLen
assert int(k3).bit_length() <= maxLen
return k0, k1, k2, k3
def recodeScalars(k):
m = 4
L = ((int(r).bit_length() + m-1) // m) + 1 # l = ⌈log2 r/m⌉ + 1
b = [[0] * L, [0] * L, [0] * L, [0] * L]
b[0][L-1] = 0
for i in range(0, L-1): # l-2 inclusive
b[0][i] = 1 - ((k[0] >> (i+1)) & 1)
for j in range(1, m):
for i in range(0, L):
b[j][i] = k[j] & 1
k[j] = k[j]//2 + (b[j][i] & b[0][i])
return b
def clearBit(v, bit):
return v & ~int(1 << bit)
def buildLut(P0, P_endos):
m = 4
assert len(P_endos) == m-1
lut = [0] * (1 << (m-1))
lut[0] = P0
lutS = [''] * (1 << (m-1))
lutS[0] = 'P0'
endoS = ['P1', 'P2', 'P3']
for u in range(1, 1 << (m-1)):
msb = u.bit_length() - 1
idx = clearBit(u, msb)
lut[u] = lut[clearBit(u, msb)] + P_endos[msb]
lutS[u] = lutS[clearBit(u, msb)] + ' + ' + endoS[msb]
print('LUT: ' + str(lutS))
return lut
def pointToString(P):
(Px, Py, Pz) = P
vPx = vector(Px)
vPy = vector(Py)
result = 'Point(\n'
result += ' Px: ' + Integer(vPx[0]).hex() + ' + β * ' + Integer(vPx[1]).hex() + '\n'
result += ' Py: ' + Integer(vPy[0]).hex() + ' + β * ' + Integer(vPy[1]).hex() + '\n'
result += ')'
return result
def getIndex(glvRecoding, bit):
m = 4
index = 0
for k in range(1, m):
index |= ((glvRecoding[k][bit] & 1) << (k-1))
return index
def updateFactors(factors, recoded, bit):
index = getIndex(recoded, bit)
if recoded[0][bit] == 0: # Positive
factors[0] += 1
factors[1] += (index >> 0) & 1
factors[2] += (index >> 1) & 1
factors[3] += (index >> 2) & 1
else:
factors[0] -= 1
factors[1] -= (index >> 0) & 1
factors[2] -= (index >> 1) & 1
factors[3] -= (index >> 2) & 1
def doubleFactors(factors):
for i in range(len(factors)):
factors[i] *= 2
def printFactors(factors):
for i in range(len(factors)):
print(f'f{i}: {factors[i].hex()}')
def scalarMulEndo(scalar, P0):
m = 4
L = ((int(r).bit_length() + m-1) // m) + 1 # l = ⌈log2 r/m⌉ + 1
print('L: ' + str(L))
print('scalar: ' + Integer(scalar).hex())
k0, k1, k2, k3 = getGLV2_decomp(scalar)
P1 = psi(P0)
P2 = psi2(P0)
P3 = psi(P2)
expected = scalar * P0
decomp = k0*P0 + k1*P1 + k2*P2 + k3*P3
print('expected: ' + pointToString(expected))
print('decomp: ' + pointToString(decomp))
assert expected == decomp
print('------ recode scalar -----------')
even = k0 & 1 == 0
print('was even: ' + str(even))
if even:
k0 += 1
b = recodeScalars([k0, k1, k2, k3])
print('b0: ' + str(list(reversed(b[0]))))
print('b1: ' + str(list(reversed(b[1]))))
print('b2: ' + str(list(reversed(b[2]))))
print('b3: ' + str(list(reversed(b[3]))))
print('------------ lut ---------------')
lut = buildLut(P0, [P1, P2, P3])
print('------------ mul ---------------')
# b[0][L-1] is always 0
print(f'L-1: {getIndex(b, L-1)}')
print(f'L-2: {getIndex(b, L-2)}')
print(f'L-3: {getIndex(b, L-3)}')
print(f'L-4: {getIndex(b, L-4)}')
print(f'L-5: {getIndex(b, L-5)}')
print(f'L-6: {getIndex(b, L-6)}')
factors = [0, 0, 0, 0] # Track the decomposed scalar applied (debugging)
updateFactors(factors, b, L-1)
Q = lut[getIndex(b, L-1)]
for bit in range(L-2, -1, -1):
Q *= 2
Q += (1 - 2 * b[0][bit]) * lut[getIndex(b, bit)]
doubleFactors(factors)
updateFactors(factors, b, bit)
if even:
Q -= P0
print('----')
print('final Q: ' + pointToString(Q))
print('expected: ' + pointToString(expected))
print('----')
printFactors(factors)
print('Mul expected:')
print(' k0: ' + k0.hex())
print(' k1: ' + k1.hex())
print(' k2: ' + k2.hex())
print(' k3: ' + k3.hex())
assert Q == expected
# Test generator
set_random_seed(1337)
for i in range(1):
print('---------------------------------------')
# scalar = randrange(r) # Pick an integer below curve order
# P = G2.random_point()
# P = clearCofactorG2(P)
# scalar = Integer('0x2c02275a71bb41c911faf48cab4f7ac7fc6672a5c15586185c8cff3203181da0')
# P = G2([
# Fp2([Integer('0x2a028c1328bb0abf252edfbf7133b84eef2a5f20163fe61685b4b54229ca585d'),
# Integer('0x8f80ad79e8e7e79bbdc645d9f5b339c52dd99a901b90de2494492656f11a9d5')]),
# Fp2([Integer('0x1f04320578e31ffa2e2b59ad8fcb1aba622b5f307ac540cf2ccdab07dec56503'),
# Integer('0x2973900c0fdf651b64f5b1a990baec7c582e0743d501bdb991374776d6c73b28')])
# ])
scalar = Integer('0x24c5b2ce21615dca82231f5fb0fc8d05aa07c6df4bb5aa7c2381ac7b61a6290c')
P = G2([
Fp2([Integer('0x1132e63c444e1abce6fc4c39bdf5be5caad586837cbf5ca9d3891482bdefe77'),
Integer('0x22b71f598dab789f055fc9669ddf66f0d75f581af0e9e8863d7f95a51ef34862')]),
Fp2([Integer('0x58e39050f64c9948d7238b99ecaee947cb934688a6e9f483c5c36b6e07aa31b'),
Integer('0x2e64b920f498e12992f2a4ae3f9ced43f3f64705c9008169f3b930a760d055fb')])
])
scalarMulEndo(scalar, P)