416 lines
14 KiB
Python
416 lines
14 KiB
Python
# Constantine
|
||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||
# Licensed and distributed under either of
|
||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||
|
||
# ############################################################
|
||
#
|
||
# BLS12-381 GLS Endomorphism
|
||
# Lattice Decomposition
|
||
#
|
||
# ############################################################
|
||
|
||
# Parameters
|
||
x = -(2^63 + 2^62 + 2^60 + 2^57 + 2^48 + 2^16)
|
||
p = (x - 1)^2 * (x^4 - x^2 + 1)//3 + x
|
||
r = x^4 - x^2 + 1
|
||
t = x+1
|
||
print(' Prime modulus p: 0x' + p.hex())
|
||
print(' Curve order r: 0x' + r.hex())
|
||
print(' trace t: 0x' + t.hex())
|
||
|
||
# Finite fields
|
||
Fp = GF(p)
|
||
K2.<u> = PolynomialRing(Fp)
|
||
Fp2.<beta> = Fp.extension(u^2+1)
|
||
|
||
SNR = Fp2([1, 1]) # Sextic Non-Residue for Sextic Twist
|
||
|
||
# Curves
|
||
b = 4
|
||
G1 = EllipticCurve(Fp, [0, b])
|
||
G2 = EllipticCurve(Fp2, [0, b*SNR])
|
||
|
||
# https://crypto.stackexchange.com/questions/64064/order-of-twisted-curve-in-pairings
|
||
# https://math.stackexchange.com/questions/144194/how-to-find-the-order-of-elliptic-curve-over-finite-field-extension
|
||
cofactorG1 = G1.order() // r
|
||
cofactorG2 = G2.order() // r
|
||
|
||
print('')
|
||
print('cofactor G1: ' + cofactorG1.hex())
|
||
print('cofactor G2: ' + cofactorG2.hex())
|
||
print('')
|
||
|
||
# Frobenius constants (D type: use SNR, M type use 1/SNR)
|
||
FrobConst_psi = (1/SNR)^((p-1)/6)
|
||
FrobConst_psi_2 = FrobConst_psi * FrobConst_psi
|
||
FrobConst_psi_3 = FrobConst_psi_2 * FrobConst_psi
|
||
FrobConst_psi2_2 = FrobConst_psi_2 * FrobConst_psi_2^p
|
||
FrobConst_psi2_3 = FrobConst_psi_3 * FrobConst_psi_3^p
|
||
|
||
def psi(P):
|
||
(Px, Py, Pz) = P
|
||
return G2([
|
||
FrobConst_psi_2 * Px.frobenius(1),
|
||
FrobConst_psi_3 * Py.frobenius(1)
|
||
# Pz.frobenius() - Always 1 after extract
|
||
])
|
||
|
||
def psi2(P):
|
||
(Px, Py, Pz) = P
|
||
return G2([
|
||
FrobConst_psi2_2 * Px.frobenius(2),
|
||
FrobConst_psi2_3 * Py.frobenius(2)
|
||
# Pz - Always 1 after extract
|
||
])
|
||
|
||
def clearCofactorG2(P):
|
||
return cofactorG2 * P
|
||
|
||
# Test generator
|
||
set_random_seed(1337)
|
||
|
||
# Check
|
||
def checkEndo():
|
||
P = G2.random_point()
|
||
P = clearCofactorG2(P)
|
||
|
||
(Px, Py, Pz) = P
|
||
|
||
# Galbraith-Lin-Scott, 2008, Theorem 1
|
||
assert psi(psi(P)) - t*psi(P) + p*P == G2([0, 1, 0])
|
||
# Galbraith-Scott, 2008, Lemma 1
|
||
# k-th cyclotomic polynomial with k = 12
|
||
assert psi2(psi2(P)) - psi2(P) + P == G2([0, 1, 0])
|
||
|
||
assert p % r == (t-1) % r
|
||
# assert (p^4 - p^2 + 1) % r == 0
|
||
assert ((t-1)^4 - (t-1)^2 + 1) % r == 0
|
||
assert (t-1)*P == (p % r)*P
|
||
assert (t-1)*P == psi(P)
|
||
|
||
print('Endomorphism OK')
|
||
|
||
checkEndo()
|
||
|
||
def subgroup_check(P):
|
||
ppP = psi2(P)
|
||
assert x * psi(ppP) - ppP + P == G2([0,1,0])
|
||
|
||
# Decomposition generated by LLL-algorithm and Babai rounding
|
||
# to solve the Shortest (Basis) Vector Problem
|
||
#
|
||
# TODO: This lattice is generating wrong result
|
||
# Lattice from Guide to Pairing-Based Cryptography
|
||
# Lat = [
|
||
# [ x, 1, 0, 0],
|
||
# [ 0, x, 1, 0],
|
||
# [ 0, 0, x, 1],
|
||
# [ 1, 0,-1, x]
|
||
# ]
|
||
# ahat = [x*(x^2+1), -(x^2+1), x, -1]
|
||
|
||
# Lattice from my own LLL+Babai rounding routines
|
||
Lat = Matrix([
|
||
[-x, 1, 0, 0],
|
||
[ 0,-x, 1, 0],
|
||
[ 0, 0,-x, 1],
|
||
[ 1, 0,-1,-x]
|
||
])
|
||
# print('Lat: ' + str(Lat))
|
||
ahat = vector([r, 0, 0, 0]) * Lat.inverse()
|
||
# print('ahat: ' + str(ahat))
|
||
|
||
n = int(r).bit_length()
|
||
n = int(((n + 64 - 1) // 64) * 64) # round to next multiple of 64
|
||
v = [Integer(int(a) << n) // r for a in ahat]
|
||
|
||
def pretty_print_lattice(Lat):
|
||
latHex = [['0x' + x.hex() if x >= 0 else '-0x' + (-x).hex() for x in vec] for vec in Lat]
|
||
maxlen = max([len(cell) for row in latHex for cell in row])
|
||
for row in latHex:
|
||
row = ' '.join(cell.rjust(maxlen + 2) for cell in row)
|
||
print(row)
|
||
|
||
print('\nLattice')
|
||
pretty_print_lattice(Lat)
|
||
|
||
print('\nbasis:')
|
||
print(' 𝛼\u03050: 0x' + v[0].hex())
|
||
print(' 𝛼\u03051: 0x' + v[1].hex())
|
||
print(' 𝛼\u03052: 0x' + v[2].hex())
|
||
print(' 𝛼\u03053: 0x' + v[3].hex())
|
||
print('')
|
||
|
||
maxInfNorm = abs(x + 2)
|
||
print('\nmax infinity norm:')
|
||
print(' ||(a0 , a1 , a2 , a3)||∞ ≤ 0x' + str(maxInfNorm.hex()))
|
||
print(' infinity norm bitlength: ' + str(int(maxInfNorm).bit_length()))
|
||
|
||
# Contrary to Faz2013 paper, we use the max infinity norm
|
||
# to properly dimension our recoding instead of ⌈log2 r/m⌉ + 1
|
||
# which fails for some inputs
|
||
# +1 for signed column
|
||
# Optional +1 for handling negative miniscalars
|
||
L = int(maxInfNorm).bit_length() + 1
|
||
L += 1
|
||
|
||
lambda1 = (t-1) % r
|
||
lambda2 = lambda1^2 % r
|
||
lambda3 = lambda1^3 % r
|
||
|
||
def getGLV2_decomp(scalar):
|
||
|
||
maxLen = (int(r).bit_length() + 3) // 4 + 1
|
||
maxLen += 1 # Deal with negative miniscalars
|
||
|
||
a0 = (v[0] * scalar) >> n
|
||
a1 = (v[1] * scalar) >> n
|
||
a2 = (v[2] * scalar) >> n
|
||
a3 = (v[3] * scalar) >> n
|
||
|
||
k0 = scalar - a0 * Lat[0][0] - a1 * Lat[1][0] - a2 * Lat[2][0] - a3 * Lat[3][0]
|
||
k1 = 0 - a0 * Lat[0][1] - a1 * Lat[1][1] - a2 * Lat[2][1] - a3 * Lat[3][1]
|
||
k2 = 0 - a0 * Lat[0][2] - a1 * Lat[1][2] - a2 * Lat[2][2] - a3 * Lat[3][2]
|
||
k3 = 0 - a0 * Lat[0][3] - a1 * Lat[1][3] - a2 * Lat[2][3] - a3 * Lat[3][3]
|
||
|
||
print("k0.bitlength(): " + str(int(k0).bit_length()))
|
||
print("k1.bitlength(): " + str(int(k1).bit_length()))
|
||
print("k2.bitlength(): " + str(int(k2).bit_length()))
|
||
print("k3.bitlength(): " + str(int(k3).bit_length()))
|
||
|
||
print('k0: ' + k0.hex())
|
||
print('k1: ' + k1.hex())
|
||
print('k2: ' + k2.hex())
|
||
print('k3: ' + k3.hex())
|
||
|
||
assert scalar == (k0 + k1*lambda1 + k2*lambda2 + k3*lambda3) % r
|
||
|
||
assert int(k0).bit_length() <= maxLen
|
||
assert int(k1).bit_length() <= maxLen
|
||
assert int(k2).bit_length() <= maxLen
|
||
assert int(k3).bit_length() <= maxLen
|
||
|
||
return k0, k1, k2, k3
|
||
|
||
def recodeScalars(k):
|
||
m = 4
|
||
|
||
b = [[0] * L, [0] * L, [0] * L, [0] * L]
|
||
b[0][L-1] = 0
|
||
for i in range(0, L-1): # l-2 inclusive
|
||
b[0][i] = 1 - ((k[0] >> (i+1)) & 1)
|
||
for j in range(1, m):
|
||
for i in range(0, L):
|
||
b[j][i] = k[j] & 1
|
||
k[j] = k[j]//2 + (b[j][i] & b[0][i])
|
||
|
||
return b
|
||
|
||
def clearBit(v, bit):
|
||
return v & ~int(1 << bit)
|
||
|
||
def buildLut(P0, P_endos):
|
||
m = 4
|
||
assert len(P_endos) == m-1
|
||
lut = [0] * (1 << (m-1))
|
||
lut[0] = P0
|
||
|
||
lutS = [''] * (1 << (m-1))
|
||
lutS[0] = 'P0'
|
||
endoS = ['P1', 'P2', 'P3']
|
||
|
||
for u in range(1, 1 << (m-1)):
|
||
msb = u.bit_length() - 1
|
||
idx = clearBit(u, msb)
|
||
lut[u] = lut[clearBit(u, msb)] + P_endos[msb]
|
||
|
||
lutS[u] = lutS[clearBit(u, msb)] + ' + ' + endoS[msb]
|
||
|
||
print('LUT: ' + str(lutS))
|
||
return lut
|
||
|
||
def pointToString(P):
|
||
(Px, Py, Pz) = P
|
||
vPx = vector(Px)
|
||
vPy = vector(Py)
|
||
result = 'Point(\n'
|
||
result += ' Px: ' + Integer(vPx[0]).hex() + ' + β * ' + Integer(vPx[1]).hex() + '\n'
|
||
result += ' Py: ' + Integer(vPy[0]).hex() + ' + β * ' + Integer(vPy[1]).hex() + '\n'
|
||
result += ')'
|
||
return result
|
||
|
||
def getIndex(glvRecoding, bit):
|
||
m = 4
|
||
index = 0
|
||
for k in range(1, m):
|
||
index |= ((glvRecoding[k][bit] & 1) << (k-1))
|
||
return index
|
||
|
||
def updateFactors(factors, recoded, bit):
|
||
index = getIndex(recoded, bit)
|
||
if recoded[0][bit] == 0: # Positive
|
||
factors[0] += 1
|
||
factors[1] += (index >> 0) & 1
|
||
factors[2] += (index >> 1) & 1
|
||
factors[3] += (index >> 2) & 1
|
||
else:
|
||
factors[0] -= 1
|
||
factors[1] -= (index >> 0) & 1
|
||
factors[2] -= (index >> 1) & 1
|
||
factors[3] -= (index >> 2) & 1
|
||
|
||
def doubleFactors(factors):
|
||
for i in range(len(factors)):
|
||
factors[i] *= 2
|
||
|
||
def printFactors(factors):
|
||
print('Multiplication done: ')
|
||
for i in range(len(factors)):
|
||
print(f' f{i}: {factors[i].hex()}')
|
||
|
||
def scalarMulEndo(scalar, P0):
|
||
m = 4
|
||
|
||
print('r bits: ' + str(int(r).bit_length()))
|
||
print('L: ' + str(L))
|
||
|
||
print('scalar: ' + Integer(scalar).hex())
|
||
|
||
k0, k1, k2, k3 = getGLV2_decomp(scalar)
|
||
|
||
P1 = psi(P0)
|
||
P2 = psi2(P0)
|
||
P3 = psi(P2)
|
||
|
||
expected = scalar * P0
|
||
decomp = k0*P0 + k1*P1 + k2*P2 + k3*P3
|
||
print('expected: ' + pointToString(expected))
|
||
print('decomp: ' + pointToString(decomp))
|
||
assert expected == decomp
|
||
|
||
# Alternative to adding an extra bit
|
||
# to deal with miniscalars, unfortunately broken
|
||
# for some input
|
||
# for example 0x5668a2332db27199dcfb7cbdfca6317c2ff128db26d7df68483e0a095ec8e88f
|
||
# which is missing bits for b[2]
|
||
# if k0 < 0: k0 = -k0; P0 = -P0
|
||
# if k1 < 0: k1 = -k1; P1 = -P1
|
||
# if k2 < 0: k2 = -k2; P2 = -P2
|
||
# if k3 < 0: k3 = -k3; P3 = -P3
|
||
|
||
print('------ recode scalar -----------')
|
||
even = k0 & 1 == 0
|
||
print('was even: ' + str(even))
|
||
if even:
|
||
k0 += 1
|
||
|
||
b = recodeScalars([k0, k1, k2, k3])
|
||
print('b0: ' + str(list(reversed(b[0]))))
|
||
print('b1: ' + str(list(reversed(b[1]))))
|
||
print('b2: ' + str(list(reversed(b[2]))))
|
||
print('b3: ' + str(list(reversed(b[3]))))
|
||
|
||
print('------------ lut ---------------')
|
||
|
||
lut = buildLut(P0, [P1, P2, P3])
|
||
|
||
print('------------ mul ---------------')
|
||
# b[0][L-1] is always 0
|
||
print(f'L-1: {getIndex(b, L-1)}')
|
||
print(f'L-2: {getIndex(b, L-2)}')
|
||
print(f'L-3: {getIndex(b, L-3)}')
|
||
print(f'L-4: {getIndex(b, L-4)}')
|
||
print(f'L-5: {getIndex(b, L-5)}')
|
||
print(f'L-6: {getIndex(b, L-6)}')
|
||
|
||
factors = [0, 0, 0, 0] # Track the decomposed scalar applied (debugging)
|
||
updateFactors(factors, b, L-1)
|
||
|
||
Q = lut[getIndex(b, L-1)]
|
||
for bit in range(L-2, -1, -1):
|
||
Q *= 2
|
||
Q += (1 - 2 * b[0][bit]) * lut[getIndex(b, bit)]
|
||
|
||
doubleFactors(factors)
|
||
updateFactors(factors, b, bit)
|
||
|
||
if even:
|
||
Q -= P0
|
||
print('----')
|
||
print('final Q: ' + pointToString(Q))
|
||
print('expected: ' + pointToString(expected))
|
||
print('----')
|
||
printFactors(factors)
|
||
print('Mul expected:')
|
||
print(' k0: ' + k0.hex())
|
||
print(' k1: ' + k1.hex())
|
||
print(' k2: ' + k2.hex())
|
||
print(' k3: ' + k3.hex())
|
||
|
||
assert Q == expected
|
||
|
||
# Test generator
|
||
set_random_seed(1337)
|
||
|
||
for i in range(1):
|
||
print('---------------------------------------')
|
||
# scalar = randrange(r) # Pick an integer below curve order
|
||
# P = G2.random_point()
|
||
# P = clearCofactorG2(P)
|
||
|
||
# scalar = Integer('0x1f7bef2a74f3bf8ac0225a9edfa514bb5666b15e7be3e929059f2ef75f0035a6')
|
||
# P = G2([
|
||
# Fp2([Integer('0x989f16bcb9da60ef72383e6134ba194f57e30109806304336c0c995e2857ed20bf5b6e03d6fe1424332e9c666cbd10a'),
|
||
# Integer('0x16692643cb5e7466e3730d3ea775c7741ac34d670b3be685761a7d6ab722a2673ce374ddab87b7c4d2675ba2199f9121')]),
|
||
# Fp2([Integer('0x931e416488bef7cb4a053e4bd86ef44818bc03a5be5b04606b2a4dc1d139a3a452f5f7172f24eeaad84702b73b155bb'),
|
||
# Integer('0x192c3e2a6619473216b7bb2447448cdbeb9f7e3c9486b0a05aadf6dcd91d7cb275a5d84c1a362628efffbc8711a62a67')])
|
||
# ])
|
||
|
||
# This integer leads to negative miniscalar for proper handling it requires either:
|
||
# 1. Negating it and then negating the corresponding curve point P
|
||
# 2. Adding an extra bit to the recoding, which will do the right thing™
|
||
#
|
||
# For implementation solution 1 is faster:
|
||
# - Double + Add is about 5000~8000 cycles on 6 64-bits limbs (BLS12-381)
|
||
# - Conditional negate is about 10 cycles per Fp, on G2 projective we have 3 (coords) * 2 (Fp2) * 10 (cycles) ~= 60 cycles
|
||
# We need to test the mini scalar, which is 65 bits so 2 Fp so about 2 cycles
|
||
# and negate it as well.
|
||
# scalar = Integer('0x6448f296d9b1a8d81319a0b789df04c587c6165776ccf39f50a354204aabe0da')
|
||
# P = G2([
|
||
# Fp2([Integer('0x5adc112fb04bf4ca642d5a7d7343ccd6b93546442d2fff5b9d32c15e456d54884cba49dd7f94ce4ddaad4018e55d0f2'),
|
||
# Integer('0x5d1c5bbf5d7a833dc76ba206bfa99c281fc37941be050e18f8c6d267b2376b3634d8ad6eb951e52a6d096315abd17d6')]),
|
||
# Fp2([Integer('0x15a959e54981fab9ac3c6f5bfd6fb60a50a916bd43d96a09922a54309b84812736581bfa728670cba864b08b9e391bb9'),
|
||
# Integer('0xf5d6d74f1dd3d9c07451340b8f6990fe93a28fe5e176564eb920bf17eb02df8b6f1e626eda5542ff415f89d51943001')])
|
||
# ])
|
||
|
||
# The following input fails in Constantine when negating the base point
|
||
# but not when adding an extra bit
|
||
|
||
# scalar = Integer('0x5668a2332db27199dcfb7cbdfca6317c2ff128db26d7df68483e0a095ec8e88f')
|
||
# P = G2([
|
||
# Fp2([Integer('0xa8c5649d2df1bae84fd9e8bfcde5113937b3acea22d67ddfedaf1fb8de8c1ef4c70591cf505c24c31e54020c2c510c3'),
|
||
# Integer('0xa0553f98229a6a067489c3ee204161c11e96f421b3e9c145dc3865b03e9d4ff6cab14c5b5308ecd31173f954463690c')]),
|
||
# Fp2([Integer('0xb29d8dfe18dc41b4826c3a102c1bf8f306cb42433cc36ee38080f47a324c02a678f9daed0a2bc577c18b9865de029f0'),
|
||
# Integer('0x558cdabf11e37c5c5e8abd668bbdd71bb3f07f320948ccaac8a207359fffe38424bfd9b1ef1d24b28b2fbb9f76faff1')])
|
||
# ])
|
||
|
||
# The following fails when we have both extra bit and negation of the first
|
||
# scalar if it is negative.
|
||
# it also uses 65 bits instead of teh expected max of 64
|
||
# And triggers an off by 1 when negating
|
||
|
||
scalar = Integer('0x6448f296d9b1a8d81319a0b789df04c587c6165776ccf39f50a354204aabe0da')
|
||
P = G2([
|
||
Fp2([Integer('0x5adc112fb04bf4ca642d5a7d7343ccd6b93546442d2fff5b9d32c15e456d54884cba49dd7f94ce4ddaad4018e55d0f2'),
|
||
Integer('0x5d1c5bbf5d7a833dc76ba206bfa99c281fc37941be050e18f8c6d267b2376b3634d8ad6eb951e52a6d096315abd17d6')]),
|
||
Fp2([Integer('0x15a959e54981fab9ac3c6f5bfd6fb60a50a916bd43d96a09922a54309b84812736581bfa728670cba864b08b9e391bb9'),
|
||
Integer('0xf5d6d74f1dd3d9c07451340b8f6990fe93a28fe5e176564eb920bf17eb02df8b6f1e626eda5542ff415f89d51943001')])
|
||
])
|
||
|
||
subgroup_check(P)
|
||
scalarMulEndo(scalar, P)
|