mirror of
https://github.com/logos-storage/constantine.git
synced 2026-01-05 14:43:08 +00:00
Implement fused sqrt invsqrt on Fp: Accelerate sqrt on Fp2 by 20% (hashToG2 and property-based testing bottleneck, 4 times slower than inversion and 87 times slower than Fp2 multiplication)
This commit is contained in:
parent
53c94e8aab
commit
d22d981e9e
@ -71,15 +71,15 @@ when SupportsGetTicks:
|
|||||||
echo "\n=================================================================================================================\n"
|
echo "\n=================================================================================================================\n"
|
||||||
|
|
||||||
proc separator*() =
|
proc separator*() =
|
||||||
echo "-".repeat(110)
|
echo "-".repeat(145)
|
||||||
|
|
||||||
proc report(op, field: string, start, stop: MonoTime, startClk, stopClk: int64, iters: int) =
|
proc report(op, field: string, start, stop: MonoTime, startClk, stopClk: int64, iters: int) =
|
||||||
let ns = inNanoseconds((stop-start) div iters)
|
let ns = inNanoseconds((stop-start) div iters)
|
||||||
let throughput = 1e9 / float64(ns)
|
let throughput = 1e9 / float64(ns)
|
||||||
when SupportsGetTicks:
|
when SupportsGetTicks:
|
||||||
echo &"{op:<15} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
|
echo &"{op:<50} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
|
||||||
else:
|
else:
|
||||||
echo &"{op:<15} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op"
|
echo &"{op:<50} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op"
|
||||||
|
|
||||||
macro fixFieldDisplay(T: typedesc): untyped =
|
macro fixFieldDisplay(T: typedesc): untyped =
|
||||||
# At compile-time, enums are integers and their display is buggy
|
# At compile-time, enums are integers and their display is buggy
|
||||||
@ -143,5 +143,31 @@ proc invBench*(T: typedesc, iters: int) =
|
|||||||
var r: T
|
var r: T
|
||||||
let x = rng.random_unsafe(T)
|
let x = rng.random_unsafe(T)
|
||||||
preventOptimAway(r)
|
preventOptimAway(r)
|
||||||
bench("Inversion", T, iters):
|
bench("Inversion (constant-time Euclid)", T, iters):
|
||||||
r.inv(x)
|
r.inv(x)
|
||||||
|
|
||||||
|
proc powFermatInversionBench*(T: typedesc, iters: int) =
|
||||||
|
let x = rng.random_unsafe(T)
|
||||||
|
bench("Inversion via exponentiation p-2 (Little Fermat)", T, iters):
|
||||||
|
var r = x
|
||||||
|
r.powUnsafeExponent(T.C.getInvModExponent())
|
||||||
|
|
||||||
|
proc sqrtBench*(T: typedesc, iters: int) =
|
||||||
|
let x = rng.random_unsafe(T)
|
||||||
|
bench("Square Root + square check (constant-time)", T, iters):
|
||||||
|
var r = x
|
||||||
|
discard r.sqrt_if_square()
|
||||||
|
|
||||||
|
proc powBench*(T: typedesc, iters: int) =
|
||||||
|
let x = rng.random_unsafe(T)
|
||||||
|
let exponent = rng.random_unsafe(BigInt[T.C.getCurveOrderBitwidth()])
|
||||||
|
bench("Exp curve order (constant-time) - " & $exponent.bits & "-bit", T, iters):
|
||||||
|
var r = x
|
||||||
|
r.pow(exponent)
|
||||||
|
|
||||||
|
proc powUnsafeBench*(T: typedesc, iters: int) =
|
||||||
|
let x = rng.random_unsafe(T)
|
||||||
|
let exponent = rng.random_unsafe(BigInt[T.C.getCurveOrderBitwidth()])
|
||||||
|
bench("Exp curve order (Leak exponent bits) - " & $exponent.bits & "-bit", T, iters):
|
||||||
|
var r = x
|
||||||
|
r.powUnsafeExponent(exponent)
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import
|
|||||||
# Internals
|
# Internals
|
||||||
../constantine/config/curves,
|
../constantine/config/curves,
|
||||||
../constantine/arithmetic,
|
../constantine/arithmetic,
|
||||||
|
../constantine/io/io_bigints,
|
||||||
# Helpers
|
# Helpers
|
||||||
../helpers/static_for,
|
../helpers/static_for,
|
||||||
./bench_fields_template,
|
./bench_fields_template,
|
||||||
@ -24,20 +25,20 @@ import
|
|||||||
|
|
||||||
|
|
||||||
const Iters = 1_000_000
|
const Iters = 1_000_000
|
||||||
const InvIters = 1000
|
const ExponentIters = 1000
|
||||||
const AvailableCurves = [
|
const AvailableCurves = [
|
||||||
P224,
|
# P224,
|
||||||
BN254_Nogami,
|
# BN254_Nogami,
|
||||||
BN254_Snarks,
|
BN254_Snarks,
|
||||||
Curve25519,
|
# Curve25519,
|
||||||
P256,
|
# P256,
|
||||||
Secp256k1,
|
# Secp256k1,
|
||||||
BLS12_377,
|
# BLS12_377,
|
||||||
BLS12_381,
|
BLS12_381,
|
||||||
BN446,
|
# BN446,
|
||||||
FKM12_447,
|
# FKM12_447,
|
||||||
BLS12_461,
|
# BLS12_461,
|
||||||
BN462
|
# BN462
|
||||||
]
|
]
|
||||||
|
|
||||||
proc main() =
|
proc main() =
|
||||||
@ -49,7 +50,12 @@ proc main() =
|
|||||||
negBench(Fp[curve], Iters)
|
negBench(Fp[curve], Iters)
|
||||||
mulBench(Fp[curve], Iters)
|
mulBench(Fp[curve], Iters)
|
||||||
sqrBench(Fp[curve], Iters)
|
sqrBench(Fp[curve], Iters)
|
||||||
invBench(Fp[curve], InvIters)
|
invBench(Fp[curve], ExponentIters)
|
||||||
|
powFermatInversionBench(Fp[curve], ExponentIters)
|
||||||
|
sqrtBench(Fp[curve], ExponentIters)
|
||||||
|
# Exponentiation by a "secret" of size ~the curve order
|
||||||
|
powBench(Fp[curve], ExponentIters)
|
||||||
|
powUnsafeBench(Fp[curve], ExponentIters)
|
||||||
separator()
|
separator()
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -29,7 +29,7 @@ const AvailableCurves = [
|
|||||||
# Pairing-Friendly curves
|
# Pairing-Friendly curves
|
||||||
# BN254_Nogami,
|
# BN254_Nogami,
|
||||||
BN254_Snarks,
|
BN254_Snarks,
|
||||||
BLS12_377,
|
# BLS12_377,
|
||||||
BLS12_381
|
BLS12_381
|
||||||
# BN446,
|
# BN446,
|
||||||
# FKM12_447,
|
# FKM12_447,
|
||||||
@ -47,6 +47,7 @@ proc main() =
|
|||||||
mulBench(Fp2[curve], Iters)
|
mulBench(Fp2[curve], Iters)
|
||||||
sqrBench(Fp2[curve], Iters)
|
sqrBench(Fp2[curve], Iters)
|
||||||
invBench(Fp2[curve], InvIters)
|
invBench(Fp2[curve], InvIters)
|
||||||
|
sqrtBench(Fp2[curve], InvIters)
|
||||||
separator()
|
separator()
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -434,6 +434,30 @@ func montyPowUnsafeExponent*[mBits, eBits: static int](
|
|||||||
var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]]
|
var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]]
|
||||||
montyPowUnsafeExponent(a.limbs, expBE, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare)
|
montyPowUnsafeExponent(a.limbs, expBE, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare)
|
||||||
|
|
||||||
|
func montyPow*[mBits: static int](
|
||||||
|
a: var BigInt[mBits], exponent: openarray[byte],
|
||||||
|
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int,
|
||||||
|
canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool
|
||||||
|
) =
|
||||||
|
## Compute a <- a^exponent (mod M)
|
||||||
|
## ``a`` in the Montgomery domain
|
||||||
|
## ``exponent`` is a BigInt in canonical big-endian representation
|
||||||
|
##
|
||||||
|
## Warning ⚠️ :
|
||||||
|
## This is an optimization for public exponent
|
||||||
|
## Otherwise bits of the exponent can be retrieved with:
|
||||||
|
## - memory access analysis
|
||||||
|
## - power analysis
|
||||||
|
## - timing analysis
|
||||||
|
##
|
||||||
|
## This uses fixed window optimization
|
||||||
|
## A window size in the range [1, 5] must be chosen
|
||||||
|
|
||||||
|
const scratchLen = if windowSize == 1: 2
|
||||||
|
else: (1 shl windowSize) + 1
|
||||||
|
var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]]
|
||||||
|
montyPow(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare)
|
||||||
|
|
||||||
func montyPowUnsafeExponent*[mBits: static int](
|
func montyPowUnsafeExponent*[mBits: static int](
|
||||||
a: var BigInt[mBits], exponent: openarray[byte],
|
a: var BigInt[mBits], exponent: openarray[byte],
|
||||||
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int,
|
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int,
|
||||||
@ -441,7 +465,7 @@ func montyPowUnsafeExponent*[mBits: static int](
|
|||||||
) =
|
) =
|
||||||
## Compute a <- a^exponent (mod M)
|
## Compute a <- a^exponent (mod M)
|
||||||
## ``a`` in the Montgomery domain
|
## ``a`` in the Montgomery domain
|
||||||
## ``exponent`` is a BigInt in canonical representation
|
## ``exponent`` is a BigInt in canonical big-endian representation
|
||||||
##
|
##
|
||||||
## Warning ⚠️ :
|
## Warning ⚠️ :
|
||||||
## This is an optimization for public exponent
|
## This is an optimization for public exponent
|
||||||
|
|||||||
@ -191,6 +191,19 @@ func pow*(a: var Fp, exponent: BigInt) =
|
|||||||
Fp.C.canUseNoCarryMontySquare()
|
Fp.C.canUseNoCarryMontySquare()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func pow*(a: var Fp, exponent: openarray[byte]) =
|
||||||
|
## Exponentiation modulo p
|
||||||
|
## ``a``: a field element to be exponentiated
|
||||||
|
## ``exponent``: a big integer in canonical big endian representation
|
||||||
|
const windowSize = 5 # TODO: find best window size for each curves
|
||||||
|
a.mres.montyPow(
|
||||||
|
exponent,
|
||||||
|
Fp.C.Mod, Fp.C.getMontyOne(),
|
||||||
|
Fp.C.getNegInvModWord(), windowSize,
|
||||||
|
Fp.C.canUseNoCarryMontyMul(),
|
||||||
|
Fp.C.canUseNoCarryMontySquare()
|
||||||
|
)
|
||||||
|
|
||||||
func powUnsafeExponent*(a: var Fp, exponent: BigInt) =
|
func powUnsafeExponent*(a: var Fp, exponent: BigInt) =
|
||||||
## Exponentiation modulo p
|
## Exponentiation modulo p
|
||||||
## ``a``: a field element to be exponentiated
|
## ``a``: a field element to be exponentiated
|
||||||
@ -214,7 +227,7 @@ func powUnsafeExponent*(a: var Fp, exponent: BigInt) =
|
|||||||
func powUnsafeExponent*(a: var Fp, exponent: openarray[byte]) =
|
func powUnsafeExponent*(a: var Fp, exponent: openarray[byte]) =
|
||||||
## Exponentiation modulo p
|
## Exponentiation modulo p
|
||||||
## ``a``: a field element to be exponentiated
|
## ``a``: a field element to be exponentiated
|
||||||
## ``exponent``: a big integer
|
## ``exponent``: a big integer a big integer in canonical big endian representation
|
||||||
##
|
##
|
||||||
## Warning ⚠️ :
|
## Warning ⚠️ :
|
||||||
## This is an optimization for public exponent
|
## This is an optimization for public exponent
|
||||||
@ -241,7 +254,7 @@ func isSquare*[C](a: Fp[C]): SecretBool =
|
|||||||
## Returns true if ``a`` is a square (quadratic residue) in 𝔽p
|
## Returns true if ``a`` is a square (quadratic residue) in 𝔽p
|
||||||
##
|
##
|
||||||
## Assumes that the prime modulus ``p`` is public.
|
## Assumes that the prime modulus ``p`` is public.
|
||||||
# Implementation: we use exponentiation by (p-1)/2 (Euler(s criterion)
|
# Implementation: we use exponentiation by (p-1)/2 (Euler's criterion)
|
||||||
# as it can reuse the exponentiation implementation
|
# as it can reuse the exponentiation implementation
|
||||||
# Note that we don't care about leaking the bits of p
|
# Note that we don't care about leaking the bits of p
|
||||||
# as we assume that
|
# as we assume that
|
||||||
@ -267,6 +280,39 @@ func sqrt_p3mod4[C](a: var Fp[C]) =
|
|||||||
static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3
|
static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3
|
||||||
a.powUnsafeExponent(C.getPrimePlus1div4_BE())
|
a.powUnsafeExponent(C.getPrimePlus1div4_BE())
|
||||||
|
|
||||||
|
func sqrt_invsqrt_p3mod4[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) =
|
||||||
|
## If ``a`` is a square, compute the square root of ``a`` in sqrt
|
||||||
|
## and the inverse square root of a in invsqrt
|
||||||
|
##
|
||||||
|
## This assumes that the prime field modulus ``p``: p ≡ 3 (mod 4)
|
||||||
|
# TODO: deterministic sign
|
||||||
|
#
|
||||||
|
# Algorithm
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# From Euler's criterion: a^((p-1)/2)) ≡ 1 (mod p) if square
|
||||||
|
# a^((p-1)/2)) * a^-1 ≡ 1/a (mod p)
|
||||||
|
# a^((p-3)/2)) ≡ 1/a (mod p)
|
||||||
|
# a^((p-3)/4)) ≡ 1/√a (mod p) # Requires p ≡ 3 (mod 4)
|
||||||
|
static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3
|
||||||
|
|
||||||
|
invsqrt = a
|
||||||
|
invsqrt.powUnsafeExponent(C.getPrimeMinus3div4_BE())
|
||||||
|
# √a ≡ a * 1/√a ≡ a^((p+1)/4) (mod p)
|
||||||
|
sqrt.prod(invsqrt, a)
|
||||||
|
|
||||||
|
func sqrt_invsqrt_if_square_p3mod4[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool =
|
||||||
|
## If ``a`` is a square, compute the square root of ``a`` in sqrt
|
||||||
|
## and the inverse square root of a in invsqrt
|
||||||
|
##
|
||||||
|
## If a is not square, sqrt and invsqrt are undefined
|
||||||
|
##
|
||||||
|
## This assumes that the prime field modulus ``p``: p ≡ 3 (mod 4)
|
||||||
|
sqrt_invsqrt_p3mod4(sqrt, invsqrt, a)
|
||||||
|
var euler {.noInit.}: Fp[C]
|
||||||
|
euler.prod(sqrt, invsqrt)
|
||||||
|
result = not(euler.mres == C.getMontyPrimeMinus1())
|
||||||
|
|
||||||
func sqrt_if_square_p3mod4[C](a: var Fp[C]): SecretBool =
|
func sqrt_if_square_p3mod4[C](a: var Fp[C]): SecretBool =
|
||||||
## If ``a`` is a square, compute the square root of ``a``
|
## If ``a`` is a square, compute the square root of ``a``
|
||||||
## if not, ``a`` is unmodified.
|
## if not, ``a`` is unmodified.
|
||||||
@ -278,19 +324,9 @@ func sqrt_if_square_p3mod4[C](a: var Fp[C]): SecretBool =
|
|||||||
## The square root, if it exist is multivalued,
|
## The square root, if it exist is multivalued,
|
||||||
## i.e. both x² == (-x)²
|
## i.e. both x² == (-x)²
|
||||||
## This procedure returns a deterministic result
|
## This procedure returns a deterministic result
|
||||||
static: doAssert BaseType(C.Mod.limbs[0]) mod 4 == 3
|
var sqrt {.noInit.}, invsqrt {.noInit.}: Fp[C]
|
||||||
|
result = sqrt_invsqrt_if_square_p3mod4(sqrt, invsqrt, a)
|
||||||
var a1 {.noInit.} = a
|
a.ccopy(sqrt, result)
|
||||||
a1.powUnsafeExponent(C.getPrimeMinus3div4_BE())
|
|
||||||
|
|
||||||
var a1a {.noInit.}: Fp[C]
|
|
||||||
a1a.prod(a1, a)
|
|
||||||
|
|
||||||
var a0 {.noInit.}: Fp[C]
|
|
||||||
a0.prod(a1a, a1)
|
|
||||||
|
|
||||||
result = not(a0.mres == C.getMontyPrimeMinus1())
|
|
||||||
a.ccopy(a1a, result)
|
|
||||||
|
|
||||||
func sqrt*[C](a: var Fp[C]) =
|
func sqrt*[C](a: var Fp[C]) =
|
||||||
## Compute the square root of ``a``
|
## Compute the square root of ``a``
|
||||||
@ -319,6 +355,36 @@ func sqrt_if_square*[C](a: var Fp[C]): SecretBool =
|
|||||||
else:
|
else:
|
||||||
{.error: "Square root is only implemented for p ≡ 3 (mod 4)".}
|
{.error: "Square root is only implemented for p ≡ 3 (mod 4)".}
|
||||||
|
|
||||||
|
func sqrt_invsqrt*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) =
|
||||||
|
## Compute the square root and inverse square root of ``a``
|
||||||
|
##
|
||||||
|
## This requires ``a`` to be a square
|
||||||
|
##
|
||||||
|
## The result is undefined otherwise
|
||||||
|
##
|
||||||
|
## The square root, if it exist is multivalued,
|
||||||
|
## i.e. both x² == (-x)²
|
||||||
|
## This procedure returns a deterministic result
|
||||||
|
when BaseType(C.Mod.limbs[0]) mod 4 == 3:
|
||||||
|
sqrt_invsqrt_p3mod4(sqrt, invsqrt, a)
|
||||||
|
else:
|
||||||
|
{.error: "Square root is only implemented for p ≡ 3 (mod 4)".}
|
||||||
|
|
||||||
|
func sqrt_invsqrt_if_square*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool =
|
||||||
|
## Compute the square root and ivnerse square root of ``a``
|
||||||
|
##
|
||||||
|
## This returns true if ``a`` is square and sqrt/invsqrt contains the square root/inverse square root
|
||||||
|
##
|
||||||
|
## The result is undefined otherwise
|
||||||
|
##
|
||||||
|
## The square root, if it exist is multivalued,
|
||||||
|
## i.e. both x² == (-x)²
|
||||||
|
## This procedure returns a deterministic result
|
||||||
|
when BaseType(C.Mod.limbs[0]) mod 4 == 3:
|
||||||
|
result = sqrt_invsqrt_if_square_p3mod4(sqrt, invsqrt, a)
|
||||||
|
else:
|
||||||
|
{.error: "Square root is only implemented for p ≡ 3 (mod 4)".}
|
||||||
|
|
||||||
# ############################################################
|
# ############################################################
|
||||||
#
|
#
|
||||||
# Field arithmetic ergonomic primitives
|
# Field arithmetic ergonomic primitives
|
||||||
|
|||||||
@ -214,10 +214,9 @@ func sqrt_if_square*(a: var QuadraticExt): SecretBool =
|
|||||||
let quadResidTest = t2.isSquare()
|
let quadResidTest = t2.isSquare()
|
||||||
t2.ccopy(t3, not quadResidTest)
|
t2.ccopy(t3, not quadResidTest)
|
||||||
|
|
||||||
t2.sqrt()
|
sqrt_invsqrt(sqrt = t1, invsqrt = t3, t2)
|
||||||
a.c0.ccopy(t2, result)
|
a.c0.ccopy(t1, result)
|
||||||
|
|
||||||
t2.double()
|
t3.div2()
|
||||||
t1.inv(t2)
|
t3 *= a.c1
|
||||||
t1 *= a.c1
|
a.c1.ccopy(t3, result)
|
||||||
a.c1.ccopy(t1, result)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user