Assembly for Fp2 (#161)

* Assembly for Fp2

* fix import
This commit is contained in:
Mamy Ratsimbazafy 2021-02-20 15:21:23 +01:00 committed by GitHub
parent aefd40f455
commit afb33a5a77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 192 additions and 12 deletions

View File

@ -20,8 +20,6 @@ import
# #
# ############################################################ # ############################################################
# TODO, MCL has an implementation about 14% faster
static: doAssert UseASM_X86_64 static: doAssert UseASM_X86_64
# MULX/ADCX/ADOX # MULX/ADCX/ADOX
@ -140,6 +138,17 @@ macro montyRedc2x_adx_gen*[N: static int](
# Code generation # Code generation
result.add ctx.generate() result.add ctx.generate()
func montRed_asm_adx_bmi2_impl*[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType,
hasSpareBit: static bool
) =
## Constant-time Montgomery reduction
## Inline-version
montyRedc2x_adx_gen(r, a, M, m0ninv, hasSpareBit)
func montRed_asm_adx_bmi2*[N: static int]( func montRed_asm_adx_bmi2*[N: static int](
r: var array[N, SecretWord], r: var array[N, SecretWord],
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
@ -148,4 +157,4 @@ func montRed_asm_adx_bmi2*[N: static int](
hasSpareBit: static bool hasSpareBit: static bool
) = ) =
## Constant-time Montgomery reduction ## Constant-time Montgomery reduction
montyRedc2x_adx_gen(r, a, M, m0ninv, hasSpareBit) montRed_asm_adx_bmi2_impl(r, a, M, m0ninv, hasSpareBit)

View File

@ -168,11 +168,18 @@ macro mulx_gen[rLen, aLen, bLen: static int](rx: var Limbs[rLen], ax: Limbs[aLen
# Codegen # Codegen
result.add ctx.generate result.add ctx.generate
func mul_asm_adx_bmi2*[rLen, aLen, bLen: static int](r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) = func mul_asm_adx_bmi2_impl*[rLen, aLen, bLen: static int](
r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) {.inline.} =
## Multi-precision Multiplication ## Multi-precision Multiplication
## Assumes r doesn't alias a or b ## Assumes r doesn't alias a or b
## Inline version
mulx_gen(r, a, b) mulx_gen(r, a, b)
func mul_asm_adx_bmi2*[rLen, aLen, bLen: static int](
r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) =
## Multi-precision Multiplication
## Assumes r doesn't alias a or b
mul_asm_adx_bmi2_impl(r, a, b)
# Squaring # Squaring
# ----------------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------------

View File

@ -0,0 +1,139 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
# Internal
../../config/[common, curves],
../../primitives,
../../arithmetic,
../../arithmetic/assembly/[
limbs_asm_mul_x86_adx_bmi2,
limbs_asm_montmul_x86_adx_bmi2,
limbs_asm_montred_x86_adx_bmi2
]
# ############################################################
# #
# Assembly implementation of 𝔽p2 #
# #
# ############################################################
static: doAssert UseASM_X86_64
# MULX/ADCX/ADOX
{.localPassC:"-madx -mbmi2".}
# Necessary for the compiler to find enough registers (enabled at -O1)
{.localPassC:"-fomit-frame-pointer".}
# No exceptions allowed
{.push raises: [].}
template c0*(a: array): auto =
a[0]
template c1*(a: array): auto =
a[1]
func has1extraBit(F: type Fp): bool =
## We construct extensions only on Fp (and not Fr)
getSpareBits(F) >= 1
func has2extraBits(F: type Fp): bool =
## We construct extensions only on Fp (and not Fr)
getSpareBits(F) >= 2
# 𝔽p2 squaring
# ------------------------------------------------------------
func sqrx2x_complex_asm_adx_bmi2*(
r: var array[2, FpDbl],
a: array[2, Fp]
) =
## Complex squaring on 𝔽p2
# This specialized proc inlines all calls and avoids many ADX support checks.
# and push/pop for paramater passing.
var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0)
when Fp.has1extraBit():
t0.sumUnr(a.c1, a.c1)
t1.sumUnr(a.c0, a.c1)
else:
t0.double(a.c1)
t1.sum(a.c0, a.c1)
r.c1.mul_asm_adx_bmi2_impl(t0, a.c0)
t0.diff(a.c0, a.c1)
r.c0.mul_asm_adx_bmi2_impl(t0, t1)
func sqrx_complex_asm_adx_bmi2*(
r: var array[2, Fp],
a: array[2, Fp]
) =
## Complex squaring on 𝔽p2
# This specialized proc inlines all calls and avoids many ADX support checks.
# and push/pop for paramater passing.
# Staying in 𝔽p and not using double-precision is faster for squaring
static: doAssert Fp.has1extraBit()
var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0)
v0.diff(a.c0, a.c1)
v1.sum(a.c0, a.c1)
r.c1.mres.limbs.montMul_CIOS_nocarry_asm_adx_bmi2(a.c0.mres.limbs, a.c1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord())
# aliasing: a unneeded now
r.c1.double()
r.c0.mres.limbs.montMul_CIOS_nocarry_asm_adx_bmi2(v0.mres.limbs, v1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord())
# 𝔽p2 multiplication
# ------------------------------------------------------------
func mulx2x_complex_asm_adx_bmi2*(
r: var array[2, FpDbl],
a, b: array[2, Fp]
) =
## Complex multiplication on 𝔽p2
var D {.noInit.}: typeof(r.c0)
var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0)
r.c0.limbs2x.mul_asm_adx_bmi2_impl(a.c0.mres.limbs, b.c0.mres.limbs)
D.limbs2x.mul_asm_adx_bmi2_impl(a.c1.mres.limbs, b.c1.mres.limbs)
when Fp.has1extraBit():
t0.sumUnr(a.c0, a.c1)
t1.sumUnr(b.c0, b.c1)
else:
t0.sum(a.c0, a.c1)
t1.sum(b.c0, b.c1)
r.c1.limbs2x.mul_asm_adx_bmi2_impl(t0.mres.limbs, t1.mres.limbs)
when Fp.has1extraBit():
r.c1.diff2xUnr(r.c1, r.c0)
r.c1.diff2xUnr(r.c1, D)
else:
r.c1.diff2xMod(r.c1, r.c0)
r.c1.diff2xMod(r.c1, D)
r.c0.diff2xMod(r.c0, D)
func mulx_complex_asm_adx_bmi2*(
r: var array[2, Fp],
a, b: array[2, Fp]
) =
## Complex multiplication on 𝔽p2
var d {.noInit.}: array[2,doublePrec(Fp)]
d.mulx2x_complex_asm_adx_bmi2(a, b)
r.c0.mres.limbs.montRed_asm_adx_bmi2_impl(
d.c0.limbs2x,
Fp.fieldMod().limbs,
Fp.getNegInvModWord(),
Fp.has1extraBit()
)
r.c1.mres.limbs.montRed_asm_adx_bmi2_impl(
d.c1.limbs2x,
Fp.fieldMod().limbs,
Fp.getNegInvModWord(),
Fp.has1extraBit()
)

View File

@ -12,6 +12,10 @@ import
../arithmetic, ../arithmetic,
../io/io_fields ../io/io_fields
when UseASM_X86_64:
import
./assembly/fp2_asm_x86_adx_bmi2
# Note: to avoid burdening the Nim compiler, we rely on generic extension # Note: to avoid burdening the Nim compiler, we rely on generic extension
# to complain if the base field procedures don't exist # to complain if the base field procedures don't exist
@ -807,8 +811,8 @@ func prod2x_complex(r: var QuadraticExt2x, a, b: Fp2) =
t1.sum(b.c0, b.c1) t1.sum(b.c0, b.c1)
r.c1.prod2x(t0, t1) # r1 = (b0 + b1)(a0 + a1) r.c1.prod2x(t0, t1) # r1 = (b0 + b1)(a0 + a1)
when Fp2.has1extraBit(): when Fp2.has1extraBit():
r.c1.diff2xUnr(r.c1, r.c0) # r1 = (b0 + b1)(a0 + a1) - a0 b0 r.c1.diff2xUnr(r.c1, r.c0) # r1 = (b0 + b1)(a0 + a1) - a0 b0
r.c1.diff2xUnr(r.c1, D) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - a1b1 r.c1.diff2xUnr(r.c1, D) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - a1b1
else: else:
r.c1.diff2xMod(r.c1, r.c0) r.c1.diff2xMod(r.c1, r.c0)
r.c1.diff2xMod(r.c1, D) r.c1.diff2xMod(r.c1, D)
@ -1227,7 +1231,13 @@ func square2x*(r: var QuadraticExt2x, a: QuadraticExt) =
func square*(r: var QuadraticExt, a: QuadraticExt) = func square*(r: var QuadraticExt, a: QuadraticExt) =
when r.fromComplexExtension(): when r.fromComplexExtension():
when true: when true:
r.square_complex(a) when UseASM_X86_64 and a.c0.mres.limbs.len <= 6:
if ({.noSideEffect.}: hasAdx()):
r.coords.sqrx_complex_asm_adx_bmi2(a.coords)
else:
r.square_complex(a)
else:
r.square_complex(a)
else: # slower else: # slower
var d {.noInit.}: doublePrec(typeof(r)) var d {.noInit.}: doublePrec(typeof(r))
d.square2x_complex(a) d.square2x_complex(a)
@ -1259,10 +1269,19 @@ func prod*(r: var QuadraticExt, a, b: QuadraticExt) =
when false: when false:
r.prod_complex(a, b) r.prod_complex(a, b)
else: # faster else: # faster
var d {.noInit.}: doublePrec(typeof(r)) when UseASM_X86_64 and a.c0.mres.limbs.len <= 6:
d.prod2x_complex(a, b) if ({.noSideEffect.}: hasAdx()):
r.c0.redc2x(d.c0) r.coords.mulx_complex_asm_adx_bmi2(a.coords, b.coords)
r.c1.redc2x(d.c1) else:
var d {.noInit.}: doublePrec(typeof(r))
d.prod2x_complex(a, b)
r.c0.redc2x(d.c0)
r.c1.redc2x(d.c1)
else:
var d {.noInit.}: doublePrec(typeof(r))
d.prod2x_complex(a, b)
r.c0.redc2x(d.c0)
r.c1.redc2x(d.c1)
else: else:
when r.typeof.F.C == BW6_761 or typeof(r.c0) is Fp: when r.typeof.F.C == BW6_761 or typeof(r.c0) is Fp:
# BW6-761 requires too many registers for Dbl width path # BW6-761 requires too many registers for Dbl width path
@ -1287,7 +1306,13 @@ func prod2x_disjoint*[Fdbl, F](
func prod2x*(r: var QuadraticExt2x, a, b: QuadraticExt) = func prod2x*(r: var QuadraticExt2x, a, b: QuadraticExt) =
## Double-precision multiplication r <- a*b ## Double-precision multiplication r <- a*b
when a.fromComplexExtension(): when a.fromComplexExtension():
r.prod2x_complex(a, b) when UseASM_X86_64 and a.c0.mres.limbs.len <= 6:
if ({.noSideEffect.}: hasAdx()):
r.coords.mulx2x_complex_asm_adx_bmi2(a.coords, b.coords)
else:
r.prod2x_complex(a, b)
else:
r.prod2x_complex(a, b)
else: else:
r.prod2x_disjoint(a.c0, a.c1, b.c0, b.c1) r.prod2x_disjoint(a.c0, a.c1, b.c0, b.c1)