Implement fused sqrt invsqrt on Fp: Accelerate sqrt on Fp2 by 20% (hashToG2 and property-based testing bottleneck, 4 times slower than inversion and 87 times slower than Fp2 multiplication)

This commit is contained in:
Mamy André-Ratsimbazafy 2020-06-17 22:44:52 +02:00
parent 53c94e8aab
commit d22d981e9e
No known key found for this signature in database
GPG Key ID: 7B88AD1FE79492E1
6 changed files with 161 additions and 39 deletions

View File

@ -71,15 +71,15 @@ when SupportsGetTicks:
echo "\n=================================================================================================================\n"
proc separator*() =
echo "-".repeat(110)
echo "-".repeat(145)
proc report(op, field: string, start, stop: MonoTime, startClk, stopClk: int64, iters: int) =
let ns = inNanoseconds((stop-start) div iters)
let throughput = 1e9 / float64(ns)
when SupportsGetTicks:
echo &"{op:<15} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
echo &"{op:<50} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
else:
echo &"{op:<15} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op"
echo &"{op:<50} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op"
macro fixFieldDisplay(T: typedesc): untyped =
# At compile-time, enums are integers and their display is buggy
@ -143,5 +143,31 @@ proc invBench*(T: typedesc, iters: int) =
var r: T
let x = rng.random_unsafe(T)
preventOptimAway(r)
bench("Inversion", T, iters):
bench("Inversion (constant-time Euclid)", T, iters):
r.inv(x)
proc powFermatInversionBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T)
bench("Inversion via exponentiation p-2 (Little Fermat)", T, iters):
var r = x
r.powUnsafeExponent(T.C.getInvModExponent())
proc sqrtBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T)
bench("Square Root + square check (constant-time)", T, iters):
var r = x
discard r.sqrt_if_square()
proc powBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T)
let exponent = rng.random_unsafe(BigInt[T.C.getCurveOrderBitwidth()])
bench("Exp curve order (constant-time) - " & $exponent.bits & "-bit", T, iters):
var r = x
r.pow(exponent)
proc powUnsafeBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T)
let exponent = rng.random_unsafe(BigInt[T.C.getCurveOrderBitwidth()])
bench("Exp curve order (Leak exponent bits) - " & $exponent.bits & "-bit", T, iters):
var r = x
r.powUnsafeExponent(exponent)

View File

@ -10,6 +10,7 @@ import
# Internals
../constantine/config/curves,
../constantine/arithmetic,
../constantine/io/io_bigints,
# Helpers
../helpers/static_for,
./bench_fields_template,
@ -24,20 +25,20 @@ import
const Iters = 1_000_000
const InvIters = 1000
const ExponentIters = 1000
const AvailableCurves = [
P224,
BN254_Nogami,
# P224,
# BN254_Nogami,
BN254_Snarks,
Curve25519,
P256,
Secp256k1,
BLS12_377,
# Curve25519,
# P256,
# Secp256k1,
# BLS12_377,
BLS12_381,
BN446,
FKM12_447,
BLS12_461,
BN462
# BN446,
# FKM12_447,
# BLS12_461,
# BN462
]
proc main() =
@ -49,7 +50,12 @@ proc main() =
negBench(Fp[curve], Iters)
mulBench(Fp[curve], Iters)
sqrBench(Fp[curve], Iters)
invBench(Fp[curve], InvIters)
invBench(Fp[curve], ExponentIters)
powFermatInversionBench(Fp[curve], ExponentIters)
sqrtBench(Fp[curve], ExponentIters)
# Exponentiation by a "secret" of size ~the curve order
powBench(Fp[curve], ExponentIters)
powUnsafeBench(Fp[curve], ExponentIters)
separator()
main()

View File

@ -29,7 +29,7 @@ const AvailableCurves = [
# Pairing-Friendly curves
# BN254_Nogami,
BN254_Snarks,
BLS12_377,
# BLS12_377,
BLS12_381
# BN446,
# FKM12_447,
@ -47,6 +47,7 @@ proc main() =
mulBench(Fp2[curve], Iters)
sqrBench(Fp2[curve], Iters)
invBench(Fp2[curve], InvIters)
sqrtBench(Fp2[curve], InvIters)
separator()
main()

View File

@ -434,6 +434,30 @@ func montyPowUnsafeExponent*[mBits, eBits: static int](
var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]]
montyPowUnsafeExponent(a.limbs, expBE, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare)
func montyPow*[mBits: static int](
a: var BigInt[mBits], exponent: openarray[byte],
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int,
canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool
) =
## Compute a <- a^exponent (mod M)
## ``a`` in the Montgomery domain
## ``exponent`` is a BigInt in canonical big-endian representation
##
## Warning ⚠️ :
## This is an optimization for public exponent
## Otherwise bits of the exponent can be retrieved with:
## - memory access analysis
## - power analysis
## - timing analysis
##
## This uses fixed window optimization
## A window size in the range [1, 5] must be chosen
const scratchLen = if windowSize == 1: 2
else: (1 shl windowSize) + 1
var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]]
montyPow(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare)
func montyPowUnsafeExponent*[mBits: static int](
a: var BigInt[mBits], exponent: openarray[byte],
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int,
@ -441,7 +465,7 @@ func montyPowUnsafeExponent*[mBits: static int](
) =
## Compute a <- a^exponent (mod M)
## ``a`` in the Montgomery domain
## ``exponent`` is a BigInt in canonical representation
## ``exponent`` is a BigInt in canonical big-endian representation
##
## Warning ⚠️ :
## This is an optimization for public exponent

View File

@ -191,6 +191,19 @@ func pow*(a: var Fp, exponent: BigInt) =
Fp.C.canUseNoCarryMontySquare()
)
func pow*(a: var Fp, exponent: openarray[byte]) =
## Exponentiation modulo p
## ``a``: a field element to be exponentiated
## ``exponent``: a big integer in canonical big endian representation
const windowSize = 5 # TODO: find best window size for each curves
a.mres.montyPow(
exponent,
Fp.C.Mod, Fp.C.getMontyOne(),
Fp.C.getNegInvModWord(), windowSize,
Fp.C.canUseNoCarryMontyMul(),
Fp.C.canUseNoCarryMontySquare()
)
func powUnsafeExponent*(a: var Fp, exponent: BigInt) =
## Exponentiation modulo p
## ``a``: a field element to be exponentiated
@ -214,7 +227,7 @@ func powUnsafeExponent*(a: var Fp, exponent: BigInt) =
func powUnsafeExponent*(a: var Fp, exponent: openarray[byte]) =
## Exponentiation modulo p
## ``a``: a field element to be exponentiated
## ``exponent``: a big integer
## ``exponent``: a big integer a big integer in canonical big endian representation
##
## Warning ⚠️ :
## This is an optimization for public exponent
@ -241,7 +254,7 @@ func isSquare*[C](a: Fp[C]): SecretBool =
## Returns true if ``a`` is a square (quadratic residue) in 𝔽p
##
## Assumes that the prime modulus ``p`` is public.
# Implementation: we use exponentiation by (p-1)/2 (Euler(s criterion)
# Implementation: we use exponentiation by (p-1)/2 (Euler's criterion)
# as it can reuse the exponentiation implementation
# Note that we don't care about leaking the bits of p
# as we assume that
@ -267,6 +280,39 @@ func sqrt_p3mod4[C](a: var Fp[C]) =
static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3
a.powUnsafeExponent(C.getPrimePlus1div4_BE())
func sqrt_invsqrt_p3mod4[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) =
## If ``a`` is a square, compute the square root of ``a`` in sqrt
## and the inverse square root of a in invsqrt
##
## This assumes that the prime field modulus ``p``: p ≡ 3 (mod 4)
# TODO: deterministic sign
#
# Algorithm
#
#
# From Euler's criterion: a^((p-1)/2)) ≡ 1 (mod p) if square
# a^((p-1)/2)) * a^-1 ≡ 1/a (mod p)
# a^((p-3)/2)) ≡ 1/a (mod p)
# a^((p-3)/4)) ≡ 1/√a (mod p) # Requires p ≡ 3 (mod 4)
static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3
invsqrt = a
invsqrt.powUnsafeExponent(C.getPrimeMinus3div4_BE())
# √a ≡ a * 1/√a ≡ a^((p+1)/4) (mod p)
sqrt.prod(invsqrt, a)
func sqrt_invsqrt_if_square_p3mod4[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool =
## If ``a`` is a square, compute the square root of ``a`` in sqrt
## and the inverse square root of a in invsqrt
##
## If a is not square, sqrt and invsqrt are undefined
##
## This assumes that the prime field modulus ``p``: p ≡ 3 (mod 4)
sqrt_invsqrt_p3mod4(sqrt, invsqrt, a)
var euler {.noInit.}: Fp[C]
euler.prod(sqrt, invsqrt)
result = not(euler.mres == C.getMontyPrimeMinus1())
func sqrt_if_square_p3mod4[C](a: var Fp[C]): SecretBool =
## If ``a`` is a square, compute the square root of ``a``
## if not, ``a`` is unmodified.
@ -278,19 +324,9 @@ func sqrt_if_square_p3mod4[C](a: var Fp[C]): SecretBool =
## The square root, if it exist is multivalued,
## i.e. both x² == (-x)²
## This procedure returns a deterministic result
static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3
var a1 {.noInit.} = a
a1.powUnsafeExponent(C.getPrimeMinus3div4_BE())
var a1a {.noInit.}: Fp[C]
a1a.prod(a1, a)
var a0 {.noInit.}: Fp[C]
a0.prod(a1a, a1)
result = not(a0.mres == C.getMontyPrimeMinus1())
a.ccopy(a1a, result)
var sqrt {.noInit.}, invsqrt {.noInit.}: Fp[C]
result = sqrt_invsqrt_if_square_p3mod4(sqrt, invsqrt, a)
a.ccopy(sqrt, result)
func sqrt*[C](a: var Fp[C]) =
## Compute the square root of ``a``
@ -319,6 +355,36 @@ func sqrt_if_square*[C](a: var Fp[C]): SecretBool =
else:
{.error: "Square root is only implemented for p ≡ 3 (mod 4)".}
func sqrt_invsqrt*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) =
## Compute the square root and inverse square root of ``a``
##
## This requires ``a`` to be a square
##
## The result is undefined otherwise
##
## The square root, if it exist is multivalued,
## i.e. both x² == (-x)²
## This procedure returns a deterministic result
when BaseType(C.Mod.limbs[0]) mod 4 == 3:
sqrt_invsqrt_p3mod4(sqrt, invsqrt, a)
else:
{.error: "Square root is only implemented for p ≡ 3 (mod 4)".}
func sqrt_invsqrt_if_square*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool =
## Compute the square root and ivnerse square root of ``a``
##
## This returns true if ``a`` is square and sqrt/invsqrt contains the square root/inverse square root
##
## The result is undefined otherwise
##
## The square root, if it exist is multivalued,
## i.e. both x² == (-x)²
## This procedure returns a deterministic result
when BaseType(C.Mod.limbs[0]) mod 4 == 3:
result = sqrt_invsqrt_if_square_p3mod4(sqrt, invsqrt, a)
else:
{.error: "Square root is only implemented for p ≡ 3 (mod 4)".}
# ############################################################
#
# Field arithmetic ergonomic primitives

View File

@ -214,10 +214,9 @@ func sqrt_if_square*(a: var QuadraticExt): SecretBool =
let quadResidTest = t2.isSquare()
t2.ccopy(t3, not quadResidTest)
t2.sqrt()
a.c0.ccopy(t2, result)
sqrt_invsqrt(sqrt = t1, invsqrt = t3, t2)
a.c0.ccopy(t1, result)
t2.double()
t1.inv(t2)
t1 *= a.c1
a.c1.ccopy(t1, result)
t3.div2()
t3 *= a.c1
a.c1.ccopy(t3, result)