From d22d981e9e5cff1762df09cc6e96a926fa033fd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Wed, 17 Jun 2020 22:44:52 +0200 Subject: [PATCH] Implement fused sqrt invsqrt on Fp: Accelerate sqrt on Fp2 by 20% (hashToG2 and property-based testing bottleneck, 4 times slower than inversion and 87 times slower than Fp2 multiplication) --- benchmarks/bench_fields_template.nim | 34 ++++++- benchmarks/bench_fp.nim | 30 +++--- benchmarks/bench_fp2.nim | 3 +- constantine/arithmetic/bigints.nim | 26 ++++- constantine/arithmetic/finite_fields.nim | 96 ++++++++++++++++--- .../exponentiations.nim | 11 +-- 6 files changed, 161 insertions(+), 39 deletions(-) diff --git a/benchmarks/bench_fields_template.nim b/benchmarks/bench_fields_template.nim index 6227887..9135c23 100644 --- a/benchmarks/bench_fields_template.nim +++ b/benchmarks/bench_fields_template.nim @@ -71,15 +71,15 @@ when SupportsGetTicks: echo "\n=================================================================================================================\n" proc separator*() = - echo "-".repeat(110) + echo "-".repeat(145) proc report(op, field: string, start, stop: MonoTime, startClk, stopClk: int64, iters: int) = let ns = inNanoseconds((stop-start) div iters) let throughput = 1e9 / float64(ns) when SupportsGetTicks: - echo &"{op:<15} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)" + echo &"{op:<50} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)" else: - echo &"{op:<15} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op" + echo &"{op:<50} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op" macro fixFieldDisplay(T: typedesc): untyped = # At compile-time, enums are integers and their display is buggy @@ -143,5 +143,31 @@ proc invBench*(T: typedesc, iters: int) = var r: T let x = rng.random_unsafe(T) preventOptimAway(r) - bench("Inversion", T, iters): + bench("Inversion (constant-time Euclid)", T, iters): r.inv(x) + +proc powFermatInversionBench*(T: typedesc, iters: int) = + let x = rng.random_unsafe(T) + bench("Inversion via exponentiation p-2 (Little Fermat)", T, iters): + var r = x + r.powUnsafeExponent(T.C.getInvModExponent()) + +proc sqrtBench*(T: typedesc, iters: int) = + let x = rng.random_unsafe(T) + bench("Square Root + square check (constant-time)", T, iters): + var r = x + discard r.sqrt_if_square() + +proc powBench*(T: typedesc, iters: int) = + let x = rng.random_unsafe(T) + let exponent = rng.random_unsafe(BigInt[T.C.getCurveOrderBitwidth()]) + bench("Exp curve order (constant-time) - " & $exponent.bits & "-bit", T, iters): + var r = x + r.pow(exponent) + +proc powUnsafeBench*(T: typedesc, iters: int) = + let x = rng.random_unsafe(T) + let exponent = rng.random_unsafe(BigInt[T.C.getCurveOrderBitwidth()]) + bench("Exp curve order (Leak exponent bits) - " & $exponent.bits & "-bit", T, iters): + var r = x + r.powUnsafeExponent(exponent) diff --git a/benchmarks/bench_fp.nim b/benchmarks/bench_fp.nim index 2d2b622..5ec224d 100644 --- a/benchmarks/bench_fp.nim +++ b/benchmarks/bench_fp.nim @@ -10,6 +10,7 @@ import # Internals ../constantine/config/curves, ../constantine/arithmetic, + ../constantine/io/io_bigints, # Helpers ../helpers/static_for, ./bench_fields_template, @@ -24,20 +25,20 @@ import const Iters = 1_000_000 -const InvIters = 1000 +const ExponentIters = 1000 const AvailableCurves = [ - P224, - BN254_Nogami, + # P224, + # BN254_Nogami, BN254_Snarks, - Curve25519, - P256, - Secp256k1, - BLS12_377, + # Curve25519, + # P256, + # Secp256k1, + # BLS12_377, BLS12_381, - BN446, - FKM12_447, - BLS12_461, - BN462 + # BN446, + # FKM12_447, + # BLS12_461, + # BN462 ] proc main() = @@ -49,7 +50,12 @@ proc main() = negBench(Fp[curve], Iters) mulBench(Fp[curve], Iters) sqrBench(Fp[curve], Iters) - invBench(Fp[curve], InvIters) + invBench(Fp[curve], ExponentIters) + powFermatInversionBench(Fp[curve], ExponentIters) + sqrtBench(Fp[curve], ExponentIters) + # Exponentiation by a "secret" of size ~the curve order + powBench(Fp[curve], ExponentIters) + powUnsafeBench(Fp[curve], ExponentIters) separator() main() diff --git a/benchmarks/bench_fp2.nim b/benchmarks/bench_fp2.nim index ef3db2c..db5579e 100644 --- a/benchmarks/bench_fp2.nim +++ b/benchmarks/bench_fp2.nim @@ -29,7 +29,7 @@ const AvailableCurves = [ # Pairing-Friendly curves # BN254_Nogami, BN254_Snarks, - BLS12_377, + # BLS12_377, BLS12_381 # BN446, # FKM12_447, @@ -47,6 +47,7 @@ proc main() = mulBench(Fp2[curve], Iters) sqrBench(Fp2[curve], Iters) invBench(Fp2[curve], InvIters) + sqrtBench(Fp2[curve], InvIters) separator() main() diff --git a/constantine/arithmetic/bigints.nim b/constantine/arithmetic/bigints.nim index 88654ff..da2106a 100644 --- a/constantine/arithmetic/bigints.nim +++ b/constantine/arithmetic/bigints.nim @@ -434,6 +434,30 @@ func montyPowUnsafeExponent*[mBits, eBits: static int]( var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]] montyPowUnsafeExponent(a.limbs, expBE, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare) +func montyPow*[mBits: static int]( + a: var BigInt[mBits], exponent: openarray[byte], + M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, + canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool + ) = + ## Compute a <- a^exponent (mod M) + ## ``a`` in the Montgomery domain + ## ``exponent`` is a BigInt in canonical big-endian representation + ## + ## Warning ⚠️ : + ## This is an optimization for public exponent + ## Otherwise bits of the exponent can be retrieved with: + ## - memory access analysis + ## - power analysis + ## - timing analysis + ## + ## This uses fixed window optimization + ## A window size in the range [1, 5] must be chosen + + const scratchLen = if windowSize == 1: 2 + else: (1 shl windowSize) + 1 + var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]] + montyPow(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare) + func montyPowUnsafeExponent*[mBits: static int]( a: var BigInt[mBits], exponent: openarray[byte], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, @@ -441,7 +465,7 @@ func montyPowUnsafeExponent*[mBits: static int]( ) = ## Compute a <- a^exponent (mod M) ## ``a`` in the Montgomery domain - ## ``exponent`` is a BigInt in canonical representation + ## ``exponent`` is a BigInt in canonical big-endian representation ## ## Warning ⚠️ : ## This is an optimization for public exponent diff --git a/constantine/arithmetic/finite_fields.nim b/constantine/arithmetic/finite_fields.nim index 4ea0d3d..97b59a4 100644 --- a/constantine/arithmetic/finite_fields.nim +++ b/constantine/arithmetic/finite_fields.nim @@ -191,6 +191,19 @@ func pow*(a: var Fp, exponent: BigInt) = Fp.C.canUseNoCarryMontySquare() ) +func pow*(a: var Fp, exponent: openarray[byte]) = + ## Exponentiation modulo p + ## ``a``: a field element to be exponentiated + ## ``exponent``: a big integer in canonical big endian representation + const windowSize = 5 # TODO: find best window size for each curves + a.mres.montyPow( + exponent, + Fp.C.Mod, Fp.C.getMontyOne(), + Fp.C.getNegInvModWord(), windowSize, + Fp.C.canUseNoCarryMontyMul(), + Fp.C.canUseNoCarryMontySquare() + ) + func powUnsafeExponent*(a: var Fp, exponent: BigInt) = ## Exponentiation modulo p ## ``a``: a field element to be exponentiated @@ -214,7 +227,7 @@ func powUnsafeExponent*(a: var Fp, exponent: BigInt) = func powUnsafeExponent*(a: var Fp, exponent: openarray[byte]) = ## Exponentiation modulo p ## ``a``: a field element to be exponentiated - ## ``exponent``: a big integer + ## ``exponent``: a big integer a big integer in canonical big endian representation ## ## Warning ⚠️ : ## This is an optimization for public exponent @@ -241,7 +254,7 @@ func isSquare*[C](a: Fp[C]): SecretBool = ## Returns true if ``a`` is a square (quadratic residue) in 𝔽p ## ## Assumes that the prime modulus ``p`` is public. - # Implementation: we use exponentiation by (p-1)/2 (Euler(s criterion) + # Implementation: we use exponentiation by (p-1)/2 (Euler's criterion) # as it can reuse the exponentiation implementation # Note that we don't care about leaking the bits of p # as we assume that @@ -267,6 +280,39 @@ func sqrt_p3mod4[C](a: var Fp[C]) = static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3 a.powUnsafeExponent(C.getPrimePlus1div4_BE()) +func sqrt_invsqrt_p3mod4[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) = + ## If ``a`` is a square, compute the square root of ``a`` in sqrt + ## and the inverse square root of a in invsqrt + ## + ## This assumes that the prime field modulus ``p``: p ≡ 3 (mod 4) + # TODO: deterministic sign + # + # Algorithm + # + # + # From Euler's criterion: a^((p-1)/2)) ≡ 1 (mod p) if square + # a^((p-1)/2)) * a^-1 ≡ 1/a (mod p) + # a^((p-3)/2)) ≡ 1/a (mod p) + # a^((p-3)/4)) ≡ 1/√a (mod p) # Requires p ≡ 3 (mod 4) + static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3 + + invsqrt = a + invsqrt.powUnsafeExponent(C.getPrimeMinus3div4_BE()) + # √a ≡ a * 1/√a ≡ a^((p+1)/4) (mod p) + sqrt.prod(invsqrt, a) + +func sqrt_invsqrt_if_square_p3mod4[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool = + ## If ``a`` is a square, compute the square root of ``a`` in sqrt + ## and the inverse square root of a in invsqrt + ## + ## If a is not square, sqrt and invsqrt are undefined + ## + ## This assumes that the prime field modulus ``p``: p ≡ 3 (mod 4) + sqrt_invsqrt_p3mod4(sqrt, invsqrt, a) + var euler {.noInit.}: Fp[C] + euler.prod(sqrt, invsqrt) + result = not(euler.mres == C.getMontyPrimeMinus1()) + func sqrt_if_square_p3mod4[C](a: var Fp[C]): SecretBool = ## If ``a`` is a square, compute the square root of ``a`` ## if not, ``a`` is unmodified. @@ -278,19 +324,9 @@ func sqrt_if_square_p3mod4[C](a: var Fp[C]): SecretBool = ## The square root, if it exist is multivalued, ## i.e. both x² == (-x)² ## This procedure returns a deterministic result - static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3 - - var a1 {.noInit.} = a - a1.powUnsafeExponent(C.getPrimeMinus3div4_BE()) - - var a1a {.noInit.}: Fp[C] - a1a.prod(a1, a) - - var a0 {.noInit.}: Fp[C] - a0.prod(a1a, a1) - - result = not(a0.mres == C.getMontyPrimeMinus1()) - a.ccopy(a1a, result) + var sqrt {.noInit.}, invsqrt {.noInit.}: Fp[C] + result = sqrt_invsqrt_if_square_p3mod4(sqrt, invsqrt, a) + a.ccopy(sqrt, result) func sqrt*[C](a: var Fp[C]) = ## Compute the square root of ``a`` @@ -319,6 +355,36 @@ func sqrt_if_square*[C](a: var Fp[C]): SecretBool = else: {.error: "Square root is only implemented for p ≡ 3 (mod 4)".} +func sqrt_invsqrt*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) = + ## Compute the square root and inverse square root of ``a`` + ## + ## This requires ``a`` to be a square + ## + ## The result is undefined otherwise + ## + ## The square root, if it exist is multivalued, + ## i.e. both x² == (-x)² + ## This procedure returns a deterministic result + when BaseType(C.Mod.limbs[0]) mod 4 == 3: + sqrt_invsqrt_p3mod4(sqrt, invsqrt, a) + else: + {.error: "Square root is only implemented for p ≡ 3 (mod 4)".} + +func sqrt_invsqrt_if_square*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool = + ## Compute the square root and ivnerse square root of ``a`` + ## + ## This returns true if ``a`` is square and sqrt/invsqrt contains the square root/inverse square root + ## + ## The result is undefined otherwise + ## + ## The square root, if it exist is multivalued, + ## i.e. both x² == (-x)² + ## This procedure returns a deterministic result + when BaseType(C.Mod.limbs[0]) mod 4 == 3: + result = sqrt_invsqrt_if_square_p3mod4(sqrt, invsqrt, a) + else: + {.error: "Square root is only implemented for p ≡ 3 (mod 4)".} + # ############################################################ # # Field arithmetic ergonomic primitives diff --git a/constantine/tower_field_extensions/exponentiations.nim b/constantine/tower_field_extensions/exponentiations.nim index 2c8c7ea..955af3d 100644 --- a/constantine/tower_field_extensions/exponentiations.nim +++ b/constantine/tower_field_extensions/exponentiations.nim @@ -214,10 +214,9 @@ func sqrt_if_square*(a: var QuadraticExt): SecretBool = let quadResidTest = t2.isSquare() t2.ccopy(t3, not quadResidTest) - t2.sqrt() - a.c0.ccopy(t2, result) + sqrt_invsqrt(sqrt = t1, invsqrt = t3, t2) + a.c0.ccopy(t1, result) - t2.double() - t1.inv(t2) - t1 *= a.c1 - a.c1.ccopy(t1, result) + t3.div2() + t3 *= a.c1 + a.c1.ccopy(t3, result)