From 37354e9ca8e433abd3f03abbba5be0738984d597 Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Sun, 7 Aug 2022 20:50:24 +0200 Subject: [PATCH] faster isSquare: faster hash_to_curve (BN254) and point deserialization (BLS12-377) closes #199 --- constantine/hash_to_curve/hash_to_curve.nim | 4 - constantine/math/arithmetic/bigints.nim | 2 +- .../arithmetic/finite_fields_square_root.nim | 61 +---- .../{limbs_invmod.nim => limbs_exgcd.nim} | 255 ++++++++++++++++-- .../math/arithmetic/limbs_unsaturated.nim | 35 +++ .../math/extension_fields/square_root_fp2.nim | 2 +- 6 files changed, 280 insertions(+), 79 deletions(-) rename constantine/math/arithmetic/{limbs_invmod.nim => limbs_exgcd.nim} (62%) diff --git a/constantine/hash_to_curve/hash_to_curve.nim b/constantine/hash_to_curve/hash_to_curve.nim index ec5b2bf..984c52e 100644 --- a/constantine/hash_to_curve/hash_to_curve.nim +++ b/constantine/hash_to_curve/hash_to_curve.nim @@ -79,10 +79,6 @@ func mapToCurve_svdw[F, G]( gx1.curve_eq_rhs(x1, G) gx2.curve_eq_rhs(x2, G) - # TODO: faster Legendre symbol. - # We can optimize the 2 legendre symbols + 3 sqrt to - # - either 2 legendre and 1 sqrt - # - or 3 fused legendre+sqrt let e1 = gx1.isSquare() let e2 = gx2.isSquare() and not e1 diff --git a/constantine/math/arithmetic/bigints.nim b/constantine/math/arithmetic/bigints.nim index e99210d..7e2bb22 100644 --- a/constantine/math/arithmetic/bigints.nim +++ b/constantine/math/arithmetic/bigints.nim @@ -11,7 +11,7 @@ import ../config/type_bigint, ./limbs, ./limbs_extmul, - ./limbs_invmod, + ./limbs_exgcd, ./limbs_division export BigInt diff --git a/constantine/math/arithmetic/finite_fields_square_root.nim b/constantine/math/arithmetic/finite_fields_square_root.nim index 3ebc076..615236c 100644 --- a/constantine/math/arithmetic/finite_fields_square_root.nim +++ b/constantine/math/arithmetic/finite_fields_square_root.nim @@ -10,7 +10,7 @@ import ../../platforms/abstractions, ../config/curves, ../curves/zoo_square_roots, - ./bigints, ./finite_fields + ./bigints, ./finite_fields, ./limbs_exgcd # ############################################################ # @@ -142,34 +142,6 @@ func precompute_tonelli_shanks(a_pre_exp: var Fp, a: Fp) = a_pre_exp = a a_pre_exp.powUnsafeExponent(Fp.C.tonelliShanks(exponent)) -func isSquare_tonelli_shanks( - a, a_pre_exp: Fp): SecretBool {.used.} = - ## Returns if `a` is a quadratic residue - ## This uses common precomputation for - ## Tonelli-Shanks based square root and inverse square root - ## - ## a^((p-1-2^e)/(2*2^e)) - ## - ## Note: if we need to compute a candidate square root anyway - ## it's faster to square it to check if we get ``a`` - const e = Fp.C.tonelliShanks(twoAdicity) - var r {.noInit.}: Fp - r.square(a_pre_exp) # a^(2(q-1-2^e)/(2*2^e)) = a^((q-1)/2^e - 1) - r *= a # a^((q-1)/2^e) - r.square_repeated(e-1) # a^((q-1)/2) - - result = not(r.isMinusOne()) - # r can be: - # - 1 if a square - # - 0 if 0 - # - -1 if a quadratic non-residue - debug: - doAssert: bool( - r.isZero or - r.isOne or - r.isMinusOne() - ) - func invsqrt_tonelli_shanks_pre( invsqrt: var Fp, a, a_pre_exp: Fp) = @@ -314,33 +286,10 @@ func isSquare*(a: Fp): SecretBool = ## Returns true if ``a`` is a square (quadratic residue) in 𝔽p ## ## Assumes that the prime modulus ``p`` is public. - when false: - # Implementation: we use exponentiation by (p-1)/2 (Euler's criterion) - # as it can reuse the exponentiation implementation - # Note that we don't care about leaking the bits of p - # as we assume that - var xi {.noInit.} = a # TODO: is noInit necessary? see https://github.com/mratsim/constantine/issues/21 - xi.powUnsafeExponent(Fp.getPrimeMinus1div2_BE()) - result = not(xi.isMinusOne()) - # xi can be: - # - 1 if a square - # - 0 if 0 - # - -1 if a quadratic non-residue - debug: - doAssert: bool( - xi.isZero or - xi.isOne or - xi.isMinusOne() - ) - else: - # We reuse the optimized addition chains instead of exponentiation by (p-1)/2 - when Fp.C.has_P_3mod4_primeModulus() or Fp.C.has_P_5mod8_primeModulus(): - var sqrt{.noInit.}, invsqrt{.noInit.}: Fp - return sqrt_invsqrt_if_square(sqrt, invsqrt, a) - else: - var a_pre_exp{.noInit.}: Fp - a_pre_exp.precompute_tonelli_shanks(a) - return isSquare_tonelli_shanks(a, a_pre_exp) + var aa {.noInit.}: matchingBigInt(Fp.C) + aa.fromField(a) + let symbol = legendre(aa.limbs, Fp.fieldMod().limbs, aa.bits) + return not(symbol == MaxWord) {.pop.} # inline diff --git a/constantine/math/arithmetic/limbs_invmod.nim b/constantine/math/arithmetic/limbs_exgcd.nim similarity index 62% rename from constantine/math/arithmetic/limbs_invmod.nim rename to constantine/math/arithmetic/limbs_exgcd.nim index 8fddcdc..b65b378 100644 --- a/constantine/math/arithmetic/limbs_invmod.nim +++ b/constantine/math/arithmetic/limbs_exgcd.nim @@ -13,6 +13,28 @@ import # No exceptions allowed {.push raises: [].} +# ############################################################ +# +# Primitives based on Bézout's identity +# +# ############################################################ +# +# Bézout's identity is the linear Diophantine equation +# au + bv = c +# +# The solution c is gcd(a, b) +# if a and b are coprime, gcd(a, b) = 1 +# au + bv = 1 +# +# Hence modulo b we have +# au + bv ≡ 1 (mod b) +# au ≡ 1 (mod b) +# So u is the modular multiplicative inverse of a (mod b) +# +# As we can use the Extended Euclidean Algorithm to find +# the GCD and the Bézout coefficient, we can use it to find the +# modular multiplicaative inverse. + # ############################################################ # # Modular inversion (mod 2ᵏ) @@ -175,15 +197,15 @@ proc partitionDivsteps(bits, wordBitWidth: int): tuple[totalIters, numChunks, ch func batchedDivsteps( t: var TransitionMatrix, - theta: SignedSecretWord, + hdelta: SignedSecretWord, f0, g0: SignedSecretWord, numIters: int, k: static int ): SignedSecretWord = - ## Bernstein-Yang half-delta (theta) batch of divsteps + ## Bernstein-Yang half-delta (hdelta) batch of divsteps ## ## Output: - ## - return theta for the next batch of divsteps + ## - return hdelta for the next batch of divsteps ## - mutate t, the transition matrix to apply `numIters` divsteps at once ## t is scaled by 2ᵏ ## @@ -200,13 +222,13 @@ func batchedDivsteps( f = f0 g = g0 - theta = theta + hdelta = hdelta for i in k-numIters ..< k: debug: func reportLoop() = debugEcho " iterations: [", k-numIters, ", ", k, ")", " (", numIters, " iterations in total)" - debugEcho " i: ", i, ", theta: ", int(theta) + debugEcho " i: ", i, ", hdelta: ", int(hdelta) # debugEcho " f: 0b", BiggestInt(f).toBin(64), ", g: 0b", BiggestInt(g).toBin(64), " | f: ", int(f), ", g: ", int(g) # debugEcho " u: 0b", BiggestInt(u).toBin(64), ", v: 0b", BiggestInt(v).toBin(64), " | u: ", int(u), ", v: ", int(v) # debugEcho " q: 0b", BiggestInt(q).toBin(64), ", r: 0b", BiggestInt(r).toBin(64), " | q: ", int(q), ", r: ", int(r) @@ -216,8 +238,8 @@ func batchedDivsteps( doAssert bool(u.ashr(k-i)*f0 + v.ashr(k-i)*g0 == f.lshl(i)), (reportLoop(); "Applying the transition matrix to (f₀, g₀) returns current (f, g)") doAssert bool(q.ashr(k-i)*f0 + r.ashr(k-i)*g0 == g.lshl(i)), (reportLoop(); "Applying the transition matrix to (f₀, g₀) returns current (f, g)") - # Conditional masks for (theta < 0) and g odd - let c1 = theta.isNegMask() + # Conditional masks for (hdelta < 0) and g odd + let c1 = hdelta.isNegMask() let c2 = g.isOddMask() # x, y, z, conditional complement of f, u, v let x = f xor c1 @@ -227,10 +249,10 @@ func batchedDivsteps( g.csub(x, c2) q.csub(y, c2) r.csub(z, c2) - # c3 = (theta >= 0) and g odd + # c3 = (hdelta >= 0) and g odd let c3 = c2 and not c1 - # theta = -theta or theta+1 - theta = (theta xor c3) + SignedSecretWord(1) + # hdelta = -hdelta or hdelta+1 + hdelta = (hdelta xor c3) + SignedSecretWord(1) # Conditional rollback substraction f.cadd(g, c3) u.cadd(q, c3) @@ -249,7 +271,7 @@ func batchedDivsteps( doAssert bool(q*f0 + r*g0 == g.lshl(k)), "Applying the final matrix to (f₀, g₀) gives the final (f, g)" doAssert checkDeterminant(t, u, v, q, r, k, numIters) - return theta + return hdelta func matVecMul_shr_k_mod_M[N, E: static int]( t: TransitionMatrix, @@ -373,8 +395,8 @@ func invmodImpl[N, E]( ## Modular inversion using Bernstein-Yang algorithm ## r ≡ F.a⁻¹ (mod M) - # theta = delta-1/2, delta starts at 1/2 for the half-delta variant - var theta = SignedSecretWord(0) + # hdelta = delta-1/2, delta starts at 1/2 for the half-delta variant + var hdelta = SignedSecretWord(0) var d{.noInit.}, e{.noInit.}: LimbsUnsaturated[N, E] var f{.noInit.}, g{.noInit.}: LimbsUnsaturated[N, E] @@ -389,8 +411,8 @@ func invmodImpl[N, E]( for i in 0 ..< partition.numChunks: var t{.noInit.}: TransitionMatrix let numIters = partition.chunkSize + int(i < partition.cutoff) - # Compute transition matrix and next theta - theta = t.batchedDivsteps(theta, f[0], g[0], numIters, k) + # Compute transition matrix and next hdelta + hdelta = t.batchedDivsteps(hdelta, f[0], g[0], numIters, k) # Apply the transition matrix # [u v] [d] # [q r]/2ᵏ.[e] mod M @@ -430,7 +452,7 @@ func invmod*( r: var Limbs, a: Limbs, F, M: static Limbs, bits: static int) = ## Compute the scaled modular inverse of ``a`` modulo M - ## r ≡ F.a⁻¹ (mod M) + ## r ≡ F.a⁻¹ (mod M) (compile-time factor and modulus overload) ## ## with F and M known at compile-time ## @@ -449,4 +471,203 @@ func invmod*( var a2 {.noInit.}: LimbsUnsaturated[NumUnsatWords, Excess] a2.fromPackedRepr(a) a2.invmodImpl(factor, m2, m0invK, k, bits) - r.fromUnsatRepr(a2) \ No newline at end of file + r.fromUnsatRepr(a2) + +# ############################################################ +# +# Euler criterion, Legendre/Jacobi/Krönecker symbol +# +# ############################################################ +# +# The Euler criterion, i.e. the quadratic residuosity test, for p an odd prime, is: +# a^((p-1)/2) ≡ 1 (mod p), iff a is a square +# ≡ -1 (mod p), iff a is quadratic non-residue +# ≡ 0 (mod p), iff a is 0 +# derived from Fermat's Little Theorem +# +# The Legendre symbol is a function with p odd prime +# (a/p)ₗ ≡ 1 (mod p), iff a is a square +# ≡ -1 (mod p), iff a is quadratic non-residue +# ≡ 0 (mod p), iff a is 0 +# +# The Jacobi symbol generalizes the Legendre symbol for any odd n: +# (a/n)ⱼ = ∏ᵢ (a/pᵢ)ₗ +# is the product of legendre symbol (a/pᵢ)ₗ with pᵢ the prime factors of n +# +# Those symbols can be computed either via exponentiation (Fermat's Little Theorem) +# or using slight modifications to the Extended Euclidean Algorithm for GCD. +# +# See +# - Algorithm II.7 in Blake, Seroussi, Smart: "Elliptic Curves in Cryptography" +# - Algorithm 5.9.2 in Bach and Shallit: "Algorithmic Number Theory" +# - Pornin: https://github.com/pornin/x25519-cm0/blob/75a53f2/src/x25519-cm0.S#L89-L155 + +func batchedDivstepsSymbol( + t: var TransitionMatrix, + hdelta: SignedSecretWord, + f0, g0: SignedSecretWord, + numIters: int, + k: static int + ): tuple[hdelta, L: SignedSecretWord] = + ## Bernstein-Yang half-delta (hdelta) batch of divsteps + ## with Legendre symbol tracking + ## + ## Output: + ## - return hdelta for the next batch of divsteps + ## - Returns the intermediate Legendre symbol + ## - mutate t, the transition matrix to apply `numIters` divsteps at once + ## t is scaled by 2ᵏ + ## + ## Input: + ## - f0, bottom limb of f + ## - g0, bottom limb of g + ## - numIters, number of iterations requested in this batch of divsteps + ## - k, the maximum batch size, transition matrix is scaled by 2ᵏ + + var + u = SignedSecretWord(1 shl (k-numIters)) + v = SignedSecretWord(0) + q = SignedSecretWord(0) + r = SignedSecretWord(1 shl (k-numIters)) + f = f0 + g = g0 + + hdelta = hdelta + L = SignedSecretWord(0) + + for i in k-numIters ..< k: + debug: + func reportLoop() = + debugEcho " iterations: [", k-numIters, ", ", k, ")", " (", numIters, " iterations in total)" + debugEcho " i: ", i, ", hdelta: ", int(hdelta) + # debugEcho " f: 0b", BiggestInt(f).toBin(64), ", g: 0b", BiggestInt(g).toBin(64), " | f: ", int(f), ", g: ", int(g) + # debugEcho " u: 0b", BiggestInt(u).toBin(64), ", v: 0b", BiggestInt(v).toBin(64), " | u: ", int(u), ", v: ", int(v) + # debugEcho " q: 0b", BiggestInt(q).toBin(64), ", r: 0b", BiggestInt(r).toBin(64), " | q: ", int(q), ", r: ", int(r) + + doAssert (BaseType(f) and 1) == 1, (reportLoop(); "f must be odd)") + doAssert bool(u*f0 + v*g0 == f.lshl(i)), (reportLoop(); "Applying the transition matrix to (f₀, g₀) returns current (f, g)") + doAssert bool(q*f0 + r*g0 == g.lshl(i)), (reportLoop(); "Applying the transition matrix to (f₀, g₀) returns current (f, g)") + + let fi = f + + # Conditional masks for (hdelta < 0) and g odd + let c1 = hdelta.isNegMask() + let c2 = g.isOddMask() + # x, y, z, conditional negated complement of f, u, v + let x = (f xor c1) - c1 + let y = (u xor c1) - c1 + let z = (v xor c1) - c1 + # conditional addition g, q, r + g.cadd(x, c2) + q.cadd(y, c2) + r.cadd(z, c2) + # c3 = (hdelta < 0) and g odd + let c3 = c2 and c1 + # hdelta = -hdelta-2 or hdelta-1 + hdelta = (hdelta xor c3) - SignedSecretWord(1) + # Conditionally rollback + f.cadd(g, c3) + u.cadd(q, c3) + v.cadd(r, c3) + # Shifts + g = g.lshr(1) + u = u.lshl(1) + v = v.lshl(1) + + L = L + (((fi and f) xor f.lshr(1)) and SignedSecretWord(2)) + L = L + (L.isOdd() xor v.isNeg()) + L = L and SignedSecretWord(3) + + t.u = u + t.v = v + t.q = q + t.r = r + debug: + doAssert bool(u*f0 + v*g0 == f.lshl(k)), "Applying the final matrix to (f₀, g₀) gives the final (f, g)" + doAssert bool(q*f0 + r*g0 == g.lshl(k)), "Applying the final matrix to (f₀, g₀) gives the final (f, g)" + doAssert checkDeterminant(t, u, v, q, r, k, numIters) + + return (hdelta, L) + +func legendreImpl[N, E]( + a: var LimbsUnsaturated[N, E], + M: LimbsUnsaturated[N, E], + k, bits: static int): SecretWord = + ## Legendre symbol / Quadratic Residuosity Test + ## using Bernstein-Yang algorithm + + # hdelta = delta-1/2, delta starts at 1/2 for the half-delta variant + var hdelta = SignedSecretWord(0) + var f{.noInit.}, g{.noInit.}: LimbsUnsaturated[N, E] + + # g < f for partitioning / iteration count formula + f = M + g = a + const partition = partitionDivsteps(bits, k) + const UnsatBitWidth = WordBitWidth - a.Excess + + var # Track and accumulate Legendre symbol transitions + accL = SignedSecretWord(0) + L = SignedSecretWord(0) + + for i in 0 ..< partition.numChunks: + var t{.noInit.}: TransitionMatrix + let numIters = partition.chunkSize + int(i < partition.cutoff) + # Compute transition matrix and next hdelta + when f.words.len > 1: + (hdelta, L) = t.batchedDivstepsSymbol( + hdelta, + # the symbol computation needs to see the extra 2 next bits. + f[0] or f[1].lshl(UnsatBitWidth), + g[0] or g[1].lshl(UnsatBitWidth), + numIters, k) + else: + (hdelta, L) = t.batchedDivstepsSymbol(hdelta, f[0], g[0], numIters, k) + # [u v] [f] + # [q r]/2ᵏ.[g] + t.matVecMul_shr_k(f, g, k) + accL = (accL + L) and SignedSecretWord(3) + accL = (accL + ((accL.isOdd() xor f.isNeg()))) and SignedSecretWord(3) + + accL = (accL + accL.isOdd()) and SignedSecretWord(3) + accL = SignedSecretWord(1)-accL + accL.csetZero(f.isZeroMask()) # f = gcd = 1 as M is prime or f = 0 if a = 0 + return SecretWord(accL) + +func legendre*(a, M: Limbs, bits: static int): SecretWord = + ## Compute the Legendre symbol + ## + ## (a/p)ₗ ≡ a^((p-1)/2) ≡ 1 (mod p), iff a is a square + ## ≡ -1 (mod p), iff a is quadratic non-residue + ## ≡ 0 (mod p), iff a is 0 + const Excess = 2 + const k = WordBitwidth - Excess + const NumUnsatWords = (bits + k - 1) div k + + # Convert values to unsaturated repr + var m2 {.noInit.}: LimbsUnsaturated[NumUnsatWords, Excess] + m2.fromPackedRepr(M) + + var a2 {.noInit.}: LimbsUnsaturated[NumUnsatWords, Excess] + a2.fromPackedRepr(a) + + legendreImpl(a2, m2, k, bits) + +func legendre*(a: Limbs, M: static Limbs, bits: static int): SecretWord = + ## Compute the Legendre symbol (compile-time modulus overload) + ## + ## (a/p)ₗ ≡ a^((p-1)/2) ≡ 1 (mod p), iff a is a square + ## ≡ -1 (mod p), iff a is quadratic non-residue + ## ≡ 0 (mod p), iff a is 0 + + const Excess = 2 + const k = WordBitwidth - Excess + const NumUnsatWords = (bits + k - 1) div k + + # Convert values to unsaturated repr + const m2 = LimbsUnsaturated[NumUnsatWords, Excess].fromPackedRepr(M) + + var a2 {.noInit.}: LimbsUnsaturated[NumUnsatWords, Excess] + a2.fromPackedRepr(a) + + legendreImpl(a2, m2, k, bits) diff --git a/constantine/math/arithmetic/limbs_unsaturated.nim b/constantine/math/arithmetic/limbs_unsaturated.nim index 4139a69..aff37e6 100644 --- a/constantine/math/arithmetic/limbs_unsaturated.nim +++ b/constantine/math/arithmetic/limbs_unsaturated.nim @@ -269,6 +269,21 @@ template `==`*(x, y: SignedSecretWord): SecretBool = # SignedSecretWord # ---------------- +func isNeg*(a: SignedSecretWord): SignedSecretWord {.inline.} = + ## Returns 1 if a is negative + ## and 0 otherwise + a.lshr(WordBitWidth-1) + +func isOdd*(a: SignedSecretWord): SignedSecretWord {.inline.} = + ## Returns 1 if a is odd + ## and 0 otherwise + a and SignedSecretWord(1) + +func isZeroMask*(a: SignedSecretWord): SignedSecretWord {.inline.} = + ## Produce the -1 mask if a is negative + ## and 0 otherwise + not SignedSecretWord(a.SecretWord().isZero()) + func isNegMask*(a: SignedSecretWord): SignedSecretWord {.inline.} = ## Produce the -1 mask if a is negative ## and 0 otherwise @@ -279,6 +294,12 @@ func isOddMask*(a: SignedSecretWord): SignedSecretWord {.inline.} = ## and 0 otherwise -(a and SignedSecretWord(1)) +func csetZero*(a: var SignedSecretWord, mask: SignedSecretWord) {.inline.} = + ## Conditionally set `a` to 0 + ## mask must be 0 (0x00000...0000) (kept as is) + ## or -1 (0xFFFF...FFFF) (zeroed) + a = a and mask + func cneg*( a: SignedSecretWord, mask: SignedSecretWord): SignedSecretWord {.inline.} = @@ -308,6 +329,20 @@ func csub*( # UnsaturatedLimbs # ---------------- +func isZeroMask*(a: LimbsUnsaturated): SignedSecretWord {.inline.} = + ## Produce the -1 mask if a is zero + ## and 0 otherwise + var accum = SignedSecretWord(0) + for i in 0 ..< a.words.len: + accum = accum or a.words[i] + + return accum.isZeroMask() + +func isNeg*(a: LimbsUnsaturated): SignedSecretWord {.inline.} = + ## Returns 1 if a is negative + ## and 0 otherwise + a[a.words.len-1].lshr(WordBitWidth - a.Excess + 1) + func isNegMask*(a: LimbsUnsaturated): SignedSecretWord {.inline.} = ## Produce the -1 mask if a is negative ## and 0 otherwise diff --git a/constantine/math/extension_fields/square_root_fp2.nim b/constantine/math/extension_fields/square_root_fp2.nim index 3e5b4b2..b721376 100644 --- a/constantine/math/extension_fields/square_root_fp2.nim +++ b/constantine/math/extension_fields/square_root_fp2.nim @@ -54,7 +54,7 @@ func sqrt_rotate_extension*( ## if there is one, update out_sqrt with it and return true ## return false otherwise, out_sqrt is undefined in this case ## - ## This avoids expensive trial "isSquare" checks (450+ field multiplications) + ## This avoids expensive trial "isSquare" checks ## This requires the sqrt of sqrt of the quadratic non-residue ## to be in Fp2 var coeff{.noInit.}, cand2{.noInit.}, t{.noInit.}: Fp2