mirror of
https://github.com/logos-storage/constantine.git
synced 2026-01-03 05:33:07 +00:00
Parallel Multi-Scalar-Multiplication (#226)
* try parallel reduction in batch add, but alas it's slower than custom chunking. Except maybe on arch with performance/efficiency cores * initial impl of parallel MSM - scaling to debug, threads not woken fast enough * improve comment [skip ci] * skip top window when c divides the number of bits * for some reason parallel-for loops scale on 5+ threads while spawn only on 2x threads. Thread wakeup issue? * Add counters and timers to audit threadpool bottlenecks * metrics and profiling fixes, (slower) latency hiding, activate tests * fix thief thread trying to wake another before canceling its own sleep * easier to sort metrics and parallel endomorphism application * selective endomorphism acceleration * some tuning * spawn can handle compile-time literals, static and type parameters. Also introduce spawnAwaitable to await void procs * improve MSM overview [skip ci] * bench cleanup
This commit is contained in:
parent
4dc2610557
commit
6c48975aee
@ -11,37 +11,12 @@ import
|
||||
../constantine/math/config/curves,
|
||||
../constantine/math/arithmetic,
|
||||
../constantine/math/elliptic/[
|
||||
ec_shortweierstrass_affine,
|
||||
ec_shortweierstrass_projective,
|
||||
ec_shortweierstrass_jacobian,
|
||||
ec_shortweierstrass_jacobian_extended,
|
||||
ec_shortweierstrass_batch_ops_parallel],
|
||||
../constantine/platforms/threadpool/threadpool,
|
||||
ec_shortweierstrass_jacobian_extended],
|
||||
# Helpers
|
||||
../helpers/prng_unsafe,
|
||||
./bench_elliptic_template,
|
||||
./bench_blueprint
|
||||
./bench_elliptic_template, ./bench_elliptic_parallel_template
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Parallel Benchmark definitions
|
||||
#
|
||||
# ############################################################
|
||||
|
||||
proc multiAddParallelBench*(EC: typedesc, numPoints: int, iters: int) =
|
||||
var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](numPoints)
|
||||
|
||||
for i in 0 ..< numPoints:
|
||||
points[i] = rng.random_unsafe(ECP_ShortW_Aff[EC.F, EC.G])
|
||||
|
||||
var r{.noInit.}: EC
|
||||
|
||||
var tp = Threadpool.new()
|
||||
|
||||
bench("EC parallel batch add (" & align($tp.numThreads, 2) & " threads) " & $EC.G & " (" & $numPoints & " points)", EC, iters):
|
||||
tp.sum_reduce_vartime_parallel(r, points)
|
||||
|
||||
tp.shutdown()
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
@ -75,18 +50,18 @@ proc main() =
|
||||
doublingBench(ECP_ShortW_JacExt[Fp[curve], G1], Iters)
|
||||
mixedAddBench(ECP_ShortW_JacExt[Fp[curve], G1], Iters)
|
||||
separator()
|
||||
for numPoints in testNumPoints:
|
||||
let batchIters = max(1, Iters div numPoints)
|
||||
multiAddBench(ECP_ShortW_Prj[Fp[curve], G1], numPoints, useBatching = false, batchIters)
|
||||
separator()
|
||||
for numPoints in testNumPoints:
|
||||
let batchIters = max(1, Iters div numPoints)
|
||||
multiAddBench(ECP_ShortW_Prj[Fp[curve], G1], numPoints, useBatching = true, batchIters)
|
||||
separator()
|
||||
for numPoints in testNumPoints:
|
||||
let batchIters = max(1, Iters div numPoints)
|
||||
multiAddParallelBench(ECP_ShortW_Prj[Fp[curve], G1], numPoints, batchIters)
|
||||
separator()
|
||||
# for numPoints in testNumPoints:
|
||||
# let batchIters = max(1, Iters div numPoints)
|
||||
# multiAddBench(ECP_ShortW_Prj[Fp[curve], G1], numPoints, useBatching = false, batchIters)
|
||||
# separator()
|
||||
# for numPoints in testNumPoints:
|
||||
# let batchIters = max(1, Iters div numPoints)
|
||||
# multiAddBench(ECP_ShortW_Prj[Fp[curve], G1], numPoints, useBatching = true, batchIters)
|
||||
# separator()
|
||||
# for numPoints in testNumPoints:
|
||||
# let batchIters = max(1, Iters div numPoints)
|
||||
# multiAddParallelBench(ECP_ShortW_Prj[Fp[curve], G1], numPoints, batchIters)
|
||||
# separator()
|
||||
for numPoints in testNumPoints:
|
||||
let batchIters = max(1, Iters div numPoints)
|
||||
multiAddBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, useBatching = false, batchIters)
|
||||
|
||||
@ -11,16 +11,11 @@ import
|
||||
../constantine/math/config/curves,
|
||||
../constantine/math/arithmetic,
|
||||
../constantine/math/elliptic/[
|
||||
ec_shortweierstrass_affine,
|
||||
ec_shortweierstrass_projective,
|
||||
ec_shortweierstrass_jacobian,
|
||||
ec_scalar_mul,
|
||||
ec_multi_scalar_mul],
|
||||
../constantine/math/constants/zoo_subgroups,
|
||||
ec_shortweierstrass_jacobian],
|
||||
# Helpers
|
||||
../helpers/prng_unsafe,
|
||||
./bench_elliptic_template,
|
||||
./bench_blueprint
|
||||
./bench_elliptic_parallel_template
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
@ -45,14 +40,9 @@ proc main() =
|
||||
staticFor i, 0, AvailableCurves.len:
|
||||
const curve = AvailableCurves[i]
|
||||
separator()
|
||||
# for numPoints in testNumPoints:
|
||||
# let batchIters = max(1, Iters div numPoints)
|
||||
# msmBench(ECP_ShortW_Prj[Fp[curve], G1], numPoints, batchIters)
|
||||
# separator()
|
||||
# separator()
|
||||
for numPoints in testNumPoints:
|
||||
let batchIters = max(1, Iters div numPoints)
|
||||
msmBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, batchIters)
|
||||
msmParallelBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, batchIters)
|
||||
separator()
|
||||
separator()
|
||||
|
||||
|
||||
1
benchmarks/bench_ec_g1_msm_bls12_381.nim.cfg
Normal file
1
benchmarks/bench_ec_g1_msm_bls12_381.nim.cfg
Normal file
@ -0,0 +1 @@
|
||||
--threads:on
|
||||
@ -11,16 +11,11 @@ import
|
||||
../constantine/math/config/curves,
|
||||
../constantine/math/arithmetic,
|
||||
../constantine/math/elliptic/[
|
||||
ec_shortweierstrass_affine,
|
||||
ec_shortweierstrass_projective,
|
||||
ec_shortweierstrass_jacobian,
|
||||
ec_scalar_mul,
|
||||
ec_multi_scalar_mul],
|
||||
../constantine/math/constants/zoo_subgroups,
|
||||
ec_shortweierstrass_jacobian],
|
||||
# Helpers
|
||||
../helpers/prng_unsafe,
|
||||
./bench_elliptic_template,
|
||||
./bench_blueprint
|
||||
./bench_elliptic_parallel_template
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
@ -36,22 +31,18 @@ const AvailableCurves = [
|
||||
BN254_Snarks,
|
||||
]
|
||||
|
||||
const testNumPoints = [10, 100, 1000, 10000, 100000]
|
||||
# const testNumPoints = [10, 100, 1000, 10000, 100000]
|
||||
# const testNumPoints = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
const testNumPoints = [1 shl 16, 1 shl 22]
|
||||
|
||||
proc main() =
|
||||
separator()
|
||||
staticFor i, 0, AvailableCurves.len:
|
||||
const curve = AvailableCurves[i]
|
||||
separator()
|
||||
# for numPoints in testNumPoints:
|
||||
# let batchIters = max(1, Iters div numPoints)
|
||||
# msmBench(ECP_ShortW_Prj[Fp[curve], G1], numPoints, batchIters)
|
||||
# separator()
|
||||
# separator()
|
||||
for numPoints in testNumPoints:
|
||||
let batchIters = max(1, Iters div numPoints)
|
||||
msmBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, batchIters)
|
||||
msmParallelBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, batchIters)
|
||||
separator()
|
||||
separator()
|
||||
|
||||
|
||||
1
benchmarks/bench_ec_g1_msm_bn254_snarks.nim.cfg
Normal file
1
benchmarks/bench_ec_g1_msm_bn254_snarks.nim.cfg
Normal file
@ -0,0 +1 @@
|
||||
--threads:on
|
||||
116
benchmarks/bench_elliptic_parallel_template.nim
Normal file
116
benchmarks/bench_elliptic_parallel_template.nim
Normal file
@ -0,0 +1,116 @@
|
||||
# Constantine
|
||||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import
|
||||
# Internals
|
||||
../constantine/math/config/curves,
|
||||
../constantine/math/arithmetic,
|
||||
../constantine/math/elliptic/[
|
||||
ec_shortweierstrass_affine,
|
||||
ec_shortweierstrass_projective,
|
||||
ec_shortweierstrass_jacobian,
|
||||
ec_shortweierstrass_jacobian_extended,
|
||||
ec_shortweierstrass_batch_ops_parallel,
|
||||
ec_multi_scalar_mul,
|
||||
ec_scalar_mul,
|
||||
ec_multi_scalar_mul_parallel],
|
||||
../constantine/math/constants/zoo_subgroups,
|
||||
# Threadpool
|
||||
../constantine/platforms/threadpool/threadpool,
|
||||
# Helpers
|
||||
../helpers/prng_unsafe,
|
||||
./bench_elliptic_template,
|
||||
./bench_blueprint
|
||||
|
||||
export bench_elliptic_template
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Parallel Benchmark definitions
|
||||
#
|
||||
# ############################################################
|
||||
|
||||
proc multiAddParallelBench*(EC: typedesc, numPoints: int, iters: int) =
|
||||
var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](numPoints)
|
||||
|
||||
for i in 0 ..< numPoints:
|
||||
points[i] = rng.random_unsafe(ECP_ShortW_Aff[EC.F, EC.G])
|
||||
|
||||
var r{.noInit.}: EC
|
||||
|
||||
var tp = Threadpool.new()
|
||||
|
||||
bench("EC parallel batch add (" & align($tp.numThreads, 2) & " threads) " & $EC.G & " (" & $numPoints & " points)", EC, iters):
|
||||
tp.sum_reduce_vartime_parallel(r, points)
|
||||
|
||||
tp.shutdown()
|
||||
|
||||
proc msmParallelBench*(EC: typedesc, numPoints: int, iters: int) =
|
||||
const bits = EC.F.C.getCurveOrderBitwidth()
|
||||
var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](numPoints)
|
||||
var scalars = newSeq[BigInt[bits]](numPoints)
|
||||
|
||||
for i in 0 ..< numPoints:
|
||||
var tmp = rng.random_unsafe(EC)
|
||||
tmp.clearCofactor()
|
||||
points[i].affine(tmp)
|
||||
scalars[i] = rng.random_unsafe(BigInt[bits])
|
||||
|
||||
var r{.noInit.}: EC
|
||||
var startNaive, stopNaive, startMSMbaseline, stopMSMbaseline, startMSMopt, stopMSMopt, startMSMpara, stopMSMpara: MonoTime
|
||||
|
||||
if numPoints <= 100000:
|
||||
bench("EC scalar muls " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
|
||||
startNaive = getMonotime()
|
||||
var tmp: EC
|
||||
r.setInf()
|
||||
for i in 0 ..< points.len:
|
||||
tmp.fromAffine(points[i])
|
||||
tmp.scalarMul(scalars[i])
|
||||
r += tmp
|
||||
stopNaive = getMonotime()
|
||||
|
||||
block:
|
||||
bench("EC multi-scalar-mul baseline " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
|
||||
startMSMbaseline = getMonotime()
|
||||
r.multiScalarMul_reference_vartime(scalars, points)
|
||||
stopMSMbaseline = getMonotime()
|
||||
|
||||
block:
|
||||
bench("EC multi-scalar-mul optimized " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
|
||||
startMSMopt = getMonotime()
|
||||
r.multiScalarMul_vartime(scalars, points)
|
||||
stopMSMopt = getMonotime()
|
||||
|
||||
block:
|
||||
var tp = Threadpool.new()
|
||||
|
||||
bench("EC multi-scalar-mul" & align($tp.numThreads & " threads", 11) & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
|
||||
startMSMpara = getMonotime()
|
||||
tp.multiScalarMul_vartime_parallel(r, scalars, points)
|
||||
stopMSMpara = getMonotime()
|
||||
|
||||
tp.shutdown()
|
||||
|
||||
let perfNaive = inNanoseconds((stopNaive-startNaive) div iters)
|
||||
let perfMSMbaseline = inNanoseconds((stopMSMbaseline-startMSMbaseline) div iters)
|
||||
let perfMSMopt = inNanoseconds((stopMSMopt-startMSMopt) div iters)
|
||||
let perfMSMpara = inNanoseconds((stopMSMpara-startMSMpara) div iters)
|
||||
|
||||
if numPoints <= 100000:
|
||||
let speedupBaseline = float(perfNaive) / float(perfMSMbaseline)
|
||||
echo &"Speedup ratio baseline over naive linear combination: {speedupBaseline:>6.3f}x"
|
||||
|
||||
let speedupOpt = float(perfNaive) / float(perfMSMopt)
|
||||
echo &"Speedup ratio optimized over naive linear combination: {speedupOpt:>6.3f}x"
|
||||
|
||||
let speedupOptBaseline = float(perfMSMbaseline) / float(perfMSMopt)
|
||||
echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x"
|
||||
|
||||
let speedupParaOpt = float(perfMSMopt) / float(perfMSMpara)
|
||||
echo &"Speedup ratio parallel over optimized linear combination: {speedupParaOpt:>6.3f}x"
|
||||
@ -36,7 +36,7 @@ import
|
||||
export notes
|
||||
export abstractions # generic sandwich on SecretBool and SecretBool in Jacobian sum
|
||||
|
||||
proc separator*() = separator(206)
|
||||
proc separator*() = separator(179)
|
||||
|
||||
macro fixEllipticDisplay(EC: typedesc): untyped =
|
||||
# At compile-time, enums are integers and their display is buggy
|
||||
@ -52,9 +52,9 @@ proc report(op, elliptic: string, start, stop: MonoTime, startClk, stopClk: int6
|
||||
let ns = inNanoseconds((stop-start) div iters)
|
||||
let throughput = 1e9 / float64(ns)
|
||||
when SupportsGetTicks:
|
||||
echo &"{op:<80} {elliptic:<40} {throughput:>15.3f} ops/s {ns:>12} ns/op {(stopClk - startClk) div iters:>12} CPU cycles (approx)"
|
||||
echo &"{op:<68} {elliptic:<32} {throughput:>15.3f} ops/s {ns:>16} ns/op {(stopClk - startClk) div iters:>12} CPU cycles (approx)"
|
||||
else:
|
||||
echo &"{op:<80} {elliptic:<40} {throughput:>15.3f} ops/s {ns:>12} ns/op"
|
||||
echo &"{op:<68} {elliptic:<32} {throughput:>15.3f} ops/s {ns:>16} ns/op"
|
||||
|
||||
template bench*(op: string, EC: typedesc, iters: int, body: untyped): untyped =
|
||||
measure(iters, startTime, stopTime, startClk, stopClk, body)
|
||||
|
||||
@ -260,7 +260,9 @@ const testDescThreadpool: seq[string] = @[
|
||||
|
||||
const testDescMultithreadedCrypto: seq[string] = @[
|
||||
"tests/parallel/t_ec_shortw_jac_g1_batch_add_parallel.nim",
|
||||
"tests/parallel/t_ec_shortw_prj_g1_batch_add_parallel.nim"
|
||||
"tests/parallel/t_ec_shortw_prj_g1_batch_add_parallel.nim",
|
||||
"tests/parallel/t_ec_shortw_jac_g1_msm_parallel.nim",
|
||||
"tests/parallel/t_ec_shortw_prj_g1_msm_parallel.nim",
|
||||
]
|
||||
|
||||
const benchDesc = [
|
||||
@ -574,6 +576,7 @@ task test, "Run all tests":
|
||||
cmdFile.addTestSet(requireGMP = true, testASM = true)
|
||||
cmdFile.addBenchSet(useASM = true) # Build (but don't run) benches to ensure they stay relevant
|
||||
cmdFile.addTestSetThreadpool()
|
||||
cmdFile.addTestSetMultithreadedCrypto()
|
||||
for cmd in cmdFile.splitLines():
|
||||
if cmd != "": # Windows doesn't like empty commands
|
||||
exec cmd
|
||||
@ -584,6 +587,7 @@ task test_no_asm, "Run all tests (no assembly)":
|
||||
cmdFile.addTestSet(requireGMP = true, testASM = false)
|
||||
cmdFile.addBenchSet(useASM = false) # Build (but don't run) benches to ensure they stay relevant
|
||||
cmdFile.addTestSetThreadpool()
|
||||
cmdFile.addTestSetMultithreadedCrypto(testASM = false)
|
||||
for cmd in cmdFile.splitLines():
|
||||
if cmd != "": # Windows doesn't like empty commands
|
||||
exec cmd
|
||||
@ -594,6 +598,7 @@ task test_no_gmp, "Run tests that don't require GMP":
|
||||
cmdFile.addTestSet(requireGMP = false, testASM = true)
|
||||
cmdFile.addBenchSet(useASM = true) # Build (but don't run) benches to ensure they stay relevant
|
||||
cmdFile.addTestSetThreadpool()
|
||||
cmdFile.addTestSetMultithreadedCrypto()
|
||||
for cmd in cmdFile.splitLines():
|
||||
if cmd != "": # Windows doesn't like empty commands
|
||||
exec cmd
|
||||
@ -604,11 +609,12 @@ task test_no_gmp_no_asm, "Run tests that don't require GMP using a pure Nim back
|
||||
cmdFile.addTestSet(requireGMP = false, testASM = false)
|
||||
cmdFile.addBenchSet(useASM = false) # Build (but don't run) benches to ensure they stay relevant
|
||||
cmdFile.addTestSetThreadpool()
|
||||
cmdFile.addTestSetMultithreadedCrypto(testASM = false)
|
||||
for cmd in cmdFile.splitLines():
|
||||
if cmd != "": # Windows doesn't like empty commands
|
||||
exec cmd
|
||||
|
||||
task test_parallel, "Run all tests in parallel (via GNU parallel)":
|
||||
task test_parallel, "Run all tests in parallel":
|
||||
# -d:testingCurves is configured in a *.nim.cfg for convenience
|
||||
clearParallelBuild()
|
||||
genParallelCmdRunner()
|
||||
@ -622,11 +628,12 @@ task test_parallel, "Run all tests in parallel (via GNU parallel)":
|
||||
# Threadpool tests done serially
|
||||
cmdFile = ""
|
||||
cmdFile.addTestSetThreadpool()
|
||||
cmdFile.addTestSetMultithreadedCrypto()
|
||||
for cmd in cmdFile.splitLines():
|
||||
if cmd != "": # Windows doesn't like empty commands
|
||||
exec cmd
|
||||
|
||||
task test_parallel_no_asm, "Run all tests (without macro assembler) in parallel (via GNU parallel)":
|
||||
task test_parallel_no_asm, "Run all tests (without macro assembler) in parallel":
|
||||
# -d:testingCurves is configured in a *.nim.cfg for convenience
|
||||
clearParallelBuild()
|
||||
genParallelCmdRunner()
|
||||
@ -640,11 +647,12 @@ task test_parallel_no_asm, "Run all tests (without macro assembler) in parallel
|
||||
# Threadpool tests done serially
|
||||
cmdFile = ""
|
||||
cmdFile.addTestSetThreadpool()
|
||||
cmdFile.addTestSetMultithreadedCrypto(testASM = false)
|
||||
for cmd in cmdFile.splitLines():
|
||||
if cmd != "": # Windows doesn't like empty commands
|
||||
exec cmd
|
||||
|
||||
task test_parallel_no_gmp, "Run all tests in parallel (via GNU parallel)":
|
||||
task test_parallel_no_gmp, "Run all tests in parallel":
|
||||
# -d:testingCurves is configured in a *.nim.cfg for convenience
|
||||
clearParallelBuild()
|
||||
genParallelCmdRunner()
|
||||
@ -658,11 +666,12 @@ task test_parallel_no_gmp, "Run all tests in parallel (via GNU parallel)":
|
||||
# Threadpool tests done serially
|
||||
cmdFile = ""
|
||||
cmdFile.addTestSetThreadpool()
|
||||
cmdFile.addTestSetMultithreadedCrypto()
|
||||
for cmd in cmdFile.splitLines():
|
||||
if cmd != "": # Windows doesn't like empty commands
|
||||
exec cmd
|
||||
|
||||
task test_parallel_no_gmp_no_asm, "Run all tests in parallel (via GNU parallel)":
|
||||
task test_parallel_no_gmp_no_asm, "Run all tests in parallel":
|
||||
# -d:testingCurves is configured in a *.nim.cfg for convenience
|
||||
clearParallelBuild()
|
||||
genParallelCmdRunner()
|
||||
@ -676,6 +685,7 @@ task test_parallel_no_gmp_no_asm, "Run all tests in parallel (via GNU parallel)"
|
||||
# Threadpool tests done serially
|
||||
cmdFile = ""
|
||||
cmdFile.addTestSetThreadpool()
|
||||
cmdFile.addTestSetMultithreadedCrypto(testASM = false)
|
||||
for cmd in cmdFile.splitLines():
|
||||
if cmd != "": # Windows doesn't like empty commands
|
||||
exec cmd
|
||||
|
||||
@ -369,9 +369,9 @@ func getWindowAt*(a: BigInt, bitIndex: int, windowSize: static int): SecretWord
|
||||
# This is constant-time, the branch does not depend on secret data.
|
||||
if pos + windowSize > WordBitWidth and slot+1 < a.limbs.len:
|
||||
# Read next word as well
|
||||
return SecretWord((word shr pos) or (a.limbs[slot+1] shl (WordBitWidth-pos))) and WindowMask
|
||||
return ((word shr pos) or (a.limbs[slot+1] shl (WordBitWidth-pos))) and WindowMask
|
||||
else:
|
||||
return SecretWord(word shr pos) and WindowMask
|
||||
return (word shr pos) and WindowMask
|
||||
|
||||
# Multiplication by small constants
|
||||
# ------------------------------------------------------------
|
||||
|
||||
@ -223,8 +223,7 @@ func nDimMultiScalarRecoding[M, L: static int](
|
||||
func buildLookupTable[M: static int, EC, ECaff](
|
||||
P: EC,
|
||||
endomorphisms: array[M-1, EC],
|
||||
lut: var array[1 shl (M-1), ECaff],
|
||||
) =
|
||||
lut: var array[1 shl (M-1), ECaff]) =
|
||||
## Build the lookup table from the base point P
|
||||
## and the curve endomorphism
|
||||
#
|
||||
@ -295,7 +294,6 @@ func scalarMulEndo*[scalBits; EC](
|
||||
## - Cofactor to be cleared
|
||||
## - 0 <= scalar < curve order
|
||||
mixin affine
|
||||
type ECaff = affine(EC)
|
||||
const C = P.F.C # curve
|
||||
static: doAssert scalBits <= C.getCurveOrderBitwidth(), "Do not use endomorphism to multiply beyond the curve order"
|
||||
when P.F is Fp:
|
||||
@ -341,7 +339,7 @@ func scalarMulEndo*[scalBits; EC](
|
||||
endomorphisms[i-1].cneg(negatePoints[i])
|
||||
|
||||
# 4. Precompute lookup table
|
||||
var lut {.noInit.}: array[1 shl (M-1), ECaff]
|
||||
var lut {.noInit.}: array[1 shl (M-1), affine(EC)]
|
||||
buildLookupTable(P, endomorphisms, lut)
|
||||
|
||||
# 5. Recode the miniscalars
|
||||
@ -355,7 +353,7 @@ func scalarMulEndo*[scalBits; EC](
|
||||
|
||||
# 6. Proceed to GLV accelerated scalar multiplication
|
||||
var Q {.noInit.}: EC
|
||||
var tmp {.noInit.}: ECaff
|
||||
var tmp {.noInit.}: affine(EC)
|
||||
tmp.secretLookup(lut, recoded.tableIndex(L-1))
|
||||
Q.fromAffine(tmp)
|
||||
|
||||
@ -493,7 +491,6 @@ func scalarMulGLV_m2w2*[scalBits; EC](
|
||||
## - Cofactor to be cleared
|
||||
## - 0 <= scalar < curve order
|
||||
mixin affine
|
||||
type ECaff = affine(EC)
|
||||
const C = P0.F.C # curve
|
||||
static: doAssert: scalBits <= C.getCurveOrderBitwidth()
|
||||
|
||||
@ -520,7 +517,7 @@ func scalarMulGLV_m2w2*[scalBits; EC](
|
||||
P1.cneg(negatePoints[1])
|
||||
|
||||
# 4. Precompute lookup table
|
||||
var lut {.noInit.}: array[8, ECaff]
|
||||
var lut {.noInit.}: array[8, affine(EC)]
|
||||
buildLookupTable_m2w2(P0, P1, lut)
|
||||
|
||||
# 5. Recode the miniscalars
|
||||
@ -534,7 +531,7 @@ func scalarMulGLV_m2w2*[scalBits; EC](
|
||||
|
||||
# 6. Proceed to GLV accelerated scalar multiplication
|
||||
var Q {.noInit.}: EC
|
||||
var tmp {.noInit.}: ECaff
|
||||
var tmp {.noInit.}: affine(EC)
|
||||
var isNeg: SecretBool
|
||||
|
||||
tmp.secretLookup(lut, recoded.w2TableIndex((L div 2) - 1, isNeg))
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
|
||||
import ./ec_multi_scalar_mul_scheduler,
|
||||
./ec_endomorphism_accel,
|
||||
../extension_fields,
|
||||
../constants/zoo_endomorphisms
|
||||
export bestBucketBitSize
|
||||
|
||||
@ -159,18 +160,16 @@ func bucketReduce[EC](r: var EC, buckets: ptr UncheckedArray[EC], numBuckets: st
|
||||
r += accumBuckets
|
||||
buckets[k].setInf()
|
||||
|
||||
type MiniMsmKind = enum
|
||||
type MiniMsmKind* = enum
|
||||
kTopWindow
|
||||
kFullWindow
|
||||
kBottomWindow
|
||||
|
||||
func miniMSM_jacext[F, G; bits: static int](
|
||||
func bucketAccumReduce_jacext*[F, G; bits: static int](
|
||||
r: var ECP_ShortW[F, G],
|
||||
buckets: ptr UncheckedArray[ECP_ShortW_JacExt[F, G]],
|
||||
bitIndex: int, miniMsmKind: static MiniMsmKind, c: static int,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], N: int) {.meter.} =
|
||||
## Apply a mini-Multi-Scalar-Multiplication on [bitIndex, bitIndex+window)
|
||||
## slice of all (coef, point) pairs
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], N: int) =
|
||||
|
||||
const excess = bits mod c
|
||||
const top = bits - excess
|
||||
@ -197,19 +196,31 @@ func miniMSM_jacext[F, G; bits: static int](
|
||||
buckets.accumulate(curVal, curNeg, points[N-1])
|
||||
|
||||
# 2. Bucket Reduction
|
||||
var sliceSum{.noinit.}: ECP_ShortW_JacExt[F, G]
|
||||
sliceSum.bucketReduce(buckets, numBuckets = 1 shl (c-1))
|
||||
var windowSum{.noinit.}: ECP_ShortW_JacExt[F, G]
|
||||
windowSum.bucketReduce(buckets, numBuckets = 1 shl (c-1))
|
||||
|
||||
r.fromJacobianExtended_vartime(windowSum)
|
||||
|
||||
func miniMSM_jacext[F, G; bits: static int](
|
||||
r: var ECP_ShortW[F, G],
|
||||
buckets: ptr UncheckedArray[ECP_ShortW_JacExt[F, G]],
|
||||
bitIndex: int, miniMsmKind: static MiniMsmKind, c: static int,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], N: int) {.meter.} =
|
||||
## Apply a mini-Multi-Scalar-Multiplication on [bitIndex, bitIndex+window)
|
||||
## slice of all (coef, point) pairs
|
||||
|
||||
var windowSum{.noInit.}: typeof(r)
|
||||
windowSum.bucketAccumReduce_jacext(
|
||||
buckets, bitIndex, miniMsmKind, c,
|
||||
coefs, points, N)
|
||||
|
||||
# 3. Mini-MSM on the slice [bitIndex, bitIndex+window)
|
||||
var windowSum{.noInit.}: typeof(r)
|
||||
windowSum.fromJacobianExtended_vartime(sliceSum)
|
||||
r += windowSum
|
||||
|
||||
when miniMsmKind != kBottomWindow:
|
||||
for _ in 0 ..< c:
|
||||
r.double()
|
||||
|
||||
func multiScalarMulJacExt_vartime[F, G; bits: static int](
|
||||
func multiScalarMulJacExt_vartime*[F, G; bits: static int](
|
||||
r: var ECP_ShortW[F, G],
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int, c: static int) {.tags:[VarTime, HeapAlloc], meter.} =
|
||||
@ -231,36 +242,35 @@ func multiScalarMulJacExt_vartime[F, G; bits: static int](
|
||||
var w = top
|
||||
r.setInf()
|
||||
|
||||
if excess != 0 and w != 0: # Prologue
|
||||
r.miniMSM_jacext(buckets, w, kTopWindow, c, coefs, points, N)
|
||||
w -= c
|
||||
when top != 0: # Prologue
|
||||
when excess != 0:
|
||||
r.miniMSM_jacext(buckets, w, kTopWindow, c, coefs, points, N)
|
||||
w -= c
|
||||
else:
|
||||
# If c divides bits exactly, the signed windowed recoding still needs to see an extra 0
|
||||
# Since we did r.setInf() earlier, this is a no-op
|
||||
w -= c
|
||||
|
||||
while w != 0: # Steady state
|
||||
while w != 0: # Steady state
|
||||
r.miniMSM_jacext(buckets, w, kFullWindow, c, coefs, points, N)
|
||||
w -= c
|
||||
|
||||
block: # Epilogue
|
||||
block: # Epilogue
|
||||
r.miniMSM_jacext(buckets, w, kBottomWindow, c, coefs, points, N)
|
||||
|
||||
# Cleanup
|
||||
# -------
|
||||
buckets.freeHeap()
|
||||
|
||||
func miniMSM_affine[NumBuckets, QueueLen, F, G; bits: static int](
|
||||
r: var ECP_ShortW[F, G],
|
||||
sched: var Scheduler[NumBuckets, QueueLen, F, G],
|
||||
func schedAccumulate*[NumBuckets, QueueLen, F, G; bits: static int](
|
||||
sched: ptr Scheduler[NumBuckets, QueueLen, F, G],
|
||||
bitIndex: int, miniMsmKind: static MiniMsmKind, c: static int,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], N: int) {.meter.} =
|
||||
## Apply a mini-Multi-Scalar-Multiplication on [bitIndex, bitIndex+window)
|
||||
## slice of all (coef, point) pairs
|
||||
|
||||
const excess = bits mod c
|
||||
const top = bits - excess
|
||||
static: doAssert miniMsmKind != kTopWindow, "The top window is smaller in bits which increases collisions in scheduler."
|
||||
|
||||
sched.buckets[].init()
|
||||
|
||||
# 1. Bucket Accumulation
|
||||
var curSP, nextSP: ScheduledPoint
|
||||
|
||||
template getSignedWindow(j : int): tuple[val: SecretWord, neg: SecretBool] =
|
||||
@ -277,13 +287,26 @@ func miniMSM_affine[NumBuckets, QueueLen, F, G; bits: static int](
|
||||
sched.schedule(curSP)
|
||||
sched.flushPendingAndReset()
|
||||
|
||||
func miniMSM_affine[NumBuckets, QueueLen, F, G; bits: static int](
|
||||
r: var ECP_ShortW[F, G],
|
||||
sched: ptr Scheduler[NumBuckets, QueueLen, F, G],
|
||||
bitIndex: int, miniMsmKind: static MiniMsmKind, c: static int,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], N: int) {.meter.} =
|
||||
## Apply a mini-Multi-Scalar-Multiplication on [bitIndex, bitIndex+window)
|
||||
## slice of all (coef, point) pairs
|
||||
|
||||
sched.buckets[].init()
|
||||
|
||||
# 1. Bucket Accumulation
|
||||
sched.schedAccumulate(bitIndex, miniMsmKind, c, coefs, N)
|
||||
|
||||
# 2. Bucket Reduction
|
||||
var sliceSum{.noInit.}: ECP_ShortW_JacExt[F, G]
|
||||
sliceSum.bucketReduce(sched.buckets[])
|
||||
var windowSum_jacext{.noInit.}: ECP_ShortW_JacExt[F, G]
|
||||
windowSum_jacext.bucketReduce(sched.buckets[])
|
||||
|
||||
# 3. Mini-MSM on the slice [bitIndex, bitIndex+window)
|
||||
var windowSum{.noInit.}: typeof(r)
|
||||
windowSum.fromJacobianExtended_vartime(sliceSum)
|
||||
windowSum.fromJacobianExtended_vartime(windowSum_jacext)
|
||||
r += windowSum
|
||||
|
||||
when miniMsmKind != kBottomWindow:
|
||||
@ -303,7 +326,7 @@ func multiScalarMulAffine_vartime[F, G; bits: static int](
|
||||
let buckets = allocHeap(Buckets[numBuckets, F, G])
|
||||
buckets[].init()
|
||||
let sched = allocHeap(Scheduler[numBuckets, queueLen, F, G])
|
||||
sched[].init(points, buckets, 0, numBuckets.int32)
|
||||
sched.init(points, buckets, 0, numBuckets.int32)
|
||||
|
||||
# Algorithm
|
||||
# ---------
|
||||
@ -312,24 +335,86 @@ func multiScalarMulAffine_vartime[F, G; bits: static int](
|
||||
var w = top
|
||||
r.setInf()
|
||||
|
||||
if excess != 0 and w != 0: # Prologue
|
||||
# The top might use only a few bits, the affine scheduler would likely have significant collisions
|
||||
zeroMem(sched.buckets.ptJacExt.addr, buckets.ptJacExt.sizeof())
|
||||
r.miniMSM_jacext(sched.buckets.ptJacExt.asUnchecked(), w, kTopWindow, c, coefs, points, N)
|
||||
when top != 0: # Prologue
|
||||
when excess != 0:
|
||||
# The top might use only a few bits, the affine scheduler would likely have significant collisions
|
||||
zeroMem(sched.buckets.ptJacExt.addr, buckets.ptJacExt.sizeof())
|
||||
r.miniMSM_jacext(sched.buckets.ptJacExt.asUnchecked(), w, kTopWindow, c, coefs, points, N)
|
||||
w -= c
|
||||
else:
|
||||
# If c divides bits exactly, the signed windowed recoding still needs to see an extra 0
|
||||
# Since we did r.setInf() earlier, this is a no-op
|
||||
w -= c
|
||||
|
||||
while w != 0: # Steady state
|
||||
r.miniMSM_affine(sched, w, kFullWindow, c, coefs, N)
|
||||
w -= c
|
||||
|
||||
while w != 0: # Steady state
|
||||
r.miniMSM_affine(sched[], w, kFullWindow, c, coefs, N)
|
||||
w -= c
|
||||
|
||||
block: # Epilogue
|
||||
r.miniMSM_affine(sched[], w, kBottomWindow, c, coefs, N)
|
||||
block: # Epilogue
|
||||
r.miniMSM_affine(sched, w, kBottomWindow, c, coefs, N)
|
||||
|
||||
# Cleanup
|
||||
# -------
|
||||
sched.freeHeap()
|
||||
buckets.freeHeap()
|
||||
|
||||
proc applyEndomorphism[bits: static int, F, G](
|
||||
coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int): auto =
|
||||
## Decompose (coefs, points) into mini-scalars
|
||||
## Returns a new triplet (endoCoefs, endoPoints, N)
|
||||
## endoCoefs and endoPoints MUST be freed afterwards
|
||||
|
||||
const M = when F is Fp: 2
|
||||
elif F is Fp2: 4
|
||||
else: {.error: "Unconfigured".}
|
||||
|
||||
const L = bits.ceilDiv_vartime(M) + 1
|
||||
let splitCoefs = allocHeapArray(array[M, BigInt[L]], N)
|
||||
let endoBasis = allocHeapArray(array[M, ECP_ShortW_Aff[F, G]], N)
|
||||
|
||||
for i in 0 ..< N:
|
||||
var negatePoints {.noinit.}: array[M, SecretBool]
|
||||
splitCoefs[i].decomposeEndo(negatePoints, coefs[i], F)
|
||||
if negatePoints[0].bool:
|
||||
endoBasis[i][0].neg(points[i])
|
||||
else:
|
||||
endoBasis[i][0] = points[i]
|
||||
|
||||
when F is Fp:
|
||||
endoBasis[i][1].x.prod(points[i].x, F.C.getCubicRootOfUnity_mod_p())
|
||||
if negatePoints[1].bool:
|
||||
endoBasis[i][1].y.neg(points[i].y)
|
||||
else:
|
||||
endoBasis[i][1].y = points[i].y
|
||||
else:
|
||||
staticFor m, 1, M:
|
||||
endoBasis[i][m].frobenius_psi(points[i], m)
|
||||
if negatePoints[m].bool:
|
||||
endoBasis[i][m].neg()
|
||||
|
||||
let endoCoefs = cast[ptr UncheckedArray[BigInt[L]]](splitCoefs)
|
||||
let endoPoints = cast[ptr UncheckedArray[ECP_ShortW_Aff[F, G]]](endoBasis)
|
||||
|
||||
return (endoCoefs, endoPoints, M*N)
|
||||
|
||||
template withEndo[bits: static int, F, G](
|
||||
msmProc: untyped,
|
||||
r: var ECP_ShortW[F, G],
|
||||
coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int, c: static int) =
|
||||
when bits <= F.C.getCurveOrderBitwidth() and hasEndomorphismAcceleration(F.C):
|
||||
let (endoCoefs, endoPoints, endoN) = applyEndomorphism(coefs, points, N)
|
||||
# Given that bits and N changed, we are able to use a bigger `c`
|
||||
# but it has no significant impact on performance
|
||||
msmProc(r, endoCoefs, endoPoints, endoN, c)
|
||||
freeHeap(endoCoefs)
|
||||
freeHeap(endoPoints)
|
||||
else:
|
||||
msmProc(r, coefs, points, N, c)
|
||||
|
||||
func multiScalarMul_dispatch_vartime[bits: static int, F, G](
|
||||
r: var ECP_ShortW[F, G], coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], N: int) =
|
||||
@ -337,24 +422,29 @@ func multiScalarMul_dispatch_vartime[bits: static int, F, G](
|
||||
## r <- [a₀]P₀ + [a₁]P₁ + ... + [aₙ]Pₙ
|
||||
let c = bestBucketBitSize(N, bits, useSignedBuckets = true, useManualTuning = true)
|
||||
|
||||
# Given that bits and N change after applying an endomorphism,
|
||||
# we are able to use a bigger `c`
|
||||
# but it has no significant impact on performance
|
||||
|
||||
case c
|
||||
of 2: multiScalarMulJacExt_vartime(r, coefs, points, N, c = 2)
|
||||
of 3: multiScalarMulJacExt_vartime(r, coefs, points, N, c = 3)
|
||||
of 4: multiScalarMulJacExt_vartime(r, coefs, points, N, c = 4)
|
||||
of 5: multiScalarMulJacExt_vartime(r, coefs, points, N, c = 5)
|
||||
of 6: multiScalarMulJacExt_vartime(r, coefs, points, N, c = 6)
|
||||
of 7: multiScalarMulJacExt_vartime(r, coefs, points, N, c = 7)
|
||||
of 8: multiScalarMulJacExt_vartime(r, coefs, points, N, c = 8)
|
||||
of 9: multiScalarMulAffine_vartime(r, coefs, points, N, c = 9)
|
||||
of 10: multiScalarMulAffine_vartime(r, coefs, points, N, c = 10)
|
||||
of 11: multiScalarMulAffine_vartime(r, coefs, points, N, c = 11)
|
||||
of 12: multiScalarMulAffine_vartime(r, coefs, points, N, c = 12)
|
||||
of 13: multiScalarMulAffine_vartime(r, coefs, points, N, c = 13)
|
||||
of 14: multiScalarMulAffine_vartime(r, coefs, points, N, c = 14)
|
||||
of 15: multiScalarMulAffine_vartime(r, coefs, points, N, c = 15)
|
||||
of 16: multiScalarMulAffine_vartime(r, coefs, points, N, c = 16)
|
||||
of 17: multiScalarMulAffine_vartime(r, coefs, points, N, c = 17)
|
||||
of 18: multiScalarMulAffine_vartime(r, coefs, points, N, c = 18)
|
||||
of 2: withEndo(multiScalarMulJacExt_vartime, r, coefs, points, N, c = 2)
|
||||
of 3: withEndo(multiScalarMulJacExt_vartime, r, coefs, points, N, c = 3)
|
||||
of 4: withEndo(multiScalarMulJacExt_vartime, r, coefs, points, N, c = 4)
|
||||
of 5: withEndo(multiScalarMulJacExt_vartime, r, coefs, points, N, c = 5)
|
||||
of 6: withEndo(multiScalarMulJacExt_vartime, r, coefs, points, N, c = 6)
|
||||
of 7: withEndo(multiScalarMulJacExt_vartime, r, coefs, points, N, c = 7)
|
||||
of 8: withEndo(multiScalarMulJacExt_vartime, r, coefs, points, N, c = 8)
|
||||
|
||||
of 9: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 9)
|
||||
of 10: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 10)
|
||||
of 11: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 11)
|
||||
of 12: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 12)
|
||||
of 13: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 13)
|
||||
of 14: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 14)
|
||||
of 15: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 15)
|
||||
of 16: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 16)
|
||||
of 17: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 17)
|
||||
of 18: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 18)
|
||||
else:
|
||||
unreachable()
|
||||
|
||||
@ -368,44 +458,4 @@ func multiScalarMul_vartime*[bits: static int, F, G](
|
||||
debug: doAssert coefs.len == points.len
|
||||
let N = points.len
|
||||
|
||||
when bits <= F.C.getCurveOrderBitwidth() and
|
||||
F.C.hasEndomorphismAcceleration():
|
||||
# TODO, min amount of bits for endomorphisms?
|
||||
|
||||
const M = when F is Fp: 2
|
||||
elif F is Fp2: 4
|
||||
else: {.error: "Unconfigured".}
|
||||
|
||||
const L = bits.ceilDiv_vartime(M) + 1
|
||||
let splitCoefs = allocHeapArray(array[M, BigInt[L]], N)
|
||||
let endoBasis = allocHeapArray(array[M, ECP_ShortW_Aff[F, G]], N)
|
||||
|
||||
for i in 0 ..< N:
|
||||
var negatePoints {.noinit.}: array[M, SecretBool]
|
||||
splitCoefs[i].decomposeEndo(negatePoints, coefs[i], F)
|
||||
if negatePoints[0].bool:
|
||||
endoBasis[i][0].neg(points[i])
|
||||
else:
|
||||
endoBasis[i][0] = points[i]
|
||||
|
||||
when F is Fp:
|
||||
endoBasis[i][1].x.prod(points[i].x, F.C.getCubicRootOfUnity_mod_p())
|
||||
if negatePoints[1].bool:
|
||||
endoBasis[i][1].y.neg(points[i].y)
|
||||
else:
|
||||
endoBasis[i][1].y = points[i].y
|
||||
else:
|
||||
staticFor m, 1, M:
|
||||
endoBasis[i][m].frobenius_psi(points[i], m)
|
||||
if negatePoints[m].bool:
|
||||
endoBasis[i][m].neg()
|
||||
|
||||
let endoCoefs = cast[ptr UncheckedArray[BigInt[L]]](splitCoefs)
|
||||
let endoPoints = cast[ptr UncheckedArray[ECP_ShortW_Aff[F, G]]](endoBasis)
|
||||
multiScalarMul_dispatch_vartime(r, endoCoefs, endoPoints, M*N)
|
||||
|
||||
endoBasis.freeHeap()
|
||||
splitCoefs.freeHeap()
|
||||
|
||||
else:
|
||||
multiScalarMul_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N)
|
||||
multiScalarMul_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N)
|
||||
477
constantine/math/elliptic/ec_multi_scalar_mul_parallel.nim
Normal file
477
constantine/math/elliptic/ec_multi_scalar_mul_parallel.nim
Normal file
@ -0,0 +1,477 @@
|
||||
# Constantine
|
||||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import ./ec_multi_scalar_mul_scheduler,
|
||||
./ec_multi_scalar_mul,
|
||||
./ec_endomorphism_accel,
|
||||
../extension_fields,
|
||||
../constants/zoo_endomorphisms,
|
||||
../../platforms/threadpool/threadpool
|
||||
export bestBucketBitSize
|
||||
|
||||
# No exceptions allowed in core cryptographic operations
|
||||
{.push raises: [].}
|
||||
{.push checks: off.}
|
||||
|
||||
# ########################################################### #
|
||||
# #
|
||||
# Parallel Multi Scalar Multiplication #
|
||||
# #
|
||||
# ########################################################### #
|
||||
#
|
||||
# Writeup
|
||||
#
|
||||
# Recall the reference implementation in pseudocode
|
||||
#
|
||||
# func multiScalarMulImpl_reference_vartime
|
||||
#
|
||||
# c <- fn(numPoints) with `fn` a function that minimizes the total number of Elliptic Curve additions
|
||||
# in the order of log2(numPoints) - 3
|
||||
# numWindows <- ⌈coefBits/c⌉
|
||||
# numBuckets <- 2ᶜ⁻¹
|
||||
# r <- ∅ (The elliptic curve infinity point)
|
||||
#
|
||||
# miniMSMs[0..<numWindows] <- ∅
|
||||
#
|
||||
# // 0.a MiniMSMs accumulation
|
||||
# for w in 0 ..< numWindows:
|
||||
#
|
||||
# // 1.a Bucket accumulation
|
||||
# buckets[0..<numBuckets] <- ∅
|
||||
# for j in 0 ..< numPoints:
|
||||
# b <- coefs[j].getWindowAt(w*c, c)
|
||||
# buckets[b] += points[j]
|
||||
#
|
||||
# // 1.r Bucket reduction
|
||||
# accumBuckets <- ∅
|
||||
# for k in countdown(numBuckets-1, 0):
|
||||
# accumBuckets += buckets[k]
|
||||
# miniMSMs[w] += accumBuckets
|
||||
#
|
||||
# // 0.r MiniMSM reduction
|
||||
# for w in countdown(numWindows-1, 0):
|
||||
# for _ in 0 ..< c:
|
||||
# r.double()
|
||||
# r += miniMSMs[w]
|
||||
#
|
||||
# return r
|
||||
#
|
||||
# A comprehensive mapping: inputSize, c, numBuckets is in ec_multi_scalar_mul_scheduler.nim
|
||||
#
|
||||
# -------inputs------- c ----buckets---- queue length collision map bytes num collisions collision %
|
||||
# 2^5 32 5 2^4 16 -108 8 -216 -675.0%
|
||||
# 2^10 1024 9 2^8 256 52 32 208 20.3%
|
||||
# 2^13 8192 11 2^10 1024 180 128 1440 17.6%
|
||||
# 2^16 65536 14 2^13 8192 432 1024 3456 5.3%
|
||||
# 2^18 262144 16 2^15 32768 640 4096 5120 2.0% <- ~10MB of buckets
|
||||
# 2^20 1048576 17 2^16 65536 756 8192 12096 1.2%
|
||||
# 2^26 67108864 23 2^22 4194304 1620 524288 25920 0.0%
|
||||
#
|
||||
# The coef bits is usually between 128 to 377 depending on endomorphisms and the elliptic curve.
|
||||
#
|
||||
# Starting from 64 points, parallelism seems to always be beneficial (serial takes over 1 ms on laptop)
|
||||
#
|
||||
# There are 3 parallelism opportunities:
|
||||
# - 0.a MiniMSMs accumulation a.k.a "window-level paralllism"
|
||||
# is straightforward as there are no data dependencies at all.
|
||||
# - 1.a Buckets accumulation a.k.a "bucket-level parallelism".
|
||||
# Buckets needs to be parallelized over buckets and not points to avoid synchronization between threads.
|
||||
# The disadvantage is that all threads scan all the points.
|
||||
# - and doing separate MSMs over part of the points, a.k.a "msm-level parallelism".
|
||||
# As the number of points grows, the cost of scalar-mul per point diminishes at the rate O(n/log n) as we can increase the window size `c`
|
||||
# to reduce the number of operations. However when `c` reaches 16, memory bandwidth becomes another bottleneck
|
||||
# hence parallelizing at this level becomes interesting.
|
||||
#
|
||||
# We can also parallelize the reductions but they would require extra doublings to "place the reduction" at the right bits.
|
||||
# Example:
|
||||
# let's say we compute the binary number 0b11010110
|
||||
# Each 1 is add+double, each 0 is just double.
|
||||
# We can split the computation in parallel 0b1101 << 4 and 0b0110
|
||||
# but now we need 4 extra doublings to shift the high part in the correct place.
|
||||
# Alternatively we can do "latency hiding", we start the computation before all results are available, and wait for the next part to finish.
|
||||
#
|
||||
# Now, with a small c, say 1024 inputs, c=9, the window-level parallelism is large: 14 to 28 for 128-bit to 256-bit coefs.
|
||||
# For large c, say 262k inputs, c=16, the window-level parallelism is small: 8 to 16 for 128-bit to 256-bit coefs.
|
||||
#
|
||||
# Zero Knowledge protocols need to operate on millions of points, so we want to fully occupy high-end CPUs
|
||||
# - AMD EPYC 9654 96C/192T on 2 sockets hence 384 threads
|
||||
# - Intel Xeon Platinum 8490H 60C/120T on 8 sockets hence 960 threads
|
||||
#
|
||||
# Bucket-level parallism has a multiplicative factor on parallelism exposed
|
||||
# Impact in order of importance of a high chunking factor:
|
||||
# + the more parallelism opportunities we offer.
|
||||
# - the more collision we have when setting up sparse vector affine addition
|
||||
# + the more fine-grained the bucket reduction can be interleaved
|
||||
# - the more passes over the data there are.
|
||||
# - the more memory we used
|
||||
#
|
||||
# Hence do we want inner parallelism to be
|
||||
# - a fraction of the number of threads?
|
||||
# - a multiple of the number of threads?
|
||||
# - a fraction of the number of buckets?
|
||||
#
|
||||
# Back to latency hiding,
|
||||
# 1. The reductions can be done bottom bits to top bits or top bits to bottom bits.
|
||||
# Often bits/c has remainder top bits that are smaller so top to bottom would allow
|
||||
# to start the reduction while the rest of the windows are still being processed.
|
||||
# 2. a thread processes its own tasks in a LIFO manner.
|
||||
# But thieves steal tasks in FIFO manner.
|
||||
#
|
||||
# Due to 1, it's probable best to reduce in staggered manner from top to bottom.
|
||||
# But then in which order to issue accumulations?
|
||||
# 1. Do we schedule the top bits first, in hope they would be stolen. (FIFO thefts)
|
||||
# 2. or do we schedule the top bits last, so that once we reduce, we directly schedule a related task. (LIFO dequeueing)
|
||||
#
|
||||
# Lastly we can go further on latency hiding for the bucket-level parallelism,
|
||||
# having decreasing range sizes so that the top ranges are ready earlier for interleaving reduction.
|
||||
|
||||
# Parallel MSM Jacobian Extended
|
||||
# ------------------------------
|
||||
|
||||
proc bucketAccumReduce_jacext_zeroMem[F, G; bits: static int](
|
||||
windowSum: ptr ECP_ShortW[F, G],
|
||||
buckets: ptr ECP_ShortW_JacExt[F, G] or ptr UncheckedArray[ECP_ShortW_JacExt[F, G]],
|
||||
bitIndex: int, miniMsmKind: static MiniMsmKind, c: static int,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], N: int) =
|
||||
const numBuckets = 1 shl (c-1)
|
||||
let buckets = cast[ptr UncheckedArray[ECP_ShortW_JacExt[F, G]]](buckets)
|
||||
zeroMem(buckets, sizeof(ECP_ShortW_JacExt[F, G]) * numBuckets)
|
||||
bucketAccumReduce_jacext(windowSum[], buckets, bitIndex, miniMsmKind, c, coefs, points, N)
|
||||
|
||||
proc msmJacExt_vartime_parallel*[bits: static int, F, G](
|
||||
tp: Threadpool,
|
||||
r: var ECP_ShortW[F, G],
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int, c: static int) =
|
||||
|
||||
# Prologue
|
||||
# --------
|
||||
const numBuckets = 1 shl (c-1)
|
||||
const numFullWindows = bits div c
|
||||
const numWindows = numFullWindows + 1 # Even if `bits div c` is exact, the signed recoding needs to see an extra 0 after the MSB
|
||||
|
||||
# Instead of storing the result in futures, risking them being scattered in memory
|
||||
# we store them in a contiguous array, and the synchronizing future just returns a bool.
|
||||
# top window is done on this thread
|
||||
type EC = typeof(r)
|
||||
let miniMSMsResults = allocHeapArray(EC, numFullWindows)
|
||||
let miniMSMsReady = allocStackArray(FlowVar[bool], numFullWindows)
|
||||
|
||||
let bucketsMatrix = allocHeapArray(ECP_ShortW_JacExt[F, G], numBuckets*numWindows)
|
||||
|
||||
# Algorithm
|
||||
# ---------
|
||||
|
||||
block: # 1. Bucket accumulation and reduction
|
||||
miniMSMsReady[0] = tp.spawnAwaitable bucketAccumReduce_jacext_zeroMem(
|
||||
miniMSMsResults[0].addr,
|
||||
bucketsMatrix[0].addr,
|
||||
bitIndex = 0, kBottomWindow, c,
|
||||
coefs, points, N)
|
||||
|
||||
for w in 1 ..< numFullWindows:
|
||||
miniMSMsReady[w] = tp.spawnAwaitable bucketAccumReduce_jacext_zeroMem(
|
||||
miniMSMsResults[w].addr,
|
||||
bucketsMatrix[w*numBuckets].addr,
|
||||
bitIndex = w*c, kFullWindow, c,
|
||||
coefs, points, N)
|
||||
|
||||
# Last window is done sync on this thread, directly initializing r
|
||||
const excess = bits mod c
|
||||
const top = bits-excess
|
||||
|
||||
when top != 0:
|
||||
when excess != 0:
|
||||
bucketAccumReduce_jacext_zeroMem(
|
||||
r.addr,
|
||||
bucketsMatrix[numFullWindows*numBuckets].addr,
|
||||
bitIndex = top, kTopWindow, c,
|
||||
coefs, points, N)
|
||||
else:
|
||||
r.setInf()
|
||||
|
||||
# 3. Final reduction, r initialized to what would be miniMSMsReady[numWindows-1]
|
||||
when excess != 0:
|
||||
for w in countdown(numWindows-2, 0):
|
||||
for _ in 0 ..< c:
|
||||
r.double()
|
||||
discard sync miniMSMsReady[w]
|
||||
r += miniMSMsResults[w]
|
||||
elif numWindows >= 2:
|
||||
discard sync miniMSMsReady[numWindows-2]
|
||||
r = miniMSMsResults[numWindows-2]
|
||||
for w in countdown(numWindows-3, 0):
|
||||
for _ in 0 ..< c:
|
||||
r.double()
|
||||
discard sync miniMSMsReady[w]
|
||||
r += miniMSMsResults[w]
|
||||
|
||||
# Cleanup
|
||||
# -------
|
||||
miniMSMsResults.freeHeap()
|
||||
bucketsMatrix.freeHeap()
|
||||
|
||||
# Parallel MSM Affine
|
||||
# ------------------------------
|
||||
|
||||
|
||||
proc bucketAccumReduce_parallel[bits: static int, F, G](
|
||||
tp: Threadpool,
|
||||
r: ptr ECP_ShortW[F, G],
|
||||
bitIndex: int,
|
||||
miniMsmKind: static MiniMsmKind, c: static int,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int) =
|
||||
|
||||
const (numBuckets, queueLen) = c.deriveSchedulerConstants()
|
||||
const outerParallelism = bits div c # It's actually ceilDiv instead of floorDiv, but the last iteration might be too small
|
||||
|
||||
var innerParallelism = 1'i32
|
||||
while outerParallelism*innerParallelism < tp.numThreads:
|
||||
innerParallelism = innerParallelism shl 1
|
||||
|
||||
let numChunks = 1'i32 # innerParallelism # TODO: unfortunately trying to expose more parallelism slows down the performance
|
||||
let chunkSize = int32(numBuckets) shr log2_vartime(cast[uint32](numChunks)) # Both are power of 2 so exact division
|
||||
let chunksReadiness = allocStackArray(FlowVar[bool], numChunks-1) # Last chunk is done on this thread
|
||||
|
||||
let buckets = allocHeap(Buckets[numBuckets, F, G])
|
||||
let scheds = allocHeapArray(Scheduler[numBuckets, queueLen, F, G], numChunks)
|
||||
|
||||
buckets[].init()
|
||||
|
||||
block: # 1. Bucket Accumulation
|
||||
for chunkID in 0'i32 ..< numChunks-1:
|
||||
let idx = chunkID*chunkSize
|
||||
scheds[chunkID].addr.init(points, buckets, idx, idx+chunkSize)
|
||||
chunksReadiness[chunkID] = tp.spawnAwaitable schedAccumulate(scheds[chunkID].addr, bitIndex, miniMsmKind, c, coefs, N)
|
||||
# Last bucket is done sync on this thread
|
||||
scheds[numChunks-1].addr.init(points, buckets, (numChunks-1)*chunkSize, int32 numBuckets)
|
||||
scheds[numChunks-1].addr.schedAccumulate(bitIndex, miniMsmKind, c, coefs, N)
|
||||
|
||||
block: # 2. Bucket reduction
|
||||
var windowSum{.noInit.}: ECP_ShortW_JacExt[F, G]
|
||||
var accumBuckets{.noinit.}: ECP_ShortW_JacExt[F, G]
|
||||
|
||||
if kAffine in buckets.status[numBuckets-1]:
|
||||
if kJacExt in buckets.status[numBuckets-1]:
|
||||
accumBuckets.madd_vartime(buckets.ptJacExt[numBuckets-1], buckets.ptAff[numBuckets-1])
|
||||
else:
|
||||
accumBuckets.fromAffine(buckets.ptAff[numBuckets-1])
|
||||
elif kJacExt in buckets.status[numBuckets-1]:
|
||||
accumBuckets = buckets.ptJacExt[numBuckets-1]
|
||||
else:
|
||||
accumBuckets.setInf()
|
||||
windowSum = accumBuckets
|
||||
buckets[].reset(numBuckets-1)
|
||||
|
||||
var nextBatch = numBuckets-1-chunkSize
|
||||
var nextFutureIdx = numChunks-2
|
||||
|
||||
for k in countdown(numBuckets-2, 0):
|
||||
if k == nextBatch:
|
||||
discard sync(chunksReadiness[nextFutureIdx])
|
||||
nextBatch -= chunkSize
|
||||
nextFutureIdx -= 1
|
||||
|
||||
if kAffine in buckets.status[k]:
|
||||
if kJacExt in buckets.status[k]:
|
||||
var t{.noInit.}: ECP_ShortW_JacExt[F, G]
|
||||
t.madd_vartime(buckets.ptJacExt[k], buckets.ptAff[k])
|
||||
accumBuckets += t
|
||||
else:
|
||||
accumBuckets += buckets.ptAff[k]
|
||||
elif kJacExt in buckets.status[k]:
|
||||
accumBuckets += buckets.ptJacExt[k]
|
||||
|
||||
buckets[].reset(k)
|
||||
windowSum += accumBuckets
|
||||
|
||||
r[].fromJacobianExtended_vartime(windowSum)
|
||||
|
||||
# Cleanup
|
||||
# ----------------
|
||||
scheds.freeHeap()
|
||||
buckets.freeHeap()
|
||||
|
||||
proc msmAffine_vartime_parallel*[bits: static int, F, G](
|
||||
tp: Threadpool,
|
||||
r: var ECP_ShortW[F, G],
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int, c: static int) =
|
||||
|
||||
# Prologue
|
||||
# --------
|
||||
const numBuckets = 1 shl (c-1)
|
||||
const numFullWindows = bits div c
|
||||
const numWindows = numFullWindows + 1 # Even if `bits div c` is exact, the signed recoding needs to see an extra 0 after the MSB
|
||||
|
||||
# Instead of storing the result in futures, risking them being scattered in memory
|
||||
# we store them in a contiguous array, and the synchronizing future just returns a bool.
|
||||
# top window is done on this thread
|
||||
type EC = typeof(r)
|
||||
let miniMSMsResults = allocHeapArray(EC, numFullWindows)
|
||||
let miniMSMsReady = allocStackArray(Flowvar[bool], numFullWindows)
|
||||
|
||||
# Algorithm
|
||||
# ---------
|
||||
|
||||
block: # 1. Bucket accumulation and reduction
|
||||
miniMSMsReady[0] = tp.spawnAwaitable bucketAccumReduce_parallel(
|
||||
tp, miniMSMsResults[0].addr,
|
||||
bitIndex = 0, kBottomWIndow, c,
|
||||
coefs, points, N)
|
||||
|
||||
for w in 1 ..< numFullWindows:
|
||||
miniMSMsReady[w] = tp.spawnAwaitable bucketAccumReduce_parallel(
|
||||
tp, miniMSMsResults[w].addr,
|
||||
bitIndex = w*c, kFullWIndow, c,
|
||||
coefs, points, N)
|
||||
|
||||
# Last window is done sync on this thread, directly initializing r
|
||||
const excess = bits mod c
|
||||
const top = bits-excess
|
||||
|
||||
when top != 0:
|
||||
when excess != 0:
|
||||
let buckets = allocHeapArray(ECP_ShortW_JacExt[F, G], numBuckets)
|
||||
zeroMem(buckets[0].addr, sizeof(ECP_ShortW_JacExt[F, G]) * numBuckets)
|
||||
r.bucketAccumReduce_jacext(buckets, bitIndex = top, kTopWindow, c,
|
||||
coefs, points, N)
|
||||
buckets.freeHeap()
|
||||
else:
|
||||
r.setInf()
|
||||
|
||||
# 3. Final reduction, r initialized to what would be miniMSMsReady[numWindows-1]
|
||||
when excess != 0:
|
||||
for w in countdown(numWindows-2, 0):
|
||||
for _ in 0 ..< c:
|
||||
r.double()
|
||||
discard sync miniMSMsReady[w]
|
||||
r += miniMSMsResults[w]
|
||||
elif numWindows >= 2:
|
||||
discard sync miniMSMsReady[numWindows-2]
|
||||
r = miniMSMsResults[numWindows-2]
|
||||
for w in countdown(numWindows-3, 0):
|
||||
for _ in 0 ..< c:
|
||||
r.double()
|
||||
discard sync miniMSMsReady[w]
|
||||
r += miniMSMsResults[w]
|
||||
|
||||
# Cleanup
|
||||
# -------
|
||||
miniMSMsResults.freeHeap()
|
||||
|
||||
proc applyEndomorphism_parallel[bits: static int, F, G](
|
||||
tp: Threadpool,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int): auto =
|
||||
## Decompose (coefs, points) into mini-scalars
|
||||
## Returns a new triplet (endoCoefs, endoPoints, N)
|
||||
## endoCoefs and endoPoints MUST be freed afterwards
|
||||
|
||||
const M = when F is Fp: 2
|
||||
elif F is Fp2: 4
|
||||
else: {.error: "Unconfigured".}
|
||||
|
||||
const L = bits.ceilDiv_vartime(M) + 1
|
||||
let splitCoefs = allocHeapArray(array[M, BigInt[L]], N)
|
||||
let endoBasis = allocHeapArray(array[M, ECP_ShortW_Aff[F, G]], N)
|
||||
|
||||
tp.parallelFor i in 0 ..< N:
|
||||
captures: {coefs, points, splitCoefs, endoBasis}
|
||||
|
||||
var negatePoints {.noinit.}: array[M, SecretBool]
|
||||
splitCoefs[i].decomposeEndo(negatePoints, coefs[i], F)
|
||||
if negatePoints[0].bool:
|
||||
endoBasis[i][0].neg(points[i])
|
||||
else:
|
||||
endoBasis[i][0] = points[i]
|
||||
|
||||
when F is Fp:
|
||||
endoBasis[i][1].x.prod(points[i].x, F.C.getCubicRootOfUnity_mod_p())
|
||||
if negatePoints[1].bool:
|
||||
endoBasis[i][1].y.neg(points[i].y)
|
||||
else:
|
||||
endoBasis[i][1].y = points[i].y
|
||||
else:
|
||||
staticFor m, 1, M:
|
||||
endoBasis[i][m].frobenius_psi(points[i], m)
|
||||
if negatePoints[m].bool:
|
||||
endoBasis[i][m].neg()
|
||||
|
||||
tp.syncAll()
|
||||
|
||||
let endoCoefs = cast[ptr UncheckedArray[BigInt[L]]](splitCoefs)
|
||||
let endoPoints = cast[ptr UncheckedArray[ECP_ShortW_Aff[F, G]]](endoBasis)
|
||||
|
||||
return (endoCoefs, endoPoints, M*N)
|
||||
|
||||
template withEndo[bits: static int, F, G](
|
||||
msmProc: untyped,
|
||||
tp: Threadpool,
|
||||
r: var ECP_ShortW[F, G],
|
||||
coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int, c: static int) =
|
||||
when bits <= F.C.getCurveOrderBitwidth() and hasEndomorphismAcceleration(F.C):
|
||||
let (endoCoefs, endoPoints, endoN) = applyEndomorphism_parallel(tp, coefs, points, N)
|
||||
# Given that bits and N changed, we are able to use a bigger `c`
|
||||
# but it has no significant impact on performance
|
||||
msmProc(tp, r, endoCoefs, endoPoints, endoN, c)
|
||||
freeHeap(endoCoefs)
|
||||
freeHeap(endoPoints)
|
||||
else:
|
||||
msmProc(tp, r, coefs, points, N, c)
|
||||
|
||||
proc multiScalarMul_dispatch_vartime_parallel[bits: static int, F, G](
|
||||
tp: Threadpool,
|
||||
r: var ECP_ShortW[F, G], coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], N: int) =
|
||||
## Multiscalar multiplication:
|
||||
## r <- [a₀]P₀ + [a₁]P₁ + ... + [aₙ]Pₙ
|
||||
let c = bestBucketBitSize(N, bits, useSignedBuckets = true, useManualTuning = true)
|
||||
|
||||
# Given that bits and N change after applying an endomorphism,
|
||||
# we are able to use a bigger `c`
|
||||
# but it has no significant impact on performance
|
||||
|
||||
case c
|
||||
of 2: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 2)
|
||||
of 3: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 3)
|
||||
of 4: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 4)
|
||||
of 5: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 5)
|
||||
of 6: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 6)
|
||||
of 7: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 7)
|
||||
of 8: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 8)
|
||||
of 9: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 9)
|
||||
of 10: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 10)
|
||||
|
||||
of 11: withEndo(msmAffine_vartime_parallel, tp, r, coefs, points, N, c = 11)
|
||||
|
||||
of 12: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 12)
|
||||
of 13: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 13)
|
||||
of 14: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 14)
|
||||
of 15: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 15)
|
||||
of 16: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 16)
|
||||
of 17: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 17)
|
||||
of 18: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 18)
|
||||
else:
|
||||
unreachable()
|
||||
|
||||
proc multiScalarMul_vartime_parallel*[bits: static int, F, G](
|
||||
tp: Threadpool,
|
||||
r: var ECP_ShortW[F, G],
|
||||
coefs: openArray[BigInt[bits]],
|
||||
points: openArray[ECP_ShortW_Aff[F, G]]) {.meter, inline.} =
|
||||
|
||||
debug: doAssert coefs.len == points.len
|
||||
let N = points.len
|
||||
|
||||
tp.multiScalarMul_dispatch_vartime_parallel(r, coefs.asUnchecked(), points.asUnchecked(), N)
|
||||
@ -173,11 +173,11 @@ func bestBucketBitSize*(inputSize: int, scalarBitwidth: static int, useSignedBuc
|
||||
# 1. Bucket accumulation
|
||||
# n - (2ᶜ-1) additions for b/c windows or n - (2ᶜ⁻¹-1) if using signed buckets
|
||||
# 2. Bucket reduction
|
||||
# 2x(2ᶜ-2) additions for b/c windows or 2x(2ᶜ⁻¹-2)
|
||||
# 2x(2ᶜ-2) additions for b/c windows or 2*(2ᶜ⁻¹-2)
|
||||
# 3. Final reduction
|
||||
# (b/c - 1) x (c doublings + 1 addition)
|
||||
# Total
|
||||
# b/c (n + 2ᶜ - 2) A + (b/c - 1) x (c*D + A)
|
||||
# b/c (n + 2ᶜ - 2) A + (b/c - 1) * (c*D + A)
|
||||
# https://www.youtube.com/watch?v=Bl5mQA7UL2I
|
||||
|
||||
# A doubling costs 50% of an addition with jacobian coordinates
|
||||
@ -234,15 +234,16 @@ func `-=`*[F; G: static Subgroup](P: var ECP_ShortW_JacExt[F, G], Q: ECP_ShortW_
|
||||
# "Sharpening the axe will not delay cutting the wood" - Chinese proverb
|
||||
|
||||
type
|
||||
BucketStatus = enum
|
||||
BucketStatus* = enum
|
||||
kAffine, kJacExt
|
||||
|
||||
Buckets*[N: static int, F; G: static Subgroup] = object
|
||||
status: array[N, set[BucketStatus]]
|
||||
ptAff: array[N, ECP_ShortW_Aff[F, G]]
|
||||
status*: array[N, set[BucketStatus]]
|
||||
ptAff*: array[N, ECP_ShortW_Aff[F, G]]
|
||||
ptJacExt*: array[N, ECP_ShortW_JacExt[F, G]] # Public for the top window
|
||||
|
||||
ScheduledPoint* = object
|
||||
# Note: we cannot compute the size at compile-time due to https://github.com/nim-lang/Nim/issues/19040
|
||||
bucket {.bitsize:26.}: int64 # Supports up to 2²⁵ = 33 554 432 buckets and -1 for the skipped bucket 0
|
||||
sign {.bitsize: 1.}: int64
|
||||
pointID {.bitsize:37.}: int64 # Supports up to 2³⁷ = 137 438 953 472 points
|
||||
@ -261,7 +262,7 @@ const MinVectorAddThreshold = 32
|
||||
func init*(buckets: var Buckets) {.inline.} =
|
||||
zeroMem(buckets.status.addr, buckets.status.sizeof())
|
||||
|
||||
func reset(buckets: var Buckets, index: int) {.inline.} =
|
||||
func reset*(buckets: var Buckets, index: int) {.inline.} =
|
||||
buckets.status[index] = {}
|
||||
|
||||
func deriveSchedulerConstants*(c: int): tuple[numNZBuckets, queueLen: int] {.compileTime.} =
|
||||
@ -270,7 +271,7 @@ func deriveSchedulerConstants*(c: int): tuple[numNZBuckets, queueLen: int] {.com
|
||||
result.queueLen = max(MinVectorAddThreshold, 4*c*c - 16*c - 128)
|
||||
|
||||
func init*[NumNZBuckets, QueueLen: static int, F; G: static Subgroup](
|
||||
sched: var Scheduler[NumNZBuckets, QueueLen, F, G], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
sched: ptr Scheduler[NumNZBuckets, QueueLen, F, G], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
buckets: ptr Buckets[NumNZBuckets, F, G], start, stopEx: int32) {.inline.} =
|
||||
## init a scheduler overseeing buckets [start, stopEx)
|
||||
## within the indices [0, NumNZBuckets). Bucket for value 0 is considered at index -1.
|
||||
@ -287,13 +288,13 @@ func scheduledPointDescriptor*(pointIndex: int, pointDesc: tuple[val: SecretWord
|
||||
sign: cast[int64](pointDesc.neg),
|
||||
pointID: cast[int64](pointIndex))
|
||||
|
||||
func enqueuePoint(sched: var Scheduler, sp: ScheduledPoint) {.inline.} =
|
||||
func enqueuePoint(sched: ptr Scheduler, sp: ScheduledPoint) {.inline.} =
|
||||
sched.queue[sched.numScheduled] = sp
|
||||
sched.collisionsMap.setBit(sp.bucket.int)
|
||||
sched.numScheduled += 1
|
||||
|
||||
func handleCollision(sched: var Scheduler, sp: ScheduledPoint)
|
||||
func rescheduleCollisions(sched: var Scheduler)
|
||||
func handleCollision(sched: ptr Scheduler, sp: ScheduledPoint)
|
||||
func rescheduleCollisions(sched: ptr Scheduler)
|
||||
func sparseVectorAddition[F, G](
|
||||
buckets: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
bucketStatuses: ptr UncheckedArray[set[BucketStatus]],
|
||||
@ -301,7 +302,7 @@ func sparseVectorAddition[F, G](
|
||||
scheduledPoints: ptr UncheckedArray[ScheduledPoint],
|
||||
numScheduled: int32) {.noInline, tags:[VarTime, Alloca].}
|
||||
|
||||
func prefetch*(sched: Scheduler, sp: ScheduledPoint) =
|
||||
func prefetch*(sched: ptr Scheduler, sp: ScheduledPoint) =
|
||||
let bucket = sp.bucket
|
||||
if bucket == -1:
|
||||
return
|
||||
@ -310,7 +311,7 @@ func prefetch*(sched: Scheduler, sp: ScheduledPoint) =
|
||||
prefetchLarge(sched.buckets.ptAff[bucket].addr, Write, HighTemporalLocality, maxCacheLines = 1)
|
||||
prefetchLarge(sched.buckets.ptJacExt[bucket].addr, Write, HighTemporalLocality, maxCacheLines = 1)
|
||||
|
||||
func schedule*(sched: var Scheduler, sp: ScheduledPoint) =
|
||||
func schedule*(sched: ptr Scheduler, sp: ScheduledPoint) =
|
||||
## Schedule a point for accumulating in buckets
|
||||
|
||||
let bucket = int sp.bucket
|
||||
@ -339,7 +340,7 @@ func schedule*(sched: var Scheduler, sp: ScheduledPoint) =
|
||||
sched.collisionsMap.setZero()
|
||||
sched.rescheduleCollisions()
|
||||
|
||||
func handleCollision(sched: var Scheduler, sp: ScheduledPoint) =
|
||||
func handleCollision(sched: ptr Scheduler, sp: ScheduledPoint) =
|
||||
if sched.numCollisions < sched.collisions.len:
|
||||
sched.collisions[sched.numCollisions] = sp
|
||||
sched.numCollisions += 1
|
||||
@ -358,7 +359,7 @@ func handleCollision(sched: var Scheduler, sp: ScheduledPoint) =
|
||||
else:
|
||||
sched.buckets.ptJacExt[sp.bucket] -= sched.points[sp.pointID]
|
||||
|
||||
func rescheduleCollisions(sched: var Scheduler) =
|
||||
func rescheduleCollisions(sched: ptr Scheduler) =
|
||||
template last: untyped = sched.numCollisions-1
|
||||
var i = last()
|
||||
while i >= 0:
|
||||
@ -370,7 +371,7 @@ func rescheduleCollisions(sched: var Scheduler) =
|
||||
sched.numCollisions -= 1
|
||||
i -= 1
|
||||
|
||||
func flushBuffer(sched: var Scheduler, buf: ptr UncheckedArray[ScheduledPoint], count: var int32) =
|
||||
func flushBuffer(sched: ptr Scheduler, buf: ptr UncheckedArray[ScheduledPoint], count: var int32) =
|
||||
for i in 0 ..< count:
|
||||
let sp = buf[i]
|
||||
if kJacExt in sched.buckets.status[sp.bucket]:
|
||||
@ -385,7 +386,7 @@ func flushBuffer(sched: var Scheduler, buf: ptr UncheckedArray[ScheduledPoint],
|
||||
sched.buckets.status[sp.bucket].incl(kJacExt)
|
||||
count = 0
|
||||
|
||||
func flushPendingAndReset*(sched: var Scheduler) =
|
||||
func flushPendingAndReset*(sched: ptr Scheduler) =
|
||||
if sched.numScheduled >= MinVectorAddThreshold:
|
||||
sparseVectorAddition(
|
||||
sched.buckets.ptAff.asUnchecked(), sched.buckets.status.asUnchecked(),
|
||||
|
||||
@ -236,7 +236,7 @@ func affineAdd*[F; G: static Subgroup](
|
||||
|
||||
func accum_half_vartime[F; G: static Subgroup](
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
len: uint) {.noInline, tags:[VarTime, Alloca].} =
|
||||
len: int) {.noInline, tags:[VarTime, Alloca].} =
|
||||
## Affine accumulation of half the points into the other half
|
||||
## Warning ⚠️ : variable-time
|
||||
##
|
||||
@ -252,7 +252,7 @@ func accum_half_vartime[F; G: static Subgroup](
|
||||
|
||||
debug: doAssert len and 1 == 0, "There must be an even number of points"
|
||||
|
||||
let N = int(len div 2)
|
||||
let N = len shr 1
|
||||
let lambdas = allocStackArray(tuple[num, den: F], N)
|
||||
|
||||
# Step 1: Compute numerators and denominators of λᵢ = λᵢ_num / λᵢ_den
|
||||
@ -368,13 +368,19 @@ template `+=`[F; G: static Subgroup](P: var ECP_ShortW_JacExt[F, G], Q: ECP_Shor
|
||||
# we create a local `+=` for this module only
|
||||
madd_vartime(P, P, Q)
|
||||
|
||||
func accumSum_chunk_vartime[F; G: static Subgroup](
|
||||
func accumSum_chunk_vartime*[F; G: static Subgroup](
|
||||
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G] or ECP_ShortW_JacExt[F, G]),
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], len: uint) =
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], len: int) {.noInline, tags:[VarTime, Alloca].} =
|
||||
## Accumulate `points` into r.
|
||||
## `r` is NOT overwritten
|
||||
## r += ∑ points
|
||||
## `points` are destroyed
|
||||
##
|
||||
## `len` should be chosen so that `len` points
|
||||
## use cache efficiently
|
||||
|
||||
let accumulators = allocStackArray(ECP_ShortW_Aff[F, G], len)
|
||||
let size = len * sizeof(ECP_ShortW_Aff[F, G])
|
||||
copyMem(accumulators[0].addr, points[0].unsafeAddr, size)
|
||||
|
||||
const minNumPointsSerial = 16
|
||||
var n = len
|
||||
@ -392,12 +398,12 @@ func accumSum_chunk_vartime[F; G: static Subgroup](
|
||||
n = n div 2
|
||||
|
||||
# Tail
|
||||
for i in 0'u ..< n:
|
||||
for i in 0 ..< n:
|
||||
r += points[i]
|
||||
|
||||
func accum_batch_vartime[F; G: static Subgroup](
|
||||
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G] or ECP_ShortW_JacExt[F, G]),
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], pointsLen: int) {.noInline, tags:[VarTime, Alloca].} =
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], pointsLen: int) =
|
||||
## Batch accumulation of `points` into `r`
|
||||
## `r` is accumulated into
|
||||
|
||||
@ -423,14 +429,9 @@ func accum_batch_vartime[F; G: static Subgroup](
|
||||
const maxTempMem = 262144 # 2¹⁸ = 262144
|
||||
const maxStride = maxTempMem div sizeof(ECP_ShortW_Aff[F, G])
|
||||
|
||||
let n = min(maxStride, pointsLen)
|
||||
let accumulators = allocStackArray(ECP_ShortW_Aff[F, G], n)
|
||||
|
||||
for i in countup(0, pointsLen-1, maxStride):
|
||||
let n = min(maxStride, pointsLen - i)
|
||||
let size = n * sizeof(ECP_ShortW_Aff[F, G])
|
||||
copyMem(accumulators[0].addr, points[i].unsafeAddr, size)
|
||||
r.accumSum_chunk_vartime(accumulators, uint n)
|
||||
r.accumSum_chunk_vartime(points +% i, n)
|
||||
|
||||
func sum_reduce_vartime*[F; G: static Subgroup](
|
||||
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G] or ECP_ShortW_JacExt[F, G]),
|
||||
@ -477,8 +478,7 @@ func consumeBuffer[EC, F; G: static Subgroup; AccumMax: static int](
|
||||
if ctx.cur == 0:
|
||||
return
|
||||
|
||||
let lambdas = allocStackArray(tuple[num, den: F], ctx.cur.int)
|
||||
ctx.accum.accumSum_chunk_vartime(ctx.buffer.asUnchecked(), lambdas, ctx.cur.uint)
|
||||
ctx.accum.accumSum_chunk_vartime(ctx.buffer.asUnchecked(), ctx.cur)
|
||||
ctx.cur = 0
|
||||
|
||||
func update*[EC, F, G; AccumMax: static int](
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
|
||||
import
|
||||
../../platforms/abstractions,
|
||||
../../platforms/threadpool/threadpool,
|
||||
../../platforms/threadpool/[threadpool, partitioners],
|
||||
./ec_shortweierstrass_affine,
|
||||
./ec_shortweierstrass_jacobian,
|
||||
./ec_shortweierstrass_projective,
|
||||
@ -24,48 +24,7 @@ import
|
||||
#
|
||||
# ############################################################
|
||||
|
||||
type ChunkDescriptor = object
|
||||
start, totalIters: int
|
||||
numChunks, baseChunkSize, cutoff: int
|
||||
|
||||
func computeBalancedChunks(start, stopEx, minChunkSize, maxChunkSize, targetNumChunks: int): ChunkDescriptor =
|
||||
## Balanced chunking algorithm for a range [start, stopEx)
|
||||
## This ideally splits a range into min(stopEx-start, targetNumChunks) balanced regions
|
||||
## unless the chunk size isn't in the range [minChunkSize, maxChunkSize]
|
||||
#
|
||||
# see constantine/platforms/threadpool/docs/partitioner.md
|
||||
let totalIters = stopEx - start
|
||||
var numChunks = max(targetNumChunks, 1)
|
||||
var baseChunkSize = totalIters div numChunks
|
||||
var cutoff = totalIters mod numChunks # Should be computed in a single instruction with baseChunkSize
|
||||
|
||||
if baseChunkSize < minChunkSize:
|
||||
numChunks = max(totalIters div minChunkSize, 1)
|
||||
baseChunkSize = totalIters div numChunks
|
||||
cutoff = totalIters mod numChunks
|
||||
elif baseChunkSize > maxChunkSize or (baseChunkSize == maxChunkSize and cutoff != 0):
|
||||
# After cutoff, we do baseChunkSize+1, and would run afoul of the maxChunkSize constraint (unless no remainder), hence ceildiv
|
||||
numChunks = totalIters.ceilDiv_vartime(maxChunkSize)
|
||||
baseChunkSize = totalIters div numChunks
|
||||
cutoff = totalIters mod numChunks
|
||||
|
||||
return ChunkDescriptor(
|
||||
start: start, totaliters: totalIters,
|
||||
numChunks: numChunks, baseChunkSize: baseChunkSize, cutoff: cutoff
|
||||
)
|
||||
|
||||
iterator items(c: ChunkDescriptor): tuple[chunkID, start, stopEx: int] =
|
||||
for chunkID in 0 ..< min(c.numChunks, c.totalIters):
|
||||
if chunkID < c.cutoff:
|
||||
let offset = c.start + ((c.baseChunkSize + 1) * chunkID)
|
||||
let chunkSize = c.baseChunkSize + 1
|
||||
yield (chunkID, offset, min(offset+chunkSize, c.totalIters))
|
||||
else:
|
||||
let offset = c.start + (c.baseChunkSize * chunkID) + c.cutoff
|
||||
let chunkSize = c.baseChunkSize
|
||||
yield (chunkID, offset, min(offset+chunkSize, c.totalIters))
|
||||
|
||||
proc sum_reduce_vartime_parallel*[F; G: static Subgroup](
|
||||
proc sum_reduce_vartime_parallelChunks[F; G: static Subgroup](
|
||||
tp: Threadpool,
|
||||
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
|
||||
points: openArray[ECP_ShortW_Aff[F, G]]) {.noInline.} =
|
||||
@ -73,49 +32,39 @@ proc sum_reduce_vartime_parallel*[F; G: static Subgroup](
|
||||
## `r` is overwritten
|
||||
## Compute is parallelized, if beneficial.
|
||||
## This function cannot be nested in another parallel function
|
||||
##
|
||||
## Side-effects due to thread-local threadpool variable accesses.
|
||||
|
||||
# TODO:
|
||||
# This function is needed in Multi-Scalar Multiplication (MSM)
|
||||
# The main bottleneck (~80% time) of zero-ledge proof systems.
|
||||
# MSM is difficult to scale above 16 cores,
|
||||
# allowing nested parallelism will expose more parallelism opportunities.
|
||||
|
||||
# Chunking constants in ec_shortweierstrass_batch_ops.nim
|
||||
|
||||
const minNumPointsParallel = 1024 # For 256-bit curves that's 1024*(32+32) = 65536 temp mem
|
||||
const maxTempMem = 262144 # 2¹⁸ = 262144
|
||||
const maxNumPoints = maxTempMem div sizeof(ECP_ShortW_Aff[F, G])
|
||||
const maxChunkSize = maxTempMem div sizeof(ECP_ShortW_Aff[F, G])
|
||||
const minChunkSize = (maxChunkSize * 60) div 100 # We want 60%~100% full chunks
|
||||
|
||||
# 262144 / (2*1024) = 128 bytes allowed per coordinates. Largest curve BW6-761 requires 96 bytes per coordinate. And G2 is on Fp, not Fp2.
|
||||
static: doAssert minNumPointsParallel <= maxNumPoints, "The curve " & $r.typeof & " requires large size and needs to be tuned."
|
||||
|
||||
if points.len < minNumPointsParallel:
|
||||
r.sum_reduce_vartime(points)
|
||||
if points.len <= maxChunkSize:
|
||||
r.setInf()
|
||||
r.accumSum_chunk_vartime(points.asUnchecked(), points.len)
|
||||
return
|
||||
|
||||
let chunkDesc = computeBalancedChunks(
|
||||
let chunkDesc = balancedChunksPrioSize(
|
||||
start = 0, stopEx = points.len,
|
||||
minNumPointsParallel, maxNumPoints,
|
||||
targetNumChunks = tp.numThreads.int)
|
||||
minChunkSize, maxChunkSize,
|
||||
numChunksHint = tp.numThreads.int)
|
||||
|
||||
let partialResults = allocStackArray(r.typeof(), chunkDesc.numChunks)
|
||||
|
||||
for iter in items(chunkDesc):
|
||||
proc sum_reduce_vartime_wrapper(res: ptr, p: ptr, pLen: int) {.nimcall.} =
|
||||
proc sum_reduce_chunk_vartime_wrapper(res: ptr, p: ptr, pLen: int) {.nimcall.} =
|
||||
# The borrow checker prevents capturing `var` and `openArray`
|
||||
# so we capture pointers instead.
|
||||
res[].sum_reduce_vartime(p, pLen)
|
||||
res[].setInf()
|
||||
res[].accumSum_chunk_vartime(p, pLen)
|
||||
|
||||
tp.spawn partialResults[iter.chunkID].addr.sum_reduce_vartime_wrapper(
|
||||
tp.spawn partialResults[iter.chunkID].addr.sum_reduce_chunk_vartime_wrapper(
|
||||
points.asUnchecked() +% iter.start,
|
||||
iter.stopEx - iter.start)
|
||||
iter.size)
|
||||
|
||||
tp.syncAll() # TODO: this prevents nesting in another parallel region
|
||||
|
||||
const minNumPointsSerial = 16
|
||||
if chunkDesc.numChunks < minNumPointsSerial:
|
||||
const minChunkSizeSerial = 32
|
||||
if chunkDesc.numChunks < minChunkSizeSerial:
|
||||
r.setInf()
|
||||
for i in 0 ..< chunkDesc.numChunks:
|
||||
r += partialResults[i]
|
||||
@ -124,16 +73,48 @@ proc sum_reduce_vartime_parallel*[F; G: static Subgroup](
|
||||
partialResultsAffine.batchAffine(partialResults, chunkDesc.numChunks)
|
||||
r.sum_reduce_vartime(partialResultsAffine, chunkDesc.numChunks)
|
||||
|
||||
# Sanity checks
|
||||
# ---------------------------------------
|
||||
proc sum_reduce_vartime_parallelFor[F; G: static Subgroup](
|
||||
tp: Threadpool,
|
||||
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
|
||||
points: openArray[ECP_ShortW_Aff[F, G]]) =
|
||||
## Batch addition of `points` into `r`
|
||||
## `r` is overwritten
|
||||
## Compute is parallelized, if beneficial.
|
||||
|
||||
when isMainModule:
|
||||
block:
|
||||
let chunkDesc = computeBalancedChunks(start = 0, stopEx = 40, minChunkSize = 16, maxChunkSize = 128, targetNumChunks = 12)
|
||||
for chunk in chunkDesc:
|
||||
echo chunk
|
||||
mixin globalSum
|
||||
|
||||
block:
|
||||
let chunkDesc = computeBalancedChunks(start = 0, stopEx = 10000, minChunkSize = 16, maxChunkSize = 128, targetNumChunks = 12)
|
||||
for chunk in chunkDesc:
|
||||
echo chunk
|
||||
const maxTempMem = 262144 # 2¹⁸ = 262144
|
||||
const maxStride = maxTempMem div sizeof(ECP_ShortW_Aff[F, G])
|
||||
|
||||
let p = points.asUnchecked
|
||||
let pointsLen = points.len
|
||||
|
||||
tp.parallelFor i in 0 ..< points.len:
|
||||
stride: maxStride
|
||||
captures: {p, pointsLen}
|
||||
reduceInto(globalSum: typeof(r)):
|
||||
prologue:
|
||||
var localSum {.noInit.}: typeof(r)
|
||||
localSum.setInf()
|
||||
forLoop:
|
||||
let n = min(maxStride, pointsLen-i)
|
||||
localSum.accumSum_chunk_vartime(p +% i, n)
|
||||
merge(remoteSum: Flowvar[typeof(r)]):
|
||||
localSum += sync(remoteSum)
|
||||
epilogue:
|
||||
return localSum
|
||||
|
||||
r = sync(globalSum)
|
||||
|
||||
proc sum_reduce_vartime_parallel*[F; G: static Subgroup](
|
||||
tp: Threadpool,
|
||||
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
|
||||
points: openArray[ECP_ShortW_Aff[F, G]]) {.inline.} =
|
||||
## Batch addition of `points` into `r`
|
||||
## `r` is overwritten
|
||||
## Compute is parallelized, if beneficial.
|
||||
## This function cannot be nested in another parallel function
|
||||
when false:
|
||||
tp.sum_reduce_vartime_parallelFor(r, points)
|
||||
else:
|
||||
tp.sum_reduce_vartime_parallelChunks(r, points)
|
||||
|
||||
@ -80,7 +80,7 @@ template allocStackUnchecked*(T: typedesc, size: int): ptr T =
|
||||
cast[ptr T](alloca(size))
|
||||
|
||||
template allocStackArray*(T: typedesc, len: SomeInteger): ptr UncheckedArray[T] =
|
||||
cast[ptr UncheckedArray[T]](alloca(sizeof(T) * len))
|
||||
cast[ptr UncheckedArray[T]](alloca(sizeof(T) * cast[int](len)))
|
||||
|
||||
# Heap allocation
|
||||
# ----------------------------------------------------------------------------------
|
||||
@ -93,14 +93,14 @@ proc allocHeapUnchecked*(T: typedesc, size: int): ptr T {.inline.} =
|
||||
cast[type result](malloc(size))
|
||||
|
||||
proc allocHeapArray*(T: typedesc, len: SomeInteger): ptr UncheckedArray[T] {.inline.} =
|
||||
cast[type result](malloc(sizeof(T) * len))
|
||||
cast[type result](malloc(sizeof(T) * cast[int](len)))
|
||||
|
||||
proc freeHeap*(p: pointer) {.inline.} =
|
||||
free(p)
|
||||
|
||||
proc allocHeapAligned*(T: typedesc, alignment: static Natural): ptr T {.inline.} =
|
||||
# aligned_alloc requires allocating in multiple of the alignment.
|
||||
const
|
||||
let # Cannot be static with bitfields. Workaround https://github.com/nim-lang/Nim/issues/19040
|
||||
size = sizeof(T)
|
||||
requiredMem = size.roundNextMultipleOf(alignment)
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ proc rebuildUntypedAst*(ast: NimNode, dropRootStmtList = false): NimNode =
|
||||
return rTree
|
||||
else:
|
||||
return defaultMultipleChildren(node)
|
||||
of nnkClosedSymChoice:
|
||||
of {nnkOpenSymChoice, nnkClosedSymChoice}:
|
||||
return rebuild(node[0])
|
||||
else:
|
||||
return defaultMultipleChildren(node)
|
||||
|
||||
@ -143,6 +143,8 @@ template runBench(tp: Threadpool, procName: untyped, matrix: Matrix, bins: int,
|
||||
else:
|
||||
max = procName(matrix, hist)
|
||||
|
||||
wv_free(hist.buffer)
|
||||
|
||||
# Algo
|
||||
# -------------------------------------------------------
|
||||
|
||||
|
||||
@ -57,8 +57,8 @@ const # bitfield setup
|
||||
# - Linux CPUSET supports up to 1024 threads (https://man7.org/linux/man-pages/man3/CPU_SET.3.html)
|
||||
#
|
||||
# Hardware limitations:
|
||||
# - Xeon Platinum 9282, 56 cores - 112 threads per socket
|
||||
# - up to 8 sockets: 896 threads
|
||||
# - Xeon Platinum 8490H, 60C/120T per socket
|
||||
# - up to 8 sockets: 960 threads
|
||||
|
||||
kPreWaitShift = 8'u32
|
||||
kPreWait = 1'u32 shl kPreWaitShift
|
||||
|
||||
@ -25,6 +25,8 @@ import
|
||||
# Flowvars are also called future interchangeably.
|
||||
# (The name future is already used for IO scheduling)
|
||||
|
||||
const NotALoop* = -1
|
||||
|
||||
type
|
||||
TaskState = object
|
||||
## This state allows synchronization between:
|
||||
@ -66,6 +68,8 @@ type
|
||||
env*{.align:sizeof(int).}: UncheckedArray[byte]
|
||||
|
||||
Flowvar*[T] = object
|
||||
# Flowvar is a public object, but we don't want
|
||||
# end-user to access the underlying task, so keep the field private.
|
||||
task: ptr Task
|
||||
|
||||
ReductionDagNode* = object
|
||||
@ -177,6 +181,9 @@ proc newSpawn*(
|
||||
result.hasFuture = false
|
||||
result.fn = fn
|
||||
|
||||
when defined(TP_Metrics):
|
||||
result.loopStepsLeft = NotALoop
|
||||
|
||||
proc newSpawn*(
|
||||
T: typedesc[Task],
|
||||
parent: ptr Task,
|
||||
@ -193,9 +200,10 @@ proc newSpawn*(
|
||||
result.fn = fn
|
||||
cast[ptr[type env]](result.env)[] = env
|
||||
|
||||
func ceilDiv_vartime*(a, b: auto): auto {.inline.} =
|
||||
## ceil division, to be used only on length or at compile-time
|
||||
## ceil(a / b)
|
||||
when defined(TP_Metrics):
|
||||
result.loopStepsLeft = NotALoop
|
||||
|
||||
func ceilDiv_vartime(a, b: auto): auto {.inline.} =
|
||||
(a + b - 1) div b
|
||||
|
||||
proc newLoop*(
|
||||
@ -287,16 +295,15 @@ func isReady*[T](fv: Flowvar[T]): bool {.inline.} =
|
||||
func readyWith*[T](task: ptr Task, childResult: T) {.inline.} =
|
||||
## Send the Flowvar result from the child thread processing the task
|
||||
## to its parent thread.
|
||||
precondition: not task.isCompleted()
|
||||
cast[ptr (ptr Task, T)](task.env.addr)[1] = childResult
|
||||
|
||||
proc sync*[T](fv: sink Flowvar[T]): T {.noInit, inline, gcsafe.} =
|
||||
## Blocks the current thread until the flowvar is available
|
||||
## and returned.
|
||||
## The thread is not idle and will complete pending tasks.
|
||||
mixin completeFuture
|
||||
completeFuture(fv, result)
|
||||
cleanup(fv)
|
||||
func copyResult*[T](dst: var T, fv: FlowVar[T]) {.inline.} =
|
||||
## Copy the result of a ready Flowvar to `dst`
|
||||
dst = cast[ptr (ptr Task, T)](fv.task.env.addr)[1]
|
||||
|
||||
func getTask*[T](fv: FlowVar[T]): ptr Task {.inline.} =
|
||||
## Copy the result of a ready Flowvar to `dst`
|
||||
fv.task
|
||||
|
||||
# ReductionDagNodes
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@ -24,7 +24,7 @@ block:
|
||||
|
||||
let sum1M = tp.sumReduce(1000000)
|
||||
echo "Sum reduce(0..1000000): ", sum1M
|
||||
doAssert sum1M == 500_000_500_000'i64
|
||||
doAssert sum1M == 500_000_500_000'i64, "incorrect sum was " & $sum1M
|
||||
|
||||
tp.shutdown()
|
||||
|
||||
|
||||
@ -135,12 +135,136 @@ template ascertain*(check: untyped) =
|
||||
## Optional runtime check in the middle of processing
|
||||
assertContract("transient condition", check)
|
||||
|
||||
# Metrics
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
macro defCountersType*(name: untyped, countersDesc: static seq[tuple[field, desc: string]]): untyped =
|
||||
var records = nnkRecList.newTree()
|
||||
|
||||
for (field, _) in countersDesc:
|
||||
records.add newIdentDefs(ident(field), ident"int64")
|
||||
|
||||
result = nnkTypeSection.newTree(
|
||||
nnkTypeDef.newTree(
|
||||
name,
|
||||
newEmptyNode(),
|
||||
nnkObjectTy.newTree(
|
||||
newEmptyNode(),
|
||||
newEmptyNode(),
|
||||
records
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
macro getCounter*(counters: untyped, counterField: static string): untyped =
|
||||
return nnkDotExpr.newTree(counters, ident(counterField))
|
||||
|
||||
# Profiling
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
when defined(TP_Profile):
|
||||
import ./primitives/timers
|
||||
# On windows and Mac, timers.nim uses globals which we want to avoid where possible
|
||||
|
||||
var ProfilerRegistry {.compileTime.}: seq[string]
|
||||
|
||||
template checkName(name: untyped) {.used.} =
|
||||
static:
|
||||
if astToStr(name) notin ProfilerRegistry:
|
||||
raise newException(
|
||||
ValueError,
|
||||
"Invalid profile name: \"" & astToStr(name) & "\"\n" &
|
||||
"Only " & $ProfilerRegistry & " are valid")
|
||||
|
||||
# With untyped dirty templates we need to bind the symbol early
|
||||
# otherwise they are resolved too late in a scope where they don't exist/
|
||||
# Alternatively we export ./timer.nim.
|
||||
|
||||
template profileDecl*(name: untyped): untyped {.dirty.} =
|
||||
bind ProfilerRegistry, Timer
|
||||
static: ProfilerRegistry.add astToStr(name)
|
||||
var `timer _ name`{.inject, threadvar.}: Timer
|
||||
|
||||
template profileInit*(name: untyped) {.dirty.} =
|
||||
bind checkName, reset
|
||||
checkName(name)
|
||||
reset(`timer _ name`)
|
||||
|
||||
macro profileStart*(name: untyped): untyped =
|
||||
newCall(bindSym"start", ident("timer_" & $name))
|
||||
|
||||
macro profileStop*(name: untyped): untyped =
|
||||
newCall(bindSym"stop", ident("timer_" & $name))
|
||||
|
||||
template profile*(name, body: untyped): untyped =
|
||||
profile_start(name)
|
||||
body
|
||||
profile_stop(name)
|
||||
|
||||
macro printWorkerProfiling*(workerID: SomeInteger): untyped =
|
||||
|
||||
let timerUnit = bindSym"kMilliseconds"
|
||||
|
||||
result = newStmtList()
|
||||
let strUnit = ident"strUnit"
|
||||
result.add newConstStmt(strUnit, newCall(bindSym"$", timerUnit))
|
||||
|
||||
var formatString = "Worker %3d: timerId %2d, %10.3lf, %s, %s\n"
|
||||
|
||||
var cumulated = newCall(bindSym"getElapsedCumulatedTime")
|
||||
for i in 0 ..< ProfilerRegistry.len:
|
||||
var fnCall = newCall(bindSym"c_printf", newLit(formatString), workerID, newLit(i))
|
||||
let timer = ident("timer_" & ProfilerRegistry[i])
|
||||
fnCall.add newCall(bindSym"getElapsedTime", timer, timerUnit)
|
||||
fnCall.add strUnit
|
||||
fnCall.add newLit(ProfilerRegistry[i])
|
||||
|
||||
cumulated.add timer
|
||||
result.add fnCall
|
||||
|
||||
cumulated.add timerUnit
|
||||
result.add newCall(
|
||||
bindSym"c_printf",
|
||||
newLit(formatString),
|
||||
workerID,
|
||||
newLit(ProfilerRegistry.len),
|
||||
cumulated,
|
||||
strUnit,
|
||||
newLit"cumulated_time")
|
||||
|
||||
result.add newCall(bindSym"flushFile", bindSym"stdout")
|
||||
|
||||
else:
|
||||
template profileDecl*(name: untyped): untyped = discard
|
||||
template profileInit*(name: untyped) = discard
|
||||
template profileStart*(name: untyped) = discard
|
||||
template profileStop*(name: untyped) = discard
|
||||
template profile*(name, body: untyped): untyped =
|
||||
body
|
||||
template printWorkerProfiling*(workerID: untyped): untyped = discard
|
||||
|
||||
# Sanity checks
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
when isMainModule:
|
||||
proc assertGreater(x, y: int) =
|
||||
postcondition(x > y)
|
||||
|
||||
# We should get a nicely formatted exception
|
||||
assertGreater(10, 12)
|
||||
block:
|
||||
proc assertGreater(x, y: int) =
|
||||
postcondition(x > y)
|
||||
|
||||
# We should get a nicely formatted exception
|
||||
# assertGreater(10, 12)
|
||||
|
||||
block:
|
||||
let ID = 0
|
||||
|
||||
profileDecl(run_task)
|
||||
profileDecl(idle)
|
||||
|
||||
profileInit(run_task)
|
||||
profileInit(idle)
|
||||
|
||||
profile(run_task):
|
||||
discard
|
||||
|
||||
printWorkerProfiling(ID)
|
||||
|
||||
@ -35,8 +35,27 @@ import
|
||||
# #
|
||||
# ############################################################
|
||||
|
||||
proc needTempStorage(argTy: NimNode): bool =
|
||||
case argTy.kind
|
||||
of nnkVarTy:
|
||||
error("It is unsafe to capture a `var` parameter and pass it to another thread. Its memory location could be invalidated if the spawning proc returns before the worker thread finishes.")
|
||||
of nnkStaticTy:
|
||||
return false
|
||||
of nnkBracketExpr:
|
||||
if argTy[0].typeKind == ntyTypeDesc:
|
||||
return false
|
||||
else:
|
||||
return true
|
||||
of nnkCharLit..nnkNilLit:
|
||||
return false
|
||||
else:
|
||||
return true
|
||||
|
||||
proc spawnVoid(funcCall: NimNode, args, argsTy: NimNode, workerContext, schedule: NimNode): NimNode =
|
||||
# Create the async function
|
||||
## Spawn a function that can be scheduled on another thread
|
||||
## without return value.
|
||||
result = newStmtList()
|
||||
|
||||
let fn = funcCall[0]
|
||||
let fnName = $fn
|
||||
let withArgs = args.len > 0
|
||||
@ -44,29 +63,31 @@ proc spawnVoid(funcCall: NimNode, args, argsTy: NimNode, workerContext, schedule
|
||||
var fnCall = newCall(fn)
|
||||
let env = ident("ctt_tpSpawnVoidEnv_") # typed pointer to env
|
||||
|
||||
# Schedule
|
||||
let task = ident"ctt_tpSpawnVoidTask_"
|
||||
let scheduleBlock = newCall(schedule, workerContext, task)
|
||||
# Closure unpacker
|
||||
var envParams = nnkTupleConstr.newTree()
|
||||
var envParamsTy = nnkTupleConstr.newTree()
|
||||
var envOffset = 0
|
||||
|
||||
result = newStmtList()
|
||||
|
||||
if funcCall.len == 2:
|
||||
# With only 1 arg, the tuple syntax doesn't construct a tuple
|
||||
# let env = (123) # is an int
|
||||
fnCall.add nnkDerefExpr.newTree(env)
|
||||
else: # This handles the 0 arg case as well
|
||||
for i in 1 ..< funcCall.len:
|
||||
fnCall.add nnkBracketExpr.newTree(
|
||||
env,
|
||||
newLit i-1)
|
||||
for i in 0 ..< args.len:
|
||||
if argsTy[i].needTempStorage():
|
||||
envParamsTy.add argsTy[i]
|
||||
envParams.add args[i]
|
||||
fnCall.add nnkBracketExpr.newTree(env, newLit envOffset)
|
||||
envOffset += 1
|
||||
else:
|
||||
fnCall.add args[i]
|
||||
|
||||
# Create the async call
|
||||
result.add quote do:
|
||||
proc `tpSpawn_closure`(env: pointer) {.nimcall, gcsafe, raises: [].} =
|
||||
when bool(`withArgs`):
|
||||
let `env` = cast[ptr `argsTy`](env)
|
||||
let `env` = cast[ptr `envParamsTy`](env)
|
||||
`fnCall`
|
||||
|
||||
# Schedule
|
||||
let task = ident"ctt_tpSpawnVoidTask_"
|
||||
let scheduleBlock = newCall(schedule, workerContext, task)
|
||||
|
||||
# Create the task
|
||||
result.add quote do:
|
||||
block enq_deq_task:
|
||||
@ -74,15 +95,82 @@ proc spawnVoid(funcCall: NimNode, args, argsTy: NimNode, workerContext, schedule
|
||||
let `task` = Task.newSpawn(
|
||||
parent = `workerContext`.currentTask,
|
||||
fn = `tpSpawn_closure`,
|
||||
env = `args`)
|
||||
env = `envParams`)
|
||||
else:
|
||||
let `task` = Task.newSpawn(
|
||||
parent = `workerContext`.currentTask,
|
||||
fn = `tpSpawn_closure`)
|
||||
`scheduleBlock`
|
||||
|
||||
proc spawnVoidAwaitable(funcCall: NimNode, args, argsTy: NimNode, workerContext, schedule: NimNode): NimNode =
|
||||
## Spawn a function that can be scheduled on another thread
|
||||
## with a dummy awaitable return value
|
||||
result = newStmtList()
|
||||
|
||||
let fn = funcCall[0]
|
||||
let fnName = $fn
|
||||
let tpSpawn_closure = ident("ctt_tpSpawnVoidAwaitableClosure_" & fnName)
|
||||
var fnCall = newCall(fn)
|
||||
let env = ident("ctt_tpSpawnVoidAwaitableEnv_") # typed pointer to env
|
||||
|
||||
# tasks have no return value.
|
||||
# 1. The start of the task `env` buffer will store the return value for the flowvar and awaiter/sync
|
||||
# 2. We create a wrapper tpSpawn_closure without return value that send the return value in the channel
|
||||
# 3. We package that wrapper function in a task
|
||||
|
||||
# We store the following in task.env:
|
||||
#
|
||||
# | ptr Task | result | arg₀ | arg₁ | ... | argₙ
|
||||
let fut = ident"ctt_tpSpawnVoidAwaitableFut_"
|
||||
let taskSelfReference = ident"ctt_taskSelfReference"
|
||||
|
||||
# Closure unpacker
|
||||
# env stores | ptr Task | result | arg₀ | arg₁ | ... | argₙ
|
||||
# so arguments starts at env[2] in the wrapping funcCall functions
|
||||
var envParams = nnkTupleConstr.newTree()
|
||||
var envParamsTy = nnkTupleConstr.newTree()
|
||||
envParams.add taskSelfReference
|
||||
envParamsTy.add nnkPtrTy.newTree(bindSym"Task")
|
||||
envParams.add newLit(false)
|
||||
envParamsTy.add getType(bool)
|
||||
var envOffset = 2
|
||||
|
||||
for i in 0 ..< args.len:
|
||||
if argsTy[i].needTempStorage():
|
||||
envParamsTy.add argsTy[i]
|
||||
envParams.add args[i]
|
||||
fnCall.add nnkBracketExpr.newTree(env, newLit envOffset)
|
||||
envOffset += 1
|
||||
else:
|
||||
fnCall.add args[i]
|
||||
|
||||
result.add quote do:
|
||||
proc `tpSpawn_closure`(env: pointer) {.nimcall, gcsafe, raises: [].} =
|
||||
let `env` = cast[ptr `envParamsTy`](env)
|
||||
`fnCall`
|
||||
readyWith(`env`[0], true)
|
||||
|
||||
# Schedule
|
||||
let task = ident"ctt_tpSpawnVoidAwaitableTask_"
|
||||
let scheduleBlock = newCall(schedule, workerContext, task)
|
||||
|
||||
# Create the task
|
||||
result.add quote do:
|
||||
block enq_deq_task:
|
||||
let `taskSelfReference` = cast[ptr Task](0xDEADBEEF)
|
||||
|
||||
let `task` = Task.newSpawn(
|
||||
parent = `workerContext`.currentTask,
|
||||
fn = `tpSpawn_closure`,
|
||||
env = `envParams`)
|
||||
let `fut` = newFlowVar(bool, `task`)
|
||||
`scheduleBlock`
|
||||
# Return the future
|
||||
`fut`
|
||||
|
||||
proc spawnRet(funcCall: NimNode, retTy, args, argsTy: NimNode, workerContext, schedule: NimNode): NimNode =
|
||||
# Create the async function
|
||||
## Spawn a function that can be scheduled on another thread
|
||||
## with an awaitable future return value.
|
||||
result = newStmtList()
|
||||
|
||||
let fn = funcCall[0]
|
||||
@ -103,21 +191,25 @@ proc spawnRet(funcCall: NimNode, retTy, args, argsTy: NimNode, workerContext, sc
|
||||
let taskSelfReference = ident"ctt_taskSelfReference"
|
||||
let retVal = ident"ctt_retVal"
|
||||
|
||||
var envParams = nnkPar.newTree
|
||||
var envParamsTy = nnkPar.newTree
|
||||
# Closure unpacker
|
||||
# env stores | ptr Task | result | arg₀ | arg₁ | ... | argₙ
|
||||
# so arguments starts at env[2] in the wrapping funcCall functions
|
||||
var envParams = nnkTupleConstr.newTree()
|
||||
var envParamsTy = nnkTupleConstr.newTree()
|
||||
envParams.add taskSelfReference
|
||||
envParamsTy.add nnkPtrTy.newTree(bindSym"Task")
|
||||
envParams.add retVal
|
||||
envParamsTy.add retTy
|
||||
var envOffset = 2
|
||||
|
||||
for i in 1 ..< funcCall.len:
|
||||
envParamsTy.add getTypeInst(funcCall[i])
|
||||
envParams.add funcCall[i]
|
||||
|
||||
# env stores | ptr Task | result | arg₀ | arg₁ | ... | argₙ
|
||||
# so arguments starts at env[2] in the wrapping funcCall functions
|
||||
for i in 1 ..< funcCall.len:
|
||||
fnCall.add nnkBracketExpr.newTree(env, newLit i+1)
|
||||
for i in 0 ..< args.len:
|
||||
if argsTy[i].needTempStorage():
|
||||
envParamsTy.add argsTy[i]
|
||||
envParams.add args[i]
|
||||
fnCall.add nnkBracketExpr.newTree(env, newLit envOffset)
|
||||
envOffset += 1
|
||||
else:
|
||||
fnCall.add args[i]
|
||||
|
||||
result.add quote do:
|
||||
proc `tpSpawn_closure`(env: pointer) {.nimcall, gcsafe, raises: [].} =
|
||||
@ -125,8 +217,8 @@ proc spawnRet(funcCall: NimNode, retTy, args, argsTy: NimNode, workerContext, sc
|
||||
let res = `fnCall`
|
||||
readyWith(`env`[0], res)
|
||||
|
||||
# Regenerate fresh ident, retTy has been tagged as a function call param
|
||||
let retTy = ident($retTy)
|
||||
# Schedule
|
||||
let retTy = ident($retTy) # Regenerate fresh ident, retTy has been tagged as a function call param
|
||||
let task = ident"ctt_tpSpawnRetTask_"
|
||||
let scheduleBlock = newCall(schedule, workerContext, task)
|
||||
|
||||
@ -154,8 +246,8 @@ proc spawnImpl*(tp: NimNode{nkSym}, funcCall: NimNode, workerContext, schedule:
|
||||
|
||||
# Get a serialized type and data for all function arguments
|
||||
# We use adhoc tuple
|
||||
var argsTy = nnkPar.newTree()
|
||||
var args = nnkPar.newTree()
|
||||
var argsTy = nnkTupleConstr.newTree()
|
||||
var args = nnkTupleConstr.newTree()
|
||||
for i in 1 ..< funcCall.len:
|
||||
argsTy.add getTypeInst(funcCall[i])
|
||||
args.add funcCall[i]
|
||||
@ -169,6 +261,29 @@ proc spawnImpl*(tp: NimNode{nkSym}, funcCall: NimNode, workerContext, schedule:
|
||||
# Wrap in a block for namespacing
|
||||
result = nnkBlockStmt.newTree(newEmptyNode(), result)
|
||||
|
||||
proc spawnAwaitableImpl*(tp: NimNode{nkSym}, funcCall: NimNode, workerContext, schedule: NimNode): NimNode =
|
||||
funcCall.expectKind(nnkCall)
|
||||
|
||||
# Get the return type if any
|
||||
let retTy = funcCall[0].getImpl[3][0]
|
||||
let needFuture = retTy.kind != nnkEmpty
|
||||
if needFuture:
|
||||
error "spawnAwaitable can only be used with procedures without returned values"
|
||||
|
||||
# Get a serialized type and data for all function arguments
|
||||
# We use adhoc tuple
|
||||
var argsTy = nnkTupleConstr.newTree()
|
||||
var args = nnkTupleConstr.newTree()
|
||||
for i in 1 ..< funcCall.len:
|
||||
argsTy.add getTypeInst(funcCall[i])
|
||||
args.add funcCall[i]
|
||||
|
||||
# Package in a task
|
||||
result = spawnVoidAwaitable(funcCall, args, argsTy, workerContext, schedule)
|
||||
|
||||
# Wrap in a block for namespacing
|
||||
result = nnkBlockStmt.newTree(newEmptyNode(), result)
|
||||
|
||||
# ############################################################
|
||||
# #
|
||||
# Data parallelism #
|
||||
|
||||
116
constantine/platforms/threadpool/partitioners.nim
Normal file
116
constantine/platforms/threadpool/partitioners.nim
Normal file
@ -0,0 +1,116 @@
|
||||
# Constantine
|
||||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
# ########################################################### #
|
||||
# #
|
||||
# Static Partitioning algorithms #
|
||||
# #
|
||||
# ########################################################### #
|
||||
|
||||
# This file implements static/eager partitioning algorithms.
|
||||
#
|
||||
# see docs/partitioners.md
|
||||
#
|
||||
# Note:
|
||||
# Those algorithms cannot take into account:
|
||||
# - Other workloads on the computer
|
||||
# - Heterogenous cores (for example Big.Little on ARM or Performance/Efficiency on x86)
|
||||
# - Load imbalance (i.e. raytracing a wall vs raytracing a mirror)
|
||||
# - CPU performance (see https://github.com/zy97140/omp-benchmark-for-pytorch)
|
||||
|
||||
type ChunkDescriptor* = object
|
||||
start, numSteps: int
|
||||
numChunks*, baseChunkSize, cutoff: int
|
||||
|
||||
iterator items*(c: ChunkDescriptor): tuple[chunkID, start, size: int] =
|
||||
for chunkID in 0 ..< c.numChunks:
|
||||
if chunkID < c.cutoff:
|
||||
let offset = c.start + ((c.baseChunkSize + 1) * chunkID)
|
||||
let chunkSize = c.baseChunkSize + 1
|
||||
yield (chunkID, offset, min(chunkSize, c.numSteps-offset))
|
||||
else:
|
||||
let offset = c.start + (c.baseChunkSize * chunkID) + c.cutoff
|
||||
let chunkSize = c.baseChunkSize
|
||||
yield (chunkID, offset, min(chunkSize, c.numSteps-offset))
|
||||
|
||||
func ceilDiv_vartime(a, b: auto): auto {.inline.} =
|
||||
(a + b - 1) div b
|
||||
|
||||
func balancedChunksPrioNumber*(start, stopEx, numChunks: int): ChunkDescriptor {.inline.} =
|
||||
## Balanced chunking algorithm for a range [start, stopEx)
|
||||
## This splits a range into min(stopEx-start, numChunks) balanced regions
|
||||
# Rationale
|
||||
# The following simple chunking scheme can lead to severe load imbalance
|
||||
#
|
||||
# let chunk_offset = chunk_size * thread_id
|
||||
# let chunk_size = if thread_id < nb_chunks - 1: chunk_size
|
||||
# else: omp_size - chunk_offset
|
||||
#
|
||||
# For example dividing 40 items on 12 threads will lead to
|
||||
# a base_chunk_size of 40/12 = 3 so work on the first 11 threads
|
||||
# will be 3 * 11 = 33, and the remainder 7 on the last thread.
|
||||
#
|
||||
# Instead of dividing 40 work items on 12 cores into:
|
||||
# 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 7 = 3*11 + 7 = 40
|
||||
# the following scheme will divide into
|
||||
# 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3 = 4*4 + 3*8 = 40
|
||||
#
|
||||
# This is compliant with OpenMP spec (page 60)
|
||||
# http://www.openmp.org/mp-documents/openmp-4.5.pdf
|
||||
# "When no chunk_size is specified, the iteration space is divided into chunks
|
||||
# that are approximately equal in size, and at most one chunk is distributed to
|
||||
# each thread. The size of the chunks is unspecified in this case."
|
||||
# ---> chunks are the same ±1
|
||||
|
||||
let
|
||||
numSteps = stopEx - start
|
||||
baseChunkSize = numSteps div numChunks
|
||||
cutoff = numSteps mod numChunks
|
||||
|
||||
return ChunkDescriptor(
|
||||
start: start, numSteps: numSteps,
|
||||
numChunks: numChunks, baseChunkSize: baseChunkSize, cutoff: cutoff)
|
||||
|
||||
func balancedChunksPrioSize*(start, stopEx, minChunkSize, maxChunkSize, numChunksHint: int): ChunkDescriptor =
|
||||
## Balanced chunking algorithm for a range [start, stopEx)
|
||||
## This ideally splits a range into min(stopEx-start, numChunksHint) balanced regions
|
||||
## unless the chunk size isn't in the range [minChunkSize, maxChunkSize]
|
||||
#
|
||||
# so many division/modulo. Can we do better?
|
||||
let numSteps = stopEx - start
|
||||
var numChunks = max(numChunksHint, 1)
|
||||
var baseChunkSize = numSteps div numChunks
|
||||
var cutoff = numSteps mod numChunks # Should be computed in a single instruction with baseChunkSize
|
||||
|
||||
if baseChunkSize < minChunkSize:
|
||||
numChunks = max(numSteps div minChunkSize, 1)
|
||||
baseChunkSize = numSteps div numChunks
|
||||
cutoff = numSteps mod numChunks
|
||||
elif baseChunkSize > maxChunkSize or (baseChunkSize == maxChunkSize and cutoff != 0):
|
||||
# After cutoff, we do baseChunkSize+1, and would run afoul of the maxChunkSize constraint (unless no remainder), hence ceildiv
|
||||
numChunks = numSteps.ceilDiv_vartime(maxChunkSize)
|
||||
baseChunkSize = numSteps div numChunks
|
||||
cutoff = numSteps mod numChunks
|
||||
|
||||
return ChunkDescriptor(
|
||||
start: start, numSteps: numSteps,
|
||||
numChunks: numChunks, baseChunkSize: baseChunkSize, cutoff: cutoff)
|
||||
|
||||
# Sanity checks
|
||||
# ---------------------------------------
|
||||
|
||||
when isMainModule:
|
||||
block:
|
||||
let chunkDesc = balancedChunksPrioSize(start = 0, stopEx = 40, minChunkSize = 16, maxChunkSize = 128, numChunksHint = 12)
|
||||
for chunk in chunkDesc:
|
||||
echo chunk
|
||||
|
||||
block:
|
||||
let chunkDesc = balancedChunksPrioSize(start = 0, stopEx = 10000, minChunkSize = 16, maxChunkSize = 128, numChunksHint = 12)
|
||||
for chunk in chunkDesc:
|
||||
echo chunk
|
||||
148
constantine/platforms/threadpool/primitives/timers.nim
Normal file
148
constantine/platforms/threadpool/primitives/timers.nim
Normal file
@ -0,0 +1,148 @@
|
||||
# Weave
|
||||
# Copyright (c) 2019 Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
# Timers
|
||||
# ----------------------------------------------------------------------------------
|
||||
#
|
||||
# While for benchmarking we can enclose our microbenchmark target between clocks utilities
|
||||
# for timing the inner part of complex code, the issue become the overhead introduced.
|
||||
# In particular, since the kernel maintains the system time, syscall overhead.
|
||||
# As shown in https://gms.tf/on-the-costs-of-syscalls.html
|
||||
# something like clock_gettime_mono_raw can take from 20ns to 760ns.
|
||||
# Furthermore real syscalls will pollute the cache, which isn't a problem
|
||||
# when benchmarking steal code (since there is no work) but is when benchmarking loop splitting a tight loop.
|
||||
#
|
||||
# Ideally we would use the RDTSC instruction, it takes 5.5 cycles
|
||||
# https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=rdtsc&expand=5578&ig_expand=5803
|
||||
# https://www.intel.com/content/dam/www/public/us/en/documents/white-papers/ia-32-ia-64-benchmark-code-execution-paper.pdf
|
||||
# and has a latency of 42 cycles (i.e. 2 RDTSCs back-to-back will take 42 cycles, preventing accurate measurement of something that costs less).
|
||||
# But converting timestamp counter (TSC) to time is very tricky business, and would require a lot of code for accuracy
|
||||
# - https://stackoverflow.com/a/42190816
|
||||
# - https://github.com/torvalds/linux/blob/master/arch/x86/kernel/tsc.c
|
||||
# - https://github.com/torvalds/linux/blob/master/tools/power/x86/turbostat/turbostat.c
|
||||
# especially when even within the same CPU family you have quirks:
|
||||
# - https://lore.kernel.org/lkml/ff6dcea166e8ff8f2f6a03c17beab2cb436aa779.1513920414.git.len.brown@intel.com/
|
||||
# "while SKX servers use a 25 MHz crystal, SKX workstations (with same model #) use a 24 MHz crystal.
|
||||
# This results in a -4.0% time drift rate on SKX workstations."
|
||||
# "While SKX servers do have a 25 MHz crystal, but they too have a problem.
|
||||
# All SKX subject the crystal to an EMI reduction circuit that
|
||||
# reduces its actual frequency by (approximately) -0.25%.
|
||||
# This results in -1 second per 10 minute time drift
|
||||
# as compared to network time."
|
||||
#
|
||||
# So with either use the monotonic clock, hoping it uses vDSO instead of a full syscall, so overhead is just ~20ns (60 cycles on 3GHz CPU)
|
||||
# or we use RDTSC, getting the TSC frequency can be done by installing the `turbostat` package.
|
||||
#
|
||||
# We choose the monotonic clockfor portability and to not deal
|
||||
# with TSC on x86 (and possibly other architectures). Due to this, timers aren't meaningful
|
||||
# on scheduler overhead workloads like fibonacci or DFS as we would be measuring the clock.
|
||||
|
||||
type Ticks = distinct int64
|
||||
|
||||
when defined(linux):
|
||||
# https://github.com/torvalds/linux/blob/v6.2/include/uapi/linux/time.h
|
||||
type
|
||||
Timespec {.pure, final, importc: "struct timespec", header: "<time.h>".} = object
|
||||
tv_sec: clong ## Seconds.
|
||||
tv_nsec: clong ## Nanoseconds.
|
||||
|
||||
const CLOCK_MONOTONIC = cint 1
|
||||
const SecondsInNanoseconds = 1_000_000_000
|
||||
|
||||
proc clock_gettime(clockKind: cint, dst: var Timespec): cint {.sideeffect, discardable, importc, header: "<time.h>".}
|
||||
## Returns the clock kind value in dst
|
||||
## Returns 0 on success or -1 on failure
|
||||
|
||||
proc getTicks(): Ticks {.inline.} =
|
||||
var ts {.noInit.}: Timespec
|
||||
clock_gettime(CLOCK_MONOTONIC, ts)
|
||||
return Ticks(ts.tv_sec.int64 * SecondsInNanoseconds + ts.tv_nsec.int64)
|
||||
|
||||
func elapsedNs(start, stop: Ticks): int64 {.inline.} =
|
||||
## Returns the elapsed time in nano-seconds from ticks
|
||||
stop.int64 - start.int64
|
||||
|
||||
elif defined(macosx):
|
||||
type
|
||||
MachTimebaseInfoData {.pure, final, importc: "mach_timebase_info_data_t", header: "<mach/mach_time.h>".} = object
|
||||
numer, denom: int32
|
||||
|
||||
proc mach_absolute_time(): Ticks {.sideeffect, importc, header: "<mach/mach.h>".}
|
||||
proc mach_timebase_info(info: var MachTimebaseInfoData) {.importc, header: "<mach/mach_time.h>".}
|
||||
|
||||
## initialize MTI once at program startup
|
||||
var mti: MachTimebaseInfoData
|
||||
mach_timebase_info(mti)
|
||||
let mti_f64_num = float64(mti.numer)
|
||||
let mti_f64_den = float64(mti_denom)
|
||||
|
||||
proc getTicks(): Ticks {.inline.} =
|
||||
## On OSX, Ticks to nanoseconds is done via multiplying by MachTimeBasedInfo fraction
|
||||
return mach_absolute_time()
|
||||
|
||||
proc elapsedNs(start, stop: Ticks): int64 {.inline.} =
|
||||
## Returns the elapsed time in nano-seconds from ticks
|
||||
# Integer division is slow ~ 55 cycles at least.
|
||||
# Also division is imprecise but we don't really care about the error there
|
||||
# only the relative magnitude between various timers.
|
||||
# Otherwise we can use 128-bit precision or continued fractions: https://stackoverflow.com/questions/23378063/how-can-i-use-mach-absolute-time-without-overflowing
|
||||
int64(float64(stop.int64 - start.int64) * mti_f64_num / mti_f64_den)
|
||||
|
||||
elif defined(windows):
|
||||
proc QueryPerformanceCounter(res: var Ticks) {.importc: "QueryPerformanceCounter", stdcall, dynlib: "kernel32".}
|
||||
proc QueryPerformanceFrequency(res: var uint64) {.importc: "QueryPerformanceFrequency", stdcall, dynlib: "kernel32".}
|
||||
|
||||
# initialize performance frequency once at startup
|
||||
# https://learn.microsoft.com/en-us/windows/win32/api/profileapi/nf-profileapi-queryperformancefrequency
|
||||
var perfFreq: uint64
|
||||
QueryPerformanceFrequency(perfFreq)
|
||||
let nsRatio = 1e9'f64 / float64(perfFreq)
|
||||
|
||||
proc getTicks(): Ticks {.inline.} =
|
||||
QueryPerformanceCounter(result)
|
||||
|
||||
proc elapsedNs(start, stop: Ticks): int64 {.inline.} =
|
||||
# Because 10⁹ is so large, multiplying by it first then dividing will accumulate a lot of FP errors
|
||||
int64(float64(stop.int64 - start.int64) * nsRatio)
|
||||
|
||||
else:
|
||||
{.error: "Timers are not implemented for this OS".}
|
||||
|
||||
type
|
||||
Timer* = object
|
||||
## A timer, resolution in nanoseconds
|
||||
startTicks: Ticks
|
||||
elapsedNS: int64
|
||||
|
||||
TimerUnit* = enum
|
||||
kMicroseconds
|
||||
kMilliseconds
|
||||
kSeconds
|
||||
|
||||
func reset*(timer: var Timer) {.inline.} =
|
||||
timer.startTicks = Ticks(0)
|
||||
timer.elapsedNS = 0
|
||||
|
||||
proc start*(timer: var Timer) {.inline.} =
|
||||
timer.startTicks = getTicks()
|
||||
|
||||
proc stop*(timer: var Timer) {.inline.} =
|
||||
let stop = getTicks()
|
||||
timer.elapsedNS += elapsedNs(timer.startTicks, stop)
|
||||
|
||||
func getElapsedTime*(timer: Timer, kind: TimerUnit): float64 {.inline.} =
|
||||
case kind
|
||||
of kMicroseconds:
|
||||
return timer.elapsedNS.float64 * 1e-3
|
||||
of kMilliseconds:
|
||||
return timer.elapsedNS.float64 * 1e-6
|
||||
of kSeconds:
|
||||
return timer.elapsedNS.float64 * 1e-9
|
||||
|
||||
func getElapsedCumulatedTime*(timers: varargs[Timer], kind: TimerUnit): float64 {.inline.} =
|
||||
for timer in timers:
|
||||
result += timer.getElapsedTime(kind)
|
||||
@ -24,7 +24,11 @@ import
|
||||
|
||||
export
|
||||
# flowvars
|
||||
Flowvar, isSpawned, isReady, sync
|
||||
Flowvar, isSpawned, isReady
|
||||
|
||||
when defined(TP_Metrics):
|
||||
import ../static_for
|
||||
import system/ansi_c
|
||||
|
||||
# ############################################################
|
||||
# #
|
||||
@ -133,6 +137,25 @@ iterator pseudoRandomPermutation(randomSeed: uint32, maxExclusive: int32): int32
|
||||
# #
|
||||
# ############################################################
|
||||
|
||||
let countersDesc {.compileTime.} = @[
|
||||
("tasksScheduled", "tasks scheduled"),
|
||||
("tasksStolen", "tasks stolen"),
|
||||
("tasksExecuted", "tasks executed"),
|
||||
("unrelatedTasksExecuted", "unrelated tasks executed"),
|
||||
("loopsSplit", "loops split"),
|
||||
("itersScheduled", "iterations scheduled"),
|
||||
("itersStolen", "iterations stolen"),
|
||||
("itersExecuted", "iterations executed"),
|
||||
("theftsIdle", "thefts while idle in event loop"),
|
||||
("theftsAwaiting", "thefts while awaiting a future"),
|
||||
("theftsLeapfrog", "leapfrogging thefts"),
|
||||
("backoffGlobalSleep", "sleeps on global backoff"),
|
||||
("backoffGlobalSignalSent", "signals sent on global backoff"),
|
||||
("backoffTaskAwaited", "sleeps on task-local backoff")
|
||||
]
|
||||
|
||||
defCountersType(Counters, countersDesc)
|
||||
|
||||
type
|
||||
WorkerID = int32
|
||||
Signal = object
|
||||
@ -155,16 +178,69 @@ type
|
||||
# Thefts
|
||||
rng: WorkStealingRng # RNG state to select victims
|
||||
|
||||
when defined(TP_Metrics):
|
||||
counters: Counters
|
||||
|
||||
Threadpool* = ptr object
|
||||
barrier: SyncBarrier # Barrier for initialization and teardown
|
||||
# All synchronization objects are put in their own cache-line to avoid invalidating
|
||||
# and reloading cache for immutable fields
|
||||
barrier{.align: 64.}: SyncBarrier # Barrier for initialization and teardown
|
||||
# -- align: 64
|
||||
globalBackoff: EventCount # Multi-Producer Multi-Consumer backoff
|
||||
globalBackoff{.align: 64.}: EventCount # Multi-Producer Multi-Consumer backoff
|
||||
# -- align: 64
|
||||
numThreads*{.align: 64.}: int32 # N regular workers
|
||||
workerQueues: ptr UncheckedArray[Taskqueue] # size N
|
||||
workers: ptr UncheckedArray[Thread[(Threadpool, WorkerID)]] # size N
|
||||
workerSignals: ptr UncheckedArray[Signal] # size N
|
||||
|
||||
# ############################################################
|
||||
# #
|
||||
# Metrics #
|
||||
# #
|
||||
# ############################################################
|
||||
|
||||
template metrics(body: untyped): untyped =
|
||||
when defined(TP_Metrics):
|
||||
block: {.noSideEffect, gcsafe.}: body
|
||||
|
||||
template incCounter(ctx: var WorkerContext, name: untyped{ident}, amount = 1) =
|
||||
bind name
|
||||
metrics:
|
||||
# Assumes workerContext is in the calling context
|
||||
ctx.counters.name += amount
|
||||
|
||||
proc calcDerivedMetrics(ctx: var WorkerContext) {.used.} =
|
||||
metrics:
|
||||
ctx.counters.tasksStolen = ctx.counters.theftsIdle + ctx.counters.theftsAwaiting
|
||||
|
||||
proc printWorkerMetrics(ctx: var WorkerContext) =
|
||||
metrics:
|
||||
if ctx.id == 0:
|
||||
c_printf("\n")
|
||||
c_printf("+========================================+\n")
|
||||
c_printf("| Per-worker statistics |\n")
|
||||
c_printf("+========================================+\n")
|
||||
flushFile(stdout)
|
||||
|
||||
ctx.calcDerivedMetrics()
|
||||
discard ctx.threadpool.barrier.wait()
|
||||
|
||||
staticFor i, 0, countersDesc.len:
|
||||
const (propName, propDesc) = countersDesc[i]
|
||||
c_printf("Worker %3d: counterId %2d, %10d, %-32s\n", ctx.id, i, ctx.counters.getCounter(propName), propDesc)
|
||||
|
||||
flushFile(stdout)
|
||||
|
||||
# ############################################################
|
||||
# #
|
||||
# Profiling #
|
||||
# #
|
||||
# ############################################################
|
||||
|
||||
profileDecl(run_task)
|
||||
profileDecl(backoff_idle)
|
||||
profileDecl(backoff_awaiting)
|
||||
|
||||
# ############################################################
|
||||
# #
|
||||
# Workers #
|
||||
@ -238,6 +314,9 @@ proc workerEntryFn(params: tuple[threadpool: Threadpool, id: WorkerID]) {.raises
|
||||
# 1 matching barrier in threadpool.shutdown() for root thread
|
||||
discard params.threadpool.barrier.wait()
|
||||
|
||||
ctx.printWorkerMetrics()
|
||||
ctx.id.printWorkerProfiling()
|
||||
|
||||
ctx.teardownWorker()
|
||||
|
||||
# ############################################################
|
||||
@ -245,44 +324,85 @@ proc workerEntryFn(params: tuple[threadpool: Threadpool, id: WorkerID]) {.raises
|
||||
# Tasks #
|
||||
# #
|
||||
# ############################################################
|
||||
#
|
||||
# Task notification overview
|
||||
#
|
||||
# 2 strategies can be used to notify idle workers of new tasks entering the runtime
|
||||
#
|
||||
# 1. "notify-on-new": Always try to wake a worker on the backoff on new tasks.
|
||||
# 2. "notify-on-transition": Wake a worker if-and-only-if our queue was empty when scheduling the task.
|
||||
#
|
||||
# In the second case, we also need a notify on successful theft to maintain the invariant that
|
||||
# there is at least a thread looking for work if work is available, or all threads are busy.
|
||||
#
|
||||
# The notify-on-transition strategy minimizes kernel syscalls at the expense of reading an atomic,
|
||||
# our dequeue status (a guaranteed cache miss).
|
||||
# This is almost always the better tradeoff.
|
||||
# Furthermore, in work-stealing, having an empty dequeue is a good approximation for starvation.
|
||||
#
|
||||
# We can minimize syscalls in the "notify-on-new" strategy as well by reading the backoff status
|
||||
# and checking if there is an idle worker not yet parked or no parked threads at all.
|
||||
# In that case we also need a notification on successful theft,
|
||||
# in case 2 threads enqueueing work find concurrently the same non-parked idle worker, otherwise one task will be missed.
|
||||
#
|
||||
# The "notify-on-new" minimizes latency in case a producer enqueues tasks quickly.
|
||||
#
|
||||
# Lastly, when awaiting a future, a worker can give up its own queue if tasks are unrelated to the awaited task.
|
||||
# In "notify-on-transition" strategy, that worker needs to wake up a relay.
|
||||
#
|
||||
# Concretely on almost-empty tasks like fibonacci or DFS, "notify-on-new" is 10x slower.
|
||||
# However, when quickly enqueueing tasks, like Multi-Scalar Multiplication,
|
||||
# There is a noticeable ramp-up.This might be solved with steal-half.
|
||||
|
||||
# Sentinel values
|
||||
const RootTask = cast[ptr Task](0xEFFACED0)
|
||||
|
||||
proc run*(ctx: var WorkerContext, task: ptr Task) {.raises:[].} =
|
||||
proc run(ctx: var WorkerContext, task: ptr Task) {.raises:[].} =
|
||||
## Run a task, frees it if it is not owned by a Flowvar
|
||||
let suspendedTask = ctx.currentTask
|
||||
ctx.currentTask = task
|
||||
debug: log("Worker %3d: running task 0x%.08x (previous: 0x%.08x, %d pending, thiefID %d)\n", ctx.id, task, suspendedTask, ctx.taskqueue[].peek(), task.getThief())
|
||||
task.fn(task.env.addr)
|
||||
profile(run_task):
|
||||
task.fn(task.env.addr)
|
||||
debug: log("Worker %3d: completed task 0x%.08x (%d pending)\n", ctx.id, task, ctx.taskqueue[].peek())
|
||||
ctx.currentTask = suspendedTask
|
||||
|
||||
ctx.incCounter(tasksExecuted)
|
||||
ctx.incCounter(itersExecuted):
|
||||
if task.loopStepsLeft == NotALoop: 0
|
||||
else: (task.loopStop - task.loopStart + task.loopStride-1) div task.loopStride
|
||||
|
||||
if not task.hasFuture: # Are we the final owner?
|
||||
debug: log("Worker %3d: freeing task 0x%.08x with no future\n", ctx.id, task)
|
||||
freeHeap(task)
|
||||
return
|
||||
|
||||
|
||||
# Sync with an awaiting thread in completeFuture that didn't find work
|
||||
# and transfer ownership of the task to it.
|
||||
debug: log("Worker %3d: transfering task 0x%.08x to future holder\n", ctx.id, task)
|
||||
task.setCompleted()
|
||||
task.setGcReady()
|
||||
|
||||
proc schedule(ctx: var WorkerContext, tn: ptr Task, forceWake = false) {.inline.} =
|
||||
proc schedule(ctx: var WorkerContext, task: ptr Task, forceWake = false) {.inline.} =
|
||||
## Schedule a task in the threadpool
|
||||
## This wakes a sibling thread if our local queue is empty
|
||||
## This wakes another worker if our local queue is empty
|
||||
## or forceWake is true.
|
||||
debug: log("Worker %3d: schedule task 0x%.08x (parent/current task 0x%.08x)\n", ctx.id, tn, tn.parent)
|
||||
debug: log("Worker %3d: schedule task 0x%.08x (parent/current task 0x%.08x)\n", ctx.id, task, task.parent)
|
||||
|
||||
# Instead of notifying every time a task is scheduled, we notify
|
||||
# only when the worker queue is empty. This is a good approximation
|
||||
# of starvation in work-stealing.
|
||||
let wasEmpty = ctx.taskqueue[].peek() == 0
|
||||
ctx.taskqueue[].push(tn)
|
||||
ctx.taskqueue[].push(task)
|
||||
|
||||
ctx.incCounter(tasksScheduled)
|
||||
ctx.incCounter(itersScheduled):
|
||||
if task.loopStepsLeft == NotALoop: 0
|
||||
else: task.loopStepsLeft
|
||||
|
||||
if forceWake or wasEmpty:
|
||||
ctx.threadpool.globalBackoff.wake()
|
||||
ctx.incCounter(backoffGlobalSignalSent)
|
||||
|
||||
# ############################################################
|
||||
# #
|
||||
@ -400,10 +520,9 @@ func increase(backoff: var BalancerBackoff) {.inline.} =
|
||||
func decrease(backoff: var BalancerBackoff) {.inline.} =
|
||||
# On success, we exponentially reduce check window.
|
||||
# Note: the thieves will start contributing as well.
|
||||
backoff.windowLogSize -= 1
|
||||
if backoff.windowLogSize > 0:
|
||||
backoff.windowLogSize -= 1
|
||||
backoff.round = 0
|
||||
if backoff.windowLogSize < 0:
|
||||
backoff.windowLogSize = 0
|
||||
|
||||
proc splitAndDispatchLoop(ctx: var WorkerContext, task: ptr Task, curLoopIndex: int, approxIdle: int32) =
|
||||
# The iterator mutates the task with the first chunk metadata
|
||||
@ -436,6 +555,7 @@ proc splitAndDispatchLoop(ctx: var WorkerContext, task: ptr Task, curLoopIndex:
|
||||
ctx.taskqueue[].push(upperSplit)
|
||||
|
||||
ctx.threadpool.globalBackoff.wakeAll()
|
||||
ctx.incCounter(backoffGlobalSignalSent)
|
||||
|
||||
proc loadBalanceLoop(ctx: var WorkerContext, task: ptr Task, curLoopIndex: int, backoff: var BalancerBackoff) =
|
||||
## Split a parallel loop when necessary
|
||||
@ -448,6 +568,7 @@ proc loadBalanceLoop(ctx: var WorkerContext, task: ptr Task, curLoopIndex: int,
|
||||
let approxIdle = waiters.preSleep + waiters.committedSleep + cast[int32](task.isFirstIter)
|
||||
if approxIdle > 0:
|
||||
ctx.splitAndDispatchLoop(task, curLoopIndex, approxIdle)
|
||||
ctx.incCounter(loopsSplit)
|
||||
backoff.decrease()
|
||||
else:
|
||||
backoff.increase()
|
||||
@ -528,8 +649,6 @@ proc tryStealOne(ctx: var WorkerContext): ptr Task =
|
||||
let stolenTask = ctx.id.steal(ctx.threadpool.workerQueues[targetId])
|
||||
|
||||
if not stolenTask.isNil():
|
||||
# Theft successful, there might be more work for idle threads, wake one
|
||||
ctx.threadpool.globalBackoff.wake()
|
||||
return stolenTask
|
||||
return nil
|
||||
|
||||
@ -555,8 +674,6 @@ proc tryLeapfrog(ctx: var WorkerContext, awaitedTask: ptr Task): ptr Task =
|
||||
|
||||
let leapTask = ctx.id.steal(ctx.threadpool.workerQueues[thiefID])
|
||||
if not leapTask.isNil():
|
||||
# Theft successful, there might be more work for idle threads, wake one
|
||||
ctx.threadpool.globalBackoff.wake()
|
||||
return leapTask
|
||||
return nil
|
||||
|
||||
@ -575,6 +692,16 @@ proc eventLoop(ctx: var WorkerContext) {.raises:[], gcsafe.} =
|
||||
if (var stolenTask = ctx.tryStealOne(); not stolenTask.isNil):
|
||||
# We manage to steal a task, cancel sleep
|
||||
ctx.threadpool.globalBackoff.cancelSleep()
|
||||
# Theft successful, there might be more work for idle threads, wake one
|
||||
# cancelSleep must be done before as wake has an optimization
|
||||
# to not notify when a thread is sleepy
|
||||
ctx.threadpool.globalBackoff.wake()
|
||||
ctx.incCounter(backoffGlobalSignalSent)
|
||||
|
||||
ctx.incCounter(theftsIdle)
|
||||
ctx.incCounter(itersStolen):
|
||||
if stolenTask.loopStepsLeft == NotALoop: 0
|
||||
else: stolenTask.loopStepsLeft
|
||||
# 2.a Run task
|
||||
debug: log("Worker %3d: eventLoop 2.a - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, stolenTask, stolenTask.parent, ctx.currentTask)
|
||||
ctx.run(stolenTask)
|
||||
@ -586,7 +713,9 @@ proc eventLoop(ctx: var WorkerContext) {.raises:[], gcsafe.} =
|
||||
else:
|
||||
# 2.c Park the thread until a new task enters the threadpool
|
||||
debugTermination: log("Worker %3d: eventLoop 2.b - sleeping\n", ctx.id)
|
||||
ctx.threadpool.globalBackoff.sleep(ticket)
|
||||
ctx.incCounter(backoffGlobalSleep)
|
||||
profile(backoff_idle):
|
||||
ctx.threadpool.globalBackoff.sleep(ticket)
|
||||
debugTermination: log("Worker %3d: eventLoop 2.b - waking\n", ctx.id)
|
||||
|
||||
# ############################################################
|
||||
@ -595,14 +724,14 @@ proc eventLoop(ctx: var WorkerContext) {.raises:[], gcsafe.} =
|
||||
# #
|
||||
# ############################################################
|
||||
|
||||
proc completeFuture*[T](fv: Flowvar[T], parentResult: var T) {.raises:[].} =
|
||||
proc completeFuture[T](fv: Flowvar[T], parentResult: var T) {.raises:[].} =
|
||||
## Eagerly complete an awaited FlowVar
|
||||
template ctx: untyped = workerContext
|
||||
|
||||
template isFutReady(): untyped =
|
||||
let isReady = fv.task.isCompleted()
|
||||
let isReady = fv.isReady()
|
||||
if isReady:
|
||||
parentResult = cast[ptr (ptr Task, T)](fv.task.env.addr)[1]
|
||||
parentResult.copyResult(fv)
|
||||
isReady
|
||||
|
||||
if isFutReady():
|
||||
@ -662,25 +791,46 @@ proc completeFuture*[T](fv: Flowvar[T], parentResult: var T) {.raises:[].} =
|
||||
|
||||
debug: log("Worker %3d: sync 2 - future not ready, becoming a thief (currentTask 0x%.08x, awaitedTask 0x%.08x)\n", ctx.id, ctx.currentTask, fv.task)
|
||||
while not isFutReady():
|
||||
if (let leapTask = ctx.tryLeapfrog(fv.task); not leapTask.isNil):
|
||||
if (let leapTask = ctx.tryLeapfrog(fv.getTask()); not leapTask.isNil):
|
||||
# Theft successful, there might be more work for idle threads, wake one
|
||||
ctx.threadpool.globalBackoff.wake()
|
||||
ctx.incCounter(backoffGlobalSignalSent)
|
||||
|
||||
# Leapfrogging, the thief had an empty queue, hence if there are tasks in its queue, it's generated by our blocked task.
|
||||
# Help the thief clear those, as if it did not finish, it's likely blocked on those children tasks.
|
||||
ctx.incCounter(theftsLeapfrog)
|
||||
ctx.incCounter(itersStolen):
|
||||
if leapTask.loopStepsLeft == NotALoop: 0
|
||||
else: leapTask.loopStepsLeft
|
||||
|
||||
debug: log("Worker %3d: sync 2.1 - leapfrog task 0x%.08x (parent 0x%.08x, current 0x%.08x, awaitedTask 0x%.08x)\n", ctx.id, leapTask, leapTask.parent, ctx.currentTask, fv.task)
|
||||
ctx.run(leapTask)
|
||||
elif (let stolenTask = ctx.tryStealOne(); not stolenTask.isNil):
|
||||
# Theft successful, there might be more work for idle threads, wake one
|
||||
ctx.threadpool.globalBackoff.wake()
|
||||
ctx.incCounter(backoffGlobalSignalSent)
|
||||
|
||||
# We stole a task, we hope we advance our awaited task.
|
||||
ctx.incCounter(theftsAwaiting)
|
||||
ctx.incCounter(itersStolen):
|
||||
if stolenTask.loopStepsLeft == NotALoop: 0
|
||||
else: stolenTask.loopStepsLeft
|
||||
|
||||
debug: log("Worker %3d: sync 2.2 - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x, awaitedTask 0x%.08x)\n", ctx.id, stolenTask, stolenTask.parent, ctx.currentTask, fv.task)
|
||||
ctx.run(stolenTask)
|
||||
elif (let ownTask = ctx.taskqueue[].pop(); not ownTask.isNil):
|
||||
# We advance our own queue, this increases global throughput but may impact latency on the awaited task.
|
||||
debug: log("Worker %3d: sync 2.3 - couldn't steal, running own task (awaitedTask 0x%.08x)\n", ctx.id, fv.task)
|
||||
ctx.incCounter(unrelatedTasksExecuted)
|
||||
ctx.run(ownTask)
|
||||
else:
|
||||
# Nothing to do, we park.
|
||||
# - On today's hyperthreaded systems, this might reduce contention on a core resources like memory caches and execution ports
|
||||
# - If more work is created, we won't be notified as we need to park on a dedicated notifier for precise wakeup when future is ready
|
||||
debugTermination: log("Worker %3d: sync 2.4 - Empty runtime, parking (awaitedTask 0x%.08x)\n", ctx.id, fv.task)
|
||||
fv.task.sleepUntilComplete(ctx.id)
|
||||
ctx.incCounter(backoffTaskAwaited)
|
||||
profile(backoff_awaiting):
|
||||
fv.getTask().sleepUntilComplete(ctx.id)
|
||||
debugTermination: log("Worker %3d: sync 2.4 - signaled, waking (awaitedTask 0x%.08x)\n", ctx.id, fv.task)
|
||||
|
||||
proc syncAll*(tp: Threadpool) {.raises: [].} =
|
||||
@ -694,6 +844,8 @@ proc syncAll*(tp: Threadpool) {.raises: [].} =
|
||||
preCondition: ctx.id == 0
|
||||
preCondition: ctx.currentTask.isRootTask()
|
||||
|
||||
profileStop(run_task)
|
||||
|
||||
while true:
|
||||
# 1. Empty local tasks
|
||||
debug: log("Worker %3d: syncAll 1 - searching task from local queue\n", ctx.id)
|
||||
@ -705,6 +857,15 @@ proc syncAll*(tp: Threadpool) {.raises: [].} =
|
||||
if (var stolenTask = ctx.tryStealOne(); not stolenTask.isNil):
|
||||
# 2.a We stole some task
|
||||
debug: log("Worker %3d: syncAll 2.a - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, stolenTask, stolenTask.parent, ctx.currentTask)
|
||||
|
||||
# Theft successful, there might be more work for idle threads, wake one
|
||||
ctx.threadpool.globalBackoff.wake()
|
||||
ctx.incCounter(backoffGlobalSignalSent)
|
||||
|
||||
ctx.incCounter(theftsIdle)
|
||||
ctx.incCounter(itersStolen):
|
||||
if stolenTask.loopStepsLeft == NotALoop: 0
|
||||
else: stolenTask.loopStepsLeft
|
||||
ctx.run(stolenTask)
|
||||
elif tp.globalBackoff.getNumWaiters() == (0'i32, tp.numThreads-1): # Don't count ourselves
|
||||
# 2.b all threads besides the current are parked
|
||||
@ -717,6 +878,8 @@ proc syncAll*(tp: Threadpool) {.raises: [].} =
|
||||
debugTermination:
|
||||
log(">>> Worker %3d leaves barrier <<<\n", ctx.id)
|
||||
|
||||
profileStart(run_task)
|
||||
|
||||
# ############################################################
|
||||
# #
|
||||
# Runtime API #
|
||||
@ -760,6 +923,7 @@ proc new*(T: type Threadpool, numThreads = countProcessors()): T {.raises: [Reso
|
||||
|
||||
# Wait for the child threads
|
||||
discard tp.barrier.wait()
|
||||
profileStart(run_task)
|
||||
return tp
|
||||
|
||||
proc cleanup(tp: var Threadpool) {.raises: [].} =
|
||||
@ -781,6 +945,7 @@ proc shutdown*(tp: var Threadpool) {.raises:[].} =
|
||||
## Wait until all tasks are processed and then shutdown the threadpool
|
||||
preCondition: workerContext.currentTask.isRootTask()
|
||||
tp.syncAll()
|
||||
profileStop(run_task)
|
||||
|
||||
# Signal termination to all threads
|
||||
for i in 0 ..< tp.numThreads:
|
||||
@ -791,6 +956,9 @@ proc shutdown*(tp: var Threadpool) {.raises:[].} =
|
||||
# 1 matching barrier in workerEntryFn
|
||||
discard tp.barrier.wait()
|
||||
|
||||
workerContext.printWorkerMetrics()
|
||||
workerContext.id.printWorkerProfiling()
|
||||
|
||||
workerContext.teardownWorker()
|
||||
tp.cleanup()
|
||||
|
||||
@ -813,12 +981,34 @@ macro spawn*(tp: Threadpool, fnCall: typed): untyped =
|
||||
##
|
||||
## If the function calls returns a result, spawn will wrap it in a Flowvar.
|
||||
## You can use `sync` to block the current thread and extract the asynchronous result from the flowvar.
|
||||
## You can use `isReady` to check if result is available and if subsequent
|
||||
## `spawn` returns immediately.
|
||||
## You can use `isReady` to check if result is available and if a subsequent
|
||||
## `sync` returns immediately.
|
||||
##
|
||||
## Tasks are processed approximately in Last-In-First-Out (LIFO) order
|
||||
result = spawnImpl(tp, fnCall, bindSym"workerContext", bindSym"schedule")
|
||||
|
||||
macro spawnAwaitable*(tp: Threadpool, fnCall: typed): untyped =
|
||||
## Spawns the input function call asynchronously, potentially on another thread of execution.
|
||||
##
|
||||
## This allows awaiting a void function.
|
||||
## The result, once ready, is always `true`.
|
||||
##
|
||||
## You can use `sync` to block the current thread until the function is finished.
|
||||
## You can use `isReady` to check if result is available and if a subsequent
|
||||
## `sync` returns immediately.
|
||||
##
|
||||
## Tasks are processed approximately in Last-In-First-Out (LIFO) order
|
||||
result = spawnAwaitableImpl(tp, fnCall, bindSym"workerContext", bindSym"schedule")
|
||||
|
||||
proc sync*[T](fv: sink Flowvar[T]): T {.noInit, inline, gcsafe.} =
|
||||
## Blocks the current thread until the flowvar is available
|
||||
## and returned.
|
||||
## The thread is not idle and will complete pending tasks.
|
||||
profileStop(run_task)
|
||||
completeFuture(fv, result)
|
||||
cleanup(fv)
|
||||
profileStart(run_task)
|
||||
|
||||
# Data parallel API
|
||||
# ---------------------------------------------
|
||||
|
||||
@ -855,4 +1045,4 @@ macro parallelFor*(tp: Threadpool, loopParams: untyped, body: untyped): untyped
|
||||
loopParams, body)
|
||||
|
||||
result.add quote do:
|
||||
{.pop.}
|
||||
{.pop.}
|
||||
|
||||
@ -972,7 +972,7 @@ proc run_EC_multi_scalar_mul_impl*[N: static int](
|
||||
naive.setInf()
|
||||
for i in 0 ..< n:
|
||||
naive_tmp.fromAffine(points[i])
|
||||
naive_tmp.scalarMulGeneric(coefs[i])
|
||||
naive_tmp.scalarMul(coefs[i])
|
||||
naive += naive_tmp
|
||||
|
||||
var msm_ref, msm: EC
|
||||
|
||||
29
tests/parallel/t_ec_shortw_jac_g1_msm_parallel.nim
Normal file
29
tests/parallel/t_ec_shortw_jac_g1_msm_parallel.nim
Normal file
@ -0,0 +1,29 @@
|
||||
# Constantine
|
||||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import
|
||||
# Internals
|
||||
../../constantine/math/config/curves,
|
||||
../../constantine/math/elliptic/ec_shortweierstrass_jacobian,
|
||||
../../constantine/math/arithmetic,
|
||||
# Test utilities
|
||||
./t_ec_template_parallel
|
||||
|
||||
const numPoints = [1, 2, 8, 16, 32, 64, 128, 1024, 2048, 16384] # 32768, 262144, 1048576]
|
||||
|
||||
run_EC_multi_scalar_mul_parallel_impl(
|
||||
ec = ECP_ShortW_Jac[Fp[BN254_Snarks], G1],
|
||||
numPoints = numPoints,
|
||||
moduleName = "test_ec_shortweierstrass_jacobian_multi_scalar_mul_" & $BN254_Snarks
|
||||
)
|
||||
|
||||
run_EC_multi_scalar_mul_parallel_impl(
|
||||
ec = ECP_ShortW_Jac[Fp[BLS12_381], G1],
|
||||
numPoints = numPoints,
|
||||
moduleName = "test_ec_shortweierstrass_jacobian_multi_scalar_mul_" & $BLS12_381
|
||||
)
|
||||
29
tests/parallel/t_ec_shortw_prj_g1_msm_parallel.nim
Normal file
29
tests/parallel/t_ec_shortw_prj_g1_msm_parallel.nim
Normal file
@ -0,0 +1,29 @@
|
||||
# Constantine
|
||||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import
|
||||
# Internals
|
||||
../../constantine/math/config/curves,
|
||||
../../constantine/math/elliptic/ec_shortweierstrass_projective,
|
||||
../../constantine/math/arithmetic,
|
||||
# Test utilities
|
||||
./t_ec_template_parallel
|
||||
|
||||
const numPoints = [1, 2, 8, 16, 128, 1024, 2048, 16384] # 32768, 262144, 1048576]
|
||||
|
||||
run_EC_multi_scalar_mul_parallel_impl(
|
||||
ec = ECP_ShortW_Prj[Fp[BN254_Snarks], G1],
|
||||
numPoints = numPoints,
|
||||
moduleName = "test_ec_shortweierstrass_projective_multi_scalar_mul_" & $BN254_Snarks
|
||||
)
|
||||
|
||||
run_EC_multi_scalar_mul_parallel_impl(
|
||||
ec = ECP_ShortW_Prj[Fp[BLS12_381], G1],
|
||||
numPoints = numPoints,
|
||||
moduleName = "test_ec_shortweierstrass_projective_multi_scalar_mul_" & $BLS12_381
|
||||
)
|
||||
@ -17,12 +17,16 @@ import
|
||||
std/[unittest, times],
|
||||
# Internals
|
||||
../../constantine/platforms/abstractions,
|
||||
../../constantine/math/constants/zoo_subgroups,
|
||||
../../constantine/math/[arithmetic, extension_fields],
|
||||
../../constantine/math/elliptic/[
|
||||
ec_shortweierstrass_affine,
|
||||
ec_shortweierstrass_jacobian,
|
||||
ec_shortweierstrass_projective,
|
||||
ec_shortweierstrass_batch_ops_parallel],
|
||||
ec_shortweierstrass_batch_ops_parallel,
|
||||
ec_scalar_mul,
|
||||
ec_multi_scalar_mul,
|
||||
ec_multi_scalar_mul_parallel],
|
||||
../../constantine/platforms/threadpool/threadpool,
|
||||
# Test utilities
|
||||
../../helpers/prng_unsafe
|
||||
@ -63,27 +67,21 @@ func random_point*(rng: var RngState, EC: typedesc, randZ: bool, gen: RandomGen)
|
||||
proc run_EC_batch_add_parallel_impl*[N: static int](
|
||||
ec: typedesc,
|
||||
numPoints: array[N, int],
|
||||
moduleName: string
|
||||
) =
|
||||
moduleName: string) =
|
||||
|
||||
# Random seed for reproducibility
|
||||
var rng: RngState
|
||||
let seed = 1674654772 # uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
|
||||
let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
|
||||
rng.seed(seed)
|
||||
echo "\n------------------------------------------------------\n"
|
||||
echo moduleName, " xoshiro512** seed: ", seed
|
||||
|
||||
when ec.G == G1:
|
||||
const G1_or_G2 = "G1"
|
||||
else:
|
||||
const G1_or_G2 = "G2"
|
||||
|
||||
const testSuiteDesc = "Elliptic curve parallel sum reduction for Short Weierstrass form"
|
||||
|
||||
suite testSuiteDesc & " - " & $ec & " - [" & $WordBitWidth & "-bit mode]":
|
||||
suite testSuiteDesc & " - " & $ec.G & " - [" & $WordBitWidth & "-bit mode]":
|
||||
|
||||
for n in numPoints:
|
||||
test $ec & " sum reduction (N=" & $n & ")":
|
||||
test $ec & " parallel sum reduction (N=" & $n & ")":
|
||||
proc test(EC: typedesc, gen: RandomGen) =
|
||||
var tp = Threadpool.new()
|
||||
defer: tp.shutdown()
|
||||
@ -108,7 +106,7 @@ proc run_EC_batch_add_parallel_impl*[N: static int](
|
||||
test(ec, gen = HighHammingWeight)
|
||||
test(ec, gen = Long01Sequence)
|
||||
|
||||
test "EC " & G1_or_G2 & " sum reduction (N=" & $n & ") - special cases":
|
||||
test "EC " & $ec.G & " parallel sum reduction (N=" & $n & ") - special cases":
|
||||
proc test(EC: typedesc, gen: RandomGen) =
|
||||
var tp = Threadpool.new()
|
||||
defer: tp.shutdown()
|
||||
@ -144,3 +142,49 @@ proc run_EC_batch_add_parallel_impl*[N: static int](
|
||||
test(ec, gen = Uniform)
|
||||
test(ec, gen = HighHammingWeight)
|
||||
test(ec, gen = Long01Sequence)
|
||||
|
||||
|
||||
proc run_EC_multi_scalar_mul_parallel_impl*[N: static int](
|
||||
ec: typedesc,
|
||||
numPoints: array[N, int],
|
||||
moduleName: string) =
|
||||
# Random seed for reproducibility
|
||||
var rng: RngState
|
||||
let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
|
||||
rng.seed(seed)
|
||||
echo "\n------------------------------------------------------\n"
|
||||
echo moduleName, " xoshiro512** seed: ", seed
|
||||
|
||||
const testSuiteDesc = "Elliptic curve parallel multi-scalar-multiplication for Short Weierstrass form"
|
||||
|
||||
suite testSuiteDesc & " - " & $ec & " - [" & $WordBitWidth & "-bit mode]":
|
||||
for n in numPoints:
|
||||
let bucketBits = bestBucketBitSize(n, ec.F.C.getCurveOrderBitwidth(), useSignedBuckets = false, useManualTuning = false)
|
||||
test $ec & " Parallel Multi-scalar-mul (N=" & $n & ", bucket bits: " & $bucketBits & ")":
|
||||
proc test(EC: typedesc, gen: RandomGen) =
|
||||
var tp = Threadpool.new()
|
||||
defer: tp.shutdown()
|
||||
var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](n)
|
||||
var coefs = newSeq[BigInt[EC.F.C.getCurveOrderBitwidth()]](n)
|
||||
|
||||
for i in 0 ..< n:
|
||||
var tmp = rng.random_unsafe(EC)
|
||||
tmp.clearCofactor()
|
||||
points[i].affine(tmp)
|
||||
coefs[i] = rng.random_unsafe(BigInt[EC.F.C.getCurveOrderBitwidth()])
|
||||
|
||||
var naive, naive_tmp: EC
|
||||
naive.setInf()
|
||||
for i in 0 ..< n:
|
||||
naive_tmp.fromAffine(points[i])
|
||||
naive_tmp.scalarMul(coefs[i])
|
||||
naive += naive_tmp
|
||||
|
||||
var msm: EC
|
||||
tp.multiScalarMul_vartime_parallel(msm, coefs, points)
|
||||
|
||||
doAssert bool(naive == msm)
|
||||
|
||||
test(ec, gen = Uniform)
|
||||
test(ec, gen = HighHammingWeight)
|
||||
test(ec, gen = Long01Sequence)
|
||||
Loading…
x
Reference in New Issue
Block a user