diff --git a/constantine/arithmetic/bigints.nim b/constantine/arithmetic/bigints.nim index d4bc8da..9ff2656 100644 --- a/constantine/arithmetic/bigints.nim +++ b/constantine/arithmetic/bigints.nim @@ -137,6 +137,12 @@ func cadd*(a: var BigInt, b: BigInt, ctl: SecretBool): SecretBool = ## The result carry is always computed. (SecretBool) cadd(a.limbs, b.limbs, ctl) +func cadd*(a: var BigInt, b: SecretWord, ctl: SecretBool): SecretBool = + ## Constant-time in-place conditional addition + ## The addition is only performed if ctl is "true" + ## The result carry is always computed. + (SecretBool) cadd(a.limbs, b, ctl) + func csub*(a: var BigInt, b: BigInt, ctl: SecretBool): SecretBool = ## Constant-time in-place conditional substraction ## The substraction is only performed if ctl is "true" diff --git a/constantine/arithmetic/limbs.nim b/constantine/arithmetic/limbs.nim index 8f7be16..e7cee3d 100644 --- a/constantine/arithmetic/limbs.nim +++ b/constantine/arithmetic/limbs.nim @@ -198,6 +198,17 @@ func cadd*(a: var Limbs, b: Limbs, ctl: SecretBool): Carry = addC(result, sum, a[i], b[i], result) ctl.ccopy(a[i], sum) +func cadd*(a: var Limbs, w: SecretWord, ctl: SecretBool): Borrow = + ## Limbs conditional addition, sub a number that fits in a word + ## Returns the borrow + result = Carry(0) + var diff: SecretWord + addC(result, diff, a[0], w, result) + ctl.ccopy(a[0], diff) + for i in 1 ..< a.len: + addC(result, diff, a[i], Zero, result) + ctl.ccopy(a[i], diff) + func sum*(r: var Limbs, a, b: Limbs): Carry = ## Sum `a` and `b` into `r` ## `r` is initialized/overwritten @@ -246,7 +257,7 @@ func csub*(a: var Limbs, b: Limbs, ctl: SecretBool): Borrow = func csub*(a: var Limbs, w: SecretWord, ctl: SecretBool): Borrow = ## Limbs conditional substraction, sub a number that fits in a word ## Returns the borrow - result = Carry(0) + result = Borrow(0) var diff: SecretWord subB(result, diff, a[0], w, result) ctl.ccopy(a[0], diff) diff --git a/constantine/elliptic/ec_endomorphism_accel.nim b/constantine/elliptic/ec_endomorphism_accel.nim index bcaae31..c5849e0 100644 --- a/constantine/elliptic/ec_endomorphism_accel.nim +++ b/constantine/elliptic/ec_endomorphism_accel.nim @@ -59,64 +59,24 @@ type ## and m the number of dimensions of the GLV endomorphism ## (ii) Exactly one subscalar which should be odd ## is expressed by a signed nonzero representation - ## with all digits ∈ {1, −1} + ## with all digits ∈ {1, −1} represented at a lowlevel + ## by bit {0, 1} (0 bit -> positive 1 digit, 1 bit -> negative -1 digit) ## (iii) Other subscalars have digits ∈ {0, 1, −1} - ## - ## We pack the representation, using 2 bits per digit: - ## 0 = 0b00 - ## 1 = 0b01 - ## -1 = 0b11 - ## - ## This means that GLV_SAC uses twice the size of a canonical integer + ## with 0 encoded as 0 and 1/-1 encoded as 1 + ## and the sign taken from the sign subscalar (at position 0) ## ## Digit-Endianness is bigEndian const - BitSize = 2 - Shift = 2 # log2(4) - we can store 4 digit per byte - ByteMask = 3 # we need (mod 4) to access a packed bytearray - DigitMask = 0b11 # Digits take 2-bit + BitSize = 1 + Shift = 1 # log2(2) - we can store 2 digits per byte + ByteMask = 1 # we need (mod 2) to access a packed bytearray + DigitMask = 0b1 # Digits take 1-bit -# template signExtend_2bit(recoded: byte): int8 = -# ## We need to extend: -# ## - 0b00 to 0b0000_0000 ( 0) -# ## - 0b01 to 0b0000_0001 ( 1) -# ## - 0b11 to 0b1111_1111 (-1) -# ## -# ## This can be done by shifting left to have -# ## - 0b00 to 0b0000_0000 -# ## - 0b01 to 0b0100_0000 -# ## - 0b11 to 0b1100_0000 -# ## -# ## And then an arithmetic right shift (SAR) -# ## -# ## However there is no builtin SAR -# ## we can get it in C by right-shifting -# ## with the main compilers/platforms -# ## (GCC, Clang, MSVC, ...) -# ## but this is implementation defined behavior -# ## Nim `ashr` uses C signed right shifting -# ## -# ## We could check the compiler to ensure we only use -# ## well documented behaviors: https://gcc.gnu.org/onlinedocs/gcc/Integers-implementation.html#Integers-implementation -# ## but if we can avoid that altogether in a crypto library -# ## -# ## Instead we use signed bitfield which are automatically sign-extended -# ## in a portable way as sign extension is automatic for builtin types - -type - SignExtender = object - ## Uses C builtin types sign extension to sign extend 2-bit to 8-bit - ## in a portable way as sign extension is automatic for builtin types - ## http://graphics.stanford.edu/~seander/bithacks.html#FixedSignExtend - digit {.bitsize:2.}: int8 - -# TODO: use unsigned to avoid checks and potentially leaking secrets -# or push checks off (or prove that checks can be elided once Nim has Z3 in the compiler) proc `[]`(recoding: Recoded, - digitIdx: int): int8 {.inline.}= + digitIdx: int): uint8 {.inline.}= ## 0 <= digitIdx < LengthInDigits - ## returns digit ∈ {0, 1, −1} + ## returns digit ∈ {0, 1} const len = Recoded.LengthInDigits assert digitIdx < len @@ -124,16 +84,12 @@ proc `[]`(recoding: Recoded, len-1 - (digitIdx shr Shift) ] let recoded = slot shr (BitSize*(digitIdx and ByteMask)) and DigitMask - var signExtender: SignExtender - # Hack with C assignment that return values - {.emit: [result, " = ", signExtender, ".digit = ", recoded, ";"].} - # " # Fix highlighting bug in VScode - + return recoded proc `[]=`(recoding: var Recoded, - digitIdx: int, value: int8) {.inline.}= + digitIdx: int, value: uint8) {.inline.}= ## 0 <= digitIdx < LengthInDigits - ## returns digit ∈ {0, 1, −1} + ## returns digit ∈ {0, 1} ## This is write-once const len = Recoded.LengthInDigits assert digitIdx < Recoded.LengthInDigits @@ -145,7 +101,6 @@ proc `[]=`(recoding: var Recoded, let shifted = byte((value and DigitMask) shl (BitSize*(digitIdx and ByteMask))) slot[] = slot[] or shifted - func nDimMultiScalarRecoding[M, L: static int]( dst: var GLV_SAC[M, L], src: MultiScalar[M, L] @@ -164,64 +119,25 @@ func nDimMultiScalarRecoding[M, L: static int]( # Algorithm 1 Protected Recoding Algorithm for the GLV-SAC Representation. # ------------------------------------------------------------------------ # - # Input: m l-bit positive integers kj = (kj_l−1, ..., kj_0)_2 for - # 0 ≤ j < m, an odd “sign-aligner” kJ ∈ {kj}^m, where - # l = ⌈log2 r/m⌉ + 1, m is the GLV dimension and r is - # the prime subgroup order. - # Output: (bj_l−1 , ..., bj_0)GLV-SAC for 0 ≤ j < m, where - # bJ_i ∈ {1, −1}, and bj_i ∈ {0, bJ_i} for 0 ≤ j < m with - # j != J. - # ------------------------------------------------------------------------ - # - # 1: bJ_l-1 = 1 - # 2: for i = 0 to (l − 2) do - # 3: bJ_i = 2kJ_i+1 - 1 - # 4: for j = 0 to (m − 1), j != J do - # 5: for i = 0 to (l − 1) do - # 6: bj_i = bJ_i kj_0 - # 7: kj = ⌊kj/2⌋ − ⌊bj_i/2⌋ - # 8: return (bj_l−1 , . . . , bj_0)_GLV-SAC for 0 ≤ j < m. - # - # - Guide to Pairing-based Cryptography - # Chapter 6: Scalar Multiplication and Exponentiation in Pairing Groups - # Joppe Bos, Craig Costello, Michael Naehrig - # - # We choose kJ = k0 - # - # Implementation strategy and points of attention - # - The subscalars kj must support extracting the least significant bit - # - The subscalars kj must support floor division by 2 - # For that floored division, kj is 0 or positive - # - The subscalars kj must support individual bit accesses - # - The subscalars kj must support addition by a small value (0 or 1) - # Hence we choose to use our own BigInt representation. - # - # - The digit bji must support floor division by 2 - # For that floored division, bji may be negative!!! - # In particular floored division of -1 is -1 not 0. - # This means that arithmetic right shift must be used instead of logical right shift + # We modify Algorithm 1 with the last paragraph optimization suggestions: + # - instead of ternary coding -1, 0, 1 (for negative, 0, positive) + # - we use 0, 1 for (0, sign of column) + # and in the sign column 0, 1 for (positive, negative) # assert src[0].isOdd - Only happen on implementation error, we don't want to leak a single bit var k = src # Keep the source multiscalar in registers template b: untyped {.dirty.} = dst - b[0][L-1] = 1 + b[0][L-1] = 0 # means positive column for i in 0 .. L-2: - b[0][i] = 2 * k[0].bit(i+1).int8 - 1 + b[0][i] = 1 - k[0].bit(i+1).uint8 for j in 1 .. M-1: for i in 0 .. L-1: - let bji = b[0][i] * k[j].bit0.int8 + let bji = k[j].bit0.uint8 b[j][i] = bji - # In the following equation - # kj = ⌊kj/2⌋ − ⌊bj_i/2⌋ - # We have ⌊bj_i/2⌋ (floor division) - # = -1 if bj_i == -1 - # = 0 if bj_i ∈ {0, 1} - # So we turn that statement in an addition - # by the opposite k[j].div2() - k[j] += SecretWord -bji.ashr(1) + k[j] += SecretWord (bji and b[0][i]) func buildLookupTable[M: static int, F]( P: ECP_SWei_Proj[F], @@ -281,10 +197,6 @@ func tableIndex(glv: GLV_SAC, bit: int): SecretWord = staticFor i, 1, GLV_SAC.M: result = result or SecretWord((glv[i][bit] and 1) shl (i-1)) -func isNeg(glv: GLV_SAC, bit: int): SecretBool = - ## Returns true if the bit requires substraction - SecretBool(glv[0][bit] < 0) - func secretLookup[T](dst: var T, table: openArray[T], index: SecretWord) = ## Load a table[index] into `dst` ## This is constant-time, whatever the `index`, its value is not leaked @@ -309,13 +221,13 @@ func scalarMulGLV*[scalBits]( const M = 2 # 1. Compute endomorphisms - var endomorphisms: array[M-1, typeof(P)] # TODO: zero-init not required + var endomorphisms {.noInit.}: array[M-1, typeof(P)] endomorphisms[0] = P endomorphisms[0].x *= C.getCubicRootOfUnity_mod_p() # 2. Decompose scalar into mini-scalars const L = (C.getCurveOrderBitwidth() + M - 1) div M + 1 - var miniScalars: array[M, BigInt[L]] # TODO: zero-init not required + var miniScalars {.noInit.}: array[M, BigInt[L]] when C == BN254_Snarks: scalar.decomposeScalar_BN254_Snarks_G1( miniScalars @@ -333,7 +245,7 @@ func scalarMulGLV*[scalBits]( # in the GLV representation at the low low price of 1 bit # 4. Precompute lookup table - var lut: array[1 shl (M-1), ECP_SWei_Proj] # TODO: zero-init not required + var lut {.noInit.}: array[1 shl (M-1), ECP_SWei_Proj] buildLookupTable(P, endomorphisms, lut) # TODO: Montgomery simultaneous inversion (or other simultaneous inversion techniques) # so that we use mixed addition formulas in the main loop @@ -341,29 +253,26 @@ func scalarMulGLV*[scalBits]( # 5. Recode the miniscalars # we need the base miniscalar (that encodes the sign) # to be odd, and this in constant-time to protect the secret least-significant bit. - var k0isOdd = miniScalars[0].isOdd() - discard miniScalars[0].csub(SecretWord(1), not k0isOdd) + let k0isOdd = miniScalars[0].isOdd() + discard miniScalars[0].cadd(SecretWord(1), not k0isOdd) var recoded: GLV_SAC[2, L] # zero-init required recoded.nDimMultiScalarRecoding(miniScalars) # 6. Proceed to GLV accelerated scalar multiplication - var Q: typeof(P) # TODO: zero-init not required + var Q {.noInit.}: typeof(P) Q.secretLookup(lut, recoded.tableIndex(L-1)) - Q.cneg(recoded.isNeg(L-1)) for i in countdown(L-2, 0): Q.double() - var tmp: typeof(Q) # TODO: zero-init not required + var tmp {.noInit.}: typeof(Q) tmp.secretLookup(lut, recoded.tableIndex(i)) - tmp.cneg(recoded.isNeg(i)) + tmp.cneg(SecretBool recoded[0][i]) Q += tmp # Now we need to correct if the sign miniscalar was not odd - # we missed an addition - lut[0] += Q # Contains Q + P0 - P = Q - P.ccopy(lut[0], not k0isOdd) + P.diff(Q, lut[0]) # Contains Q - P0 + P.ccopy(Q, k0isOdd) # Sanity checks # ---------------------------------------------------------------- @@ -384,7 +293,6 @@ when isMainModule: for i in countdown(glvSac.LengthInDigits-1, 0): result.add " " & (block: case glvSac[j][i] - of -1: "1\u{0305}" of 0: "0" of 1: "1" else: @@ -438,7 +346,7 @@ when isMainModule: const miniBitwidth = 4 # Bitwidth of the miniscalars resulting from scalar decomposition var k: MultiScalar[M, miniBitwidth] - var kRecoded: GLV_SAC[M, miniBitwidth+1] + var kRecoded: GLV_SAC[M, miniBitwidth] k[0].fromUint(11) k[1].fromUint(6) @@ -477,39 +385,40 @@ when isMainModule: const M = 2 const scalBits = BN254_Snarks.getCurveOrderBitwidth() const miniBits = (scalBits+M-1) div M + const L = miniBits + 1 block: let scalar = BigInt[scalBits].fromHex( "0x24a0b87203c7a8def0018c95d7fab106373aebf920265c696f0ae08f8229b3f3" ) - var decomp: MultiScalar[M, miniBits] + var decomp: MultiScalar[M, L] decomposeScalar_BN254_Snarks_G1(scalar, decomp) - doAssert: bool(decomp[0] == BigInt[127].fromHex"14928105460c820ccc9a25d0d953dbfe") - doAssert: bool(decomp[1] == BigInt[127].fromHex"13a2f911eb48a578844b901de6f41660") + doAssert: bool(decomp[0] == BigInt[L].fromHex"14928105460c820ccc9a25d0d953dbfe") + doAssert: bool(decomp[1] == BigInt[L].fromHex"13a2f911eb48a578844b901de6f41660") block: let scalar = BigInt[scalBits].fromHex( "24554fa6d0c06f6dc51c551dea8b058cd737fc8d83f7692fcebdd1842b3092c4" ) - var decomp: MultiScalar[M, miniBits] + var decomp: MultiScalar[M, L] decomposeScalar_BN254_Snarks_G1(scalar, decomp) - doAssert: bool(decomp[0] == BigInt[127].fromHex"28cf7429c3ff8f7e82fc419e90cc3a2") - doAssert: bool(decomp[1] == BigInt[127].fromHex"457efc201bdb3d2e6087df36430a6db6") + doAssert: bool(decomp[0] == BigInt[L].fromHex"28cf7429c3ff8f7e82fc419e90cc3a2") + doAssert: bool(decomp[1] == BigInt[L].fromHex"457efc201bdb3d2e6087df36430a6db6") block: let scalar = BigInt[scalBits].fromHex( "288c20b297b9808f4e56aeb70eabf269e75d055567ff4e05fe5fb709881e6717" ) - var decomp: MultiScalar[M, miniBits] + var decomp: MultiScalar[M, L] decomposeScalar_BN254_Snarks_G1(scalar, decomp) - doAssert: bool(decomp[0] == BigInt[127].fromHex"4da8c411566c77e00c902eb542aaa66b") - doAssert: bool(decomp[1] == BigInt[127].fromHex"5aa8f2f15afc3217f06677702bd4e41a") + doAssert: bool(decomp[0] == BigInt[L].fromHex"4da8c411566c77e00c902eb542aaa66b") + doAssert: bool(decomp[1] == BigInt[L].fromHex"5aa8f2f15afc3217f06677702bd4e41a") main_decomp() @@ -555,7 +464,7 @@ when isMainModule: const miniBitwidth = 4 # Bitwidth of the miniscalars resulting from scalar decomposition const L = miniBitwidth + 1 # Bitwidth of the recoded scalars - var k: MultiScalar[M, miniBitwidth] + var k: MultiScalar[M, L] var kRecoded: GLV_SAC[M, L] k[0].fromUint(11) @@ -582,7 +491,7 @@ when isMainModule: # Q = sign_l-1 P[K_l-1] let idx = kRecoded.tableIndex(L-1) for p in lut[int(idx)]: - Q[p] = if kRecoded.isNeg(L-1).bool: -1 else: 1 + Q[p] = if kRecoded[0][L-1] == 0: 1 else: -1 # Loop for i in countdown(L-2, 0): # Q = 2Q @@ -591,7 +500,7 @@ when isMainModule: # Q = Q + sign_l-1 P[K_l-1] let idx = kRecoded.tableIndex(i) for p in lut[int(idx)]: - Q[p] += (if kRecoded.isNeg(i).bool: -1 else: 1) + Q[p] += (if kRecoded[0][i] == 0: 1 else: -1) echo "Q + sign_l-1 P[K_l-1]: ", Q echo Q diff --git a/constantine/elliptic/ec_weierstrass_projective.nim b/constantine/elliptic/ec_weierstrass_projective.nim index 3bdcd51..97ea1f9 100644 --- a/constantine/elliptic/ec_weierstrass_projective.nim +++ b/constantine/elliptic/ec_weierstrass_projective.nim @@ -300,6 +300,15 @@ func double*[F](P: var ECP_SWei_Proj[F]) = tmp.double(P) P = tmp +func diff*[F](r: var ECP_SWei_Proj[F], + P, Q: ECP_SWei_Proj[F] + ) = + ## r = P - Q + ## Can handle r and Q aliasing + var nQ = Q + nQ.neg() + r.sum(P, nQ) + func affineFromProjective*[F](aff: var ECP_SWei_Proj[F], proj: ECP_SWei_Proj) = # TODO: for now we reuse projective coordinate backend # TODO: simultaneous inversion diff --git a/sage/lattice_decomposition_bls12_381_g1.sage b/sage/lattice_decomposition_bls12_381_g1.sage index 3d082c8..ef9ec82 100644 --- a/sage/lattice_decomposition_bls12_381_g1.sage +++ b/sage/lattice_decomposition_bls12_381_g1.sage @@ -117,13 +117,13 @@ def recodeScalars(k): L = ((int(r).bit_length() + m-1) // m) + 1 # l = ⌈log2 r/m⌉ + 1 b = [[0] * L, [0] * L] - b[0][L-1] = 1 + b[0][L-1] = 0 for i in range(0, L-1): # l-2 inclusive - b[0][i] = 2 * ((k[0] >> (i+1)) & 1) - 1 + b[0][i] = 1 - ((k[0] >> (i+1)) & 1) for j in range(1, m): for i in range(0, L): - b[j][i] = b[0][i] * (k[j] & 1) - k[j] = (k[j]//2) - (b[j][i] // 2) + b[j][i] = k[j] & 1 + k[j] = k[j]//2 + (b[j][i] & b[0][i]) return b @@ -160,9 +160,9 @@ def scalarMulGLV(scalar, P0): assert expected == decomp print('------ recode scalar -----------') - even = k0 & 1 == 1 + even = k0 & 1 == 0 if even: - k0 -= 1 + k0 += 1 b = recodeScalars([k0, k1]) print('b0: ' + str(list(reversed(b[0])))) @@ -173,14 +173,14 @@ def scalarMulGLV(scalar, P0): lut = buildLut(P0, P1) print('------------ mul ---------------') - print('b0 L-1: ' + str(b[0][L-1])) - Q = b[0][L-1] * lut[b[1][L-1] & 1] + # b[0][L-1] is always 0 + Q = lut[b[1][L-1]] for i in range(L-2, -1, -1): Q *= 2 - Q += b[0][i] * lut[b[1][i] & 1] + Q += (1 - 2 * b[0][i]) * lut[b[1][i]] if even: - Q += P0 + Q -= P0 print('final Q: ' + pointToString(Q)) print('expected: ' + pointToString(expected))