Double-Precision towering (#155)

* consistent naming for dbl-width

* Isolate double-width Fp2 mul

* Implement double-width complex multiplication

* Lay out Fp4 double-width mul

* Off by p in square Fp4 as well :/

* less copies and stack space in addition chains

* Address https://github.com/mratsim/constantine/issues/154 partly

* Fix #154, faster Fp4 square: less non-residue, no Mul, only square (bit more ops total)

* Fix typo

* better assembly scheduling for add/sub

* Double-width -> Double-precision

* Unred -> Unr

* double-precision modular addition

* Replace canUseNoCarryMontyMul and canUseNoCarryMontySquare by getSpareBits

* Complete the double-precision implementation

* Use double-precision path for Fp4 squaring and mul

* remove mixin annotations

* Lazy reduction in Fp4 prod

* Fix assembly for sum2xMod

* Assembly for double-precision negation

* reduce white spaces in pairing benchmarks

* ADX implies BMI2
This commit is contained in:
Mamy Ratsimbazafy 2021-02-09 22:57:45 +01:00 committed by GitHub
parent 491b4d4d21
commit 5806cc4638
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1570 additions and 697 deletions

View File

@ -88,7 +88,6 @@ proc notes*() =
echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc\" or \"nimble bench_ec_g1_clang\"." echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc\" or \"nimble bench_ec_g1_clang\"."
echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc_noasm\" or \"nimble bench_ec_g1_clang_noasm\"." echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc_noasm\" or \"nimble bench_ec_g1_clang_noasm\"."
echo " - The simplest operations might be optimized away by the compiler." echo " - The simplest operations might be optimized away by the compiler."
echo " - Fast Squaring and Fast Multiplication are possible if there are spare bits in the prime representation (i.e. the prime uses 254 bits out of 256 bits)"
template measure*(iters: int, template measure*(iters: int,
startTime, stopTime: untyped, startTime, stopTime: untyped,

View File

@ -89,11 +89,13 @@ proc notes*() =
echo "Notes:" echo "Notes:"
echo " - Compilers:" echo " - Compilers:"
echo " Compilers are severely limited on multiprecision arithmetic." echo " Compilers are severely limited on multiprecision arithmetic."
echo " Inline Assembly is used by default (nimble bench_fp)." echo " Constantine compile-time assembler is used by default (nimble bench_fp)."
echo " Bench without assembly can use \"nimble bench_fp_gcc\" or \"nimble bench_fp_clang\"."
echo " GCC is significantly slower than Clang on multiprecision arithmetic due to catastrophic handling of carries." echo " GCC is significantly slower than Clang on multiprecision arithmetic due to catastrophic handling of carries."
echo " GCC also seems to have issues with large temporaries and register spilling."
echo " This is somewhat alleviated by Constantine compile-time assembler."
echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc\" or \"nimble bench_ec_g1_clang\"."
echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc_noasm\" or \"nimble bench_ec_g1_clang_noasm\"."
echo " - The simplest operations might be optimized away by the compiler." echo " - The simplest operations might be optimized away by the compiler."
echo " - Fast Squaring and Fast Multiplication are possible if there are spare bits in the prime representation (i.e. the prime uses 254 bits out of 256 bits)"
template bench(op: string, desc: string, iters: int, body: untyped): untyped = template bench(op: string, desc: string, iters: int, body: untyped): untyped =
let start = getMonotime() let start = getMonotime()
@ -121,12 +123,12 @@ func random_unsafe(rng: var RngState, a: var FpDbl, Base: typedesc) =
for i in 0 ..< aHi.mres.limbs.len: for i in 0 ..< aHi.mres.limbs.len:
a.limbs2x[aLo.mres.limbs.len+i] = aHi.mres.limbs[i] a.limbs2x[aLo.mres.limbs.len+i] = aHi.mres.limbs[i]
proc sumNoReduce(T: typedesc, iters: int) = proc sumUnr(T: typedesc, iters: int) =
var r: T var r: T
let a = rng.random_unsafe(T) let a = rng.random_unsafe(T)
let b = rng.random_unsafe(T) let b = rng.random_unsafe(T)
bench("Addition no reduce", $T, iters): bench("Addition unreduced", $T, iters):
r.sumNoReduce(a, b) r.sumUnr(a, b)
proc sum(T: typedesc, iters: int) = proc sum(T: typedesc, iters: int) =
var r: T var r: T
@ -135,12 +137,12 @@ proc sum(T: typedesc, iters: int) =
bench("Addition", $T, iters): bench("Addition", $T, iters):
r.sum(a, b) r.sum(a, b)
proc diffNoReduce(T: typedesc, iters: int) = proc diffUnr(T: typedesc, iters: int) =
var r: T var r: T
let a = rng.random_unsafe(T) let a = rng.random_unsafe(T)
let b = rng.random_unsafe(T) let b = rng.random_unsafe(T)
bench("Substraction no reduce", $T, iters): bench("Substraction unreduced", $T, iters):
r.diffNoReduce(a, b) r.diffUnr(a, b)
proc diff(T: typedesc, iters: int) = proc diff(T: typedesc, iters: int) =
var r: T var r: T
@ -149,52 +151,86 @@ proc diff(T: typedesc, iters: int) =
bench("Substraction", $T, iters): bench("Substraction", $T, iters):
r.diff(a, b) r.diff(a, b)
proc diff2xNoReduce(T: typedesc, iters: int) = proc neg(T: typedesc, iters: int) =
var r, a, b: doubleWidth(T) var r: T
let a = rng.random_unsafe(T)
bench("Negation", $T, iters):
r.neg(a)
proc sum2xUnreduce(T: typedesc, iters: int) =
var r, a, b: doublePrec(T)
rng.random_unsafe(r, T) rng.random_unsafe(r, T)
rng.random_unsafe(a, T) rng.random_unsafe(a, T)
rng.random_unsafe(b, T) rng.random_unsafe(b, T)
bench("Substraction 2x no reduce", $doubleWidth(T), iters): bench("Addition 2x unreduced", $doublePrec(T), iters):
r.diffNoReduce(a, b) r.sum2xUnr(a, b)
proc sum2x(T: typedesc, iters: int) =
var r, a, b: doublePrec(T)
rng.random_unsafe(r, T)
rng.random_unsafe(a, T)
rng.random_unsafe(b, T)
bench("Addition 2x reduced", $doublePrec(T), iters):
r.sum2xMod(a, b)
proc diff2xUnreduce(T: typedesc, iters: int) =
var r, a, b: doublePrec(T)
rng.random_unsafe(r, T)
rng.random_unsafe(a, T)
rng.random_unsafe(b, T)
bench("Substraction 2x unreduced", $doublePrec(T), iters):
r.diff2xUnr(a, b)
proc diff2x(T: typedesc, iters: int) = proc diff2x(T: typedesc, iters: int) =
var r, a, b: doubleWidth(T) var r, a, b: doublePrec(T)
rng.random_unsafe(r, T) rng.random_unsafe(r, T)
rng.random_unsafe(a, T) rng.random_unsafe(a, T)
rng.random_unsafe(b, T) rng.random_unsafe(b, T)
bench("Substraction 2x", $doubleWidth(T), iters): bench("Substraction 2x reduced", $doublePrec(T), iters):
r.diff(a, b) r.diff2xMod(a, b)
proc mul2xBench*(rLen, aLen, bLen: static int, iters: int) = proc neg2x(T: typedesc, iters: int) =
var r, a: doublePrec(T)
rng.random_unsafe(a, T)
bench("Negation 2x reduced", $doublePrec(T), iters):
r.neg2xMod(a)
proc prod2xBench*(rLen, aLen, bLen: static int, iters: int) =
var r: BigInt[rLen] var r: BigInt[rLen]
let a = rng.random_unsafe(BigInt[aLen]) let a = rng.random_unsafe(BigInt[aLen])
let b = rng.random_unsafe(BigInt[bLen]) let b = rng.random_unsafe(BigInt[bLen])
bench("Multiplication", $rLen & " <- " & $aLen & " x " & $bLen, iters): bench("Multiplication 2x", $rLen & " <- " & $aLen & " x " & $bLen, iters):
r.prod(a, b) r.prod(a, b)
proc square2xBench*(rLen, aLen: static int, iters: int) = proc square2xBench*(rLen, aLen: static int, iters: int) =
var r: BigInt[rLen] var r: BigInt[rLen]
let a = rng.random_unsafe(BigInt[aLen]) let a = rng.random_unsafe(BigInt[aLen])
bench("Squaring", $rLen & " <- " & $aLen & "²", iters): bench("Squaring 2x", $rLen & " <- " & $aLen & "²", iters):
r.square(a) r.square(a)
proc reduce2x*(T: typedesc, iters: int) = proc reduce2x*(T: typedesc, iters: int) =
var r: T var r: T
var t: doubleWidth(T) var t: doublePrec(T)
rng.random_unsafe(t, T) rng.random_unsafe(t, T)
bench("Reduce 2x-width", $T & " <- " & $doubleWidth(T), iters): bench("Redc 2x", $T & " <- " & $doublePrec(T), iters):
r.reduce(t) r.redc2x(t)
proc main() = proc main() =
separator() separator()
sumNoReduce(Fp[BLS12_381], iters = 10_000_000)
diffNoReduce(Fp[BLS12_381], iters = 10_000_000)
sum(Fp[BLS12_381], iters = 10_000_000) sum(Fp[BLS12_381], iters = 10_000_000)
sumUnr(Fp[BLS12_381], iters = 10_000_000)
diff(Fp[BLS12_381], iters = 10_000_000) diff(Fp[BLS12_381], iters = 10_000_000)
diffUnr(Fp[BLS12_381], iters = 10_000_000)
neg(Fp[BLS12_381], iters = 10_000_000)
separator()
sum2x(Fp[BLS12_381], iters = 10_000_000)
sum2xUnreduce(Fp[BLS12_381], iters = 10_000_000)
diff2x(Fp[BLS12_381], iters = 10_000_000) diff2x(Fp[BLS12_381], iters = 10_000_000)
diff2xNoReduce(Fp[BLS12_381], iters = 10_000_000) diff2xUnreduce(Fp[BLS12_381], iters = 10_000_000)
mul2xBench(768, 384, 384, iters = 10_000_000) neg2x(Fp[BLS12_381], iters = 10_000_000)
separator()
prod2xBench(768, 384, 384, iters = 10_000_000)
square2xBench(768, 384, iters = 10_000_000) square2xBench(768, 384, iters = 10_000_000)
reduce2x(Fp[BLS12_381], iters = 10_000_000) reduce2x(Fp[BLS12_381], iters = 10_000_000)
separator() separator()

View File

@ -32,15 +32,15 @@ import
./bench_blueprint ./bench_blueprint
export notes export notes
proc separator*() = separator(177) proc separator*() = separator(132)
proc report(op, curve: string, startTime, stopTime: MonoTime, startClk, stopClk: int64, iters: int) = proc report(op, curve: string, startTime, stopTime: MonoTime, startClk, stopClk: int64, iters: int) =
let ns = inNanoseconds((stopTime-startTime) div iters) let ns = inNanoseconds((stopTime-startTime) div iters)
let throughput = 1e9 / float64(ns) let throughput = 1e9 / float64(ns)
when SupportsGetTicks: when SupportsGetTicks:
echo &"{op:<60} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)" echo &"{op:<40} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
else: else:
echo &"{op:<60} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op" echo &"{op:<40} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op"
template bench(op: string, C: static Curve, iters: int, body: untyped): untyped = template bench(op: string, C: static Curve, iters: int, body: untyped): untyped =
measure(iters, startTime, stopTime, startClk, stopClk, body) measure(iters, startTime, stopTime, startClk, stopClk, body)

View File

@ -43,13 +43,14 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[
("tests/t_finite_fields_powinv.nim", false), ("tests/t_finite_fields_powinv.nim", false),
("tests/t_finite_fields_vs_gmp.nim", true), ("tests/t_finite_fields_vs_gmp.nim", true),
("tests/t_fp_cubic_root.nim", false), ("tests/t_fp_cubic_root.nim", false),
# Double-width finite fields # Double-precision finite fields
# ---------------------------------------------------------- # ----------------------------------------------------------
("tests/t_finite_fields_double_width.nim", false), ("tests/t_finite_fields_double_precision.nim", false),
# Towers of extension fields # Towers of extension fields
# ---------------------------------------------------------- # ----------------------------------------------------------
("tests/t_fp2.nim", false), ("tests/t_fp2.nim", false),
("tests/t_fp2_sqrt.nim", false), ("tests/t_fp2_sqrt.nim", false),
("tests/t_fp4.nim", false),
("tests/t_fp6_bn254_snarks.nim", false), ("tests/t_fp6_bn254_snarks.nim", false),
("tests/t_fp6_bls12_377.nim", false), ("tests/t_fp6_bls12_377.nim", false),
("tests/t_fp6_bls12_381.nim", false), ("tests/t_fp6_bls12_381.nim", false),
@ -259,7 +260,7 @@ proc buildAllBenches() =
echo "\n\n------------------------------------------------------\n" echo "\n\n------------------------------------------------------\n"
echo "Building benchmarks to ensure they stay relevant ..." echo "Building benchmarks to ensure they stay relevant ..."
buildBench("bench_fp") buildBench("bench_fp")
buildBench("bench_fp_double_width") buildBench("bench_fp_double_precision")
buildBench("bench_fp2") buildBench("bench_fp2")
buildBench("bench_fp6") buildBench("bench_fp6")
buildBench("bench_fp12") buildBench("bench_fp12")
@ -400,19 +401,19 @@ task bench_fp_clang_noasm, "Run benchmark 𝔽p with clang - no Assembly":
runBench("bench_fp", "clang", useAsm = false) runBench("bench_fp", "clang", useAsm = false)
task bench_fpdbl, "Run benchmark 𝔽pDbl with your default compiler": task bench_fpdbl, "Run benchmark 𝔽pDbl with your default compiler":
runBench("bench_fp_double_width") runBench("bench_fp_double_precision")
task bench_fpdbl_gcc, "Run benchmark 𝔽p with gcc": task bench_fpdbl_gcc, "Run benchmark 𝔽p with gcc":
runBench("bench_fp_double_width", "gcc") runBench("bench_fp_double_precision", "gcc")
task bench_fpdbl_clang, "Run benchmark 𝔽p with clang": task bench_fpdbl_clang, "Run benchmark 𝔽p with clang":
runBench("bench_fp_double_width", "clang") runBench("bench_fp_double_precision", "clang")
task bench_fpdbl_gcc_noasm, "Run benchmark 𝔽p with gcc - no Assembly": task bench_fpdbl_gcc_noasm, "Run benchmark 𝔽p with gcc - no Assembly":
runBench("bench_fp_double_width", "gcc", useAsm = false) runBench("bench_fp_double_precision", "gcc", useAsm = false)
task bench_fpdbl_clang_noasm, "Run benchmark 𝔽p with clang - no Assembly": task bench_fpdbl_clang_noasm, "Run benchmark 𝔽p with clang - no Assembly":
runBench("bench_fp_double_width", "clang", useAsm = false) runBench("bench_fp_double_precision", "clang", useAsm = false)
task bench_fp2, "Run benchmark with 𝔽p2 your default compiler": task bench_fp2, "Run benchmark with 𝔽p2 your default compiler":
runBench("bench_fp2") runBench("bench_fp2")

View File

@ -12,7 +12,7 @@ import
finite_fields, finite_fields,
finite_fields_inversion, finite_fields_inversion,
finite_fields_square_root, finite_fields_square_root,
finite_fields_double_width finite_fields_double_precision
] ]
export export
@ -21,4 +21,4 @@ export
finite_fields, finite_fields,
finite_fields_inversion, finite_fields_inversion,
finite_fields_square_root, finite_fields_square_root,
finite_fields_double_width finite_fields_double_precision

View File

@ -0,0 +1,244 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
# Standard library
std/macros,
# Internal
../../config/common,
../../primitives
# ############################################################
# #
# Assembly implementation of FpDbl #
# #
# ############################################################
# A FpDbl is a partially-reduced double-precision element of Fp
# The allowed range is [0, 2ⁿp)
# with n = w*WordBitSize
# and w the number of words necessary to represent p on the machine.
# Concretely a 381-bit p needs 6*64 bits limbs (hence 384 bits total)
# and so FpDbl would 768 bits.
static: doAssert UseASM_X86_64
{.localPassC:"-fomit-frame-pointer".} # Needed so that the compiler finds enough registers
# Double-precision field addition
# ------------------------------------------------------------
macro addmod2x_gen[N: static int](R: var Limbs[N], A, B: Limbs[N], m: Limbs[N div 2]): untyped =
## Generate an optimized out-of-place double-precision addition kernel
result = newStmtList()
var ctx = init(Assembler_x86, BaseType)
let
H = N div 2
r = init(OperandArray, nimSymbol = R, N, PointerInReg, InputOutput)
# We reuse the reg used for b for overflow detection
b = init(OperandArray, nimSymbol = B, N, PointerInReg, InputOutput)
# We could force m as immediate by specializing per moduli
M = init(OperandArray, nimSymbol = m, N, PointerInReg, Input)
# If N is too big, we need to spill registers. TODO.
u = init(OperandArray, nimSymbol = ident"U", H, ElemsInReg, InputOutput)
v = init(OperandArray, nimSymbol = ident"V", H, ElemsInReg, InputOutput)
let usym = u.nimSymbol
let vsym = v.nimSymbol
result.add quote do:
var `usym`{.noinit.}, `vsym` {.noInit.}: typeof(`A`)
staticFor i, 0, `H`:
`usym`[i] = `A`[i]
staticFor i, `H`, `N`:
`vsym`[i-`H`] = `A`[i]
# Addition
# u = a[0..<H] + b[0..<H], v = a[H..<N]
for i in 0 ..< H:
if i == 0:
ctx.add u[0], b[0]
else:
ctx.adc u[i], b[i]
ctx.mov r[i], u[i]
# v = a[H..<N] + b[H..<N], a[0..<H] = u, u = v
for i in H ..< N:
ctx.adc v[i-H], b[i]
ctx.mov u[i-H], v[i-H]
# Mask: overflowed contains 0xFFFF or 0x0000
# TODO: unnecessary if MSB never set, i.e. "Field.getSpareBits >= 1"
let overflowed = b.reuseRegister()
ctx.sbb overflowed, overflowed
# Now substract the modulus to test a < 2ⁿp
for i in 0 ..< H:
if i == 0:
ctx.sub v[0], M[0]
else:
ctx.sbb v[i], M[i]
# If it overflows here, it means that it was
# smaller than the modulus and we don't need v
ctx.sbb overflowed, 0
# Conditional Mov and
# and store result
for i in 0 ..< H:
ctx.cmovnc u[i], v[i]
ctx.mov r[i+H], u[i]
result.add ctx.generate
func addmod2x_asm*[N: static int](r: var Limbs[N], a, b: Limbs[N], M: Limbs[N div 2]) =
## Constant-time double-precision addition
## Output is conditionally reduced by 2ⁿp
## to stay in the [0, 2ⁿp) range
addmod2x_gen(r, a, b, M)
# Double-precision field substraction
# ------------------------------------------------------------
macro submod2x_gen[N: static int](R: var Limbs[N], A, B: Limbs[N], m: Limbs[N div 2]): untyped =
## Generate an optimized out-of-place double-precision substraction kernel
result = newStmtList()
var ctx = init(Assembler_x86, BaseType)
let
H = N div 2
r = init(OperandArray, nimSymbol = R, N, PointerInReg, InputOutput)
# We reuse the reg used for b for overflow detection
b = init(OperandArray, nimSymbol = B, N, PointerInReg, InputOutput)
# We could force m as immediate by specializing per moduli
M = init(OperandArray, nimSymbol = m, N, PointerInReg, Input)
# If N is too big, we need to spill registers. TODO.
u = init(OperandArray, nimSymbol = ident"U", H, ElemsInReg, InputOutput)
v = init(OperandArray, nimSymbol = ident"V", H, ElemsInReg, InputOutput)
let usym = u.nimSymbol
let vsym = v.nimSymbol
result.add quote do:
var `usym`{.noinit.}, `vsym` {.noInit.}: typeof(`A`)
staticFor i, 0, `H`:
`usym`[i] = `A`[i]
staticFor i, `H`, `N`:
`vsym`[i-`H`] = `A`[i]
# Substraction
# u = a[0..<H] - b[0..<H], v = a[H..<N]
for i in 0 ..< H:
if i == 0:
ctx.sub u[0], b[0]
else:
ctx.sbb u[i], b[i]
ctx.mov r[i], u[i]
# v = a[H..<N] - b[H..<N], a[0..<H] = u, u = M
for i in H ..< N:
ctx.sbb v[i-H], b[i]
ctx.mov u[i-H], M[i-H] # TODO, bottleneck 17% perf: prefetch or inline modulus?
# Mask: underflowed contains 0xFFFF or 0x0000
let underflowed = b.reuseRegister()
ctx.sbb underflowed, underflowed
# Now mask the adder, with 0 or the modulus limbs
for i in 0 ..< H:
ctx.`and` u[i], underflowed
# Add the masked modulus
for i in 0 ..< H:
if i == 0:
ctx.add u[0], v[0]
else:
ctx.adc u[i], v[i]
ctx.mov r[i+H], u[i]
result.add ctx.generate
func submod2x_asm*[N: static int](r: var Limbs[N], a, b: Limbs[N], M: Limbs[N div 2]) =
## Constant-time double-precision substraction
## Output is conditionally reduced by 2ⁿp
## to stay in the [0, 2ⁿp) range
submod2x_gen(r, a, b, M)
# Double-precision field negation
# ------------------------------------------------------------
macro negmod2x_gen[N: static int](R: var Limbs[N], A: Limbs[N], m: Limbs[N div 2]): untyped =
## Generate an optimized modular negation kernel
result = newStmtList()
var ctx = init(Assembler_x86, BaseType)
let
H = N div 2
a = init(OperandArray, nimSymbol = A, N, PointerInReg, Input)
r = init(OperandArray, nimSymbol = R, N, PointerInReg, InputOutput)
u = init(OperandArray, nimSymbol = ident"U", N, ElemsInReg, Output_EarlyClobber)
# We could force m as immediate by specializing per moduli
# We reuse the reg used for m for overflow detection
M = init(OperandArray, nimSymbol = m, N, PointerInReg, InputOutput)
isZero = Operand(
desc: OperandDesc(
asmId: "[isZero]",
nimSymbol: ident"isZero",
rm: Reg,
constraint: Output_EarlyClobber,
cEmit: "isZero"
)
)
# Substraction 2ⁿp - a
# The lower half of 2ⁿp is filled with zero
ctx.`xor` isZero, isZero
for i in 0 ..< H:
ctx.`xor` u[i], u[i]
ctx.`or` isZero, a[i]
for i in 0 ..< H:
# 0 - a[i]
if i == 0:
ctx.sub u[0], a[0]
else:
ctx.sbb u[i], a[i]
# store result, overwrite a[i] lower-half if aliasing.
ctx.mov r[i], u[i]
# Prepare second-half, u <- M
ctx.mov u[i], M[i]
for i in H ..< N:
# u = 2ⁿp higher half
ctx.sbb u[i-H], a[i]
# Deal with a == 0,
# we already accumulated 0 in the first half (which was destroyed if aliasing)
for i in H ..< N:
ctx.`or` isZero, a[i]
# Zero result if a == 0, only the upper half needs to be zero-ed here
for i in H ..< N:
ctx.cmovz u[i-H], isZero
ctx.mov r[i], u[i-H]
let isZerosym = isZero.desc.nimSymbol
let usym = u.nimSymbol
result.add quote do:
var `isZerosym`{.noInit.}: BaseType
var `usym`{.noinit.}: typeof(`A`)
result.add ctx.generate
func negmod2x_asm*[N: static int](r: var Limbs[N], a: Limbs[N], M: Limbs[N div 2]) =
## Constant-time double-precision negation
negmod2x_gen(r, a, M)

View File

@ -1,93 +0,0 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
# Standard library
std/macros,
# Internal
../../config/common,
../../primitives
# ############################################################
#
# Assembly implementation of FpDbl
#
# ############################################################
static: doAssert UseASM_X86_64
{.localPassC:"-fomit-frame-pointer".} # Needed so that the compiler finds enough registers
# TODO slower than intrinsics
# Substraction
# ------------------------------------------------------------
macro sub2x_gen[N: static int](R: var Limbs[N], A, B: Limbs[N], m: Limbs[N div 2]): untyped =
## Generate an optimized out-of-place double-width substraction kernel
result = newStmtList()
var ctx = init(Assembler_x86, BaseType)
let
H = N div 2
r = init(OperandArray, nimSymbol = R, N, PointerInReg, InputOutput)
# We reuse the reg used for b for overflow detection
b = init(OperandArray, nimSymbol = B, N, PointerInReg, InputOutput)
# We could force m as immediate by specializing per moduli
M = init(OperandArray, nimSymbol = m, N, PointerInReg, Input)
# If N is too big, we need to spill registers. TODO.
u = init(OperandArray, nimSymbol = ident"U", H, ElemsInReg, InputOutput)
v = init(OperandArray, nimSymbol = ident"V", H, ElemsInReg, InputOutput)
let usym = u.nimSymbol
let vsym = v.nimSymbol
result.add quote do:
var `usym`{.noinit.}, `vsym` {.noInit.}: typeof(`A`)
staticFor i, 0, `H`:
`usym`[i] = `A`[i]
staticFor i, `H`, `N`:
`vsym`[i-`H`] = `A`[i]
# Substraction
# u = a[0..<H] - b[0..<H], v = a[H..<N]
for i in 0 ..< H:
if i == 0:
ctx.sub u[0], b[0]
else:
ctx.sbb u[i], b[i]
# Everything should be hot in cache now so movs are cheaper
# we can try using 2 per SBB
# v = a[H..<N] - b[H..<N], a[0..<H] = u, u = M
for i in H ..< N:
ctx.mov r[i-H], u[i-H]
ctx.sbb v[i-H], b[i]
ctx.mov u[i-H], M[i-H] # TODO, bottleneck 17% perf: prefetch or inline modulus?
# Mask: underflowed contains 0xFFFF or 0x0000
let underflowed = b.reuseRegister()
ctx.sbb underflowed, underflowed
# Now mask the adder, with 0 or the modulus limbs
for i in 0 ..< H:
ctx.`and` u[i], underflowed
# Add the masked modulus
for i in 0 ..< H:
if i == 0:
ctx.add u[0], v[0]
else:
ctx.adc u[i], v[i]
ctx.mov r[i+H], u[i]
result.add ctx.generate
func sub2x_asm*[N: static int](r: var Limbs[N], a, b: Limbs[N], M: Limbs[N div 2]) =
## Constant-time double-width substraction
sub2x_gen(r, a, b, M)

View File

@ -69,11 +69,11 @@ macro addmod_gen[N: static int](R: var Limbs[N], A, B, m: Limbs[N]): untyped =
ctx.mov v[i], u[i] ctx.mov v[i], u[i]
# Mask: overflowed contains 0xFFFF or 0x0000 # Mask: overflowed contains 0xFFFF or 0x0000
# TODO: unnecessary if MSB never set, i.e. "canUseNoCarryMontyMul" # TODO: unnecessary if MSB never set, i.e. "Field.getSpareBits >= 1"
let overflowed = b.reuseRegister() let overflowed = b.reuseRegister()
ctx.sbb overflowed, overflowed ctx.sbb overflowed, overflowed
# Now substract the modulus # Now substract the modulus to test a < p
for i in 0 ..< N: for i in 0 ..< N:
if i == 0: if i == 0:
ctx.sub v[0], M[0] ctx.sub v[0], M[0]
@ -81,7 +81,7 @@ macro addmod_gen[N: static int](R: var Limbs[N], A, B, m: Limbs[N]): untyped =
ctx.sbb v[i], M[i] ctx.sbb v[i], M[i]
# If it overflows here, it means that it was # If it overflows here, it means that it was
# smaller than the modulus and we don'u need V # smaller than the modulus and we don't need V
ctx.sbb overflowed, 0 ctx.sbb overflowed, 0
# Conditional Mov and # Conditional Mov and

View File

@ -83,12 +83,12 @@ proc finalSubCanOverflow*(
# Montgomery reduction # Montgomery reduction
# ------------------------------------------------------------ # ------------------------------------------------------------
macro montyRed_gen[N: static int]( macro montyRedc2x_gen[N: static int](
r_MR: var array[N, SecretWord], r_MR: var array[N, SecretWord],
a_MR: array[N*2, SecretWord], a_MR: array[N*2, SecretWord],
M_MR: array[N, SecretWord], M_MR: array[N, SecretWord],
m0ninv_MR: BaseType, m0ninv_MR: BaseType,
canUseNoCarryMontyMul: static bool spareBits: static int
) = ) =
# TODO, slower than Clang, in particular due to the shadowing # TODO, slower than Clang, in particular due to the shadowing
@ -236,7 +236,7 @@ macro montyRed_gen[N: static int](
let reuse = repackRegisters(t, scratch[N], scratch[N+1]) let reuse = repackRegisters(t, scratch[N], scratch[N+1])
if canUseNoCarryMontyMul: if spareBits >= 1:
ctx.finalSubNoCarry(r, scratch, M, reuse) ctx.finalSubNoCarry(r, scratch, M, reuse)
else: else:
ctx.finalSubCanOverflow(r, scratch, M, reuse, rRAX) ctx.finalSubCanOverflow(r, scratch, M, reuse, rRAX)
@ -249,7 +249,7 @@ func montRed_asm*[N: static int](
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, m0ninv: BaseType,
canUseNoCarryMontyMul: static bool spareBits: static int
) = ) =
## Constant-time Montgomery reduction ## Constant-time Montgomery reduction
montyRed_gen(r, a, M, m0ninv, canUseNoCarryMontyMul) montyRedc2x_gen(r, a, M, m0ninv, spareBits)

View File

@ -35,12 +35,12 @@ static: doAssert UseASM_X86_64
# Montgomery reduction # Montgomery reduction
# ------------------------------------------------------------ # ------------------------------------------------------------
macro montyRedx_gen[N: static int]( macro montyRedc2xx_gen[N: static int](
r_MR: var array[N, SecretWord], r_MR: var array[N, SecretWord],
a_MR: array[N*2, SecretWord], a_MR: array[N*2, SecretWord],
M_MR: array[N, SecretWord], M_MR: array[N, SecretWord],
m0ninv_MR: BaseType, m0ninv_MR: BaseType,
canUseNoCarryMontyMul: static bool spareBits: static int
) = ) =
# TODO, slower than Clang, in particular due to the shadowing # TODO, slower than Clang, in particular due to the shadowing
@ -175,7 +175,7 @@ macro montyRedx_gen[N: static int](
let reuse = repackRegisters(t, scratch[N]) let reuse = repackRegisters(t, scratch[N])
if canUseNoCarryMontyMul: if spareBits >= 1:
ctx.finalSubNoCarry(r, scratch, M, reuse) ctx.finalSubNoCarry(r, scratch, M, reuse)
else: else:
ctx.finalSubCanOverflow(r, scratch, M, reuse, hi) ctx.finalSubCanOverflow(r, scratch, M, reuse, hi)
@ -188,7 +188,7 @@ func montRed_asm_adx_bmi2*[N: static int](
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, m0ninv: BaseType,
canUseNoCarryMontyMul: static bool spareBits: static int
) = ) =
## Constant-time Montgomery reduction ## Constant-time Montgomery reduction
montyRedx_gen(r, a, M, m0ninv, canUseNoCarryMontyMul) montyRedc2xx_gen(r, a, M, m0ninv, spareBits)

View File

@ -138,14 +138,16 @@ macro add_gen[N: static int](carry: var Carry, r: var Limbs[N], a, b: Limbs[N]):
var `t0sym`{.noinit.}, `t1sym`{.noinit.}: BaseType var `t0sym`{.noinit.}, `t1sym`{.noinit.}: BaseType
# Algorithm # Algorithm
for i in 0 ..< N: ctx.mov t0, arrA[0] # Prologue
ctx.mov t0, arrA[i] ctx.add t0, arrB[0]
if i == 0:
ctx.add t0, arrB[0] for i in 1 ..< N:
else: ctx.mov t1, arrA[i] # Prepare the next iteration
ctx.adc t0, arrB[i] ctx.mov arrR[i-1], t0 # Save the previous result in an interleaved manner
ctx.mov arrR[i], t0 ctx.adc t1, arrB[i] # Compute
swap(t0, t1) swap(t0, t1) # Break dependency chain
ctx.mov arrR[N-1], t0 # Epilogue
ctx.setToCarryFlag(carry) ctx.setToCarryFlag(carry)
# Codegen # Codegen
@ -197,14 +199,16 @@ macro sub_gen[N: static int](borrow: var Borrow, r: var Limbs[N], a, b: Limbs[N]
var `t0sym`{.noinit.}, `t1sym`{.noinit.}: BaseType var `t0sym`{.noinit.}, `t1sym`{.noinit.}: BaseType
# Algorithm # Algorithm
for i in 0 ..< N: ctx.mov t0, arrA[0] # Prologue
ctx.mov t0, arrA[i] ctx.sub t0, arrB[0]
if i == 0:
ctx.sub t0, arrB[0] for i in 1 ..< N:
else: ctx.mov t1, arrA[i] # Prepare the next iteration
ctx.sbb t0, arrB[i] ctx.mov arrR[i-1], t0 # Save the previous reult in an interleaved manner
ctx.mov arrR[i], t0 ctx.sbb t1, arrB[i] # Compute
swap(t0, t1) swap(t0, t1) # Break dependency chain
ctx.mov arrR[N-1], t0 # Epilogue
ctx.setToCarryFlag(borrow) ctx.setToCarryFlag(borrow)
# Codegen # Codegen

View File

@ -25,7 +25,7 @@ import
# #
# ############################################################ # ############################################################
func montyResidue*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = func montyResidue*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, spareBits: static int) =
## Convert a BigInt from its natural representation ## Convert a BigInt from its natural representation
## to the Montgomery n-residue form ## to the Montgomery n-residue form
## ##
@ -40,9 +40,9 @@ func montyResidue*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseTy
## - `r2modM` is R² (mod M) ## - `r2modM` is R² (mod M)
## with W = M.len ## with W = M.len
## and R = (2^WordBitWidth)^W ## and R = (2^WordBitWidth)^W
montyResidue(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, canUseNoCarryMontyMul) montyResidue(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, spareBits)
func redc*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = func redc*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static BaseType, spareBits: static int) =
## Convert a BigInt from its Montgomery n-residue form ## Convert a BigInt from its Montgomery n-residue form
## to the natural representation ## to the natural representation
## ##
@ -54,26 +54,26 @@ func redc*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static Base
var one {.noInit.}: BigInt[mBits] var one {.noInit.}: BigInt[mBits]
one.setOne() one.setOne()
one one
redc(r.limbs, a.limbs, one.limbs, M.limbs, m0ninv, canUseNoCarryMontyMul) redc(r.limbs, a.limbs, one.limbs, M.limbs, m0ninv, spareBits)
func montyMul*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType, canUseNoCarryMontyMul: static bool) = func montyMul*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType, spareBits: static int) =
## Compute r <- a*b (mod M) in the Montgomery domain ## Compute r <- a*b (mod M) in the Montgomery domain
## ##
## This resets r to zero before processing. Use {.noInit.} ## This resets r to zero before processing. Use {.noInit.}
## to avoid duplicating with Nim zero-init policy ## to avoid duplicating with Nim zero-init policy
montyMul(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, canUseNoCarryMontyMul) montyMul(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, spareBits)
func montySquare*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType, canUseNoCarryMontyMul: static bool) = func montySquare*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType, spareBits: static int) =
## Compute r <- a^2 (mod M) in the Montgomery domain ## Compute r <- a^2 (mod M) in the Montgomery domain
## ##
## This resets r to zero before processing. Use {.noInit.} ## This resets r to zero before processing. Use {.noInit.}
## to avoid duplicating with Nim zero-init policy ## to avoid duplicating with Nim zero-init policy
montySquare(r.limbs, a.limbs, M.limbs, negInvModWord, canUseNoCarryMontyMul) montySquare(r.limbs, a.limbs, M.limbs, negInvModWord, spareBits)
func montyPow*[mBits: static int]( func montyPow*[mBits: static int](
a: var BigInt[mBits], exponent: openarray[byte], a: var BigInt[mBits], exponent: openarray[byte],
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int,
canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool spareBits: static int
) = ) =
## Compute a <- a^exponent (mod M) ## Compute a <- a^exponent (mod M)
## ``a`` in the Montgomery domain ## ``a`` in the Montgomery domain
@ -92,12 +92,12 @@ func montyPow*[mBits: static int](
const scratchLen = if windowSize == 1: 2 const scratchLen = if windowSize == 1: 2
else: (1 shl windowSize) + 1 else: (1 shl windowSize) + 1
var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]] var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]]
montyPow(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare) montyPow(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, spareBits)
func montyPowUnsafeExponent*[mBits: static int]( func montyPowUnsafeExponent*[mBits: static int](
a: var BigInt[mBits], exponent: openarray[byte], a: var BigInt[mBits], exponent: openarray[byte],
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int,
canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool spareBits: static int
) = ) =
## Compute a <- a^exponent (mod M) ## Compute a <- a^exponent (mod M)
## ``a`` in the Montgomery domain ## ``a`` in the Montgomery domain
@ -116,7 +116,7 @@ func montyPowUnsafeExponent*[mBits: static int](
const scratchLen = if windowSize == 1: 2 const scratchLen = if windowSize == 1: 2
else: (1 shl windowSize) + 1 else: (1 shl windowSize) + 1
var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]] var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]]
montyPowUnsafeExponent(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare) montyPowUnsafeExponent(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, spareBits)
from ../io/io_bigints import exportRawUint from ../io/io_bigints import exportRawUint
# Workaround recursive dependencies # Workaround recursive dependencies
@ -124,7 +124,7 @@ from ../io/io_bigints import exportRawUint
func montyPow*[mBits, eBits: static int]( func montyPow*[mBits, eBits: static int](
a: var BigInt[mBits], exponent: BigInt[eBits], a: var BigInt[mBits], exponent: BigInt[eBits],
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int,
canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool spareBits: static int
) = ) =
## Compute a <- a^exponent (mod M) ## Compute a <- a^exponent (mod M)
## ``a`` in the Montgomery domain ## ``a`` in the Montgomery domain
@ -138,12 +138,12 @@ func montyPow*[mBits, eBits: static int](
var expBE {.noInit.}: array[(ebits + 7) div 8, byte] var expBE {.noInit.}: array[(ebits + 7) div 8, byte]
expBE.exportRawUint(exponent, bigEndian) expBE.exportRawUint(exponent, bigEndian)
montyPow(a, expBE, M, one, negInvModWord, windowSize, canUseNoCarryMontyMul, canUseNoCarryMontySquare) montyPow(a, expBE, M, one, negInvModWord, windowSize, spareBits)
func montyPowUnsafeExponent*[mBits, eBits: static int]( func montyPowUnsafeExponent*[mBits, eBits: static int](
a: var BigInt[mBits], exponent: BigInt[eBits], a: var BigInt[mBits], exponent: BigInt[eBits],
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int,
canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool spareBits: static int
) = ) =
## Compute a <- a^exponent (mod M) ## Compute a <- a^exponent (mod M)
## ``a`` in the Montgomery domain ## ``a`` in the Montgomery domain
@ -161,7 +161,7 @@ func montyPowUnsafeExponent*[mBits, eBits: static int](
var expBE {.noInit.}: array[(ebits + 7) div 8, byte] var expBE {.noInit.}: array[(ebits + 7) div 8, byte]
expBE.exportRawUint(exponent, bigEndian) expBE.exportRawUint(exponent, bigEndian)
montyPowUnsafeExponent(a, expBE, M, one, negInvModWord, windowSize, canUseNoCarryMontyMul, canUseNoCarryMontySquare) montyPowUnsafeExponent(a, expBE, M, one, negInvModWord, windowSize, spareBits)
{.pop.} # inline {.pop.} # inline
{.pop.} # raises no exceptions {.pop.} # raises no exceptions

View File

@ -56,7 +56,7 @@ func fromBig*(dst: var FF, src: BigInt) =
when nimvm: when nimvm:
dst.mres.montyResidue_precompute(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord()) dst.mres.montyResidue_precompute(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord())
else: else:
dst.mres.montyResidue(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) dst.mres.montyResidue(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits())
func fromBig*[C: static Curve](T: type FF[C], src: BigInt): FF[C] {.noInit.} = func fromBig*[C: static Curve](T: type FF[C], src: BigInt): FF[C] {.noInit.} =
## Convert a BigInt to its Montgomery form ## Convert a BigInt to its Montgomery form
@ -65,7 +65,7 @@ func fromBig*[C: static Curve](T: type FF[C], src: BigInt): FF[C] {.noInit.} =
func toBig*(src: FF): auto {.noInit, inline.} = func toBig*(src: FF): auto {.noInit, inline.} =
## Convert a finite-field element to a BigInt in natural representation ## Convert a finite-field element to a BigInt in natural representation
var r {.noInit.}: typeof(src.mres) var r {.noInit.}: typeof(src.mres)
r.redc(src.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) r.redc(src.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits())
return r return r
# Copy # Copy
@ -169,7 +169,7 @@ func sum*(r: var FF, a, b: FF) {.meter.} =
overflowed = overflowed or not(r.mres < FF.fieldMod()) overflowed = overflowed or not(r.mres < FF.fieldMod())
discard csub(r.mres, FF.fieldMod(), overflowed) discard csub(r.mres, FF.fieldMod(), overflowed)
func sumNoReduce*(r: var FF, a, b: FF) {.meter.} = func sumUnr*(r: var FF, a, b: FF) {.meter.} =
## Sum ``a`` and ``b`` into ``r`` without reduction ## Sum ``a`` and ``b`` into ``r`` without reduction
discard r.mres.sum(a.mres, b.mres) discard r.mres.sum(a.mres, b.mres)
@ -183,7 +183,7 @@ func diff*(r: var FF, a, b: FF) {.meter.} =
var underflowed = r.mres.diff(a.mres, b.mres) var underflowed = r.mres.diff(a.mres, b.mres)
discard cadd(r.mres, FF.fieldMod(), underflowed) discard cadd(r.mres, FF.fieldMod(), underflowed)
func diffNoReduce*(r: var FF, a, b: FF) {.meter.} = func diffUnr*(r: var FF, a, b: FF) {.meter.} =
## Substract `b` from `a` and store the result into `r` ## Substract `b` from `a` and store the result into `r`
## without reduction ## without reduction
discard r.mres.diff(a.mres, b.mres) discard r.mres.diff(a.mres, b.mres)
@ -201,11 +201,11 @@ func double*(r: var FF, a: FF) {.meter.} =
func prod*(r: var FF, a, b: FF) {.meter.} = func prod*(r: var FF, a, b: FF) {.meter.} =
## Store the product of ``a`` by ``b`` modulo p into ``r`` ## Store the product of ``a`` by ``b`` modulo p into ``r``
## ``r`` is initialized / overwritten ## ``r`` is initialized / overwritten
r.mres.montyMul(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) r.mres.montyMul(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits())
func square*(r: var FF, a: FF) {.meter.} = func square*(r: var FF, a: FF) {.meter.} =
## Squaring modulo p ## Squaring modulo p
r.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontySquare()) r.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits())
func neg*(r: var FF, a: FF) {.meter.} = func neg*(r: var FF, a: FF) {.meter.} =
## Negate modulo p ## Negate modulo p
@ -279,8 +279,7 @@ func pow*(a: var FF, exponent: BigInt) =
exponent, exponent,
FF.fieldMod(), FF.getMontyOne(), FF.fieldMod(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize, FF.getNegInvModWord(), windowSize,
FF.canUseNoCarryMontyMul(), FF.getSpareBits()
FF.canUseNoCarryMontySquare()
) )
func pow*(a: var FF, exponent: openarray[byte]) = func pow*(a: var FF, exponent: openarray[byte]) =
@ -292,8 +291,7 @@ func pow*(a: var FF, exponent: openarray[byte]) =
exponent, exponent,
FF.fieldMod(), FF.getMontyOne(), FF.fieldMod(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize, FF.getNegInvModWord(), windowSize,
FF.canUseNoCarryMontyMul(), FF.getSpareBits()
FF.canUseNoCarryMontySquare()
) )
func powUnsafeExponent*(a: var FF, exponent: BigInt) = func powUnsafeExponent*(a: var FF, exponent: BigInt) =
@ -312,8 +310,7 @@ func powUnsafeExponent*(a: var FF, exponent: BigInt) =
exponent, exponent,
FF.fieldMod(), FF.getMontyOne(), FF.fieldMod(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize, FF.getNegInvModWord(), windowSize,
FF.canUseNoCarryMontyMul(), FF.getSpareBits()
FF.canUseNoCarryMontySquare()
) )
func powUnsafeExponent*(a: var FF, exponent: openarray[byte]) = func powUnsafeExponent*(a: var FF, exponent: openarray[byte]) =
@ -332,8 +329,7 @@ func powUnsafeExponent*(a: var FF, exponent: openarray[byte]) =
exponent, exponent,
FF.fieldMod(), FF.getMontyOne(), FF.fieldMod(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize, FF.getNegInvModWord(), windowSize,
FF.canUseNoCarryMontyMul(), FF.getSpareBits()
FF.canUseNoCarryMontySquare()
) )
# ############################################################ # ############################################################
@ -350,7 +346,7 @@ func `*=`*(a: var FF, b: FF) {.meter.} =
func square*(a: var FF) {.meter.} = func square*(a: var FF) {.meter.} =
## Squaring modulo p ## Squaring modulo p
a.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontySquare()) a.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits())
func square_repeated*(r: var FF, num: int) {.meter.} = func square_repeated*(r: var FF, num: int) {.meter.} =
## Repeated squarings ## Repeated squarings
@ -389,59 +385,57 @@ func `*=`*(a: var FF, b: static int) =
elif b == 2: elif b == 2:
a.double() a.double()
elif b == 3: elif b == 3:
let t1 = a var t {.noInit.}: typeof(a)
a.double() t.double(a)
a += t1 a += t
elif b == 4: elif b == 4:
a.double() a.double()
a.double() a.double()
elif b == 5: elif b == 5:
let t1 = a var t {.noInit.}: typeof(a)
a.double() t.double(a)
a.double() t.double()
a += t1 a += t
elif b == 6: elif b == 6:
a.double() var t {.noInit.}: typeof(a)
let t2 = a t.double(a)
a.double() # 4 t += a # 3
a += t2 a.double(t)
elif b == 7: elif b == 7:
let t1 = a var t {.noInit.}: typeof(a)
a.double() t.double(a)
let t2 = a t.double()
a.double() # 4 t.double()
a += t2 a.diff(t, a)
a += t1
elif b == 8: elif b == 8:
a.double() a.double()
a.double() a.double()
a.double() a.double()
elif b == 9: elif b == 9:
let t1 = a var t {.noInit.}: typeof(a)
a.double() t.double(a)
a.double() t.double()
a.double() # 8 t.double()
a += t1 a.sum(t, a)
elif b == 10: elif b == 10:
var t {.noInit.}: typeof(a)
t.double(a)
t.double()
a += t # 5
a.double() a.double()
let t2 = a
a.double()
a.double() # 8
a += t2
elif b == 11: elif b == 11:
let t1 = a var t {.noInit.}: typeof(a)
a.double() t.double(a)
let t2 = a t += a # 3
a.double() t.double() # 6
a.double() # 8 t.double() # 12
a += t2 a.diff(t, a) # 11
a += t1
elif b == 12: elif b == 12:
a.double() var t {.noInit.}: typeof(a)
a.double() # 4 t.double(a)
let t4 = a t += a # 3
a.double() # 8 t.double() # 6
a += t4 a.double(t) # 12
else: else:
{.error: "Multiplication by this small int not implemented".} {.error: "Multiplication by this small int not implemented".}

View File

@ -0,0 +1,243 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
../config/[common, curves, type_ff],
../primitives,
./bigints,
./finite_fields,
./limbs,
./limbs_extmul,
./limbs_montgomery
when UseASM_X86_64:
import assembly/limbs_asm_modular_dbl_prec_x86
type FpDbl*[C: static Curve] = object
## Double-precision Fp element
## A FpDbl is a partially-reduced double-precision element of Fp
## The allowed range is [0, 2ⁿp)
## with n = w*WordBitSize
## and w the number of words necessary to represent p on the machine.
## Concretely a 381-bit p needs 6*64 bits limbs (hence 384 bits total)
## and so FpDbl would 768 bits.
# We directly work with double the number of limbs,
# instead of BigInt indirection.
limbs2x*: matchingLimbs2x(C)
template doublePrec*(T: type Fp): type =
## Return the double-precision type matching with Fp
FpDbl[T.C]
# No exceptions allowed
{.push raises: [].}
{.push inline.}
func `==`*(a, b: FpDbl): SecretBool =
a.limbs2x == b.limbs2x
func isZero*(a: FpDbl): SecretBool =
a.limbs2x.isZero()
func setZero*(a: var FpDbl) =
a.limbs2x.setZero()
func prod2x*(r: var FpDbl, a, b: Fp) =
## Double-precision multiplication
## Store the product of ``a`` by ``b`` into ``r``
##
## If a and b are in [0, p)
## Output is in [0, p²)
##
## Output can be up to [0, 2ⁿp) range
## provided spare bits are available in Fp representation
r.limbs2x.prod(a.mres.limbs, b.mres.limbs)
func square2x*(r: var FpDbl, a: Fp) =
## Double-precision squaring
## Store the square of ``a`` into ``r``
##
## If a is in [0, p)
## Output is in [0, p²)
##
## Output can be up to [0, 2ⁿp) range
## provided spare bits are available in Fp representation
r.limbs2x.square(a.mres.limbs)
func redc2x*(r: var Fp, a: FpDbl) =
## Reduce a double-precision field element into r
## from [0, 2ⁿp) range to [0, p) range
const N = r.mres.limbs.len
montyRedc2x(
r.mres.limbs,
a.limbs2x,
Fp.C.Mod.limbs,
Fp.getNegInvModWord(),
Fp.getSpareBits()
)
func diff2xUnr*(r: var FpDbl, a, b: FpDbl) =
## Double-precision substraction without reduction
##
## If the result is negative, fully reduced addition/substraction
## are necessary afterwards to guarantee the [0, 2ⁿp) range
discard r.limbs2x.diff(a.limbs2x, b.limbs2x)
func diff2xMod*(r: var FpDbl, a, b: FpDbl) =
## Double-precision modular substraction
## Output is conditionally reduced by 2ⁿp
## to stay in the [0, 2ⁿp) range
when UseASM_X86_64:
submod2x_asm(r.limbs2x, a.limbs2x, b.limbs2x, FpDbl.C.Mod.limbs)
else:
# Substraction step
var underflowed = SecretBool r.limbs2x.diff(a.limbs2x, b.limbs2x)
# Conditional reduction by 2ⁿp
const N = r.limbs2x.len div 2
const M = FpDbl.C.Mod
var carry = Carry(0)
var sum: SecretWord
staticFor i, 0, N:
addC(carry, sum, r.limbs2x[i+N], M.limbs[i], carry)
underflowed.ccopy(r.limbs2x[i+N], sum)
func sum2xUnr*(r: var FpDbl, a, b: FpDbl) =
## Double-precision addition without reduction
##
## If the result is bigger than 2ⁿp, fully reduced addition/substraction
## are necessary afterwards to guarantee the [0, 2ⁿp) range
discard r.limbs2x.sum(a.limbs2x, b.limbs2x)
func sum2xMod*(r: var FpDbl, a, b: FpDbl) =
## Double-precision modular addition
## Output is conditionally reduced by 2ⁿp
## to stay in the [0, 2ⁿp) range
when UseASM_X86_64:
addmod2x_asm(r.limbs2x, a.limbs2x, b.limbs2x, FpDbl.C.Mod.limbs)
else:
# Addition step
var overflowed = SecretBool r.limbs2x.sum(a.limbs2x, b.limbs2x)
const N = r.limbs2x.len div 2
const M = FpDbl.C.Mod
# Test >= 2ⁿp
var borrow = Borrow(0)
var t{.noInit.}: Limbs[N]
staticFor i, 0, N:
subB(borrow, t[i], r.limbs2x[i+N], M.limbs[i], borrow)
# If no borrow occured, r was bigger than 2ⁿp
overflowed = overflowed or not(SecretBool borrow)
# Conditional reduction by 2ⁿp
staticFor i, 0, N:
SecretBool(overflowed).ccopy(r.limbs2x[i+N], t[i])
func neg2xMod*(r: var FpDbl, a: FpDbl) =
## Double-precision modular substraction
## Negate modulo 2ⁿp
when UseASM_X86_64:
negmod2x_asm(r.limbs2x, a.limbs2x, FpDbl.C.Mod.limbs)
else:
# If a = 0 we need r = 0 and not r = M
# as comparison operator assume unicity
# of the modular representation.
# Also make sure to handle aliasing where r.addr = a.addr
var t {.noInit.}: FpDbl
let isZero = a.isZero()
const N = r.limbs2x.len div 2
const M = FpDbl.C.Mod
var borrow = Borrow(0)
# 2ⁿp is filled with 0 in the first half
staticFor i, 0, N:
subB(borrow, t.limbs2x[i], Zero, a.limbs2x[i], borrow)
# 2ⁿp has p (shifted) for the rest of the limbs
staticFor i, N, r.limbs2x.len:
subB(borrow, t.limbs2x[i], M.limbs[i-N], a.limbs2x[i], borrow)
# Zero the result if input was zero
t.limbs2x.czero(isZero)
r = t
func prod2xImpl(
r {.noAlias.}: var FpDbl,
a {.noAlias.}: FpDbl, b: static int) =
## Multiplication by a small integer known at compile-time
## Requires no aliasing and b positive
static: doAssert b >= 0
when b == 0:
r.setZero()
elif b == 1:
r = a
elif b == 2:
r.sum2xMod(a, a)
elif b == 3:
r.sum2xMod(a, a)
r.sum2xMod(a, r)
elif b == 4:
r.sum2xMod(a, a)
r.sum2xMod(r, r)
elif b == 5:
r.sum2xMod(a, a)
r.sum2xMod(r, r)
r.sum2xMod(r, a)
elif b == 6:
r.sum2xMod(a, a)
let t2 = r
r.sum2xMod(r, r) # 4
r.sum2xMod(t, t2)
elif b == 7:
r.sum2xMod(a, a)
r.sum2xMod(r, r) # 4
r.sum2xMod(r, r)
r.diff2xMod(r, a)
elif b == 8:
r.sum2xMod(a, a)
r.sum2xMod(r, r)
r.sum2xMod(r, r)
elif b == 9:
r.sum2xMod(a, a)
r.sum2xMod(r, r)
r.sum2xMod(r, r) # 8
r.sum2xMod(r, a)
elif b == 10:
r.sum2xMod(a, a)
r.sum2xMod(r, r)
r.sum2xMod(r, a) # 5
r.sum2xMod(r, r)
elif b == 11:
r.sum2xMod(a, a)
r.sum2xMod(r, r)
r.sum2xMod(r, a) # 5
r.sum2xMod(r, r)
r.sum2xMod(r, a)
elif b == 12:
r.sum2xMod(a, a)
r.sum2xMod(r, r) # 4
let t4 = a
r.sum2xMod(r, r) # 8
r.sum2xMod(r, t4)
else:
{.error: "Multiplication by this small int not implemented".}
func prod2x*(r: var FpDbl, a: FpDbl, b: static int) =
## Multiplication by a small integer known at compile-time
const negate = b < 0
const b = if negate: -b
else: b
when negate:
var t {.noInit.}: typeof(r)
t.neg2xMod(a)
else:
let t = a
prod2xImpl(r, t, b)
{.pop.} # inline
{.pop.} # raises no exceptions

View File

@ -1,81 +0,0 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
../config/[common, curves, type_ff],
../primitives,
./bigints,
./finite_fields,
./limbs,
./limbs_extmul,
./limbs_montgomery
when UseASM_X86_64:
import assembly/limbs_asm_modular_dbl_width_x86
type FpDbl*[C: static Curve] = object
## Double-width Fp element
## This allows saving on reductions
# We directly work with double the number of limbs
limbs2x*: matchingLimbs2x(C)
template doubleWidth*(T: typedesc[Fp]): typedesc =
## Return the double-width type matching with Fp
FpDbl[T.C]
# No exceptions allowed
{.push raises: [].}
{.push inline.}
func `==`*(a, b: FpDbl): SecretBool =
a.limbs2x == b.limbs2x
func mulNoReduce*(r: var FpDbl, a, b: Fp) =
## Store the product of ``a`` by ``b`` into ``r``
r.limbs2x.prod(a.mres.limbs, b.mres.limbs)
func squareNoReduce*(r: var FpDbl, a: Fp) =
## Store the square of ``a`` into ``r``
r.limbs2x.square(a.mres.limbs)
func reduce*(r: var Fp, a: FpDbl) =
## Reduce a double-width field element into r
const N = r.mres.limbs.len
montyRed(
r.mres.limbs,
a.limbs2x,
Fp.C.Mod.limbs,
Fp.getNegInvModWord(),
Fp.canUseNoCarryMontyMul()
)
func diffNoReduce*(r: var FpDbl, a, b: FpDbl) =
## Double-width substraction without reduction
discard r.limbs2x.diff(a.limbs2x, b.limbs2x)
func diff*(r: var FpDbl, a, b: FpDbl) =
## Double-width modular substraction
when UseASM_X86_64:
sub2x_asm(r.limbs2x, a.limbs2x, b.limbs2x, FpDbl.C.Mod.limbs)
else:
var underflowed = SecretBool r.limbs2x.diff(a.limbs2x, b.limbs2x)
const N = r.limbs2x.len div 2
const M = FpDbl.C.Mod
var carry = Carry(0)
var sum: SecretWord
for i in 0 ..< N:
addC(carry, sum, r.limbs2x[i+N], M.limbs[i], carry)
underflowed.ccopy(r.limbs2x[i+N], sum)
func `-=`*(a: var FpDbl, b: FpDbl) =
## Double-width modular substraction
a.diff(a, b)
{.pop.} # inline
{.pop.} # raises no exceptions

View File

@ -72,7 +72,8 @@ func prod*[rLen, aLen, bLen: static int](r: var Limbs[rLen], a: Limbs[aLen], b:
## `r` must not alias ``a`` or ``b`` ## `r` must not alias ``a`` or ``b``
when UseASM_X86_64 and aLen <= 6: when UseASM_X86_64 and aLen <= 6:
if ({.noSideEffect.}: hasBmi2()) and ({.noSideEffect.}: hasAdx()): # ADX implies BMI2
if ({.noSideEffect.}: hasAdx()):
mul_asm_adx_bmi2(r, a, b) mul_asm_adx_bmi2(r, a, b)
else: else:
mul_asm(r, a, b) mul_asm(r, a, b)

View File

@ -281,12 +281,12 @@ func montySquare_CIOS(r: var Limbs, a, M: Limbs, m0ninv: BaseType) {.used.}=
# Montgomery Reduction # Montgomery Reduction
# ------------------------------------------------------------ # ------------------------------------------------------------
func montyRed_CIOS[N: static int]( func montyRedc2x_CIOS[N: static int](
r: var array[N, SecretWord], r: var array[N, SecretWord],
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType) = m0ninv: BaseType) =
## Montgomery reduce a double-width bigint modulo M ## Montgomery reduce a double-precision bigint modulo M
# - Analyzing and Comparing Montgomery Multiplication Algorithms # - Analyzing and Comparing Montgomery Multiplication Algorithms
# Cetin Kaya Koc and Tolga Acar and Burton S. Kaliski Jr. # Cetin Kaya Koc and Tolga Acar and Burton S. Kaliski Jr.
# http://pdfs.semanticscholar.org/5e39/41ff482ec3ee41dc53c3298f0be085c69483.pdf # http://pdfs.semanticscholar.org/5e39/41ff482ec3ee41dc53c3298f0be085c69483.pdf
@ -299,7 +299,7 @@ func montyRed_CIOS[N: static int](
# Algorithm # Algorithm
# Inputs: # Inputs:
# - N number of limbs # - N number of limbs
# - a[0 ..< 2N] (double-width input to reduce) # - a[0 ..< 2N] (double-precision input to reduce)
# - M[0 ..< N] The field modulus (must be odd for Montgomery reduction) # - M[0 ..< N] The field modulus (must be odd for Montgomery reduction)
# - m0ninv: Montgomery Reduction magic number = -1/M[0] # - m0ninv: Montgomery Reduction magic number = -1/M[0]
# Output: # Output:
@ -343,12 +343,12 @@ func montyRed_CIOS[N: static int](
discard res.csub(M, SecretWord(carry).isNonZero() or not(res < M)) discard res.csub(M, SecretWord(carry).isNonZero() or not(res < M))
r = res r = res
func montyRed_Comba[N: static int]( func montyRedc2x_Comba[N: static int](
r: var array[N, SecretWord], r: var array[N, SecretWord],
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType) = m0ninv: BaseType) =
## Montgomery reduce a double-width bigint modulo M ## Montgomery reduce a double-precision bigint modulo M
# We use Product Scanning / Comba multiplication # We use Product Scanning / Comba multiplication
var t, u, v = Zero var t, u, v = Zero
var carry: Carry var carry: Carry
@ -392,7 +392,7 @@ func montyRed_Comba[N: static int](
func montyMul*( func montyMul*(
r: var Limbs, a, b, M: Limbs, r: var Limbs, a, b, M: Limbs,
m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) {.inline.} = m0ninv: static BaseType, spareBits: static int) {.inline.} =
## Compute r <- a*b (mod M) in the Montgomery domain ## Compute r <- a*b (mod M) in the Montgomery domain
## `m0ninv` = -1/M (mod SecretWord). Our words are 2^32 or 2^64 ## `m0ninv` = -1/M (mod SecretWord). Our words are 2^32 or 2^64
## ##
@ -419,9 +419,10 @@ func montyMul*(
# The implementation is visible from here, the compiler can make decision whether to: # The implementation is visible from here, the compiler can make decision whether to:
# - specialize/duplicate code for m0ninv == 1 (especially if only 1 curve is needed) # - specialize/duplicate code for m0ninv == 1 (especially if only 1 curve is needed)
# - keep it generic and optimize code size # - keep it generic and optimize code size
when canUseNoCarryMontyMul: when spareBits >= 1:
when UseASM_X86_64 and a.len in {2 .. 6}: # TODO: handle spilling when UseASM_X86_64 and a.len in {2 .. 6}: # TODO: handle spilling
if ({.noSideEffect.}: hasBmi2()) and ({.noSideEffect.}: hasAdx()): # ADX implies BMI2
if ({.noSideEffect.}: hasAdx()):
montMul_CIOS_nocarry_asm_adx_bmi2(r, a, b, M, m0ninv) montMul_CIOS_nocarry_asm_adx_bmi2(r, a, b, M, m0ninv)
else: else:
montMul_CIOS_nocarry_asm(r, a, b, M, m0ninv) montMul_CIOS_nocarry_asm(r, a, b, M, m0ninv)
@ -431,14 +432,14 @@ func montyMul*(
montyMul_FIPS(r, a, b, M, m0ninv) montyMul_FIPS(r, a, b, M, m0ninv)
func montySquare*(r: var Limbs, a, M: Limbs, func montySquare*(r: var Limbs, a, M: Limbs,
m0ninv: static BaseType, canUseNoCarryMontySquare: static bool) {.inline.} = m0ninv: static BaseType, spareBits: static int) {.inline.} =
## Compute r <- a^2 (mod M) in the Montgomery domain ## Compute r <- a^2 (mod M) in the Montgomery domain
## `m0ninv` = -1/M (mod SecretWord). Our words are 2^31 or 2^63 ## `m0ninv` = -1/M (mod SecretWord). Our words are 2^31 or 2^63
# TODO: needs optimization similar to multiplication # TODO: needs optimization similar to multiplication
montyMul(r, a, a, M, m0ninv, canUseNoCarryMontySquare) montyMul(r, a, a, M, m0ninv, spareBits)
# when canUseNoCarryMontySquare: # when spareBits >= 2:
# # TODO: Deactivated # # TODO: Deactivated
# # Off-by one on 32-bit on the least significant bit # # Off-by one on 32-bit on the least significant bit
# # for Fp[BLS12-381] with inputs # # for Fp[BLS12-381] with inputs
@ -459,26 +460,27 @@ func montySquare*(r: var Limbs, a, M: Limbs,
# montyMul_FIPS(r, a, a, M, m0ninv) # montyMul_FIPS(r, a, a, M, m0ninv)
# TODO upstream, using Limbs[N] breaks semcheck # TODO upstream, using Limbs[N] breaks semcheck
func montyRed*[N: static int]( func montyRedc2x*[N: static int](
r: var array[N, SecretWord], r: var array[N, SecretWord],
a: array[N*2, SecretWord], a: array[N*2, SecretWord],
M: array[N, SecretWord], M: array[N, SecretWord],
m0ninv: BaseType, canUseNoCarryMontyMul: static bool) {.inline.} = m0ninv: BaseType, spareBits: static int) {.inline.} =
## Montgomery reduce a double-width bigint modulo M ## Montgomery reduce a double-precision bigint modulo M
when UseASM_X86_64 and r.len <= 6: when UseASM_X86_64 and r.len <= 6:
if ({.noSideEffect.}: hasBmi2()) and ({.noSideEffect.}: hasAdx()): # ADX implies BMI2
montRed_asm_adx_bmi2(r, a, M, m0ninv, canUseNoCarryMontyMul) if ({.noSideEffect.}: hasAdx()):
montRed_asm_adx_bmi2(r, a, M, m0ninv, spareBits)
else: else:
montRed_asm(r, a, M, m0ninv, canUseNoCarryMontyMul) montRed_asm(r, a, M, m0ninv, spareBits)
elif UseASM_X86_32 and r.len <= 6: elif UseASM_X86_32 and r.len <= 6:
# TODO: Assembly faster than GCC but slower than Clang # TODO: Assembly faster than GCC but slower than Clang
montRed_asm(r, a, M, m0ninv, canUseNoCarryMontyMul) montRed_asm(r, a, M, m0ninv, spareBits)
else: else:
montyRed_CIOS(r, a, M, m0ninv) montyRedc2x_CIOS(r, a, M, m0ninv)
# montyRed_Comba(r, a, M, m0ninv) # montyRedc2x_Comba(r, a, M, m0ninv)
func redc*(r: var Limbs, a, one, M: Limbs, func redc*(r: var Limbs, a, one, M: Limbs,
m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = m0ninv: static BaseType, spareBits: static int) =
## Transform a bigint ``a`` from it's Montgomery N-residue representation (mod N) ## Transform a bigint ``a`` from it's Montgomery N-residue representation (mod N)
## to the regular natural representation (mod N) ## to the regular natural representation (mod N)
## ##
@ -497,10 +499,10 @@ func redc*(r: var Limbs, a, one, M: Limbs,
# - http://langevin.univ-tln.fr/cours/MLC/extra/montgomery.pdf # - http://langevin.univ-tln.fr/cours/MLC/extra/montgomery.pdf
# Montgomery original paper # Montgomery original paper
# #
montyMul(r, a, one, M, m0ninv, canUseNoCarryMontyMul) montyMul(r, a, one, M, m0ninv, spareBits)
func montyResidue*(r: var Limbs, a, M, r2modM: Limbs, func montyResidue*(r: var Limbs, a, M, r2modM: Limbs,
m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = m0ninv: static BaseType, spareBits: static int) =
## Transform a bigint ``a`` from it's natural representation (mod N) ## Transform a bigint ``a`` from it's natural representation (mod N)
## to a the Montgomery n-residue representation ## to a the Montgomery n-residue representation
## ##
@ -518,7 +520,7 @@ func montyResidue*(r: var Limbs, a, M, r2modM: Limbs,
## Important: `r` is overwritten ## Important: `r` is overwritten
## The result `r` buffer size MUST be at least the size of `M` buffer ## The result `r` buffer size MUST be at least the size of `M` buffer
# Reference: https://eprint.iacr.org/2017/1057.pdf # Reference: https://eprint.iacr.org/2017/1057.pdf
montyMul(r, a, r2ModM, M, m0ninv, canUseNoCarryMontyMul) montyMul(r, a, r2ModM, M, m0ninv, spareBits)
# Montgomery Modular Exponentiation # Montgomery Modular Exponentiation
# ------------------------------------------ # ------------------------------------------
@ -565,7 +567,7 @@ func montyPowPrologue(
a: var Limbs, M, one: Limbs, a: var Limbs, M, one: Limbs,
m0ninv: static BaseType, m0ninv: static BaseType,
scratchspace: var openarray[Limbs], scratchspace: var openarray[Limbs],
canUseNoCarryMontyMul: static bool spareBits: static int
): uint = ): uint =
## Setup the scratchspace ## Setup the scratchspace
## Returns the fixed-window size for exponentiation with window optimization. ## Returns the fixed-window size for exponentiation with window optimization.
@ -579,7 +581,7 @@ func montyPowPrologue(
else: else:
scratchspace[2] = a scratchspace[2] = a
for k in 2 ..< 1 shl result: for k in 2 ..< 1 shl result:
scratchspace[k+1].montyMul(scratchspace[k], a, M, m0ninv, canUseNoCarryMontyMul) scratchspace[k+1].montyMul(scratchspace[k], a, M, m0ninv, spareBits)
# Set a to one # Set a to one
a = one a = one
@ -593,7 +595,7 @@ func montyPowSquarings(
window: uint, window: uint,
acc, acc_len: var uint, acc, acc_len: var uint,
e: var int, e: var int,
canUseNoCarryMontySquare: static bool spareBits: static int
): tuple[k, bits: uint] {.inline.}= ): tuple[k, bits: uint] {.inline.}=
## Squaring step of exponentiation by squaring ## Squaring step of exponentiation by squaring
## Get the next k bits in range [1, window) ## Get the next k bits in range [1, window)
@ -629,7 +631,7 @@ func montyPowSquarings(
# We have k bits and can do k squaring # We have k bits and can do k squaring
for i in 0 ..< k: for i in 0 ..< k:
tmp.montySquare(a, M, m0ninv, canUseNoCarryMontySquare) tmp.montySquare(a, M, m0ninv, spareBits)
a = tmp a = tmp
return (k, bits) return (k, bits)
@ -640,8 +642,7 @@ func montyPow*(
M, one: Limbs, M, one: Limbs,
m0ninv: static BaseType, m0ninv: static BaseType,
scratchspace: var openarray[Limbs], scratchspace: var openarray[Limbs],
canUseNoCarryMontyMul: static bool, spareBits: static int
canUseNoCarryMontySquare: static bool
) = ) =
## Modular exponentiation r = a^exponent mod M ## Modular exponentiation r = a^exponent mod M
## in the Montgomery domain ## in the Montgomery domain
@ -669,7 +670,7 @@ func montyPow*(
## A window of size 5 requires (2^5 + 1)*(381 + 7)/8 = 33 * 48 bytes = 1584 bytes ## A window of size 5 requires (2^5 + 1)*(381 + 7)/8 = 33 * 48 bytes = 1584 bytes
## of scratchspace (on the stack). ## of scratchspace (on the stack).
let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, canUseNoCarryMontyMul) let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, spareBits)
# We process bits with from most to least significant. # We process bits with from most to least significant.
# At each loop iteration with have acc_len bits in acc. # At each loop iteration with have acc_len bits in acc.
@ -684,7 +685,7 @@ func montyPow*(
a, exponent, M, m0ninv, a, exponent, M, m0ninv,
scratchspace[0], window, scratchspace[0], window,
acc, acc_len, e, acc, acc_len, e,
canUseNoCarryMontySquare spareBits
) )
# Window lookup: we set scratchspace[1] to the lookup value. # Window lookup: we set scratchspace[1] to the lookup value.
@ -699,7 +700,7 @@ func montyPow*(
# Multiply with the looked-up value # Multiply with the looked-up value
# we keep the product only if the exponent bits are not all zeroes # we keep the product only if the exponent bits are not all zeroes
scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, canUseNoCarryMontyMul) scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, spareBits)
a.ccopy(scratchspace[0], SecretWord(bits).isNonZero()) a.ccopy(scratchspace[0], SecretWord(bits).isNonZero())
func montyPowUnsafeExponent*( func montyPowUnsafeExponent*(
@ -708,8 +709,7 @@ func montyPowUnsafeExponent*(
M, one: Limbs, M, one: Limbs,
m0ninv: static BaseType, m0ninv: static BaseType,
scratchspace: var openarray[Limbs], scratchspace: var openarray[Limbs],
canUseNoCarryMontyMul: static bool, spareBits: static int
canUseNoCarryMontySquare: static bool
) = ) =
## Modular exponentiation r = a^exponent mod M ## Modular exponentiation r = a^exponent mod M
## in the Montgomery domain ## in the Montgomery domain
@ -723,7 +723,7 @@ func montyPowUnsafeExponent*(
# TODO: scratchspace[1] is unused when window > 1 # TODO: scratchspace[1] is unused when window > 1
let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, canUseNoCarryMontyMul) let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, spareBits)
var var
acc, acc_len: uint acc, acc_len: uint
@ -733,16 +733,16 @@ func montyPowUnsafeExponent*(
a, exponent, M, m0ninv, a, exponent, M, m0ninv,
scratchspace[0], window, scratchspace[0], window,
acc, acc_len, e, acc, acc_len, e,
canUseNoCarryMontySquare spareBits
) )
## Warning ⚠️: Exposes the exponent bits ## Warning ⚠️: Exposes the exponent bits
if bits != 0: if bits != 0:
if window > 1: if window > 1:
scratchspace[0].montyMul(a, scratchspace[1+bits], M, m0ninv, canUseNoCarryMontyMul) scratchspace[0].montyMul(a, scratchspace[1+bits], M, m0ninv, spareBits)
else: else:
# scratchspace[1] holds the original `a` # scratchspace[1] holds the original `a`
scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, canUseNoCarryMontyMul) scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, spareBits)
a = scratchspace[0] a = scratchspace[0]
{.pop.} # raises no exceptions {.pop.} # raises no exceptions

View File

@ -51,18 +51,10 @@ macro genDerivedConstants*(mode: static DerivedConstantMode): untyped =
let M = if mode == kModulus: bindSym(curve & "_Modulus") let M = if mode == kModulus: bindSym(curve & "_Modulus")
else: bindSym(curve & "_Order") else: bindSym(curve & "_Order")
# const MyCurve_CanUseNoCarryMontyMul = useNoCarryMontyMul(MyCurve_Modulus) # const MyCurve_SpareBits = countSpareBits(MyCurve_Modulus)
result.add newConstStmt( result.add newConstStmt(
used(curve & ff & "_CanUseNoCarryMontyMul"), newCall( used(curve & ff & "_SpareBits"), newCall(
bindSym"useNoCarryMontyMul", bindSym"countSpareBits",
M
)
)
# const MyCurve_CanUseNoCarryMontySquare = useNoCarryMontySquare(MyCurve_Modulus)
result.add newConstStmt(
used(curve & ff & "_CanUseNoCarryMontySquare"), newCall(
bindSym"useNoCarryMontySquare",
M M
) )
) )

View File

@ -61,15 +61,18 @@ template fieldMod*(Field: type FF): auto =
else: else:
Field.C.getCurveOrder() Field.C.getCurveOrder()
macro canUseNoCarryMontyMul*(ff: type FF): untyped = macro getSpareBits*(ff: type FF): untyped =
## Returns true if the Modulus is compatible with a fast ## Returns the number of extra bits
## Montgomery multiplication that avoids many carries ## in the modulus M representation.
result = bindConstant(ff, "CanUseNoCarryMontyMul") ##
## This is used for no-carry operations
macro canUseNoCarryMontySquare*(ff: type FF): untyped = ## or lazily reduced operations by allowing
## Returns true if the Modulus is compatible with a fast ## output in range:
## Montgomery squaring that avoids many carries ## - [0, 2p) if 1 bit is available
result = bindConstant(ff, "CanUseNoCarryMontySquare") ## - [0, 4p) if 2 bits are available
## - [0, 8p) if 3 bits are available
## - ...
result = bindConstant(ff, "SpareBits")
macro getR2modP*(ff: type FF): untyped = macro getR2modP*(ff: type FF): untyped =
## Get the Montgomery "R^2 mod P" constant associated to a curve field modulus ## Get the Montgomery "R^2 mod P" constant associated to a curve field modulus

View File

@ -238,21 +238,20 @@ func checkValidModulus(M: BigInt) =
doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those" doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those"
func useNoCarryMontyMul*(M: BigInt): bool = func countSpareBits*(M: BigInt): int =
## Returns if the modulus is compatible ## Count the number of extra bits
## with the no-carry Montgomery Multiplication ## in the modulus M representation.
## from https://hackmd.io/@zkteam/modular_multiplication ##
# Indirection needed because static object are buggy ## This is used for no-carry operations
# https://github.com/nim-lang/Nim/issues/9679 ## or lazily reduced operations by allowing
BaseType(M.limbs[^1]) < high(BaseType) shr 1 ## output in range:
## - [0, 2p) if 1 bit is available
func useNoCarryMontySquare*(M: BigInt): bool = ## - [0, 4p) if 2 bits are available
## Returns if the modulus is compatible ## - [0, 8p) if 3 bits are available
## with the no-carry Montgomery Squaring ## - ...
## from https://hackmd.io/@zkteam/modular_multiplication checkValidModulus(M)
# Indirection needed because static object are buggy let msb = log2(BaseType(M.limbs[^1]))
# https://github.com/nim-lang/Nim/issues/9679 result = WordBitWidth - 1 - msb.int
BaseType(M.limbs[^1]) < high(BaseType) shr 2
func invModBitwidth[T: SomeUnsignedInt](a: T): T = func invModBitwidth[T: SomeUnsignedInt](a: T): T =
# We use BaseType for return value because static distinct type # We use BaseType for return value because static distinct type

View File

@ -87,6 +87,16 @@ From Ben Edgington, https://hackmd.io/@benjaminion/bls12-381
Jean-Luc Beuchat and Jorge Enrique González Díaz and Shigeo Mitsunari and Eiji Okamoto and Francisco Rodríguez-Henríquez and Tadanori Teruya, 2010\ Jean-Luc Beuchat and Jorge Enrique González Díaz and Shigeo Mitsunari and Eiji Okamoto and Francisco Rodríguez-Henríquez and Tadanori Teruya, 2010\
https://eprint.iacr.org/2010/354 https://eprint.iacr.org/2010/354
- Faster Explicit Formulas for Computing Pairings over Ordinary Curves\
Diego F. Aranha and Koray Karabina and Patrick Longa and Catherine H. Gebotys and Julio López, 2010\
https://eprint.iacr.org/2010/526.pdf\
https://www.iacr.org/archive/eurocrypt2011/66320047/66320047.pdf
- Efficient Implementation of Bilinear Pairings on ARM Processors
Gurleen Grewal, Reza Azarderakhsh,
Patrick Longa, Shi Hu, and David Jao, 2012
https://eprint.iacr.org/2012/408.pdf
- Choosing and generating parameters for low level pairing implementation on BN curves\ - Choosing and generating parameters for low level pairing implementation on BN curves\
Sylvain Duquesne and Nadia El Mrabet and Safia Haloui and Franck Rondepierre, 2015\ Sylvain Duquesne and Nadia El Mrabet and Safia Haloui and Franck Rondepierre, 2015\
https://eprint.iacr.org/2015/1212 https://eprint.iacr.org/2015/1212

View File

@ -47,11 +47,11 @@ template c1*(a: ExtensionField): auto =
template c2*(a: CubicExt): auto = template c2*(a: CubicExt): auto =
a.coords[2] a.coords[2]
template `c0=`*(a: ExtensionField, v: auto) = template `c0=`*(a: var ExtensionField, v: auto) =
a.coords[0] = v a.coords[0] = v
template `c1=`*(a: ExtensionField, v: auto) = template `c1=`*(a: var ExtensionField, v: auto) =
a.coords[1] = v a.coords[1] = v
template `c2=`*(a: CubicExt, v: auto) = template `c2=`*(a: var CubicExt, v: auto) =
a.coords[2] = v a.coords[2] = v
template C*(E: type ExtensionField): Curve = template C*(E: type ExtensionField): Curve =
@ -222,88 +222,343 @@ func csub*(a: var ExtensionField, b: ExtensionField, ctl: SecretBool) =
func `*=`*(a: var ExtensionField, b: static int) = func `*=`*(a: var ExtensionField, b: static int) =
## Multiplication by a small integer known at compile-time ## Multiplication by a small integer known at compile-time
for i in 0 ..< a.coords.len:
const negate = b < 0 a.coords[i] *= b
const b = if negate: -b
else: b
when negate:
a.neg(a)
when b == 0:
a.setZero()
elif b == 1:
return
elif b == 2:
a.double()
elif b == 3:
let t1 = a
a.double()
a += t1
elif b == 4:
a.double()
a.double()
elif b == 5:
let t1 = a
a.double()
a.double()
a += t1
elif b == 6:
a.double()
let t2 = a
a.double() # 4
a += t2
elif b == 7:
let t1 = a
a.double()
let t2 = a
a.double() # 4
a += t2
a += t1
elif b == 8:
a.double()
a.double()
a.double()
elif b == 9:
let t1 = a
a.double()
a.double()
a.double() # 8
a += t1
elif b == 10:
a.double()
let t2 = a
a.double()
a.double() # 8
a += t2
elif b == 11:
let t1 = a
a.double()
let t2 = a
a.double()
a.double() # 8
a += t2
a += t1
elif b == 12:
a.double()
a.double() # 4
let t4 = a
a.double() # 8
a += t4
else:
{.error: "Multiplication by this small int not implemented".}
func prod*(r: var ExtensionField, a: ExtensionField, b: static int) = func prod*(r: var ExtensionField, a: ExtensionField, b: static int) =
## Multiplication by a small integer known at compile-time ## Multiplication by a small integer known at compile-time
const negate = b < 0 r = a
const b = if negate: -b
else: b
when negate:
r.neg(a)
else:
r = a
r *= b r *= b
{.pop.} # inline {.pop.} # inline
# ############################################################
# #
# Lazy reduced extension fields #
# #
# ############################################################
type
QuadraticExt2x[F] = object
## Quadratic Extension field for lazy reduced fields
coords: array[2, F]
CubicExt2x[F] = object
## Cubic Extension field for lazy reduced fields
coords: array[3, F]
ExtensionField2x[F] = QuadraticExt2x[F] or CubicExt2x[F]
template doublePrec(T: type ExtensionField): type =
# For now naive unrolling, recursive template don't match
# and I don't want to deal with types in macros
when T is QuadraticExt:
when T.F is QuadraticExt: # Fp4Dbl
QuadraticExt2x[QuadraticExt2x[doublePrec(T.F.F)]]
elif T.F is Fp: # Fp2Dbl
QuadraticExt2x[doublePrec(T.F)]
elif T is CubicExt:
when T.F is QuadraticExt: # Fp6Dbl
CubicExt2x[QuadraticExt2x[doublePrec(T.F.F)]]
func has1extraBit(F: type Fp): bool =
## We construct extensions only on Fp (and not Fr)
getSpareBits(F) >= 1
func has2extraBits(F: type Fp): bool =
## We construct extensions only on Fp (and not Fr)
getSpareBits(F) >= 2
func has1extraBit(E: type ExtensionField): bool =
## We construct extensions only on Fp (and not Fr)
getSpareBits(Fp[E.F.C]) >= 1
func has2extraBits(E: type ExtensionField): bool =
## We construct extensions only on Fp (and not Fr)
getSpareBits(Fp[E.F.C]) >= 2
template C(E: type ExtensionField2x): Curve =
E.F.C
template c0(a: ExtensionField2x): auto =
a.coords[0]
template c1(a: ExtensionField2x): auto =
a.coords[1]
template c2(a: CubicExt2x): auto =
a.coords[2]
template `c0=`(a: var ExtensionField2x, v: auto) =
a.coords[0] = v
template `c1=`(a: var ExtensionField2x, v: auto) =
a.coords[1] = v
template `c2=`(a: var CubicExt2x, v: auto) =
a.coords[2] = v
# Initialization
# -------------------------------------------------------------------
func setZero*(a: var ExtensionField2x) =
## Set ``a`` to 0 in the extension field
staticFor i, 0, a.coords.len:
a.coords[i].setZero()
# Abelian group
# -------------------------------------------------------------------
func sumUnr(r: var ExtensionField, a, b: ExtensionField) =
## Sum ``a`` and ``b`` into ``r``
staticFor i, 0, a.coords.len:
r.coords[i].sumUnr(a.coords[i], b.coords[i])
func diff2xUnr(r: var ExtensionField2x, a, b: ExtensionField2x) =
## Double-precision substraction without reduction
staticFor i, 0, a.coords.len:
r.coords[i].diff2xUnr(a.coords[i], b.coords[i])
func diff2xMod(r: var ExtensionField2x, a, b: ExtensionField2x) =
## Double-precision modular substraction
staticFor i, 0, a.coords.len:
r.coords[i].diff2xMod(a.coords[i], b.coords[i])
func sum2xUnr(r: var ExtensionField2x, a, b: ExtensionField2x) =
## Double-precision addition without reduction
staticFor i, 0, a.coords.len:
r.coords[i].sum2xUnr(a.coords[i], b.coords[i])
func sum2xMod(r: var ExtensionField2x, a, b: ExtensionField2x) =
## Double-precision modular addition
staticFor i, 0, a.coords.len:
r.coords[i].sum2xMod(a.coords[i], b.coords[i])
func neg2xMod(r: var ExtensionField2x, a: ExtensionField2x) =
## Double-precision modular negation
staticFor i, 0, a.coords.len:
r.coords[i].neg2xMod(a.coords[i], b.coords[i])
# Reductions
# -------------------------------------------------------------------
func redc2x(r: var ExtensionField, a: ExtensionField2x) =
## Reduction
staticFor i, 0, a.coords.len:
r.coords[i].redc2x(a.coords[i])
# Multiplication by a small integer known at compile-time
# -------------------------------------------------------------------
func prod2x(r: var ExtensionField2x, a: ExtensionField2x, b: static int) =
## Multiplication by a small integer known at compile-time
for i in 0 ..< a.coords.len:
r.coords[i].prod2x(a.coords[i], b)
# NonResidue
# ----------------------------------------------------------------------
func prod2x(r: var FpDbl, a: FpDbl, _: type NonResidue){.inline.} =
## Multiply an element of 𝔽p by the quadratic non-residue
## chosen to construct 𝔽p2
static: doAssert FpDbl.C.getNonResidueFp() != -1, "𝔽p2 should be specialized for complex extension"
r.prod2x(a, FpDbl.C.getNonResidueFp())
func prod2x[C: static Curve](
r {.noalias.}: var QuadraticExt2x[FpDbl[C]],
a {.noalias.}: QuadraticExt2x[FpDbl[C]],
_: type NonResidue) {.inline.} =
## Multiplication by non-residue
## ! no aliasing!
const complex = C.getNonResidueFp() == -1
const U = C.getNonResidueFp2()[0]
const V = C.getNonResidueFp2()[1]
const Beta {.used.} = C.getNonResidueFp()
when complex and U == 1 and V == 1:
r.c0.diff2xMod(a.c0, a.c1)
r.c1.sum2xMod(a.c0, a.c1)
else:
# Case:
# - BN254_Snarks, QNR_Fp: -1, SNR_Fp2: 9+1𝑖 (𝑖 = √-1)
# - BLS12_377, QNR_Fp: -5, SNR_Fp2: 0+1j (j = √-5)
# - BW6_761, SNR_Fp: -4, CNR_Fp2: 0+1j (j = √-4)
when U == 0:
# mul_sparse_by_0v
# r0 = β a1 v
# r1 = a0 v
# r and a don't alias, we use `r` as a temp location
r.c1.prod2x(a.c1, V)
r.c0.prod2x(r.c1, NonResidue)
r.c1.prod2x(a.c0, V)
else:
# ξ = u + v x
# and x² = β
#
# (c0 + c1 x) (u + v x) => u c0 + (u c0 + u c1)x + v c1 x²
# => u c0 + β v c1 + (v c0 + u c1) x
var t {.noInit.}: FpDbl[C]
r.c0.prod2x(a.c0, U)
when V == 1 and Beta == -1: # Case BN254_Snarks
r.c0.diff2xMod(r.c0, a.c1) # r0 = u c0 + β v c1
else:
{.error: "Unimplemented".}
r.c1.prod2x(a.c0, V)
t.prod2x(a.c1, U)
r.c1.sum2xMod(r.c1, t) # r1 = v c0 + u c1
# ############################################################
# #
# Quadratic extensions - Lazy Reductions #
# #
# ############################################################
# Forward declarations
# ----------------------------------------------------------------------
func prod2x(r: var QuadraticExt2x, a, b: QuadraticExt)
func square2x(r: var QuadraticExt2x, a: QuadraticExt)
# Commutative ring implementation for complex quadratic extension fields
# ----------------------------------------------------------------------
func prod2x_complex(r: var QuadraticExt2x, a, b: QuadraticExt) =
## Double-precision unreduced complex multiplication
# r and a or b cannot alias
mixin fromComplexExtension
static: doAssert a.fromComplexExtension()
var D {.noInit.}: typeof(r.c0)
var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0)
r.c0.prod2x(a.c0, b.c0) # r0 = a0 b0
D.prod2x(a.c1, b.c1) # d = a1 b1
when QuadraticExt.has1extraBit():
t0.sumUnr(a.c0, a.c1)
t1.sumUnr(b.c0, b.c1)
else:
t0.sum(a.c0, a.c1)
t1.sum(b.c0, b.c1)
r.c1.prod2x(t0, t1) # r1 = (b0 + b1)(a0 + a1)
when QuadraticExt.has1extraBit():
r.c1.diff2xUnr(r.c1, r.c0) # r1 = (b0 + b1)(a0 + a1) - a0 b0
r.c1.diff2xUnr(r.c1, D) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - a1b1
else:
r.c1.diff2xMod(r.c1, r.c0)
r.c1.diff2xMod(r.c1, D)
r.c0.diff2xMod(r.c0, D) # r0 = a0 b0 - a1 b1
func square2x_complex(r: var QuadraticExt2x, a: QuadraticExt) =
## Double-precision unreduced complex squaring
mixin fromComplexExtension
static: doAssert a.fromComplexExtension()
var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0)
# Require 2 extra bits
when QuadraticExt.has2extraBits():
t0.sumUnr(a.c1, a.c1)
t1.sum(a.c0, a.c1)
else:
t0.double(a.c1)
t1.sum(a.c0, a.c1)
r.c1.prod2x(t0, a.c0) # r1 = 2a0a1
t0.diff(a.c0, a.c1)
r.c0.prod2x(t0, t1) # r0 = (a0 + a1)(a0 - a1)
# Commutative ring implementation for generic quadratic extension fields
# ----------------------------------------------------------------------
#
# Some sparse functions, reconstruct a Fp4 from disjoint pieces
# to limit copies, we provide versions with disjoint elements
# prod2x_disjoint:
# - 2 products in mul_sparse_by_line_xyz000 (Fp4)
# - 2 products in mul_sparse_by_line_xy000z (Fp4)
# - mul_by_line_xy0 in mul_sparse_by_line_xy00z0 (Fp6)
#
# square2x_disjoint:
# - cyclotomic square in Fp2 -> Fp6 -> Fp12 towering
# needs Fp4 as special case
func prod2x_disjoint[Fdbl, F](
r: var QuadraticExt2x[FDbl],
a: QuadraticExt[F],
b0, b1: F) =
## Return a * (b0, b1) in r
static: doAssert Fdbl is doublePrec(F)
var V0 {.noInit.}, V1 {.noInit.}: typeof(r.c0) # Double-precision
var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) # Single-width
# Require 2 extra bits
V0.prod2x(a.c0, b0) # v0 = a0b0
V1.prod2x(a.c1, b1) # v1 = a1b1
when F.has1extraBit():
t0.sumUnr(a.c0, a.c1)
t1.sumUnr(b0, b1)
else:
t0.sum(a.c0, a.c1)
t1.sum(b0, b1)
r.c1.prod2x(t0, t1) # r1 = (a0 + a1)(b0 + b1)
when F.has1extraBit():
r.c1.diff2xMod(r.c1, V0)
r.c1.diff2xMod(r.c1, V1)
else:
r.c1.diff2xMod(r.c1, V0) # r1 = (a0 + a1)(b0 + b1) - a0b0
r.c1.diff2xMod(r.c1, V1) # r1 = (a0 + a1)(b0 + b1) - a0b0 - a1b1
r.c0.prod2x(V1, NonResidue) # r0 = β a1 b1
r.c0.sum2xMod(r.c0, V0) # r0 = a0 b0 + β a1 b1
func square2x_disjoint[Fdbl, F](
r: var QuadraticExt2x[FDbl],
a0, a1: F) =
## Return (a0, a1)² in r
var V0 {.noInit.}, V1 {.noInit.}: typeof(r.c0) # Double-precision
var t {.noInit.}: F # Single-width
# TODO: which is the best formulation? 3 squarings or 2 Mul?
# It seems like the higher the tower the better squarings are
# So for Fp12 = 2xFp6, prefer squarings.
V0.square2x(a0)
V1.square2x(a1)
t.sum(a0, a1)
# r0 = a0² + β a1² (option 1) <=> (a0 + a1)(a0 + β a1) - β a0a1 - a0a1 (option 2)
r.c0.prod2x(V1, NonResidue)
r.c0.sum2xMod(r.c0, V0)
# r1 = 2 a0 a1 (option 1) = (a0 + a1)² - a0² - a1² (option 2)
r.c1.square2x(t)
r.c1.diff2xMod(r.c1, V0)
r.c1.diff2xMod(r.c1, V1)
# Dispatch
# ----------------------------------------------------------------------
func prod2x(r: var QuadraticExt2x, a, b: QuadraticExt) =
mixin fromComplexExtension
when a.fromComplexExtension():
r.prod2x_complex(a, b)
else:
r.prod2x_disjoint(a, b.c0, b.c1)
func square2x(r: var QuadraticExt2x, a: QuadraticExt) =
mixin fromComplexExtension
when a.fromComplexExtension():
r.square2x_complex(a)
else:
r.square2x_disjoint(a.c0, a.c1)
# ############################################################
# #
# Cubic extensions - Lazy Reductions #
# #
# ############################################################
# ############################################################ # ############################################################
# # # #
# Quadratic extensions # # Quadratic extensions #
@ -386,60 +641,18 @@ func prod_complex(r: var QuadraticExt, a, b: QuadraticExt) =
mixin fromComplexExtension mixin fromComplexExtension
static: doAssert r.fromComplexExtension() static: doAssert r.fromComplexExtension()
# TODO: GCC is adding an unexplainable 30 cycles tax to this function (~10% slow down) var a0b0 {.noInit.}, a1b1 {.noInit.}: typeof(r.c0)
# for seemingly no reason a0b0.prod(a.c0, b.c0) # [1 Mul]
a1b1.prod(a.c1, b.c1) # [2 Mul]
when false: # Single-width implementation - BLS12-381 r.c0.sum(a.c0, a.c1) # r0 = (a0 + a1) # [2 Mul, 1 Add]
# Clang 348 cycles on i9-9980XE @3.9 GHz r.c1.sum(b.c0, b.c1) # r1 = (b0 + b1) # [2 Mul, 2 Add]
var a0b0 {.noInit.}, a1b1 {.noInit.}: typeof(r.c0) # aliasing: a and b unneeded now
a0b0.prod(a.c0, b.c0) # [1 Mul] r.c1 *= r.c0 # r1 = (b0 + b1)(a0 + a1) # [3 Mul, 2 Add] - 𝔽p temporary
a1b1.prod(a.c1, b.c1) # [2 Mul]
r.c0.sum(a.c0, a.c1) # r0 = (a0 + a1) # [2 Mul, 1 Add] r.c0.diff(a0b0, a1b1) # r0 = a0 b0 - a1 b1 # [3 Mul, 2 Add, 1 Sub]
r.c1.sum(b.c0, b.c1) # r1 = (b0 + b1) # [2 Mul, 2 Add] r.c1 -= a0b0 # r1 = (b0 + b1)(a0 + a1) - a0b0 # [3 Mul, 2 Add, 2 Sub]
# aliasing: a and b unneeded now r.c1 -= a1b1 # r1 = (b0 + b1)(a0 + a1) - a0b0 - a1b1 # [3 Mul, 2 Add, 3 Sub]
r.c1 *= r.c0 # r1 = (b0 + b1)(a0 + a1) # [3 Mul, 2 Add] - 𝔽p temporary
r.c0.diff(a0b0, a1b1) # r0 = a0 b0 - a1 b1 # [3 Mul, 2 Add, 1 Sub]
r.c1 -= a0b0 # r1 = (b0 + b1)(a0 + a1) - a0b0 # [3 Mul, 2 Add, 2 Sub]
r.c1 -= a1b1 # r1 = (b0 + b1)(a0 + a1) - a0b0 - a1b1 # [3 Mul, 2 Add, 3 Sub]
else: # Double-width implementation with lazy reduction
# Clang 341 cycles on i9-9980XE @3.9 GHz
var a0b0 {.noInit.}, a1b1 {.noInit.}: doubleWidth(typeof(r.c0))
var d {.noInit.}: doubleWidth(typeof(r.c0))
const msbSet = r.c0.typeof.canUseNoCarryMontyMul()
a0b0.mulNoReduce(a.c0, b.c0) # 44 cycles - cumul 44
a1b1.mulNoReduce(a.c1, b.c1) # 44 cycles - cumul 88
when msbSet:
r.c0.sum(a.c0, a.c1)
r.c1.sum(b.c0, b.c1)
else:
r.c0.sumNoReduce(a.c0, a.c1) # 5 cycles - cumul 93
r.c1.sumNoReduce(b.c0, b.c1) # 5 cycles - cumul 98
# aliasing: a and b unneeded now
d.mulNoReduce(r.c0, r.c1) # 44 cycles - cumul 142
when msbSet:
d -= a0b0
d -= a1b1
else:
d.diffNoReduce(d, a0b0) # 11 cycles - cumul 153
d.diffNoReduce(d, a1b1) # 11 cycles - cumul 164
a0b0.diff(a0b0, a1b1) # 19 cycles - cumul 183
r.c0.reduce(a0b0) # 50 cycles - cumul 233
r.c1.reduce(d) # 50 cycles - cumul 288
# Single-width [3 Mul, 2 Add, 3 Sub]
# 3*88 + 2*14 + 3*14 = 334 theoretical cycles
# 348 measured
# Double-Width
# 288 theoretical cycles
# 329 measured
# Unexplained 40 cycles diff between theo and measured
# and unexplained 30 cycles between Clang and GCC
# - Function calls?
# - push/pop stack?
func mul_sparse_complex_by_0y( func mul_sparse_complex_by_0y(
r: var QuadraticExt, a: QuadraticExt, r: var QuadraticExt, a: QuadraticExt,
@ -497,31 +710,67 @@ func square_generic(r: var QuadraticExt, a: QuadraticExt) =
# #
# Alternative 2: # Alternative 2:
# c0² + β c1² <=> (c0 + c1)(c0 + β c1) - β c0c1 - c0c1 # c0² + β c1² <=> (c0 + c1)(c0 + β c1) - β c0c1 - c0c1
mixin prod #
var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) # This gives us 2 Mul and 2 mul-nonresidue (which is costly for BN254_Snarks)
#
# We can also reframe the 2nd term with only squarings
# which might be significantly faster on higher tower degrees
#
# 2 c0 c1 <=> (a0 + a1)² - a0² - a1²
#
# This gives us 3 Sqr and 1 Mul-non-residue
const costlyMul = block:
# No shortcutting in the VM :/
when a.c0 is ExtensionField:
when a.c0.c0 is ExtensionField:
true
else:
false
else:
false
# v1 <- (c0 + β c1) when QuadraticExt.C == BN254_Snarks or costlyMul:
v1.prod(a.c1, NonResidue) var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0)
v1 += a.c0 v0.square(a.c0)
v1.square(a.c1)
# v0 <- (c0 + c1)(c0 + β c1) # Aliasing: a unneeded now
v0.sum(a.c0, a.c1) r.c1.sum(a.c0, a.c1)
v0 *= v1
# v1 <- c0 c1 # r0 = c0² + β c1²
v1.prod(a.c0, a.c1) r.c0.prod(v1, NonResidue)
r.c0 += v0
# aliasing: a unneeded now # r1 = (a0 + a1)² - a0² - a1²
r.c1.square()
r.c1 -= v0
r.c1 -= v1
# r0 = (c0 + c1)(c0 + β c1) - c0c1 else:
v0 -= v1 var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0)
# r1 = 2 c0c1 # v1 <- (c0 + β c1)
r.c1.double(v1) v1.prod(a.c1, NonResidue)
v1 += a.c0
# r0 = (c0 + c1)(c0 + β c1) - c0c1 - β c0c1 # v0 <- (c0 + c1)(c0 + β c1)
v1 *= NonResidue v0.sum(a.c0, a.c1)
r.c0.diff(v0, v1) v0 *= v1
# v1 <- c0 c1
v1.prod(a.c0, a.c1)
# aliasing: a unneeded now
# r0 = (c0 + c1)(c0 + β c1) - c0c1
v0 -= v1
# r1 = 2 c0c1
r.c1.double(v1)
# r0 = (c0 + c1)(c0 + β c1) - c0c1 - β c0c1
v1 *= NonResidue
r.c0.diff(v0, v1)
func prod_generic(r: var QuadraticExt, a, b: QuadraticExt) = func prod_generic(r: var QuadraticExt, a, b: QuadraticExt) =
## Returns r = a * b ## Returns r = a * b
@ -529,7 +778,6 @@ func prod_generic(r: var QuadraticExt, a, b: QuadraticExt) =
# #
# r0 = a0 b0 + β a1 b1 # r0 = a0 b0 + β a1 b1
# r1 = (a0 + a1) (b0 + b1) - a0 b0 - a1 b1 (Karatsuba) # r1 = (a0 + a1) (b0 + b1) - a0 b0 - a1 b1 (Karatsuba)
mixin prod
var v0 {.noInit.}, v1 {.noInit.}, v2 {.noInit.}: typeof(r.c0) var v0 {.noInit.}, v1 {.noInit.}, v2 {.noInit.}: typeof(r.c0)
# v2 <- (a0 + a1)(b0 + b1) # v2 <- (a0 + a1)(b0 + b1)
@ -564,7 +812,6 @@ func mul_sparse_generic_by_x0(r: var QuadraticExt, a, sparseB: QuadraticExt) =
# #
# r0 = a0 b0 # r0 = a0 b0
# r1 = (a0 + a1) b0 - a0 b0 = a1 b0 # r1 = (a0 + a1) b0 - a0 b0 = a1 b0
mixin prod
template b(): untyped = sparseB template b(): untyped = sparseB
r.c0.prod(a.c0, b.c0) r.c0.prod(a.c0, b.c0)
@ -658,21 +905,52 @@ func invImpl(r: var QuadraticExt, a: QuadraticExt) =
# Exported quadratic symbols # Exported quadratic symbols
# ------------------------------------------------------------------- # -------------------------------------------------------------------
{.push inline.}
func square*(r: var QuadraticExt, a: QuadraticExt) = func square*(r: var QuadraticExt, a: QuadraticExt) =
mixin fromComplexExtension mixin fromComplexExtension
when r.fromComplexExtension(): when r.fromComplexExtension():
r.square_complex(a) when true:
r.square_complex(a)
else: # slower
var d {.noInit.}: doublePrec(typeof(r))
d.square2x_complex(a)
r.c0.redc2x(d.c0)
r.c1.redc2x(d.c1)
else: else:
r.square_generic(a) when true: # r.typeof.F.C in {BLS12_377, BW6_761}:
# BW6-761 requires too many registers for Dbl width path
r.square_generic(a)
else:
# TODO understand why Fp4[BLS12_377]
# is so slow in the branch
# TODO:
# - On Fp4, we can have a.c0.c0 off by p
# a reduction is missing
var d {.noInit.}: doublePrec(typeof(r))
d.square2x_disjoint(a.c0, a.c1)
r.c0.redc2x(d.c0)
r.c1.redc2x(d.c1)
func prod*(r: var QuadraticExt, a, b: QuadraticExt) = func prod*(r: var QuadraticExt, a, b: QuadraticExt) =
mixin fromComplexExtension mixin fromComplexExtension
when r.fromComplexExtension(): when r.fromComplexExtension():
r.prod_complex(a, b) when false:
r.prod_complex(a, b)
else: # faster
var d {.noInit.}: doublePrec(typeof(r))
d.prod2x_complex(a, b)
r.c0.redc2x(d.c0)
r.c1.redc2x(d.c1)
else: else:
r.prod_generic(a, b) when r.typeof.F.C == BW6_761 or typeof(r.c0) is Fp:
# BW6-761 requires too many registers for Dbl width path
r.prod_generic(a, b)
else:
var d {.noInit.}: doublePrec(typeof(r))
d.prod2x_disjoint(a, b.c0, b.c1)
r.c0.redc2x(d.c0)
r.c1.redc2x(d.c1)
{.push inline.}
func inv*(r: var QuadraticExt, a: QuadraticExt) = func inv*(r: var QuadraticExt, a: QuadraticExt) =
## Compute the multiplicative inverse of ``a`` ## Compute the multiplicative inverse of ``a``
@ -765,7 +1043,6 @@ func mul_sparse_by_x0*(a: var QuadraticExt, sparseB: QuadraticExt) =
func square_Chung_Hasan_SQR2(r: var CubicExt, a: CubicExt) {.used.}= func square_Chung_Hasan_SQR2(r: var CubicExt, a: CubicExt) {.used.}=
## Returns r = a² ## Returns r = a²
mixin prod, square, sum
var s0{.noInit.}, m01{.noInit.}, m12{.noInit.}: typeof(r.c0) var s0{.noInit.}, m01{.noInit.}, m12{.noInit.}: typeof(r.c0)
# precomputations that use a # precomputations that use a
@ -801,7 +1078,6 @@ func square_Chung_Hasan_SQR2(r: var CubicExt, a: CubicExt) {.used.}=
func square_Chung_Hasan_SQR3(r: var CubicExt, a: CubicExt) = func square_Chung_Hasan_SQR3(r: var CubicExt, a: CubicExt) =
## Returns r = a² ## Returns r = a²
mixin prod, square, sum
var s0{.noInit.}, t{.noInit.}, m12{.noInit.}: typeof(r.c0) var s0{.noInit.}, t{.noInit.}, m12{.noInit.}: typeof(r.c0)
# s₀ = (a₀ + a₁ + a₂)² # s₀ = (a₀ + a₁ + a₂)²

View File

@ -116,17 +116,24 @@ func prod*(r: var Fp2, a: Fp2, _: type NonResidue) {.inline.} =
# BLS12_377 and BW6_761, use small addition chain # BLS12_377 and BW6_761, use small addition chain
r.mul_sparse_by_0y(a, v) r.mul_sparse_by_0y(a, v)
else: else:
# BN254_Snarks, u = 9 # BN254_Snarks, u = 9, v = 1, β = -1
# Full 𝔽p2 multiplication is cheaper than addition chains # Even with u = 9, the 2x9 addition chains (8 additions total)
# for u*c0 and u*c1 # are cheaper than full Fp2 multiplication
static: var t {.noInit.}: typeof(a.c0)
doAssert u >= 0 and uint64(u) <= uint64(high(BaseType))
doAssert v >= 0 and uint64(v) <= uint64(high(BaseType)) t.prod(a.c0, u)
# TODO: compile-time when v == 1 and Beta == -1: # Case BN254_Snarks
var NR {.noInit.}: Fp2 t -= a.c1 # r0 = u c0 + β v c1
NR.c0.fromUint(uint u) else:
NR.c1.fromUint(uint v) {.error: "Unimplemented".}
r.prod(a, NR)
r.c1.prod(a.c1, u)
when v == 1: # r1 = v c0 + u c1
r.c1 += a.c0
# aliasing: a.c0 is unused
r.c0 = t
else:
{.error: "Unimplemented".}
func `*=`*(a: var Fp2, _: type NonResidue) {.inline.} = func `*=`*(a: var Fp2, _: type NonResidue) {.inline.} =
## Multiply an element of 𝔽p2 by the non-residue ## Multiply an element of 𝔽p2 by the non-residue

View File

@ -26,10 +26,10 @@ The optimizations can be of algebraic, algorithmic or "implementation details" n
- [x] x86: MULX, ADCX, ADOX instructions - [x] x86: MULX, ADCX, ADOX instructions
- [x] Fused Multiply + Shift-right by word (for Barrett Reduction and approximating multiplication by fractional constant) - [x] Fused Multiply + Shift-right by word (for Barrett Reduction and approximating multiplication by fractional constant)
- Squaring - Squaring
- [ ] Dedicated squaring functions - [x] Dedicated squaring functions
- [ ] int128 - [x] int128
- [ ] loop unrolling - [ ] loop unrolling
- [ ] x86: Full Assembly implementation - [x] x86: Full Assembly implementation
- [ ] x86: MULX, ADCX, ADOX instructions - [ ] x86: MULX, ADCX, ADOX instructions
## Finite Fields & Modular Arithmetic ## Finite Fields & Modular Arithmetic
@ -107,13 +107,13 @@ The optimizations can be of algebraic, algorithmic or "implementation details" n
## Extension Fields ## Extension Fields
- [ ] Lazy reduction via double-width base fields - [x] Lazy reduction via double-precision base fields
- [x] Sparse multiplication - [x] Sparse multiplication
- Fp2 - Fp2
- [x] complex multiplication - [x] complex multiplication
- [x] complex squaring - [x] complex squaring
- [x] sqrt via the constant-time complex method (Adj et al) - [x] sqrt via the constant-time complex method (Adj et al)
- [ ] sqrt using addition chain - [x] sqrt using addition chain
- [x] fused complex method sqrt by rotating in complex plane - [x] fused complex method sqrt by rotating in complex plane
- Cubic extension fields - Cubic extension fields
- [x] Toom-Cook polynomial multiplication (Chung-Hasan) - [x] Toom-Cook polynomial multiplication (Chung-Hasan)

View File

@ -146,7 +146,7 @@ func random_unsafe(rng: var RngState, a: var FF) =
# Note: a simple modulo will be biaised but it's simple and "fast" # Note: a simple modulo will be biaised but it's simple and "fast"
reduced.reduce(unreduced, FF.fieldMod()) reduced.reduce(unreduced, FF.fieldMod())
a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits())
func random_unsafe(rng: var RngState, a: var ExtensionField) = func random_unsafe(rng: var RngState, a: var ExtensionField) =
## Recursively initialize an extension Field element ## Recursively initialize an extension Field element
@ -177,7 +177,7 @@ func random_highHammingWeight(rng: var RngState, a: var FF) =
# Note: a simple modulo will be biaised but it's simple and "fast" # Note: a simple modulo will be biaised but it's simple and "fast"
reduced.reduce(unreduced, FF.fieldMod()) reduced.reduce(unreduced, FF.fieldMod())
a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits())
func random_highHammingWeight(rng: var RngState, a: var ExtensionField) = func random_highHammingWeight(rng: var RngState, a: var ExtensionField) =
## Recursively initialize an extension Field element ## Recursively initialize an extension Field element
@ -222,7 +222,7 @@ func random_long01Seq(rng: var RngState, a: var FF) =
# Note: a simple modulo will be biaised but it's simple and "fast" # Note: a simple modulo will be biaised but it's simple and "fast"
reduced.reduce(unreduced, FF.fieldMod()) reduced.reduce(unreduced, FF.fieldMod())
a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits())
func random_long01Seq(rng: var RngState, a: var ExtensionField) = func random_long01Seq(rng: var RngState, a: var ExtensionField) =
## Recursively initialize an extension Field element ## Recursively initialize an extension Field element

View File

@ -0,0 +1,228 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
# Standard library
std/[unittest, times],
# Internal
../constantine/arithmetic,
../constantine/io/[io_bigints, io_fields],
../constantine/config/[curves, common, type_bigint],
# Test utilities
../helpers/prng_unsafe
const Iters = 24
var rng: RngState
let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
rng.seed(seed)
echo "\n------------------------------------------------------\n"
echo "test_finite_fields_double_precision xoshiro512** seed: ", seed
template addsubnegTest(rng_gen: untyped): untyped =
proc `addsubneg _ rng_gen`(C: static Curve) =
# Try to exercise all code paths for in-place/out-of-place add/sum/sub/diff/double/neg
# (1 - (-a) - b + (-a) - 2a) + (2a + 2b + (-b)) == 1
let aFp = rng_gen(rng, Fp[C])
let bFp = rng_gen(rng, Fp[C])
var accumFp {.noInit.}: Fp[C]
var OneFp {.noInit.}: Fp[C]
var accum {.noInit.}, One {.noInit.}, a{.noInit.}, na{.noInit.}, b{.noInit.}, nb{.noInit.}, a2 {.noInit.}, b2 {.noInit.}: FpDbl[C]
OneFp.setOne()
One.prod2x(OneFp, OneFp)
a.prod2x(aFp, OneFp)
b.prod2x(bFp, OneFp)
block: # sanity check
var t: Fp[C]
t.redc2x(One)
doAssert bool t.isOne()
a2.sum2xMod(a, a)
na.neg2xMod(a)
block: # sanity check
var t0, t1: Fp[C]
t0.redc2x(na)
t1.neg(aFp)
doAssert bool(t0 == t1),
"Beware, if the hex are the same, it means the outputs are the same (mod p),\n" &
"but one might not be completely reduced\n" &
" t0: " & t0.toHex() & "\n" &
" t1: " & t1.toHex() & "\n"
block: # sanity check
var t0, t1: Fp[C]
t0.redc2x(a2)
t1.double(aFp)
doAssert bool(t0 == t1),
"Beware, if the hex are the same, it means the outputs are the same (mod p),\n" &
"but one might not be completely reduced\n" &
" t0: " & t0.toHex() & "\n" &
" t1: " & t1.toHex() & "\n"
b2.sum2xMod(b, b)
nb.neg2xMod(b)
accum.diff2xMod(One, na)
accum.diff2xMod(accum, b)
accum.sum2xMod(accum, na)
accum.diff2xMod(accum, a2)
var t{.noInit.}: FpDbl[C]
t.sum2xMod(a2, b2)
t.sum2xMod(t, nb)
accum.sum2xMod(accum, t)
accumFp.redc2x(accum)
doAssert bool accumFp.isOne(),
"Beware, if the hex are the same, it means the outputs are the same (mod p),\n" &
"but one might not be completely reduced\n" &
" accumFp: " & accumFp.toHex()
template mulTest(rng_gen: untyped): untyped =
proc `mul _ rng_gen`(C: static Curve) =
let a = rng_gen(rng, Fp[C])
let b = rng_gen(rng, Fp[C])
var r_fp{.noInit.}, r_fpDbl{.noInit.}: Fp[C]
var tmpDbl{.noInit.}: FpDbl[C]
r_fp.prod(a, b)
tmpDbl.prod2x(a, b)
r_fpDbl.redc2x(tmpDbl)
doAssert bool(r_fp == r_fpDbl)
template sqrTest(rng_gen: untyped): untyped =
proc `sqr _ rng_gen`(C: static Curve) =
let a = rng_gen(rng, Fp[C])
var mulDbl{.noInit.}, sqrDbl{.noInit.}: FpDbl[C]
mulDbl.prod2x(a, a)
sqrDbl.square2x(a)
doAssert bool(mulDbl == sqrDbl)
addsubnegTest(random_unsafe)
addsubnegTest(randomHighHammingWeight)
addsubnegTest(random_long01Seq)
mulTest(random_unsafe)
mulTest(randomHighHammingWeight)
mulTest(random_long01Seq)
sqrTest(random_unsafe)
sqrTest(randomHighHammingWeight)
sqrTest(random_long01Seq)
suite "Field Addition/Substraction/Negation via double-precision field elements" & " [" & $WordBitwidth & "-bit mode]":
test "With P-224 field modulus":
for _ in 0 ..< Iters:
addsubneg_random_unsafe(P224)
for _ in 0 ..< Iters:
addsubneg_randomHighHammingWeight(P224)
for _ in 0 ..< Iters:
addsubneg_random_long01Seq(P224)
test "With P-256 field modulus":
for _ in 0 ..< Iters:
addsubneg_random_unsafe(P256)
for _ in 0 ..< Iters:
addsubneg_randomHighHammingWeight(P256)
for _ in 0 ..< Iters:
addsubneg_random_long01Seq(P256)
test "With BN254_Snarks field modulus":
for _ in 0 ..< Iters:
addsubneg_random_unsafe(BN254_Snarks)
for _ in 0 ..< Iters:
addsubneg_randomHighHammingWeight(BN254_Snarks)
for _ in 0 ..< Iters:
addsubneg_random_long01Seq(BN254_Snarks)
test "With BLS12_381 field modulus":
for _ in 0 ..< Iters:
addsubneg_random_unsafe(BLS12_381)
for _ in 0 ..< Iters:
addsubneg_randomHighHammingWeight(BLS12_381)
for _ in 0 ..< Iters:
addsubneg_random_long01Seq(BLS12_381)
test "Negate 0 returns 0 (unique Montgomery repr)":
var a: FpDbl[BN254_Snarks]
var r {.noInit.}: FpDbl[BN254_Snarks]
r.neg2xMod(a)
check: bool r.isZero()
suite "Field Multiplication via double-precision field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]":
test "With P-224 field modulus":
for _ in 0 ..< Iters:
mul_random_unsafe(P224)
for _ in 0 ..< Iters:
mul_randomHighHammingWeight(P224)
for _ in 0 ..< Iters:
mul_random_long01Seq(P224)
test "With P-256 field modulus":
for _ in 0 ..< Iters:
mul_random_unsafe(P256)
for _ in 0 ..< Iters:
mul_randomHighHammingWeight(P256)
for _ in 0 ..< Iters:
mul_random_long01Seq(P256)
test "With BN254_Snarks field modulus":
for _ in 0 ..< Iters:
mul_random_unsafe(BN254_Snarks)
for _ in 0 ..< Iters:
mul_randomHighHammingWeight(BN254_Snarks)
for _ in 0 ..< Iters:
mul_random_long01Seq(BN254_Snarks)
test "With BLS12_381 field modulus":
for _ in 0 ..< Iters:
mul_random_unsafe(BLS12_381)
for _ in 0 ..< Iters:
mul_randomHighHammingWeight(BLS12_381)
for _ in 0 ..< Iters:
mul_random_long01Seq(BLS12_381)
suite "Field Squaring via double-precision field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]":
test "With P-224 field modulus":
for _ in 0 ..< Iters:
sqr_random_unsafe(P224)
for _ in 0 ..< Iters:
sqr_randomHighHammingWeight(P224)
for _ in 0 ..< Iters:
sqr_random_long01Seq(P224)
test "With P-256 field modulus":
for _ in 0 ..< Iters:
sqr_random_unsafe(P256)
for _ in 0 ..< Iters:
sqr_randomHighHammingWeight(P256)
for _ in 0 ..< Iters:
sqr_random_long01Seq(P256)
test "With BN254_Snarks field modulus":
for _ in 0 ..< Iters:
sqr_random_unsafe(BN254_Snarks)
for _ in 0 ..< Iters:
sqr_randomHighHammingWeight(BN254_Snarks)
for _ in 0 ..< Iters:
sqr_random_long01Seq(BN254_Snarks)
test "With BLS12_381 field modulus":
for _ in 0 ..< Iters:
sqr_random_unsafe(BLS12_381)
for _ in 0 ..< Iters:
sqr_randomHighHammingWeight(BLS12_381)
for _ in 0 ..< Iters:
sqr_random_long01Seq(BLS12_381)

View File

@ -1,123 +0,0 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
# Standard library
std/[unittest, times],
# Internal
../constantine/arithmetic,
../constantine/io/[io_bigints, io_fields],
../constantine/config/[curves, common, type_bigint],
# Test utilities
../helpers/prng_unsafe
const Iters = 24
var rng: RngState
let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
rng.seed(seed)
echo "\n------------------------------------------------------\n"
echo "test_finite_fields_double_width xoshiro512** seed: ", seed
template mulTest(rng_gen: untyped): untyped =
proc `mul _ rng_gen`(C: static Curve) =
let a = rng_gen(rng, Fp[C])
let b = rng.random_unsafe(Fp[C])
var r_fp{.noInit.}, r_fpDbl{.noInit.}: Fp[C]
var tmpDbl{.noInit.}: FpDbl[C]
r_fp.prod(a, b)
tmpDbl.mulNoReduce(a, b)
r_fpDbl.reduce(tmpDbl)
doAssert bool(r_fp == r_fpDbl)
template sqrTest(rng_gen: untyped): untyped =
proc `sqr _ rng_gen`(C: static Curve) =
let a = rng_gen(rng, Fp[C])
var mulDbl{.noInit.}, sqrDbl{.noInit.}: FpDbl[C]
mulDbl.mulNoReduce(a, a)
sqrDbl.squareNoReduce(a)
doAssert bool(mulDbl == sqrDbl)
mulTest(random_unsafe)
mulTest(randomHighHammingWeight)
mulTest(random_long01Seq)
sqrTest(random_unsafe)
sqrTest(randomHighHammingWeight)
sqrTest(random_long01Seq)
suite "Field Multiplication via double-width field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]":
test "With P-224 field modulus":
for _ in 0 ..< Iters:
mul_random_unsafe(P224)
for _ in 0 ..< Iters:
mul_randomHighHammingWeight(P224)
for _ in 0 ..< Iters:
mul_random_long01Seq(P224)
test "With P-256 field modulus":
for _ in 0 ..< Iters:
mul_random_unsafe(P256)
for _ in 0 ..< Iters:
mul_randomHighHammingWeight(P256)
for _ in 0 ..< Iters:
mul_random_long01Seq(P256)
test "With BN254_Snarks field modulus":
for _ in 0 ..< Iters:
mul_random_unsafe(BN254_Snarks)
for _ in 0 ..< Iters:
mul_randomHighHammingWeight(BN254_Snarks)
for _ in 0 ..< Iters:
mul_random_long01Seq(BN254_Snarks)
test "With BLS12_381 field modulus":
for _ in 0 ..< Iters:
mul_random_unsafe(BLS12_381)
for _ in 0 ..< Iters:
mul_randomHighHammingWeight(BLS12_381)
for _ in 0 ..< Iters:
mul_random_long01Seq(BLS12_381)
suite "Field Squaring via double-width field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]":
test "With P-224 field modulus":
for _ in 0 ..< Iters:
sqr_random_unsafe(P224)
for _ in 0 ..< Iters:
sqr_randomHighHammingWeight(P224)
for _ in 0 ..< Iters:
sqr_random_long01Seq(P224)
test "With P-256 field modulus":
for _ in 0 ..< Iters:
sqr_random_unsafe(P256)
for _ in 0 ..< Iters:
sqr_randomHighHammingWeight(P256)
for _ in 0 ..< Iters:
sqr_random_long01Seq(P256)
test "With BN254_Snarks field modulus":
for _ in 0 ..< Iters:
sqr_random_unsafe(BN254_Snarks)
for _ in 0 ..< Iters:
sqr_randomHighHammingWeight(BN254_Snarks)
for _ in 0 ..< Iters:
sqr_random_long01Seq(BN254_Snarks)
test "With BLS12_381 field modulus":
for _ in 0 ..< Iters:
sqr_random_unsafe(BLS12_381)
for _ in 0 ..< Iters:
sqr_randomHighHammingWeight(BLS12_381)
for _ in 0 ..< Iters:
sqr_random_long01Seq(BLS12_381)

View File

@ -27,7 +27,7 @@ echo "test_finite_fields_mulsquare xoshiro512** seed: ", seed
static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option" static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option"
proc sanity(C: static Curve) = proc sanity(C: static Curve) =
test "Squaring 0,1,2 with "& $Curve(C) & " [FastSquaring = " & $Fp[C].canUseNoCarryMontySquare & "]": test "Squaring 0,1,2 with "& $Curve(C) & " [FastSquaring = " & $(Fp[C].getSpareBits() >= 2) & "]":
block: # 0² mod block: # 0² mod
var n: Fp[C] var n: Fp[C]
@ -89,7 +89,7 @@ mainSanity()
proc mainSelectCases() = proc mainSelectCases() =
suite "Modular Squaring: selected tricky cases" & " [" & $WordBitwidth & "-bit mode]": suite "Modular Squaring: selected tricky cases" & " [" & $WordBitwidth & "-bit mode]":
test "P-256 [FastSquaring = " & $Fp[P256].canUseNoCarryMontySquare & "]": test "P-256 [FastSquaring = " & $(Fp[P256].getSpareBits() >= 2) & "]":
block: block:
# Triggered an issue in the (t[N+1], t[N]) = t[N] + (A1, A0) # Triggered an issue in the (t[N+1], t[N]) = t[N] + (A1, A0)
# between the squaring and reduction step, with t[N+1] and A1 being carry bits. # between the squaring and reduction step, with t[N+1] and A1 being carry bits.
@ -136,7 +136,7 @@ proc random_long01Seq(C: static Curve) =
doAssert bool(r_mul == r_sqr) doAssert bool(r_mul == r_sqr)
suite "Random Modular Squaring is consistent with Modular Multiplication" & " [" & $WordBitwidth & "-bit mode]": suite "Random Modular Squaring is consistent with Modular Multiplication" & " [" & $WordBitwidth & "-bit mode]":
test "Random squaring mod P-224 [FastSquaring = " & $Fp[P224].canUseNoCarryMontySquare & "]": test "Random squaring mod P-224 [FastSquaring = " & $(Fp[P224].getSpareBits() >= 2) & "]":
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
randomCurve(P224) randomCurve(P224)
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
@ -144,7 +144,8 @@ suite "Random Modular Squaring is consistent with Modular Multiplication" & " ["
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
random_long01Seq(P224) random_long01Seq(P224)
test "Random squaring mod P-256 [FastSquaring = " & $Fp[P256].canUseNoCarryMontySquare & "]": test "Random squaring mod P-256 [FastSquaring = " & $(Fp[P256].getSpareBits() >= 2) & "]":
echo "Fp[P256].getSpareBits(): ", Fp[P256].getSpareBits()
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
randomCurve(P256) randomCurve(P256)
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
@ -152,7 +153,7 @@ suite "Random Modular Squaring is consistent with Modular Multiplication" & " ["
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
random_long01Seq(P256) random_long01Seq(P256)
test "Random squaring mod BLS12_381 [FastSquaring = " & $Fp[BLS12_381].canUseNoCarryMontySquare & "]": test "Random squaring mod BLS12_381 [FastSquaring = " & $(Fp[BLS12_381].getSpareBits() >= 2) & "]":
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
randomCurve(BLS12_381) randomCurve(BLS12_381)
for _ in 0 ..< Iters: for _ in 0 ..< Iters:

129
tests/t_fp4.nim Normal file
View File

@ -0,0 +1,129 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
std/unittest,
# Internals
../constantine/towers,
../constantine/io/io_towers,
../constantine/config/curves,
# Test utilities
./t_fp_tower_template
const TestCurves = [
BN254_Nogami,
BN254_Snarks,
BLS12_377,
BLS12_381,
BW6_761
]
runTowerTests(
ExtDegree = 4,
Iters = 12,
TestCurves = TestCurves,
moduleName = "test_fp4",
testSuiteDesc = "𝔽p4 = 𝔽p2[v]"
)
# Fuzzing failure
# Issue when using Fp4Dbl
suite "𝔽p4 - Anti-regression":
test "Partial reduction (off by p) on double-precision field":
proc partred1() =
type F = Fp4[BN254_Snarks]
var x: F
x.fromHex(
"0x0000000000000000000fffffffffffffffffe000000fffffffffcffffff80000",
"0x000000000000007ffffffffff800000001fffe000000000007ffffffffffffe0",
"0x000000c0ff0300fcffffffff7f00000000f0ffffffffffffffff00000000e0ff",
"0x0e0a77c19a07df27e5eea36f7879462c0a7ceb28e5c70b3dd35d438dc58f4d9c"
)
# echo "x: ", x.toHex()
# echo "\n----------------------"
var s: F
s.square(x)
# echo "s: ", s.toHex()
# echo "\ns raw: ", s
# echo "\n----------------------"
var p: F
p.prod(x, x)
# echo "p: ", p.toHex()
# echo "\np raw: ", p
check: bool(p == s)
partred1()
proc partred2() =
type F = Fp4[BN254_Snarks]
var x: F
x.fromHex(
"0x0660df54c75b67a0c32fc6208f08b13d8cc86cd93084180725a04884e7f45849",
"0x094185b0915ce1aa3bd3c63d33fd6d9cf3f04ea30fc88efe1e6e9b59117513bb",
"0x26c20beee711e46406372ab4f0e6d0069c67ded0a494bc0301bbfde48f7a4073",
"0x23c60254946def07120e46155466cc9b883b5c3d1c17d1d6516a6268a41dcc5d"
)
# echo "x: ", x.toHex()
# echo "\n----------------------"
var s: F
s.square(x)
# echo "s: ", s.toHex()
# echo "\ns raw: ", s
# echo "\n----------------------"
var p: F
p.prod(x, x)
# echo "p: ", p.toHex()
# echo "\np raw: ", p
check: bool(p == s)
partred2()
proc partred3() =
type F = Fp4[BN254_Snarks]
var x: F
x.fromHex(
"0x233066f735efcf7a0ad6e3ffa3afe4ed39bdfeffffb3f7d8b1fd7eeabfddfb36",
"0x1caba0b27fdfdfd512bdecf3fffbfebdb939fffffffbff8a14e663f7fef7fc85",
"0x212a64f0efefff1b7abe2ebe2bffbfc1b9335fb73ffd7c8815ffffffffffff8d",
"0x212ba4b1ff8feff552a61efff5ffffc5b839f7ffffffff71f477dffe7ffc7e08"
)
# echo "x: ", x.toHex()
# echo "\n----------------------"
var s: F
s.square(x)
# echo "s: ", s.toHex()
# echo "\ns raw: ", s
# echo "\n----------------------"
var n, s2: F
n.neg(x)
s2.prod(n, n)
# echo "s2: ", s2.toHex()
# echo "\ns2 raw: ", s2
check: bool(s == s2)
partred3()

View File

@ -20,6 +20,7 @@ import
../constantine/towers, ../constantine/towers,
../constantine/config/[common, curves], ../constantine/config/[common, curves],
../constantine/arithmetic, ../constantine/arithmetic,
../constantine/io/io_towers,
# Test utilities # Test utilities
../helpers/[prng_unsafe, static_for] ../helpers/[prng_unsafe, static_for]
@ -28,6 +29,8 @@ echo "\n------------------------------------------------------\n"
template ExtField(degree: static int, curve: static Curve): untyped = template ExtField(degree: static int, curve: static Curve): untyped =
when degree == 2: when degree == 2:
Fp2[curve] Fp2[curve]
elif degree == 4:
Fp4[curve]
elif degree == 6: elif degree == 6:
Fp6[curve] Fp6[curve]
elif degree == 12: elif degree == 12:
@ -273,7 +276,7 @@ proc runTowerTests*[N](
rMul.prod(a, a) rMul.prod(a, a)
rSqr.square(a) rSqr.square(a)
check: bool(rMul == rSqr) doAssert bool(rMul == rSqr), "Failure with a (" & $Field & "): " & a.toHex()
staticFor(curve, TestCurves): staticFor(curve, TestCurves):
test(ExtField(ExtDegree, curve), Iters, gen = Uniform) test(ExtField(ExtDegree, curve), Iters, gen = Uniform)
@ -292,7 +295,7 @@ proc runTowerTests*[N](
rSqr.square(a) rSqr.square(a)
rNegSqr.square(na) rNegSqr.square(na)
check: bool(rSqr == rNegSqr) doAssert bool(rSqr == rNegSqr), "Failure with a (" & $Field & "): " & a.toHex()
staticFor(curve, TestCurves): staticFor(curve, TestCurves):
test(ExtField(ExtDegree, curve), Iters, gen = Uniform) test(ExtField(ExtDegree, curve), Iters, gen = Uniform)

View File

@ -25,7 +25,7 @@ echo "\n------------------------------------------------------\n"
echo "test_fr xoshiro512** seed: ", seed echo "test_fr xoshiro512** seed: ", seed
proc sanity(C: static Curve) = proc sanity(C: static Curve) =
test "Fr: Squaring 0,1,2 with "& $Fr[C] & " [FastSquaring = " & $Fr[C].canUseNoCarryMontySquare & "]": test "Fr: Squaring 0,1,2 with "& $Fr[C] & " [FastSquaring = " & $(Fr[C].getSpareBits() >= 2) & "]":
block: # 0² mod block: # 0² mod
var n: Fr[C] var n: Fr[C]
@ -112,7 +112,7 @@ proc random_long01Seq(C: static Curve) =
doAssert bool(r_mul == r_sqr) doAssert bool(r_mul == r_sqr)
suite "Fr: Random Modular Squaring is consistent with Modular Multiplication" & " [" & $WordBitwidth & "-bit mode]": suite "Fr: Random Modular Squaring is consistent with Modular Multiplication" & " [" & $WordBitwidth & "-bit mode]":
test "Random squaring mod r_BN254_Snarks [FastSquaring = " & $Fr[BN254_Snarks].canUseNoCarryMontySquare & "]": test "Random squaring mod r_BN254_Snarks [FastSquaring = " & $(Fr[BN254_Snarks].getSpareBits() >= 2) & "]":
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
randomCurve(BN254_Snarks) randomCurve(BN254_Snarks)
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
@ -120,7 +120,7 @@ suite "Fr: Random Modular Squaring is consistent with Modular Multiplication" &
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
random_long01Seq(BN254_Snarks) random_long01Seq(BN254_Snarks)
test "Random squaring mod r_BLS12_381 [FastSquaring = " & $Fr[BLS12_381].canUseNoCarryMontySquare & "]": test "Random squaring mod r_BLS12_381 [FastSquaring = " & $(Fr[BLS12_381].getSpareBits() >= 2) & "]":
for _ in 0 ..< Iters: for _ in 0 ..< Iters:
randomCurve(BLS12_381) randomCurve(BLS12_381)
for _ in 0 ..< Iters: for _ in 0 ..< Iters: