mirror of
https://github.com/codex-storage/constantine.git
synced 2025-01-11 19:44:10 +00:00
Implement fused sqrt invsqrt on Fp: Accelerate sqrt on Fp2 by 20% (hashToG2 and property-based testing bottleneck, 4 times slower than inversion and 87 times slower than Fp2 multiplication)
This commit is contained in:
parent
53c94e8aab
commit
d22d981e9e
@ -71,15 +71,15 @@ when SupportsGetTicks:
|
||||
echo "\n=================================================================================================================\n"
|
||||
|
||||
proc separator*() =
|
||||
echo "-".repeat(110)
|
||||
echo "-".repeat(145)
|
||||
|
||||
proc report(op, field: string, start, stop: MonoTime, startClk, stopClk: int64, iters: int) =
|
||||
let ns = inNanoseconds((stop-start) div iters)
|
||||
let throughput = 1e9 / float64(ns)
|
||||
when SupportsGetTicks:
|
||||
echo &"{op:<15} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
|
||||
echo &"{op:<50} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
|
||||
else:
|
||||
echo &"{op:<15} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op"
|
||||
echo &"{op:<50} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op"
|
||||
|
||||
macro fixFieldDisplay(T: typedesc): untyped =
|
||||
# At compile-time, enums are integers and their display is buggy
|
||||
@ -143,5 +143,31 @@ proc invBench*(T: typedesc, iters: int) =
|
||||
var r: T
|
||||
let x = rng.random_unsafe(T)
|
||||
preventOptimAway(r)
|
||||
bench("Inversion", T, iters):
|
||||
bench("Inversion (constant-time Euclid)", T, iters):
|
||||
r.inv(x)
|
||||
|
||||
proc powFermatInversionBench*(T: typedesc, iters: int) =
|
||||
let x = rng.random_unsafe(T)
|
||||
bench("Inversion via exponentiation p-2 (Little Fermat)", T, iters):
|
||||
var r = x
|
||||
r.powUnsafeExponent(T.C.getInvModExponent())
|
||||
|
||||
proc sqrtBench*(T: typedesc, iters: int) =
|
||||
let x = rng.random_unsafe(T)
|
||||
bench("Square Root + square check (constant-time)", T, iters):
|
||||
var r = x
|
||||
discard r.sqrt_if_square()
|
||||
|
||||
proc powBench*(T: typedesc, iters: int) =
|
||||
let x = rng.random_unsafe(T)
|
||||
let exponent = rng.random_unsafe(BigInt[T.C.getCurveOrderBitwidth()])
|
||||
bench("Exp curve order (constant-time) - " & $exponent.bits & "-bit", T, iters):
|
||||
var r = x
|
||||
r.pow(exponent)
|
||||
|
||||
proc powUnsafeBench*(T: typedesc, iters: int) =
|
||||
let x = rng.random_unsafe(T)
|
||||
let exponent = rng.random_unsafe(BigInt[T.C.getCurveOrderBitwidth()])
|
||||
bench("Exp curve order (Leak exponent bits) - " & $exponent.bits & "-bit", T, iters):
|
||||
var r = x
|
||||
r.powUnsafeExponent(exponent)
|
||||
|
@ -10,6 +10,7 @@ import
|
||||
# Internals
|
||||
../constantine/config/curves,
|
||||
../constantine/arithmetic,
|
||||
../constantine/io/io_bigints,
|
||||
# Helpers
|
||||
../helpers/static_for,
|
||||
./bench_fields_template,
|
||||
@ -24,20 +25,20 @@ import
|
||||
|
||||
|
||||
const Iters = 1_000_000
|
||||
const InvIters = 1000
|
||||
const ExponentIters = 1000
|
||||
const AvailableCurves = [
|
||||
P224,
|
||||
BN254_Nogami,
|
||||
# P224,
|
||||
# BN254_Nogami,
|
||||
BN254_Snarks,
|
||||
Curve25519,
|
||||
P256,
|
||||
Secp256k1,
|
||||
BLS12_377,
|
||||
# Curve25519,
|
||||
# P256,
|
||||
# Secp256k1,
|
||||
# BLS12_377,
|
||||
BLS12_381,
|
||||
BN446,
|
||||
FKM12_447,
|
||||
BLS12_461,
|
||||
BN462
|
||||
# BN446,
|
||||
# FKM12_447,
|
||||
# BLS12_461,
|
||||
# BN462
|
||||
]
|
||||
|
||||
proc main() =
|
||||
@ -49,7 +50,12 @@ proc main() =
|
||||
negBench(Fp[curve], Iters)
|
||||
mulBench(Fp[curve], Iters)
|
||||
sqrBench(Fp[curve], Iters)
|
||||
invBench(Fp[curve], InvIters)
|
||||
invBench(Fp[curve], ExponentIters)
|
||||
powFermatInversionBench(Fp[curve], ExponentIters)
|
||||
sqrtBench(Fp[curve], ExponentIters)
|
||||
# Exponentiation by a "secret" of size ~the curve order
|
||||
powBench(Fp[curve], ExponentIters)
|
||||
powUnsafeBench(Fp[curve], ExponentIters)
|
||||
separator()
|
||||
|
||||
main()
|
||||
|
@ -29,7 +29,7 @@ const AvailableCurves = [
|
||||
# Pairing-Friendly curves
|
||||
# BN254_Nogami,
|
||||
BN254_Snarks,
|
||||
BLS12_377,
|
||||
# BLS12_377,
|
||||
BLS12_381
|
||||
# BN446,
|
||||
# FKM12_447,
|
||||
@ -47,6 +47,7 @@ proc main() =
|
||||
mulBench(Fp2[curve], Iters)
|
||||
sqrBench(Fp2[curve], Iters)
|
||||
invBench(Fp2[curve], InvIters)
|
||||
sqrtBench(Fp2[curve], InvIters)
|
||||
separator()
|
||||
|
||||
main()
|
||||
|
@ -434,6 +434,30 @@ func montyPowUnsafeExponent*[mBits, eBits: static int](
|
||||
var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]]
|
||||
montyPowUnsafeExponent(a.limbs, expBE, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare)
|
||||
|
||||
func montyPow*[mBits: static int](
|
||||
a: var BigInt[mBits], exponent: openarray[byte],
|
||||
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int,
|
||||
canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool
|
||||
) =
|
||||
## Compute a <- a^exponent (mod M)
|
||||
## ``a`` in the Montgomery domain
|
||||
## ``exponent`` is a BigInt in canonical big-endian representation
|
||||
##
|
||||
## Warning ⚠️ :
|
||||
## This is an optimization for public exponent
|
||||
## Otherwise bits of the exponent can be retrieved with:
|
||||
## - memory access analysis
|
||||
## - power analysis
|
||||
## - timing analysis
|
||||
##
|
||||
## This uses fixed window optimization
|
||||
## A window size in the range [1, 5] must be chosen
|
||||
|
||||
const scratchLen = if windowSize == 1: 2
|
||||
else: (1 shl windowSize) + 1
|
||||
var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]]
|
||||
montyPow(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare)
|
||||
|
||||
func montyPowUnsafeExponent*[mBits: static int](
|
||||
a: var BigInt[mBits], exponent: openarray[byte],
|
||||
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int,
|
||||
@ -441,7 +465,7 @@ func montyPowUnsafeExponent*[mBits: static int](
|
||||
) =
|
||||
## Compute a <- a^exponent (mod M)
|
||||
## ``a`` in the Montgomery domain
|
||||
## ``exponent`` is a BigInt in canonical representation
|
||||
## ``exponent`` is a BigInt in canonical big-endian representation
|
||||
##
|
||||
## Warning ⚠️ :
|
||||
## This is an optimization for public exponent
|
||||
|
@ -191,6 +191,19 @@ func pow*(a: var Fp, exponent: BigInt) =
|
||||
Fp.C.canUseNoCarryMontySquare()
|
||||
)
|
||||
|
||||
func pow*(a: var Fp, exponent: openarray[byte]) =
|
||||
## Exponentiation modulo p
|
||||
## ``a``: a field element to be exponentiated
|
||||
## ``exponent``: a big integer in canonical big endian representation
|
||||
const windowSize = 5 # TODO: find best window size for each curves
|
||||
a.mres.montyPow(
|
||||
exponent,
|
||||
Fp.C.Mod, Fp.C.getMontyOne(),
|
||||
Fp.C.getNegInvModWord(), windowSize,
|
||||
Fp.C.canUseNoCarryMontyMul(),
|
||||
Fp.C.canUseNoCarryMontySquare()
|
||||
)
|
||||
|
||||
func powUnsafeExponent*(a: var Fp, exponent: BigInt) =
|
||||
## Exponentiation modulo p
|
||||
## ``a``: a field element to be exponentiated
|
||||
@ -214,7 +227,7 @@ func powUnsafeExponent*(a: var Fp, exponent: BigInt) =
|
||||
func powUnsafeExponent*(a: var Fp, exponent: openarray[byte]) =
|
||||
## Exponentiation modulo p
|
||||
## ``a``: a field element to be exponentiated
|
||||
## ``exponent``: a big integer
|
||||
## ``exponent``: a big integer a big integer in canonical big endian representation
|
||||
##
|
||||
## Warning ⚠️ :
|
||||
## This is an optimization for public exponent
|
||||
@ -241,7 +254,7 @@ func isSquare*[C](a: Fp[C]): SecretBool =
|
||||
## Returns true if ``a`` is a square (quadratic residue) in 𝔽p
|
||||
##
|
||||
## Assumes that the prime modulus ``p`` is public.
|
||||
# Implementation: we use exponentiation by (p-1)/2 (Euler(s criterion)
|
||||
# Implementation: we use exponentiation by (p-1)/2 (Euler's criterion)
|
||||
# as it can reuse the exponentiation implementation
|
||||
# Note that we don't care about leaking the bits of p
|
||||
# as we assume that
|
||||
@ -267,6 +280,39 @@ func sqrt_p3mod4[C](a: var Fp[C]) =
|
||||
static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3
|
||||
a.powUnsafeExponent(C.getPrimePlus1div4_BE())
|
||||
|
||||
func sqrt_invsqrt_p3mod4[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) =
|
||||
## If ``a`` is a square, compute the square root of ``a`` in sqrt
|
||||
## and the inverse square root of a in invsqrt
|
||||
##
|
||||
## This assumes that the prime field modulus ``p``: p ≡ 3 (mod 4)
|
||||
# TODO: deterministic sign
|
||||
#
|
||||
# Algorithm
|
||||
#
|
||||
#
|
||||
# From Euler's criterion: a^((p-1)/2)) ≡ 1 (mod p) if square
|
||||
# a^((p-1)/2)) * a^-1 ≡ 1/a (mod p)
|
||||
# a^((p-3)/2)) ≡ 1/a (mod p)
|
||||
# a^((p-3)/4)) ≡ 1/√a (mod p) # Requires p ≡ 3 (mod 4)
|
||||
static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3
|
||||
|
||||
invsqrt = a
|
||||
invsqrt.powUnsafeExponent(C.getPrimeMinus3div4_BE())
|
||||
# √a ≡ a * 1/√a ≡ a^((p+1)/4) (mod p)
|
||||
sqrt.prod(invsqrt, a)
|
||||
|
||||
func sqrt_invsqrt_if_square_p3mod4[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool =
|
||||
## If ``a`` is a square, compute the square root of ``a`` in sqrt
|
||||
## and the inverse square root of a in invsqrt
|
||||
##
|
||||
## If a is not square, sqrt and invsqrt are undefined
|
||||
##
|
||||
## This assumes that the prime field modulus ``p``: p ≡ 3 (mod 4)
|
||||
sqrt_invsqrt_p3mod4(sqrt, invsqrt, a)
|
||||
var euler {.noInit.}: Fp[C]
|
||||
euler.prod(sqrt, invsqrt)
|
||||
result = not(euler.mres == C.getMontyPrimeMinus1())
|
||||
|
||||
func sqrt_if_square_p3mod4[C](a: var Fp[C]): SecretBool =
|
||||
## If ``a`` is a square, compute the square root of ``a``
|
||||
## if not, ``a`` is unmodified.
|
||||
@ -278,19 +324,9 @@ func sqrt_if_square_p3mod4[C](a: var Fp[C]): SecretBool =
|
||||
## The square root, if it exist is multivalued,
|
||||
## i.e. both x² == (-x)²
|
||||
## This procedure returns a deterministic result
|
||||
static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3
|
||||
|
||||
var a1 {.noInit.} = a
|
||||
a1.powUnsafeExponent(C.getPrimeMinus3div4_BE())
|
||||
|
||||
var a1a {.noInit.}: Fp[C]
|
||||
a1a.prod(a1, a)
|
||||
|
||||
var a0 {.noInit.}: Fp[C]
|
||||
a0.prod(a1a, a1)
|
||||
|
||||
result = not(a0.mres == C.getMontyPrimeMinus1())
|
||||
a.ccopy(a1a, result)
|
||||
var sqrt {.noInit.}, invsqrt {.noInit.}: Fp[C]
|
||||
result = sqrt_invsqrt_if_square_p3mod4(sqrt, invsqrt, a)
|
||||
a.ccopy(sqrt, result)
|
||||
|
||||
func sqrt*[C](a: var Fp[C]) =
|
||||
## Compute the square root of ``a``
|
||||
@ -319,6 +355,36 @@ func sqrt_if_square*[C](a: var Fp[C]): SecretBool =
|
||||
else:
|
||||
{.error: "Square root is only implemented for p ≡ 3 (mod 4)".}
|
||||
|
||||
func sqrt_invsqrt*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) =
|
||||
## Compute the square root and inverse square root of ``a``
|
||||
##
|
||||
## This requires ``a`` to be a square
|
||||
##
|
||||
## The result is undefined otherwise
|
||||
##
|
||||
## The square root, if it exist is multivalued,
|
||||
## i.e. both x² == (-x)²
|
||||
## This procedure returns a deterministic result
|
||||
when BaseType(C.Mod.limbs[0]) mod 4 == 3:
|
||||
sqrt_invsqrt_p3mod4(sqrt, invsqrt, a)
|
||||
else:
|
||||
{.error: "Square root is only implemented for p ≡ 3 (mod 4)".}
|
||||
|
||||
func sqrt_invsqrt_if_square*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool =
|
||||
## Compute the square root and ivnerse square root of ``a``
|
||||
##
|
||||
## This returns true if ``a`` is square and sqrt/invsqrt contains the square root/inverse square root
|
||||
##
|
||||
## The result is undefined otherwise
|
||||
##
|
||||
## The square root, if it exist is multivalued,
|
||||
## i.e. both x² == (-x)²
|
||||
## This procedure returns a deterministic result
|
||||
when BaseType(C.Mod.limbs[0]) mod 4 == 3:
|
||||
result = sqrt_invsqrt_if_square_p3mod4(sqrt, invsqrt, a)
|
||||
else:
|
||||
{.error: "Square root is only implemented for p ≡ 3 (mod 4)".}
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Field arithmetic ergonomic primitives
|
||||
|
@ -214,10 +214,9 @@ func sqrt_if_square*(a: var QuadraticExt): SecretBool =
|
||||
let quadResidTest = t2.isSquare()
|
||||
t2.ccopy(t3, not quadResidTest)
|
||||
|
||||
t2.sqrt()
|
||||
a.c0.ccopy(t2, result)
|
||||
sqrt_invsqrt(sqrt = t1, invsqrt = t3, t2)
|
||||
a.c0.ccopy(t1, result)
|
||||
|
||||
t2.double()
|
||||
t1.inv(t2)
|
||||
t1 *= a.c1
|
||||
a.c1.ccopy(t1, result)
|
||||
t3.div2()
|
||||
t3 *= a.c1
|
||||
a.c1.ccopy(t3, result)
|
||||
|
Loading…
x
Reference in New Issue
Block a user