Faster inversion with addition chains (#80)

This commit is contained in:
Mamy Ratsimbazafy 2020-09-04 19:04:32 +02:00 committed by GitHub
parent c2313ad697
commit 28e83e7b49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 404 additions and 32 deletions

View File

@ -160,14 +160,29 @@ proc invBench*(T: typedesc, iters: int) =
var r: T var r: T
let x = rng.random_unsafe(T) let x = rng.random_unsafe(T)
preventOptimAway(r) preventOptimAway(r)
bench("Inversion (constant-time Euclid)", T, iters): bench("Inversion (constant-time default method)", T, iters):
r.inv(x) r.inv(x)
proc powFermatInversionBench*(T: typedesc, iters: int) = proc invEuclidBench*(T: typedesc, iters: int) =
var r: T
let x = rng.random_unsafe(T) let x = rng.random_unsafe(T)
preventOptimAway(r)
bench("Inversion via constant-time Euclid", T, iters):
r.inv_euclid(x)
proc invPowFermatBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T)
const exponent = T.C.getInvModExponent()
bench("Inversion via exponentiation p-2 (Little Fermat)", T, iters): bench("Inversion via exponentiation p-2 (Little Fermat)", T, iters):
var r = x var r = x
r.powUnsafeExponent(T.C.getInvModExponent()) r.powUnsafeExponent(exponent)
proc invAddChainBench*(T: typedesc, iters: int) =
var r: T
let x = rng.random_unsafe(T)
preventOptimAway(r)
bench("Inversion via addition chain", T, iters):
r.inv_addchain(x)
proc sqrtBench*(T: typedesc, iters: int) = proc sqrtBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T) let x = rng.random_unsafe(T)

View File

@ -50,8 +50,9 @@ proc main() =
negBench(Fp[curve], Iters) negBench(Fp[curve], Iters)
mulBench(Fp[curve], Iters) mulBench(Fp[curve], Iters)
sqrBench(Fp[curve], Iters) sqrBench(Fp[curve], Iters)
invBench(Fp[curve], ExponentIters) invEuclidBench(Fp[curve], ExponentIters)
powFermatInversionBench(Fp[curve], ExponentIters) invPowFermatBench(Fp[curve], ExponentIters)
invAddChainBench(Fp[curve], ExponentIters)
sqrtBench(Fp[curve], ExponentIters) sqrtBench(Fp[curve], ExponentIters)
# Exponentiation by a "secret" of size ~the curve order # Exponentiation by a "secret" of size ~the curve order
powBench(Fp[curve], ExponentIters) powBench(Fp[curve], ExponentIters)

View File

@ -18,13 +18,14 @@ import
# ############################################################ # ############################################################
# Field-specific inversion routines # Field-specific inversion routines
template repeat(num: int, body: untyped) = func square_repeated(r: var Fp, num: int) =
## Repeated squarings
for _ in 0 ..< num: for _ in 0 ..< num:
body r.square()
# Secp256k1 # Secp256k1
# ------------------------------------------------------------ # ------------------------------------------------------------
func invmod_addchain(r: var Fp[Secp256k1], a: Fp[Secp256k1]) {.used.}= func inv_addchain(r: var Fp[Secp256k1], a: Fp[Secp256k1]) {.used.}=
## We invert via Little Fermat's theorem ## We invert via Little Fermat's theorem
## a^(-1) ≡ a^(p-2) (mod p) ## a^(-1) ≡ a^(p-2) (mod p)
## with p = "0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F" ## with p = "0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F"
@ -57,52 +58,256 @@ func invmod_addchain(r: var Fp[Secp256k1], a: Fp[Secp256k1]) {.used.}=
x3 *= a x3 *= a
x6 = x3 x6 = x3
repeat 3: x6.square() x6.square_repeated(3)
x6 *= x3 x6 *= x3
x9 = x6 x9 = x6
repeat 3: x9.square() x9.square_repeated(3)
x9 *= x3 x9 *= x3
x11 = x9 x11 = x9
repeat 2: x11.square() x11.square_repeated(2)
x11 *= x2 x11 *= x2
x22 = x11 x22 = x11
repeat 11: x22.square() x22.square_repeated(11)
x22 *= x11 x22 *= x11
x44 = x22 x44 = x22
repeat 22: x44.square() x44.square_repeated(22)
x44 *= x22 x44 *= x22
x88 = x44 x88 = x44
repeat 44: x88.square() x88.square_repeated(44)
x88 *= x44 x88 *= x44
x176 = x88 x176 = x88
repeat 88: x88.square() x88.square_repeated(88)
x176 *= x88 x176 *= x88
x220 = x176 x220 = x176
repeat 44: x220.square() x220.square_repeated(44)
x220 *= x44 x220 *= x44
x223 = x220 x223 = x220
repeat 3: x223.square() x223.square_repeated(3)
x223 *= x3 x223 *= x3
# The final result is then assembled using a sliding window over the blocks # The final result is then assembled using a sliding window over the blocks
r = x223 r = x223
repeat 23: r.square() r.square_repeated(23)
r *= x22 r *= x22
repeat 5: r.square() r.square_repeated(5)
r *= a r *= a
repeat 3: r.square() r.square_repeated(3)
r *= x2 r *= x2
repeat 2: r.square() r.square_repeated(2)
r *= a r *= a
# BLS12-381
# ------------------------------------------------------------
func inv_addchain*(r: var Fp[BLS12_381], a: Fp[BLS12_381]) =
var
x10 {.noinit.}: Fp[BLS12_381]
x100 {.noinit.}: Fp[BLS12_381]
x1000 {.noinit.}: Fp[BLS12_381]
x1001 {.noinit.}: Fp[BLS12_381]
x1011 {.noinit.}: Fp[BLS12_381]
x1101 {.noinit.}: Fp[BLS12_381]
x10001 {.noinit.}: Fp[BLS12_381]
x10100 {.noinit.}: Fp[BLS12_381]
x11001 {.noinit.}: Fp[BLS12_381]
x11010 {.noinit.}: Fp[BLS12_381]
x110100 {.noinit.}: Fp[BLS12_381]
x110110 {.noinit.}: Fp[BLS12_381]
x110111 {.noinit.}: Fp[BLS12_381]
x1001101 {.noinit.}: Fp[BLS12_381]
x1001111 {.noinit.}: Fp[BLS12_381]
x1010101 {.noinit.}: Fp[BLS12_381]
x1011101 {.noinit.}: Fp[BLS12_381]
x1100111 {.noinit.}: Fp[BLS12_381]
x1101001 {.noinit.}: Fp[BLS12_381]
x1110111 {.noinit.}: Fp[BLS12_381]
x1111011 {.noinit.}: Fp[BLS12_381]
x10001001 {.noinit.}: Fp[BLS12_381]
x10010101 {.noinit.}: Fp[BLS12_381]
x10010111 {.noinit.}: Fp[BLS12_381]
x10101001 {.noinit.}: Fp[BLS12_381]
x10110001 {.noinit.}: Fp[BLS12_381]
x10111111 {.noinit.}: Fp[BLS12_381]
x11000011 {.noinit.}: Fp[BLS12_381]
x11010000 {.noinit.}: Fp[BLS12_381]
x11010111 {.noinit.}: Fp[BLS12_381]
x11100001 {.noinit.}: Fp[BLS12_381]
x11100101 {.noinit.}: Fp[BLS12_381]
x11101011 {.noinit.}: Fp[BLS12_381]
x11110101 {.noinit.}: Fp[BLS12_381]
x11111111 {.noinit.}: Fp[BLS12_381]
x10 .square(a)
x100 .square(x10)
x1000 .square(x100)
x1001 .prod(a, x1000)
x1011 .prod(x10, x1001)
x1101 .prod(x10, x1011)
x10001 .prod(x100, x1101)
x10100 .prod(x1001, x1011)
x11001 .prod(x1000, x10001)
x11010 .prod(a, x11001)
x110100 .square(x11010)
x110110 .prod(x10, x110100)
x110111 .prod(a, x110110)
x1001101 .prod(x11001, x110100)
x1001111 .prod(x10, x1001101)
x1010101 .prod(x1000, x1001101)
x1011101 .prod(x1000, x1010101)
x1100111 .prod(x11010, x1001101)
x1101001 .prod(x10, x1100111)
x1110111 .prod(x11010, x1011101)
x1111011 .prod(x100, x1110111)
x10001001 .prod(x110100, x1010101)
x10010101 .prod(x11010, x1111011)
x10010111 .prod(x10, x10010101)
x10101001 .prod(x10100, x10010101)
x10110001 .prod(x1000, x10101001)
x10111111 .prod(x110110, x10001001)
x11000011 .prod(x100, x10111111)
x11010000 .prod(x1101, x11000011)
x11010111 .prod(x10100, x11000011)
x11100001 .prod(x10001, x11010000)
x11100101 .prod(x100, x11100001)
x11101011 .prod(x10100, x11010111)
x11110101 .prod(x10100, x11100001)
x11111111 .prod(x10100, x11101011) # 35 operations
# TODO: we can accumulate in a partially reduced
# doubled-size `r` to avoid the final substractions.
# and only reduce at the end.
# This requires the number of op to be less than log2(p) == 381
# 35 + 22 = 57 operations
r.prod(x10111111, x11100001)
r.square_repeated(8)
r *= x10001
r.square_repeated(11)
r *= x11110101
# 57 + 28 = 85 operations
r.square_repeated(11)
r *= x11100101
r.square_repeated(8)
r *= x11111111
r.square_repeated(7)
# 88 + 22 = 107 operations
r *= x1001101
r.square_repeated(9)
r *= x1101001
r.square_repeated(10)
r *= x10110001
# 107+24 = 131 operations
r.square_repeated(7)
r *= x1011101
r.square_repeated(9)
r *= x1111011
r.square_repeated(6)
# 131+23 = 154 operations
r *= x11001
r.square_repeated(11)
r *= x1101001
r.square_repeated(9)
r *= x11101011
# 154+28 = 182 operations
r.square_repeated(10)
r *= x11010111
r.square_repeated(6)
r *= x11001
r.square_repeated(10)
# 182+23 = 205 operations
r *= x1110111
r.square_repeated(9)
r *= x10010111
r.square_repeated(11)
r *= x1001111
# 205+30 = 235 operations
r.square_repeated(10)
r *= x11100001
r.square_repeated(9)
r *= x10001001
r.square_repeated(9)
# 235+21 = 256 operations
r *= x10111111
r.square_repeated(8)
r *= x1100111
r.square_repeated(10)
r *= x11000011
# 256+28 = 284 operations
r.square_repeated(9)
r *= x10010101
r.square_repeated(12)
r *= x1111011
r.square_repeated(5)
# 284 + 21 = 305 operations
r *= x1011
r.square_repeated(11)
r *= x1111011
r.square_repeated(7)
r *= x1001
# 305+32 = 337 operations
r.square_repeated(13)
r *= x11110101
r.square_repeated(9)
r *= x10111111
r.square_repeated(8)
# 337+22 = 359 operations
r *= x11111111
r.square_repeated(8)
r *= x11101011
r.square_repeated(11)
r *= x10101001
# 359+24 = 383 operations
r.square_repeated(8)
r *= x11111111
r.square_repeated(8)
r *= x11111111
r.square_repeated(6)
# 383+22 = 405 operations
r *= x110111
r.square_repeated(10)
r *= x11111111
r.square_repeated(9)
r *= x11111111
# 405+26 = 431 operations
r.square_repeated(8)
r *= x11111111
r.square_repeated(8)
r *= x11111111
r.square_repeated(8)
# 431+19 = 450 operations
r *= x11111111
r.square_repeated(7)
r *= x1010101
r.square_repeated(9)
r *= x10101001
# Total 450 operations:
# - 74 multiplications
# - 376 squarings
# BN Curves # BN Curves
# ------------------------------------------------------------ # ------------------------------------------------------------
# Efficient Pairings and ECC for Embedded Systems # Efficient Pairings and ECC for Embedded Systems
@ -117,16 +322,20 @@ func invmod_addchain(r: var Fp[Secp256k1], a: Fp[Secp256k1]) {.used.}=
# = a^(36 u^4 + 36 u^3 + 24 u^2 + 6u + 1 - 2) mod p # = a^(36 u^4 + 36 u^3 + 24 u^2 + 6u + 1 - 2) mod p
# = a^(36 u^4) . a^(36 u^3) . a^(24 u^2) . a^(6u-1) mod p # = a^(36 u^4) . a^(36 u^3) . a^(24 u^2) . a^(6u-1) mod p
# #
# Note: it only works for u positive, in particular BN254 doesn't work :/ # Note: it only works for u positive, in particular BN254_Nogami doesn't work :/
# Is there a way to only use a^-u or even powers? # Is there a way to only use a^-u or even powers?
func invmod_addchain_bn[C](r: var Fp[C], a: Fp[C]) {.used.}= func inv_addchain_bn[C](r: var Fp[C], a: Fp[C]) {.used.}=
## Inversion on BN prime fields with positive base parameter `u` ## Inversion on BN prime fields with positive base parameter `u`
## via Little Fermat theorem and leveraging the prime low Hamming weight ## via Little Fermat theorem and leveraging the prime low Hamming weight
## ##
## Requires a `bn` curve with a positive parameter `u` ## Requires a `bn` curve with a positive parameter `u`
# TODO: debug for input "0x0d2007d8aaface1b8501bfbe792974166e8f9ad6106e5b563604f0aea9ab06f6" # TODO: debug for input "0x0d2007d8aaface1b8501bfbe792974166e8f9ad6106e5b563604f0aea9ab06f6"
# see test suite # on BN254_Snarks see test suite (but works in Sage so aliasing issue?)
#
# For BN254_Snarks `u` and `6u-1` exponentiation are not fast enough
# (even with dedicated addchains)
# compared to an addchain on the full prime modulus
static: doAssert C.canUse_BN_AddchainInversion() static: doAssert C.canUse_BN_AddchainInversion()
var v0 {.noInit.}, v1 {.noInit.}: Fp[C] var v0 {.noInit.}, v1 {.noInit.}: Fp[C]
@ -144,12 +353,159 @@ func invmod_addchain_bn[C](r: var Fp[C], a: Fp[C]) {.used.}=
v1.powUnsafeExponent(C.getBN_param_u_BE()) # v1 <- a^(36u⁴) v1.powUnsafeExponent(C.getBN_param_u_BE()) # v1 <- a^(36u⁴)
r *= v1 # r <- a^(36u⁴) a^(36u³) a^(24u²) a^(6u-1) = a^(p-2) = a^(-1) r *= v1 # r <- a^(36u⁴) a^(36u³) a^(24u²) a^(6u-1) = a^(p-2) = a^(-1)
func inv_addchain*(r: var Fp[BN254_Snarks], a: Fp[BN254_Snarks]) =
var
x10 {.noInit.}: Fp[BN254_Snarks]
x11 {.noInit.}: Fp[BN254_Snarks]
x101 {.noInit.}: Fp[BN254_Snarks]
x110 {.noInit.}: Fp[BN254_Snarks]
x1000 {.noInit.}: Fp[BN254_Snarks]
x1101 {.noInit.}: Fp[BN254_Snarks]
x10010 {.noInit.}: Fp[BN254_Snarks]
x10011 {.noInit.}: Fp[BN254_Snarks]
x10100 {.noInit.}: Fp[BN254_Snarks]
x10111 {.noInit.}: Fp[BN254_Snarks]
x11100 {.noInit.}: Fp[BN254_Snarks]
x100000 {.noInit.}: Fp[BN254_Snarks]
x100011 {.noInit.}: Fp[BN254_Snarks]
x101011 {.noInit.}: Fp[BN254_Snarks]
x101111 {.noInit.}: Fp[BN254_Snarks]
x1000001 {.noInit.}: Fp[BN254_Snarks]
x1010011 {.noInit.}: Fp[BN254_Snarks]
x1011011 {.noInit.}: Fp[BN254_Snarks]
x1100001 {.noInit.}: Fp[BN254_Snarks]
x1110101 {.noInit.}: Fp[BN254_Snarks]
x10010001 {.noInit.}: Fp[BN254_Snarks]
x10010101 {.noInit.}: Fp[BN254_Snarks]
x10110101 {.noInit.}: Fp[BN254_Snarks]
x10111011 {.noInit.}: Fp[BN254_Snarks]
x11000001 {.noInit.}: Fp[BN254_Snarks]
x11000011 {.noInit.}: Fp[BN254_Snarks]
x11010011 {.noInit.}: Fp[BN254_Snarks]
x11100001 {.noInit.}: Fp[BN254_Snarks]
x11100011 {.noInit.}: Fp[BN254_Snarks]
x11100111 {.noInit.}: Fp[BN254_Snarks]
x10 .square(a)
x11 .prod(x10, a)
x101 .prod(x10, x11)
x110 .prod(x101, a)
x1000 .prod(x10, x110)
x1101 .prod(x101, x1000)
x10010 .prod(x101, x1101)
x10011 .prod(x10010, a)
x10100 .prod(x10011, a)
x10111 .prod(x11, x10100)
x11100 .prod(x101, x10111)
x100000 .prod(x1101, x10011)
x100011 .prod(x11, x100000)
x101011 .prod(x1000, x100011)
x101111 .prod(x10011, x11100)
x1000001 .prod(x10010, x101111)
x1010011 .prod(x10010, x1000001)
x1011011 .prod(x1000, x1010011)
x1100001 .prod(x110, x1011011)
x1110101 .prod(x10100, x1100001)
x10010001 .prod(x11100, x1110101)
x10010101 .prod(x100000, x1110101)
x10110101 .prod(x100000, x10010101)
x10111011 .prod(x110, x10110101)
x11000001 .prod(x110, x10111011)
x11000011 .prod(x10, x11000001)
x11010011 .prod(x10010, x11000001)
x11100001 .prod(x100000, x11000001)
x11100011 .prod(x10, x11100001)
x11100111 .prod(x110, x11100001) # 30 operations
# 30 + 27 = 57 operations
r.square(x11000001)
r.square_repeated(7)
r *= x10010001
r.square_repeated(10)
r *= x11100111
r.square_repeated(7)
# 57 + 19 = 76 operations
r *= x10111
r.square_repeated(9)
r *= x10011
r.square_repeated(7)
r *= x1101
# 76 + 33 = 109 operations
r.square_repeated(14)
r *= x1010011
r.square_repeated(9)
r *= x11100001
r.square_repeated(8)
# 109 + 18 = 127 operations
r *= x1000001
r.square_repeated(10)
r *= x1011011
r.square_repeated(5)
r *= x1101
# 127 + 34 = 161 operations
r.square_repeated(8)
r *= x11
r.square_repeated(12)
r *= x101011
r.square_repeated(12)
# 161 + 25 = 186 operations
r *= x10111011
r.square_repeated(8)
r *= x101111
r.square_repeated(14)
r *= x10110101
# 186 + 28 = 214
r.square_repeated(9)
r *= x10010001
r.square_repeated(5)
r *= x1101
r.square_repeated(12)
# 214 + 22 = 236
r *= x11100011
r.square_repeated(8)
r *= x10010101
r.square_repeated(11)
r *= x11010011
# 236 + 32 = 268
r.square_repeated(7)
r *= x1100001
r.square_repeated(11)
r *= x100011
r.square_repeated(12)
# 268 + 20 = 288
r *= x1011011
r.square_repeated(9)
r *= x11000011
r.square_repeated(8)
r *= x11100111
# 288 + 15 = 303
r.square_repeated(7)
r *= x1110101
r.square_repeated(6)
r *= x101
# ############################################################ # ############################################################
# #
# Dispatch # Dispatch
# #
# ############################################################ # ############################################################
func inv_euclid*(r: var Fp, a: Fp) =
## Inversion modulo p via
## Niels Moller constant-time version of
## Stein's GCD derived from extended binary Euclid algorithm
r.mres.steinsGCD(a.mres, Fp.C.getR2modP(), Fp.C.Mod, Fp.C.getPrimePlus1div2())
func inv*(r: var Fp, a: Fp) = func inv*(r: var Fp, a: Fp) =
## Inversion modulo p ## Inversion modulo p
## ##
@ -161,10 +517,10 @@ func inv*(r: var Fp, a: Fp) =
# neither for Secp256k1 nor BN curves # neither for Secp256k1 nor BN curves
# Performance is slower than GCD # Performance is slower than GCD
# To be revisited with faster squaring/multiplications # To be revisited with faster squaring/multiplications
when false: # Fp.C.canUse_BN_AddchainInversion(): when Fp.C in {BN254_Snarks, BLS12_381}:
r.invmod_addchain_bn(a) r.inv_addchain(a)
else: else:
r.mres.steinsGCD(a.mres, Fp.C.getR2modP(), Fp.C.Mod, Fp.C.getPrimePlus1div2()) r.inv_euclid(a)
func inv*(a: var Fp) = func inv*(a: var Fp) =
## Inversion modulo p ## Inversion modulo p
@ -174,9 +530,9 @@ func inv*(a: var Fp) =
## to convert Jacobian and Projective coordinates ## to convert Jacobian and Projective coordinates
## to affine for elliptic curve ## to affine for elliptic curve
# For now we don't activate the addition chains # For now we don't activate the addition chains
# neither for Secp256k1 nor BN curves # for Secp256k1 nor BN curves
# Performance is slower than GCD # Performance is slower than GCD
# To be revisited with faster squaring/multiplications when Fp.C in {BN254_Snarks, BLS12_381}:
var t: typeof(a) # TODO: zero-init needed? a.inv_addchain(a)
t.inv(a) else:
a = t a.inv_euclid(a)