# Constantine # Copyright (c) 2018-2019 Status Research & Development GmbH # Copyright (c) 2020-Present Mamy André-Ratsimbazafy # Licensed and distributed under either of # * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). # * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). # at your option. This file may not be copied, modified, or distributed except according to those terms. # ############################################################ # # BLS12-381 GLS Endomorphism # Lattice Decomposition # # ############################################################ # Parameters x = 3 * 2^46 * (7 * 13 * 499) + 1 p = (x - 1)^2 * (x^4 - x^2 + 1)//3 + x r = x^4 - x^2 + 1 t = x+1 print(' Prime modulus p: 0x' + p.hex()) print(' Curve order r: 0x' + r.hex()) print(' trace t: 0x' + t.hex()) # Finite fields Fp = GF(p) K2. = PolynomialRing(Fp) Fp2. = Fp.extension(u^2+5) SNR = Fp2([0, 1]) # Sextic Non-Residue for Sextic Twist # Curves b = 1 G1 = EllipticCurve(Fp, [0, b]) G2 = EllipticCurve(Fp2, [0, b/SNR]) # https://crypto.stackexchange.com/questions/64064/order-of-twisted-curve-in-pairings # https://math.stackexchange.com/questions/144194/how-to-find-the-order-of-elliptic-curve-over-finite-field-extension cofactorG1 = G1.order() // r cofactorG2 = G2.order() // r print('') print('cofactor G1: ' + cofactorG1.hex()) print('cofactor G2: ' + cofactorG2.hex()) print('') # Frobenius constants (D type: use SNR, M type use 1/SNR) FrobConst_psi = SNR^((p-1)/6) FrobConst_psi_2 = FrobConst_psi * FrobConst_psi FrobConst_psi_3 = FrobConst_psi_2 * FrobConst_psi FrobConst_psi2_2 = FrobConst_psi_2 * FrobConst_psi_2^p FrobConst_psi2_3 = FrobConst_psi_3 * FrobConst_psi_3^p def psi(P): (Px, Py, Pz) = P return G2([ FrobConst_psi_2 * Px.frobenius(1), FrobConst_psi_3 * Py.frobenius(1) # Pz.frobenius() - Always 1 after extract ]) def psi2(P): (Px, Py, Pz) = P return G2([ FrobConst_psi2_2 * Px.frobenius(2), FrobConst_psi2_3 * Py.frobenius(2) # Pz - Always 1 after extract ]) def clearCofactorG2(P): return cofactorG2 * P # Test generator set_random_seed(1337) # Check def checkEndo(): P = G2.random_point() P = clearCofactorG2(P) (Px, Py, Pz) = P # Galbraith-Lin-Scott, 2008, Theorem 1 assert psi(psi(P)) - t*psi(P) + p*P == G2([0, 1, 0]) # Galbraith-Scott, 2008, Lemma 1 # k-th cyclotomic polynomial with k = 12 assert psi2(psi2(P)) - psi2(P) + P == G2([0, 1, 0]) assert p % r == (t-1) % r # assert (p^4 - p^2 + 1) % r == 0 assert ((t-1)^4 - (t-1)^2 + 1) % r == 0 assert (t-1)*P == (p % r)*P assert (t-1)*P == psi(P) print('Endomorphism OK') checkEndo() def subgroup_check(P): ppP = psi2(P) assert x * psi(ppP) - ppP + P == G2([0,1,0]) # Decomposition generated by LLL-algorithm and Babai rounding # to solve the Shortest (Basis) Vector Problem # # TODO: This lattice is generating wrong result # Lattice from Guide to Pairing-Based Cryptography # Lat = [ # [ x, 1, 0, 0], # [ 0, x, 1, 0], # [ 0, 0, x, 1], # [ 1, 0,-1, x] # ] # ahat = [x*(x^2+1), -(x^2+1), x, -1] # Lattice from my own LLL+Babai rounding routines Lat = Matrix([ [-x, 1, 0, 0], [ 0,-x, 1, 0], [ 0, 0,-x, 1], [ 1, 0,-1,-x] ]) # print('Lat: ' + str(Lat)) ahat = vector([r, 0, 0, 0]) * Lat.inverse() # print('ahat: ' + str(ahat)) n = int(r).bit_length() n = int(((n + 64 - 1) // 64) * 64) # round to next multiple of 64 v = [Integer(int(a) << n) // r for a in ahat] def pretty_print_lattice(Lat): latHex = [['0x' + x.hex() if x >= 0 else '-0x' + (-x).hex() for x in vec] for vec in Lat] maxlen = max([len(cell) for row in latHex for cell in row]) for row in latHex: row = ' '.join(cell.rjust(maxlen + 2) for cell in row) print(row) print('\nLattice') pretty_print_lattice(Lat) print('\nbasis:') print(' 𝛼\u03050: 0x' + v[0].hex()) print(' 𝛼\u03051: 0x' + v[1].hex()) print(' 𝛼\u03052: 0x' + v[2].hex()) print(' 𝛼\u03053: 0x' + v[3].hex()) print('') maxInfNorm = abs(x + 2) print('\nmax infinity norm:') print(' ||(a0, a1, a2, a3)||∞ ≤ 0x' + str(maxInfNorm.hex())) print(' infinity norm bitlength: ' + str(int(maxInfNorm).bit_length())) # Contrary to Faz2013 paper, we use the max infinity norm # to properly dimension our recoding instead of ⌈log2 r/m⌉ + 1 # which fails for some inputs # +1 for signed column # Optional +1 for handling negative miniscalars L = int(maxInfNorm).bit_length() + 1 L += 1 lambda1 = (t-1) % r lambda2 = lambda1^2 % r lambda3 = lambda1^3 % r def getGLV2_decomp(scalar): maxLen = (int(r).bit_length() + 3) // 4 + 1 maxLen += 1 # Deal with negative miniscalars a0 = (v[0] * scalar) >> n a1 = (v[1] * scalar) >> n a2 = (v[2] * scalar) >> n a3 = (v[3] * scalar) >> n k0 = scalar - a0 * Lat[0][0] - a1 * Lat[1][0] - a2 * Lat[2][0] - a3 * Lat[3][0] k1 = 0 - a0 * Lat[0][1] - a1 * Lat[1][1] - a2 * Lat[2][1] - a3 * Lat[3][1] k2 = 0 - a0 * Lat[0][2] - a1 * Lat[1][2] - a2 * Lat[2][2] - a3 * Lat[3][2] k3 = 0 - a0 * Lat[0][3] - a1 * Lat[1][3] - a2 * Lat[2][3] - a3 * Lat[3][3] print("k0.bitlength(): " + str(int(k0).bit_length())) print("k1.bitlength(): " + str(int(k1).bit_length())) print("k2.bitlength(): " + str(int(k2).bit_length())) print("k3.bitlength(): " + str(int(k3).bit_length())) print('k0: ' + k0.hex()) print('k1: ' + k1.hex()) print('k2: ' + k2.hex()) print('k3: ' + k3.hex()) assert scalar == (k0 + k1*lambda1 + k2*lambda2 + k3*lambda3) % r assert int(k0).bit_length() <= maxLen assert int(k1).bit_length() <= maxLen assert int(k2).bit_length() <= maxLen assert int(k3).bit_length() <= maxLen return k0, k1, k2, k3 def recodeScalars(k): m = 4 b = [[0] * L, [0] * L, [0] * L, [0] * L] b[0][L-1] = 0 for i in range(0, L-1): # l-2 inclusive b[0][i] = 1 - ((k[0] >> (i+1)) & 1) for j in range(1, m): for i in range(0, L): b[j][i] = k[j] & 1 k[j] = k[j]//2 + (b[j][i] & b[0][i]) return b def clearBit(v, bit): return v & ~int(1 << bit) def buildLut(P0, P_endos): m = 4 assert len(P_endos) == m-1 lut = [0] * (1 << (m-1)) lut[0] = P0 lutS = [''] * (1 << (m-1)) lutS[0] = 'P0' endoS = ['P1', 'P2', 'P3'] for u in range(1, 1 << (m-1)): msb = u.bit_length() - 1 idx = clearBit(u, msb) lut[u] = lut[clearBit(u, msb)] + P_endos[msb] lutS[u] = lutS[clearBit(u, msb)] + ' + ' + endoS[msb] print('LUT: ' + str(lutS)) return lut def pointToString(P): (Px, Py, Pz) = P vPx = vector(Px) vPy = vector(Py) result = 'Point(\n' result += ' Px: ' + Integer(vPx[0]).hex() + ' + β * ' + Integer(vPx[1]).hex() + '\n' result += ' Py: ' + Integer(vPy[0]).hex() + ' + β * ' + Integer(vPy[1]).hex() + '\n' result += ')' return result def getIndex(glvRecoding, bit): m = 4 index = 0 for k in range(1, m): index |= ((glvRecoding[k][bit] & 1) << (k-1)) return index def updateFactors(factors, recoded, bit): index = getIndex(recoded, bit) if recoded[0][bit] == 0: # Positive factors[0] += 1 factors[1] += (index >> 0) & 1 factors[2] += (index >> 1) & 1 factors[3] += (index >> 2) & 1 else: factors[0] -= 1 factors[1] -= (index >> 0) & 1 factors[2] -= (index >> 1) & 1 factors[3] -= (index >> 2) & 1 def doubleFactors(factors): for i in range(len(factors)): factors[i] *= 2 def printFactors(factors): print('Multiplication done: ') for i in range(len(factors)): print(f' f{i}: {factors[i].hex()}') def scalarMulEndo(scalar, P0): m = 4 print('L: ' + str(L)) print('scalar: ' + Integer(scalar).hex()) k0, k1, k2, k3 = getGLV2_decomp(scalar) P1 = psi(P0) P2 = psi2(P0) P3 = psi(P2) expected = scalar * P0 decomp = k0*P0 + k1*P1 + k2*P2 + k3*P3 print('expected: ' + pointToString(expected)) print('decomp: ' + pointToString(decomp)) assert expected == decomp # Alternative to adding an extra bit # to deal with miniscalars # if k0 < 0: k0 = -k0; P0 = -P0 # if k1 < 0: k1 = -k1; P1 = -P1 # if k2 < 0: k2 = -k2; P2 = -P2 # if k3 < 0: k3 = -k3; P3 = -P3 # assert expected == k0*P0 + k1*P1 + k2*P2 + k3*P3 # Somehow the recoding doesn't cope with first scalar being negative if k0 < 0: k0 = -k0 P0 = -P0 print('------ recode scalar -----------') even = k0 & 1 == 0 print('was even: ' + str(even)) if even: k0 += 1 b = recodeScalars([k0, k1, k2, k3]) print('b0: ' + str(list(reversed(b[0])))) print('b1: ' + str(list(reversed(b[1])))) print('b2: ' + str(list(reversed(b[2])))) print('b3: ' + str(list(reversed(b[3])))) print('------------ lut ---------------') lut = buildLut(P0, [P1, P2, P3]) print('------------ mul ---------------') # b[0][L-1] is always 0 print(f'L-1: {getIndex(b, L-1)}') print(f'L-2: {getIndex(b, L-2)}') print(f'L-3: {getIndex(b, L-3)}') print(f'L-4: {getIndex(b, L-4)}') print(f'L-5: {getIndex(b, L-5)}') print(f'L-6: {getIndex(b, L-6)}') factors = [0, 0, 0, 0] # Track the decomposed scalar applied (debugging) updateFactors(factors, b, L-1) Q = lut[getIndex(b, L-1)] for bit in range(L-2, -1, -1): Q *= 2 Q += (1 - 2 * b[0][bit]) * lut[getIndex(b, bit)] doubleFactors(factors) updateFactors(factors, b, bit) if even: Q -= P0 print('----') print('final Q: ' + pointToString(Q)) print('expected: ' + pointToString(expected)) print('----') printFactors(factors) print('Mul expected:') print(' k0: ' + k0.hex()) print(' k1: ' + k1.hex()) print(' k2: ' + k2.hex()) print(' k3: ' + k3.hex()) assert Q == expected # Test generator set_random_seed(1337) for i in range(1): print('---------------------------------------') scalar = randrange(r) # Pick an integer below curve order # P = G2.random_point() # P = clearCofactorG2(P) scalar = Integer('0x9d432eb58ec68bbc09d10961451d99c7796fb2f795eca603d6feaf3e2a1634b') P = G2([ Fp2([Integer('0x267401f3ef554fe74ae131d56a10edf14ae40192654901b4618d2bf7af22e77c2a9b79e407348dbd4aad13ca73b33a'), Integer('0x12dcca838f46a3e0418e5dd8b978362757a16bfd78f0b77f4a1916ace353938389ae3ea228d0eb5020a0aaa58884aec')]), Fp2([Integer('0x11799118d2e054aabd9f74c0843fecbdc1c0d56f61c61c5854c2507ae2416e48a6b2cd3bc8bf7495a4d3d8270eafe2b'), Integer('0x823b9f8fb9f8297734a14359fa2c2a0de275e7e638197eaaaa7cff28f9cb3101bdabb570016672455f1ecae625e294')]) ]) subgroup_check(P) scalarMulEndo(scalar, P)