diff --git a/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim b/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim index f2ac174..eb2cb18 100644 --- a/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim +++ b/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim @@ -20,8 +20,6 @@ import # # ############################################################ -# TODO, MCL has an implementation about 14% faster - static: doAssert UseASM_X86_64 # MULX/ADCX/ADOX @@ -140,6 +138,17 @@ macro montyRedc2x_adx_gen*[N: static int]( # Code generation result.add ctx.generate() +func montRed_asm_adx_bmi2_impl*[N: static int]( + r: var array[N, SecretWord], + a: array[N*2, SecretWord], + M: array[N, SecretWord], + m0ninv: BaseType, + hasSpareBit: static bool + ) = + ## Constant-time Montgomery reduction + ## Inline-version + montyRedc2x_adx_gen(r, a, M, m0ninv, hasSpareBit) + func montRed_asm_adx_bmi2*[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], @@ -148,4 +157,4 @@ func montRed_asm_adx_bmi2*[N: static int]( hasSpareBit: static bool ) = ## Constant-time Montgomery reduction - montyRedc2x_adx_gen(r, a, M, m0ninv, hasSpareBit) + montRed_asm_adx_bmi2_impl(r, a, M, m0ninv, hasSpareBit) diff --git a/constantine/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim b/constantine/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim index 3a7fde1..844a38c 100644 --- a/constantine/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim +++ b/constantine/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim @@ -168,11 +168,18 @@ macro mulx_gen[rLen, aLen, bLen: static int](rx: var Limbs[rLen], ax: Limbs[aLen # Codegen result.add ctx.generate -func mul_asm_adx_bmi2*[rLen, aLen, bLen: static int](r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) = +func mul_asm_adx_bmi2_impl*[rLen, aLen, bLen: static int]( + r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) {.inline.} = ## Multi-precision Multiplication ## Assumes r doesn't alias a or b + ## Inline version mulx_gen(r, a, b) +func mul_asm_adx_bmi2*[rLen, aLen, bLen: static int]( + r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) = + ## Multi-precision Multiplication + ## Assumes r doesn't alias a or b + mul_asm_adx_bmi2_impl(r, a, b) # Squaring # ----------------------------------------------------------------------------------------------- diff --git a/constantine/tower_field_extensions/assembly/fp2_asm_x86_adx_bmi2.nim b/constantine/tower_field_extensions/assembly/fp2_asm_x86_adx_bmi2.nim new file mode 100644 index 0000000..3a1c7d7 --- /dev/null +++ b/constantine/tower_field_extensions/assembly/fp2_asm_x86_adx_bmi2.nim @@ -0,0 +1,139 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + # Internal + ../../config/[common, curves], + ../../primitives, + ../../arithmetic, + ../../arithmetic/assembly/[ + limbs_asm_mul_x86_adx_bmi2, + limbs_asm_montmul_x86_adx_bmi2, + limbs_asm_montred_x86_adx_bmi2 + ] + + +# ############################################################ +# # +# Assembly implementation of 𝔽p2 # +# # +# ############################################################ + +static: doAssert UseASM_X86_64 + +# MULX/ADCX/ADOX +{.localPassC:"-madx -mbmi2".} +# Necessary for the compiler to find enough registers (enabled at -O1) +{.localPassC:"-fomit-frame-pointer".} + +# No exceptions allowed +{.push raises: [].} + +template c0*(a: array): auto = + a[0] +template c1*(a: array): auto = + a[1] + +func has1extraBit(F: type Fp): bool = + ## We construct extensions only on Fp (and not Fr) + getSpareBits(F) >= 1 + +func has2extraBits(F: type Fp): bool = + ## We construct extensions only on Fp (and not Fr) + getSpareBits(F) >= 2 + +# 𝔽p2 squaring +# ------------------------------------------------------------ + +func sqrx2x_complex_asm_adx_bmi2*( + r: var array[2, FpDbl], + a: array[2, Fp] + ) = + ## Complex squaring on 𝔽p2 + # This specialized proc inlines all calls and avoids many ADX support checks. + # and push/pop for paramater passing. + + var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) + + when Fp.has1extraBit(): + t0.sumUnr(a.c1, a.c1) + t1.sumUnr(a.c0, a.c1) + else: + t0.double(a.c1) + t1.sum(a.c0, a.c1) + + r.c1.mul_asm_adx_bmi2_impl(t0, a.c0) + t0.diff(a.c0, a.c1) + r.c0.mul_asm_adx_bmi2_impl(t0, t1) + +func sqrx_complex_asm_adx_bmi2*( + r: var array[2, Fp], + a: array[2, Fp] + ) = + ## Complex squaring on 𝔽p2 + # This specialized proc inlines all calls and avoids many ADX support checks. + # and push/pop for paramater passing. + # Staying in 𝔽p and not using double-precision is faster for squaring + + static: doAssert Fp.has1extraBit() + + var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) + v0.diff(a.c0, a.c1) + v1.sum(a.c0, a.c1) + r.c1.mres.limbs.montMul_CIOS_nocarry_asm_adx_bmi2(a.c0.mres.limbs, a.c1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) + # aliasing: a unneeded now + r.c1.double() + r.c0.mres.limbs.montMul_CIOS_nocarry_asm_adx_bmi2(v0.mres.limbs, v1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) + +# 𝔽p2 multiplication +# ------------------------------------------------------------ + +func mulx2x_complex_asm_adx_bmi2*( + r: var array[2, FpDbl], + a, b: array[2, Fp] + ) = + ## Complex multiplication on 𝔽p2 + var D {.noInit.}: typeof(r.c0) + var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) + + r.c0.limbs2x.mul_asm_adx_bmi2_impl(a.c0.mres.limbs, b.c0.mres.limbs) + D.limbs2x.mul_asm_adx_bmi2_impl(a.c1.mres.limbs, b.c1.mres.limbs) + when Fp.has1extraBit(): + t0.sumUnr(a.c0, a.c1) + t1.sumUnr(b.c0, b.c1) + else: + t0.sum(a.c0, a.c1) + t1.sum(b.c0, b.c1) + r.c1.limbs2x.mul_asm_adx_bmi2_impl(t0.mres.limbs, t1.mres.limbs) + when Fp.has1extraBit(): + r.c1.diff2xUnr(r.c1, r.c0) + r.c1.diff2xUnr(r.c1, D) + else: + r.c1.diff2xMod(r.c1, r.c0) + r.c1.diff2xMod(r.c1, D) + r.c0.diff2xMod(r.c0, D) + +func mulx_complex_asm_adx_bmi2*( + r: var array[2, Fp], + a, b: array[2, Fp] + ) = + ## Complex multiplication on 𝔽p2 + var d {.noInit.}: array[2,doublePrec(Fp)] + d.mulx2x_complex_asm_adx_bmi2(a, b) + r.c0.mres.limbs.montRed_asm_adx_bmi2_impl( + d.c0.limbs2x, + Fp.fieldMod().limbs, + Fp.getNegInvModWord(), + Fp.has1extraBit() + ) + r.c1.mres.limbs.montRed_asm_adx_bmi2_impl( + d.c1.limbs2x, + Fp.fieldMod().limbs, + Fp.getNegInvModWord(), + Fp.has1extraBit() + ) diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index c20ec65..4c64198 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -12,6 +12,10 @@ import ../arithmetic, ../io/io_fields +when UseASM_X86_64: + import + ./assembly/fp2_asm_x86_adx_bmi2 + # Note: to avoid burdening the Nim compiler, we rely on generic extension # to complain if the base field procedures don't exist @@ -807,8 +811,8 @@ func prod2x_complex(r: var QuadraticExt2x, a, b: Fp2) = t1.sum(b.c0, b.c1) r.c1.prod2x(t0, t1) # r1 = (b0 + b1)(a0 + a1) when Fp2.has1extraBit(): - r.c1.diff2xUnr(r.c1, r.c0) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - r.c1.diff2xUnr(r.c1, D) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - a1b1 + r.c1.diff2xUnr(r.c1, r.c0) # r1 = (b0 + b1)(a0 + a1) - a0 b0 + r.c1.diff2xUnr(r.c1, D) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - a1b1 else: r.c1.diff2xMod(r.c1, r.c0) r.c1.diff2xMod(r.c1, D) @@ -1227,7 +1231,13 @@ func square2x*(r: var QuadraticExt2x, a: QuadraticExt) = func square*(r: var QuadraticExt, a: QuadraticExt) = when r.fromComplexExtension(): when true: - r.square_complex(a) + when UseASM_X86_64 and a.c0.mres.limbs.len <= 6: + if ({.noSideEffect.}: hasAdx()): + r.coords.sqrx_complex_asm_adx_bmi2(a.coords) + else: + r.square_complex(a) + else: + r.square_complex(a) else: # slower var d {.noInit.}: doublePrec(typeof(r)) d.square2x_complex(a) @@ -1259,10 +1269,19 @@ func prod*(r: var QuadraticExt, a, b: QuadraticExt) = when false: r.prod_complex(a, b) else: # faster - var d {.noInit.}: doublePrec(typeof(r)) - d.prod2x_complex(a, b) - r.c0.redc2x(d.c0) - r.c1.redc2x(d.c1) + when UseASM_X86_64 and a.c0.mres.limbs.len <= 6: + if ({.noSideEffect.}: hasAdx()): + r.coords.mulx_complex_asm_adx_bmi2(a.coords, b.coords) + else: + var d {.noInit.}: doublePrec(typeof(r)) + d.prod2x_complex(a, b) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) + else: + var d {.noInit.}: doublePrec(typeof(r)) + d.prod2x_complex(a, b) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) else: when r.typeof.F.C == BW6_761 or typeof(r.c0) is Fp: # BW6-761 requires too many registers for Dbl width path @@ -1287,7 +1306,13 @@ func prod2x_disjoint*[Fdbl, F]( func prod2x*(r: var QuadraticExt2x, a, b: QuadraticExt) = ## Double-precision multiplication r <- a*b when a.fromComplexExtension(): - r.prod2x_complex(a, b) + when UseASM_X86_64 and a.c0.mres.limbs.len <= 6: + if ({.noSideEffect.}: hasAdx()): + r.coords.mulx2x_complex_asm_adx_bmi2(a.coords, b.coords) + else: + r.prod2x_complex(a, b) + else: + r.prod2x_complex(a, b) else: r.prod2x_disjoint(a.c0, a.c1, b.c0, b.c1)