Add optimized squaring (~15% speedup) (#18)

* Add optimized squaring (~15% speedup)

* avoid repetitions in tests
This commit is contained in:
Mamy Ratsimbazafy 2020-03-17 22:04:37 +01:00 committed by GitHub
parent 4ff0e3d90b
commit 6423be0dfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 362 additions and 23 deletions

View File

@ -29,52 +29,84 @@ proc test(flags, path: string) =
### tasks
task test, "Run all tests":
# -d:testingCurves is configured in a *.nim.cfg for convenience
# Primitives
test "", "tests/test_primitives.nim"
# Big ints
test "", "tests/test_io_bigints.nim"
test "", "tests/test_bigints.nim"
test "", "tests/test_bigints_multimod.nim"
test "", "tests/test_bigints_vs_gmp.nim"
# Field
test "", "tests/test_io_fields"
test "", "tests/test_finite_fields.nim"
test "", "tests/test_finite_fields_mulsquare.nim"
test "", "tests/test_finite_fields_powinv.nim"
test "", "tests/test_bigints_vs_gmp.nim"
test "", "tests/test_finite_fields_vs_gmp.nim"
# 𝔽p2
test "", "tests/test_fp2.nim"
if sizeof(int) == 8: # 32-bit tests
# Primitives
test "-d:Constantine32", "tests/test_primitives.nim"
# Big ints
test "-d:Constantine32", "tests/test_io_bigints.nim"
test "-d:Constantine32", "tests/test_bigints.nim"
test "-d:Constantine32", "tests/test_bigints_multimod.nim"
test "-d:Constantine32", "tests/test_bigints_vs_gmp.nim"
# Field
test "-d:Constantine32", "tests/test_io_fields"
test "-d:Constantine32", "tests/test_finite_fields.nim"
test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim"
test "-d:Constantine32", "tests/test_finite_fields_powinv.nim"
test "-d:Constantine32", "tests/test_bigints_vs_gmp.nim"
test "-d:Constantine32", "tests/test_finite_fields_vs_gmp.nim"
# 𝔽p2
test "", "tests/test_fp2.nim"
task test_no_gmp, "Run tests that don't require GMP":
# -d:testingCurves is configured in a *.nim.cfg for convenience
# Primitives
test "", "tests/test_primitives.nim"
# Big ints
test "", "tests/test_io_bigints.nim"
test "", "tests/test_bigints.nim"
test "", "tests/test_bigints_multimod.nim"
# Field
test "", "tests/test_io_fields"
test "", "tests/test_finite_fields.nim"
test "", "tests/test_finite_fields_mulsquare.nim"
test "", "tests/test_finite_fields_powinv.nim"
# 𝔽p2
test "", "tests/test_fp2.nim"
if sizeof(int) == 8: # 32-bit tests
# Primitives
test "-d:Constantine32", "tests/test_primitives.nim"
# Big ints
test "-d:Constantine32", "tests/test_io_bigints.nim"
test "-d:Constantine32", "tests/test_bigints.nim"
test "-d:Constantine32", "tests/test_bigints_multimod.nim"
# Field
test "-d:Constantine32", "tests/test_io_fields"
test "-d:Constantine32", "tests/test_finite_fields.nim"
test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim"
test "-d:Constantine32", "tests/test_finite_fields_powinv.nim"
# 𝔽p2
test "", "tests/test_fp2.nim"

View File

@ -148,7 +148,7 @@ func prod*(r: var Fp, a, b: Fp) =
func square*(r: var Fp, a: Fp) =
## Squaring modulo p
r.mres.montySquare(a.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord(), Fp.C.canUseNoCarryMontyMul())
r.mres.montySquare(a.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord(), Fp.C.canUseNoCarryMontySquare())
func neg*(r: var Fp, a: Fp) =
## Negate modulo p

View File

@ -86,7 +86,7 @@ macro staticFor(idx: untyped{nkIdent}, start, stopEx: static int, body: untyped)
# the code generated is already big enough for curve with different
# limb sizes, we want to use the same codepath when limbs lenght are compatible.
func montyMul_CIOS_nocarry_unrolled(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
func montyMul_CIOS_nocarry(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS)
## and no-carry optimization.
## This requires the most significant word of the Modulus
@ -137,23 +137,21 @@ func montyMul_CIOS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
var tNp1: Carry
staticFor i, 0, N:
var C = Zero
var A = Zero
# Multiplication
staticFor j, 0, N:
# (C, t[j]) <- a[j] * b[i] + t[j] + C
muladd2(C, t[j], a[j], b[i], t[j], C)
addC(tNp1, tN, tN, C, Carry(0))
# (A, t[j]) <- a[j] * b[i] + t[j] + A
muladd2(A, t[j], a[j], b[i], t[j], A)
addC(tNp1, tN, tN, A, Carry(0))
# Reduction
# m <- (t[0] * m0ninv) mod 2^w
# (C, _) <- m * M[0] + t[0]
var lo: Word
C = Zero
var C, lo = Zero
let m = t[0] * Word(m0ninv)
muladd1(C, lo, m, M[0], t[0])
staticFor j, 1, N:
# (C, t[j]) <- a[j] * b[i] + t[j] + C
# (C, t[j-1]) <- m*M[j] + t[j] + C
muladd2(C, t[j-1], m, M[j], t[j], C)
# (C,t[N-1]) <- t[N] + C
@ -168,6 +166,98 @@ func montyMul_CIOS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)^w
r = t
func montySquare_CIOS_nocarry(r: var Limbs, a, M: Limbs, m0ninv: BaseType) =
## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS)
## and no-carry optimization.
## This requires the most significant word of the Modulus
## M[^1] < high(Word) shr 2 (i.e. less than 0b00111...1111)
## https://hackmd.io/@zkteam/modular_multiplication
# We want all the computation to be kept in registers
# hence we use a temporary `t`, hoping that the compiler does it.
var t: typeof(M) # zero-init
const N = t.len
staticFor i, 0, N:
# Squaring
var
A1: Carry
A0: Word
# (A0, t[i]) <- a[i] * a[i] + t[i]
muladd1(A0, t[i], a[i], a[i], t[i])
staticFor j, i+1, N:
# (A1, A0, t[j]) <- 2*a[j]*a[i] + t[j] + (A1, A0)
# 2*a[j]*a[i] can spill 1-bit on a 3rd word
mulDoubleAdd2(A1, A0, t[j], a[j], a[i], t[j], A1, A0)
# Reduction
# m <- (t[0] * m0ninv) mod 2^w
# (C, _) <- m * M[0] + t[0]
let m = t[0] * Word(m0ninv)
var C, lo: Word
muladd1(C, lo, m, M[0], t[0])
staticFor j, 1, N:
# (C, t[j-1]) <- m*M[j] + t[j] + C
muladd2(C, t[j-1], m, M[j], t[j], C)
t[N-1] = C + A0
discard t.csub(M, not(t < M))
r = t
func montySquare_CIOS(r: var Limbs, a, M: Limbs, m0ninv: BaseType) =
## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS)
##
## Architectural Support for Long Integer Modulo Arithmetic on Risc-Based Smart Cards
## Johann Großschädl, 2003
## https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=95950BAC26A728114431C0C7B425E022?doi=10.1.1.115.3276&rep=rep1&type=pdf
##
## Analyzing and Comparing Montgomery Multiplication Algorithms
## Koc, Acar, Kaliski, 1996
## https://www.semanticscholar.org/paper/Analyzing-and-comparing-Montgomery-multiplication-Ko%C3%A7-Acar/5e3941ff482ec3ee41dc53c3298f0be085c69483
# We want all the computation to be kept in registers
# hence we use a temporary `t`, hoping that the compiler does it.
var t: typeof(M) # zero-init
const N = t.len
# Extra words to handle up to 2 carries t[N] and t[N+1]
var tNp1: Word
var tN: Word
staticFor i, 0, N:
# Squaring
var
A1: Carry
A0: Word
# (A0, t[i]) <- a[i] * a[i] + t[i]
muladd1(A0, t[i], a[i], a[i], t[i])
staticFor j, i+1, N:
# (A1, A0, t[j]) <- 2*a[j]*a[i] + t[j] + (A1, A0)
# 2*a[j]*a[i] can spill 1-bit on a 3rd word
mulDoubleAdd2(A1, A0, t[j], a[j], a[i], t[j], A1, A0)
var carryS: Carry
addC(carryS, tN, tN, A0, Carry(0))
addC(carryS, tNp1, Word(A1), Zero, carryS)
# Reduction
# m <- (t[0] * m0ninv) mod 2^w
# (C, _) <- m * M[0] + t[0]
var C, lo: Word
let m = t[0] * Word(m0ninv)
muladd1(C, lo, m, M[0], t[0])
staticFor j, 1, N:
# (C, t[j-1]) <- m*M[j] + t[j] + C
muladd2(C, t[j-1], m, M[j], t[j], C)
# (C,t[N-1]) <- t[N] + C
# (_, t[N]) <- t[N+1] + C
var carryR: Carry
addC(carryR, t[N-1], tN, C, Carry(0))
addC(carryR, tN, Word(tNp1), Zero, carryR)
discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)^w
r = t
# Exported API
# ------------------------------------------------------------
@ -206,15 +296,19 @@ func montyMul*(
# of Montgomery-friendly m0ninv if the compiler deems it interesting,
# or we use `when m0ninv == 1` and enforce the inlining.
when canUseNoCarryMontyMul:
montyMul_CIOS_nocarry_unrolled(r, a, b, M, m0ninv)
montyMul_CIOS_nocarry(r, a, b, M, m0ninv)
else:
montyMul_CIOS(r, a, b, M, m0ninv)
func montySquare*(r: var Limbs, a, M: Limbs,
m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) {.inline.} =
m0ninv: static BaseType, canUseNoCarryMontySquare: static bool) {.inline.} =
## Compute r <- a^2 (mod M) in the Montgomery domain
## `negInvModWord` = -1/M (mod Word). Our words are 2^31 or 2^63
montyMul(r, a, a, M, m0ninv, canUseNoCarryMontyMul)
when canUseNoCarryMontySquare:
montySquare_CIOS_nocarry(r, a, M, m0ninv)
else:
montySquare_CIOS(r, a, M, m0ninv)
func redc*(r: var Limbs, a, one, M: Limbs,
m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) {.inline.} =

View File

@ -128,6 +128,14 @@ func useNoCarryMontyMul*(M: BigInt): bool =
# https://github.com/nim-lang/Nim/issues/9679
BaseType(M.limbs[^1]) < high(BaseType) shr 1
func useNoCarryMontySquare*(M: BigInt): bool =
## Returns if the modulus is compatible
## with the no-carry Montgomery Squaring
## from https://hackmd.io/@zkteam/modular_multiplication
# Indirection needed because static object are buggy
# https://github.com/nim-lang/Nim/issues/9679
BaseType(M.limbs[^1]) < high(BaseType) shr 2
func negInvModWord*(M: BigInt): BaseType =
## Returns the Montgomery domain magic constant for the input modulus:
##

View File

@ -52,6 +52,9 @@ when not defined(testingCurves):
bitsize: 381
modulus: "0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab"
# Equation: y^2 = x^3 + 4
curve P224: # NIST P-224
bitsize: 224
modulus: "0xffffffff_ffffffff_ffffffff_ffffffff_00000000_00000000_00000001"
curve P256: # secp256r1 / NIST P-256
bitsize: 256
modulus: "0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
@ -70,6 +73,9 @@ else:
curve Mersenne127:
bitsize: 127
modulus: "0x7fffffffffffffffffffffffffffffff" # 2^127 - 1
curve P224: # NIST P-224
bitsize: 224
modulus: "0xffffffff_ffffffff_ffffffff_ffffffff_00000000_00000000_00000001"
curve P256: # secp256r1 / NIST P-256
bitsize: 256
modulus: "0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
@ -120,6 +126,17 @@ macro genMontyMagics(T: typed): untyped =
)
)
# const MyCurve_CanUseNoCarryMontySquare = useNoCarryMontySquare(MyCurve_Modulus)
result.add newConstStmt(
ident($curve & "_CanUseNoCarryMontySquare"), newCall(
bindSym"useNoCarryMontySquare",
nnkDotExpr.newTree(
bindSym($curve & "_Modulus"),
ident"mres"
)
)
)
# const MyCurve_R2modP = r2mod(MyCurve_Modulus)
result.add newConstStmt(
ident($curve & "_R2modP"), newCall(
@ -170,6 +187,11 @@ macro canUseNoCarryMontyMul*(C: static Curve): untyped =
## Montgomery multiplication that avoids many carries
result = bindSym($C & "_CanUseNoCarryMontyMul")
macro canUseNoCarryMontySquare*(C: static Curve): untyped =
## Returns true if the Modulus is compatible with a fast
## Montgomery squaring that avoids many carries
result = bindSym($C & "_CanUseNoCarryMontySquare")
macro getR2modP*(C: static Curve): untyped =
## Get the Montgomery "R^2 mod P" constant associated to a curve field modulus
result = bindSym($C & "_R2modP")

View File

@ -12,7 +12,7 @@
#
# ############################################################
import ./constant_time_types
import ./constant_time_types, ./addcarry_subborrow
# ############################################################
#
@ -37,6 +37,16 @@ func unsafeDiv2n1n*(q, r: var Ct[uint32], n_hi, n_lo, d: Ct[uint32]) {.inline.}=
q = (Ct[uint32])(dividend div divisor)
r = (Ct[uint32])(dividend mod divisor)
func mul*(hi, lo: var Ct[uint32], a, b: Ct[uint32]) {.inline.} =
## Extended precision multiplication
## (hi, lo) <- a*b
##
## This is constant-time on most hardware
## See: https://www.bearssl.org/ctmul.html
let dblPrec = uint64(a) * uint64(b)
lo = (Ct[uint32])(dblPrec)
hi = (Ct[uint32])(dblPrec shr 32)
func muladd1*(hi, lo: var Ct[uint32], a, b, c: Ct[uint32]) {.inline.} =
## Extended precision multiplication + addition
## (hi, lo) <- a*b + c
@ -70,13 +80,49 @@ func muladd2*(hi, lo: var Ct[uint32], a, b, c1, c2: Ct[uint32]) {.inline.}=
when sizeof(int) == 8:
when defined(vcc):
from ./extended_precision_x86_64_msvc import unsafeDiv2n1n, muladd1, muladd2
from ./extended_precision_x86_64_msvc import unsafeDiv2n1n, mul, muladd1, muladd2
elif GCCCompatible:
# TODO: constant-time div2n1n
when X86:
from ./extended_precision_x86_64_gcc import unsafeDiv2n1n
from ./extended_precision_64bit_uint128 import muladd1, muladd2
from ./extended_precision_64bit_uint128 import mul, muladd1, muladd2
else:
from ./extended_precision_64bit_uint128 import unsafeDiv2n1n, muladd1, muladd2
from ./extended_precision_64bit_uint128 import unsafeDiv2n1n, mul, muladd1, muladd2
export unsafeDiv2n1n, muladd1, muladd2
# ############################################################
#
# Composite primitives
#
# ############################################################
func mulDoubleAdd2*[T: Ct[uint32]|Ct[uint64]](r2: var Carry, r1, r0: var T, a, b, c: T, dHi: Carry, dLo: T) {.inline.} =
## (r2, r1, r0) <- 2*a*b + c + (dHi, dLo)
## with r = (r2, r1, r0) a triple-word number
## and d = (dHi, dLo) a double-word number
## r2 and dHi are carries, either 0 or 1
var carry: Carry
# (r1, r0) <- a*b
# Note: 0xFFFFFFFF_FFFFFFFF² -> (hi: 0xFFFFFFFF_FFFFFFFE, lo: 0x00000000_00000001)
mul(r1, r0, a, b)
# (r2, r1, r0) <- 2*a*b
# Then (hi: 0xFFFFFFFF_FFFFFFFE, lo: 0x00000000_00000001) * 2
# (carry: 1, hi: 0xFFFFFFFF_FFFFFFFC, lo: 0x00000000_00000002)
addC(carry, r0, r0, r0, Carry(0))
addC(r2, r1, r1, r1, carry)
# (r1, r0) <- (r1, r0) + c
# Adding any uint64 cannot overflow into r2 for example Adding 2^64-1
# (carry: 1, hi: 0xFFFFFFFF_FFFFFFFD, lo: 0x00000000_00000001)
addC(carry, r0, r0, c, Carry(0))
addC(carry, r1, r1, T(0), carry)
# (r1, r0) <- (r1, r0) + (dHi, dLo) with dHi a carry (previous limb r2)
# (dHi, dLo) is at most (dhi: 1, dlo: 0xFFFFFFFF_FFFFFFFF)
# summing into (carry: 1, hi: 0xFFFFFFFF_FFFFFFFD, lo: 0x00000000_00000001)
# result at most in (carry: 1, hi: 0xFFFFFFFF_FFFFFFFF, lo: 0x00000000_00000000)
addC(carry, r0, r0, dLo, Carry(0))
addC(carry, r1, r1, T(dHi), carry)

View File

@ -36,6 +36,24 @@ func unsafeDiv2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}=
{.emit:["*",q, " = (NU64)(", dblPrec," / ", d, ");"].}
{.emit:["*",r, " = (NU64)(", dblPrec," % ", d, ");"].}
func mul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} =
## Extended precision multiplication
## (hi, lo) <- a*b
##
## This is constant-time on most hardware
## See: https://www.bearssl.org/ctmul.html
block:
var dblPrec {.noInit.}: uint128
{.emit:[dblPrec, " = (unsigned __int128)", a," * (unsigned __int128)", b,";"].}
# Don't forget to dereference the var param in C mode
when defined(cpp):
{.emit:[hi, " = (NU64)(", dblPrec," >> ", 64'u64, ");"].}
{.emit:[lo, " = (NU64)", dblPrec,";"].}
else:
{.emit:["*",hi, " = (NU64)(", dblPrec," >> ", 64'u64, ");"].}
{.emit:["*",lo, " = (NU64)", dblPrec,";"].}
func muladd1*(hi, lo: var Ct[uint64], a, b, c: Ct[uint64]) {.inline.} =
## Extended precision multiplication + addition
## (hi, lo) <- a*b + c

View File

@ -43,6 +43,14 @@ func unsafeDiv2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}=
# -> use uint128? Compiler might add unwanted branches
q = udiv128(n_hi, n_lo, d, r)
func mul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} =
## Extended precision multiplication
## (hi, lo) <- a*b
##
## This is constant-time on most hardware
## See: https://www.bearssl.org/ctmul.html
lo = umul128(a, b, hi)
func muladd1*(hi, lo: var Ct[uint64], a, b, c: Ct[uint64]) {.inline.} =
## Extended precision multiplication + addition
## (hi, lo) <- a*b + c

View File

@ -0,0 +1,112 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import std/unittest, std/times,
../constantine/arithmetic/[bigints, finite_fields],
../constantine/io/[io_bigints, io_fields],
../constantine/config/curves,
# Test utilities
./prng
const Iters = 128
var rng: RngState
let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
rng.seed(seed)
echo "test_finite_fields_mulsquare xoshiro512** seed: ", seed
static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option"
import ../constantine/config/common
proc sanity(C: static Curve) =
test "Squaring 0,1,2 with "& $C & " [FastSquaring = " & $Fake101.canUseNoCarryMontySquare & "]":
block: # 0² mod
var n: Fp[C]
n.fromUint(0'u32)
let expected = n
var r: Fp[C]
r.square(n)
check: bool(r == expected)
block: # 1² mod
var n: Fp[C]
n.fromUint(1'u32)
let expected = n
var r: Fp[C]
r.square(n)
check: bool(r == expected)
block: # 2² mod
var n, expected: Fp[C]
n.fromUint(2'u32)
expected.fromUint(4'u32)
var r: Fp[C]
r.square(n)
check: bool(r == expected)
proc mainSanity() =
suite "Modular squaring is consistent with multiplication on special elements":
sanity Fake101
sanity Mersenne61
sanity Mersenne127
sanity P224 # P224 uses the fast-path with 64-bit words and the slow path with 32-bit words
sanity P256
sanity BLS12_381
mainSanity()
proc mainSelectCases() =
suite "Modular Squaring: selected tricky cases":
test "P-256 [FastSquaring = " & $P256.canUseNoCarryMontySquare & "]":
block:
# Triggered an issue in the (t[N+1], t[N]) = t[N] + (A1, A0)
# between the squaring and reduction step, with t[N+1] and A1 being carry bits.
var a: Fp[P256]
a.fromHex"0xa0da36b4885df98997ee89a22a7ceb64fa431b2ecc87342fc083587da3d6ebc7"
var r_mul, r_sqr: Fp[P256]
r_mul.prod(a, a)
r_sqr.square(a)
doAssert bool(r_mul == r_sqr)
mainSelectCases()
proc randomCurve(C: static Curve) =
let a = rng.random(Fp[C])
var r_mul, r_sqr: Fp[C]
r_mul.prod(a, a)
r_sqr.square(a)
doAssert bool(r_mul == r_sqr)
suite "Random Modular Squaring is consistent with Modular Multiplication":
test "Random squaring mod P-224 [FastSquaring = " & $P224.canUseNoCarryMontySquare & "]":
for _ in 0 ..< Iters:
randomCurve(P224)
test "Random squaring mod P-256 [FastSquaring = " & $P256.canUseNoCarryMontySquare & "]":
for _ in 0 ..< Iters:
randomCurve(P256)
test "Random squaring mod BLS12_381 [FastSquaring = " & $BLS12_381.canUseNoCarryMontySquare & "]":
for _ in 0 ..< Iters:
randomCurve(BLS12_381)

View File

@ -0,0 +1 @@
-d:testingCurves

View File

@ -8,11 +8,9 @@
import unittest,
../constantine/arithmetic/[bigints, finite_fields],
../constantine/io/io_fields,
../constantine/io/[io_bigints, io_fields],
../constantine/config/curves
import ../constantine/io/io_bigints
static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option"
proc main() =