From 28e83e7b49de6312d9b61c2816fe503e9bd53ce3 Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Fri, 4 Sep 2020 19:04:32 +0200 Subject: [PATCH] Faster inversion with addition chains (#80) --- benchmarks/bench_fields_template.nim | 21 +- benchmarks/bench_fp.nim | 5 +- .../arithmetic/finite_fields_inversion.nim | 410 ++++++++++++++++-- 3 files changed, 404 insertions(+), 32 deletions(-) diff --git a/benchmarks/bench_fields_template.nim b/benchmarks/bench_fields_template.nim index 4bafa5b..da7c6fa 100644 --- a/benchmarks/bench_fields_template.nim +++ b/benchmarks/bench_fields_template.nim @@ -160,14 +160,29 @@ proc invBench*(T: typedesc, iters: int) = var r: T let x = rng.random_unsafe(T) preventOptimAway(r) - bench("Inversion (constant-time Euclid)", T, iters): + bench("Inversion (constant-time default method)", T, iters): r.inv(x) -proc powFermatInversionBench*(T: typedesc, iters: int) = +proc invEuclidBench*(T: typedesc, iters: int) = + var r: T let x = rng.random_unsafe(T) + preventOptimAway(r) + bench("Inversion via constant-time Euclid", T, iters): + r.inv_euclid(x) + +proc invPowFermatBench*(T: typedesc, iters: int) = + let x = rng.random_unsafe(T) + const exponent = T.C.getInvModExponent() bench("Inversion via exponentiation p-2 (Little Fermat)", T, iters): var r = x - r.powUnsafeExponent(T.C.getInvModExponent()) + r.powUnsafeExponent(exponent) + +proc invAddChainBench*(T: typedesc, iters: int) = + var r: T + let x = rng.random_unsafe(T) + preventOptimAway(r) + bench("Inversion via addition chain", T, iters): + r.inv_addchain(x) proc sqrtBench*(T: typedesc, iters: int) = let x = rng.random_unsafe(T) diff --git a/benchmarks/bench_fp.nim b/benchmarks/bench_fp.nim index ea510ee..2a718ce 100644 --- a/benchmarks/bench_fp.nim +++ b/benchmarks/bench_fp.nim @@ -50,8 +50,9 @@ proc main() = negBench(Fp[curve], Iters) mulBench(Fp[curve], Iters) sqrBench(Fp[curve], Iters) - invBench(Fp[curve], ExponentIters) - powFermatInversionBench(Fp[curve], ExponentIters) + invEuclidBench(Fp[curve], ExponentIters) + invPowFermatBench(Fp[curve], ExponentIters) + invAddChainBench(Fp[curve], ExponentIters) sqrtBench(Fp[curve], ExponentIters) # Exponentiation by a "secret" of size ~the curve order powBench(Fp[curve], ExponentIters) diff --git a/constantine/arithmetic/finite_fields_inversion.nim b/constantine/arithmetic/finite_fields_inversion.nim index 0057075..e64373d 100644 --- a/constantine/arithmetic/finite_fields_inversion.nim +++ b/constantine/arithmetic/finite_fields_inversion.nim @@ -18,13 +18,14 @@ import # ############################################################ # Field-specific inversion routines -template repeat(num: int, body: untyped) = +func square_repeated(r: var Fp, num: int) = + ## Repeated squarings for _ in 0 ..< num: - body + r.square() # Secp256k1 # ------------------------------------------------------------ -func invmod_addchain(r: var Fp[Secp256k1], a: Fp[Secp256k1]) {.used.}= +func inv_addchain(r: var Fp[Secp256k1], a: Fp[Secp256k1]) {.used.}= ## We invert via Little Fermat's theorem ## a^(-1) ≡ a^(p-2) (mod p) ## with p = "0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F" @@ -57,52 +58,256 @@ func invmod_addchain(r: var Fp[Secp256k1], a: Fp[Secp256k1]) {.used.}= x3 *= a x6 = x3 - repeat 3: x6.square() + x6.square_repeated(3) x6 *= x3 x9 = x6 - repeat 3: x9.square() + x9.square_repeated(3) x9 *= x3 x11 = x9 - repeat 2: x11.square() + x11.square_repeated(2) x11 *= x2 x22 = x11 - repeat 11: x22.square() + x22.square_repeated(11) x22 *= x11 x44 = x22 - repeat 22: x44.square() + x44.square_repeated(22) x44 *= x22 x88 = x44 - repeat 44: x88.square() + x88.square_repeated(44) x88 *= x44 x176 = x88 - repeat 88: x88.square() + x88.square_repeated(88) x176 *= x88 x220 = x176 - repeat 44: x220.square() + x220.square_repeated(44) x220 *= x44 x223 = x220 - repeat 3: x223.square() + x223.square_repeated(3) x223 *= x3 # The final result is then assembled using a sliding window over the blocks r = x223 - repeat 23: r.square() + r.square_repeated(23) r *= x22 - repeat 5: r.square() + r.square_repeated(5) r *= a - repeat 3: r.square() + r.square_repeated(3) r *= x2 - repeat 2: r.square() + r.square_repeated(2) r *= a +# BLS12-381 +# ------------------------------------------------------------ +func inv_addchain*(r: var Fp[BLS12_381], a: Fp[BLS12_381]) = + var + x10 {.noinit.}: Fp[BLS12_381] + x100 {.noinit.}: Fp[BLS12_381] + x1000 {.noinit.}: Fp[BLS12_381] + x1001 {.noinit.}: Fp[BLS12_381] + x1011 {.noinit.}: Fp[BLS12_381] + x1101 {.noinit.}: Fp[BLS12_381] + x10001 {.noinit.}: Fp[BLS12_381] + x10100 {.noinit.}: Fp[BLS12_381] + x11001 {.noinit.}: Fp[BLS12_381] + x11010 {.noinit.}: Fp[BLS12_381] + x110100 {.noinit.}: Fp[BLS12_381] + x110110 {.noinit.}: Fp[BLS12_381] + x110111 {.noinit.}: Fp[BLS12_381] + x1001101 {.noinit.}: Fp[BLS12_381] + x1001111 {.noinit.}: Fp[BLS12_381] + x1010101 {.noinit.}: Fp[BLS12_381] + x1011101 {.noinit.}: Fp[BLS12_381] + x1100111 {.noinit.}: Fp[BLS12_381] + x1101001 {.noinit.}: Fp[BLS12_381] + x1110111 {.noinit.}: Fp[BLS12_381] + x1111011 {.noinit.}: Fp[BLS12_381] + x10001001 {.noinit.}: Fp[BLS12_381] + x10010101 {.noinit.}: Fp[BLS12_381] + x10010111 {.noinit.}: Fp[BLS12_381] + x10101001 {.noinit.}: Fp[BLS12_381] + x10110001 {.noinit.}: Fp[BLS12_381] + x10111111 {.noinit.}: Fp[BLS12_381] + x11000011 {.noinit.}: Fp[BLS12_381] + x11010000 {.noinit.}: Fp[BLS12_381] + x11010111 {.noinit.}: Fp[BLS12_381] + x11100001 {.noinit.}: Fp[BLS12_381] + x11100101 {.noinit.}: Fp[BLS12_381] + x11101011 {.noinit.}: Fp[BLS12_381] + x11110101 {.noinit.}: Fp[BLS12_381] + x11111111 {.noinit.}: Fp[BLS12_381] + + x10 .square(a) + x100 .square(x10) + x1000 .square(x100) + x1001 .prod(a, x1000) + x1011 .prod(x10, x1001) + x1101 .prod(x10, x1011) + x10001 .prod(x100, x1101) + x10100 .prod(x1001, x1011) + x11001 .prod(x1000, x10001) + x11010 .prod(a, x11001) + x110100 .square(x11010) + x110110 .prod(x10, x110100) + x110111 .prod(a, x110110) + x1001101 .prod(x11001, x110100) + x1001111 .prod(x10, x1001101) + x1010101 .prod(x1000, x1001101) + x1011101 .prod(x1000, x1010101) + x1100111 .prod(x11010, x1001101) + x1101001 .prod(x10, x1100111) + x1110111 .prod(x11010, x1011101) + x1111011 .prod(x100, x1110111) + x10001001 .prod(x110100, x1010101) + x10010101 .prod(x11010, x1111011) + x10010111 .prod(x10, x10010101) + x10101001 .prod(x10100, x10010101) + x10110001 .prod(x1000, x10101001) + x10111111 .prod(x110110, x10001001) + x11000011 .prod(x100, x10111111) + x11010000 .prod(x1101, x11000011) + x11010111 .prod(x10100, x11000011) + x11100001 .prod(x10001, x11010000) + x11100101 .prod(x100, x11100001) + x11101011 .prod(x10100, x11010111) + x11110101 .prod(x10100, x11100001) + x11111111 .prod(x10100, x11101011) # 35 operations + + # TODO: we can accumulate in a partially reduced + # doubled-size `r` to avoid the final substractions. + # and only reduce at the end. + # This requires the number of op to be less than log2(p) == 381 + + # 35 + 22 = 57 operations + r.prod(x10111111, x11100001) + r.square_repeated(8) + r *= x10001 + r.square_repeated(11) + r *= x11110101 + + # 57 + 28 = 85 operations + r.square_repeated(11) + r *= x11100101 + r.square_repeated(8) + r *= x11111111 + r.square_repeated(7) + + # 88 + 22 = 107 operations + r *= x1001101 + r.square_repeated(9) + r *= x1101001 + r.square_repeated(10) + r *= x10110001 + + # 107+24 = 131 operations + r.square_repeated(7) + r *= x1011101 + r.square_repeated(9) + r *= x1111011 + r.square_repeated(6) + + # 131+23 = 154 operations + r *= x11001 + r.square_repeated(11) + r *= x1101001 + r.square_repeated(9) + r *= x11101011 + + # 154+28 = 182 operations + r.square_repeated(10) + r *= x11010111 + r.square_repeated(6) + r *= x11001 + r.square_repeated(10) + + # 182+23 = 205 operations + r *= x1110111 + r.square_repeated(9) + r *= x10010111 + r.square_repeated(11) + r *= x1001111 + + # 205+30 = 235 operations + r.square_repeated(10) + r *= x11100001 + r.square_repeated(9) + r *= x10001001 + r.square_repeated(9) + + # 235+21 = 256 operations + r *= x10111111 + r.square_repeated(8) + r *= x1100111 + r.square_repeated(10) + r *= x11000011 + + # 256+28 = 284 operations + r.square_repeated(9) + r *= x10010101 + r.square_repeated(12) + r *= x1111011 + r.square_repeated(5) + + # 284 + 21 = 305 operations + r *= x1011 + r.square_repeated(11) + r *= x1111011 + r.square_repeated(7) + r *= x1001 + + # 305+32 = 337 operations + r.square_repeated(13) + r *= x11110101 + r.square_repeated(9) + r *= x10111111 + r.square_repeated(8) + + # 337+22 = 359 operations + r *= x11111111 + r.square_repeated(8) + r *= x11101011 + r.square_repeated(11) + r *= x10101001 + + # 359+24 = 383 operations + r.square_repeated(8) + r *= x11111111 + r.square_repeated(8) + r *= x11111111 + r.square_repeated(6) + + # 383+22 = 405 operations + r *= x110111 + r.square_repeated(10) + r *= x11111111 + r.square_repeated(9) + r *= x11111111 + + # 405+26 = 431 operations + r.square_repeated(8) + r *= x11111111 + r.square_repeated(8) + r *= x11111111 + r.square_repeated(8) + + # 431+19 = 450 operations + r *= x11111111 + r.square_repeated(7) + r *= x1010101 + r.square_repeated(9) + r *= x10101001 + + # Total 450 operations: + # - 74 multiplications + # - 376 squarings + # BN Curves # ------------------------------------------------------------ # Efficient Pairings and ECC for Embedded Systems @@ -117,16 +322,20 @@ func invmod_addchain(r: var Fp[Secp256k1], a: Fp[Secp256k1]) {.used.}= # = a^(36 u^4 + 36 u^3 + 24 u^2 + 6u + 1 - 2) mod p # = a^(36 u^4) . a^(36 u^3) . a^(24 u^2) . a^(6u-1) mod p # -# Note: it only works for u positive, in particular BN254 doesn't work :/ +# Note: it only works for u positive, in particular BN254_Nogami doesn't work :/ # Is there a way to only use a^-u or even powers? -func invmod_addchain_bn[C](r: var Fp[C], a: Fp[C]) {.used.}= +func inv_addchain_bn[C](r: var Fp[C], a: Fp[C]) {.used.}= ## Inversion on BN prime fields with positive base parameter `u` ## via Little Fermat theorem and leveraging the prime low Hamming weight ## ## Requires a `bn` curve with a positive parameter `u` # TODO: debug for input "0x0d2007d8aaface1b8501bfbe792974166e8f9ad6106e5b563604f0aea9ab06f6" - # see test suite + # on BN254_Snarks see test suite (but works in Sage so aliasing issue?) + # + # For BN254_Snarks `u` and `6u-1` exponentiation are not fast enough + # (even with dedicated addchains) + # compared to an addchain on the full prime modulus static: doAssert C.canUse_BN_AddchainInversion() var v0 {.noInit.}, v1 {.noInit.}: Fp[C] @@ -144,12 +353,159 @@ func invmod_addchain_bn[C](r: var Fp[C], a: Fp[C]) {.used.}= v1.powUnsafeExponent(C.getBN_param_u_BE()) # v1 <- a^(36u⁴) r *= v1 # r <- a^(36u⁴) a^(36u³) a^(24u²) a^(6u-1) = a^(p-2) = a^(-1) +func inv_addchain*(r: var Fp[BN254_Snarks], a: Fp[BN254_Snarks]) = + var + x10 {.noInit.}: Fp[BN254_Snarks] + x11 {.noInit.}: Fp[BN254_Snarks] + x101 {.noInit.}: Fp[BN254_Snarks] + x110 {.noInit.}: Fp[BN254_Snarks] + x1000 {.noInit.}: Fp[BN254_Snarks] + x1101 {.noInit.}: Fp[BN254_Snarks] + x10010 {.noInit.}: Fp[BN254_Snarks] + x10011 {.noInit.}: Fp[BN254_Snarks] + x10100 {.noInit.}: Fp[BN254_Snarks] + x10111 {.noInit.}: Fp[BN254_Snarks] + x11100 {.noInit.}: Fp[BN254_Snarks] + x100000 {.noInit.}: Fp[BN254_Snarks] + x100011 {.noInit.}: Fp[BN254_Snarks] + x101011 {.noInit.}: Fp[BN254_Snarks] + x101111 {.noInit.}: Fp[BN254_Snarks] + x1000001 {.noInit.}: Fp[BN254_Snarks] + x1010011 {.noInit.}: Fp[BN254_Snarks] + x1011011 {.noInit.}: Fp[BN254_Snarks] + x1100001 {.noInit.}: Fp[BN254_Snarks] + x1110101 {.noInit.}: Fp[BN254_Snarks] + x10010001 {.noInit.}: Fp[BN254_Snarks] + x10010101 {.noInit.}: Fp[BN254_Snarks] + x10110101 {.noInit.}: Fp[BN254_Snarks] + x10111011 {.noInit.}: Fp[BN254_Snarks] + x11000001 {.noInit.}: Fp[BN254_Snarks] + x11000011 {.noInit.}: Fp[BN254_Snarks] + x11010011 {.noInit.}: Fp[BN254_Snarks] + x11100001 {.noInit.}: Fp[BN254_Snarks] + x11100011 {.noInit.}: Fp[BN254_Snarks] + x11100111 {.noInit.}: Fp[BN254_Snarks] + + x10 .square(a) + x11 .prod(x10, a) + x101 .prod(x10, x11) + x110 .prod(x101, a) + x1000 .prod(x10, x110) + x1101 .prod(x101, x1000) + x10010 .prod(x101, x1101) + x10011 .prod(x10010, a) + x10100 .prod(x10011, a) + x10111 .prod(x11, x10100) + x11100 .prod(x101, x10111) + x100000 .prod(x1101, x10011) + x100011 .prod(x11, x100000) + x101011 .prod(x1000, x100011) + x101111 .prod(x10011, x11100) + x1000001 .prod(x10010, x101111) + x1010011 .prod(x10010, x1000001) + x1011011 .prod(x1000, x1010011) + x1100001 .prod(x110, x1011011) + x1110101 .prod(x10100, x1100001) + x10010001 .prod(x11100, x1110101) + x10010101 .prod(x100000, x1110101) + x10110101 .prod(x100000, x10010101) + x10111011 .prod(x110, x10110101) + x11000001 .prod(x110, x10111011) + x11000011 .prod(x10, x11000001) + x11010011 .prod(x10010, x11000001) + x11100001 .prod(x100000, x11000001) + x11100011 .prod(x10, x11100001) + x11100111 .prod(x110, x11100001) # 30 operations + + # 30 + 27 = 57 operations + r.square(x11000001) + r.square_repeated(7) + r *= x10010001 + r.square_repeated(10) + r *= x11100111 + r.square_repeated(7) + + # 57 + 19 = 76 operations + r *= x10111 + r.square_repeated(9) + r *= x10011 + r.square_repeated(7) + r *= x1101 + + # 76 + 33 = 109 operations + r.square_repeated(14) + r *= x1010011 + r.square_repeated(9) + r *= x11100001 + r.square_repeated(8) + + # 109 + 18 = 127 operations + r *= x1000001 + r.square_repeated(10) + r *= x1011011 + r.square_repeated(5) + r *= x1101 + + # 127 + 34 = 161 operations + r.square_repeated(8) + r *= x11 + r.square_repeated(12) + r *= x101011 + r.square_repeated(12) + + # 161 + 25 = 186 operations + r *= x10111011 + r.square_repeated(8) + r *= x101111 + r.square_repeated(14) + r *= x10110101 + + # 186 + 28 = 214 + r.square_repeated(9) + r *= x10010001 + r.square_repeated(5) + r *= x1101 + r.square_repeated(12) + + # 214 + 22 = 236 + r *= x11100011 + r.square_repeated(8) + r *= x10010101 + r.square_repeated(11) + r *= x11010011 + + # 236 + 32 = 268 + r.square_repeated(7) + r *= x1100001 + r.square_repeated(11) + r *= x100011 + r.square_repeated(12) + + # 268 + 20 = 288 + r *= x1011011 + r.square_repeated(9) + r *= x11000011 + r.square_repeated(8) + r *= x11100111 + + # 288 + 15 = 303 + r.square_repeated(7) + r *= x1110101 + r.square_repeated(6) + r *= x101 + # ############################################################ # # Dispatch # # ############################################################ +func inv_euclid*(r: var Fp, a: Fp) = + ## Inversion modulo p via + ## Niels Moller constant-time version of + ## Stein's GCD derived from extended binary Euclid algorithm + r.mres.steinsGCD(a.mres, Fp.C.getR2modP(), Fp.C.Mod, Fp.C.getPrimePlus1div2()) + func inv*(r: var Fp, a: Fp) = ## Inversion modulo p ## @@ -161,10 +517,10 @@ func inv*(r: var Fp, a: Fp) = # neither for Secp256k1 nor BN curves # Performance is slower than GCD # To be revisited with faster squaring/multiplications - when false: # Fp.C.canUse_BN_AddchainInversion(): - r.invmod_addchain_bn(a) + when Fp.C in {BN254_Snarks, BLS12_381}: + r.inv_addchain(a) else: - r.mres.steinsGCD(a.mres, Fp.C.getR2modP(), Fp.C.Mod, Fp.C.getPrimePlus1div2()) + r.inv_euclid(a) func inv*(a: var Fp) = ## Inversion modulo p @@ -174,9 +530,9 @@ func inv*(a: var Fp) = ## to convert Jacobian and Projective coordinates ## to affine for elliptic curve # For now we don't activate the addition chains - # neither for Secp256k1 nor BN curves + # for Secp256k1 nor BN curves # Performance is slower than GCD - # To be revisited with faster squaring/multiplications - var t: typeof(a) # TODO: zero-init needed? - t.inv(a) - a = t + when Fp.C in {BN254_Snarks, BLS12_381}: + a.inv_addchain(a) + else: + a.inv_euclid(a)