Low-level refactor part 2 (#176)

This commit is contained in:
Mamy Ratsimbazafy 2022-02-14 14:38:22 +01:00 committed by GitHub
parent 14af7e8724
commit 5db30ef68d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 154 additions and 129 deletions

View File

@ -38,7 +38,7 @@ static: doAssert UseASM_X86_64
macro mulMont_CIOS_sparebit_gen[N: static int]( macro mulMont_CIOS_sparebit_gen[N: static int](
r_PIR: var Limbs[N], a_PIR, b_PIR, r_PIR: var Limbs[N], a_PIR, b_PIR,
M_PIR: Limbs[N], m0ninv_REG: BaseType, M_PIR: Limbs[N], m0ninv_REG: BaseType,
skipReduction: static bool skipFinalSub: static bool
): untyped = ): untyped =
## Generate an optimized Montgomery Multiplication kernel ## Generate an optimized Montgomery Multiplication kernel
## using the CIOS method ## using the CIOS method
@ -175,7 +175,7 @@ macro mulMont_CIOS_sparebit_gen[N: static int](
ctx.mov rax, r # move r away from scratchspace that will be used for final substraction ctx.mov rax, r # move r away from scratchspace that will be used for final substraction
let r2 = rax.asArrayAddr(len = N) let r2 = rax.asArrayAddr(len = N)
if skipReduction: if skipFinalSub:
for i in 0 ..< N: for i in 0 ..< N:
ctx.mov r2[i], t[i] ctx.mov r2[i], t[i]
else: else:
@ -185,14 +185,14 @@ macro mulMont_CIOS_sparebit_gen[N: static int](
) )
result.add ctx.generate() result.add ctx.generate()
func mulMont_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) = func mulMont_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipFinalSub: static bool = false) =
## Constant-time Montgomery multiplication ## Constant-time Montgomery multiplication
## If "skipReduction" is set ## If "skipFinalSub" is set
## the result is in the range [0, 2M) ## the result is in the range [0, 2M)
## otherwise the result is in the range [0, M) ## otherwise the result is in the range [0, M)
## ##
## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation ## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation
r.mulMont_CIOS_sparebit_gen(a, b, M, m0ninv, skipReduction) r.mulMont_CIOS_sparebit_gen(a, b, M, m0ninv, skipFinalSub)
# Montgomery Squaring # Montgomery Squaring
# ------------------------------------------------------------ # ------------------------------------------------------------
@ -209,8 +209,8 @@ func squareMont_CIOS_asm*[N](
r: var Limbs[N], r: var Limbs[N],
a, M: Limbs[N], a, M: Limbs[N],
m0ninv: BaseType, m0ninv: BaseType,
hasSpareBit, skipReduction: static bool) = hasSpareBit, skipFinalSub: static bool) =
## Constant-time modular squaring ## Constant-time modular squaring
var r2x {.noInit.}: Limbs[2*N] var r2x {.noInit.}: Limbs[2*N]
r2x.square_asm_inline(a) r2x.square_asm_inline(a)
r.redcMont_asm_inline(r2x, M, m0ninv, hasSpareBit, skipReduction) r.redcMont_asm_inline(r2x, M, m0ninv, hasSpareBit, skipFinalSub)

View File

@ -179,7 +179,7 @@ proc partialRedx(
macro mulMont_CIOS_sparebit_adx_gen[N: static int]( macro mulMont_CIOS_sparebit_adx_gen[N: static int](
r_PIR: var Limbs[N], a_PIR, b_PIR, r_PIR: var Limbs[N], a_PIR, b_PIR,
M_PIR: Limbs[N], m0ninv_REG: BaseType, M_PIR: Limbs[N], m0ninv_REG: BaseType,
skipReduction: static bool): untyped = skipFinalSub: static bool): untyped =
## Generate an optimized Montgomery Multiplication kernel ## Generate an optimized Montgomery Multiplication kernel
## using the CIOS method ## using the CIOS method
## This requires the most significant word of the Modulus ## This requires the most significant word of the Modulus
@ -268,7 +268,7 @@ macro mulMont_CIOS_sparebit_adx_gen[N: static int](
lo, C lo, C
) )
if skipReduction: if skipFinalSub:
for i in 0 ..< N: for i in 0 ..< N:
ctx.mov r[i], t[i] ctx.mov r[i], t[i]
else: else:
@ -279,14 +279,14 @@ macro mulMont_CIOS_sparebit_adx_gen[N: static int](
result.add ctx.generate result.add ctx.generate
func mulMont_CIOS_sparebit_asm_adx*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) = func mulMont_CIOS_sparebit_asm_adx*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipFinalSub: static bool = false) =
## Constant-time Montgomery multiplication ## Constant-time Montgomery multiplication
## If "skipReduction" is set ## If "skipFinalSub" is set
## the result is in the range [0, 2M) ## the result is in the range [0, 2M)
## otherwise the result is in the range [0, M) ## otherwise the result is in the range [0, M)
## ##
## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation ## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation
r.mulMont_CIOS_sparebit_adx_gen(a, b, M, m0ninv, skipReduction) r.mulMont_CIOS_sparebit_adx_gen(a, b, M, m0ninv, skipFinalSub)
# Montgomery Squaring # Montgomery Squaring
# ------------------------------------------------------------ # ------------------------------------------------------------
@ -295,8 +295,8 @@ func squareMont_CIOS_asm_adx*[N](
r: var Limbs[N], r: var Limbs[N],
a, M: Limbs[N], a, M: Limbs[N],
m0ninv: BaseType, m0ninv: BaseType,
hasSpareBit, skipReduction: static bool) = hasSpareBit, skipFinalSub: static bool) =
## Constant-time modular squaring ## Constant-time modular squaring
var r2x {.noInit.}: Limbs[2*N] var r2x {.noInit.}: Limbs[2*N]
r2x.square_asm_adx_inline(a) r2x.square_asm_adx_inline(a)
r.redcMont_asm_adx(r2x, M, m0ninv, hasSpareBit, skipReduction) r.redcMont_asm_adx(r2x, M, m0ninv, hasSpareBit, skipFinalSub)

View File

@ -34,7 +34,7 @@ macro redc2xMont_gen*[N: static int](
a_PIR: array[N*2, SecretWord], a_PIR: array[N*2, SecretWord],
M_PIR: array[N, SecretWord], M_PIR: array[N, SecretWord],
m0ninv_REG: BaseType, m0ninv_REG: BaseType,
hasSpareBit, skipReduction: static bool hasSpareBit, skipFinalSub: static bool
) = ) =
# No register spilling handling # No register spilling handling
@ -153,7 +153,7 @@ macro redc2xMont_gen*[N: static int](
# v is invalidated from now on # v is invalidated from now on
let t = repackRegisters(v, u[N], u[N+1]) let t = repackRegisters(v, u[N], u[N+1])
if hasSpareBit and skipReduction: if hasSpareBit and skipFinalSub:
for i in 0 ..< N: for i in 0 ..< N:
ctx.mov r[i], t[i] ctx.mov r[i], t[i]
elif hasSpareBit: elif hasSpareBit:
@ -170,22 +170,22 @@ func redcMont_asm_inline*[N: static int](
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, m0ninv: BaseType,
hasSpareBit: static bool, hasSpareBit: static bool,
skipReduction: static bool = false skipFinalSub: static bool = false
) {.inline.} = ) {.inline.} =
## Constant-time Montgomery reduction ## Constant-time Montgomery reduction
## Inline-version ## Inline-version
redc2xMont_gen(r, a, M, m0ninv, hasSpareBit, skipReduction) redc2xMont_gen(r, a, M, m0ninv, hasSpareBit, skipFinalSub)
func redcMont_asm*[N: static int]( func redcMont_asm*[N: static int](
r: var array[N, SecretWord], r: var array[N, SecretWord],
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, m0ninv: BaseType,
hasSpareBit, skipReduction: static bool hasSpareBit, skipFinalSub: static bool
) = ) =
## Constant-time Montgomery reduction ## Constant-time Montgomery reduction
static: doAssert UseASM_X86_64, "This requires x86-64." static: doAssert UseASM_X86_64, "This requires x86-64."
redcMont_asm_inline(r, a, M, m0ninv, hasSpareBit, skipReduction) redcMont_asm_inline(r, a, M, m0ninv, hasSpareBit, skipFinalSub)
# Montgomery conversion # Montgomery conversion
# ---------------------------------------------------------- # ----------------------------------------------------------
@ -351,8 +351,8 @@ when isMainModule:
var a_sqr{.noInit.}, na_sqr{.noInit.}: Limbs[2] var a_sqr{.noInit.}, na_sqr{.noInit.}: Limbs[2]
var a_sqr_comba{.noInit.}, na_sqr_comba{.noInit.}: Limbs[2] var a_sqr_comba{.noInit.}, na_sqr_comba{.noInit.}: Limbs[2]
a_sqr.redcMont_asm(adbl_sqr, M, 1, hasSpareBit = false, skipReduction = false) a_sqr.redcMont_asm(adbl_sqr, M, 1, hasSpareBit = false, skipFinalSub = false)
na_sqr.redcMont_asm(nadbl_sqr, M, 1, hasSpareBit = false, skipReduction = false) na_sqr.redcMont_asm(nadbl_sqr, M, 1, hasSpareBit = false, skipFinalSub = false)
a_sqr_comba.redc2xMont_Comba(adbl_sqr, M, 1) a_sqr_comba.redc2xMont_Comba(adbl_sqr, M, 1)
na_sqr_comba.redc2xMont_Comba(nadbl_sqr, M, 1) na_sqr_comba.redc2xMont_Comba(nadbl_sqr, M, 1)

View File

@ -38,7 +38,7 @@ macro redc2xMont_adx_gen[N: static int](
a_PIR: array[N*2, SecretWord], a_PIR: array[N*2, SecretWord],
M_PIR: array[N, SecretWord], M_PIR: array[N, SecretWord],
m0ninv_REG: BaseType, m0ninv_REG: BaseType,
hasSpareBit, skipReduction: static bool hasSpareBit, skipFinalSub: static bool
) = ) =
# No register spilling handling # No register spilling handling
@ -131,7 +131,7 @@ macro redc2xMont_adx_gen[N: static int](
let t = repackRegisters(v, u[N]) let t = repackRegisters(v, u[N])
if hasSpareBit and skipReduction: if hasSpareBit and skipFinalSub:
for i in 0 ..< N: for i in 0 ..< N:
ctx.mov r[i], t[i] ctx.mov r[i], t[i]
elif hasSpareBit: elif hasSpareBit:
@ -148,11 +148,11 @@ func redcMont_asm_adx_inline*[N: static int](
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, m0ninv: BaseType,
hasSpareBit: static bool, hasSpareBit: static bool,
skipReduction: static bool = false skipFinalSub: static bool = false
) {.inline.} = ) {.inline.} =
## Constant-time Montgomery reduction ## Constant-time Montgomery reduction
## Inline-version ## Inline-version
redc2xMont_adx_gen(r, a, M, m0ninv, hasSpareBit, skipReduction) redc2xMont_adx_gen(r, a, M, m0ninv, hasSpareBit, skipFinalSub)
func redcMont_asm_adx*[N: static int]( func redcMont_asm_adx*[N: static int](
r: var array[N, SecretWord], r: var array[N, SecretWord],
@ -160,10 +160,10 @@ func redcMont_asm_adx*[N: static int](
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, m0ninv: BaseType,
hasSpareBit: static bool, hasSpareBit: static bool,
skipReduction: static bool = false skipFinalSub: static bool = false
) = ) =
## Constant-time Montgomery reduction ## Constant-time Montgomery reduction
redcMont_asm_adx_inline(r, a, M, m0ninv, hasSpareBit, skipReduction) redcMont_asm_adx_inline(r, a, M, m0ninv, hasSpareBit, skipFinalSub)
# Montgomery conversion # Montgomery conversion

View File

@ -24,7 +24,7 @@ import
# #
# ############################################################ # ############################################################
func getMont*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, spareBits: static int) = func getMont*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: BaseType, spareBits: static int) =
## Convert a BigInt from its natural representation ## Convert a BigInt from its natural representation
## to the Montgomery residue form ## to the Montgomery residue form
## ##
@ -41,7 +41,7 @@ func getMont*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, s
## and R = (2^WordBitWidth)^W ## and R = (2^WordBitWidth)^W
getMont(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, spareBits) getMont(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, spareBits)
func fromMont*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static BaseType, spareBits: static int) = func fromMont*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: BaseType, spareBits: static int) =
## Convert a BigInt from its Montgomery residue form ## Convert a BigInt from its Montgomery residue form
## to the natural representation ## to the natural representation
## ##
@ -52,20 +52,20 @@ func fromMont*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static
fromMont(r.limbs, a.limbs, M.limbs, m0ninv, spareBits) fromMont(r.limbs, a.limbs, M.limbs, m0ninv, spareBits)
func mulMont*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType, func mulMont*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType,
spareBits: static int, skipReduction: static bool = false) = spareBits: static int, skipFinalSub: static bool = false) =
## Compute r <- a*b (mod M) in the Montgomery domain ## Compute r <- a*b (mod M) in the Montgomery domain
## ##
## This resets r to zero before processing. Use {.noInit.} ## This resets r to zero before processing. Use {.noInit.}
## to avoid duplicating with Nim zero-init policy ## to avoid duplicating with Nim zero-init policy
mulMont(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, spareBits, skipReduction) mulMont(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, spareBits, skipFinalSub)
func squareMont*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType, func squareMont*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType,
spareBits: static int, skipReduction: static bool = false) = spareBits: static int, skipFinalSub: static bool = false) =
## Compute r <- a^2 (mod M) in the Montgomery domain ## Compute r <- a^2 (mod M) in the Montgomery domain
## ##
## This resets r to zero before processing. Use {.noInit.} ## This resets r to zero before processing. Use {.noInit.}
## to avoid duplicating with Nim zero-init policy ## to avoid duplicating with Nim zero-init policy
squareMont(r.limbs, a.limbs, M.limbs, negInvModWord, spareBits, skipReduction) squareMont(r.limbs, a.limbs, M.limbs, negInvModWord, spareBits, skipFinalSub)
func powMont*[mBits: static int]( func powMont*[mBits: static int](
a: var BigInt[mBits], exponent: openarray[byte], a: var BigInt[mBits], exponent: openarray[byte],

View File

@ -214,14 +214,14 @@ func double*(r: var FF, a: FF) {.meter.} =
overflowed = overflowed or not(r.mres < FF.fieldMod()) overflowed = overflowed or not(r.mres < FF.fieldMod())
discard csub(r.mres, FF.fieldMod(), overflowed) discard csub(r.mres, FF.fieldMod(), overflowed)
func prod*(r: var FF, a, b: FF, skipReduction: static bool = false) {.meter.} = func prod*(r: var FF, a, b: FF, skipFinalSub: static bool = false) {.meter.} =
## Store the product of ``a`` by ``b`` modulo p into ``r`` ## Store the product of ``a`` by ``b`` modulo p into ``r``
## ``r`` is initialized / overwritten ## ``r`` is initialized / overwritten
r.mres.mulMont(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipReduction) r.mres.mulMont(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)
func square*(r: var FF, a: FF, skipReduction: static bool = false) {.meter.} = func square*(r: var FF, a: FF, skipFinalSub: static bool = false) {.meter.} =
## Squaring modulo p ## Squaring modulo p
r.mres.squareMont(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipReduction) r.mres.squareMont(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)
func neg*(r: var FF, a: FF) {.meter.} = func neg*(r: var FF, a: FF) {.meter.} =
## Negate modulo p ## Negate modulo p
@ -413,38 +413,23 @@ func `*=`*(a: var FF, b: FF) {.meter.} =
## Multiplication modulo p ## Multiplication modulo p
a.prod(a, b) a.prod(a, b)
func square*(a: var FF, skipReduction: static bool = false) {.meter.} = func square*(a: var FF, skipFinalSub: static bool = false) {.meter.} =
## Squaring modulo p ## Squaring modulo p
a.square(a, skipReduction) a.square(a, skipFinalSub)
func square_repeated*(a: var FF, num: int, skipReduction: static bool = false) {.meter.} = func square_repeated*(a: var FF, num: int, skipFinalSub: static bool = false) {.meter.} =
## Repeated squarings ## Repeated squarings
# Except in Tonelli-Shanks, num is always known at compile-time ## Assumes at least 1 squaring
# and square repeated is inlined, so the compiler should optimize the branches away. for _ in 0 ..< num-1:
a.square(skipFinalSub = true)
a.square(skipFinalSub)
# TODO: understand the conditions to avoid the final substraction func square_repeated*(r: var FF, a: FF, num: int, skipFinalSub: static bool = false) {.meter.} =
for _ in 0 ..< num:
a.square(skipReduction = false)
func square_repeated*(r: var FF, a: FF, num: int, skipReduction: static bool = false) {.meter.} =
## Repeated squarings ## Repeated squarings
r.square(a, skipFinalSub = true)
# TODO: understand the conditions to avoid the final substraction for _ in 1 ..< num-1:
r.square(a) r.square(skipFinalSub = true)
for _ in 1 ..< num: r.square(skipFinalSub)
r.square()
func square_repeated_then_mul*(a: var FF, num: int, b: FF, skipReduction: static bool = false) {.meter.} =
## Square `a`, `num` times and then multiply by b
## Assumes at least 1 squaring
a.square_repeated(num, skipReduction = false)
a.prod(a, b, skipReduction = skipReduction)
func square_repeated_then_mul*(r: var FF, a: FF, num: int, b: FF, skipReduction: static bool = false) {.meter.} =
## Square `a`, `num` times and then multiply by b
## Assumes at least 1 squaring
r.square_repeated(a, num, skipReduction = false)
r.prod(r, b, skipReduction = skipReduction)
func `*=`*(a: var FF, b: static int) = func `*=`*(a: var FF, b: static int) =
## Multiplication by a small integer known at compile-time ## Multiplication by a small integer known at compile-time
@ -550,3 +535,46 @@ template mulCheckSparse*(a: var Fp, b: Fp) =
{.pop.} # inline {.pop.} # inline
{.pop.} # raises no exceptions {.pop.} # raises no exceptions
# ############################################################
#
# Field arithmetic ergonomic macros
#
# ############################################################
import std/macros
macro addchain*(fn: untyped): untyped =
## Modify all prod, `*=`, square, square_repeated calls
## to skipFinalSub except the very last call.
## This assumes straight-line code.
fn.expectKind(nnkFuncDef)
result = fn
var body = newStmtList()
for i, statement in fn[^1]:
statement.expectKind({nnkCommentStmt, nnkVarSection, nnkCall, nnkInfix})
var s = statement.copyNimTree()
if i + 1 != result[^1].len:
# Modify all but the last
if s.kind == nnkCall:
doAssert s[0].kind == nnkDotExpr, "Only method call syntax or infix syntax is supported in addition chains"
doAssert s[0][1].eqIdent"prod" or s[0][1].eqIdent"square" or s[0][1].eqIdent"square_repeated"
s.add newLit(true)
elif s.kind == nnkInfix:
doAssert s[0].eqIdent"*="
# a *= b -> prod(a, a, b, true)
s = newCall(
bindSym"prod",
s[1],
s[1],
s[2],
newLit(true)
)
body.add s
result[^1] = body
# echo result.toStrLit()

View File

@ -194,13 +194,14 @@ func invsqrt_tonelli_shanks_pre(
t.square(z) t.square(z)
t *= a t *= a
r = z r = z
var b = t var b {.noInit.} = t
var root = Fp.C.tonelliShanks(root_of_unity) var root {.noInit.} = Fp.C.tonelliShanks(root_of_unity)
var buf {.noInit.}: Fp var buf {.noInit.}: Fp
for i in countdown(e, 2, 1): for i in countdown(e, 2, 1):
b.square_repeated(i-2) if i-2 >= 1:
b.square_repeated(i-2)
let bNotOne = not b.isOne() let bNotOne = not b.isOne()
buf.prod(r, root) buf.prod(r, root)

View File

@ -57,15 +57,14 @@ func redc2xMont_CIOS[N: static int](
r: var array[N, SecretWord], r: var array[N, SecretWord],
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, skipReduction: static bool = false) = m0ninv: BaseType, skipFinalSub: static bool = false) =
## Montgomery reduce a double-precision bigint modulo M ## Montgomery reduce a double-precision bigint modulo M
## ##
## This maps ## This maps
## - [0, 4p²) -> [0, 2p) with skipReduction ## - [0, 4p²) -> [0, 2p) with skipFinalSub
## - [0, 4p²) -> [0, p) without ## - [0, 4p²) -> [0, p) without
## ##
## SkipReduction skips the final substraction step. ## skipFinalSub skips the final substraction step.
## For skipReduction, M needs to have a spare bit in it's representation i.e. unused MSB.
# - Analyzing and Comparing Montgomery Multiplication Algorithms # - Analyzing and Comparing Montgomery Multiplication Algorithms
# Cetin Kaya Koc and Tolga Acar and Burton S. Kaliski Jr. # Cetin Kaya Koc and Tolga Acar and Burton S. Kaliski Jr.
# http://pdfs.semanticscholar.org/5e39/41ff482ec3ee41dc53c3298f0be085c69483.pdf # http://pdfs.semanticscholar.org/5e39/41ff482ec3ee41dc53c3298f0be085c69483.pdf
@ -119,7 +118,7 @@ func redc2xMont_CIOS[N: static int](
addC(carry, res[i], a[i+N], res[i], carry) addC(carry, res[i], a[i+N], res[i], carry)
# Final substraction # Final substraction
when not skipReduction: when not skipFinalSub:
discard res.csub(M, SecretWord(carry).isNonZero() or not(res < M)) discard res.csub(M, SecretWord(carry).isNonZero() or not(res < M))
r = res r = res
@ -127,15 +126,14 @@ func redc2xMont_Comba[N: static int](
r: var array[N, SecretWord], r: var array[N, SecretWord],
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, skipReduction: static bool = false) = m0ninv: BaseType, skipFinalSub: static bool = false) =
## Montgomery reduce a double-precision bigint modulo M ## Montgomery reduce a double-precision bigint modulo M
## ##
## This maps ## This maps
## - [0, 4p²) -> [0, 2p) with skipReduction ## - [0, 4p²) -> [0, 2p) with skipFinalSub
## - [0, 4p²) -> [0, p) without ## - [0, 4p²) -> [0, p) without
## ##
## SkipReduction skips the final substraction step. ## skipFinalSub skips the final substraction step.
## For skipReduction, M needs to have a spare bit in it's representation i.e. unused MSB.
# We use Product Scanning / Comba multiplication # We use Product Scanning / Comba multiplication
var t, u, v = Zero var t, u, v = Zero
var carry: Carry var carry: Carry
@ -171,14 +169,14 @@ func redc2xMont_Comba[N: static int](
addC(carry, z[N-1], v, a[2*N-1], Carry(0)) addC(carry, z[N-1], v, a[2*N-1], Carry(0))
# Final substraction # Final substraction
when not skipReduction: when not skipFinalSub:
discard z.csub(M, SecretBool(carry) or not(z < M)) discard z.csub(M, SecretBool(carry) or not(z < M))
r = z r = z
# Montgomery Multiplication # Montgomery Multiplication
# ------------------------------------------------------------ # ------------------------------------------------------------
func mulMont_CIOS_sparebit(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) = func mulMont_CIOS_sparebit(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipFinalSub: static bool = false) =
## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS) ## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS)
## and no-carry optimization. ## and no-carry optimization.
## This requires the most significant word of the Modulus ## This requires the most significant word of the Modulus
@ -186,10 +184,10 @@ func mulMont_CIOS_sparebit(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipR
## https://hackmd.io/@gnark/modular_multiplication ## https://hackmd.io/@gnark/modular_multiplication
## ##
## This maps ## This maps
## - [0, 2p) -> [0, 2p) with skipReduction ## - [0, 2p) -> [0, 2p) with skipFinalSub
## - [0, 2p) -> [0, p) without ## - [0, 2p) -> [0, p) without
## ##
## SkipReduction skips the final substraction step. ## skipFinalSub skips the final substraction step.
# We want all the computation to be kept in registers # We want all the computation to be kept in registers
# hence we use a temporary `t`, hoping that the compiler does it. # hence we use a temporary `t`, hoping that the compiler does it.
@ -213,7 +211,7 @@ func mulMont_CIOS_sparebit(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipR
t[N-1] = C + A t[N-1] = C + A
when not skipReduction: when not skipFinalSub:
discard t.csub(M, not(t < M)) discard t.csub(M, not(t < M))
r = t r = t
@ -265,15 +263,14 @@ func mulMont_CIOS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) {.used.} =
discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)ʷ discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)ʷ
r = t r = t
func mulMont_FIPS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) = func mulMont_FIPS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipFinalSub: static bool = false) =
## Montgomery Multiplication using Finely Integrated Product Scanning (FIPS) ## Montgomery Multiplication using Finely Integrated Product Scanning (FIPS)
## ##
## This maps ## This maps
## - [0, 2p) -> [0, 2p) with skipReduction ## - [0, 2p) -> [0, 2p) with skipFinalSub
## - [0, 2p) -> [0, p) without ## - [0, 2p) -> [0, p) without
## ##
## SkipReduction skips the final substraction step. ## skipFinalSub skips the final substraction step.
## For skipReduction, M needs to have a spare bit in it's representation i.e. unused MSB.
# - Architectural Enhancements for Montgomery # - Architectural Enhancements for Montgomery
# Multiplication on Embedded RISC Processors # Multiplication on Embedded RISC Processors
# Johann Großschädl and Guy-Armand Kamendje, 2003 # Johann Großschädl and Guy-Armand Kamendje, 2003
@ -306,7 +303,7 @@ func mulMont_FIPS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction:
u = t u = t
t = Zero t = Zero
when not skipReduction: when not skipFinalSub:
discard z.csub(M, v.isNonZero() or not(z < M)) discard z.csub(M, v.isNonZero() or not(z < M))
r = z r = z
@ -386,39 +383,46 @@ func fromMont_CIOS(r: var Limbs, a, M: Limbs, m0ninv: BaseType) =
# Exported API # Exported API
# ------------------------------------------------------------ # ------------------------------------------------------------
# Skipping reduction requires the modulus M <= R/4
# On 64-bit R is the multiple of 2⁶⁴ immediately larger than M
#
# Montgomery Arithmetic from a Software Perspective
# Bos and Montgomery, 2017
# https://eprint.iacr.org/2017/1057.pdf
# TODO upstream, using Limbs[N] breaks semcheck # TODO upstream, using Limbs[N] breaks semcheck
func redc2xMont*[N: static int]( func redc2xMont*[N: static int](
r: var array[N, SecretWord], r: var array[N, SecretWord],
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, m0ninv: BaseType,
spareBits: static int, skipReduction: static bool = false) {.inline.} = spareBits: static int, skipFinalSub: static bool = false) {.inline.} =
## Montgomery reduce a double-precision bigint modulo M ## Montgomery reduce a double-precision bigint modulo M
const skipReduction = skipReduction and spareBits >= 1 const skipFinalSub = skipFinalSub and spareBits >= 2
when UseASM_X86_64 and r.len <= 6: when UseASM_X86_64 and r.len <= 6:
# ADX implies BMI2 # ADX implies BMI2
if ({.noSideEffect.}: hasAdx()): if ({.noSideEffect.}: hasAdx()):
redcMont_asm_adx(r, a, M, m0ninv, spareBits >= 1, skipReduction) redcMont_asm_adx(r, a, M, m0ninv, spareBits >= 1, skipFinalSub)
else: else:
when r.len in {3..6}: when r.len in {3..6}:
redcMont_asm(r, a, M, m0ninv, spareBits >= 1, skipReduction) redcMont_asm(r, a, M, m0ninv, spareBits >= 1, skipFinalSub)
else: else:
redc2xMont_CIOS(r, a, M, m0ninv, skipReduction) redc2xMont_CIOS(r, a, M, m0ninv, skipFinalSub)
# redc2xMont_Comba(r, a, M, m0ninv) # redc2xMont_Comba(r, a, M, m0ninv)
elif UseASM_X86_64 and r.len in {3..6}: elif UseASM_X86_64 and r.len in {3..6}:
# TODO: Assembly faster than GCC but slower than Clang # TODO: Assembly faster than GCC but slower than Clang
redcMont_asm(r, a, M, m0ninv, spareBits >= 1, skipReduction) redcMont_asm(r, a, M, m0ninv, spareBits >= 1, skipFinalSub)
else: else:
redc2xMont_CIOS(r, a, M, m0ninv, skipReduction) redc2xMont_CIOS(r, a, M, m0ninv, skipFinalSub)
# redc2xMont_Comba(r, a, M, m0ninv, skipReduction) # redc2xMont_Comba(r, a, M, m0ninv, skipFinalSub)
func mulMont*( func mulMont*(
r: var Limbs, a, b, M: Limbs, r: var Limbs, a, b, M: Limbs,
m0ninv: BaseType, m0ninv: BaseType,
spareBits: static int, spareBits: static int,
skipReduction: static bool = false) {.inline.} = skipFinalSub: static bool = false) {.inline.} =
## Compute r <- a*b (mod M) in the Montgomery domain ## Compute r <- a*b (mod M) in the Montgomery domain
## `m0ninv` = -1/M (mod SecretWord). Our words are 2^32 or 2^64 ## `m0ninv` = -1/M (mod SecretWord). Our words are 2^32 or 2^64
## ##
@ -438,36 +442,28 @@ func mulMont*(
# i.e. c'R <- a'R b'R * R^-1 (mod M) in the natural domain # i.e. c'R <- a'R b'R * R^-1 (mod M) in the natural domain
# as in the Montgomery domain all numbers are scaled by R # as in the Montgomery domain all numbers are scaled by R
# Many curve moduli are "Montgomery-friendly" which means that m0ninv is 1 const skipFinalSub = skipFinalSub and spareBits >= 2
# This saves N basic type multiplication and potentially many register mov
# as well as unless using "mulx" instruction, x86 "mul" requires very specific registers.
#
# The implementation is visible from here, the compiler can make decision whether to:
# - specialize/duplicate code for m0ninv == 1 (especially if only 1 curve is needed)
# - keep it generic and optimize code size
const skipReduction = skipReduction and spareBits >= 1
when spareBits >= 1: when spareBits >= 1:
when UseASM_X86_64 and a.len in {2 .. 6}: # TODO: handle spilling when UseASM_X86_64 and a.len in {2 .. 6}: # TODO: handle spilling
# ADX implies BMI2 # ADX implies BMI2
if ({.noSideEffect.}: hasAdx()): if ({.noSideEffect.}: hasAdx()):
mulMont_CIOS_sparebit_asm_adx(r, a, b, M, m0ninv, skipReduction) mulMont_CIOS_sparebit_asm_adx(r, a, b, M, m0ninv, skipFinalSub)
else: else:
mulMont_CIOS_sparebit_asm(r, a, b, M, m0ninv, skipReduction) mulMont_CIOS_sparebit_asm(r, a, b, M, m0ninv, skipFinalSub)
else: else:
mulMont_CIOS_sparebit(r, a, b, M, m0ninv, skipReduction) mulMont_CIOS_sparebit(r, a, b, M, m0ninv, skipFinalSub)
else: else:
mulMont_FIPS(r, a, b, M, m0ninv, skipReduction) mulMont_FIPS(r, a, b, M, m0ninv, skipFinalSub)
func squareMont*[N](r: var Limbs[N], a, M: Limbs[N], func squareMont*[N](r: var Limbs[N], a, M: Limbs[N],
m0ninv: BaseType, m0ninv: BaseType,
spareBits: static int, spareBits: static int,
skipReduction: static bool = false) {.inline.} = skipFinalSub: static bool = false) {.inline.} =
## Compute r <- a^2 (mod M) in the Montgomery domain ## Compute r <- a^2 (mod M) in the Montgomery domain
## `m0ninv` = -1/M (mod SecretWord). Our words are 2^31 or 2^63 ## `m0ninv` = -1/M (mod SecretWord). Our words are 2^31 or 2^63
const skipReduction = skipReduction and spareBits >= 1 const skipFinalSub = skipFinalSub and spareBits >= 2
when UseASM_X86_64 and a.len in {4, 6}: when UseASM_X86_64 and a.len in {4, 6}:
# ADX implies BMI2 # ADX implies BMI2
@ -476,20 +472,20 @@ func squareMont*[N](r: var Limbs[N], a, M: Limbs[N],
# which uses unfused squaring then Montgomery reduction # which uses unfused squaring then Montgomery reduction
# is slightly slower than fused Montgomery multiplication # is slightly slower than fused Montgomery multiplication
when spareBits >= 1: when spareBits >= 1:
mulMont_CIOS_sparebit_asm_adx(r, a, a, M, m0ninv, skipReduction) mulMont_CIOS_sparebit_asm_adx(r, a, a, M, m0ninv, skipFinalSub)
else: else:
squareMont_CIOS_asm_adx(r, a, M, m0ninv, spareBits >= 1, skipReduction) squareMont_CIOS_asm_adx(r, a, M, m0ninv, spareBits >= 1, skipFinalSub)
else: else:
squareMont_CIOS_asm(r, a, M, m0ninv, spareBits >= 1, skipReduction) squareMont_CIOS_asm(r, a, M, m0ninv, spareBits >= 1, skipFinalSub)
elif UseASM_X86_64: elif UseASM_X86_64:
var r2x {.noInit.}: Limbs[2*N] var r2x {.noInit.}: Limbs[2*N]
r2x.square(a) r2x.square(a)
r.redc2xMont(r2x, M, m0ninv, spareBits, skipReduction) r.redc2xMont(r2x, M, m0ninv, spareBits, skipFinalSub)
else: else:
mulMont(r, a, a, M, m0ninv, spareBits, skipReduction) mulMont(r, a, a, M, m0ninv, spareBits, skipFinalSub)
func fromMont*(r: var Limbs, a, M: Limbs, func fromMont*(r: var Limbs, a, M: Limbs,
m0ninv: static BaseType, spareBits: static int) = m0ninv: BaseType, spareBits: static int) =
## Transform a bigint ``a`` from it's Montgomery N-residue representation (mod N) ## Transform a bigint ``a`` from it's Montgomery N-residue representation (mod N)
## to the regular natural representation (mod N) ## to the regular natural representation (mod N)
## ##
@ -517,7 +513,7 @@ func fromMont*(r: var Limbs, a, M: Limbs,
fromMont_CIOS(r, a, M, m0ninv) fromMont_CIOS(r, a, M, m0ninv)
func getMont*(r: var Limbs, a, M, r2modM: Limbs, func getMont*(r: var Limbs, a, M, r2modM: Limbs,
m0ninv: static BaseType, spareBits: static int) = m0ninv: BaseType, spareBits: static int) =
## Transform a bigint ``a`` from it's natural representation (mod N) ## Transform a bigint ``a`` from it's natural representation (mod N)
## to a the Montgomery n-residue representation ## to a the Montgomery n-residue representation
## ##
@ -580,7 +576,7 @@ func getWindowLen(bufLen: int): uint =
func powMontPrologue( func powMontPrologue(
a: var Limbs, M, one: Limbs, a: var Limbs, M, one: Limbs,
m0ninv: static BaseType, m0ninv: BaseType,
scratchspace: var openarray[Limbs], scratchspace: var openarray[Limbs],
spareBits: static int spareBits: static int
): uint = ): uint =
@ -605,7 +601,7 @@ func powMontSquarings(
a: var Limbs, a: var Limbs,
exponent: openarray[byte], exponent: openarray[byte],
M: Limbs, M: Limbs,
m0ninv: static BaseType, m0ninv: BaseType,
tmp: var Limbs, tmp: var Limbs,
window: uint, window: uint,
acc, acc_len: var uint, acc, acc_len: var uint,
@ -654,7 +650,7 @@ func powMont*(
a: var Limbs, a: var Limbs,
exponent: openarray[byte], exponent: openarray[byte],
M, one: Limbs, M, one: Limbs,
m0ninv: static BaseType, m0ninv: BaseType,
scratchspace: var openarray[Limbs], scratchspace: var openarray[Limbs],
spareBits: static int spareBits: static int
) = ) =
@ -721,7 +717,7 @@ func powMontUnsafeExponent*(
a: var Limbs, a: var Limbs,
exponent: openarray[byte], exponent: openarray[byte],
M, one: Limbs, M, one: Limbs,
m0ninv: static BaseType, m0ninv: BaseType,
scratchspace: var openarray[Limbs], scratchspace: var openarray[Limbs],
spareBits: static int spareBits: static int
) = ) =

View File

@ -28,7 +28,7 @@ const
func precompute_tonelli_shanks_addchain*( func precompute_tonelli_shanks_addchain*(
r: var Fp[BLS12_377], r: var Fp[BLS12_377],
a: Fp[BLS12_377]) = a: Fp[BLS12_377]) {.addchain.} =
## Does a^BLS12_377_TonelliShanks_exponent ## Does a^BLS12_377_TonelliShanks_exponent
## via an addition-chain ## via an addition-chain

View File

@ -16,7 +16,7 @@ import
# #
# ############################################################ # ############################################################
func invsqrt_addchain*(r: var Fp[BLS12_381], a: Fp[BLS12_381]) = func invsqrt_addchain*(r: var Fp[BLS12_381], a: Fp[BLS12_381]) {.addchain.} =
var var
x10 {.noInit.}: Fp[BLS12_381] x10 {.noInit.}: Fp[BLS12_381]
x100 {.noInit.}: Fp[BLS12_381] x100 {.noInit.}: Fp[BLS12_381]

View File

@ -16,7 +16,7 @@ import
# #
# ############################################################ # ############################################################
func invsqrt_addchain*(r: var Fp[BN254_Nogami], a: Fp[BN254_Nogami]) = func invsqrt_addchain*(r: var Fp[BN254_Nogami], a: Fp[BN254_Nogami]) {.addchain.} =
var var
x10 {.noInit.}: Fp[BN254_Nogami] x10 {.noInit.}: Fp[BN254_Nogami]
x11 {.noInit.}: Fp[BN254_Nogami] x11 {.noInit.}: Fp[BN254_Nogami]

View File

@ -16,7 +16,7 @@ import
# #
# ############################################################ # ############################################################
func invsqrt_addchain*(r: var Fp[BN254_Snarks], a: Fp[BN254_Snarks]) = func invsqrt_addchain*(r: var Fp[BN254_Snarks], a: Fp[BN254_Snarks]) {.addchain.} =
var var
x10 {.noInit.}: Fp[BN254_Snarks] x10 {.noInit.}: Fp[BN254_Snarks]
x11 {.noInit.}: Fp[BN254_Snarks] x11 {.noInit.}: Fp[BN254_Snarks]

View File

@ -16,7 +16,7 @@ import
# #
# ############################################################ # ############################################################
func invsqrt_addchain*(r: var Fp[BW6_761], a: Fp[BW6_761]) = func invsqrt_addchain*(r: var Fp[BW6_761], a: Fp[BW6_761]) {.addchain.} =
var var
x10 {.noInit.}: Fp[BW6_761] x10 {.noInit.}: Fp[BW6_761]
x11 {.noInit.}: Fp[BW6_761] x11 {.noInit.}: Fp[BW6_761]

View File

@ -220,7 +220,7 @@ func sum*[F; G: static Subgroup](
t4 *= SexticNonResidue t4 *= SexticNonResidue
x3.sum(P.x, P.z) # 14. X₃ <- X₁ + Z₁ x3.sum(P.x, P.z) # 14. X₃ <- X₁ + Z₁
y3.sum(Q.x, Q.z) # 15. Y₃ <- X₂ + Z₂ y3.sum(Q.x, Q.z) # 15. Y₃ <- X₂ + Z₂
x3 *= y3 # 16. X₃ <- X₃ Y₃ X₃ = (X₁Z₁)(X₂Z₂) x3 *= y3 # 16. X₃ <- X₃ Y₃ X₃ = (X₁+Z₁)(X₂+Z₂)
y3.sum(t0, t2) # 17. Y₃ <- t₀ + t₂ Y₃ = X₁ X₂ + Z₁ Z₂ y3.sum(t0, t2) # 17. Y₃ <- t₀ + t₂ Y₃ = X₁ X₂ + Z₁ Z₂
y3.diff(x3, y3) # 18. Y₃ <- X₃ - Y₃ Y₃ = (X₁ + Z₁)(X₂ + Z₂) - (X₁ X₂ + Z₁ Z₂) = X₁Z₂ + X₂Z₁ y3.diff(x3, y3) # 18. Y₃ <- X₃ - Y₃ Y₃ = (X₁ + Z₁)(X₂ + Z₂) - (X₁ X₂ + Z₁ Z₂) = X₁Z₂ + X₂Z₁
when G == G2 and F.C.getSexticTwist() == D_Twist: when G == G2 and F.C.getSexticTwist() == D_Twist:

View File

@ -84,7 +84,7 @@ The optimizations can be of algebraic, algorithmic or "implementation details" n
- [ ] NAF recoding - [ ] NAF recoding
- [ ] windowed-NAF recoding - [ ] windowed-NAF recoding
- [ ] SIMD vectorized select in window algorithm - [ ] SIMD vectorized select in window algorithm
- [ ] Montgomery Multiplication with no final substraction, - [x] Montgomery Multiplication with no final substraction,
- Bos and Montgomery, https://eprint.iacr.org/2017/1057.pdf - Bos and Montgomery, https://eprint.iacr.org/2017/1057.pdf
- Colin D Walter, https://colinandmargaret.co.uk/Research/CDW_ELL_99.pdf - Colin D Walter, https://colinandmargaret.co.uk/Research/CDW_ELL_99.pdf
- Hachez and Quisquater, https://link.springer.com/content/pdf/10.1007%2F3-540-44499-8_23.pdf - Hachez and Quisquater, https://link.springer.com/content/pdf/10.1007%2F3-540-44499-8_23.pdf