diff --git a/benchmarks/bench_blueprint.nim b/benchmarks/bench_blueprint.nim index 1faa36c..e7a0faa 100644 --- a/benchmarks/bench_blueprint.nim +++ b/benchmarks/bench_blueprint.nim @@ -16,7 +16,7 @@ import # Internal ../constantine/config/common, # Helpers - ../helpers/[prng_unsafe, static_for], + ../helpers/prng_unsafe, ./platforms, # Standard library std/[monotimes, times, strformat, strutils, macros] diff --git a/benchmarks/bench_fields_template.nim b/benchmarks/bench_fields_template.nim index d046a33..3ff77e8 100644 --- a/benchmarks/bench_fields_template.nim +++ b/benchmarks/bench_fields_template.nim @@ -14,15 +14,17 @@ import # Internals - ../constantine/config/[curves, common], + ../constantine/config/[common, curves], ../constantine/arithmetic, ../constantine/towers, + ../constantine/curves/zoo_square_roots, # Helpers - ../helpers/[prng_unsafe, static_for], + ../helpers/prng_unsafe, ./bench_blueprint export notes proc separator*() = separator(165) +proc smallSeparator*() = separator(8) proc report(op, field: string, start, stop: MonoTime, startClk, stopClk: int64, iters: int) = let ns = inNanoseconds((stop-start) div iters) @@ -119,37 +121,30 @@ proc invAddChainBench*(T: typedesc, iters: int) = proc sqrtBench*(T: typedesc, iters: int) = let x = rng.random_unsafe(T) - bench("Square Root + isSquare (constant-time default impl)", T, iters): + + const algoType = block: + when T.C.hasP3mod4_primeModulus(): + "p ≡ 3 (mod 4)" + elif T.C.hasP5mod8_primeModulus(): + "p ≡ 5 (mod 8)" + else: + "Tonelli-Shanks" + const addchain = block: + when T.C.hasSqrtAddchain() or T.C.hasTonelliShanksAddchain(): + "with addition chain" + else: + "without addition chain" + const desc = "Square Root (constant-time " & algoType & " " & addchain & ")" + bench(desc, T, iters): var r = x discard r.sqrt_if_square() -proc sqrtP3mod4Bench*(T: typedesc, iters: int) = +proc sqrtRatioBench*(T: typedesc, iters: int) = var r: T - let x = rng.random_unsafe(T) - bench("SquareRoot (p ≡ 3 (mod 4) exponentiation)", T, iters): - r.invsqrt_p3mod4(x) - r *= x - -proc sqrtAddChainBench*(T: typedesc, iters: int) = - var r: T - let x = rng.random_unsafe(T) - bench("SquareRoot (addition chain)", T, iters): - r.invsqrt_addchain(x) - r *= x - -proc sqrtTonelliBench*(T: typedesc, iters: int) = - var r: T - let x = rng.random_unsafe(T) - bench("SquareRoot (constant-time Tonelli-Shanks exponentiation)", T, iters): - r.invsqrt_tonelli_shanks(x, useAddChain = false) - r *= x - -proc sqrtTonelliAddChainBench*(T: typedesc, iters: int) = - var r: T - let x = rng.random_unsafe(T) - bench("SquareRoot (constant-time Tonelli-Shanks addchain)", T, iters): - r.invsqrt_tonelli_shanks(x, useAddChain = true) - r *= x + let u = rng.random_unsafe(T) + let v = rng.random_unsafe(T) + bench("Fused SquareRoot+Division+isSquare sqrt(u/v)", T, iters): + let isSquare = r.sqrt_ratio_if_square(u, v) proc powBench*(T: typedesc, iters: int) = let x = rng.random_unsafe(T) diff --git a/benchmarks/bench_fp.nim b/benchmarks/bench_fp.nim index 8b9d53c..9cae10a 100644 --- a/benchmarks/bench_fp.nim +++ b/benchmarks/bench_fp.nim @@ -14,9 +14,7 @@ import ../constantine/curves/[zoo_inversions, zoo_square_roots], # Helpers ../helpers/static_for, - ./bench_fields_template, - # Standard library - std/strutils + ./bench_fields_template # ############################################################ # @@ -31,9 +29,10 @@ const AvailableCurves = [ # P224, BN254_Nogami, BN254_Snarks, - # Curve25519, - # P256, - # Secp256k1, + Curve25519, + Bandersnatch, + P256, + Secp256k1, BLS12_377, BLS12_381, BW6_761 @@ -50,17 +49,11 @@ proc main() = div2Bench(Fp[curve], Iters) mulBench(Fp[curve], Iters) sqrBench(Fp[curve], Iters) + smallSeparator() invEuclidBench(Fp[curve], ExponentIters) invPowFermatBench(Fp[curve], ExponentIters) - when curve.hasInversionAddchain(): - invAddChainBench(Fp[curve], ExponentIters) - when (BaseType(curve.Mod.limbs[0]) and 3) == 3: - sqrtP3mod4Bench(Fp[curve], ExponentIters) - when curve.hasSqrtAddchain(): - sqrtAddChainBench(Fp[curve], ExponentIters) - when curve in {BLS12_377}: - sqrtTonelliBench(Fp[curve], ExponentIters) - sqrtTonelliAddChainBench(Fp[curve], ExponentIters) + sqrtBench(Fp[curve], ExponentIters) + sqrtRatioBench(Fp[curve], ExponentIters) # Exponentiation by a "secret" of size ~the curve order powBench(Fp[curve], ExponentIters) powUnsafeBench(Fp[curve], ExponentIters) diff --git a/benchmarks/bench_fp_double_precision.nim b/benchmarks/bench_fp_double_precision.nim index ba05464..7a1d100 100644 --- a/benchmarks/bench_fp_double_precision.nim +++ b/benchmarks/bench_fp_double_precision.nim @@ -230,8 +230,11 @@ proc main() = diff2xUnreduce(Fp[BLS12_381], iters = 10_000_000) neg2x(Fp[BLS12_381], iters = 10_000_000) separator() + prod2xBench(512, 256, 256, iters = 10_000_000) prod2xBench(768, 384, 384, iters = 10_000_000) + square2xBench(512, 256, iters = 10_000_000) square2xBench(768, 384, iters = 10_000_000) + reduce2x(Fp[BN254_Snarks], iters = 10_000_000) reduce2x(Fp[BLS12_381], iters = 10_000_000) separator() diff --git a/constantine/arithmetic/assembly/limbs_asm_montmul_x86.nim b/constantine/arithmetic/assembly/limbs_asm_montmul_x86.nim index 3a656e2..6c9e3f2 100644 --- a/constantine/arithmetic/assembly/limbs_asm_montmul_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_montmul_x86.nim @@ -34,7 +34,7 @@ static: doAssert UseASM_X86_64 # Montgomery multiplication # ------------------------------------------------------------ # Fallback when no ADX and BMI2 support (MULX, ADCX, ADOX) -macro montMul_CIOS_nocarry_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_MM: Limbs[N], m0ninv_MM: BaseType): untyped = +macro montMul_CIOS_sparebit_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_MM: Limbs[N], m0ninv_MM: BaseType): untyped = ## Generate an optimized Montgomery Multiplication kernel ## using the CIOS method ## @@ -174,9 +174,9 @@ macro montMul_CIOS_nocarry_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_ result.add ctx.generate -func montMul_CIOS_nocarry_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = +func montMul_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = ## Constant-time modular multiplication - montMul_CIOS_nocarry_gen(r, a, b, M, m0ninv) + montMul_CIOS_sparebit_gen(r, a, b, M, m0ninv) # Montgomery Squaring # ------------------------------------------------------------ diff --git a/constantine/arithmetic/assembly/limbs_asm_montmul_x86_adx_bmi2.nim b/constantine/arithmetic/assembly/limbs_asm_montmul_x86_adx_bmi2.nim index be09592..7b38e69 100644 --- a/constantine/arithmetic/assembly/limbs_asm_montmul_x86_adx_bmi2.nim +++ b/constantine/arithmetic/assembly/limbs_asm_montmul_x86_adx_bmi2.nim @@ -174,7 +174,7 @@ proc partialRedx( ctx.adcx t[N-1], S ctx.adox t[N-1], C -macro montMul_CIOS_nocarry_adx_bmi2_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_MM: Limbs[N], m0ninv_MM: BaseType): untyped = +macro montMul_CIOS_sparebit_adx_bmi2_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_MM: Limbs[N], m0ninv_MM: BaseType): untyped = ## Generate an optimized Montgomery Multiplication kernel ## using the CIOS method ## This requires the most significant word of the Modulus @@ -270,9 +270,10 @@ macro montMul_CIOS_nocarry_adx_bmi2_gen[N: static int](r_MM: var Limbs[N], a_MM, result.add ctx.generate -func montMul_CIOS_nocarry_asm_adx_bmi2*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = +func montMul_CIOS_sparebit_asm_adx_bmi2*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = ## Constant-time modular multiplication - montMul_CIOS_nocarry_adx_bmi2_gen(r, a, b, M, m0ninv) + ## Requires the prime modulus to have a spare bit in the representation. (Hence if using 64-bit words and 4 words, to be at most 255-bit) + montMul_CIOS_sparebit_adx_bmi2_gen(r, a, b, M, m0ninv) # Montgomery Squaring # ------------------------------------------------------------ diff --git a/constantine/arithmetic/finite_fields_inversion.nim b/constantine/arithmetic/finite_fields_inversion.nim index b120c7f..93c7186 100644 --- a/constantine/arithmetic/finite_fields_inversion.nim +++ b/constantine/arithmetic/finite_fields_inversion.nim @@ -37,10 +37,16 @@ func inv*(r: var FF, a: FF) = ## to convert Jacobian and Projective coordinates ## to affine for elliptic curve # For now we don't activate the addition chains - # neither for Secp256k1 nor BN curves - # Performance is slower than GCD - # To be revisited with faster squaring/multiplications - when FF is Fp and FF.C.hasInversionAddchain(): + # Performance is slower than Euclid-based inversion on newer CPUs + # + # - Montgomery multiplication/squaring can skip the final substraction + # - For generalized Mersenne Prime curves, modular reduction can be made extremely cheap. + # - For BW6-761 the addition chain is over 2x slower than Euclid-based inversion + # due to multiplication being so costly with 12 limbs (grows quadratically) + # while Euclid costs grows linearly. + when false and + FF is Fp and FF.C.hasInversionAddchain() and + FF.C notin {BW6_761}: r.inv_addchain(a) else: r.inv_euclid(a) diff --git a/constantine/arithmetic/finite_fields_square_root.nim b/constantine/arithmetic/finite_fields_square_root.nim index 86e6574..c6c955f 100644 --- a/constantine/arithmetic/finite_fields_square_root.nim +++ b/constantine/arithmetic/finite_fields_square_root.nim @@ -10,8 +10,7 @@ import ../primitives, ../config/[common, type_ff, curves], ../curves/zoo_square_roots, - ./bigints, ./finite_fields, - ./finite_fields_inversion + ./bigints, ./finite_fields # ############################################################ # @@ -23,39 +22,14 @@ import {.push raises: [].} {.push inline.} -# Legendre symbol / Euler's Criterion / Kronecker's symbol -# ------------------------------------------------------------ - -func isSquare*(a: Fp): SecretBool = - ## Returns true if ``a`` is a square (quadratic residue) in 𝔽p - ## - ## Assumes that the prime modulus ``p`` is public. - # Implementation: we use exponentiation by (p-1)/2 (Euler's criterion) - # as it can reuse the exponentiation implementation - # Note that we don't care about leaking the bits of p - # as we assume that - var xi {.noInit.} = a # TODO: is noInit necessary? see https://github.com/mratsim/constantine/issues/21 - xi.powUnsafeExponent(Fp.getPrimeMinus1div2_BE()) - result = not(xi.isMinusOne()) - # xi can be: - # - 1 if a square - # - 0 if 0 - # - -1 if a quadratic non-residue - debug: - doAssert: bool( - xi.isZero or - xi.isOne or - xi.isMinusOne() - ) - # Specialized routine for p ≡ 3 (mod 4) # ------------------------------------------------------------ -func hasP3mod4_primeModulus(C: static Curve): static bool = +func hasP3mod4_primeModulus*(C: static Curve): static bool = ## Returns true iff p ≡ 3 (mod 4) (BaseType(C.Mod.limbs[0]) and 3) == 3 -func invsqrt_p3mod4*(r: var Fp, a: Fp) = +func invsqrt_p3mod4(r: var Fp, a: Fp) = ## Compute the inverse square root of ``a`` ## ## This requires ``a`` to be a square @@ -75,17 +49,20 @@ func invsqrt_p3mod4*(r: var Fp, a: Fp) = # a^((p-3)/2)) ≡ 1/a (mod p) # a^((p-3)/4)) ≡ 1/√a (mod p) # Requires p ≡ 3 (mod 4) static: doAssert Fp.C.hasP3mod4_primeModulus() - r = a - r.powUnsafeExponent(Fp.getPrimeMinus3div4_BE()) + when FP.C.hasSqrtAddchain(): + r.invsqrt_addchain(a) + else: + r = a + r.powUnsafeExponent(Fp.getPrimeMinus3div4_BE()) # Specialized routine for p ≡ 5 (mod 8) # ------------------------------------------------------------ -func hasP5mod8_primeModulus(C: static Curve): static bool = +func hasP5mod8_primeModulus*(C: static Curve): static bool = ## Returns true iff p ≡ 5 (mod 8) (BaseType(C.Mod.limbs[0]) and 7) == 5 -func invsqrt_p5mod8*(r: var Fp, a: Fp) = +func invsqrt_p5mod8(r: var Fp, a: Fp) = ## Compute the inverse square root of ``a`` ## ## This requires ``a`` to be a square @@ -141,7 +118,10 @@ func invsqrt_p5mod8*(r: var Fp, a: Fp) = # α = (2a)^((p-5)/8) alpha.double(a) beta = alpha - alpha.powUnsafeExponent(Fp.getPrimeMinus5div8_BE()) + when Fp.C.hasSqrtAddchain(): + alpha.invsqrt_addchain_pminus5over8(alpha) + else: + alpha.powUnsafeExponent(Fp.getPrimeMinus5div8_BE()) # Note: if r aliases a, for inverse square root we don't use `a` again @@ -163,13 +143,11 @@ func invsqrt_p5mod8*(r: var Fp, a: Fp) = # Tonelli Shanks for any prime # ------------------------------------------------------------ -func precompute_tonelli_shanks( - a_pre_exp: var Fp, - a: Fp, useAddChain: static bool) = - a_pre_exp = a - when useAddChain: +func precompute_tonelli_shanks(a_pre_exp: var Fp, a: Fp) = + when FP.C.hasTonelliShanksAddchain(): a_pre_exp.precompute_tonelli_shanks_addchain(a) else: + a_pre_exp = a a_pre_exp.powUnsafeExponent(Fp.C.tonelliShanks(exponent)) func isSquare_tonelli_shanks( @@ -232,7 +210,7 @@ func invsqrt_tonelli_shanks_pre( t.ccopy(buf, bNotOne) b = t -func invsqrt_tonelli_shanks*(r: var Fp, a: Fp, useAddChain: static bool) = +func invsqrt_tonelli_shanks*(r: var Fp, a: Fp) = ## Compute the inverse square root of ``a`` ## ## This requires ``a`` to be a square @@ -244,13 +222,11 @@ func invsqrt_tonelli_shanks*(r: var Fp, a: Fp, useAddChain: static bool) = ## This procedure returns a deterministic result ## This procedure is constant-time var a_pre_exp{.noInit.}: Fp - a_pre_exp.precompute_tonelli_shanks(a, useAddChain) + a_pre_exp.precompute_tonelli_shanks(a) invsqrt_tonelli_shanks_pre(r, a, a_pre_exp) # Public routines # ------------------------------------------------------------ -# Note: we export the inner sqrt_invsqrt_IMPL -# for benchmarking purposes. {.push inline.} @@ -265,14 +241,12 @@ func invsqrt*[C](r: var Fp[C], a: Fp[C]) = ## i.e. both x² == (-x)² ## This procedure returns a deterministic result ## This procedure is constant-time - when C.hasSqrtAddchain(): - r.invsqrt_addchain(a) - elif C.hasP3mod4_primeModulus(): + when C.hasP3mod4_primeModulus(): r.invsqrt_p3mod4(a) elif C.hasP5mod8_primeModulus(): r.invsqrt_p5mod8(a) else: - r.invsqrt_tonelli_shanks(a, useAddChain = C.hasTonelliShanksAddchain()) + r.invsqrt_tonelli_shanks(a) func sqrt*[C](a: var Fp[C]) = ## Compute the square root of ``a`` @@ -340,73 +314,46 @@ func invsqrt_if_square*[C](r: var Fp[C], a: Fp[C]): SecretBool = var sqrt{.noInit.}: Fp[C] result = sqrt_invsqrt_if_square(sqrt, r, a) +# Legendre symbol / Euler's Criterion / Kronecker's symbol +# ------------------------------------------------------------ + +func isSquare*(a: Fp): SecretBool = + ## Returns true if ``a`` is a square (quadratic residue) in 𝔽p + ## + ## Assumes that the prime modulus ``p`` is public. + when false: + # Implementation: we use exponentiation by (p-1)/2 (Euler's criterion) + # as it can reuse the exponentiation implementation + # Note that we don't care about leaking the bits of p + # as we assume that + var xi {.noInit.} = a # TODO: is noInit necessary? see https://github.com/mratsim/constantine/issues/21 + xi.powUnsafeExponent(Fp.getPrimeMinus1div2_BE()) + result = not(xi.isMinusOne()) + # xi can be: + # - 1 if a square + # - 0 if 0 + # - -1 if a quadratic non-residue + debug: + doAssert: bool( + xi.isZero or + xi.isOne or + xi.isMinusOne() + ) + else: + # We reuse the optimized addition chains instead of exponentiation by (p-1)/2 + when Fp.C.hasP3mod4_primeModulus() or Fp.C.hasP5mod8_primeModulus(): + var sqrt{.noInit.}, invsqrt{.noInit.}: Fp + return sqrt_invsqrt_if_square(sqrt, invsqrt, a) + else: + var a_pre_exp{.noInit.}: Fp + a_pre_exp.precompute_tonelli_shanks(a) + return isSquare_tonelli_shanks(a, a_pre_exp) + {.pop.} # inline # Fused routines # ------------------------------------------------------------ -func sqrt_ratio_if_square_p5mod8(r: var Fp, u, v: Fp): SecretBool = - ## If u/v is a square, compute √(u/v) - ## if not, the result is undefined - ## - ## Requires p ≡ 5 (mod 8) - ## r must not alias u or v - ## - ## The square root, if it exist is multivalued, - ## i.e. both (u/v)² == (-u/v)² - ## This procedure returns a deterministic result - ## This procedure is constant-time - - # References: - # - High-Speed High-Security Signature, Bernstein et al, p15 "Fast decompression", https://ed25519.cr.yp.to/ed25519-20110705.pdf - # - IETF Hash-to-Curve: https://github.com/cfrg/draft-irtf-cfrg-hash-to-curve/blob/9939a07/draft-irtf-cfrg-hash-to-curve.md#optimized-sqrt_ratio-for-q--5-mod-8 - # - Pasta curves divsqrt: https://github.com/zcash/pasta/blob/f0f7068/squareroottab.sage#L139-L193 - # - # p ≡ 5 (mod 8), hence 𝑖 ∈ Fp with 𝑖² ≡ −1 (mod p) - # if α is a square, with β ≡ α^((p+3)/8) (mod p) - # - either β² ≡ α (mod p), hence √α ≡ ± β (mod p) - # - or β² ≡ -α (mod p), hence √α ≡ ± 𝑖β (mod p) - # (see explanation in invsqrt_p5mod8) - # - # In our fused division and sqrt case we have - # β = (u/v)^((p+3)/8) - # = u^((p+3)/8).v^(p−1−(p+3)/8) via Fermat's little theorem - # = u^((p+3)/8).v^((7p−11)/8) - # = u.u^((p-5)/8).v³.v^((7p−35)/8) - # = uv³.u^((p-5)/8).v^(7(p-5)/8) - # = uv³(uv⁷)^((p−5)/8) - # - # We can check if β² ≡ -α (mod p) - # by checking vβ² ≡ -u (mod p), and then multiply by 𝑖 - # and if it's neither u or -u it wasn't a square. - static: doAssert Fp.C.hasP5mod8_primeModulus() - var t {.noInit.}: Fp - t.square(v) - t *= v - - # r = uv³ - r.prod(u, t) - - # t = (uv⁷)^((p−5)/8) - t *= r - t *= v - t.powUnsafeExponent(Fp.getPrimeMinus5div8_BE()) - - # r = β = uv³(uv⁷)^((p−5)/8) - r *= t - - # Check candidate square roots - t.square(r) - t *= v - block: - result = t == u - block: - t.neg() - let isSol = t == u - result = result or isSol - t.prod(r, Fp.C.sqrt_minus_one()) - r.ccopy(t, isSol) - func sqrt_ratio_if_square*(r: var Fp, u, v: Fp): SecretBool {.inline.} = ## If u/v is a square, compute √(u/v) ## if not, the result is undefined @@ -417,12 +364,13 @@ func sqrt_ratio_if_square*(r: var Fp, u, v: Fp): SecretBool {.inline.} = ## i.e. both (u/v)² == (-u/v)² ## This procedure returns a deterministic result ## This procedure is constant-time - when Fp.C.hasP5mod8_primeModulus(): - sqrt_ratio_if_square_p5mod8(r, u, v) - else: - # TODO: Fuse inversion and tonelli-shanks and legendre symbol - r.inv(v) - r *= u - result = r.sqrt_if_square() + + # u/v is square iff 𝛘(u/v) = 1 (mod p) + # As 𝛘(a) = 1 or -1 + # 𝛘(u/v) = 𝛘(ub) + var uv{.noInit.}: Fp + uv.prod(u, v) # uv + result = r.invsqrt_if_square(uv) # 1/√uv + r *= u # √u/√v {.pop.} # raises no exceptions diff --git a/constantine/arithmetic/limbs_montgomery.nim b/constantine/arithmetic/limbs_montgomery.nim index b4b5608..d88a4d5 100644 --- a/constantine/arithmetic/limbs_montgomery.nim +++ b/constantine/arithmetic/limbs_montgomery.nim @@ -162,7 +162,7 @@ func montyRedc2x_Comba[N: static int]( # Montgomery Multiplication # ------------------------------------------------------------ -func montyMul_CIOS_nocarry(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = +func montyMul_CIOS_sparebit(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = ## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS) ## and no-carry optimization. ## This requires the most significant word of the Modulus @@ -373,11 +373,11 @@ func montyMul*( when UseASM_X86_64 and a.len in {2 .. 6}: # TODO: handle spilling # ADX implies BMI2 if ({.noSideEffect.}: hasAdx()): - montMul_CIOS_nocarry_asm_adx_bmi2(r, a, b, M, m0ninv) + montMul_CIOS_sparebit_asm_adx_bmi2(r, a, b, M, m0ninv) else: - montMul_CIOS_nocarry_asm(r, a, b, M, m0ninv) + montMul_CIOS_sparebit_asm(r, a, b, M, m0ninv) else: - montyMul_CIOS_nocarry(r, a, b, M, m0ninv) + montyMul_CIOS_sparebit(r, a, b, M, m0ninv) else: montyMul_FIPS(r, a, b, M, m0ninv) @@ -393,7 +393,7 @@ func montySquare*[N](r: var Limbs[N], a, M: Limbs[N], # which uses unfused squaring then Montgomery reduction # is slightly slower than fused Montgomery multiplication when spareBits >= 1: - montMul_CIOS_nocarry_asm_adx_bmi2(r, a, a, M, m0ninv) + montMul_CIOS_sparebit_asm_adx_bmi2(r, a, a, M, m0ninv) else: montSquare_CIOS_asm_adx_bmi2(r, a, M, m0ninv, spareBits >= 1) else: diff --git a/constantine/config/type_ff.nim b/constantine/config/type_ff.nim index f351608..50a465d 100644 --- a/constantine/config/type_ff.nim +++ b/constantine/config/type_ff.nim @@ -19,6 +19,8 @@ type ## P being the prime modulus of the Curve C ## Internally, data is stored in Montgomery n-residue form ## with the magic constant chosen for convenient division (a power of 2 depending on P bitsize) + # TODO, pseudo mersenne priles like 2²⁵⁵-19 have very fast modular reduction + # and don't need Montgomery representation mres*: matchingBigInt(C) Fr*[C: static Curve] = object diff --git a/constantine/curves/curve25519_inversion.nim b/constantine/curves/curve25519_inversion.nim new file mode 100644 index 0000000..d05a53e --- /dev/null +++ b/constantine/curves/curve25519_inversion.nim @@ -0,0 +1,60 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + ../config/curves, + ../arithmetic/finite_fields + +# ############################################################ +# +# Specialized inversion for BLS12-381 +# +# ############################################################ + +func inv_addchain*(r: var Fp[Curve25519], a: Fp[Curve25519]) = + var + x10 {.noinit.}: Fp[Curve25519] + x1001 {.noinit.}: Fp[Curve25519] + x1011 {.noinit.}: Fp[Curve25519] + + x10 .square(a) # 2 + x1001 .square_repeated(x10, 2) # 8 + x1001 *= a # 9 + x1011 .prod(x10, x1001) # 11 + # 5 operations + + # TODO: we can accumulate in a partially reduced + # doubled-size `r` to avoid the final substractions. + # and only reduce at the end. + # This requires the number of op to be less than log2(p) == 255 + + template t: untyped = x10 + + t.square(x1011) # 22 + r.prod(t, x1001) # 31 = 2⁵-1 + + template u: untyped = x1001 + + t.square_repeated(r, 5) + r *= t # 2¹⁰-1 + t.square_repeated(r, 10) + t *= r # 2²⁰-1 + + u.square_repeated(t, 20) + t *= u # 2⁴⁰-1 + t.square_repeated(10) + t *= r # 2⁵⁰-1 + r.square_repeated(t, 50) + r *= t # 2¹⁰⁰-1 + + u.square_repeated(r, 100) + r *= u # 2²⁰⁰-1 + r.square_repeated(50) + r *= t # 2²⁵⁰-1 + r.square_repeated(5) + r *= x1011 # 2²⁵⁵-21 (note: 11 = 2⁵-21) diff --git a/constantine/curves/curve25519_sqrt.nim b/constantine/curves/curve25519_sqrt.nim index 752907f..322eb4a 100644 --- a/constantine/curves/curve25519_sqrt.nim +++ b/constantine/curves/curve25519_sqrt.nim @@ -7,21 +7,32 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import - ../config/[curves, type_bigint, type_ff], - ../io/[io_bigints, io_fields], + ../config/[curves, type_ff], ../arithmetic/finite_fields -# p ≡ 5 (mod 8), hence 𝑖 ∈ Fp with 𝑖² ≡ −1 (mod p) -# Hence if α is a square -# with β ≡ α^((p+3)/8) (mod p) -# - either β² ≡ α (mod p), hence √α ≡ ±β (mod p) -# - or β² ≡ -α (mod p), hence √α ≡ ±𝑖β (mod p) - -# Sage: -# p = Integer('0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed') -# Fp = GF(p) -# sqrt_minus1 = Fp(-1).sqrt() -# print(Integer(sqrt_minus1).hex()) -const Curve25519_sqrt_minus_one* = Fp[Curve25519].fromHex( - "0x2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0" -) \ No newline at end of file +func invsqrt_addchain_pminus5over8*(r: var Fp[Curve25519], a: Fp[Curve25519]) = + ## Returns a^((p-5)/8) = 2²⁵²-3 for inverse square root computation + + var t{.noInit.}, u{.noInit.}, v{.noinit.}: Fp[Curve25519] + u.square(a) # 2 + v.square_repeated(u, 2) # 8 + v *= a # 9 + u *= v # 11 + u.square() # 22 + u *= v # 31 = 2⁵-1 + v.square_repeated(u, 5) # + u *= v # 2¹⁰-1 + v.square_repeated(u, 10) # + v *= u # 2²⁰-1 + t.square_repeated(v, 20) # + v *= t # 2⁴⁰-1 + v.square_repeated(10) # + u *= v # 2⁵⁰-1 + v.square_repeated(u, 50) # + v *= u # 2¹⁰⁰-1 + t.square_repeated(v, 100) # + v *= t # 2²⁰⁰-1 + v.square_repeated(50) # + u *= v # 2²⁵⁰-1 + u.square_repeated(2) # 2²⁵²-4 + r.prod(a, u) # 2²⁵²-3 \ No newline at end of file diff --git a/constantine/curves/zoo_inversions.nim b/constantine/curves/zoo_inversions.nim index 8f1b07a..def85a2 100644 --- a/constantine/curves/zoo_inversions.nim +++ b/constantine/curves/zoo_inversions.nim @@ -7,13 +7,15 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import - ../config/[curves, type_ff], + ../config/curves, ./bls12_377_inversion, ./bls12_381_inversion, ./bn254_nogami_inversion, ./bn254_snarks_inversion, ./bw6_761_inversion, - ./secp256k1_inversion + ./secp256k1_inversion, + ./curve25519_inversion + export bls12_377_inversion, @@ -21,16 +23,13 @@ export bn254_nogami_inversion, bn254_snarks_inversion, bw6_761_inversion, - secp256k1_inversion + secp256k1_inversion, + curve25519_inversion func hasInversionAddchain*(C: static Curve): static bool = - # TODO: For now we don't activate the addition chains - # for Secp256k1 - # Performance is slower than GCD (to investigate) - # For BW6-761 the addition chain is over 2x slower than Euclid-based inversion - # due to multiplication being so costly with 12 limbs (grows quadratically) - # while Euclid costs grows linearly. - when C in {BN254_Nogami, BN254_Snarks, BLS12_377, BLS12_381}: + ## Is an inversion addition chain implemented for the curve. + ## Note: the addition chain might be slower than Euclid-based inversion. + when C in {BN254_Nogami, BN254_Snarks, BLS12_377, BLS12_381, BW6_761, Curve25519, Secp256k1}: true else: false diff --git a/constantine/curves/zoo_square_roots.nim b/constantine/curves/zoo_square_roots.nim index 06b080c..9683dbb 100644 --- a/constantine/curves/zoo_square_roots.nim +++ b/constantine/curves/zoo_square_roots.nim @@ -27,7 +27,7 @@ export curve25519_sqrt func hasSqrtAddchain*(C: static Curve): static bool = - when C in {BLS12_381, BN254_Nogami, BN254_Snarks, BW6_761}: + when C in {BLS12_381, BN254_Nogami, BN254_Snarks, BW6_761, Curve25519}: true else: false @@ -43,7 +43,3 @@ func hasTonelliShanksAddchain*(C: static Curve): static bool = true else: false - -macro sqrt_minus_one*(C: static Curve): untyped = - ## Return 𝑖 ∈ Fp with 𝑖² ≡ −1 (mod p) - return bindSym($C & "_sqrt_minus_one") \ No newline at end of file diff --git a/constantine/tower_field_extensions/assembly/fp2_asm_x86_adx_bmi2.nim b/constantine/tower_field_extensions/assembly/fp2_asm_x86_adx_bmi2.nim index 3a1c7d7..5417839 100644 --- a/constantine/tower_field_extensions/assembly/fp2_asm_x86_adx_bmi2.nim +++ b/constantine/tower_field_extensions/assembly/fp2_asm_x86_adx_bmi2.nim @@ -71,7 +71,7 @@ func sqrx2x_complex_asm_adx_bmi2*( t0.diff(a.c0, a.c1) r.c0.mul_asm_adx_bmi2_impl(t0, t1) -func sqrx_complex_asm_adx_bmi2*( +func sqrx_complex_sparebit_asm_adx_bmi2*( r: var array[2, Fp], a: array[2, Fp] ) = @@ -85,10 +85,10 @@ func sqrx_complex_asm_adx_bmi2*( var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) v0.diff(a.c0, a.c1) v1.sum(a.c0, a.c1) - r.c1.mres.limbs.montMul_CIOS_nocarry_asm_adx_bmi2(a.c0.mres.limbs, a.c1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) + r.c1.mres.limbs.montMul_CIOS_sparebit_asm_adx_bmi2(a.c0.mres.limbs, a.c1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) # aliasing: a unneeded now r.c1.double() - r.c0.mres.limbs.montMul_CIOS_nocarry_asm_adx_bmi2(v0.mres.limbs, v1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) + r.c0.mres.limbs.montMul_CIOS_sparebit_asm_adx_bmi2(v0.mres.limbs, v1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) # 𝔽p2 multiplication # ------------------------------------------------------------ diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index 9ff921e..4f6a247 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -1230,9 +1230,9 @@ func square2x*(r: var QuadraticExt2x, a: QuadraticExt) = func square*(r: var QuadraticExt, a: QuadraticExt) = when r.fromComplexExtension(): when true: - when UseASM_X86_64 and a.c0.mres.limbs.len <= 6: + when UseASM_X86_64 and a.c0.mres.limbs.len <= 6 and r.typeof.has1extraBit(): if ({.noSideEffect.}: hasAdx()): - r.coords.sqrx_complex_asm_adx_bmi2(a.coords) + r.coords.sqrx_complex_sparebit_asm_adx_bmi2(a.coords) else: r.square_complex(a) else: diff --git a/tests/t_finite_fields_double_precision.nim b/tests/t_finite_fields_double_precision.nim index 4ad5790..5f42254 100644 --- a/tests/t_finite_fields_double_precision.nim +++ b/tests/t_finite_fields_double_precision.nim @@ -157,6 +157,22 @@ suite "Field Addition/Substraction/Negation via double-precision field elements" for _ in 0 ..< Iters: addsubneg_random_long01Seq(BLS12_381) + test "With Curve25519 field modulus": + for _ in 0 ..< Iters: + addsubneg_random_unsafe(Curve25519) + for _ in 0 ..< Iters: + addsubneg_randomHighHammingWeight(Curve25519) + for _ in 0 ..< Iters: + addsubneg_random_long01Seq(Curve25519) + + test "With Bandersnatch field modulus": + for _ in 0 ..< Iters: + addsubneg_random_unsafe(Bandersnatch) + for _ in 0 ..< Iters: + addsubneg_randomHighHammingWeight(Bandersnatch) + for _ in 0 ..< Iters: + addsubneg_random_long01Seq(Bandersnatch) + test "Negate 0 returns 0 (unique Montgomery repr)": var a: FpDbl[BN254_Snarks] var r {.noInit.}: FpDbl[BN254_Snarks] @@ -197,6 +213,22 @@ suite "Field Multiplication via double-precision field elements is consistent wi for _ in 0 ..< Iters: mul_random_long01Seq(BLS12_381) + test "With Curve25519 field modulus": + for _ in 0 ..< Iters: + mul_random_unsafe(Curve25519) + for _ in 0 ..< Iters: + mul_randomHighHammingWeight(Curve25519) + for _ in 0 ..< Iters: + mul_random_long01Seq(Curve25519) + + test "With Bandersnatch field modulus": + for _ in 0 ..< Iters: + mul_random_unsafe(Bandersnatch) + for _ in 0 ..< Iters: + mul_randomHighHammingWeight(Bandersnatch) + for _ in 0 ..< Iters: + mul_random_long01Seq(Bandersnatch) + suite "Field Squaring via double-precision field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]": test "With P-224 field modulus": for _ in 0 ..< Iters: @@ -229,3 +261,19 @@ suite "Field Squaring via double-precision field elements is consistent with sin sqr_randomHighHammingWeight(BLS12_381) for _ in 0 ..< Iters: sqr_random_long01Seq(BLS12_381) + + test "With Curve25519 field modulus": + for _ in 0 ..< Iters: + sqr_random_unsafe(Curve25519) + for _ in 0 ..< Iters: + sqr_randomHighHammingWeight(Curve25519) + for _ in 0 ..< Iters: + sqr_random_long01Seq(Curve25519) + + test "With Bandersnatch field modulus": + for _ in 0 ..< Iters: + sqr_random_unsafe(Bandersnatch) + for _ in 0 ..< Iters: + sqr_randomHighHammingWeight(Bandersnatch) + for _ in 0 ..< Iters: + sqr_random_long01Seq(Bandersnatch) \ No newline at end of file diff --git a/tests/t_finite_fields_mulsquare.nim b/tests/t_finite_fields_mulsquare.nim index 6edd470..d1fa93e 100644 --- a/tests/t_finite_fields_mulsquare.nim +++ b/tests/t_finite_fields_mulsquare.nim @@ -83,7 +83,10 @@ proc mainSanity() = sanity Mersenne127 sanity P224 # P224 uses the fast-path with 64-bit words and the slow path with 32-bit words sanity P256 + sanity Secp256k1 sanity BLS12_381 + sanity Curve25519 + sanity Bandersnatch mainSanity() @@ -152,6 +155,14 @@ suite "Random Modular Squaring is consistent with Modular Multiplication" & " [" for _ in 0 ..< Iters: random_long01Seq(P256) + test "Random squaring mod Secp256k1 [FastSquaring = " & $(Fp[Secp256k1].getSpareBits() >= 2) & "]": + for _ in 0 ..< Iters: + randomCurve(Secp256k1) + for _ in 0 ..< Iters: + randomHighHammingWeight(Secp256k1) + for _ in 0 ..< Iters: + random_long01Seq(Secp256k1) + test "Random squaring mod BLS12_381 [FastSquaring = " & $(Fp[BLS12_381].getSpareBits() >= 2) & "]": for _ in 0 ..< Iters: randomCurve(BLS12_381) @@ -160,6 +171,22 @@ suite "Random Modular Squaring is consistent with Modular Multiplication" & " [" for _ in 0 ..< Iters: random_long01Seq(BLS12_381) + test "Random squaring mod Curve25519 [FastSquaring = " & $(Fp[Curve25519].getSpareBits() >= 2) & "]": + for _ in 0 ..< Iters: + randomCurve(Curve25519) + for _ in 0 ..< Iters: + randomHighHammingWeight(Curve25519) + for _ in 0 ..< Iters: + random_long01Seq(Curve25519) + + test "Random squaring mod Bandersnatch [FastSquaring = " & $(Fp[Bandersnatch].getSpareBits() >= 2) & "]": + for _ in 0 ..< Iters: + randomCurve(Bandersnatch) + for _ in 0 ..< Iters: + randomHighHammingWeight(Bandersnatch) + for _ in 0 ..< Iters: + random_long01Seq(Bandersnatch) + suite "Modular squaring - bugs highlighted by property-based testing": test "a² == (-a)² on for Fp[2^127 - 1] - #61": var a{.noInit.}: Fp[Mersenne127] diff --git a/tests/t_finite_fields_powinv.nim b/tests/t_finite_fields_powinv.nim index 87b2c9b..b3386c8 100644 --- a/tests/t_finite_fields_powinv.nim +++ b/tests/t_finite_fields_powinv.nim @@ -198,6 +198,7 @@ proc main() = testRandomDiv2 Secp256k1 testRandomDiv2 BLS12_377 testRandomDiv2 BLS12_381 + testRandomDiv2 Bandersnatch suite "Modular inversion over prime fields" & " [" & $WordBitwidth & "-bit mode]": test "Specific tests on Fp[BLS12_381]": @@ -285,6 +286,7 @@ proc main() = testRandomInv Secp256k1 testRandomInv BLS12_377 testRandomInv BLS12_381 + testRandomInv Bandersnatch main() diff --git a/tests/t_finite_fields_vs_gmp.nim b/tests/t_finite_fields_vs_gmp.nim index 30b4ab9..42f9e13 100644 --- a/tests/t_finite_fields_vs_gmp.nim +++ b/tests/t_finite_fields_vs_gmp.nim @@ -24,7 +24,7 @@ var RNG {.compileTime.} = initRand(1234) const AvailableCurves = [ P224, BN254_Nogami, BN254_Snarks, - P256, Secp256k1, + P256, Secp256k1, Curve25519, Bandersnatch, BLS12_377, BLS12_381, BW6_761 ]