Add optimized squaring (~15% speedup) (#18)
* Add optimized squaring (~15% speedup) * avoid repetitions in tests
This commit is contained in:
parent
4ff0e3d90b
commit
6423be0dfb
|
@ -29,52 +29,84 @@ proc test(flags, path: string) =
|
||||||
### tasks
|
### tasks
|
||||||
task test, "Run all tests":
|
task test, "Run all tests":
|
||||||
# -d:testingCurves is configured in a *.nim.cfg for convenience
|
# -d:testingCurves is configured in a *.nim.cfg for convenience
|
||||||
|
|
||||||
|
# Primitives
|
||||||
test "", "tests/test_primitives.nim"
|
test "", "tests/test_primitives.nim"
|
||||||
|
|
||||||
|
# Big ints
|
||||||
test "", "tests/test_io_bigints.nim"
|
test "", "tests/test_io_bigints.nim"
|
||||||
test "", "tests/test_bigints.nim"
|
test "", "tests/test_bigints.nim"
|
||||||
test "", "tests/test_bigints_multimod.nim"
|
test "", "tests/test_bigints_multimod.nim"
|
||||||
|
|
||||||
|
test "", "tests/test_bigints_vs_gmp.nim"
|
||||||
|
|
||||||
|
# Field
|
||||||
test "", "tests/test_io_fields"
|
test "", "tests/test_io_fields"
|
||||||
test "", "tests/test_finite_fields.nim"
|
test "", "tests/test_finite_fields.nim"
|
||||||
|
test "", "tests/test_finite_fields_mulsquare.nim"
|
||||||
test "", "tests/test_finite_fields_powinv.nim"
|
test "", "tests/test_finite_fields_powinv.nim"
|
||||||
|
|
||||||
test "", "tests/test_bigints_vs_gmp.nim"
|
|
||||||
test "", "tests/test_finite_fields_vs_gmp.nim"
|
test "", "tests/test_finite_fields_vs_gmp.nim"
|
||||||
|
|
||||||
|
# 𝔽p2
|
||||||
|
test "", "tests/test_fp2.nim"
|
||||||
|
|
||||||
if sizeof(int) == 8: # 32-bit tests
|
if sizeof(int) == 8: # 32-bit tests
|
||||||
|
# Primitives
|
||||||
test "-d:Constantine32", "tests/test_primitives.nim"
|
test "-d:Constantine32", "tests/test_primitives.nim"
|
||||||
|
|
||||||
|
# Big ints
|
||||||
test "-d:Constantine32", "tests/test_io_bigints.nim"
|
test "-d:Constantine32", "tests/test_io_bigints.nim"
|
||||||
test "-d:Constantine32", "tests/test_bigints.nim"
|
test "-d:Constantine32", "tests/test_bigints.nim"
|
||||||
test "-d:Constantine32", "tests/test_bigints_multimod.nim"
|
test "-d:Constantine32", "tests/test_bigints_multimod.nim"
|
||||||
|
|
||||||
|
test "-d:Constantine32", "tests/test_bigints_vs_gmp.nim"
|
||||||
|
|
||||||
|
# Field
|
||||||
test "-d:Constantine32", "tests/test_io_fields"
|
test "-d:Constantine32", "tests/test_io_fields"
|
||||||
test "-d:Constantine32", "tests/test_finite_fields.nim"
|
test "-d:Constantine32", "tests/test_finite_fields.nim"
|
||||||
|
test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim"
|
||||||
test "-d:Constantine32", "tests/test_finite_fields_powinv.nim"
|
test "-d:Constantine32", "tests/test_finite_fields_powinv.nim"
|
||||||
|
|
||||||
test "-d:Constantine32", "tests/test_bigints_vs_gmp.nim"
|
|
||||||
test "-d:Constantine32", "tests/test_finite_fields_vs_gmp.nim"
|
test "-d:Constantine32", "tests/test_finite_fields_vs_gmp.nim"
|
||||||
|
|
||||||
|
# 𝔽p2
|
||||||
|
test "", "tests/test_fp2.nim"
|
||||||
|
|
||||||
task test_no_gmp, "Run tests that don't require GMP":
|
task test_no_gmp, "Run tests that don't require GMP":
|
||||||
# -d:testingCurves is configured in a *.nim.cfg for convenience
|
# -d:testingCurves is configured in a *.nim.cfg for convenience
|
||||||
|
|
||||||
|
# Primitives
|
||||||
test "", "tests/test_primitives.nim"
|
test "", "tests/test_primitives.nim"
|
||||||
|
|
||||||
|
# Big ints
|
||||||
test "", "tests/test_io_bigints.nim"
|
test "", "tests/test_io_bigints.nim"
|
||||||
test "", "tests/test_bigints.nim"
|
test "", "tests/test_bigints.nim"
|
||||||
test "", "tests/test_bigints_multimod.nim"
|
test "", "tests/test_bigints_multimod.nim"
|
||||||
|
|
||||||
|
# Field
|
||||||
test "", "tests/test_io_fields"
|
test "", "tests/test_io_fields"
|
||||||
test "", "tests/test_finite_fields.nim"
|
test "", "tests/test_finite_fields.nim"
|
||||||
|
test "", "tests/test_finite_fields_mulsquare.nim"
|
||||||
test "", "tests/test_finite_fields_powinv.nim"
|
test "", "tests/test_finite_fields_powinv.nim"
|
||||||
|
|
||||||
|
# 𝔽p2
|
||||||
|
test "", "tests/test_fp2.nim"
|
||||||
|
|
||||||
if sizeof(int) == 8: # 32-bit tests
|
if sizeof(int) == 8: # 32-bit tests
|
||||||
|
# Primitives
|
||||||
test "-d:Constantine32", "tests/test_primitives.nim"
|
test "-d:Constantine32", "tests/test_primitives.nim"
|
||||||
|
|
||||||
|
# Big ints
|
||||||
test "-d:Constantine32", "tests/test_io_bigints.nim"
|
test "-d:Constantine32", "tests/test_io_bigints.nim"
|
||||||
test "-d:Constantine32", "tests/test_bigints.nim"
|
test "-d:Constantine32", "tests/test_bigints.nim"
|
||||||
test "-d:Constantine32", "tests/test_bigints_multimod.nim"
|
test "-d:Constantine32", "tests/test_bigints_multimod.nim"
|
||||||
|
|
||||||
|
# Field
|
||||||
test "-d:Constantine32", "tests/test_io_fields"
|
test "-d:Constantine32", "tests/test_io_fields"
|
||||||
test "-d:Constantine32", "tests/test_finite_fields.nim"
|
test "-d:Constantine32", "tests/test_finite_fields.nim"
|
||||||
|
test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim"
|
||||||
test "-d:Constantine32", "tests/test_finite_fields_powinv.nim"
|
test "-d:Constantine32", "tests/test_finite_fields_powinv.nim"
|
||||||
|
|
||||||
|
# 𝔽p2
|
||||||
|
test "", "tests/test_fp2.nim"
|
||||||
|
|
|
@ -148,7 +148,7 @@ func prod*(r: var Fp, a, b: Fp) =
|
||||||
|
|
||||||
func square*(r: var Fp, a: Fp) =
|
func square*(r: var Fp, a: Fp) =
|
||||||
## Squaring modulo p
|
## Squaring modulo p
|
||||||
r.mres.montySquare(a.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord(), Fp.C.canUseNoCarryMontyMul())
|
r.mres.montySquare(a.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord(), Fp.C.canUseNoCarryMontySquare())
|
||||||
|
|
||||||
func neg*(r: var Fp, a: Fp) =
|
func neg*(r: var Fp, a: Fp) =
|
||||||
## Negate modulo p
|
## Negate modulo p
|
||||||
|
|
|
@ -86,7 +86,7 @@ macro staticFor(idx: untyped{nkIdent}, start, stopEx: static int, body: untyped)
|
||||||
# the code generated is already big enough for curve with different
|
# the code generated is already big enough for curve with different
|
||||||
# limb sizes, we want to use the same codepath when limbs lenght are compatible.
|
# limb sizes, we want to use the same codepath when limbs lenght are compatible.
|
||||||
|
|
||||||
func montyMul_CIOS_nocarry_unrolled(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
|
func montyMul_CIOS_nocarry(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
|
||||||
## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS)
|
## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS)
|
||||||
## and no-carry optimization.
|
## and no-carry optimization.
|
||||||
## This requires the most significant word of the Modulus
|
## This requires the most significant word of the Modulus
|
||||||
|
@ -137,23 +137,21 @@ func montyMul_CIOS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
|
||||||
var tNp1: Carry
|
var tNp1: Carry
|
||||||
|
|
||||||
staticFor i, 0, N:
|
staticFor i, 0, N:
|
||||||
var C = Zero
|
var A = Zero
|
||||||
|
|
||||||
# Multiplication
|
# Multiplication
|
||||||
staticFor j, 0, N:
|
staticFor j, 0, N:
|
||||||
# (C, t[j]) <- a[j] * b[i] + t[j] + C
|
# (A, t[j]) <- a[j] * b[i] + t[j] + A
|
||||||
muladd2(C, t[j], a[j], b[i], t[j], C)
|
muladd2(A, t[j], a[j], b[i], t[j], A)
|
||||||
addC(tNp1, tN, tN, C, Carry(0))
|
addC(tNp1, tN, tN, A, Carry(0))
|
||||||
|
|
||||||
# Reduction
|
# Reduction
|
||||||
# m <- (t[0] * m0ninv) mod 2^w
|
# m <- (t[0] * m0ninv) mod 2^w
|
||||||
# (C, _) <- m * M[0] + t[0]
|
# (C, _) <- m * M[0] + t[0]
|
||||||
var lo: Word
|
var C, lo = Zero
|
||||||
C = Zero
|
|
||||||
let m = t[0] * Word(m0ninv)
|
let m = t[0] * Word(m0ninv)
|
||||||
muladd1(C, lo, m, M[0], t[0])
|
muladd1(C, lo, m, M[0], t[0])
|
||||||
staticFor j, 1, N:
|
staticFor j, 1, N:
|
||||||
# (C, t[j]) <- a[j] * b[i] + t[j] + C
|
# (C, t[j-1]) <- m*M[j] + t[j] + C
|
||||||
muladd2(C, t[j-1], m, M[j], t[j], C)
|
muladd2(C, t[j-1], m, M[j], t[j], C)
|
||||||
|
|
||||||
# (C,t[N-1]) <- t[N] + C
|
# (C,t[N-1]) <- t[N] + C
|
||||||
|
@ -168,6 +166,98 @@ func montyMul_CIOS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
|
||||||
discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)^w
|
discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)^w
|
||||||
r = t
|
r = t
|
||||||
|
|
||||||
|
func montySquare_CIOS_nocarry(r: var Limbs, a, M: Limbs, m0ninv: BaseType) =
|
||||||
|
## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS)
|
||||||
|
## and no-carry optimization.
|
||||||
|
## This requires the most significant word of the Modulus
|
||||||
|
## M[^1] < high(Word) shr 2 (i.e. less than 0b00111...1111)
|
||||||
|
## https://hackmd.io/@zkteam/modular_multiplication
|
||||||
|
|
||||||
|
# We want all the computation to be kept in registers
|
||||||
|
# hence we use a temporary `t`, hoping that the compiler does it.
|
||||||
|
var t: typeof(M) # zero-init
|
||||||
|
const N = t.len
|
||||||
|
staticFor i, 0, N:
|
||||||
|
# Squaring
|
||||||
|
var
|
||||||
|
A1: Carry
|
||||||
|
A0: Word
|
||||||
|
# (A0, t[i]) <- a[i] * a[i] + t[i]
|
||||||
|
muladd1(A0, t[i], a[i], a[i], t[i])
|
||||||
|
staticFor j, i+1, N:
|
||||||
|
# (A1, A0, t[j]) <- 2*a[j]*a[i] + t[j] + (A1, A0)
|
||||||
|
# 2*a[j]*a[i] can spill 1-bit on a 3rd word
|
||||||
|
mulDoubleAdd2(A1, A0, t[j], a[j], a[i], t[j], A1, A0)
|
||||||
|
|
||||||
|
# Reduction
|
||||||
|
# m <- (t[0] * m0ninv) mod 2^w
|
||||||
|
# (C, _) <- m * M[0] + t[0]
|
||||||
|
let m = t[0] * Word(m0ninv)
|
||||||
|
var C, lo: Word
|
||||||
|
muladd1(C, lo, m, M[0], t[0])
|
||||||
|
staticFor j, 1, N:
|
||||||
|
# (C, t[j-1]) <- m*M[j] + t[j] + C
|
||||||
|
muladd2(C, t[j-1], m, M[j], t[j], C)
|
||||||
|
|
||||||
|
t[N-1] = C + A0
|
||||||
|
|
||||||
|
discard t.csub(M, not(t < M))
|
||||||
|
r = t
|
||||||
|
|
||||||
|
func montySquare_CIOS(r: var Limbs, a, M: Limbs, m0ninv: BaseType) =
|
||||||
|
## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS)
|
||||||
|
##
|
||||||
|
## Architectural Support for Long Integer Modulo Arithmetic on Risc-Based Smart Cards
|
||||||
|
## Johann Großschädl, 2003
|
||||||
|
## https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=95950BAC26A728114431C0C7B425E022?doi=10.1.1.115.3276&rep=rep1&type=pdf
|
||||||
|
##
|
||||||
|
## Analyzing and Comparing Montgomery Multiplication Algorithms
|
||||||
|
## Koc, Acar, Kaliski, 1996
|
||||||
|
## https://www.semanticscholar.org/paper/Analyzing-and-comparing-Montgomery-multiplication-Ko%C3%A7-Acar/5e3941ff482ec3ee41dc53c3298f0be085c69483
|
||||||
|
|
||||||
|
# We want all the computation to be kept in registers
|
||||||
|
# hence we use a temporary `t`, hoping that the compiler does it.
|
||||||
|
var t: typeof(M) # zero-init
|
||||||
|
const N = t.len
|
||||||
|
# Extra words to handle up to 2 carries t[N] and t[N+1]
|
||||||
|
var tNp1: Word
|
||||||
|
var tN: Word
|
||||||
|
|
||||||
|
staticFor i, 0, N:
|
||||||
|
# Squaring
|
||||||
|
var
|
||||||
|
A1: Carry
|
||||||
|
A0: Word
|
||||||
|
# (A0, t[i]) <- a[i] * a[i] + t[i]
|
||||||
|
muladd1(A0, t[i], a[i], a[i], t[i])
|
||||||
|
staticFor j, i+1, N:
|
||||||
|
# (A1, A0, t[j]) <- 2*a[j]*a[i] + t[j] + (A1, A0)
|
||||||
|
# 2*a[j]*a[i] can spill 1-bit on a 3rd word
|
||||||
|
mulDoubleAdd2(A1, A0, t[j], a[j], a[i], t[j], A1, A0)
|
||||||
|
|
||||||
|
var carryS: Carry
|
||||||
|
addC(carryS, tN, tN, A0, Carry(0))
|
||||||
|
addC(carryS, tNp1, Word(A1), Zero, carryS)
|
||||||
|
|
||||||
|
# Reduction
|
||||||
|
# m <- (t[0] * m0ninv) mod 2^w
|
||||||
|
# (C, _) <- m * M[0] + t[0]
|
||||||
|
var C, lo: Word
|
||||||
|
let m = t[0] * Word(m0ninv)
|
||||||
|
muladd1(C, lo, m, M[0], t[0])
|
||||||
|
staticFor j, 1, N:
|
||||||
|
# (C, t[j-1]) <- m*M[j] + t[j] + C
|
||||||
|
muladd2(C, t[j-1], m, M[j], t[j], C)
|
||||||
|
|
||||||
|
# (C,t[N-1]) <- t[N] + C
|
||||||
|
# (_, t[N]) <- t[N+1] + C
|
||||||
|
var carryR: Carry
|
||||||
|
addC(carryR, t[N-1], tN, C, Carry(0))
|
||||||
|
addC(carryR, tN, Word(tNp1), Zero, carryR)
|
||||||
|
|
||||||
|
discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)^w
|
||||||
|
r = t
|
||||||
|
|
||||||
# Exported API
|
# Exported API
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
|
@ -206,15 +296,19 @@ func montyMul*(
|
||||||
# of Montgomery-friendly m0ninv if the compiler deems it interesting,
|
# of Montgomery-friendly m0ninv if the compiler deems it interesting,
|
||||||
# or we use `when m0ninv == 1` and enforce the inlining.
|
# or we use `when m0ninv == 1` and enforce the inlining.
|
||||||
when canUseNoCarryMontyMul:
|
when canUseNoCarryMontyMul:
|
||||||
montyMul_CIOS_nocarry_unrolled(r, a, b, M, m0ninv)
|
montyMul_CIOS_nocarry(r, a, b, M, m0ninv)
|
||||||
else:
|
else:
|
||||||
montyMul_CIOS(r, a, b, M, m0ninv)
|
montyMul_CIOS(r, a, b, M, m0ninv)
|
||||||
|
|
||||||
func montySquare*(r: var Limbs, a, M: Limbs,
|
func montySquare*(r: var Limbs, a, M: Limbs,
|
||||||
m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) {.inline.} =
|
m0ninv: static BaseType, canUseNoCarryMontySquare: static bool) {.inline.} =
|
||||||
## Compute r <- a^2 (mod M) in the Montgomery domain
|
## Compute r <- a^2 (mod M) in the Montgomery domain
|
||||||
## `negInvModWord` = -1/M (mod Word). Our words are 2^31 or 2^63
|
## `negInvModWord` = -1/M (mod Word). Our words are 2^31 or 2^63
|
||||||
montyMul(r, a, a, M, m0ninv, canUseNoCarryMontyMul)
|
|
||||||
|
when canUseNoCarryMontySquare:
|
||||||
|
montySquare_CIOS_nocarry(r, a, M, m0ninv)
|
||||||
|
else:
|
||||||
|
montySquare_CIOS(r, a, M, m0ninv)
|
||||||
|
|
||||||
func redc*(r: var Limbs, a, one, M: Limbs,
|
func redc*(r: var Limbs, a, one, M: Limbs,
|
||||||
m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) {.inline.} =
|
m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) {.inline.} =
|
||||||
|
|
|
@ -128,6 +128,14 @@ func useNoCarryMontyMul*(M: BigInt): bool =
|
||||||
# https://github.com/nim-lang/Nim/issues/9679
|
# https://github.com/nim-lang/Nim/issues/9679
|
||||||
BaseType(M.limbs[^1]) < high(BaseType) shr 1
|
BaseType(M.limbs[^1]) < high(BaseType) shr 1
|
||||||
|
|
||||||
|
func useNoCarryMontySquare*(M: BigInt): bool =
|
||||||
|
## Returns if the modulus is compatible
|
||||||
|
## with the no-carry Montgomery Squaring
|
||||||
|
## from https://hackmd.io/@zkteam/modular_multiplication
|
||||||
|
# Indirection needed because static object are buggy
|
||||||
|
# https://github.com/nim-lang/Nim/issues/9679
|
||||||
|
BaseType(M.limbs[^1]) < high(BaseType) shr 2
|
||||||
|
|
||||||
func negInvModWord*(M: BigInt): BaseType =
|
func negInvModWord*(M: BigInt): BaseType =
|
||||||
## Returns the Montgomery domain magic constant for the input modulus:
|
## Returns the Montgomery domain magic constant for the input modulus:
|
||||||
##
|
##
|
||||||
|
|
|
@ -52,6 +52,9 @@ when not defined(testingCurves):
|
||||||
bitsize: 381
|
bitsize: 381
|
||||||
modulus: "0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab"
|
modulus: "0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab"
|
||||||
# Equation: y^2 = x^3 + 4
|
# Equation: y^2 = x^3 + 4
|
||||||
|
curve P224: # NIST P-224
|
||||||
|
bitsize: 224
|
||||||
|
modulus: "0xffffffff_ffffffff_ffffffff_ffffffff_00000000_00000000_00000001"
|
||||||
curve P256: # secp256r1 / NIST P-256
|
curve P256: # secp256r1 / NIST P-256
|
||||||
bitsize: 256
|
bitsize: 256
|
||||||
modulus: "0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
|
modulus: "0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
|
||||||
|
@ -70,6 +73,9 @@ else:
|
||||||
curve Mersenne127:
|
curve Mersenne127:
|
||||||
bitsize: 127
|
bitsize: 127
|
||||||
modulus: "0x7fffffffffffffffffffffffffffffff" # 2^127 - 1
|
modulus: "0x7fffffffffffffffffffffffffffffff" # 2^127 - 1
|
||||||
|
curve P224: # NIST P-224
|
||||||
|
bitsize: 224
|
||||||
|
modulus: "0xffffffff_ffffffff_ffffffff_ffffffff_00000000_00000000_00000001"
|
||||||
curve P256: # secp256r1 / NIST P-256
|
curve P256: # secp256r1 / NIST P-256
|
||||||
bitsize: 256
|
bitsize: 256
|
||||||
modulus: "0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
|
modulus: "0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
|
||||||
|
@ -120,6 +126,17 @@ macro genMontyMagics(T: typed): untyped =
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# const MyCurve_CanUseNoCarryMontySquare = useNoCarryMontySquare(MyCurve_Modulus)
|
||||||
|
result.add newConstStmt(
|
||||||
|
ident($curve & "_CanUseNoCarryMontySquare"), newCall(
|
||||||
|
bindSym"useNoCarryMontySquare",
|
||||||
|
nnkDotExpr.newTree(
|
||||||
|
bindSym($curve & "_Modulus"),
|
||||||
|
ident"mres"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# const MyCurve_R2modP = r2mod(MyCurve_Modulus)
|
# const MyCurve_R2modP = r2mod(MyCurve_Modulus)
|
||||||
result.add newConstStmt(
|
result.add newConstStmt(
|
||||||
ident($curve & "_R2modP"), newCall(
|
ident($curve & "_R2modP"), newCall(
|
||||||
|
@ -170,6 +187,11 @@ macro canUseNoCarryMontyMul*(C: static Curve): untyped =
|
||||||
## Montgomery multiplication that avoids many carries
|
## Montgomery multiplication that avoids many carries
|
||||||
result = bindSym($C & "_CanUseNoCarryMontyMul")
|
result = bindSym($C & "_CanUseNoCarryMontyMul")
|
||||||
|
|
||||||
|
macro canUseNoCarryMontySquare*(C: static Curve): untyped =
|
||||||
|
## Returns true if the Modulus is compatible with a fast
|
||||||
|
## Montgomery squaring that avoids many carries
|
||||||
|
result = bindSym($C & "_CanUseNoCarryMontySquare")
|
||||||
|
|
||||||
macro getR2modP*(C: static Curve): untyped =
|
macro getR2modP*(C: static Curve): untyped =
|
||||||
## Get the Montgomery "R^2 mod P" constant associated to a curve field modulus
|
## Get the Montgomery "R^2 mod P" constant associated to a curve field modulus
|
||||||
result = bindSym($C & "_R2modP")
|
result = bindSym($C & "_R2modP")
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#
|
#
|
||||||
# ############################################################
|
# ############################################################
|
||||||
|
|
||||||
import ./constant_time_types
|
import ./constant_time_types, ./addcarry_subborrow
|
||||||
|
|
||||||
# ############################################################
|
# ############################################################
|
||||||
#
|
#
|
||||||
|
@ -37,6 +37,16 @@ func unsafeDiv2n1n*(q, r: var Ct[uint32], n_hi, n_lo, d: Ct[uint32]) {.inline.}=
|
||||||
q = (Ct[uint32])(dividend div divisor)
|
q = (Ct[uint32])(dividend div divisor)
|
||||||
r = (Ct[uint32])(dividend mod divisor)
|
r = (Ct[uint32])(dividend mod divisor)
|
||||||
|
|
||||||
|
func mul*(hi, lo: var Ct[uint32], a, b: Ct[uint32]) {.inline.} =
|
||||||
|
## Extended precision multiplication
|
||||||
|
## (hi, lo) <- a*b
|
||||||
|
##
|
||||||
|
## This is constant-time on most hardware
|
||||||
|
## See: https://www.bearssl.org/ctmul.html
|
||||||
|
let dblPrec = uint64(a) * uint64(b)
|
||||||
|
lo = (Ct[uint32])(dblPrec)
|
||||||
|
hi = (Ct[uint32])(dblPrec shr 32)
|
||||||
|
|
||||||
func muladd1*(hi, lo: var Ct[uint32], a, b, c: Ct[uint32]) {.inline.} =
|
func muladd1*(hi, lo: var Ct[uint32], a, b, c: Ct[uint32]) {.inline.} =
|
||||||
## Extended precision multiplication + addition
|
## Extended precision multiplication + addition
|
||||||
## (hi, lo) <- a*b + c
|
## (hi, lo) <- a*b + c
|
||||||
|
@ -70,13 +80,49 @@ func muladd2*(hi, lo: var Ct[uint32], a, b, c1, c2: Ct[uint32]) {.inline.}=
|
||||||
|
|
||||||
when sizeof(int) == 8:
|
when sizeof(int) == 8:
|
||||||
when defined(vcc):
|
when defined(vcc):
|
||||||
from ./extended_precision_x86_64_msvc import unsafeDiv2n1n, muladd1, muladd2
|
from ./extended_precision_x86_64_msvc import unsafeDiv2n1n, mul, muladd1, muladd2
|
||||||
elif GCCCompatible:
|
elif GCCCompatible:
|
||||||
# TODO: constant-time div2n1n
|
# TODO: constant-time div2n1n
|
||||||
when X86:
|
when X86:
|
||||||
from ./extended_precision_x86_64_gcc import unsafeDiv2n1n
|
from ./extended_precision_x86_64_gcc import unsafeDiv2n1n
|
||||||
from ./extended_precision_64bit_uint128 import muladd1, muladd2
|
from ./extended_precision_64bit_uint128 import mul, muladd1, muladd2
|
||||||
else:
|
else:
|
||||||
from ./extended_precision_64bit_uint128 import unsafeDiv2n1n, muladd1, muladd2
|
from ./extended_precision_64bit_uint128 import unsafeDiv2n1n, mul, muladd1, muladd2
|
||||||
|
|
||||||
export unsafeDiv2n1n, muladd1, muladd2
|
export unsafeDiv2n1n, muladd1, muladd2
|
||||||
|
|
||||||
|
# ############################################################
|
||||||
|
#
|
||||||
|
# Composite primitives
|
||||||
|
#
|
||||||
|
# ############################################################
|
||||||
|
|
||||||
|
func mulDoubleAdd2*[T: Ct[uint32]|Ct[uint64]](r2: var Carry, r1, r0: var T, a, b, c: T, dHi: Carry, dLo: T) {.inline.} =
|
||||||
|
## (r2, r1, r0) <- 2*a*b + c + (dHi, dLo)
|
||||||
|
## with r = (r2, r1, r0) a triple-word number
|
||||||
|
## and d = (dHi, dLo) a double-word number
|
||||||
|
## r2 and dHi are carries, either 0 or 1
|
||||||
|
|
||||||
|
var carry: Carry
|
||||||
|
|
||||||
|
# (r1, r0) <- a*b
|
||||||
|
# Note: 0xFFFFFFFF_FFFFFFFF² -> (hi: 0xFFFFFFFF_FFFFFFFE, lo: 0x00000000_00000001)
|
||||||
|
mul(r1, r0, a, b)
|
||||||
|
|
||||||
|
# (r2, r1, r0) <- 2*a*b
|
||||||
|
# Then (hi: 0xFFFFFFFF_FFFFFFFE, lo: 0x00000000_00000001) * 2
|
||||||
|
# (carry: 1, hi: 0xFFFFFFFF_FFFFFFFC, lo: 0x00000000_00000002)
|
||||||
|
addC(carry, r0, r0, r0, Carry(0))
|
||||||
|
addC(r2, r1, r1, r1, carry)
|
||||||
|
|
||||||
|
# (r1, r0) <- (r1, r0) + c
|
||||||
|
# Adding any uint64 cannot overflow into r2 for example Adding 2^64-1
|
||||||
|
# (carry: 1, hi: 0xFFFFFFFF_FFFFFFFD, lo: 0x00000000_00000001)
|
||||||
|
addC(carry, r0, r0, c, Carry(0))
|
||||||
|
addC(carry, r1, r1, T(0), carry)
|
||||||
|
|
||||||
|
# (r1, r0) <- (r1, r0) + (dHi, dLo) with dHi a carry (previous limb r2)
|
||||||
|
# (dHi, dLo) is at most (dhi: 1, dlo: 0xFFFFFFFF_FFFFFFFF)
|
||||||
|
# summing into (carry: 1, hi: 0xFFFFFFFF_FFFFFFFD, lo: 0x00000000_00000001)
|
||||||
|
# result at most in (carry: 1, hi: 0xFFFFFFFF_FFFFFFFF, lo: 0x00000000_00000000)
|
||||||
|
addC(carry, r0, r0, dLo, Carry(0))
|
||||||
|
addC(carry, r1, r1, T(dHi), carry)
|
||||||
|
|
|
@ -36,6 +36,24 @@ func unsafeDiv2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}=
|
||||||
{.emit:["*",q, " = (NU64)(", dblPrec," / ", d, ");"].}
|
{.emit:["*",q, " = (NU64)(", dblPrec," / ", d, ");"].}
|
||||||
{.emit:["*",r, " = (NU64)(", dblPrec," % ", d, ");"].}
|
{.emit:["*",r, " = (NU64)(", dblPrec," % ", d, ");"].}
|
||||||
|
|
||||||
|
func mul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} =
|
||||||
|
## Extended precision multiplication
|
||||||
|
## (hi, lo) <- a*b
|
||||||
|
##
|
||||||
|
## This is constant-time on most hardware
|
||||||
|
## See: https://www.bearssl.org/ctmul.html
|
||||||
|
block:
|
||||||
|
var dblPrec {.noInit.}: uint128
|
||||||
|
{.emit:[dblPrec, " = (unsigned __int128)", a," * (unsigned __int128)", b,";"].}
|
||||||
|
|
||||||
|
# Don't forget to dereference the var param in C mode
|
||||||
|
when defined(cpp):
|
||||||
|
{.emit:[hi, " = (NU64)(", dblPrec," >> ", 64'u64, ");"].}
|
||||||
|
{.emit:[lo, " = (NU64)", dblPrec,";"].}
|
||||||
|
else:
|
||||||
|
{.emit:["*",hi, " = (NU64)(", dblPrec," >> ", 64'u64, ");"].}
|
||||||
|
{.emit:["*",lo, " = (NU64)", dblPrec,";"].}
|
||||||
|
|
||||||
func muladd1*(hi, lo: var Ct[uint64], a, b, c: Ct[uint64]) {.inline.} =
|
func muladd1*(hi, lo: var Ct[uint64], a, b, c: Ct[uint64]) {.inline.} =
|
||||||
## Extended precision multiplication + addition
|
## Extended precision multiplication + addition
|
||||||
## (hi, lo) <- a*b + c
|
## (hi, lo) <- a*b + c
|
||||||
|
|
|
@ -43,6 +43,14 @@ func unsafeDiv2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}=
|
||||||
# -> use uint128? Compiler might add unwanted branches
|
# -> use uint128? Compiler might add unwanted branches
|
||||||
q = udiv128(n_hi, n_lo, d, r)
|
q = udiv128(n_hi, n_lo, d, r)
|
||||||
|
|
||||||
|
func mul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} =
|
||||||
|
## Extended precision multiplication
|
||||||
|
## (hi, lo) <- a*b
|
||||||
|
##
|
||||||
|
## This is constant-time on most hardware
|
||||||
|
## See: https://www.bearssl.org/ctmul.html
|
||||||
|
lo = umul128(a, b, hi)
|
||||||
|
|
||||||
func muladd1*(hi, lo: var Ct[uint64], a, b, c: Ct[uint64]) {.inline.} =
|
func muladd1*(hi, lo: var Ct[uint64], a, b, c: Ct[uint64]) {.inline.} =
|
||||||
## Extended precision multiplication + addition
|
## Extended precision multiplication + addition
|
||||||
## (hi, lo) <- a*b + c
|
## (hi, lo) <- a*b + c
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
# Constantine
|
||||||
|
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||||||
|
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||||||
|
# Licensed and distributed under either of
|
||||||
|
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||||
|
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||||
|
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||||
|
|
||||||
|
import std/unittest, std/times,
|
||||||
|
../constantine/arithmetic/[bigints, finite_fields],
|
||||||
|
../constantine/io/[io_bigints, io_fields],
|
||||||
|
../constantine/config/curves,
|
||||||
|
# Test utilities
|
||||||
|
./prng
|
||||||
|
|
||||||
|
const Iters = 128
|
||||||
|
|
||||||
|
var rng: RngState
|
||||||
|
let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
|
||||||
|
rng.seed(seed)
|
||||||
|
echo "test_finite_fields_mulsquare xoshiro512** seed: ", seed
|
||||||
|
|
||||||
|
static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option"
|
||||||
|
|
||||||
|
import ../constantine/config/common
|
||||||
|
|
||||||
|
proc sanity(C: static Curve) =
|
||||||
|
test "Squaring 0,1,2 with "& $C & " [FastSquaring = " & $Fake101.canUseNoCarryMontySquare & "]":
|
||||||
|
block: # 0² mod
|
||||||
|
var n: Fp[C]
|
||||||
|
|
||||||
|
n.fromUint(0'u32)
|
||||||
|
let expected = n
|
||||||
|
|
||||||
|
var r: Fp[C]
|
||||||
|
r.square(n)
|
||||||
|
|
||||||
|
check: bool(r == expected)
|
||||||
|
|
||||||
|
block: # 1² mod
|
||||||
|
var n: Fp[C]
|
||||||
|
|
||||||
|
n.fromUint(1'u32)
|
||||||
|
let expected = n
|
||||||
|
|
||||||
|
var r: Fp[C]
|
||||||
|
r.square(n)
|
||||||
|
|
||||||
|
check: bool(r == expected)
|
||||||
|
|
||||||
|
block: # 2² mod
|
||||||
|
var n, expected: Fp[C]
|
||||||
|
|
||||||
|
n.fromUint(2'u32)
|
||||||
|
expected.fromUint(4'u32)
|
||||||
|
|
||||||
|
var r: Fp[C]
|
||||||
|
r.square(n)
|
||||||
|
|
||||||
|
check: bool(r == expected)
|
||||||
|
|
||||||
|
proc mainSanity() =
|
||||||
|
suite "Modular squaring is consistent with multiplication on special elements":
|
||||||
|
sanity Fake101
|
||||||
|
sanity Mersenne61
|
||||||
|
sanity Mersenne127
|
||||||
|
sanity P224 # P224 uses the fast-path with 64-bit words and the slow path with 32-bit words
|
||||||
|
sanity P256
|
||||||
|
sanity BLS12_381
|
||||||
|
|
||||||
|
mainSanity()
|
||||||
|
|
||||||
|
proc mainSelectCases() =
|
||||||
|
suite "Modular Squaring: selected tricky cases":
|
||||||
|
test "P-256 [FastSquaring = " & $P256.canUseNoCarryMontySquare & "]":
|
||||||
|
block:
|
||||||
|
# Triggered an issue in the (t[N+1], t[N]) = t[N] + (A1, A0)
|
||||||
|
# between the squaring and reduction step, with t[N+1] and A1 being carry bits.
|
||||||
|
var a: Fp[P256]
|
||||||
|
a.fromHex"0xa0da36b4885df98997ee89a22a7ceb64fa431b2ecc87342fc083587da3d6ebc7"
|
||||||
|
|
||||||
|
var r_mul, r_sqr: Fp[P256]
|
||||||
|
|
||||||
|
r_mul.prod(a, a)
|
||||||
|
r_sqr.square(a)
|
||||||
|
|
||||||
|
doAssert bool(r_mul == r_sqr)
|
||||||
|
|
||||||
|
mainSelectCases()
|
||||||
|
|
||||||
|
proc randomCurve(C: static Curve) =
|
||||||
|
let a = rng.random(Fp[C])
|
||||||
|
|
||||||
|
var r_mul, r_sqr: Fp[C]
|
||||||
|
|
||||||
|
r_mul.prod(a, a)
|
||||||
|
r_sqr.square(a)
|
||||||
|
|
||||||
|
doAssert bool(r_mul == r_sqr)
|
||||||
|
|
||||||
|
suite "Random Modular Squaring is consistent with Modular Multiplication":
|
||||||
|
test "Random squaring mod P-224 [FastSquaring = " & $P224.canUseNoCarryMontySquare & "]":
|
||||||
|
for _ in 0 ..< Iters:
|
||||||
|
randomCurve(P224)
|
||||||
|
|
||||||
|
test "Random squaring mod P-256 [FastSquaring = " & $P256.canUseNoCarryMontySquare & "]":
|
||||||
|
for _ in 0 ..< Iters:
|
||||||
|
randomCurve(P256)
|
||||||
|
|
||||||
|
test "Random squaring mod BLS12_381 [FastSquaring = " & $BLS12_381.canUseNoCarryMontySquare & "]":
|
||||||
|
for _ in 0 ..< Iters:
|
||||||
|
randomCurve(BLS12_381)
|
|
@ -0,0 +1 @@
|
||||||
|
-d:testingCurves
|
|
@ -8,11 +8,9 @@
|
||||||
|
|
||||||
import unittest,
|
import unittest,
|
||||||
../constantine/arithmetic/[bigints, finite_fields],
|
../constantine/arithmetic/[bigints, finite_fields],
|
||||||
../constantine/io/io_fields,
|
../constantine/io/[io_bigints, io_fields],
|
||||||
../constantine/config/curves
|
../constantine/config/curves
|
||||||
|
|
||||||
import ../constantine/io/io_bigints
|
|
||||||
|
|
||||||
static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option"
|
static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option"
|
||||||
|
|
||||||
proc main() =
|
proc main() =
|
||||||
|
|
Loading…
Reference in New Issue