diff --git a/constantine.nimble b/constantine.nimble index fe4ab29..5215559 100644 --- a/constantine.nimble +++ b/constantine.nimble @@ -23,3 +23,4 @@ task test, "Run all tests": test "", "tests/test_bigints_vs_gmp.nim" test "", "tests/test_finite_fields.nim" test "", "tests/test_finite_fields_vs_gmp.nim" + test "", "tests/test_finite_fields_powinv.nim" diff --git a/constantine/math/bigints_checked.nim b/constantine/math/bigints_checked.nim index 51261aa..0696708 100644 --- a/constantine/math/bigints_checked.nim +++ b/constantine/math/bigints_checked.nim @@ -185,3 +185,33 @@ func montyPow*[mBits, eBits: static int]( scratchPtrs[i] = scratchSpace[i].view() montyPow(a.view, expBE, M.view, one.view, Word(negInvModWord), scratchPtrs) + +func montyPowUnsafeExponent*[mBits, eBits: static int]( + a: var BigInt[mBits], exponent: BigInt[eBits], + M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int) = + ## Compute a <- a^exponent (mod M) + ## ``a`` in the Montgomery domain + ## ``exponent`` is any BigInt, in the canonical domain + ## + ## Warning ⚠️ : + ## This is an optimization for public exponent + ## Otherwise bits of the exponent can be retrieved with: + ## - memory access analysis + ## - power analysis + ## - timing analysis + ## + ## This uses fixed window optimization + ## A window size in the range [1, 5] must be chosen + mixin exportRawUint # exported in io_bigints which depends on this module ... + + var expBE {.noInit.}: array[(ebits + 7) div 8, byte] + expBE.exportRawUint(exponent, bigEndian) + + const scratchLen = if windowSize == 1: 2 + else: (1 shl windowSize) + 1 + var scratchSpace {.noInit.}: array[scratchLen, BigInt[mBits]] + var scratchPtrs {.noInit.}: array[scratchLen, BigIntViewMut] + for i in 0 ..< scratchLen: + scratchPtrs[i] = scratchSpace[i].view() + + montyPowUnsafeExponent(a.view, expBE, M.view, one.view, Word(negInvModWord), scratchPtrs) diff --git a/constantine/math/bigints_raw.nim b/constantine/math/bigints_raw.nim index e71321f..c8c99a8 100644 --- a/constantine/math/bigints_raw.nim +++ b/constantine/math/bigints_raw.nim @@ -602,6 +602,69 @@ func getWindowLen(bufLen: int): uint = while (1 shl result) + 1 > bufLen: dec result +func montyPowPrologue( + a: BigIntViewMut, M, one: BigIntViewConst, + negInvModWord: Word, + scratchspace: openarray[BigIntViewMut] + ): tuple[window: uint, bigIntSize: int] = + + result.window = scratchspace.len.getWindowLen() + result.bigIntSize = a.numLimbs() * sizeof(Word) + sizeof(BigIntView.bitLength) + + # Precompute window content, special case for window = 1 + # (i.e scratchspace has only space for 2 temporaries) + # The content scratchspace[2+k] is set at a^k + # with scratchspace[0] untouched + if result.window == 1: + copyMem(pointer scratchspace[1], pointer a, result.bigIntSize) + else: + copyMem(pointer scratchspace[2], pointer a, result.bigIntSize) + for k in 2 ..< 1 shl result.window: + scratchspace[k+1].montyMul(scratchspace[k], a, M, negInvModWord) + + scratchspace[1].setBitLength(bitSizeof(M)) + + # Set a to one + copyMem(pointer a, pointer one, result.bigIntSize) + +func montyPowSquarings( + a: BigIntViewMut, + exponent: openarray[byte], + M: BigIntViewConst, + negInvModWord: Word, + tmp: BigIntViewMut, + window: uint, + bigIntSize: int, + acc, acc_len: var uint, + e: var int, + ): tuple[k, bits: uint] = + ## Squaring step of exponentiation by squaring + ## Get the next k bits in range [1, window) + ## Square k times + ## Returns the number of squarings done and the corresponding bits + ## + ## Updates iteration variables and accumulators + + # Get the next bits + var k = window + if acc_len < window: + if e < exponent.len: + acc = (acc shl 8) or exponent[e].uint + inc e + acc_len += 8 + else: # Drained all exponent bits + k = acc_len + + let bits = (acc shr (acc_len - k)) and ((1'u32 shl k) - 1) + acc_len -= k + + # We have k bits and can do k squaring + for i in 0 ..< k: + tmp.montyMul(a, a, M, negInvModWord) + copyMem(pointer a, pointer tmp, bigIntSize) + + return (k, bits) + func montyPow*( a: BigIntViewMut, exponent: openarray[byte], @@ -610,7 +673,7 @@ func montyPow*( scratchspace: openarray[BigIntViewMut] ) = ## Modular exponentiation r = a^exponent mod M - ## in the montgomery domain + ## in the Montgomery domain ## ## This uses fixed-window optimization if possible ## @@ -635,24 +698,7 @@ func montyPow*( ## A window of size 5 requires (2^5 + 1)*(381 + 7)/8 = 33 * 48 bytes = 1584 bytes ## of scratchspace (on the stack). - let window = scratchspace.len.getWindowLen() - let bigIntSize = a.numLimbs() * sizeof(Word) + sizeof(BigIntView.bitLength) - - # Precompute window content, special case for window = 1 - # (i.e scratchspace has only space for 2 temporaries) - # The content scratchspace[2+k] is set at a^k - # with scratchspace[0] untouched - if window == 1: - copyMem(pointer scratchspace[1], pointer a, bigIntSize) - else: - copyMem(pointer scratchspace[2], pointer a, bigIntSize) - for k in 2 ..< 1 shl window: - scratchspace[k+1].montyMul(scratchspace[k], a, M, negInvModWord) - - scratchspace[1].setBitLength(bitSizeof(M)) - - # Set a to one - copyMem(pointer a, pointer one, bigIntSize) + let (window, bigIntSize) = montyPowPrologue(a, M, one, negInvModWord, scratchspace) # We process bits with from most to least significant. # At each loop iteration with have acc_len bits in acc. @@ -663,23 +709,12 @@ func montyPow*( acc, acc_len: uint e = 0 while acc_len > 0 or e < exponent.len: - # Get the next bits - var k = window - if acc_len < window: - if e < exponent.len: - acc = (acc shl 8) or exponent[e].uint - inc e - acc_len += 8 - else: # Drained all exponent bits - k = acc_len + let (k, bits) = montyPowSquarings( + a, exponent, M, negInvModWord, + scratchspace[0], window, bigIntSize, + acc, acc_len, e + ) - let bits = (acc shr (acc_len - k)) and ((1'u32 shl k) - 1) - acc_len -= k - - # We have k bits and can do k squaring - for i in 0 ..< k: - scratchspace[0].montyMul(a, a, M, negInvModWord) - copyMem(pointer a, pointer scratchspace[0], bigIntSize) # Window lookup: we set scratchspace[1] to the lookup value. # If the window length is 1, then it's already set. if window > 1: @@ -694,3 +729,44 @@ func montyPow*( # we keep the product only if the exponent bits are not all zero scratchspace[0].montyMul(a, scratchspace[1], M, negInvModWord) a.cmov(scratchspace[0], Word(bits) != Zero) + +func montyPowUnsafeExponent*( + a: BigIntViewMut, + exponent: openarray[byte], + M, one: BigIntViewConst, + negInvModWord: Word, + scratchspace: openarray[BigIntViewMut] + ) = + ## Modular exponentiation r = a^exponent mod M + ## in the Montgomery domain + ## + ## Warning ⚠️ : + ## This is an optimization for public exponent + ## Otherwise bits of the exponent can be retrieved with: + ## - memory access analysis + ## - power analysis + ## - timing analysis + + # TODO: scratchspace[1] is unused when window > 1 + + let (window, bigIntSize) = montyPowPrologue( + a, M, one, negInvModWord, scratchspace) + + var + acc, acc_len: uint + e = 0 + while acc_len > 0 or e < exponent.len: + let (k, bits) = montyPowSquarings( + a, exponent, M, negInvModWord, + scratchspace[0], window, bigIntSize, + acc, acc_len, e + ) + + ## Warning ⚠️: Exposes the exponent bits + if bits != 0: + if window > 1: + scratchspace[0].montyMul(a, scratchspace[1+bits], M, negInvModWord) + else: + # scratchspace[1] holds the original `a` + scratchspace[0].montyMul(a, scratchspace[1], M, negInvModWord) + copyMem(pointer a, pointer scratchspace[0], bigIntSize) diff --git a/constantine/math/finite_fields.nim b/constantine/math/finite_fields.nim index 1bda6a6..abbcbe9 100644 --- a/constantine/math/finite_fields.nim +++ b/constantine/math/finite_fields.nim @@ -135,3 +135,17 @@ func pow*(a: var Fq, exponent: BigInt) = ## ``exponent``: a big integer const windowSize = 5 # TODO: find best window size for each curves a.mres.montyPow(exponent, Fq.C.Mod.mres, Fq.C.getMontyOne(), Fq.C.getNegInvModWord(), windowSize) + +func powUnsafeExponent*(a: var Fq, exponent: BigInt) = + ## Exponentiation over Fq + ## ``a``: a field element to be exponentiated + ## ``exponent``: a big integer + ## + ## Warning ⚠️ : + ## This is an optimization for public exponent + ## Otherwise bits of the exponent can be retrieved with: + ## - memory access analysis + ## - power analysis + ## - timing analysis + const windowSize = 5 # TODO: find best window size for each curves + a.mres.montyPowUnsafeExponent(exponent, Fq.C.Mod.mres, Fq.C.getMontyOne(), Fq.C.getNegInvModWord(), windowSize) diff --git a/tests/test_finite_fields_powinv.nim b/tests/test_finite_fields_powinv.nim index dae0b99..1bb2ef3 100644 --- a/tests/test_finite_fields_powinv.nim +++ b/tests/test_finite_fields_powinv.nim @@ -93,19 +93,36 @@ proc main() = 20'u64 == r test "x^(p-2) mod p (modular inversion if p prime)": - var x: Fq[BLS12_381] + block: + var x: Fq[BLS12_381] - # BN254 field modulus - x.fromHex("0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47") - # BLS12-381 prime - 2 - let exponent = BigInt[381].fromHex("0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaa9") + # BN254 field modulus + x.fromHex("0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47") + # BLS12-381 prime - 2 + let exponent = BigInt[381].fromHex("0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaa9") - let expected = "0x0636759a0f3034fa47174b2c0334902f11e9915b7bd89c6a2b3082b109abbc9837da17201f6d8286fe6203caa1b9d4c8" + let expected = "0x0636759a0f3034fa47174b2c0334902f11e9915b7bd89c6a2b3082b109abbc9837da17201f6d8286fe6203caa1b9d4c8" - x.pow(exponent) - let computed = x.toHex() + x.pow(exponent) + let computed = x.toHex() - check: - computed == expected + check: + computed == expected + + block: + var x: Fq[BLS12_381] + + # BN254 field modulus + x.fromHex("0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47") + # BLS12-381 prime - 2 + let exponent = BigInt[381].fromHex("0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaa9") + + let expected = "0x0636759a0f3034fa47174b2c0334902f11e9915b7bd89c6a2b3082b109abbc9837da17201f6d8286fe6203caa1b9d4c8" + + x.powUnsafeExponent(exponent) + let computed = x.toHex() + + check: + computed == expected main()