Square ADX (#160)

* Add MULX/ADOX/ADCX assembly for squaring 4 limbs

* Add squarings for 6 limbs

* Use the new square assembly where relevant

* Fix 32-bit register name and calling convention

* typo

* Disable MontRed ASM for 2 limbs or less
This commit is contained in:
Mamy Ratsimbazafy 2021-02-20 13:18:49 +01:00 committed by GitHub
parent 8a7c35af59
commit aefd40f455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 932 additions and 422 deletions

View File

@ -12,7 +12,8 @@ import
# Internal # Internal
../../config/common, ../../config/common,
../../primitives, ../../primitives,
./limbs_asm_montred_x86 ./limbs_asm_montred_x86,
./limbs_asm_mul_x86
# ############################################################ # ############################################################
# #
@ -176,3 +177,37 @@ macro montMul_CIOS_nocarry_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_
func montMul_CIOS_nocarry_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = func montMul_CIOS_nocarry_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
## Constant-time modular multiplication ## Constant-time modular multiplication
montMul_CIOS_nocarry_gen(r, a, b, M, m0ninv) montMul_CIOS_nocarry_gen(r, a, b, M, m0ninv)
# Montgomery Squaring
# ------------------------------------------------------------
func square_asm_inline[rLen, aLen: static int](r: var Limbs[rLen], a: Limbs[aLen]) {.inline.} =
## Multi-precision Squaring
## Assumes r doesn't alias a
## Extra indirection as the generator assumes that
## arrays are pointers, which is true for parameters
## but not for stack variables
sqr_gen(r, a)
func montRed_asm_inline[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType,
hasSpareBit: static bool
) {.inline.} =
## Constant-time Montgomery reduction
## Extra indirection as the generator assumes that
## arrays are pointers, which is true for parameters
## but not for stack variables
montyRedc2x_gen(r, a, M, m0ninv, hasSpareBit)
func montSquare_CIOS_asm*[N](
r: var Limbs[N],
a, M: Limbs[N],
m0ninv: BaseType,
hasSpareBit: static bool) =
## Constant-time modular squaring
var r2x {.noInit.}: Limbs[2*N]
r2x.square_asm_inline(a)
r.montRed_asm_inline(r2x, M, m0ninv, hasSpareBit)

View File

@ -12,7 +12,9 @@ import
# Internal # Internal
../../config/common, ../../config/common,
../../primitives, ../../primitives,
./limbs_asm_montred_x86 ./limbs_asm_montred_x86,
./limbs_asm_montred_x86_adx_bmi2,
./limbs_asm_mul_x86_adx_bmi2
# ############################################################ # ############################################################
# #
@ -271,3 +273,36 @@ macro montMul_CIOS_nocarry_adx_bmi2_gen[N: static int](r_MM: var Limbs[N], a_MM,
func montMul_CIOS_nocarry_asm_adx_bmi2*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = func montMul_CIOS_nocarry_asm_adx_bmi2*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
## Constant-time modular multiplication ## Constant-time modular multiplication
montMul_CIOS_nocarry_adx_bmi2_gen(r, a, b, M, m0ninv) montMul_CIOS_nocarry_adx_bmi2_gen(r, a, b, M, m0ninv)
# Montgomery Squaring
# ------------------------------------------------------------
func square_asm_adx_bmi2_inline[rLen, aLen: static int](r: var Limbs[rLen], a: Limbs[aLen]) {.inline.} =
## Multi-precision Squaring
## Extra indirection as the generator assumes that
## arrays are pointers, which is true for parameters
## but not for stack variables.
sqrx_gen(r, a)
func montRed_asm_adx_bmi2_inline[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType,
hasSpareBit: static bool
) {.inline.} =
## Constant-time Montgomery reduction
## Extra indirection as the generator assumes that
## arrays are pointers, which is true for parameters
## but not for stack variables.
montyRedc2x_adx_gen(r, a, M, m0ninv, hasSpareBit)
func montSquare_CIOS_asm_adx_bmi2*[N](
r: var Limbs[N],
a, M: Limbs[N],
m0ninv: BaseType,
hasSpareBit: static bool) =
## Constant-time modular squaring
var r2x {.noInit.}: Limbs[2*N]
r2x.square_asm_adx_bmi2_inline(a)
r.montRed_asm_adx_bmi2_inline(r2x, M, m0ninv, hasSpareBit)

View File

@ -60,10 +60,13 @@ proc finalSubCanOverflow*(
## `overflowReg` should be a register that will be used ## `overflowReg` should be a register that will be used
## to store the carry flag ## to store the carry flag
ctx.comment "Final substraction (may carry)"
# Mask: overflowed contains 0xFFFF or 0x0000
ctx.sbb overflowReg, overflowReg ctx.sbb overflowReg, overflowReg
# Now substract the modulus to test a < p
let N = M.len let N = M.len
ctx.comment "Final substraction (may carry)"
for i in 0 ..< N: for i in 0 ..< N:
ctx.mov scratch[i], a[i] ctx.mov scratch[i], a[i]
if i == 0: if i == 0:
@ -71,6 +74,8 @@ proc finalSubCanOverflow*(
else: else:
ctx.sbb scratch[i], M[i] ctx.sbb scratch[i], M[i]
# If it overflows here, it means that it was
# smaller than the modulus and we don't need `scratch`
ctx.sbb overflowReg, 0 ctx.sbb overflowReg, 0
# If we borrowed it means that we were smaller than # If we borrowed it means that we were smaller than
@ -83,12 +88,12 @@ proc finalSubCanOverflow*(
# Montgomery reduction # Montgomery reduction
# ------------------------------------------------------------ # ------------------------------------------------------------
macro montyRedc2x_gen[N: static int]( macro montyRedc2x_gen*[N: static int](
r_MR: var array[N, SecretWord], r_MR: var array[N, SecretWord],
a_MR: array[N*2, SecretWord], a_MR: array[N*2, SecretWord],
M_MR: array[N, SecretWord], M_MR: array[N, SecretWord],
m0ninv_MR: BaseType, m0ninv_MR: BaseType,
spareBits: static int hasSpareBit: static bool
) = ) =
result = newStmtList() result = newStmtList()
@ -137,7 +142,8 @@ macro montyRedc2x_gen[N: static int](
# r -= M # r -= M
# No register spilling handling # No register spilling handling
doAssert N <= 6, "The Assembly-optimized montgomery multiplication requires at most 6 limbs." doAssert N > 2, "The Assembly-optimized montgomery reduction requires a minimum of 2 limbs."
doAssert N <= 6, "The Assembly-optimized montgomery reduction requires at most 6 limbs."
for i in 0 ..< N: for i in 0 ..< N:
ctx.mov u[i], a[i] ctx.mov u[i], a[i]
@ -205,7 +211,7 @@ macro montyRedc2x_gen[N: static int](
let t = repackRegisters(v, u[N], u[N+1]) let t = repackRegisters(v, u[N], u[N+1])
# v is invalidated # v is invalidated
if spareBits >= 1: if hasSpareBit:
ctx.finalSubNoCarry(r, u, M, t) ctx.finalSubNoCarry(r, u, M, t)
else: else:
ctx.finalSubCanOverflow(r, u, M, t, rax) ctx.finalSubCanOverflow(r, u, M, t, rax)
@ -218,7 +224,93 @@ func montRed_asm*[N: static int](
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, m0ninv: BaseType,
spareBits: static int hasSpareBit: static bool
) = ) =
## Constant-time Montgomery reduction ## Constant-time Montgomery reduction
montyRedc2x_gen(r, a, M, m0ninv, spareBits) static: doAssert UseASM_X86_64, "This requires x86-64."
montyRedc2x_gen(r, a, M, m0ninv, hasSpareBit)
# Sanity checks
# ----------------------------------------------------------
when isMainModule:
import
../../config/[type_bigint, common],
../../arithmetic/limbs
type SW = SecretWord
# TODO: Properly handle low number of limbs
func montyRedc2x_Comba[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType) =
## Montgomery reduce a double-precision bigint modulo M
# We use Product Scanning / Comba multiplication
var t, u, v = Zero
var carry: Carry
var z: typeof(r) # zero-init, ensure on stack and removes in-place problems in tower fields
staticFor i, 0, N:
staticFor j, 0, i:
mulAcc(t, u, v, z[j], M[i-j])
addC(carry, v, v, a[i], Carry(0))
addC(carry, u, u, Zero, carry)
addC(carry, t, t, Zero, carry)
z[i] = v * SecretWord(m0ninv)
mulAcc(t, u, v, z[i], M[0])
v = u
u = t
t = Zero
staticFor i, N, 2*N-1:
staticFor j, i-N+1, N:
mulAcc(t, u, v, z[j], M[i-j])
addC(carry, v, v, a[i], Carry(0))
addC(carry, u, u, Zero, carry)
addC(carry, t, t, Zero, carry)
z[i-N] = v
v = u
u = t
t = Zero
addC(carry, z[N-1], v, a[2*N-1], Carry(0))
# Final substraction
discard z.csub(M, SecretBool(carry) or not(z < M))
r = z
proc main2L() =
let M = [SW 0xFFFFFFFF_FFFFFFFF'u64, SW 0x7FFFFFFF_FFFFFFFF'u64]
# a²
let adbl_sqr = [SW 0xFF677F6000000001'u64, SW 0xD79897153FA818FD'u64, SW 0x68BFF63DE35C5451'u64, SW 0x2D243FE4B480041F'u64]
# (-a)²
let nadbl_sqr = [SW 0xFECEFEC000000004'u64, SW 0xAE9896D43FA818FB'u64, SW 0x690C368DE35C5450'u64, SW 0x01A4400534800420'u64]
var a_sqr{.noInit.}, na_sqr{.noInit.}: Limbs[2]
var a_sqr_comba{.noInit.}, na_sqr_comba{.noInit.}: Limbs[2]
a_sqr.montRed_asm(adbl_sqr, M, 1, hasSpareBit = false)
na_sqr.montRed_asm(nadbl_sqr, M, 1, hasSpareBit = false)
a_sqr_comba.montyRedc2x_Comba(adbl_sqr, M, 1)
na_sqr_comba.montyRedc2x_Comba(nadbl_sqr, M, 1)
debugecho "--------------------------------"
debugecho "after:"
debugecho " a_sqr: ", a_sqr.toString()
debugecho " na_sqr: ", na_sqr.toString()
debugecho " a_sqr_comba: ", a_sqr_comba.toString()
debugecho " na_sqr_comba: ", na_sqr_comba.toString()
doAssert bool(a_sqr == na_sqr)
doAssert bool(a_sqr == a_sqr_comba)
main2L()

View File

@ -35,12 +35,12 @@ static: doAssert UseASM_X86_64
# Montgomery reduction # Montgomery reduction
# ------------------------------------------------------------ # ------------------------------------------------------------
macro montyRedc2x_gen[N: static int]( macro montyRedc2x_adx_gen*[N: static int](
r_MR: var array[N, SecretWord], r_MR: var array[N, SecretWord],
a_MR: array[N*2, SecretWord], a_MR: array[N*2, SecretWord],
M_MR: array[N, SecretWord], M_MR: array[N, SecretWord],
m0ninv_MR: BaseType, m0ninv_MR: BaseType,
spareBits: static int hasSpareBit: static bool
) = ) =
result = newStmtList() result = newStmtList()
@ -132,7 +132,7 @@ macro montyRedc2x_gen[N: static int](
let t = repackRegisters(v, u[N]) let t = repackRegisters(v, u[N])
if spareBits >= 1: if hasSpareBit:
ctx.finalSubNoCarry(r, u, M, t) ctx.finalSubNoCarry(r, u, M, t)
else: else:
ctx.finalSubCanOverflow(r, u, M, t, hi) ctx.finalSubCanOverflow(r, u, M, t, hi)
@ -145,7 +145,7 @@ func montRed_asm_adx_bmi2*[N: static int](
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, m0ninv: BaseType,
spareBits: static int hasSpareBit: static bool
) = ) =
## Constant-time Montgomery reduction ## Constant-time Montgomery reduction
montyRedc2x_gen(r, a, M, m0ninv, spareBits) montyRedc2x_adx_gen(r, a, M, m0ninv, hasSpareBit)

View File

@ -133,7 +133,7 @@ func mul_asm*[rLen, aLen, bLen: static int](r: var Limbs[rLen], a: Limbs[aLen],
# Squaring # Squaring
# ----------------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------------
macro square_gen[rLen, aLen: static int](r: var Limbs[rLen], a: Limbs[aLen]) = macro sqr_gen*[rLen, aLen: static int](r: var Limbs[rLen], a: Limbs[aLen]) =
## Comba squaring generator ## Comba squaring generator
## `a` and `r` can have a different number of limbs ## `a` and `r` can have a different number of limbs
## if `r`.limbs.len < a.limbs.len * 2 ## if `r`.limbs.len < a.limbs.len * 2
@ -240,4 +240,4 @@ macro square_gen[rLen, aLen: static int](r: var Limbs[rLen], a: Limbs[aLen]) =
func square_asm*[rLen, aLen: static int](r: var Limbs[rLen], a: Limbs[aLen]) = func square_asm*[rLen, aLen: static int](r: var Limbs[rLen], a: Limbs[aLen]) =
## Multi-precision Squaring ## Multi-precision Squaring
## Assumes r doesn't alias a ## Assumes r doesn't alias a
square_gen(r, a) sqr_gen(r, a)

View File

@ -40,8 +40,10 @@ proc mulx_by_word(
word0: Operand word0: Operand
) = ) =
## Multiply the `a[0..<N]` by `word` ## Multiply the `a[0..<N]` by `word`
## and store in `[t:r0]` ## and store in `[t[n..1]:r0]`
## with [t:r0] = tn, tn-1, ... t1, r0 ## with [t[n..1]:r0] = tn, tn-1, ... t1, r0
## This assumes that t will be rotated left and so
## t1 is in t[0] and tn in t[n-1]
doAssert a.len + 1 == t.len doAssert a.len + 1 == t.len
let N = a.len let N = a.len
@ -76,8 +78,10 @@ proc mulaccx_by_word(
word: Operand word: Operand
) = ) =
## Multiply the `a[0..<N]` by `word` ## Multiply the `a[0..<N]` by `word`
## and store in `[t:r0]` ## and store in `[t[n..0]:r0]`
## with [t:r0] = tn, tn-1, ... t1, r0 ## with [t[n..0]:r0] = tn, tn-1, ... t1, r0
## This assumes that t will be rotated left and so
## t1 is in t[0] and tn in t[n-1]
doAssert a.len + 1 == t.len doAssert a.len + 1 == t.len
let N = min(a.len, r.len) let N = min(a.len, r.len)
let hi = t[a.len] let hi = t[a.len]
@ -130,7 +134,6 @@ macro mulx_gen[rLen, aLen, bLen: static int](rx: var Limbs[rLen], ax: Limbs[aLen
var # If aLen is too big, we need to spill registers. TODO. var # If aLen is too big, we need to spill registers. TODO.
t = init(OperandArray, nimSymbol = ident"t", tSlots, ElemsInReg, Output_EarlyClobber) t = init(OperandArray, nimSymbol = ident"t", tSlots, ElemsInReg, Output_EarlyClobber)
# Prologue # Prologue
let tsym = t.nimSymbol let tsym = t.nimSymbol
result.add quote do: result.add quote do:
@ -173,131 +176,516 @@ func mul_asm_adx_bmi2*[rLen, aLen, bLen: static int](r: var Limbs[rLen], a: Limb
# Squaring # Squaring
# ----------------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------------
# TODO: We use 16 registers but GCC/Clang still complain :/ #
# Strategy:
# We want to use the same scheduling as mul_asm_adx_bmi2
# and so process `a[0..<N]` by `word`
# and store the intermediate result in `[t[n..1]:r0]`
#
# However for squarings, all the multiplications a[i,j] * a[i,j]
# with i != j occurs twice, hence we can do them only once and double them at an opportune time.
#
# Assuming a 4 limbs bigint we have the following multiplications to do:
#
# a₃a₂a₁a₀
# * a₃a₂a₁a₀
# ---------------------------
# a₀a₀
# a₁a₁
# a₂a₂
# a₃a₃
#
# a₁a₀ |
# a₂a₀ |
# a₃a₀ |
# | * 2
# a₂a₁ |
# a₃a₁ |
# |
# a₃a₂ |
#
# r₇r₆r₅r₄r₃r₂r₁r₀
#
# The multiplication strategy is to mulx+adox+adcx on a diagonal
# handling both carry into next mul and partial sums carry into t
# then saving the lowest word in t into r.
#
# We want `t` of size N+1 with N the number of limbs just like multiplication,
# and reuse the multiplication algorithm
# this means that we need to reorganize scheduling like so to maximize utilization
#
# a₃ a₂ a₁ a₀
# * a₃ a₂ a₁ a₀
# ------------------------------
# a₀*a₀
# a₁*a₁
# a₂*a₂
# a₃*a₃
#
# a₂*a₁ a₁*a₀ |
# a₃*a₁ a₂*a₀ | * 2
# a₃*a₂ a₃*a₀ |
#
# r₇ r₆ r₅ r₄ r₃ r₂ r₁ r₀
#
# Note that while processing the second diagonal we do
# a₂*a₁ then a₃*a₁ then we change word to a₃*a₂.
#
# We want to use an index as much as possible in the diagonal.
# - There is probably a clever solution using graphs
# - https://en.wikipedia.org/wiki/Longest_path_problem
# - https://en.wikipedia.org/wiki/Longest_increasing_subsequence
# - or polyhedral optimization: http://playground.pollylabs.org/
#
# but we only care about 4*4 and 6*6 at the moment, for 6*6 the schedule is
# a₅ a₄ a₃ a₂ a₁ a₀
# * a₅ a₄ a₃ a₂ a₁ a₀
# -------------------------------------
# a₀*a₀
# a₁*a₁
# a₂*a₂
# a₃*a₃
# a₄*a₄
# a₅*a₅
#
# a₃*a₂ a₂*a₁ a₁*a₀ |
# a₄*a₂ a₃*a₁ a₂*a₀ |
# a₄*a₃ a₄*a₁ a₃*a₀ | * 2
# a₅*a₃ a₅*a₁ a₄*a₀ |
# a₅*a₄ a₅*a₂ a₅*a₀ |
#
#
# r₁₁ r₁₀ r₉ r₈ r₇ r₆ r₅ r₄ r₃ r₂ r₁ r₀
# macro sqrx_gen[rLen, aLen: static int](rx: var Limbs[rLen], ax: Limbs[aLen]) = template merge_diag_and_partsum(r, a, hi, lo, zero, i): untyped =
# ## Squaring ctx.mulx hi, lo, a[i], rdx
# ## `a` and `r` can have a different number of limbs if i+1 < a.len:
# ## if `r`.limbs.len < a.limbs.len * 2 ctx.mov rdx, a[i+1] # prepare next iteration
# ## The result will be truncated, i.e. it will be if i != 0:
# ## a² (mod (2^WordBitwidth)^r.limbs.len) ctx.adox lo, r[2*i]
# ## ctx.adcx lo, r[2*i]
# ## Assumes r doesn't aliases a ctx.mov r[2*i], lo
# result = newStmtList() if i+1 < a.len:
# ctx.adox hi, r[2*i+1]
# var ctx = init(Assembler_x86, BaseType) ctx.adcx hi, r[2*i+1]
# let else: # finish carry chain
# # Register count with 6 limbs: ctx.adox hi, zero
# # r + a + rax + rdx = 4 ctx.adcx hi, zero
# # t = 2 * a.len = 12 ctx.mov r[2*i+1], hi
# # We use the full x86 register set.
# func sqrx_gen4L(ctx: var Assembler_x86, r, a: OperandArray, t: var OperandArray) =
# r = init(OperandArray, nimSymbol = rx, rLen, PointerInReg, InputOutput) # a₃ a₂ a₁ a₀
# a = init(OperandArray, nimSymbol = ax, aLen, PointerInReg, Input) # * a₃ a₂ a₁ a₀
# # ------------------------------
# N = a.len # a₀*a₀
# tSlots = a.len * 2 # a₁*a₁
# # If aLen is too big, we need to spill registers. TODO. # a₂*a₂
# t = init(OperandArray, nimSymbol = ident"t", tSlots, ElemsInReg, Output_EarlyClobber) # a₃*a₃
# #
# # MULX requires RDX # a₂*a₁ a₁*a₀ |
# rRDX = Operand( # a₃*a₁ a₂*a₀ | * 2
# desc: OperandDesc( # a₃*a₂ a₃*a₀ |
# asmId: "[rdx]", #
# nimSymbol: ident"rdx", # r₇ r₆ r₅ r₄ r₃ r₂ r₁ r₀
# rm: RDX,
# constraint: Output_EarlyClobber, # First diagonal. a₀ * [aₙ₋₁ .. a₂ a₁]
# cEmit: "rdx" # ------------------------------------
# ) # This assumes that t will be rotated left and so
# ) # t1 is in t[0] and tn in t[n-1]
# ctx.mov rdx, a[0]
# # Scratch spaces for carries ctx.`xor` rax, rax # clear flags
# rRAX = Operand(
# desc: OperandDesc( ctx.comment "a₁*a₀"
# asmId: "[rax]", ctx.mulx t[1], rax, a[1], rdx # t₁ partial sum of r₂
# nimSymbol: ident"rax", ctx.mov r[1], rax
# rm: RAX,
# constraint: Output_EarlyClobber, ctx.comment "a₂*a₀"
# cEmit: "rax" ctx.mulx t[2], rax, a[2], rdx # t₂ partial sum of r₃
# ) ctx.add t[1], rax
# ) ctx.mov r[2], t[1] # r₂ finished
#
# # Prologue ctx.comment "a₃*a₀"
# # ------------------------------- ctx.mulx t[3], rax, a[3], rdx # t₃ partial sum of r₄
# let tsym = t.nimSymbol ctx.mov rdx, a[1] # prepare next iteration
# let eax = rRAX.desc.nimSymbol ctx.adc t[2], rax
# let edx = rRDX.desc.nimSymbol ctx.adc t[3], 0 # final carry in r₄
# result.add quote do:
# var `tsym`{.noInit.}: array[`N`, BaseType] # Second diagonal, a₂*a₁, a₃*a₁, a₃*a₂
# var `eax`{.noInit.}, `edx`{.noInit.}: BaseType # ------------------------------------
#
# # Algorithm ctx.`xor` t[a.len], t[a.len] # Clear flags and upper word
# # ------------------------------- t.rotateLeft() # Our schema are big-endian (rotate right)
# t.rotateLeft() # but we are little-endian (rotateLeft)
# block: # Triangle # Partial sums: t₀ is r₃, t₁ is r₄, t₂ is r₅, t₃ is r₆
# # i = 0 let hi = t[a.len]
# # ----------------
# ctx.mov rRDX, a[0] ctx.comment "a₂*a₁"
# # Put a[1..<N] in unused registers, 4 mov per cycle on x86 ctx.mulx hi, rax, a[2], rdx
# for i in 1 ..< N: ctx.adox t[0], rax # t₀ partial sum r₃
# ctx.mov t[i+N], a[i] ctx.mov r[3], t[0] # r₃ finished
# ctx.adcx t[1], hi # t₁ partial sum r₄
# let # Carry handlers
# hi = r.reuseRegister() ctx.comment "a₃*a₁"
# lo = rRAX ctx.mulx hi, rax, a[3], rdx
# ctx.mov rdx, a[2] # prepare next iteration
# for j in 1 ..< N: ctx.adox t[1], rax # t₁ partial sum r₄
# # (carry, t[j]) <- a[j] * a[0] with a[j] in t[j+N] ctx.mov r[4], t[1] # r₄ finished
# ctx.mulx t[j], rRAX, t[j+N], rdx ctx.adcx t[2], hi # t₂ partial sum r₅
# if j == 1:
# ctx.add t[j-1], rRAX ctx.comment "a₃*a₂"
# else: ctx.mulx hi, rax, a[3], rdx
# ctx.adc t[j-1], rRAX ctx.mov rdx, 0 # Set to 0 without clearing flags
# ctx.adc t[N-1], 0 ctx.adox t[2], rax # t₂ partial sum r₅
# ctx.mov r[5], t[2] # r₅ finished
# for i in 1 ..< N-1: ctx.adcx hi, rdx # Terminate carry chains
# ctx.comment " Process squaring triangle " & $i ctx.adox hi, rdx
# ctx.mov rRDX, a[i] ctx.mov rdx, a[0] # prepare next iteration
# ctx.`xor` t[i+N], t[i+N] # Clear flags ctx.mov r[6], hi # r₆ finished
# for j in i+1 ..< N:
# ctx.mulx hi, lo, t[j+N], rdx # a[i] * a[i] + 2 * r[2n-1 .. 1]
# ctx.adox t[i+j], lo # ------------------------------
# if j == N-1: #
# break # a₃ a₂ a₁ a₀
# ctx.adcx t[i+j+1], hi # * a₃ a₂ a₁ a₀
# # ------------------------------
# ctx.comment " Accumulate last carries in i+N word" # a₀*a₀
# # t[i+N] is already 0 # a₁*a₁
# ctx.adcx hi, t[i+N] # a₂*a₂
# ctx.adox t[i+N], hi # a₃*a₃
# #
# block: # a₂*a₁ a₁*a₀ |
# ctx.comment "Finish: (t[2*i+1], t[2*i]) <- 2*t[2*i] + a[i]*a[i]" # a₃*a₁ a₂*a₀ | * 2
# # a₃*a₂ a₃*a₀ |
# # Restore result #
# ctx.mov r.reuseRegister(), xmm0 # r₇ r₆ r₅ r₄ r₃ r₂ r₁ r₀
#
# ctx.mov rRDX, a[0] # a₀ in RDX
# var
# # (t[2*i+1], t[2*i]) <- 2*t[2*i] + a[i]*a[i] hi1 = hi
# for i in 0 ..< N: lo1 = rax
# ctx.mulx rRAX, rRDX, a[i], rdx var
# ctx.add t[2*i], t[2*i] hi2 = t[1]
# ctx.adc t[2*i+1], 0 lo2 = t[0]
# ctx.add t[2*i], rRDX
# if i != N - 1: ctx.comment "ai*ai + 2*r[1..<2*n-2]"
# ctx.mov rRDX, a[i+1] let zero = t[2]
# ctx.adc t[2*i+1], rRAX ctx.`xor` zero, zero # clear flags, break dependency chains
# ctx.mov r[i], t[i]
# merge_diag_and_partsum(r, a, hi1, lo1, zero, 0)
# # Move the rest merge_diag_and_partsum(r, a, hi2, lo2, zero, 1)
# for i in N ..< min(rLen, 2*N): merge_diag_and_partsum(r, a, hi1, lo1, zero, 2)
# ctx.mov r[i], t[i] merge_diag_and_partsum(r, a, hi2, lo2, zero, 3)
#
# # Codegen
# result.add ctx.generate func sqrx_gen6L(ctx: var Assembler_x86, r, a: OperandArray, t: var OperandArray) =
# # a₅ a₄ a₃ a₂ a₁ a₀
# func square_asm_adx_bmi2*[rLen, aLen: static int](r: var Limbs[rLen], a: Limbs[aLen]) = # * a₅ a₄ a₃ a₂ a₁ a₀
# ## Multi-precision Squaring # -------------------------------------
# ## Assumes r doesn't alias a # a₀*a₀
# sqrx_gen(r, a) # a₁*a₁
# a₂*a₂
# a₃*a₃
# a₄*a₄
# a₅*a₅
#
# a₃*a₂ a₂*a₁ a₁*a₀ |
# a₄*a₂ a₃*a₁ a₂*a₀ |
# a₄*a₃ a₄*a₁ a₃*a₀ | * 2
# a₅*a₃ a₅*a₁ a₄*a₀ |
# a₅*a₄ a₅*a₂ a₅*a₀ |
#
#
# r₁₁ r₁₀ r₉ r₈ r₇ r₆ r₅ r₄ r₃ r₂ r₁ r₀
# First diagonal. a₀ * [aₙ₋₁ .. a₂ a₁]
# ------------------------------------
# This assumes that t will be rotated left and so
# t1 is in t[0] and tn in t[n-1]
ctx.mov rdx, a[0]
ctx.`xor` rax, rax # clear flags
ctx.comment "a₁*a₀"
ctx.mulx t[1], rax, a[1], rdx # t₁ partial sum of r₂
ctx.mov r[1], rax
ctx.comment "a₂*a₀"
ctx.mulx t[2], rax, a[2], rdx # t₂ partial sum of r₃
ctx.add t[1], rax
ctx.mov r[2], t[1] # r₂ finished
ctx.comment "a₃*a₀"
ctx.mulx t[3], rax, a[3], rdx # t₃ partial sum of r₄
ctx.adc t[2], rax
ctx.comment "a₄*a₀"
ctx.mulx t[4], rax, a[4], rdx # t₄ partial sum of r₅
ctx.adc t[3], rax
ctx.comment "a₅*a₀"
ctx.mulx t[5], rax, a[5], rdx # t₅ partial sum of r₆
ctx.mov rdx, a[1] # prepare next iteration
ctx.adc t[4], rax
ctx.adc t[5], 0 # final carry in r₆
# Second diagonal, a₂*a₁, a₃*a₁, a₄*a₁, a₅*a₁, a₅*a₂
# --------------------------------------------------
ctx.`xor` t[a.len], t[a.len] # Clear flags and upper word
t.rotateLeft() # Our schema are big-endian (rotate right)
t.rotateLeft() # but we are little-endian (rotateLeft)
# Partial sums: t₀ is r₃, t₁ is r₄, t₂ is r₅, t₃ is r₆, ...
block:
let hi = t[a.len]
ctx.comment "a₂*a₁"
ctx.mulx hi, rax, a[2], rdx
ctx.adox t[0], rax # t₀ partial sum r₃
ctx.mov r[3], t[0] # r₃ finished
ctx.adcx t[1], hi # t₁ partial sum r₄
ctx.comment "a₃*a₁"
ctx.mulx hi, rax, a[3], rdx
ctx.adox t[1], rax # t₁ partial sum r₄
ctx.mov r[4], t[1] # r₄ finished
ctx.adcx t[2], hi # t₂ partial sum r₅
ctx.comment "a₄*a₁"
ctx.mulx hi, rax, a[4], rdx
ctx.adox t[2], rax # t₂ partial sum r₅
ctx.adcx t[3], hi # t₃ partial sum r₆
ctx.comment "a₅*a₁"
ctx.mulx hi, rax, a[5], rdx
ctx.mov rdx, a[2] # prepare next iteration
ctx.adox t[3], rax # t₃ partial sum r₆
ctx.adcx t[4], hi # t₄ partial sum r₇
ctx.comment "a₅*a₂"
ctx.mulx t[5], rax, a[5], rdx
ctx.mov hi, 0 # Set to 0 `hi` (== t[6] = r₉) without clearing flags
ctx.adox t[4], rax # t₄ partial sum r₇
ctx.adcx t[5], hi # t₅ partial sum r₈, terminate carry chains
ctx.adox t[5], hi
# Third diagonal, a₃*a₂, a₄*a₂, a₄*a₃, a₅*a₃, a₅*a₄
# --------------------------------------------------
t.rotateLeft()
t.rotateLeft()
# Partial sums: t₀ is r₅, t₁ is r₆, t₂ is r₇, t₃ is r₈, t₄ is r₉, t₅ is r₁₀
# t₄ is r₉ and was set to zero, a₂ in RDX
block:
let hi = t[a.len]
ctx.`xor` hi, hi # t₅ is r₁₀ = 0, break dependency chains
ctx.comment "a₃*a₂"
ctx.mulx hi, rax, a[3], rdx
ctx.adox t[0], rax # t₀ partial sum r₅
ctx.mov r[5], t[0] # r₅ finished
ctx.adcx t[1], hi # t₁ partial sum r₆
ctx.comment "a₄*a₂"
ctx.mulx hi, rax, a[4], rdx
ctx.mov rdx, a[3] # prepare next iteration
ctx.adox t[1], rax # t₁ partial sum r₆
ctx.mov r[6], t[1] # r₆ finished
ctx.adcx t[2], hi # t₂ partial sum r₇
ctx.comment "a₄*a₃"
ctx.mulx hi, rax, a[4], rdx
ctx.adox t[2], rax # t₂ partial sum r₇
ctx.mov r[7], t[2] # r₇ finished
ctx.adcx t[3], hi # t₃ partial sum r₈
ctx.comment "a₅*a₃"
ctx.mulx hi, rax, a[5], rdx
ctx.mov rdx, a[4] # prepare next iteration
ctx.adox t[3], rax # t₃ partial sum r₈
ctx.mov r[8], t[3] # r₈ finished
ctx.adcx t[4], hi # t₄ partial sum r₉ (was zero)
ctx.comment "a₅*a₄"
ctx.mulx hi, rax, a[5], rdx
ctx.mov rdx, 0 # Set to 0 without clearing flags
ctx.adox t[4], rax # t₄ partial sum r₉
ctx.mov r[9], t[4] # r₉ finished
ctx.adcx hi, rdx # Terminate carry chains
ctx.adox hi, rdx
ctx.mov rdx, a[0] # prepare next iteration
ctx.mov r[10], hi # r₁₀ finished
# a[i] * a[i] + 2 * r[2n-1 .. 1]
# -------------------------------------
#
# a₅ a₄ a₃ a₂ a₁ a₀
# * a₅ a₄ a₃ a₂ a₁ a₀
# -------------------------------------
# a₀*a₀
# a₁*a₁
# a₂*a₂
# a₃*a₃
# a₄*a₄
# a₅*a₅
#
# a₃*a₂ a₂*a₁ a₁*a₀ |
# a₄*a₂ a₃*a₁ a₂*a₀ |
# a₄*a₃ a₄*a₁ a₃*a₀ | * 2
# a₅*a₃ a₅*a₁ a₄*a₀ |
# a₅*a₄ a₅*a₂ a₅*a₀ |
#
#
# r₁₁ r₁₀ r₉ r₈ r₇ r₆ r₅ r₄ r₃ r₂ r₁ r₀
# a₀ in RDX
var
hi1 = t[a.len]
lo1 = rax
var
hi2 = t[1]
lo2 = t[0]
ctx.comment "ai*ai + 2*r[1..<2*n-2]"
let zero = t[2]
ctx.`xor` zero, zero # clear flags, break dependency chains
merge_diag_and_partsum(r, a, hi1, lo1, zero, 0)
merge_diag_and_partsum(r, a, hi2, lo2, zero, 1)
merge_diag_and_partsum(r, a, hi1, lo1, zero, 2)
merge_diag_and_partsum(r, a, hi2, lo2, zero, 3)
merge_diag_and_partsum(r, a, hi1, lo1, zero, 4)
merge_diag_and_partsum(r, a, hi2, lo2, zero, 5)
macro sqrx_gen*[rLen, aLen: static int](rx: var Limbs[rLen], ax: Limbs[aLen]) =
## Squaring
## `a` and `r` can have a different number of limbs
## if `r`.limbs.len < a.limbs.len * 2
## The result will be truncated, i.e. it will be
## a² (mod (2^WordBitwidth)^r.limbs.len)
##
## Assumes r doesn't aliases a
result = newStmtList()
var ctx = init(Assembler_x86, BaseType)
let
# Register count with 6 limbs:
# r + a + rax + rdx = 4
# t = 2 * a.len = 12
# We use the full x86 register set.
r = init(OperandArray, nimSymbol = rx, rLen, PointerInReg, InputOutput)
a = init(OperandArray, nimSymbol = ax, aLen, PointerInReg, Input)
# MULX requires RDX
tSlots = aLen+1 # Extra for high word
var # If aLen is too big, we need to spill registers. TODO.
t = init(OperandArray, nimSymbol = ident"t", tSlots, ElemsInReg, Output_EarlyClobber)
# Prologue
# -------------------------------
let tsym = t.nimSymbol
result.add quote do:
var `tsym`{.noInit.}: array[`tSlots`, BaseType]
if aLen == 4:
ctx.sqrx_gen4L(r, a, t)
elif aLen == 6:
ctx.sqrx_gen6L(r, a, t)
else:
error: "Not implemented"
# Codegen
result.add ctx.generate
func square_asm_adx_bmi2*[rLen, aLen: static int](r: var Limbs[rLen], a: Limbs[aLen]) =
## Multi-precision Squaring
## Assumes r doesn't alias a
sqrx_gen(r, a)
# Sanity checks
# ----------------------------------------------------------
when isMainModule:
import
../../config/[type_bigint, common],
../../arithmetic/limbs
type SW = SecretWord
# 4 limbs
# --------------------------------
proc mainSqr1() =
var a = [SW 0xFFFF_FFFF_FFFF_FFFF'u64, SW 0xFFFF_FFFF_FFFF_FFFF'u64, SW 0xFFFF_FFFF_FFFF_FFFF'u64, SW 0xFFFF_FFFF_FFFF_FFFF'u64]
var a2x, expected: Limbs[8]
a2x.square_asm_adx_bmi2(a)
expected.mul_asm_adx_bmi2(a, a)
debugecho "--------------------------------"
debugecho "before:"
debugecho " a : ", a.toString()
debugecho "after:"
debugecho " a2x: ", a2x.toString()
debugecho " ref: ", expected.toString()
doAssert bool(a2x == expected)
proc mainSqr2() =
var a = [SW 0x2'u64, SW 0x1'u64, SW 0x1'u64, SW 0x2'u64]
var a2x, expected: Limbs[8]
a2x.square_asm_adx_bmi2(a)
expected.mul_asm_adx_bmi2(a, a)
debugecho "--------------------------------"
debugecho "before:"
debugecho " a : ", a.toString()
debugecho "after:"
debugecho " a2x: ", a2x.toString()
debugecho " ref: ", expected.toString()
doAssert bool(a2x == expected)
mainSqr1()
mainSqr2()
# 6 limbs
# --------------------------------
proc mainSqr3() =
var a = [SW 0xFFFF_FFFF_FFFF_FFFF'u64, SW 0xFFFF_FFFF_FFFF_FFFF'u64, SW 0xFFFF_FFFF_FFFF_FFFF'u64, SW 0xFFFF_FFFF_FFFF_FFFF'u64, SW 0xFFFF_FFFF_FFFF_FFFF'u64, SW 0xFFFF_FFFF_FFFF_FFFF'u64]
var a2x, expected: Limbs[12]
a2x.square_asm_adx_bmi2(a)
expected.mul_asm_adx_bmi2(a, a)
debugecho "--------------------------------"
debugecho "before:"
debugecho " a : ", a.toString()
debugecho "after:"
debugecho " a2x: ", a2x.toString()
debugecho " ref: ", expected.toString()
doAssert bool(a2x == expected)
proc mainSqr4() =
var a = [SW 0x1'u64, SW 0x2'u64, SW 0x2'u64, SW 0x2'u64, SW 0x1'u64, SW 0x1'u64,]
var a2x, expected: Limbs[12]
a2x.square_asm_adx_bmi2(a)
expected.mul_asm_adx_bmi2(a, a)
debugecho "--------------------------------"
debugecho "before:"
debugecho " a : ", a.toString()
debugecho "after:"
debugecho " a2x: ", a2x.toString()
debugecho " ref: ", expected.toString()
doAssert bool(a2x == expected)
mainSqr3()
mainSqr4()

View File

@ -200,7 +200,13 @@ func square*[rLen, aLen](
## a² (mod (2^WordBitwidth)^r.limbs.len) ## a² (mod (2^WordBitwidth)^r.limbs.len)
## ##
## `r` must not alias ``a`` or ``b`` ## `r` must not alias ``a`` or ``b``
when UseASM_X86_64: when UseASM_X86_64 and aLen in {4, 6} and rLen == 2*aLen:
# ADX implies BMI2
if ({.noSideEffect.}: hasAdx()):
square_asm_adx_bmi2(r, a)
else:
square_asm(r, a)
elif UseASM_X86_64:
square_asm(r, a) square_asm(r, a)
else: else:
square_comba(r, a) square_comba(r, a)

View File

@ -12,7 +12,7 @@ import
# Internal # Internal
../config/common, ../config/common,
../primitives, ../primitives,
./limbs ./limbs, ./limbs_extmul
when UseASM_X86_32: when UseASM_X86_32:
import ./assembly/limbs_asm_montred_x86 import ./assembly/limbs_asm_montred_x86
@ -51,6 +51,114 @@ when UseASM_X86_64:
# No exceptions allowed # No exceptions allowed
{.push raises: [].} {.push raises: [].}
# Montgomery Reduction
# ------------------------------------------------------------
func montyRedc2x_CIOS[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType) =
## Montgomery reduce a double-precision bigint modulo M
# - Analyzing and Comparing Montgomery Multiplication Algorithms
# Cetin Kaya Koc and Tolga Acar and Burton S. Kaliski Jr.
# http://pdfs.semanticscholar.org/5e39/41ff482ec3ee41dc53c3298f0be085c69483.pdf
#
# - Arithmetic of Finite Fields
# Chapter 5 of Guide to Pairing-Based Cryptography
# Jean Luc Beuchat, Luis J. Dominguez Perez, Sylvain Duquesne, Nadia El Mrabet, Laura Fuentes-Castañeda, Francisco Rodríguez-Henríquez, 2017
# https://www.researchgate.net/publication/319538235_Arithmetic_of_Finite_Fields
#
# Algorithm
# Inputs:
# - N number of limbs
# - a[0 ..< 2N] (double-precision input to reduce)
# - M[0 ..< N] The field modulus (must be odd for Montgomery reduction)
# - m0ninv: Montgomery Reduction magic number = -1/M[0]
# Output:
# - r[0 ..< N], in the Montgomery domain
# Parameters:
# - w, the word width usually 64 on 64-bit platforms or 32 on 32-bit
#
# for i in 0 .. n-1:
# C <- 0
# m <- a[i] * m0ninv mod 2^w (i.e. simple multiplication)
# for j in 0 .. n-1:
# (C, S) <- a[i+j] + m * M[j] + C
# a[i+j] <- S
# a[i+n] += C
# for i in 0 .. n-1:
# r[i] = a[i+n]
# if r >= M:
# r -= M
#
# Important note: `a[i+n] += C` should propagate the carry
# to the higher limb if any, thank you "implementation detail"
# missing from paper.
var a = a # Copy "t" for mutation and ensure on stack
var res: typeof(r) # Accumulator
staticFor i, 0, N:
var C = Zero
let m = a[i] * SecretWord(m0ninv)
staticFor j, 0, N:
muladd2(C, a[i+j], m, M[j], a[i+j], C)
res[i] = C
# This does t[i+n] += C
# but in a separate carry chain, fused with the
# copy "r[i] = t[i+n]"
var carry = Carry(0)
staticFor i, 0, N:
addC(carry, res[i], a[i+N], res[i], carry)
# Final substraction
discard res.csub(M, SecretWord(carry).isNonZero() or not(res < M))
r = res
func montyRedc2x_Comba[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType) =
## Montgomery reduce a double-precision bigint modulo M
# We use Product Scanning / Comba multiplication
var t, u, v = Zero
var carry: Carry
var z: typeof(r) # zero-init, ensure on stack and removes in-place problems in tower fields
staticFor i, 0, N:
staticFor j, 0, i:
mulAcc(t, u, v, z[j], M[i-j])
addC(carry, v, v, a[i], Carry(0))
addC(carry, u, u, Zero, carry)
addC(carry, t, t, Zero, carry)
z[i] = v * SecretWord(m0ninv)
mulAcc(t, u, v, z[i], M[0])
v = u
u = t
t = Zero
staticFor i, N, 2*N-1:
staticFor j, i-N+1, N:
mulAcc(t, u, v, z[j], M[i-j])
addC(carry, v, v, a[i], Carry(0))
addC(carry, u, u, Zero, carry)
addC(carry, t, t, Zero, carry)
z[i-N] = v
v = u
u = t
t = Zero
addC(carry, z[N-1], v, a[2*N-1], Carry(0))
# Final substraction
discard z.csub(M, SecretBool(carry) or not(z < M))
r = z
# Montgomery Multiplication # Montgomery Multiplication
# ------------------------------------------------------------ # ------------------------------------------------------------
@ -172,224 +280,66 @@ func montyMul_FIPS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) =
r = z r = z
# Montgomery Squaring # Montgomery Squaring
# ------------------------------------------------------------ # --------------------------------------------------------------------------------------------------------------------
#
func montySquare_CIOS_nocarry(r: var Limbs, a, M: Limbs, m0ninv: BaseType) {.used.}= # There are Montgomery squaring multiplications mentioned in the litterature
## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS) # - https://hackmd.io/@zkteam/modular_multiplication if M[^1] < high(SecretWord) shr 2 (i.e. less than 0b00111...1111)
## and no-carry optimization. # - Architectural Support for Long Integer Modulo Arithmetic on Risc-Based Smart Cards
## This requires the most significant word of the Modulus # Johann Großschädl, 2003
## M[^1] < high(SecretWord) shr 2 (i.e. less than 0b00111...1111) # https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=95950BAC26A728114431C0C7B425E022?doi=10.1.1.115.3276&rep=rep1&type=pdf
## https://hackmd.io/@zkteam/modular_multiplication # - Analyzing and Comparing Montgomery Multiplication Algorithms
# Koc, Acar, Kaliski, 1996
# TODO: Deactivated # https://www.semanticscholar.org/paper/Analyzing-and-comparing-Montgomery-multiplication-Ko%C3%A7-Acar/5e3941ff482ec3ee41dc53c3298f0be085c69483
# Off-by one on 32-bit on the least significant bit #
# for Fp[BLS12-381] with inputs # However fuzzing the implementation showed off-by-one on certain inputs especially in 32-bit mode
# - -0x091F02EFA1C9B99C004329E94CD3C6B308164CBE02037333D78B6C10415286F7C51B5CD7F917F77B25667AB083314B1B #
# - -0x0B7C8AFE5D43E9A973AF8649AD8C733B97D06A78CFACD214CBE9946663C3F682362E0605BC8318714305B249B505AFD9 # for Fp[BLS12-381] on 32 bit with inputs
# - 0x091F02EFA1C9B99C004329E94CD3C6B308164CBE02037333D78B6C10415286F7C51B5CD7F917F77B25667AB083314B1B
# We want all the computation to be kept in registers # - 0x0B7C8AFE5D43E9A973AF8649AD8C733B97D06A78CFACD214CBE9946663C3F682362E0605BC8318714305B249B505AFD9
# hence we use a temporary `t`, hoping that the compiler does it. # for Consensys/zkteam algorithm (off by one in least significant bit)
var t: typeof(M) # zero-init #
const N = t.len # for Fp[2^127 - 1] with inputs
staticFor i, 0, N: # - -0x75bfffefbfffffff7fd9dfd800000000
# Squaring # - -0x7ff7ffffffffffff1dfb7fafc0000000
var # Squaring the number and its opposite
A1: Carry # should give the same result, but those are off-by-one
A0: SecretWord # with Großschädl algorithm
# (A0, t[i]) <- a[i] * a[i] + t[i] #
muladd1(A0, t[i], a[i], a[i], t[i]) # I suspect either I did a twice the same mistake when translating 2 different algorithms
staticFor j, i+1, N: # or there is a carry propagation constraint that prevents interleaving squaring
# (A1, A0, t[j]) <- 2*a[j]*a[i] + t[j] + (A1, A0) # and Montgomery reduction in the following loops
# 2*a[j]*a[i] can spill 1-bit on a 3rd word # for i in 0 ..< N:
mulDoubleAdd2(A1, A0, t[j], a[j], a[i], t[j], A1, A0) # for j in i+1 ..< N: # <-- squaring, notice that we start at i+1 but carries from below may still impact us.
# ...
# Reduction # for j in 1 ..< N: # <- Montgomery reduce.
# m <- (t[0] * m0ninv) mod 2^w
# (C, _) <- m * M[0] + t[0]
let m = t[0] * SecretWord(m0ninv)
var C, lo: SecretWord
muladd1(C, lo, m, M[0], t[0])
staticFor j, 1, N:
# (C, t[j-1]) <- m*M[j] + t[j] + C
muladd2(C, t[j-1], m, M[j], t[j], C)
t[N-1] = C + A0
discard t.csub(M, not(t < M))
r = t
func montySquare_CIOS(r: var Limbs, a, M: Limbs, m0ninv: BaseType) {.used.}=
## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS)
##
## Architectural Support for Long Integer Modulo Arithmetic on Risc-Based Smart Cards
## Johann Großschädl, 2003
## https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=95950BAC26A728114431C0C7B425E022?doi=10.1.1.115.3276&rep=rep1&type=pdf
##
## Analyzing and Comparing Montgomery Multiplication Algorithms
## Koc, Acar, Kaliski, 1996
## https://www.semanticscholar.org/paper/Analyzing-and-comparing-Montgomery-multiplication-Ko%C3%A7-Acar/5e3941ff482ec3ee41dc53c3298f0be085c69483
# TODO: Deactivated
# Off-by one on 32-bit on the least significant bit
# for Fp[2^127 - 1] with inputs
# - -0x75bfffefbfffffff7fd9dfd800000000
# - -0x7ff7ffffffffffff1dfb7fafc0000000
# Squaring the number and its opposite
# should give the same result, but those are off-by-one
# We want all the computation to be kept in registers
# hence we use a temporary `t`, hoping that the compiler does it.
var t: typeof(M) # zero-init
const N = t.len
# Extra words to handle up to 2 carries t[N] and t[N+1]
var tNp1: SecretWord
var tN: SecretWord
staticFor i, 0, N:
# Squaring
var A1 = Carry(0)
var A0: SecretWord
# (A0, t[i]) <- a[i] * a[i] + t[i]
muladd1(A0, t[i], a[i], a[i], t[i])
staticFor j, i+1, N:
# (A1, A0, t[j]) <- 2*a[j]*a[i] + t[j] + (A1, A0)
# 2*a[j]*a[i] can spill 1-bit on a 3rd word
mulDoubleAdd2(A1, A0, t[j], a[j], a[i], t[j], A1, A0)
var carryS: Carry
addC(carryS, tN, tN, A0, Carry(0))
addC(carryS, tNp1, SecretWord(A1), Zero, carryS)
# Reduction
# m <- (t[0] * m0ninv) mod 2^w
# (C, _) <- m * M[0] + t[0]
var C, lo: SecretWord
let m = t[0] * SecretWord(m0ninv)
muladd1(C, lo, m, M[0], t[0])
staticFor j, 1, N:
# (C, t[j-1]) <- m*M[j] + t[j] + C
muladd2(C, t[j-1], m, M[j], t[j], C)
# (C,t[N-1]) <- t[N] + C
# (_, t[N]) <- t[N+1] + C
var carryR: Carry
addC(carryR, t[N-1], tN, C, Carry(0))
addC(carryR, tN, tNp1, Zero, carryR)
discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)^w
r = t
# Montgomery Reduction
# ------------------------------------------------------------
func montyRedc2x_CIOS[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType) =
## Montgomery reduce a double-precision bigint modulo M
# - Analyzing and Comparing Montgomery Multiplication Algorithms
# Cetin Kaya Koc and Tolga Acar and Burton S. Kaliski Jr.
# http://pdfs.semanticscholar.org/5e39/41ff482ec3ee41dc53c3298f0be085c69483.pdf
#
# - Arithmetic of Finite Fields
# Chapter 5 of Guide to Pairing-Based Cryptography
# Jean Luc Beuchat, Luis J. Dominguez Perez, Sylvain Duquesne, Nadia El Mrabet, Laura Fuentes-Castañeda, Francisco Rodríguez-Henríquez, 2017
# https://www.researchgate.net/publication/319538235_Arithmetic_of_Finite_Fields
#
# Algorithm
# Inputs:
# - N number of limbs
# - a[0 ..< 2N] (double-precision input to reduce)
# - M[0 ..< N] The field modulus (must be odd for Montgomery reduction)
# - m0ninv: Montgomery Reduction magic number = -1/M[0]
# Output:
# - r[0 ..< N], in the Montgomery domain
# Parameters:
# - w, the word width usually 64 on 64-bit platforms or 32 on 32-bit
#
# for i in 0 .. n-1:
# C <- 0
# m <- a[i] * m0ninv mod 2^w (i.e. simple multiplication)
# for j in 0 .. n-1:
# (C, S) <- a[i+j] + m * M[j] + C
# a[i+j] <- S
# a[i+n] += C
# for i in 0 .. n-1:
# r[i] = a[i+n]
# if r >= M:
# r -= M
#
# Important note: `a[i+n] += C` should propagate the carry
# to the higher limb if any, thank you "implementation detail"
# missing from paper.
var a = a # Copy "t" for mutation and ensure on stack
var res: typeof(r) # Accumulator
staticFor i, 0, N:
var C = Zero
let m = a[i] * SecretWord(m0ninv)
staticFor j, 0, N:
muladd2(C, a[i+j], m, M[j], a[i+j], C)
res[i] = C
# This does t[i+n] += C
# but in a separate carry chain, fused with the
# copy "r[i] = t[i+n]"
var carry = Carry(0)
staticFor i, 0, N:
addC(carry, res[i], a[i+N], res[i], carry)
# Final substraction
discard res.csub(M, SecretWord(carry).isNonZero() or not(res < M))
r = res
func montyRedc2x_Comba[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType) =
## Montgomery reduce a double-precision bigint modulo M
# We use Product Scanning / Comba multiplication
var t, u, v = Zero
var carry: Carry
var z: typeof(r) # zero-init, ensure on stack and removes in-place problems in tower fields
staticFor i, 0, N:
staticFor j, 0, i:
mulAcc(t, u, v, z[j], M[i-j])
addC(carry, v, v, a[i], Carry(0))
addC(carry, u, u, Zero, carry)
addC(carry, t, t, Zero, carry)
z[i] = v * SecretWord(m0ninv)
mulAcc(t, u, v, z[i], M[0])
v = u
u = t
t = Zero
staticFor i, N, 2*N-1:
staticFor j, i-N+1, N:
mulAcc(t, u, v, z[j], M[i-j])
addC(carry, v, v, a[i], Carry(0))
addC(carry, u, u, Zero, carry)
addC(carry, t, t, Zero, carry)
z[i-N] = v
v = u
u = t
t = Zero
addC(carry, z[N-1], v, a[2*N-1], Carry(0))
# Final substraction
discard z.csub(M, SecretBool(carry) or not(z < M))
r = z
# Exported API # Exported API
# ------------------------------------------------------------ # ------------------------------------------------------------
# TODO upstream, using Limbs[N] breaks semcheck
func montyRedc2x*[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType, spareBits: static int) {.inline.} =
## Montgomery reduce a double-precision bigint modulo M
when UseASM_X86_64 and r.len <= 6:
# ADX implies BMI2
if ({.noSideEffect.}: hasAdx()):
montRed_asm_adx_bmi2(r, a, M, m0ninv, spareBits >= 1)
else:
when r.len in {3..6}:
montRed_asm(r, a, M, m0ninv, spareBits >= 1)
else:
montyRedc2x_CIOS(r, a, M, m0ninv)
# montyRedc2x_Comba(r, a, M, m0ninv)
elif UseASM_X86_64 and r.len in {3..6}:
# TODO: Assembly faster than GCC but slower than Clang
montRed_asm(r, a, M, m0ninv, spareBits >= 1)
else:
montyRedc2x_CIOS(r, a, M, m0ninv)
# montyRedc2x_Comba(r, a, M, m0ninv)
func montyMul*( func montyMul*(
r: var Limbs, a, b, M: Limbs, r: var Limbs, a, b, M: Limbs,
m0ninv: static BaseType, spareBits: static int) {.inline.} = m0ninv: static BaseType, spareBits: static int) {.inline.} =
@ -431,53 +381,27 @@ func montyMul*(
else: else:
montyMul_FIPS(r, a, b, M, m0ninv) montyMul_FIPS(r, a, b, M, m0ninv)
func montySquare*(r: var Limbs, a, M: Limbs, func montySquare*[N](r: var Limbs[N], a, M: Limbs[N],
m0ninv: static BaseType, spareBits: static int) {.inline.} = m0ninv: static BaseType, spareBits: static int) {.inline.} =
## Compute r <- a^2 (mod M) in the Montgomery domain ## Compute r <- a^2 (mod M) in the Montgomery domain
## `m0ninv` = -1/M (mod SecretWord). Our words are 2^31 or 2^63 ## `m0ninv` = -1/M (mod SecretWord). Our words are 2^31 or 2^63
# TODO: needs optimization similar to multiplication when UseASM_X86_64 and a.len in {4, 6}:
montyMul(r, a, a, M, m0ninv, spareBits)
# when spareBits >= 2:
# # TODO: Deactivated
# # Off-by one on 32-bit on the least significant bit
# # for Fp[BLS12-381] with inputs
# # - -0x091F02EFA1C9B99C004329E94CD3C6B308164CBE02037333D78B6C10415286F7C51B5CD7F917F77B25667AB083314B1B
# # - -0x0B7C8AFE5D43E9A973AF8649AD8C733B97D06A78CFACD214CBE9946663C3F682362E0605BC8318714305B249B505AFD9
#
# # montySquare_CIOS_nocarry(r, a, M, m0ninv)
# montyMul_CIOS_nocarry(r, a, a, M, m0ninv)
# else:
# # TODO: Deactivated
# # Off-by one on 32-bit for Fp[2^127 - 1] with inputs
# # - -0x75bfffefbfffffff7fd9dfd800000000
# # - -0x7ff7ffffffffffff1dfb7fafc0000000
# # Squaring the number and its opposite
# # should give the same result, but those are off-by-one
#
# # montySquare_CIOS(r, a, M, m0ninv) # TODO <--- Fix this
# montyMul_FIPS(r, a, a, M, m0ninv)
# TODO upstream, using Limbs[N] breaks semcheck
func montyRedc2x*[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType, spareBits: static int) {.inline.} =
## Montgomery reduce a double-precision bigint modulo M
when UseASM_X86_64 and r.len <= 6:
# ADX implies BMI2 # ADX implies BMI2
if ({.noSideEffect.}: hasAdx()): if ({.noSideEffect.}: hasAdx()):
montRed_asm_adx_bmi2(r, a, M, m0ninv, spareBits) # With ADX and spare bit, montSquare_CIOS_asm_adx_bmi2
# which uses unfused squaring then Montgomery reduction
# is slightly slower than fused Montgomery multiplication
when spareBits >= 1:
montMul_CIOS_nocarry_asm_adx_bmi2(r, a, a, M, m0ninv)
else: else:
montRed_asm(r, a, M, m0ninv, spareBits) montSquare_CIOS_asm_adx_bmi2(r, a, M, m0ninv, spareBits >= 1)
elif UseASM_X86_32 and r.len <= 6:
# TODO: Assembly faster than GCC but slower than Clang
montRed_asm(r, a, M, m0ninv, spareBits)
else: else:
montyRedc2x_CIOS(r, a, M, m0ninv) montSquare_CIOS_asm(r, a, M, m0ninv, spareBits >= 1)
# montyRedc2x_Comba(r, a, M, m0ninv) else:
var r2x {.noInit.}: Limbs[2*N]
r2x.square(a)
r.montyRedc2x(r2x, M, m0ninv, spareBits)
func redc*(r: var Limbs, a, one, M: Limbs, func redc*(r: var Limbs, a, one, M: Limbs,
m0ninv: static BaseType, spareBits: static int) = m0ninv: static BaseType, spareBits: static int) =

View File

@ -125,7 +125,7 @@ macro debugConsts(): untyped {.used.} =
for i in 1 ..< E.len: for i in 1 ..< E.len:
let curve = E[i] let curve = E[i]
let curveName = $curve let curveName = $curve
let modulus = bindSym(curveName & "_Fp_Modulus") let modulus = bindSym(curveName & "_Modulus")
let r2modp = bindSym(curveName & "_Fp_R2modP") let r2modp = bindSym(curveName & "_Fp_R2modP")
let negInvModWord = bindSym(curveName & "_Fp_NegInvModWord") let negInvModWord = bindSym(curveName & "_Fp_NegInvModWord")

View File

@ -42,9 +42,24 @@ type
# Clobbered register # Clobbered register
ClobberedReg ClobberedReg
when sizeof(int) == 8 and not defined(Constantine32):
type
Register* = enum Register* = enum
rbx, rdx, r8, rax, xmm0 rbx
rdx
r8
rax
xmm0
else:
type
Register* = enum
rbx = "ebx"
rdx = "edx"
r8 = "r8d"
rax = "eax"
xmm0
type
Constraint* = enum Constraint* = enum
## GCC extended assembly modifier ## GCC extended assembly modifier
Input = "" Input = ""
@ -117,6 +132,9 @@ func len*(opArray: Operand): int =
func rotateLeft*(opArray: var OperandArray) = func rotateLeft*(opArray: var OperandArray) =
opArray.buf.rotateLeft(1) opArray.buf.rotateLeft(1)
func rotateRight*(opArray: var OperandArray) =
opArray.buf.rotateLeft(-1)
proc `[]`*(opArray: OperandArray, index: int): Operand = proc `[]`*(opArray: OperandArray, index: int): Operand =
opArray.buf[index] opArray.buf[index]
@ -941,7 +959,7 @@ func mulx*(a: var Assembler_x86, dHi, dLo: Register, src0: Operand, src1: Regist
a.regClobbers.incl dHi a.regClobbers.incl dHi
a.regClobbers.incl dLo a.regClobbers.incl dLo
func adcx*(a: var Assembler_x86, dst: Operand|OperandReuse, src: Operand|OperandReuse|Register) = func adcx*(a: var Assembler_x86, dst: Operand|OperandReuse|Register, src: Operand|OperandReuse|Register) =
## Does: dst <- dst + src + carry ## Does: dst <- dst + src + carry
## and only sets the carry flag ## and only sets the carry flag
when dst is Operand: when dst is Operand:
@ -950,7 +968,7 @@ func adcx*(a: var Assembler_x86, dst: Operand|OperandReuse, src: Operand|Operand
a.codeFragment("adcx", src, dst) a.codeFragment("adcx", src, dst)
a.areFlagsClobbered = true a.areFlagsClobbered = true
func adox*(a: var Assembler_x86, dst: Operand|OperandReuse, src: Operand|OperandReuse|Register) = func adox*(a: var Assembler_x86, dst: Operand|OperandReuse|Register, src: Operand|OperandReuse|Register) =
## Does: dst <- dst + src + overflow ## Does: dst <- dst + src + overflow
## and only sets the overflow flag ## and only sets the overflow flag
when dst is Operand: when dst is Operand:

View File

@ -109,7 +109,10 @@ template sqrTest(rng_gen: untyped): untyped =
mulDbl.prod2x(a, a) mulDbl.prod2x(a, a)
sqrDbl.square2x(a) sqrDbl.square2x(a)
doAssert bool(mulDbl == sqrDbl) doAssert bool(mulDbl == sqrDbl),
"\nOriginal: " & a.mres.limbs.toString() &
"\n Mul: " & mulDbl.limbs2x.toString() &
"\n Sqr: " & sqrDbl.limbs2x.toString()
addsubnegTest(random_unsafe) addsubnegTest(random_unsafe)
addsubnegTest(randomHighHammingWeight) addsubnegTest(randomHighHammingWeight)

View File

@ -0,0 +1 @@
-d:debugConstantine

View File

@ -145,7 +145,6 @@ suite "Random Modular Squaring is consistent with Modular Multiplication" & " ["
random_long01Seq(P224) random_long01Seq(P224)
test "Random squaring mod P-256 [FastSquaring = " & $(Fp[P256].getSpareBits() >= 2) & "]": test "Random squaring mod P-256 [FastSquaring = " & $(Fp[P256].getSpareBits() >= 2) & "]":
echo "Fp[P256].getSpareBits(): ", Fp[P256].getSpareBits()
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
randomCurve(P256) randomCurve(P256)
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
@ -173,8 +172,9 @@ suite "Modular squaring - bugs highlighted by property-based testing":
a.square() a.square()
na.square() na.square()
check: doAssert bool(a == na),
bool(a == na) "\n a² : " & a.mres.limbs.toString() &
"\n (-a)²: " & na.mres.limbs.toString()
var a2{.noInit.}, na2{.noInit.}: Fp[Mersenne127] var a2{.noInit.}, na2{.noInit.}: Fp[Mersenne127]
a2.fromHex"0x75bfffefbfffffff7fd9dfd800000000" a2.fromHex"0x75bfffefbfffffff7fd9dfd800000000"
@ -183,10 +183,13 @@ suite "Modular squaring - bugs highlighted by property-based testing":
a2 *= a2 a2 *= a2
na2 *= na2 na2 *= na2
check: doAssert(
bool(a2 == na2) bool(a2 == na2) and
bool(a2 == a) bool(a2 == a) and
bool(a2 == na) bool(a2 == na),
"\n a*a: " & a2.mres.limbs.toString() &
"\n (-a)*(-a): " & na2.mres.limbs.toString()
)
test "a² == (-a)² on for Fp[2^127 - 1] - #62": test "a² == (-a)² on for Fp[2^127 - 1] - #62":
var a{.noInit.}: Fp[Mersenne127] var a{.noInit.}: Fp[Mersenne127]
@ -199,8 +202,9 @@ suite "Modular squaring - bugs highlighted by property-based testing":
a.square() a.square()
na.square() na.square()
check: doAssert bool(a == na),
bool(a == na) "\n a² : " & a.mres.limbs.toString() &
"\n (-a)²: " & na.mres.limbs.toString()
var a2{.noInit.}, na2{.noInit.}: Fp[Mersenne127] var a2{.noInit.}, na2{.noInit.}: Fp[Mersenne127]
a2.fromHex"0x7ff7ffffffffffff1dfb7fafc0000000" a2.fromHex"0x7ff7ffffffffffff1dfb7fafc0000000"
@ -209,10 +213,13 @@ suite "Modular squaring - bugs highlighted by property-based testing":
a2 *= a2 a2 *= a2
na2 *= na2 na2 *= na2
check: doAssert(
bool(a2 == na2) bool(a2 == na2) and
bool(a2 == a) bool(a2 == a) and
bool(a2 == na) bool(a2 == na),
"\n a*a: " & a2.mres.limbs.toString() &
"\n (-a)*(-a): " & na2.mres.limbs.toString()
)
test "32-bit fast squaring on BLS12-381 - #42": test "32-bit fast squaring on BLS12-381 - #42":
# x = -(2^63 + 2^62 + 2^60 + 2^57 + 2^48 + 2^16) # x = -(2^63 + 2^62 + 2^60 + 2^57 + 2^48 + 2^16)

View File

@ -1 +1,2 @@
-d:testingCurves -d:testingCurves
-d:debugConstantine