Field sqrt optimization (#168)

* add more Fp tests for Twisted Edwards curves

* add fused sqrt+division bench

* Significant fused sqrt+division improvement for any prime field over algorithm described in  "High-Speed High-Security Signature", Bernstein et al, p15 "Fast decompression", https://ed25519.cr.yp.to/ed25519-20110705.pdf

* Activate secp256k1 field benches + spring renaming of field multiplication

* addition chains for inversion and sqrt of Curve25519

* Make isSquare use addition chains

* add double-prec mul/square bench for <256-bit prime fields.
This commit is contained in:
Mamy Ratsimbazafy 2022-01-01 16:19:35 +01:00 committed by GitHub
parent 53f9708c2b
commit bea798e27c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 304 additions and 213 deletions

View File

@ -16,7 +16,7 @@ import
# Internal # Internal
../constantine/config/common, ../constantine/config/common,
# Helpers # Helpers
../helpers/[prng_unsafe, static_for], ../helpers/prng_unsafe,
./platforms, ./platforms,
# Standard library # Standard library
std/[monotimes, times, strformat, strutils, macros] std/[monotimes, times, strformat, strutils, macros]

View File

@ -14,15 +14,17 @@
import import
# Internals # Internals
../constantine/config/[curves, common], ../constantine/config/[common, curves],
../constantine/arithmetic, ../constantine/arithmetic,
../constantine/towers, ../constantine/towers,
../constantine/curves/zoo_square_roots,
# Helpers # Helpers
../helpers/[prng_unsafe, static_for], ../helpers/prng_unsafe,
./bench_blueprint ./bench_blueprint
export notes export notes
proc separator*() = separator(165) proc separator*() = separator(165)
proc smallSeparator*() = separator(8)
proc report(op, field: string, start, stop: MonoTime, startClk, stopClk: int64, iters: int) = proc report(op, field: string, start, stop: MonoTime, startClk, stopClk: int64, iters: int) =
let ns = inNanoseconds((stop-start) div iters) let ns = inNanoseconds((stop-start) div iters)
@ -119,37 +121,30 @@ proc invAddChainBench*(T: typedesc, iters: int) =
proc sqrtBench*(T: typedesc, iters: int) = proc sqrtBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T) let x = rng.random_unsafe(T)
bench("Square Root + isSquare (constant-time default impl)", T, iters):
const algoType = block:
when T.C.hasP3mod4_primeModulus():
"p ≡ 3 (mod 4)"
elif T.C.hasP5mod8_primeModulus():
"p ≡ 5 (mod 8)"
else:
"Tonelli-Shanks"
const addchain = block:
when T.C.hasSqrtAddchain() or T.C.hasTonelliShanksAddchain():
"with addition chain"
else:
"without addition chain"
const desc = "Square Root (constant-time " & algoType & " " & addchain & ")"
bench(desc, T, iters):
var r = x var r = x
discard r.sqrt_if_square() discard r.sqrt_if_square()
proc sqrtP3mod4Bench*(T: typedesc, iters: int) = proc sqrtRatioBench*(T: typedesc, iters: int) =
var r: T var r: T
let x = rng.random_unsafe(T) let u = rng.random_unsafe(T)
bench("SquareRoot (p ≡ 3 (mod 4) exponentiation)", T, iters): let v = rng.random_unsafe(T)
r.invsqrt_p3mod4(x) bench("Fused SquareRoot+Division+isSquare sqrt(u/v)", T, iters):
r *= x let isSquare = r.sqrt_ratio_if_square(u, v)
proc sqrtAddChainBench*(T: typedesc, iters: int) =
var r: T
let x = rng.random_unsafe(T)
bench("SquareRoot (addition chain)", T, iters):
r.invsqrt_addchain(x)
r *= x
proc sqrtTonelliBench*(T: typedesc, iters: int) =
var r: T
let x = rng.random_unsafe(T)
bench("SquareRoot (constant-time Tonelli-Shanks exponentiation)", T, iters):
r.invsqrt_tonelli_shanks(x, useAddChain = false)
r *= x
proc sqrtTonelliAddChainBench*(T: typedesc, iters: int) =
var r: T
let x = rng.random_unsafe(T)
bench("SquareRoot (constant-time Tonelli-Shanks addchain)", T, iters):
r.invsqrt_tonelli_shanks(x, useAddChain = true)
r *= x
proc powBench*(T: typedesc, iters: int) = proc powBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T) let x = rng.random_unsafe(T)

View File

@ -14,9 +14,7 @@ import
../constantine/curves/[zoo_inversions, zoo_square_roots], ../constantine/curves/[zoo_inversions, zoo_square_roots],
# Helpers # Helpers
../helpers/static_for, ../helpers/static_for,
./bench_fields_template, ./bench_fields_template
# Standard library
std/strutils
# ############################################################ # ############################################################
# #
@ -31,9 +29,10 @@ const AvailableCurves = [
# P224, # P224,
BN254_Nogami, BN254_Nogami,
BN254_Snarks, BN254_Snarks,
# Curve25519, Curve25519,
# P256, Bandersnatch,
# Secp256k1, P256,
Secp256k1,
BLS12_377, BLS12_377,
BLS12_381, BLS12_381,
BW6_761 BW6_761
@ -50,17 +49,11 @@ proc main() =
div2Bench(Fp[curve], Iters) div2Bench(Fp[curve], Iters)
mulBench(Fp[curve], Iters) mulBench(Fp[curve], Iters)
sqrBench(Fp[curve], Iters) sqrBench(Fp[curve], Iters)
smallSeparator()
invEuclidBench(Fp[curve], ExponentIters) invEuclidBench(Fp[curve], ExponentIters)
invPowFermatBench(Fp[curve], ExponentIters) invPowFermatBench(Fp[curve], ExponentIters)
when curve.hasInversionAddchain(): sqrtBench(Fp[curve], ExponentIters)
invAddChainBench(Fp[curve], ExponentIters) sqrtRatioBench(Fp[curve], ExponentIters)
when (BaseType(curve.Mod.limbs[0]) and 3) == 3:
sqrtP3mod4Bench(Fp[curve], ExponentIters)
when curve.hasSqrtAddchain():
sqrtAddChainBench(Fp[curve], ExponentIters)
when curve in {BLS12_377}:
sqrtTonelliBench(Fp[curve], ExponentIters)
sqrtTonelliAddChainBench(Fp[curve], ExponentIters)
# Exponentiation by a "secret" of size ~the curve order # Exponentiation by a "secret" of size ~the curve order
powBench(Fp[curve], ExponentIters) powBench(Fp[curve], ExponentIters)
powUnsafeBench(Fp[curve], ExponentIters) powUnsafeBench(Fp[curve], ExponentIters)

View File

@ -230,8 +230,11 @@ proc main() =
diff2xUnreduce(Fp[BLS12_381], iters = 10_000_000) diff2xUnreduce(Fp[BLS12_381], iters = 10_000_000)
neg2x(Fp[BLS12_381], iters = 10_000_000) neg2x(Fp[BLS12_381], iters = 10_000_000)
separator() separator()
prod2xBench(512, 256, 256, iters = 10_000_000)
prod2xBench(768, 384, 384, iters = 10_000_000) prod2xBench(768, 384, 384, iters = 10_000_000)
square2xBench(512, 256, iters = 10_000_000)
square2xBench(768, 384, iters = 10_000_000) square2xBench(768, 384, iters = 10_000_000)
reduce2x(Fp[BN254_Snarks], iters = 10_000_000)
reduce2x(Fp[BLS12_381], iters = 10_000_000) reduce2x(Fp[BLS12_381], iters = 10_000_000)
separator() separator()

View File

@ -34,7 +34,7 @@ static: doAssert UseASM_X86_64
# Montgomery multiplication # Montgomery multiplication
# ------------------------------------------------------------ # ------------------------------------------------------------
# Fallback when no ADX and BMI2 support (MULX, ADCX, ADOX) # Fallback when no ADX and BMI2 support (MULX, ADCX, ADOX)
macro montMul_CIOS_nocarry_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_MM: Limbs[N], m0ninv_MM: BaseType): untyped = macro montMul_CIOS_sparebit_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_MM: Limbs[N], m0ninv_MM: BaseType): untyped =
## Generate an optimized Montgomery Multiplication kernel ## Generate an optimized Montgomery Multiplication kernel
## using the CIOS method ## using the CIOS method
## ##
@ -174,9 +174,9 @@ macro montMul_CIOS_nocarry_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_
result.add ctx.generate result.add ctx.generate
func montMul_CIOS_nocarry_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = func montMul_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
## Constant-time modular multiplication ## Constant-time modular multiplication
montMul_CIOS_nocarry_gen(r, a, b, M, m0ninv) montMul_CIOS_sparebit_gen(r, a, b, M, m0ninv)
# Montgomery Squaring # Montgomery Squaring
# ------------------------------------------------------------ # ------------------------------------------------------------

View File

@ -174,7 +174,7 @@ proc partialRedx(
ctx.adcx t[N-1], S ctx.adcx t[N-1], S
ctx.adox t[N-1], C ctx.adox t[N-1], C
macro montMul_CIOS_nocarry_adx_bmi2_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_MM: Limbs[N], m0ninv_MM: BaseType): untyped = macro montMul_CIOS_sparebit_adx_bmi2_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_MM: Limbs[N], m0ninv_MM: BaseType): untyped =
## Generate an optimized Montgomery Multiplication kernel ## Generate an optimized Montgomery Multiplication kernel
## using the CIOS method ## using the CIOS method
## This requires the most significant word of the Modulus ## This requires the most significant word of the Modulus
@ -270,9 +270,10 @@ macro montMul_CIOS_nocarry_adx_bmi2_gen[N: static int](r_MM: var Limbs[N], a_MM,
result.add ctx.generate result.add ctx.generate
func montMul_CIOS_nocarry_asm_adx_bmi2*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = func montMul_CIOS_sparebit_asm_adx_bmi2*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
## Constant-time modular multiplication ## Constant-time modular multiplication
montMul_CIOS_nocarry_adx_bmi2_gen(r, a, b, M, m0ninv) ## Requires the prime modulus to have a spare bit in the representation. (Hence if using 64-bit words and 4 words, to be at most 255-bit)
montMul_CIOS_sparebit_adx_bmi2_gen(r, a, b, M, m0ninv)
# Montgomery Squaring # Montgomery Squaring
# ------------------------------------------------------------ # ------------------------------------------------------------

View File

@ -37,10 +37,16 @@ func inv*(r: var FF, a: FF) =
## to convert Jacobian and Projective coordinates ## to convert Jacobian and Projective coordinates
## to affine for elliptic curve ## to affine for elliptic curve
# For now we don't activate the addition chains # For now we don't activate the addition chains
# neither for Secp256k1 nor BN curves # Performance is slower than Euclid-based inversion on newer CPUs
# Performance is slower than GCD #
# To be revisited with faster squaring/multiplications # - Montgomery multiplication/squaring can skip the final substraction
when FF is Fp and FF.C.hasInversionAddchain(): # - For generalized Mersenne Prime curves, modular reduction can be made extremely cheap.
# - For BW6-761 the addition chain is over 2x slower than Euclid-based inversion
# due to multiplication being so costly with 12 limbs (grows quadratically)
# while Euclid costs grows linearly.
when false and
FF is Fp and FF.C.hasInversionAddchain() and
FF.C notin {BW6_761}:
r.inv_addchain(a) r.inv_addchain(a)
else: else:
r.inv_euclid(a) r.inv_euclid(a)

View File

@ -10,8 +10,7 @@ import
../primitives, ../primitives,
../config/[common, type_ff, curves], ../config/[common, type_ff, curves],
../curves/zoo_square_roots, ../curves/zoo_square_roots,
./bigints, ./finite_fields, ./bigints, ./finite_fields
./finite_fields_inversion
# ############################################################ # ############################################################
# #
@ -23,39 +22,14 @@ import
{.push raises: [].} {.push raises: [].}
{.push inline.} {.push inline.}
# Legendre symbol / Euler's Criterion / Kronecker's symbol
# ------------------------------------------------------------
func isSquare*(a: Fp): SecretBool =
## Returns true if ``a`` is a square (quadratic residue) in 𝔽p
##
## Assumes that the prime modulus ``p`` is public.
# Implementation: we use exponentiation by (p-1)/2 (Euler's criterion)
# as it can reuse the exponentiation implementation
# Note that we don't care about leaking the bits of p
# as we assume that
var xi {.noInit.} = a # TODO: is noInit necessary? see https://github.com/mratsim/constantine/issues/21
xi.powUnsafeExponent(Fp.getPrimeMinus1div2_BE())
result = not(xi.isMinusOne())
# xi can be:
# - 1 if a square
# - 0 if 0
# - -1 if a quadratic non-residue
debug:
doAssert: bool(
xi.isZero or
xi.isOne or
xi.isMinusOne()
)
# Specialized routine for p ≡ 3 (mod 4) # Specialized routine for p ≡ 3 (mod 4)
# ------------------------------------------------------------ # ------------------------------------------------------------
func hasP3mod4_primeModulus(C: static Curve): static bool = func hasP3mod4_primeModulus*(C: static Curve): static bool =
## Returns true iff p ≡ 3 (mod 4) ## Returns true iff p ≡ 3 (mod 4)
(BaseType(C.Mod.limbs[0]) and 3) == 3 (BaseType(C.Mod.limbs[0]) and 3) == 3
func invsqrt_p3mod4*(r: var Fp, a: Fp) = func invsqrt_p3mod4(r: var Fp, a: Fp) =
## Compute the inverse square root of ``a`` ## Compute the inverse square root of ``a``
## ##
## This requires ``a`` to be a square ## This requires ``a`` to be a square
@ -75,17 +49,20 @@ func invsqrt_p3mod4*(r: var Fp, a: Fp) =
# a^((p-3)/2)) ≡ 1/a (mod p) # a^((p-3)/2)) ≡ 1/a (mod p)
# a^((p-3)/4)) ≡ 1/√a (mod p) # Requires p ≡ 3 (mod 4) # a^((p-3)/4)) ≡ 1/√a (mod p) # Requires p ≡ 3 (mod 4)
static: doAssert Fp.C.hasP3mod4_primeModulus() static: doAssert Fp.C.hasP3mod4_primeModulus()
when FP.C.hasSqrtAddchain():
r.invsqrt_addchain(a)
else:
r = a r = a
r.powUnsafeExponent(Fp.getPrimeMinus3div4_BE()) r.powUnsafeExponent(Fp.getPrimeMinus3div4_BE())
# Specialized routine for p ≡ 5 (mod 8) # Specialized routine for p ≡ 5 (mod 8)
# ------------------------------------------------------------ # ------------------------------------------------------------
func hasP5mod8_primeModulus(C: static Curve): static bool = func hasP5mod8_primeModulus*(C: static Curve): static bool =
## Returns true iff p ≡ 5 (mod 8) ## Returns true iff p ≡ 5 (mod 8)
(BaseType(C.Mod.limbs[0]) and 7) == 5 (BaseType(C.Mod.limbs[0]) and 7) == 5
func invsqrt_p5mod8*(r: var Fp, a: Fp) = func invsqrt_p5mod8(r: var Fp, a: Fp) =
## Compute the inverse square root of ``a`` ## Compute the inverse square root of ``a``
## ##
## This requires ``a`` to be a square ## This requires ``a`` to be a square
@ -141,6 +118,9 @@ func invsqrt_p5mod8*(r: var Fp, a: Fp) =
# α = (2a)^((p-5)/8) # α = (2a)^((p-5)/8)
alpha.double(a) alpha.double(a)
beta = alpha beta = alpha
when Fp.C.hasSqrtAddchain():
alpha.invsqrt_addchain_pminus5over8(alpha)
else:
alpha.powUnsafeExponent(Fp.getPrimeMinus5div8_BE()) alpha.powUnsafeExponent(Fp.getPrimeMinus5div8_BE())
# Note: if r aliases a, for inverse square root we don't use `a` again # Note: if r aliases a, for inverse square root we don't use `a` again
@ -163,13 +143,11 @@ func invsqrt_p5mod8*(r: var Fp, a: Fp) =
# Tonelli Shanks for any prime # Tonelli Shanks for any prime
# ------------------------------------------------------------ # ------------------------------------------------------------
func precompute_tonelli_shanks( func precompute_tonelli_shanks(a_pre_exp: var Fp, a: Fp) =
a_pre_exp: var Fp, when FP.C.hasTonelliShanksAddchain():
a: Fp, useAddChain: static bool) =
a_pre_exp = a
when useAddChain:
a_pre_exp.precompute_tonelli_shanks_addchain(a) a_pre_exp.precompute_tonelli_shanks_addchain(a)
else: else:
a_pre_exp = a
a_pre_exp.powUnsafeExponent(Fp.C.tonelliShanks(exponent)) a_pre_exp.powUnsafeExponent(Fp.C.tonelliShanks(exponent))
func isSquare_tonelli_shanks( func isSquare_tonelli_shanks(
@ -232,7 +210,7 @@ func invsqrt_tonelli_shanks_pre(
t.ccopy(buf, bNotOne) t.ccopy(buf, bNotOne)
b = t b = t
func invsqrt_tonelli_shanks*(r: var Fp, a: Fp, useAddChain: static bool) = func invsqrt_tonelli_shanks*(r: var Fp, a: Fp) =
## Compute the inverse square root of ``a`` ## Compute the inverse square root of ``a``
## ##
## This requires ``a`` to be a square ## This requires ``a`` to be a square
@ -244,13 +222,11 @@ func invsqrt_tonelli_shanks*(r: var Fp, a: Fp, useAddChain: static bool) =
## This procedure returns a deterministic result ## This procedure returns a deterministic result
## This procedure is constant-time ## This procedure is constant-time
var a_pre_exp{.noInit.}: Fp var a_pre_exp{.noInit.}: Fp
a_pre_exp.precompute_tonelli_shanks(a, useAddChain) a_pre_exp.precompute_tonelli_shanks(a)
invsqrt_tonelli_shanks_pre(r, a, a_pre_exp) invsqrt_tonelli_shanks_pre(r, a, a_pre_exp)
# Public routines # Public routines
# ------------------------------------------------------------ # ------------------------------------------------------------
# Note: we export the inner sqrt_invsqrt_IMPL
# for benchmarking purposes.
{.push inline.} {.push inline.}
@ -265,14 +241,12 @@ func invsqrt*[C](r: var Fp[C], a: Fp[C]) =
## i.e. both x² == (-x)² ## i.e. both x² == (-x)²
## This procedure returns a deterministic result ## This procedure returns a deterministic result
## This procedure is constant-time ## This procedure is constant-time
when C.hasSqrtAddchain(): when C.hasP3mod4_primeModulus():
r.invsqrt_addchain(a)
elif C.hasP3mod4_primeModulus():
r.invsqrt_p3mod4(a) r.invsqrt_p3mod4(a)
elif C.hasP5mod8_primeModulus(): elif C.hasP5mod8_primeModulus():
r.invsqrt_p5mod8(a) r.invsqrt_p5mod8(a)
else: else:
r.invsqrt_tonelli_shanks(a, useAddChain = C.hasTonelliShanksAddchain()) r.invsqrt_tonelli_shanks(a)
func sqrt*[C](a: var Fp[C]) = func sqrt*[C](a: var Fp[C]) =
## Compute the square root of ``a`` ## Compute the square root of ``a``
@ -340,73 +314,46 @@ func invsqrt_if_square*[C](r: var Fp[C], a: Fp[C]): SecretBool =
var sqrt{.noInit.}: Fp[C] var sqrt{.noInit.}: Fp[C]
result = sqrt_invsqrt_if_square(sqrt, r, a) result = sqrt_invsqrt_if_square(sqrt, r, a)
# Legendre symbol / Euler's Criterion / Kronecker's symbol
# ------------------------------------------------------------
func isSquare*(a: Fp): SecretBool =
## Returns true if ``a`` is a square (quadratic residue) in 𝔽p
##
## Assumes that the prime modulus ``p`` is public.
when false:
# Implementation: we use exponentiation by (p-1)/2 (Euler's criterion)
# as it can reuse the exponentiation implementation
# Note that we don't care about leaking the bits of p
# as we assume that
var xi {.noInit.} = a # TODO: is noInit necessary? see https://github.com/mratsim/constantine/issues/21
xi.powUnsafeExponent(Fp.getPrimeMinus1div2_BE())
result = not(xi.isMinusOne())
# xi can be:
# - 1 if a square
# - 0 if 0
# - -1 if a quadratic non-residue
debug:
doAssert: bool(
xi.isZero or
xi.isOne or
xi.isMinusOne()
)
else:
# We reuse the optimized addition chains instead of exponentiation by (p-1)/2
when Fp.C.hasP3mod4_primeModulus() or Fp.C.hasP5mod8_primeModulus():
var sqrt{.noInit.}, invsqrt{.noInit.}: Fp
return sqrt_invsqrt_if_square(sqrt, invsqrt, a)
else:
var a_pre_exp{.noInit.}: Fp
a_pre_exp.precompute_tonelli_shanks(a)
return isSquare_tonelli_shanks(a, a_pre_exp)
{.pop.} # inline {.pop.} # inline
# Fused routines # Fused routines
# ------------------------------------------------------------ # ------------------------------------------------------------
func sqrt_ratio_if_square_p5mod8(r: var Fp, u, v: Fp): SecretBool =
## If u/v is a square, compute √(u/v)
## if not, the result is undefined
##
## Requires p ≡ 5 (mod 8)
## r must not alias u or v
##
## The square root, if it exist is multivalued,
## i.e. both (u/v)² == (-u/v)²
## This procedure returns a deterministic result
## This procedure is constant-time
# References:
# - High-Speed High-Security Signature, Bernstein et al, p15 "Fast decompression", https://ed25519.cr.yp.to/ed25519-20110705.pdf
# - IETF Hash-to-Curve: https://github.com/cfrg/draft-irtf-cfrg-hash-to-curve/blob/9939a07/draft-irtf-cfrg-hash-to-curve.md#optimized-sqrt_ratio-for-q--5-mod-8
# - Pasta curves divsqrt: https://github.com/zcash/pasta/blob/f0f7068/squareroottab.sage#L139-L193
#
# p ≡ 5 (mod 8), hence 𝑖 ∈ Fp with 𝑖² ≡ 1 (mod p)
# if α is a square, with β ≡ α^((p+3)/8) (mod p)
# - either β² ≡ α (mod p), hence √α ≡ ± β (mod p)
# - or β² ≡ -α (mod p), hence √α ≡ ± 𝑖β (mod p)
# (see explanation in invsqrt_p5mod8)
#
# In our fused division and sqrt case we have
# β = (u/v)^((p+3)/8)
# = u^((p+3)/8).v^(p1(p+3)/8) via Fermat's little theorem
# = u^((p+3)/8).v^((7p11)/8)
# = u.u^((p-5)/8).v³.v^((7p35)/8)
# = uv³.u^((p-5)/8).v^(7(p-5)/8)
# = uv³(uv⁷)^((p5)/8)
#
# We can check if β² ≡ -α (mod p)
# by checking vβ² ≡ -u (mod p), and then multiply by 𝑖
# and if it's neither u or -u it wasn't a square.
static: doAssert Fp.C.hasP5mod8_primeModulus()
var t {.noInit.}: Fp
t.square(v)
t *= v
# r = uv³
r.prod(u, t)
# t = (uv⁷)^((p5)/8)
t *= r
t *= v
t.powUnsafeExponent(Fp.getPrimeMinus5div8_BE())
# r = β = uv³(uv⁷)^((p5)/8)
r *= t
# Check candidate square roots
t.square(r)
t *= v
block:
result = t == u
block:
t.neg()
let isSol = t == u
result = result or isSol
t.prod(r, Fp.C.sqrt_minus_one())
r.ccopy(t, isSol)
func sqrt_ratio_if_square*(r: var Fp, u, v: Fp): SecretBool {.inline.} = func sqrt_ratio_if_square*(r: var Fp, u, v: Fp): SecretBool {.inline.} =
## If u/v is a square, compute √(u/v) ## If u/v is a square, compute √(u/v)
## if not, the result is undefined ## if not, the result is undefined
@ -417,12 +364,13 @@ func sqrt_ratio_if_square*(r: var Fp, u, v: Fp): SecretBool {.inline.} =
## i.e. both (u/v)² == (-u/v)² ## i.e. both (u/v)² == (-u/v)²
## This procedure returns a deterministic result ## This procedure returns a deterministic result
## This procedure is constant-time ## This procedure is constant-time
when Fp.C.hasP5mod8_primeModulus():
sqrt_ratio_if_square_p5mod8(r, u, v) # u/v is square iff 𝛘(u/v) = 1 (mod p)
else: # As 𝛘(a) = 1 or -1
# TODO: Fuse inversion and tonelli-shanks and legendre symbol # 𝛘(u/v) = 𝛘(ub)
r.inv(v) var uv{.noInit.}: Fp
r *= u uv.prod(u, v) # uv
result = r.sqrt_if_square() result = r.invsqrt_if_square(uv) # 1/√uv
r *= u # √u/√v
{.pop.} # raises no exceptions {.pop.} # raises no exceptions

View File

@ -162,7 +162,7 @@ func montyRedc2x_Comba[N: static int](
# Montgomery Multiplication # Montgomery Multiplication
# ------------------------------------------------------------ # ------------------------------------------------------------
func montyMul_CIOS_nocarry(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = func montyMul_CIOS_sparebit(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS) ## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS)
## and no-carry optimization. ## and no-carry optimization.
## This requires the most significant word of the Modulus ## This requires the most significant word of the Modulus
@ -373,11 +373,11 @@ func montyMul*(
when UseASM_X86_64 and a.len in {2 .. 6}: # TODO: handle spilling when UseASM_X86_64 and a.len in {2 .. 6}: # TODO: handle spilling
# ADX implies BMI2 # ADX implies BMI2
if ({.noSideEffect.}: hasAdx()): if ({.noSideEffect.}: hasAdx()):
montMul_CIOS_nocarry_asm_adx_bmi2(r, a, b, M, m0ninv) montMul_CIOS_sparebit_asm_adx_bmi2(r, a, b, M, m0ninv)
else: else:
montMul_CIOS_nocarry_asm(r, a, b, M, m0ninv) montMul_CIOS_sparebit_asm(r, a, b, M, m0ninv)
else: else:
montyMul_CIOS_nocarry(r, a, b, M, m0ninv) montyMul_CIOS_sparebit(r, a, b, M, m0ninv)
else: else:
montyMul_FIPS(r, a, b, M, m0ninv) montyMul_FIPS(r, a, b, M, m0ninv)
@ -393,7 +393,7 @@ func montySquare*[N](r: var Limbs[N], a, M: Limbs[N],
# which uses unfused squaring then Montgomery reduction # which uses unfused squaring then Montgomery reduction
# is slightly slower than fused Montgomery multiplication # is slightly slower than fused Montgomery multiplication
when spareBits >= 1: when spareBits >= 1:
montMul_CIOS_nocarry_asm_adx_bmi2(r, a, a, M, m0ninv) montMul_CIOS_sparebit_asm_adx_bmi2(r, a, a, M, m0ninv)
else: else:
montSquare_CIOS_asm_adx_bmi2(r, a, M, m0ninv, spareBits >= 1) montSquare_CIOS_asm_adx_bmi2(r, a, M, m0ninv, spareBits >= 1)
else: else:

View File

@ -19,6 +19,8 @@ type
## P being the prime modulus of the Curve C ## P being the prime modulus of the Curve C
## Internally, data is stored in Montgomery n-residue form ## Internally, data is stored in Montgomery n-residue form
## with the magic constant chosen for convenient division (a power of 2 depending on P bitsize) ## with the magic constant chosen for convenient division (a power of 2 depending on P bitsize)
# TODO, pseudo mersenne priles like 2²⁵⁵-19 have very fast modular reduction
# and don't need Montgomery representation
mres*: matchingBigInt(C) mres*: matchingBigInt(C)
Fr*[C: static Curve] = object Fr*[C: static Curve] = object

View File

@ -0,0 +1,60 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
../config/curves,
../arithmetic/finite_fields
# ############################################################
#
# Specialized inversion for BLS12-381
#
# ############################################################
func inv_addchain*(r: var Fp[Curve25519], a: Fp[Curve25519]) =
var
x10 {.noinit.}: Fp[Curve25519]
x1001 {.noinit.}: Fp[Curve25519]
x1011 {.noinit.}: Fp[Curve25519]
x10 .square(a) # 2
x1001 .square_repeated(x10, 2) # 8
x1001 *= a # 9
x1011 .prod(x10, x1001) # 11
# 5 operations
# TODO: we can accumulate in a partially reduced
# doubled-size `r` to avoid the final substractions.
# and only reduce at the end.
# This requires the number of op to be less than log2(p) == 255
template t: untyped = x10
t.square(x1011) # 22
r.prod(t, x1001) # 31 = 2⁵-1
template u: untyped = x1001
t.square_repeated(r, 5)
r *= t # 2¹⁰-1
t.square_repeated(r, 10)
t *= r # 2²⁰-1
u.square_repeated(t, 20)
t *= u # 2⁴⁰-1
t.square_repeated(10)
t *= r # 2⁵⁰-1
r.square_repeated(t, 50)
r *= t # 2¹⁰⁰-1
u.square_repeated(r, 100)
r *= u # 2²⁰⁰-1
r.square_repeated(50)
r *= t # 2²⁵⁰-1
r.square_repeated(5)
r *= x1011 # 2²⁵⁵-21 (note: 11 = 2⁵-21)

View File

@ -7,21 +7,32 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms. # at your option. This file may not be copied, modified, or distributed except according to those terms.
import import
../config/[curves, type_bigint, type_ff], ../config/[curves, type_ff],
../io/[io_bigints, io_fields],
../arithmetic/finite_fields ../arithmetic/finite_fields
# p ≡ 5 (mod 8), hence 𝑖 ∈ Fp with 𝑖² ≡ 1 (mod p) func invsqrt_addchain_pminus5over8*(r: var Fp[Curve25519], a: Fp[Curve25519]) =
# Hence if α is a square ## Returns a^((p-5)/8) = 2²⁵²-3 for inverse square root computation
# with β ≡ α^((p+3)/8) (mod p)
# - either β² ≡ α (mod p), hence √α ≡ ±β (mod p)
# - or β² ≡ -α (mod p), hence √α ≡ ±𝑖β (mod p)
# Sage: var t{.noInit.}, u{.noInit.}, v{.noinit.}: Fp[Curve25519]
# p = Integer('0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed') u.square(a) # 2
# Fp = GF(p) v.square_repeated(u, 2) # 8
# sqrt_minus1 = Fp(-1).sqrt() v *= a # 9
# print(Integer(sqrt_minus1).hex()) u *= v # 11
const Curve25519_sqrt_minus_one* = Fp[Curve25519].fromHex( u.square() # 22
"0x2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0" u *= v # 31 = 2⁵-1
) v.square_repeated(u, 5) #
u *= v # 2¹⁰-1
v.square_repeated(u, 10) #
v *= u # 2²⁰-1
t.square_repeated(v, 20) #
v *= t # 2⁴⁰-1
v.square_repeated(10) #
u *= v # 2⁵⁰-1
v.square_repeated(u, 50) #
v *= u # 2¹⁰⁰-1
t.square_repeated(v, 100) #
v *= t # 2²⁰⁰-1
v.square_repeated(50) #
u *= v # 2²⁵⁰-1
u.square_repeated(2) # 2²⁵²-4
r.prod(a, u) # 2²⁵²-3

View File

@ -7,13 +7,15 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms. # at your option. This file may not be copied, modified, or distributed except according to those terms.
import import
../config/[curves, type_ff], ../config/curves,
./bls12_377_inversion, ./bls12_377_inversion,
./bls12_381_inversion, ./bls12_381_inversion,
./bn254_nogami_inversion, ./bn254_nogami_inversion,
./bn254_snarks_inversion, ./bn254_snarks_inversion,
./bw6_761_inversion, ./bw6_761_inversion,
./secp256k1_inversion ./secp256k1_inversion,
./curve25519_inversion
export export
bls12_377_inversion, bls12_377_inversion,
@ -21,16 +23,13 @@ export
bn254_nogami_inversion, bn254_nogami_inversion,
bn254_snarks_inversion, bn254_snarks_inversion,
bw6_761_inversion, bw6_761_inversion,
secp256k1_inversion secp256k1_inversion,
curve25519_inversion
func hasInversionAddchain*(C: static Curve): static bool = func hasInversionAddchain*(C: static Curve): static bool =
# TODO: For now we don't activate the addition chains ## Is an inversion addition chain implemented for the curve.
# for Secp256k1 ## Note: the addition chain might be slower than Euclid-based inversion.
# Performance is slower than GCD (to investigate) when C in {BN254_Nogami, BN254_Snarks, BLS12_377, BLS12_381, BW6_761, Curve25519, Secp256k1}:
# For BW6-761 the addition chain is over 2x slower than Euclid-based inversion
# due to multiplication being so costly with 12 limbs (grows quadratically)
# while Euclid costs grows linearly.
when C in {BN254_Nogami, BN254_Snarks, BLS12_377, BLS12_381}:
true true
else: else:
false false

View File

@ -27,7 +27,7 @@ export
curve25519_sqrt curve25519_sqrt
func hasSqrtAddchain*(C: static Curve): static bool = func hasSqrtAddchain*(C: static Curve): static bool =
when C in {BLS12_381, BN254_Nogami, BN254_Snarks, BW6_761}: when C in {BLS12_381, BN254_Nogami, BN254_Snarks, BW6_761, Curve25519}:
true true
else: else:
false false
@ -43,7 +43,3 @@ func hasTonelliShanksAddchain*(C: static Curve): static bool =
true true
else: else:
false false
macro sqrt_minus_one*(C: static Curve): untyped =
## Return 𝑖 ∈ Fp with 𝑖² ≡ 1 (mod p)
return bindSym($C & "_sqrt_minus_one")

View File

@ -71,7 +71,7 @@ func sqrx2x_complex_asm_adx_bmi2*(
t0.diff(a.c0, a.c1) t0.diff(a.c0, a.c1)
r.c0.mul_asm_adx_bmi2_impl(t0, t1) r.c0.mul_asm_adx_bmi2_impl(t0, t1)
func sqrx_complex_asm_adx_bmi2*( func sqrx_complex_sparebit_asm_adx_bmi2*(
r: var array[2, Fp], r: var array[2, Fp],
a: array[2, Fp] a: array[2, Fp]
) = ) =
@ -85,10 +85,10 @@ func sqrx_complex_asm_adx_bmi2*(
var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0)
v0.diff(a.c0, a.c1) v0.diff(a.c0, a.c1)
v1.sum(a.c0, a.c1) v1.sum(a.c0, a.c1)
r.c1.mres.limbs.montMul_CIOS_nocarry_asm_adx_bmi2(a.c0.mres.limbs, a.c1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) r.c1.mres.limbs.montMul_CIOS_sparebit_asm_adx_bmi2(a.c0.mres.limbs, a.c1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord())
# aliasing: a unneeded now # aliasing: a unneeded now
r.c1.double() r.c1.double()
r.c0.mres.limbs.montMul_CIOS_nocarry_asm_adx_bmi2(v0.mres.limbs, v1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) r.c0.mres.limbs.montMul_CIOS_sparebit_asm_adx_bmi2(v0.mres.limbs, v1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord())
# 𝔽p2 multiplication # 𝔽p2 multiplication
# ------------------------------------------------------------ # ------------------------------------------------------------

View File

@ -1230,9 +1230,9 @@ func square2x*(r: var QuadraticExt2x, a: QuadraticExt) =
func square*(r: var QuadraticExt, a: QuadraticExt) = func square*(r: var QuadraticExt, a: QuadraticExt) =
when r.fromComplexExtension(): when r.fromComplexExtension():
when true: when true:
when UseASM_X86_64 and a.c0.mres.limbs.len <= 6: when UseASM_X86_64 and a.c0.mres.limbs.len <= 6 and r.typeof.has1extraBit():
if ({.noSideEffect.}: hasAdx()): if ({.noSideEffect.}: hasAdx()):
r.coords.sqrx_complex_asm_adx_bmi2(a.coords) r.coords.sqrx_complex_sparebit_asm_adx_bmi2(a.coords)
else: else:
r.square_complex(a) r.square_complex(a)
else: else:

View File

@ -157,6 +157,22 @@ suite "Field Addition/Substraction/Negation via double-precision field elements"
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
addsubneg_random_long01Seq(BLS12_381) addsubneg_random_long01Seq(BLS12_381)
test "With Curve25519 field modulus":
for _ in 0 ..< Iters:
addsubneg_random_unsafe(Curve25519)
for _ in 0 ..< Iters:
addsubneg_randomHighHammingWeight(Curve25519)
for _ in 0 ..< Iters:
addsubneg_random_long01Seq(Curve25519)
test "With Bandersnatch field modulus":
for _ in 0 ..< Iters:
addsubneg_random_unsafe(Bandersnatch)
for _ in 0 ..< Iters:
addsubneg_randomHighHammingWeight(Bandersnatch)
for _ in 0 ..< Iters:
addsubneg_random_long01Seq(Bandersnatch)
test "Negate 0 returns 0 (unique Montgomery repr)": test "Negate 0 returns 0 (unique Montgomery repr)":
var a: FpDbl[BN254_Snarks] var a: FpDbl[BN254_Snarks]
var r {.noInit.}: FpDbl[BN254_Snarks] var r {.noInit.}: FpDbl[BN254_Snarks]
@ -197,6 +213,22 @@ suite "Field Multiplication via double-precision field elements is consistent wi
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
mul_random_long01Seq(BLS12_381) mul_random_long01Seq(BLS12_381)
test "With Curve25519 field modulus":
for _ in 0 ..< Iters:
mul_random_unsafe(Curve25519)
for _ in 0 ..< Iters:
mul_randomHighHammingWeight(Curve25519)
for _ in 0 ..< Iters:
mul_random_long01Seq(Curve25519)
test "With Bandersnatch field modulus":
for _ in 0 ..< Iters:
mul_random_unsafe(Bandersnatch)
for _ in 0 ..< Iters:
mul_randomHighHammingWeight(Bandersnatch)
for _ in 0 ..< Iters:
mul_random_long01Seq(Bandersnatch)
suite "Field Squaring via double-precision field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]": suite "Field Squaring via double-precision field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]":
test "With P-224 field modulus": test "With P-224 field modulus":
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
@ -229,3 +261,19 @@ suite "Field Squaring via double-precision field elements is consistent with sin
sqr_randomHighHammingWeight(BLS12_381) sqr_randomHighHammingWeight(BLS12_381)
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
sqr_random_long01Seq(BLS12_381) sqr_random_long01Seq(BLS12_381)
test "With Curve25519 field modulus":
for _ in 0 ..< Iters:
sqr_random_unsafe(Curve25519)
for _ in 0 ..< Iters:
sqr_randomHighHammingWeight(Curve25519)
for _ in 0 ..< Iters:
sqr_random_long01Seq(Curve25519)
test "With Bandersnatch field modulus":
for _ in 0 ..< Iters:
sqr_random_unsafe(Bandersnatch)
for _ in 0 ..< Iters:
sqr_randomHighHammingWeight(Bandersnatch)
for _ in 0 ..< Iters:
sqr_random_long01Seq(Bandersnatch)

View File

@ -83,7 +83,10 @@ proc mainSanity() =
sanity Mersenne127 sanity Mersenne127
sanity P224 # P224 uses the fast-path with 64-bit words and the slow path with 32-bit words sanity P224 # P224 uses the fast-path with 64-bit words and the slow path with 32-bit words
sanity P256 sanity P256
sanity Secp256k1
sanity BLS12_381 sanity BLS12_381
sanity Curve25519
sanity Bandersnatch
mainSanity() mainSanity()
@ -152,6 +155,14 @@ suite "Random Modular Squaring is consistent with Modular Multiplication" & " ["
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
random_long01Seq(P256) random_long01Seq(P256)
test "Random squaring mod Secp256k1 [FastSquaring = " & $(Fp[Secp256k1].getSpareBits() >= 2) & "]":
for _ in 0 ..< Iters:
randomCurve(Secp256k1)
for _ in 0 ..< Iters:
randomHighHammingWeight(Secp256k1)
for _ in 0 ..< Iters:
random_long01Seq(Secp256k1)
test "Random squaring mod BLS12_381 [FastSquaring = " & $(Fp[BLS12_381].getSpareBits() >= 2) & "]": test "Random squaring mod BLS12_381 [FastSquaring = " & $(Fp[BLS12_381].getSpareBits() >= 2) & "]":
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
randomCurve(BLS12_381) randomCurve(BLS12_381)
@ -160,6 +171,22 @@ suite "Random Modular Squaring is consistent with Modular Multiplication" & " ["
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
random_long01Seq(BLS12_381) random_long01Seq(BLS12_381)
test "Random squaring mod Curve25519 [FastSquaring = " & $(Fp[Curve25519].getSpareBits() >= 2) & "]":
for _ in 0 ..< Iters:
randomCurve(Curve25519)
for _ in 0 ..< Iters:
randomHighHammingWeight(Curve25519)
for _ in 0 ..< Iters:
random_long01Seq(Curve25519)
test "Random squaring mod Bandersnatch [FastSquaring = " & $(Fp[Bandersnatch].getSpareBits() >= 2) & "]":
for _ in 0 ..< Iters:
randomCurve(Bandersnatch)
for _ in 0 ..< Iters:
randomHighHammingWeight(Bandersnatch)
for _ in 0 ..< Iters:
random_long01Seq(Bandersnatch)
suite "Modular squaring - bugs highlighted by property-based testing": suite "Modular squaring - bugs highlighted by property-based testing":
test "a² == (-a)² on for Fp[2^127 - 1] - #61": test "a² == (-a)² on for Fp[2^127 - 1] - #61":
var a{.noInit.}: Fp[Mersenne127] var a{.noInit.}: Fp[Mersenne127]

View File

@ -198,6 +198,7 @@ proc main() =
testRandomDiv2 Secp256k1 testRandomDiv2 Secp256k1
testRandomDiv2 BLS12_377 testRandomDiv2 BLS12_377
testRandomDiv2 BLS12_381 testRandomDiv2 BLS12_381
testRandomDiv2 Bandersnatch
suite "Modular inversion over prime fields" & " [" & $WordBitwidth & "-bit mode]": suite "Modular inversion over prime fields" & " [" & $WordBitwidth & "-bit mode]":
test "Specific tests on Fp[BLS12_381]": test "Specific tests on Fp[BLS12_381]":
@ -285,6 +286,7 @@ proc main() =
testRandomInv Secp256k1 testRandomInv Secp256k1
testRandomInv BLS12_377 testRandomInv BLS12_377
testRandomInv BLS12_381 testRandomInv BLS12_381
testRandomInv Bandersnatch
main() main()

View File

@ -24,7 +24,7 @@ var RNG {.compileTime.} = initRand(1234)
const AvailableCurves = [ const AvailableCurves = [
P224, P224,
BN254_Nogami, BN254_Snarks, BN254_Nogami, BN254_Snarks,
P256, Secp256k1, P256, Secp256k1, Curve25519, Bandersnatch,
BLS12_377, BLS12_381, BW6_761 BLS12_377, BLS12_381, BW6_761
] ]