From 6423be0dfba5553b5b0e530c3f904077b6847281 Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Tue, 17 Mar 2020 22:04:37 +0100 Subject: [PATCH] Add optimized squaring (~15% speedup) (#18) * Add optimized squaring (~15% speedup) * avoid repetitions in tests --- constantine.nimble | 36 +++++- constantine/arithmetic/finite_fields.nim | 2 +- constantine/arithmetic/montgomery.nim | 118 ++++++++++++++++-- constantine/arithmetic/precomputed.nim | 8 ++ constantine/config/curves.nim | 22 ++++ constantine/primitives/extended_precision.nim | 56 ++++++++- .../extended_precision_64bit_uint128.nim | 18 +++ .../extended_precision_x86_64_msvc.nim | 8 ++ tests/test_finite_fields_mulsquare.nim | 112 +++++++++++++++++ tests/test_finite_fields_mulsquare.nim.cfg | 1 + tests/test_finite_fields_powinv.nim | 4 +- 11 files changed, 362 insertions(+), 23 deletions(-) create mode 100644 tests/test_finite_fields_mulsquare.nim create mode 100644 tests/test_finite_fields_mulsquare.nim.cfg diff --git a/constantine.nimble b/constantine.nimble index b9638b7..15d2e2a 100644 --- a/constantine.nimble +++ b/constantine.nimble @@ -29,52 +29,84 @@ proc test(flags, path: string) = ### tasks task test, "Run all tests": # -d:testingCurves is configured in a *.nim.cfg for convenience + + # Primitives test "", "tests/test_primitives.nim" + # Big ints test "", "tests/test_io_bigints.nim" test "", "tests/test_bigints.nim" test "", "tests/test_bigints_multimod.nim" + test "", "tests/test_bigints_vs_gmp.nim" + + # Field test "", "tests/test_io_fields" test "", "tests/test_finite_fields.nim" + test "", "tests/test_finite_fields_mulsquare.nim" test "", "tests/test_finite_fields_powinv.nim" - test "", "tests/test_bigints_vs_gmp.nim" test "", "tests/test_finite_fields_vs_gmp.nim" + # 𝔽p2 + test "", "tests/test_fp2.nim" + if sizeof(int) == 8: # 32-bit tests + # Primitives test "-d:Constantine32", "tests/test_primitives.nim" + # Big ints test "-d:Constantine32", "tests/test_io_bigints.nim" test "-d:Constantine32", "tests/test_bigints.nim" test "-d:Constantine32", "tests/test_bigints_multimod.nim" + test "-d:Constantine32", "tests/test_bigints_vs_gmp.nim" + + # Field test "-d:Constantine32", "tests/test_io_fields" test "-d:Constantine32", "tests/test_finite_fields.nim" + test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim" test "-d:Constantine32", "tests/test_finite_fields_powinv.nim" - test "-d:Constantine32", "tests/test_bigints_vs_gmp.nim" test "-d:Constantine32", "tests/test_finite_fields_vs_gmp.nim" + # 𝔽p2 + test "", "tests/test_fp2.nim" + task test_no_gmp, "Run tests that don't require GMP": # -d:testingCurves is configured in a *.nim.cfg for convenience + + # Primitives test "", "tests/test_primitives.nim" + # Big ints test "", "tests/test_io_bigints.nim" test "", "tests/test_bigints.nim" test "", "tests/test_bigints_multimod.nim" + # Field test "", "tests/test_io_fields" test "", "tests/test_finite_fields.nim" + test "", "tests/test_finite_fields_mulsquare.nim" test "", "tests/test_finite_fields_powinv.nim" + # 𝔽p2 + test "", "tests/test_fp2.nim" + if sizeof(int) == 8: # 32-bit tests + # Primitives test "-d:Constantine32", "tests/test_primitives.nim" + # Big ints test "-d:Constantine32", "tests/test_io_bigints.nim" test "-d:Constantine32", "tests/test_bigints.nim" test "-d:Constantine32", "tests/test_bigints_multimod.nim" + # Field test "-d:Constantine32", "tests/test_io_fields" test "-d:Constantine32", "tests/test_finite_fields.nim" + test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim" test "-d:Constantine32", "tests/test_finite_fields_powinv.nim" + + # 𝔽p2 + test "", "tests/test_fp2.nim" diff --git a/constantine/arithmetic/finite_fields.nim b/constantine/arithmetic/finite_fields.nim index fe7ceda..93c9f93 100644 --- a/constantine/arithmetic/finite_fields.nim +++ b/constantine/arithmetic/finite_fields.nim @@ -148,7 +148,7 @@ func prod*(r: var Fp, a, b: Fp) = func square*(r: var Fp, a: Fp) = ## Squaring modulo p - r.mres.montySquare(a.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord(), Fp.C.canUseNoCarryMontyMul()) + r.mres.montySquare(a.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord(), Fp.C.canUseNoCarryMontySquare()) func neg*(r: var Fp, a: Fp) = ## Negate modulo p diff --git a/constantine/arithmetic/montgomery.nim b/constantine/arithmetic/montgomery.nim index 4bf4100..964f975 100644 --- a/constantine/arithmetic/montgomery.nim +++ b/constantine/arithmetic/montgomery.nim @@ -86,7 +86,7 @@ macro staticFor(idx: untyped{nkIdent}, start, stopEx: static int, body: untyped) # the code generated is already big enough for curve with different # limb sizes, we want to use the same codepath when limbs lenght are compatible. -func montyMul_CIOS_nocarry_unrolled(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = +func montyMul_CIOS_nocarry(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = ## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS) ## and no-carry optimization. ## This requires the most significant word of the Modulus @@ -137,23 +137,21 @@ func montyMul_CIOS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = var tNp1: Carry staticFor i, 0, N: - var C = Zero - + var A = Zero # Multiplication staticFor j, 0, N: - # (C, t[j]) <- a[j] * b[i] + t[j] + C - muladd2(C, t[j], a[j], b[i], t[j], C) - addC(tNp1, tN, tN, C, Carry(0)) + # (A, t[j]) <- a[j] * b[i] + t[j] + A + muladd2(A, t[j], a[j], b[i], t[j], A) + addC(tNp1, tN, tN, A, Carry(0)) # Reduction # m <- (t[0] * m0ninv) mod 2^w # (C, _) <- m * M[0] + t[0] - var lo: Word - C = Zero + var C, lo = Zero let m = t[0] * Word(m0ninv) muladd1(C, lo, m, M[0], t[0]) staticFor j, 1, N: - # (C, t[j]) <- a[j] * b[i] + t[j] + C + # (C, t[j-1]) <- m*M[j] + t[j] + C muladd2(C, t[j-1], m, M[j], t[j], C) # (C,t[N-1]) <- t[N] + C @@ -168,6 +166,98 @@ func montyMul_CIOS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)^w r = t +func montySquare_CIOS_nocarry(r: var Limbs, a, M: Limbs, m0ninv: BaseType) = + ## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS) + ## and no-carry optimization. + ## This requires the most significant word of the Modulus + ## M[^1] < high(Word) shr 2 (i.e. less than 0b00111...1111) + ## https://hackmd.io/@zkteam/modular_multiplication + + # We want all the computation to be kept in registers + # hence we use a temporary `t`, hoping that the compiler does it. + var t: typeof(M) # zero-init + const N = t.len + staticFor i, 0, N: + # Squaring + var + A1: Carry + A0: Word + # (A0, t[i]) <- a[i] * a[i] + t[i] + muladd1(A0, t[i], a[i], a[i], t[i]) + staticFor j, i+1, N: + # (A1, A0, t[j]) <- 2*a[j]*a[i] + t[j] + (A1, A0) + # 2*a[j]*a[i] can spill 1-bit on a 3rd word + mulDoubleAdd2(A1, A0, t[j], a[j], a[i], t[j], A1, A0) + + # Reduction + # m <- (t[0] * m0ninv) mod 2^w + # (C, _) <- m * M[0] + t[0] + let m = t[0] * Word(m0ninv) + var C, lo: Word + muladd1(C, lo, m, M[0], t[0]) + staticFor j, 1, N: + # (C, t[j-1]) <- m*M[j] + t[j] + C + muladd2(C, t[j-1], m, M[j], t[j], C) + + t[N-1] = C + A0 + + discard t.csub(M, not(t < M)) + r = t + +func montySquare_CIOS(r: var Limbs, a, M: Limbs, m0ninv: BaseType) = + ## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS) + ## + ## Architectural Support for Long Integer Modulo Arithmetic on Risc-Based Smart Cards + ## Johann Großschädl, 2003 + ## https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=95950BAC26A728114431C0C7B425E022?doi=10.1.1.115.3276&rep=rep1&type=pdf + ## + ## Analyzing and Comparing Montgomery Multiplication Algorithms + ## Koc, Acar, Kaliski, 1996 + ## https://www.semanticscholar.org/paper/Analyzing-and-comparing-Montgomery-multiplication-Ko%C3%A7-Acar/5e3941ff482ec3ee41dc53c3298f0be085c69483 + + # We want all the computation to be kept in registers + # hence we use a temporary `t`, hoping that the compiler does it. + var t: typeof(M) # zero-init + const N = t.len + # Extra words to handle up to 2 carries t[N] and t[N+1] + var tNp1: Word + var tN: Word + + staticFor i, 0, N: + # Squaring + var + A1: Carry + A0: Word + # (A0, t[i]) <- a[i] * a[i] + t[i] + muladd1(A0, t[i], a[i], a[i], t[i]) + staticFor j, i+1, N: + # (A1, A0, t[j]) <- 2*a[j]*a[i] + t[j] + (A1, A0) + # 2*a[j]*a[i] can spill 1-bit on a 3rd word + mulDoubleAdd2(A1, A0, t[j], a[j], a[i], t[j], A1, A0) + + var carryS: Carry + addC(carryS, tN, tN, A0, Carry(0)) + addC(carryS, tNp1, Word(A1), Zero, carryS) + + # Reduction + # m <- (t[0] * m0ninv) mod 2^w + # (C, _) <- m * M[0] + t[0] + var C, lo: Word + let m = t[0] * Word(m0ninv) + muladd1(C, lo, m, M[0], t[0]) + staticFor j, 1, N: + # (C, t[j-1]) <- m*M[j] + t[j] + C + muladd2(C, t[j-1], m, M[j], t[j], C) + + # (C,t[N-1]) <- t[N] + C + # (_, t[N]) <- t[N+1] + C + var carryR: Carry + addC(carryR, t[N-1], tN, C, Carry(0)) + addC(carryR, tN, Word(tNp1), Zero, carryR) + + discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)^w + r = t + # Exported API # ------------------------------------------------------------ @@ -206,15 +296,19 @@ func montyMul*( # of Montgomery-friendly m0ninv if the compiler deems it interesting, # or we use `when m0ninv == 1` and enforce the inlining. when canUseNoCarryMontyMul: - montyMul_CIOS_nocarry_unrolled(r, a, b, M, m0ninv) + montyMul_CIOS_nocarry(r, a, b, M, m0ninv) else: montyMul_CIOS(r, a, b, M, m0ninv) func montySquare*(r: var Limbs, a, M: Limbs, - m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) {.inline.} = + m0ninv: static BaseType, canUseNoCarryMontySquare: static bool) {.inline.} = ## Compute r <- a^2 (mod M) in the Montgomery domain ## `negInvModWord` = -1/M (mod Word). Our words are 2^31 or 2^63 - montyMul(r, a, a, M, m0ninv, canUseNoCarryMontyMul) + + when canUseNoCarryMontySquare: + montySquare_CIOS_nocarry(r, a, M, m0ninv) + else: + montySquare_CIOS(r, a, M, m0ninv) func redc*(r: var Limbs, a, one, M: Limbs, m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) {.inline.} = diff --git a/constantine/arithmetic/precomputed.nim b/constantine/arithmetic/precomputed.nim index 150db9d..ec93e07 100644 --- a/constantine/arithmetic/precomputed.nim +++ b/constantine/arithmetic/precomputed.nim @@ -128,6 +128,14 @@ func useNoCarryMontyMul*(M: BigInt): bool = # https://github.com/nim-lang/Nim/issues/9679 BaseType(M.limbs[^1]) < high(BaseType) shr 1 +func useNoCarryMontySquare*(M: BigInt): bool = + ## Returns if the modulus is compatible + ## with the no-carry Montgomery Squaring + ## from https://hackmd.io/@zkteam/modular_multiplication + # Indirection needed because static object are buggy + # https://github.com/nim-lang/Nim/issues/9679 + BaseType(M.limbs[^1]) < high(BaseType) shr 2 + func negInvModWord*(M: BigInt): BaseType = ## Returns the Montgomery domain magic constant for the input modulus: ## diff --git a/constantine/config/curves.nim b/constantine/config/curves.nim index a627454..9615043 100644 --- a/constantine/config/curves.nim +++ b/constantine/config/curves.nim @@ -52,6 +52,9 @@ when not defined(testingCurves): bitsize: 381 modulus: "0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab" # Equation: y^2 = x^3 + 4 + curve P224: # NIST P-224 + bitsize: 224 + modulus: "0xffffffff_ffffffff_ffffffff_ffffffff_00000000_00000000_00000001" curve P256: # secp256r1 / NIST P-256 bitsize: 256 modulus: "0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff" @@ -70,6 +73,9 @@ else: curve Mersenne127: bitsize: 127 modulus: "0x7fffffffffffffffffffffffffffffff" # 2^127 - 1 + curve P224: # NIST P-224 + bitsize: 224 + modulus: "0xffffffff_ffffffff_ffffffff_ffffffff_00000000_00000000_00000001" curve P256: # secp256r1 / NIST P-256 bitsize: 256 modulus: "0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff" @@ -120,6 +126,17 @@ macro genMontyMagics(T: typed): untyped = ) ) + # const MyCurve_CanUseNoCarryMontySquare = useNoCarryMontySquare(MyCurve_Modulus) + result.add newConstStmt( + ident($curve & "_CanUseNoCarryMontySquare"), newCall( + bindSym"useNoCarryMontySquare", + nnkDotExpr.newTree( + bindSym($curve & "_Modulus"), + ident"mres" + ) + ) + ) + # const MyCurve_R2modP = r2mod(MyCurve_Modulus) result.add newConstStmt( ident($curve & "_R2modP"), newCall( @@ -170,6 +187,11 @@ macro canUseNoCarryMontyMul*(C: static Curve): untyped = ## Montgomery multiplication that avoids many carries result = bindSym($C & "_CanUseNoCarryMontyMul") +macro canUseNoCarryMontySquare*(C: static Curve): untyped = + ## Returns true if the Modulus is compatible with a fast + ## Montgomery squaring that avoids many carries + result = bindSym($C & "_CanUseNoCarryMontySquare") + macro getR2modP*(C: static Curve): untyped = ## Get the Montgomery "R^2 mod P" constant associated to a curve field modulus result = bindSym($C & "_R2modP") diff --git a/constantine/primitives/extended_precision.nim b/constantine/primitives/extended_precision.nim index 87fe11a..22a91f6 100644 --- a/constantine/primitives/extended_precision.nim +++ b/constantine/primitives/extended_precision.nim @@ -12,7 +12,7 @@ # # ############################################################ -import ./constant_time_types +import ./constant_time_types, ./addcarry_subborrow # ############################################################ # @@ -37,6 +37,16 @@ func unsafeDiv2n1n*(q, r: var Ct[uint32], n_hi, n_lo, d: Ct[uint32]) {.inline.}= q = (Ct[uint32])(dividend div divisor) r = (Ct[uint32])(dividend mod divisor) +func mul*(hi, lo: var Ct[uint32], a, b: Ct[uint32]) {.inline.} = + ## Extended precision multiplication + ## (hi, lo) <- a*b + ## + ## This is constant-time on most hardware + ## See: https://www.bearssl.org/ctmul.html + let dblPrec = uint64(a) * uint64(b) + lo = (Ct[uint32])(dblPrec) + hi = (Ct[uint32])(dblPrec shr 32) + func muladd1*(hi, lo: var Ct[uint32], a, b, c: Ct[uint32]) {.inline.} = ## Extended precision multiplication + addition ## (hi, lo) <- a*b + c @@ -70,13 +80,49 @@ func muladd2*(hi, lo: var Ct[uint32], a, b, c1, c2: Ct[uint32]) {.inline.}= when sizeof(int) == 8: when defined(vcc): - from ./extended_precision_x86_64_msvc import unsafeDiv2n1n, muladd1, muladd2 + from ./extended_precision_x86_64_msvc import unsafeDiv2n1n, mul, muladd1, muladd2 elif GCCCompatible: # TODO: constant-time div2n1n when X86: from ./extended_precision_x86_64_gcc import unsafeDiv2n1n - from ./extended_precision_64bit_uint128 import muladd1, muladd2 + from ./extended_precision_64bit_uint128 import mul, muladd1, muladd2 else: - from ./extended_precision_64bit_uint128 import unsafeDiv2n1n, muladd1, muladd2 - + from ./extended_precision_64bit_uint128 import unsafeDiv2n1n, mul, muladd1, muladd2 export unsafeDiv2n1n, muladd1, muladd2 + +# ############################################################ +# +# Composite primitives +# +# ############################################################ + +func mulDoubleAdd2*[T: Ct[uint32]|Ct[uint64]](r2: var Carry, r1, r0: var T, a, b, c: T, dHi: Carry, dLo: T) {.inline.} = + ## (r2, r1, r0) <- 2*a*b + c + (dHi, dLo) + ## with r = (r2, r1, r0) a triple-word number + ## and d = (dHi, dLo) a double-word number + ## r2 and dHi are carries, either 0 or 1 + + var carry: Carry + + # (r1, r0) <- a*b + # Note: 0xFFFFFFFF_FFFFFFFF² -> (hi: 0xFFFFFFFF_FFFFFFFE, lo: 0x00000000_00000001) + mul(r1, r0, a, b) + + # (r2, r1, r0) <- 2*a*b + # Then (hi: 0xFFFFFFFF_FFFFFFFE, lo: 0x00000000_00000001) * 2 + # (carry: 1, hi: 0xFFFFFFFF_FFFFFFFC, lo: 0x00000000_00000002) + addC(carry, r0, r0, r0, Carry(0)) + addC(r2, r1, r1, r1, carry) + + # (r1, r0) <- (r1, r0) + c + # Adding any uint64 cannot overflow into r2 for example Adding 2^64-1 + # (carry: 1, hi: 0xFFFFFFFF_FFFFFFFD, lo: 0x00000000_00000001) + addC(carry, r0, r0, c, Carry(0)) + addC(carry, r1, r1, T(0), carry) + + # (r1, r0) <- (r1, r0) + (dHi, dLo) with dHi a carry (previous limb r2) + # (dHi, dLo) is at most (dhi: 1, dlo: 0xFFFFFFFF_FFFFFFFF) + # summing into (carry: 1, hi: 0xFFFFFFFF_FFFFFFFD, lo: 0x00000000_00000001) + # result at most in (carry: 1, hi: 0xFFFFFFFF_FFFFFFFF, lo: 0x00000000_00000000) + addC(carry, r0, r0, dLo, Carry(0)) + addC(carry, r1, r1, T(dHi), carry) diff --git a/constantine/primitives/extended_precision_64bit_uint128.nim b/constantine/primitives/extended_precision_64bit_uint128.nim index 7e1e70f..599ea95 100644 --- a/constantine/primitives/extended_precision_64bit_uint128.nim +++ b/constantine/primitives/extended_precision_64bit_uint128.nim @@ -36,6 +36,24 @@ func unsafeDiv2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}= {.emit:["*",q, " = (NU64)(", dblPrec," / ", d, ");"].} {.emit:["*",r, " = (NU64)(", dblPrec," % ", d, ");"].} +func mul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} = + ## Extended precision multiplication + ## (hi, lo) <- a*b + ## + ## This is constant-time on most hardware + ## See: https://www.bearssl.org/ctmul.html + block: + var dblPrec {.noInit.}: uint128 + {.emit:[dblPrec, " = (unsigned __int128)", a," * (unsigned __int128)", b,";"].} + + # Don't forget to dereference the var param in C mode + when defined(cpp): + {.emit:[hi, " = (NU64)(", dblPrec," >> ", 64'u64, ");"].} + {.emit:[lo, " = (NU64)", dblPrec,";"].} + else: + {.emit:["*",hi, " = (NU64)(", dblPrec," >> ", 64'u64, ");"].} + {.emit:["*",lo, " = (NU64)", dblPrec,";"].} + func muladd1*(hi, lo: var Ct[uint64], a, b, c: Ct[uint64]) {.inline.} = ## Extended precision multiplication + addition ## (hi, lo) <- a*b + c diff --git a/constantine/primitives/extended_precision_x86_64_msvc.nim b/constantine/primitives/extended_precision_x86_64_msvc.nim index 2be7d34..65fa75b 100644 --- a/constantine/primitives/extended_precision_x86_64_msvc.nim +++ b/constantine/primitives/extended_precision_x86_64_msvc.nim @@ -43,6 +43,14 @@ func unsafeDiv2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}= # -> use uint128? Compiler might add unwanted branches q = udiv128(n_hi, n_lo, d, r) +func mul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} = + ## Extended precision multiplication + ## (hi, lo) <- a*b + ## + ## This is constant-time on most hardware + ## See: https://www.bearssl.org/ctmul.html + lo = umul128(a, b, hi) + func muladd1*(hi, lo: var Ct[uint64], a, b, c: Ct[uint64]) {.inline.} = ## Extended precision multiplication + addition ## (hi, lo) <- a*b + c diff --git a/tests/test_finite_fields_mulsquare.nim b/tests/test_finite_fields_mulsquare.nim new file mode 100644 index 0000000..25066ae --- /dev/null +++ b/tests/test_finite_fields_mulsquare.nim @@ -0,0 +1,112 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import std/unittest, std/times, + ../constantine/arithmetic/[bigints, finite_fields], + ../constantine/io/[io_bigints, io_fields], + ../constantine/config/curves, + # Test utilities + ./prng + +const Iters = 128 + +var rng: RngState +let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32 +rng.seed(seed) +echo "test_finite_fields_mulsquare xoshiro512** seed: ", seed + +static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option" + +import ../constantine/config/common + +proc sanity(C: static Curve) = + test "Squaring 0,1,2 with "& $C & " [FastSquaring = " & $Fake101.canUseNoCarryMontySquare & "]": + block: # 0² mod + var n: Fp[C] + + n.fromUint(0'u32) + let expected = n + + var r: Fp[C] + r.square(n) + + check: bool(r == expected) + + block: # 1² mod + var n: Fp[C] + + n.fromUint(1'u32) + let expected = n + + var r: Fp[C] + r.square(n) + + check: bool(r == expected) + + block: # 2² mod + var n, expected: Fp[C] + + n.fromUint(2'u32) + expected.fromUint(4'u32) + + var r: Fp[C] + r.square(n) + + check: bool(r == expected) + +proc mainSanity() = + suite "Modular squaring is consistent with multiplication on special elements": + sanity Fake101 + sanity Mersenne61 + sanity Mersenne127 + sanity P224 # P224 uses the fast-path with 64-bit words and the slow path with 32-bit words + sanity P256 + sanity BLS12_381 + +mainSanity() + +proc mainSelectCases() = + suite "Modular Squaring: selected tricky cases": + test "P-256 [FastSquaring = " & $P256.canUseNoCarryMontySquare & "]": + block: + # Triggered an issue in the (t[N+1], t[N]) = t[N] + (A1, A0) + # between the squaring and reduction step, with t[N+1] and A1 being carry bits. + var a: Fp[P256] + a.fromHex"0xa0da36b4885df98997ee89a22a7ceb64fa431b2ecc87342fc083587da3d6ebc7" + + var r_mul, r_sqr: Fp[P256] + + r_mul.prod(a, a) + r_sqr.square(a) + + doAssert bool(r_mul == r_sqr) + +mainSelectCases() + +proc randomCurve(C: static Curve) = + let a = rng.random(Fp[C]) + + var r_mul, r_sqr: Fp[C] + + r_mul.prod(a, a) + r_sqr.square(a) + + doAssert bool(r_mul == r_sqr) + +suite "Random Modular Squaring is consistent with Modular Multiplication": + test "Random squaring mod P-224 [FastSquaring = " & $P224.canUseNoCarryMontySquare & "]": + for _ in 0 ..< Iters: + randomCurve(P224) + + test "Random squaring mod P-256 [FastSquaring = " & $P256.canUseNoCarryMontySquare & "]": + for _ in 0 ..< Iters: + randomCurve(P256) + + test "Random squaring mod BLS12_381 [FastSquaring = " & $BLS12_381.canUseNoCarryMontySquare & "]": + for _ in 0 ..< Iters: + randomCurve(BLS12_381) diff --git a/tests/test_finite_fields_mulsquare.nim.cfg b/tests/test_finite_fields_mulsquare.nim.cfg new file mode 100644 index 0000000..0922c18 --- /dev/null +++ b/tests/test_finite_fields_mulsquare.nim.cfg @@ -0,0 +1 @@ +-d:testingCurves diff --git a/tests/test_finite_fields_powinv.nim b/tests/test_finite_fields_powinv.nim index 20430f8..ffb712e 100644 --- a/tests/test_finite_fields_powinv.nim +++ b/tests/test_finite_fields_powinv.nim @@ -8,11 +8,9 @@ import unittest, ../constantine/arithmetic/[bigints, finite_fields], - ../constantine/io/io_fields, + ../constantine/io/[io_bigints, io_fields], ../constantine/config/curves -import ../constantine/io/io_bigints - static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option" proc main() =