mirror of
https://github.com/codex-storage/constantine.git
synced 2025-01-27 11:04:51 +00:00
MSM tuning for high core count (#227)
* tune for high core count * reentrancy: allow nesting of parallel functions by introducing precise scoped barriers * increase collision queue depth
This commit is contained in:
parent
6c48975aee
commit
93dac2503c
@ -65,35 +65,35 @@ proc msmParallelBench*(EC: typedesc, numPoints: int, iters: int) =
|
||||
var startNaive, stopNaive, startMSMbaseline, stopMSMbaseline, startMSMopt, stopMSMopt, startMSMpara, stopMSMpara: MonoTime
|
||||
|
||||
if numPoints <= 100000:
|
||||
startNaive = getMonotime()
|
||||
bench("EC scalar muls " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
|
||||
startNaive = getMonotime()
|
||||
var tmp: EC
|
||||
r.setInf()
|
||||
for i in 0 ..< points.len:
|
||||
tmp.fromAffine(points[i])
|
||||
tmp.scalarMul(scalars[i])
|
||||
r += tmp
|
||||
stopNaive = getMonotime()
|
||||
stopNaive = getMonotime()
|
||||
|
||||
block:
|
||||
if numPoints <= 100000:
|
||||
startMSMbaseline = getMonotime()
|
||||
bench("EC multi-scalar-mul baseline " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
|
||||
startMSMbaseline = getMonotime()
|
||||
r.multiScalarMul_reference_vartime(scalars, points)
|
||||
stopMSMbaseline = getMonotime()
|
||||
stopMSMbaseline = getMonotime()
|
||||
|
||||
block:
|
||||
startMSMopt = getMonotime()
|
||||
bench("EC multi-scalar-mul optimized " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
|
||||
startMSMopt = getMonotime()
|
||||
r.multiScalarMul_vartime(scalars, points)
|
||||
stopMSMopt = getMonotime()
|
||||
stopMSMopt = getMonotime()
|
||||
|
||||
block:
|
||||
var tp = Threadpool.new()
|
||||
|
||||
startMSMpara = getMonotime()
|
||||
bench("EC multi-scalar-mul" & align($tp.numThreads & " threads", 11) & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
|
||||
startMSMpara = getMonotime()
|
||||
tp.multiScalarMul_vartime_parallel(r, scalars, points)
|
||||
stopMSMpara = getMonotime()
|
||||
stopMSMpara = getMonotime()
|
||||
|
||||
tp.shutdown()
|
||||
|
||||
@ -109,8 +109,8 @@ proc msmParallelBench*(EC: typedesc, numPoints: int, iters: int) =
|
||||
let speedupOpt = float(perfNaive) / float(perfMSMopt)
|
||||
echo &"Speedup ratio optimized over naive linear combination: {speedupOpt:>6.3f}x"
|
||||
|
||||
let speedupOptBaseline = float(perfMSMbaseline) / float(perfMSMopt)
|
||||
echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x"
|
||||
let speedupOptBaseline = float(perfMSMbaseline) / float(perfMSMopt)
|
||||
echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x"
|
||||
|
||||
let speedupParaOpt = float(perfMSMopt) / float(perfMSMpara)
|
||||
echo &"Speedup ratio parallel over optimized linear combination: {speedupParaOpt:>6.3f}x"
|
||||
|
@ -118,12 +118,8 @@ func multiScalarMul_reference_vartime*[EC](r: var EC, coefs: openArray[BigInt],
|
||||
of 13: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 13)
|
||||
of 14: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 14)
|
||||
of 15: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 15)
|
||||
of 16: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 16)
|
||||
of 17: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 17)
|
||||
of 18: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 18)
|
||||
of 19: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 19)
|
||||
of 20: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 20)
|
||||
of 21: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 21)
|
||||
|
||||
of 16..20: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 16)
|
||||
else:
|
||||
unreachable()
|
||||
|
||||
@ -271,6 +267,8 @@ func schedAccumulate*[NumBuckets, QueueLen, F, G; bits: static int](
|
||||
const top = bits - excess
|
||||
static: doAssert miniMsmKind != kTopWindow, "The top window is smaller in bits which increases collisions in scheduler."
|
||||
|
||||
sched.bucketInit()
|
||||
|
||||
var curSP, nextSP: ScheduledPoint
|
||||
|
||||
template getSignedWindow(j : int): tuple[val: SecretWord, neg: SecretBool] =
|
||||
@ -295,14 +293,12 @@ func miniMSM_affine[NumBuckets, QueueLen, F, G; bits: static int](
|
||||
## Apply a mini-Multi-Scalar-Multiplication on [bitIndex, bitIndex+window)
|
||||
## slice of all (coef, point) pairs
|
||||
|
||||
sched.buckets[].init()
|
||||
|
||||
# 1. Bucket Accumulation
|
||||
sched.schedAccumulate(bitIndex, miniMsmKind, c, coefs, N)
|
||||
|
||||
# 2. Bucket Reduction
|
||||
var windowSum_jacext{.noInit.}: ECP_ShortW_JacExt[F, G]
|
||||
windowSum_jacext.bucketReduce(sched.buckets[])
|
||||
windowSum_jacext.bucketReduce(sched.buckets)
|
||||
|
||||
# 3. Mini-MSM on the slice [bitIndex, bitIndex+window)
|
||||
var windowSum{.noInit.}: typeof(r)
|
||||
@ -324,7 +320,6 @@ func multiScalarMulAffine_vartime[F, G; bits: static int](
|
||||
# -----
|
||||
const (numBuckets, queueLen) = c.deriveSchedulerConstants()
|
||||
let buckets = allocHeap(Buckets[numBuckets, F, G])
|
||||
buckets[].init()
|
||||
let sched = allocHeap(Scheduler[numBuckets, queueLen, F, G])
|
||||
sched.init(points, buckets, 0, numBuckets.int32)
|
||||
|
||||
@ -440,11 +435,10 @@ func multiScalarMul_dispatch_vartime[bits: static int, F, G](
|
||||
of 11: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 11)
|
||||
of 12: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 12)
|
||||
of 13: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 13)
|
||||
of 14: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 14)
|
||||
of 15: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 15)
|
||||
of 16: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 16)
|
||||
of 17: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 17)
|
||||
of 18: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 18)
|
||||
of 14: multiScalarMulAffine_vartime(r, coefs, points, N, c = 14)
|
||||
of 15: multiScalarMulAffine_vartime(r, coefs, points, N, c = 15)
|
||||
|
||||
of 16..17: multiScalarMulAffine_vartime(r, coefs, points, N, c = 16)
|
||||
else:
|
||||
unreachable()
|
||||
|
||||
@ -458,4 +452,4 @@ func multiScalarMul_vartime*[bits: static int, F, G](
|
||||
debug: doAssert coefs.len == points.len
|
||||
let N = points.len
|
||||
|
||||
multiScalarMul_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N)
|
||||
multiScalarMul_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N)
|
||||
|
@ -11,7 +11,7 @@ import ./ec_multi_scalar_mul_scheduler,
|
||||
./ec_endomorphism_accel,
|
||||
../extension_fields,
|
||||
../constants/zoo_endomorphisms,
|
||||
../../platforms/threadpool/threadpool
|
||||
../../platforms/threadpool/[threadpool, partitioners]
|
||||
export bestBucketBitSize
|
||||
|
||||
# No exceptions allowed in core cryptographic operations
|
||||
@ -133,8 +133,8 @@ export bestBucketBitSize
|
||||
# Parallel MSM Jacobian Extended
|
||||
# ------------------------------
|
||||
|
||||
proc bucketAccumReduce_jacext_zeroMem[F, G; bits: static int](
|
||||
windowSum: ptr ECP_ShortW[F, G],
|
||||
proc bucketAccumReduce_jacext_zeroMem[EC, F, G; bits: static int](
|
||||
windowSum: ptr EC,
|
||||
buckets: ptr ECP_ShortW_JacExt[F, G] or ptr UncheckedArray[ECP_ShortW_JacExt[F, G]],
|
||||
bitIndex: int, miniMsmKind: static MiniMsmKind, c: static int,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], N: int) =
|
||||
@ -143,9 +143,9 @@ proc bucketAccumReduce_jacext_zeroMem[F, G; bits: static int](
|
||||
zeroMem(buckets, sizeof(ECP_ShortW_JacExt[F, G]) * numBuckets)
|
||||
bucketAccumReduce_jacext(windowSum[], buckets, bitIndex, miniMsmKind, c, coefs, points, N)
|
||||
|
||||
proc msmJacExt_vartime_parallel*[bits: static int, F, G](
|
||||
proc msmJacExt_vartime_parallel*[bits: static int, EC, F, G](
|
||||
tp: Threadpool,
|
||||
r: var ECP_ShortW[F, G],
|
||||
r: ptr EC,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int, c: static int) =
|
||||
|
||||
@ -158,7 +158,6 @@ proc msmJacExt_vartime_parallel*[bits: static int, F, G](
|
||||
# Instead of storing the result in futures, risking them being scattered in memory
|
||||
# we store them in a contiguous array, and the synchronizing future just returns a bool.
|
||||
# top window is done on this thread
|
||||
type EC = typeof(r)
|
||||
let miniMSMsResults = allocHeapArray(EC, numFullWindows)
|
||||
let miniMSMsReady = allocStackArray(FlowVar[bool], numFullWindows)
|
||||
|
||||
@ -188,41 +187,38 @@ proc msmJacExt_vartime_parallel*[bits: static int, F, G](
|
||||
when top != 0:
|
||||
when excess != 0:
|
||||
bucketAccumReduce_jacext_zeroMem(
|
||||
r.addr,
|
||||
r,
|
||||
bucketsMatrix[numFullWindows*numBuckets].addr,
|
||||
bitIndex = top, kTopWindow, c,
|
||||
coefs, points, N)
|
||||
else:
|
||||
r.setInf()
|
||||
r[].setInf()
|
||||
|
||||
# 3. Final reduction, r initialized to what would be miniMSMsReady[numWindows-1]
|
||||
when excess != 0:
|
||||
for w in countdown(numWindows-2, 0):
|
||||
for _ in 0 ..< c:
|
||||
r.double()
|
||||
r[].double()
|
||||
discard sync miniMSMsReady[w]
|
||||
r += miniMSMsResults[w]
|
||||
r[] += miniMSMsResults[w]
|
||||
elif numWindows >= 2:
|
||||
discard sync miniMSMsReady[numWindows-2]
|
||||
r = miniMSMsResults[numWindows-2]
|
||||
r[] = miniMSMsResults[numWindows-2]
|
||||
for w in countdown(numWindows-3, 0):
|
||||
for _ in 0 ..< c:
|
||||
r.double()
|
||||
r[].double()
|
||||
discard sync miniMSMsReady[w]
|
||||
r += miniMSMsResults[w]
|
||||
r[] += miniMSMsResults[w]
|
||||
|
||||
# Cleanup
|
||||
# -------
|
||||
miniMSMsResults.freeHeap()
|
||||
bucketsMatrix.freeHeap()
|
||||
|
||||
# Parallel MSM Affine
|
||||
# ------------------------------
|
||||
|
||||
|
||||
proc bucketAccumReduce_parallel[bits: static int, F, G](
|
||||
tp: Threadpool,
|
||||
r: ptr ECP_ShortW[F, G],
|
||||
# Parallel MSM Affine - bucket accumulation
|
||||
# -----------------------------------------
|
||||
proc bucketAccumReduce_serial[bits: static int, EC, F, G](
|
||||
r: ptr EC,
|
||||
bitIndex: int,
|
||||
miniMsmKind: static MiniMsmKind, c: static int,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
@ -230,21 +226,46 @@ proc bucketAccumReduce_parallel[bits: static int, F, G](
|
||||
N: int) =
|
||||
|
||||
const (numBuckets, queueLen) = c.deriveSchedulerConstants()
|
||||
const outerParallelism = bits div c # It's actually ceilDiv instead of floorDiv, but the last iteration might be too small
|
||||
let buckets = allocHeap(Buckets[numBuckets, F, G])
|
||||
let sched = allocHeap(Scheduler[numBuckets, queueLen, F, G])
|
||||
sched.init(points, buckets, 0, numBuckets.int32)
|
||||
|
||||
var innerParallelism = 1'i32
|
||||
while outerParallelism*innerParallelism < tp.numThreads:
|
||||
innerParallelism = innerParallelism shl 1
|
||||
# 1. Bucket Accumulation
|
||||
sched.schedAccumulate(bitIndex, miniMsmKind, c, coefs, N)
|
||||
|
||||
let numChunks = 1'i32 # innerParallelism # TODO: unfortunately trying to expose more parallelism slows down the performance
|
||||
# 2. Bucket Reduction
|
||||
var windowSum{.noInit.}: ECP_ShortW_JacExt[F, G]
|
||||
windowSum.bucketReduce(sched.buckets)
|
||||
r[].fromJacobianExtended_vartime(windowSum)
|
||||
|
||||
# Cleanup
|
||||
# ----------------
|
||||
sched.freeHeap()
|
||||
buckets.freeHeap()
|
||||
|
||||
proc bucketAccumReduce_parallel[bits: static int, EC, F, G](
|
||||
tp: Threadpool,
|
||||
r: ptr EC,
|
||||
bitIndex: int,
|
||||
miniMsmKind: static MiniMsmKind, c: static int,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int) =
|
||||
|
||||
const (numBuckets, queueLen) = c.deriveSchedulerConstants()
|
||||
const windowParallelism = bits div c # It's actually ceilDiv instead of floorDiv, but the last iteration might be too small
|
||||
|
||||
var bucketParallelism = 1'i32
|
||||
while windowParallelism*bucketParallelism < tp.numThreads:
|
||||
bucketParallelism = bucketParallelism shl 1
|
||||
|
||||
let numChunks = bucketParallelism
|
||||
let chunkSize = int32(numBuckets) shr log2_vartime(cast[uint32](numChunks)) # Both are power of 2 so exact division
|
||||
let chunksReadiness = allocStackArray(FlowVar[bool], numChunks-1) # Last chunk is done on this thread
|
||||
|
||||
let buckets = allocHeap(Buckets[numBuckets, F, G])
|
||||
let scheds = allocHeapArray(Scheduler[numBuckets, queueLen, F, G], numChunks)
|
||||
|
||||
buckets[].init()
|
||||
|
||||
block: # 1. Bucket Accumulation
|
||||
for chunkID in 0'i32 ..< numChunks-1:
|
||||
let idx = chunkID*chunkSize
|
||||
@ -254,7 +275,7 @@ proc bucketAccumReduce_parallel[bits: static int, F, G](
|
||||
scheds[numChunks-1].addr.init(points, buckets, (numChunks-1)*chunkSize, int32 numBuckets)
|
||||
scheds[numChunks-1].addr.schedAccumulate(bitIndex, miniMsmKind, c, coefs, N)
|
||||
|
||||
block: # 2. Bucket reduction
|
||||
block: # 2. Bucket reduction with latency hiding
|
||||
var windowSum{.noInit.}: ECP_ShortW_JacExt[F, G]
|
||||
var accumBuckets{.noinit.}: ECP_ShortW_JacExt[F, G]
|
||||
|
||||
@ -268,7 +289,7 @@ proc bucketAccumReduce_parallel[bits: static int, F, G](
|
||||
else:
|
||||
accumBuckets.setInf()
|
||||
windowSum = accumBuckets
|
||||
buckets[].reset(numBuckets-1)
|
||||
buckets.reset(numBuckets-1)
|
||||
|
||||
var nextBatch = numBuckets-1-chunkSize
|
||||
var nextFutureIdx = numChunks-2
|
||||
@ -289,7 +310,7 @@ proc bucketAccumReduce_parallel[bits: static int, F, G](
|
||||
elif kJacExt in buckets.status[k]:
|
||||
accumBuckets += buckets.ptJacExt[k]
|
||||
|
||||
buckets[].reset(k)
|
||||
buckets.reset(k)
|
||||
windowSum += accumBuckets
|
||||
|
||||
r[].fromJacobianExtended_vartime(windowSum)
|
||||
@ -299,11 +320,14 @@ proc bucketAccumReduce_parallel[bits: static int, F, G](
|
||||
scheds.freeHeap()
|
||||
buckets.freeHeap()
|
||||
|
||||
proc msmAffine_vartime_parallel*[bits: static int, F, G](
|
||||
# Parallel MSM Affine - window-level only
|
||||
# ---------------------------------------
|
||||
|
||||
proc msmAffine_vartime_parallel*[bits: static int, EC, F, G](
|
||||
tp: Threadpool,
|
||||
r: var ECP_ShortW[F, G],
|
||||
r: ptr EC,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int, c: static int) =
|
||||
N: int, c: static int, useParallelBuckets: static bool) =
|
||||
|
||||
# Prologue
|
||||
# --------
|
||||
@ -314,24 +338,36 @@ proc msmAffine_vartime_parallel*[bits: static int, F, G](
|
||||
# Instead of storing the result in futures, risking them being scattered in memory
|
||||
# we store them in a contiguous array, and the synchronizing future just returns a bool.
|
||||
# top window is done on this thread
|
||||
type EC = typeof(r)
|
||||
type EC = typeof(r[])
|
||||
let miniMSMsResults = allocHeapArray(EC, numFullWindows)
|
||||
let miniMSMsReady = allocStackArray(Flowvar[bool], numFullWindows)
|
||||
|
||||
# Algorithm
|
||||
# ---------
|
||||
|
||||
block: # 1. Bucket accumulation and reduction
|
||||
# 1. mini-MSMs: Bucket accumulation and reduction
|
||||
when useParallelBuckets:
|
||||
miniMSMsReady[0] = tp.spawnAwaitable bucketAccumReduce_parallel(
|
||||
tp, miniMSMsResults[0].addr,
|
||||
bitIndex = 0, kBottomWIndow, c,
|
||||
coefs, points, N)
|
||||
tp, miniMSMsResults[0].addr,
|
||||
bitIndex = 0, kBottomWindow, c,
|
||||
coefs, points, N)
|
||||
|
||||
for w in 1 ..< numFullWindows:
|
||||
miniMSMsReady[w] = tp.spawnAwaitable bucketAccumReduce_parallel(
|
||||
tp, miniMSMsResults[w].addr,
|
||||
bitIndex = w*c, kFullWIndow, c,
|
||||
coefs, points, N)
|
||||
for w in 1 ..< numFullWindows:
|
||||
miniMSMsReady[w] = tp.spawnAwaitable bucketAccumReduce_parallel(
|
||||
tp, miniMSMsResults[w].addr,
|
||||
bitIndex = w*c, kFullWindow, c,
|
||||
coefs, points, N)
|
||||
else:
|
||||
miniMSMsReady[0] = tp.spawnAwaitable bucketAccumReduce_serial(
|
||||
miniMSMsResults[0].addr,
|
||||
bitIndex = 0, kBottomWindow, c,
|
||||
coefs, points, N)
|
||||
|
||||
for w in 1 ..< numFullWindows:
|
||||
miniMSMsReady[w] = tp.spawnAwaitable bucketAccumReduce_serial(
|
||||
miniMSMsResults[w].addr,
|
||||
bitIndex = w*c, kFullWindow, c,
|
||||
coefs, points, N)
|
||||
|
||||
# Last window is done sync on this thread, directly initializing r
|
||||
const excess = bits mod c
|
||||
@ -341,32 +377,79 @@ proc msmAffine_vartime_parallel*[bits: static int, F, G](
|
||||
when excess != 0:
|
||||
let buckets = allocHeapArray(ECP_ShortW_JacExt[F, G], numBuckets)
|
||||
zeroMem(buckets[0].addr, sizeof(ECP_ShortW_JacExt[F, G]) * numBuckets)
|
||||
r.bucketAccumReduce_jacext(buckets, bitIndex = top, kTopWindow, c,
|
||||
r[].bucketAccumReduce_jacext(buckets, bitIndex = top, kTopWindow, c,
|
||||
coefs, points, N)
|
||||
buckets.freeHeap()
|
||||
else:
|
||||
r.setInf()
|
||||
r[].setInf()
|
||||
|
||||
# 3. Final reduction, r initialized to what would be miniMSMsReady[numWindows-1]
|
||||
# 2. Final reduction with latency hiding, r initialized to what would be miniMSMsReady[numWindows-1]
|
||||
when excess != 0:
|
||||
for w in countdown(numWindows-2, 0):
|
||||
for _ in 0 ..< c:
|
||||
r.double()
|
||||
r[].double()
|
||||
discard sync miniMSMsReady[w]
|
||||
r += miniMSMsResults[w]
|
||||
r[] += miniMSMsResults[w]
|
||||
elif numWindows >= 2:
|
||||
discard sync miniMSMsReady[numWindows-2]
|
||||
r = miniMSMsResults[numWindows-2]
|
||||
r[] = miniMSMsResults[numWindows-2]
|
||||
for w in countdown(numWindows-3, 0):
|
||||
for _ in 0 ..< c:
|
||||
r.double()
|
||||
r[].double()
|
||||
discard sync miniMSMsReady[w]
|
||||
r += miniMSMsResults[w]
|
||||
r[] += miniMSMsResults[w]
|
||||
|
||||
# Cleanup
|
||||
# -------
|
||||
miniMSMsResults.freeHeap()
|
||||
|
||||
proc msmAffine_vartime_parallel_split[bits: static int, EC, F, G](
|
||||
tp: Threadpool,
|
||||
r: ptr EC,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int, c: static int, useParallelBuckets: static bool) =
|
||||
|
||||
# Parallelism levels:
|
||||
# - MSM parallelism: compute independent MSMs, this increases the number of total ops
|
||||
# - window parallelism: compute a MSM outer loop on different threads, this has no tradeoffs
|
||||
# - bucket parallelism: handle range of buckets on different threads, threads do superfluous overlapping memory reads
|
||||
#
|
||||
# It seems to be beneficial to have both MSM and bucket level parallelism.
|
||||
# Probably by guaranteeing 2x more tasks than threads, we avoid starvation.
|
||||
|
||||
var windowParallelism = bits div c # It's actually ceilDiv instead of floorDiv, but the last iteration might be too small
|
||||
var msmParallelism = 1'i32
|
||||
|
||||
while windowParallelism*msmParallelism < tp.numThreads:
|
||||
windowParallelism = bits div c # This is an approximation
|
||||
msmParallelism = msmParallelism shl 1
|
||||
|
||||
if msmParallelism == 1:
|
||||
msmAffine_vartime_parallel(tp, r, coefs, points, N, c, useParallelBuckets)
|
||||
return
|
||||
|
||||
let chunkingDescriptor = balancedChunksPrioNumber(0, N, msmParallelism)
|
||||
let splitMSMsResults = allocHeapArray(typeof(r[]), msmParallelism-1)
|
||||
let splitMSMsReady = allocStackArray(Flowvar[bool], msmParallelism-1)
|
||||
|
||||
for (i, start, len) in items(chunkingDescriptor):
|
||||
if i != msmParallelism-1:
|
||||
splitMSMsReady[i] = tp.spawnAwaitable msmAffine_vartime_parallel(
|
||||
tp, splitMSMsResults[i].addr,
|
||||
coefs +% start, points +% start, len,
|
||||
c, useParallelBuckets)
|
||||
else: # Run last on this thread
|
||||
msmAffine_vartime_parallel(
|
||||
tp, r,
|
||||
coefs +% start, points +% start, len,
|
||||
c, useParallelBuckets)
|
||||
|
||||
for i in countdown(msmParallelism-2, 0):
|
||||
discard sync splitMSMsReady[i]
|
||||
r[] += splitMSMsResults[i]
|
||||
|
||||
freeHeap(splitMSMsResults)
|
||||
|
||||
proc applyEndomorphism_parallel[bits: static int, F, G](
|
||||
tp: Threadpool,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
@ -384,39 +467,38 @@ proc applyEndomorphism_parallel[bits: static int, F, G](
|
||||
let splitCoefs = allocHeapArray(array[M, BigInt[L]], N)
|
||||
let endoBasis = allocHeapArray(array[M, ECP_ShortW_Aff[F, G]], N)
|
||||
|
||||
tp.parallelFor i in 0 ..< N:
|
||||
captures: {coefs, points, splitCoefs, endoBasis}
|
||||
syncScope:
|
||||
tp.parallelFor i in 0 ..< N:
|
||||
captures: {coefs, points, splitCoefs, endoBasis}
|
||||
|
||||
var negatePoints {.noinit.}: array[M, SecretBool]
|
||||
splitCoefs[i].decomposeEndo(negatePoints, coefs[i], F)
|
||||
if negatePoints[0].bool:
|
||||
endoBasis[i][0].neg(points[i])
|
||||
else:
|
||||
endoBasis[i][0] = points[i]
|
||||
|
||||
when F is Fp:
|
||||
endoBasis[i][1].x.prod(points[i].x, F.C.getCubicRootOfUnity_mod_p())
|
||||
if negatePoints[1].bool:
|
||||
endoBasis[i][1].y.neg(points[i].y)
|
||||
var negatePoints {.noinit.}: array[M, SecretBool]
|
||||
splitCoefs[i].decomposeEndo(negatePoints, coefs[i], F)
|
||||
if negatePoints[0].bool:
|
||||
endoBasis[i][0].neg(points[i])
|
||||
else:
|
||||
endoBasis[i][1].y = points[i].y
|
||||
else:
|
||||
staticFor m, 1, M:
|
||||
endoBasis[i][m].frobenius_psi(points[i], m)
|
||||
if negatePoints[m].bool:
|
||||
endoBasis[i][m].neg()
|
||||
endoBasis[i][0] = points[i]
|
||||
|
||||
tp.syncAll()
|
||||
when F is Fp:
|
||||
endoBasis[i][1].x.prod(points[i].x, F.C.getCubicRootOfUnity_mod_p())
|
||||
if negatePoints[1].bool:
|
||||
endoBasis[i][1].y.neg(points[i].y)
|
||||
else:
|
||||
endoBasis[i][1].y = points[i].y
|
||||
else:
|
||||
staticFor m, 1, M:
|
||||
endoBasis[i][m].frobenius_psi(points[i], m)
|
||||
if negatePoints[m].bool:
|
||||
endoBasis[i][m].neg()
|
||||
|
||||
let endoCoefs = cast[ptr UncheckedArray[BigInt[L]]](splitCoefs)
|
||||
let endoPoints = cast[ptr UncheckedArray[ECP_ShortW_Aff[F, G]]](endoBasis)
|
||||
|
||||
return (endoCoefs, endoPoints, M*N)
|
||||
|
||||
template withEndo[bits: static int, F, G](
|
||||
template withEndo[bits: static int, EC, F, G](
|
||||
msmProc: untyped,
|
||||
tp: Threadpool,
|
||||
r: var ECP_ShortW[F, G],
|
||||
r: ptr EC,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int, c: static int) =
|
||||
@ -430,9 +512,26 @@ template withEndo[bits: static int, F, G](
|
||||
else:
|
||||
msmProc(tp, r, coefs, points, N, c)
|
||||
|
||||
proc multiScalarMul_dispatch_vartime_parallel[bits: static int, F, G](
|
||||
template withEndo[bits: static int, EC, F, G](
|
||||
msmProc: untyped,
|
||||
tp: Threadpool,
|
||||
r: ptr EC,
|
||||
coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
|
||||
N: int, c: static int, useParallelBuckets: static bool) =
|
||||
when bits <= F.C.getCurveOrderBitwidth() and hasEndomorphismAcceleration(F.C):
|
||||
let (endoCoefs, endoPoints, endoN) = applyEndomorphism_parallel(tp, coefs, points, N)
|
||||
# Given that bits and N changed, we are able to use a bigger `c`
|
||||
# but it has no significant impact on performance
|
||||
msmProc(tp, r, endoCoefs, endoPoints, endoN, c, useParallelBuckets)
|
||||
freeHeap(endoCoefs)
|
||||
freeHeap(endoPoints)
|
||||
else:
|
||||
msmProc(tp, r, coefs, points, N, c, useParallelBuckets)
|
||||
|
||||
proc multiScalarMul_dispatch_vartime_parallel[bits: static int, EC, F, G](
|
||||
tp: Threadpool,
|
||||
r: var ECP_ShortW[F, G], coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
r: ptr EC, coefs: ptr UncheckedArray[BigInt[bits]],
|
||||
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], N: int) =
|
||||
## Multiscalar multiplication:
|
||||
## r <- [a₀]P₀ + [a₁]P₁ + ... + [aₙ]Pₙ
|
||||
@ -448,30 +547,32 @@ proc multiScalarMul_dispatch_vartime_parallel[bits: static int, F, G](
|
||||
of 4: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 4)
|
||||
of 5: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 5)
|
||||
of 6: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 6)
|
||||
of 7: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 7)
|
||||
of 8: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 8)
|
||||
of 9: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 9)
|
||||
of 10: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 10)
|
||||
|
||||
of 11: withEndo(msmAffine_vartime_parallel, tp, r, coefs, points, N, c = 11)
|
||||
of 7: msmJacExt_vartime_parallel(tp, r, coefs, points, N, c = 7)
|
||||
of 8: msmJacExt_vartime_parallel(tp, r, coefs, points, N, c = 8)
|
||||
|
||||
of 12: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 12)
|
||||
of 13: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 13)
|
||||
of 14: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 14)
|
||||
of 15: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 15)
|
||||
of 16: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 16)
|
||||
of 17: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 17)
|
||||
of 18: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 18)
|
||||
of 9: withEndo(msmAffine_vartime_parallel_split, tp, r, coefs, points, N, c = 9, useParallelBuckets = true)
|
||||
of 10: withEndo(msmAffine_vartime_parallel_split, tp, r, coefs, points, N, c = 10, useParallelBuckets = true)
|
||||
|
||||
of 11: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 10, useParallelBuckets = true)
|
||||
of 12: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 11, useParallelBuckets = true)
|
||||
of 13: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 12, useParallelBuckets = true)
|
||||
of 14: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 13, useParallelBuckets = true)
|
||||
of 15: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 14, useParallelBuckets = true)
|
||||
of 16: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 15, useParallelBuckets = true)
|
||||
of 17: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 16, useParallelBuckets = true)
|
||||
else:
|
||||
unreachable()
|
||||
|
||||
proc multiScalarMul_vartime_parallel*[bits: static int, F, G](
|
||||
proc multiScalarMul_vartime_parallel*[bits: static int, EC, F, G](
|
||||
tp: Threadpool,
|
||||
r: var ECP_ShortW[F, G],
|
||||
r: var EC,
|
||||
coefs: openArray[BigInt[bits]],
|
||||
points: openArray[ECP_ShortW_Aff[F, G]]) {.meter, inline.} =
|
||||
|
||||
## Multiscalar multiplication:
|
||||
## r <- [a₀]P₀ + [a₁]P₁ + ... + [aₙ]Pₙ
|
||||
## This function can be nested in another parallel function
|
||||
debug: doAssert coefs.len == points.len
|
||||
let N = points.len
|
||||
|
||||
tp.multiScalarMul_dispatch_vartime_parallel(r, coefs.asUnchecked(), points.asUnchecked(), N)
|
||||
tp.multiScalarMul_dispatch_vartime_parallel(r.addr, coefs.asUnchecked(), points.asUnchecked(), N)
|
||||
|
@ -190,7 +190,7 @@ func bestBucketBitSize*(inputSize: int, scalarBitwidth: static int, useSignedBuc
|
||||
let n = inputSize
|
||||
let b = float32(scalarBitwidth)
|
||||
var minCost = float32(Inf)
|
||||
for c in 2 .. 21:
|
||||
for c in 2 .. 20: # cap return value at 17 after manual tuning
|
||||
let b_over_c = b/c.float32
|
||||
|
||||
let bucket_accumulate_reduce = b_over_c * float32(n + (1 shl (c-s)) - 2) * A
|
||||
@ -255,14 +255,14 @@ type
|
||||
numScheduled, numCollisions: int32
|
||||
collisionsMap: BigInt[NumNZBuckets] # We use a BigInt as a bitmap, when all you have is an axe ...
|
||||
queue: array[QueueLen, ScheduledPoint]
|
||||
collisions: array[32, ScheduledPoint]
|
||||
collisions: array[QueueLen, ScheduledPoint]
|
||||
|
||||
const MinVectorAddThreshold = 32
|
||||
|
||||
func init*(buckets: var Buckets) {.inline.} =
|
||||
func init*(buckets: ptr Buckets) {.inline.} =
|
||||
zeroMem(buckets.status.addr, buckets.status.sizeof())
|
||||
|
||||
func reset*(buckets: var Buckets, index: int) {.inline.} =
|
||||
func reset*(buckets: ptr Buckets, index: int) {.inline.} =
|
||||
buckets.status[index] = {}
|
||||
|
||||
func deriveSchedulerConstants*(c: int): tuple[numNZBuckets, queueLen: int] {.compileTime.} =
|
||||
@ -282,6 +282,9 @@ func init*[NumNZBuckets, QueueLen: static int, F; G: static Subgroup](
|
||||
sched.numScheduled = 0
|
||||
sched.numCollisions = 0
|
||||
|
||||
func bucketInit*(sched: ptr Scheduler) {.inline.} =
|
||||
zeroMem(sched.buckets.status.addr +% sched.start, (sched.stopEx-sched.start)*sizeof(set[BucketStatus]))
|
||||
|
||||
func scheduledPointDescriptor*(pointIndex: int, pointDesc: tuple[val: SecretWord, neg: SecretBool]): ScheduledPoint {.inline.} =
|
||||
ScheduledPoint(
|
||||
bucket: cast[int64](pointDesc.val)-1, # shift bucket by 1 as bucket 0 is skipped
|
||||
@ -548,7 +551,7 @@ func sparseVectorAddition[F, G](
|
||||
|
||||
func bucketReduce*[N, F, G](
|
||||
r: var ECP_ShortW_JacExt[F, G],
|
||||
buckets: var Buckets[N, F, G]) =
|
||||
buckets: ptr Buckets[N, F, G]) =
|
||||
|
||||
var accumBuckets{.noinit.}: ECP_ShortW_JacExt[F, G]
|
||||
|
||||
|
@ -31,7 +31,7 @@ proc sum_reduce_vartime_parallelChunks[F; G: static Subgroup](
|
||||
## Batch addition of `points` into `r`
|
||||
## `r` is overwritten
|
||||
## Compute is parallelized, if beneficial.
|
||||
## This function cannot be nested in another parallel function
|
||||
## This function can be nested in another parallel function
|
||||
|
||||
# Chunking constants in ec_shortweierstrass_batch_ops.nim
|
||||
const maxTempMem = 262144 # 2¹⁸ = 262144
|
||||
@ -50,18 +50,17 @@ proc sum_reduce_vartime_parallelChunks[F; G: static Subgroup](
|
||||
|
||||
let partialResults = allocStackArray(r.typeof(), chunkDesc.numChunks)
|
||||
|
||||
for iter in items(chunkDesc):
|
||||
proc sum_reduce_chunk_vartime_wrapper(res: ptr, p: ptr, pLen: int) {.nimcall.} =
|
||||
# The borrow checker prevents capturing `var` and `openArray`
|
||||
# so we capture pointers instead.
|
||||
res[].setInf()
|
||||
res[].accumSum_chunk_vartime(p, pLen)
|
||||
syncScope:
|
||||
for iter in items(chunkDesc):
|
||||
proc sum_reduce_chunk_vartime_wrapper(res: ptr, p: ptr, pLen: int) {.nimcall.} =
|
||||
# The borrow checker prevents capturing `var` and `openArray`
|
||||
# so we capture pointers instead.
|
||||
res[].setInf()
|
||||
res[].accumSum_chunk_vartime(p, pLen)
|
||||
|
||||
tp.spawn partialResults[iter.chunkID].addr.sum_reduce_chunk_vartime_wrapper(
|
||||
points.asUnchecked() +% iter.start,
|
||||
iter.size)
|
||||
|
||||
tp.syncAll() # TODO: this prevents nesting in another parallel region
|
||||
tp.spawn partialResults[iter.chunkID].addr.sum_reduce_chunk_vartime_wrapper(
|
||||
points.asUnchecked() +% iter.start,
|
||||
iter.size)
|
||||
|
||||
const minChunkSizeSerial = 32
|
||||
if chunkDesc.numChunks < minChunkSizeSerial:
|
||||
|
105
constantine/platforms/threadpool/crossthread/scoped_barriers.nim
Normal file
105
constantine/platforms/threadpool/crossthread/scoped_barriers.nim
Normal file
@ -0,0 +1,105 @@
|
||||
# Weave
|
||||
# Copyright (c) 2019 Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import
|
||||
std/atomics,
|
||||
../instrumentation
|
||||
|
||||
# Scoped barrier
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
# A scoped barrier allows detecting if there are child, grandchild or descendant tasks
|
||||
# still running so that a thread can wait for nested spawns.
|
||||
#
|
||||
# This solves the following problem:
|
||||
#
|
||||
# parallelFor i in 0 ..< M:
|
||||
# awaitable: loopI
|
||||
# parallelFor j in 0 ..< M:
|
||||
# awaitable: loopJ
|
||||
# discard
|
||||
#
|
||||
# Here, the loop i will add all the parallel tasks for loop j to the task queue
|
||||
# and continue on. Even if we await loopI, the task was really to just create the loopJ,
|
||||
# but the loop itself might still be pending in the task queue
|
||||
#
|
||||
# We could instead await loopJ before going to the next i iteration but that
|
||||
# restrict the parallelism exposed.
|
||||
#
|
||||
# Alternatively we could use `syncRoot` after those statements but that means
|
||||
# that it cannot be called from parallel code.
|
||||
#
|
||||
# In short:
|
||||
# - the scoped barrier should prevent a thread from continuing while any descendant task
|
||||
# is still running.
|
||||
# - it should be composable and parallel constructs need not to worry about its presence
|
||||
# in spawn functions.
|
||||
# - the enclosed scope should be able to expose all parallelism opportunities
|
||||
# in particular nested parallel for regions.
|
||||
|
||||
# TODO: transform into a Futex/eventcount and allow sleeping
|
||||
|
||||
type
|
||||
ScopedBarrier* = object
|
||||
## A scoped barrier allows detecting if there are child, grandchild or descendant tasks
|
||||
## still running so that a thread can wait for nested spawns.
|
||||
##
|
||||
## ScopedBarriers can be nested and work like a stack.
|
||||
## Only one can be active for a given thread for a given code section.
|
||||
##
|
||||
## They can be allocated on the stack given that a scoped barrier can not be exited
|
||||
## before all the descendant tasks exited and so the descendants cannot escape,
|
||||
## i.e. they have a pointer to their scope which is always valid.
|
||||
##
|
||||
## This means that in case of nested scopes, only the inner scope needs to track its descendants.
|
||||
descendants: Atomic[int]
|
||||
|
||||
# Note: If one is defined, the other destructors proc are not implictly create inline
|
||||
# even if trivial
|
||||
|
||||
proc `=`*(dst: var ScopedBarrier, src: ScopedBarrier) {.error: "A scoped barrier cannot be copied.".}
|
||||
|
||||
proc `=sink`*(dst: var ScopedBarrier, src: ScopedBarrier) {.inline.} =
|
||||
# Nim doesn't respect noinit and tries to zeroMem then move the type
|
||||
{.warning: "Moving a shared resource (an atomic type).".}
|
||||
system.`=sink`(dst.descendants, src.descendants)
|
||||
|
||||
proc `=destroy`*(sb: var ScopedBarrier) {.inline.}=
|
||||
preCondition: sb.descendants.load(moRelaxed) == 0
|
||||
# system.`=destroy`(sb.descendants)
|
||||
|
||||
proc initialize*(scopedBarrier: var ScopedBarrier) {.inline.} =
|
||||
## Initialize a scoped barrier
|
||||
scopedBarrier.descendants.store(0, moRelaxed)
|
||||
|
||||
proc registerDescendant*(scopedBarrier: ptr ScopedBarrier) {.inline.} =
|
||||
## Register a descendant task to the scoped barrier
|
||||
## Important: the thread creating the task must register the descendant
|
||||
## before handing them over to the runtime.
|
||||
## This way, if the end of scope is reached and we have 0 descendant it means that
|
||||
## - either no task was created in the scope
|
||||
## - tasks were created, but descendants increment/decrement cannot reach 0 before all descendants actually exited
|
||||
if not scopedBarrier.isNil:
|
||||
preCondition: scopedBarrier.descendants.load(moAcquire) >= 0
|
||||
discard scopedBarrier.descendants.fetchAdd(1, moRelease)
|
||||
postCondition: scopedBarrier.descendants.load(moAcquire) >= 1
|
||||
|
||||
proc unlistDescendant*(scopedBarrier: ptr ScopedBarrier) {.inline.} =
|
||||
## Unlist a descendant task from the scoped barrier.
|
||||
## Important: if that task spawned new tasks, it is fine even if those grandchild tasks
|
||||
## are still running, however they must have been registered to the scoped barrier to avoid race conditions.
|
||||
if not scopedBarrier.isNil:
|
||||
preCondition: scopedBarrier.descendants.load(moAcquire) >= 1
|
||||
fence(moRelease)
|
||||
discard scopedBarrier.descendants.fetchSub(1, moRelease)
|
||||
preCondition: scopedBarrier.descendants.load(moAcquire) >= 0
|
||||
|
||||
proc hasDescendantTasks*(scopedBarrier: ptr ScopedBarrier): bool {.inline.} =
|
||||
## Returns true if a scoped barrier has at least a descendant task.
|
||||
## This should only be called from the thread that created the scoped barrier.
|
||||
preCondition: scopedBarrier.descendants.load(moAcquire) >= 0
|
||||
return scopedBarrier.descendants.load(moAcquire) != 0
|
@ -8,6 +8,7 @@
|
||||
|
||||
import
|
||||
std/atomics,
|
||||
./scoped_barriers,
|
||||
../instrumentation,
|
||||
../../allocs,
|
||||
../primitives/futexes
|
||||
@ -49,6 +50,7 @@ type
|
||||
# ------------------
|
||||
state: TaskState
|
||||
parent*: ptr Task # Latency: When a task is awaited, a thread can quickly prioritize its direct children.
|
||||
scopedBarrier*: ptr ScopedBarrier
|
||||
hasFuture*: bool # Ownership: if a task has a future, the future deallocates it. Otherwise the worker thread does.
|
||||
|
||||
# Data parallelism
|
||||
@ -170,14 +172,18 @@ proc setThief*(task: ptr Task, thiefID: int32) {.inline.} =
|
||||
proc newSpawn*(
|
||||
T: typedesc[Task],
|
||||
parent: ptr Task,
|
||||
scopedBarrier: ptr ScopedBarrier,
|
||||
fn: proc (env: pointer) {.nimcall, gcsafe, raises: [].}
|
||||
): ptr Task {.inline.} =
|
||||
|
||||
const size = sizeof(T)
|
||||
|
||||
scopedBarrier.registerDescendant()
|
||||
|
||||
result = allocHeapUnchecked(T, size)
|
||||
result.initSynchroState()
|
||||
result.parent = parent
|
||||
result.scopedBarrier = scopedBarrier
|
||||
result.hasFuture = false
|
||||
result.fn = fn
|
||||
|
||||
@ -187,15 +193,19 @@ proc newSpawn*(
|
||||
proc newSpawn*(
|
||||
T: typedesc[Task],
|
||||
parent: ptr Task,
|
||||
scopedBarrier: ptr ScopedBarrier,
|
||||
fn: proc (env: pointer) {.nimcall, gcsafe, raises: [].},
|
||||
env: auto): ptr Task {.inline.} =
|
||||
|
||||
const size = sizeof(T) + # size without Unchecked
|
||||
sizeof(env)
|
||||
|
||||
scopedBarrier.registerDescendant()
|
||||
|
||||
result = allocHeapUnchecked(T, size)
|
||||
result.initSynchroState()
|
||||
result.parent = parent
|
||||
result.scopedBarrier = scopedBarrier
|
||||
result.hasFuture = false
|
||||
result.fn = fn
|
||||
cast[ptr[type env]](result.env)[] = env
|
||||
@ -209,6 +219,7 @@ func ceilDiv_vartime(a, b: auto): auto {.inline.} =
|
||||
proc newLoop*(
|
||||
T: typedesc[Task],
|
||||
parent: ptr Task,
|
||||
scopedBarrier: ptr ScopedBarrier,
|
||||
start, stop, stride: int,
|
||||
isFirstIter: bool,
|
||||
fn: proc (env: pointer) {.nimcall, gcsafe, raises: [].}
|
||||
@ -216,9 +227,12 @@ proc newLoop*(
|
||||
const size = sizeof(T)
|
||||
preCondition: start < stop
|
||||
|
||||
scopedBarrier.registerDescendant()
|
||||
|
||||
result = allocHeapUnchecked(T, size)
|
||||
result.initSynchroState()
|
||||
result.parent = parent
|
||||
result.scopedBarrier = scopedBarrier
|
||||
result.hasFuture = false
|
||||
result.fn = fn
|
||||
result.envSize = 0
|
||||
@ -233,6 +247,7 @@ proc newLoop*(
|
||||
proc newLoop*(
|
||||
T: typedesc[Task],
|
||||
parent: ptr Task,
|
||||
scopedBarrier: ptr ScopedBarrier,
|
||||
start, stop, stride: int,
|
||||
isFirstIter: bool,
|
||||
fn: proc (env: pointer) {.nimcall, gcsafe, raises: [].},
|
||||
@ -242,9 +257,12 @@ proc newLoop*(
|
||||
sizeof(env)
|
||||
preCondition: start < stop
|
||||
|
||||
scopedBarrier.registerDescendant()
|
||||
|
||||
result = allocHeapUnchecked(T, size)
|
||||
result.initSynchroState()
|
||||
result.parent = parent
|
||||
result.scopedBarrier = scopedBarrier
|
||||
result.hasFuture = false
|
||||
result.fn = fn
|
||||
result.envSize = int32(sizeof(env))
|
||||
@ -311,4 +329,4 @@ func getTask*[T](fv: FlowVar[T]): ptr Task {.inline.} =
|
||||
proc newReductionDagNode*(task: ptr Task, next: ptr ReductionDagNode): ptr ReductionDagNode {.inline.} =
|
||||
result = allocHeap(ReductionDagNode)
|
||||
result.next = next
|
||||
result.task = task
|
||||
result.task = task
|
||||
|
@ -94,11 +94,13 @@ proc spawnVoid(funcCall: NimNode, args, argsTy: NimNode, workerContext, schedule
|
||||
when bool(`withArgs`):
|
||||
let `task` = Task.newSpawn(
|
||||
parent = `workerContext`.currentTask,
|
||||
scopedBarrier = `workerContext`.currentScope,
|
||||
fn = `tpSpawn_closure`,
|
||||
env = `envParams`)
|
||||
else:
|
||||
let `task` = Task.newSpawn(
|
||||
parent = `workerContext`.currentTask,
|
||||
scopedBarrier = `workerContext`.currentScope,
|
||||
fn = `tpSpawn_closure`)
|
||||
`scheduleBlock`
|
||||
|
||||
@ -161,6 +163,7 @@ proc spawnVoidAwaitable(funcCall: NimNode, args, argsTy: NimNode, workerContext,
|
||||
|
||||
let `task` = Task.newSpawn(
|
||||
parent = `workerContext`.currentTask,
|
||||
scopedBarrier = `workerContext`.currentScope,
|
||||
fn = `tpSpawn_closure`,
|
||||
env = `envParams`)
|
||||
let `fut` = newFlowVar(bool, `task`)
|
||||
@ -230,6 +233,7 @@ proc spawnRet(funcCall: NimNode, retTy, args, argsTy: NimNode, workerContext, sc
|
||||
|
||||
let `task` = Task.newSpawn(
|
||||
parent = `workerContext`.currentTask,
|
||||
scopedBarrier = `workerContext`.currentScope,
|
||||
fn = `tpSpawn_closure`,
|
||||
env = `envParams`)
|
||||
let `fut` = newFlowVar(`retTy`, `task`)
|
||||
@ -589,6 +593,7 @@ proc generateAndScheduleLoopTask(ld: LoopDescriptor): NimNode =
|
||||
when bool(`withCaptures`):
|
||||
let `task` = Task.newLoop(
|
||||
parent = `workerContext`.currentTask,
|
||||
scopedBarrier = `workerContext`.currentScope,
|
||||
start, stopEx, `stride`,
|
||||
isFirstIter = true,
|
||||
fn = `closureName`,
|
||||
@ -596,6 +601,7 @@ proc generateAndScheduleLoopTask(ld: LoopDescriptor): NimNode =
|
||||
else:
|
||||
let `task` = Task.newLoop(
|
||||
parent = `workerContext`.currentTask,
|
||||
scopedBarrier = `workerContext`.currentScope,
|
||||
start, stopEx, `stride`,
|
||||
isFirstIter = true,
|
||||
fn = `closureName`)
|
||||
@ -613,6 +619,7 @@ proc generateAndScheduleLoopTask(ld: LoopDescriptor): NimNode =
|
||||
when bool(`withCaptures`):
|
||||
let `task` = Task.newLoop(
|
||||
parent = `workerContext`.currentTask,
|
||||
scopedBarrier = `workerContext`.currentScope,
|
||||
start, stopEx, `stride`,
|
||||
isFirstIter = true,
|
||||
fn = `closureName`,
|
||||
@ -620,6 +627,7 @@ proc generateAndScheduleLoopTask(ld: LoopDescriptor): NimNode =
|
||||
else:
|
||||
let `task` = Task.newLoop(
|
||||
parent = `workerContext`.currentTask,
|
||||
scopedBarrier = `workerContext`.currentScope,
|
||||
start, stopEx, `stride`,
|
||||
isFirstIter = true,
|
||||
fn = `closureName`,
|
||||
|
@ -16,6 +16,7 @@ import
|
||||
./crossthread/[
|
||||
taskqueues,
|
||||
backoff,
|
||||
scoped_barriers,
|
||||
tasks_flowvars],
|
||||
./instrumentation,
|
||||
./primitives/barriers,
|
||||
@ -173,6 +174,7 @@ type
|
||||
currentTask: ptr Task
|
||||
|
||||
# Synchronization
|
||||
currentScope*: ptr ScopedBarrier # need to be exported for syncScope template
|
||||
signal: ptr Signal # owned signal
|
||||
|
||||
# Thefts
|
||||
@ -273,6 +275,7 @@ proc setupWorker(ctx: var WorkerContext) =
|
||||
ctx.rng.seed(0xEFFACED + ctx.id)
|
||||
|
||||
# Synchronization
|
||||
ctx.currentScope = nil
|
||||
ctx.signal = addr ctx.threadpool.workerSignals[ctx.id]
|
||||
ctx.signal.terminate.store(false, moRelaxed)
|
||||
|
||||
@ -359,13 +362,21 @@ const RootTask = cast[ptr Task](0xEFFACED0)
|
||||
|
||||
proc run(ctx: var WorkerContext, task: ptr Task) {.raises:[].} =
|
||||
## Run a task, frees it if it is not owned by a Flowvar
|
||||
|
||||
let suspendedTask = ctx.currentTask
|
||||
let suspendedScope = ctx.currentScope
|
||||
|
||||
ctx.currentTask = task
|
||||
ctx.currentScope = task.scopedBarrier
|
||||
|
||||
debug: log("Worker %3d: running task 0x%.08x (previous: 0x%.08x, %d pending, thiefID %d)\n", ctx.id, task, suspendedTask, ctx.taskqueue[].peek(), task.getThief())
|
||||
profile(run_task):
|
||||
task.fn(task.env.addr)
|
||||
task.scopedBarrier.unlistDescendant()
|
||||
debug: log("Worker %3d: completed task 0x%.08x (%d pending)\n", ctx.id, task, ctx.taskqueue[].peek())
|
||||
|
||||
ctx.currentTask = suspendedTask
|
||||
ctx.currentScope = suspendedScope
|
||||
|
||||
ctx.incCounter(tasksExecuted)
|
||||
ctx.incCounter(itersExecuted):
|
||||
@ -387,7 +398,7 @@ proc schedule(ctx: var WorkerContext, task: ptr Task, forceWake = false) {.inlin
|
||||
## Schedule a task in the threadpool
|
||||
## This wakes another worker if our local queue is empty
|
||||
## or forceWake is true.
|
||||
debug: log("Worker %3d: schedule task 0x%.08x (parent/current task 0x%.08x)\n", ctx.id, task, task.parent)
|
||||
debug: log("Worker %3d: schedule task 0x%.08x (parent/current task 0x%.08x, scope 0x%.08x)\n", ctx.id, task, task.parent, task.scopedBarrier)
|
||||
|
||||
# Instead of notifying every time a task is scheduled, we notify
|
||||
# only when the worker queue is empty. This is a good approximation
|
||||
@ -536,6 +547,7 @@ proc splitAndDispatchLoop(ctx: var WorkerContext, task: ptr Task, curLoopIndex:
|
||||
|
||||
upperSplit.initSynchroState()
|
||||
upperSplit.parent = task
|
||||
upperSplit.scopedBarrier.registerDescendant()
|
||||
|
||||
upperSplit.isFirstIter = false
|
||||
upperSplit.loopStart = offset
|
||||
@ -839,7 +851,7 @@ proc syncAll*(tp: Threadpool) {.raises: [].} =
|
||||
template ctx: untyped = workerContext
|
||||
|
||||
debugTermination:
|
||||
log(">>> Worker %3d enters barrier <<<\n", ctx.id)
|
||||
log(">>> Worker %3d enters global barrier <<<\n", ctx.id)
|
||||
|
||||
preCondition: ctx.id == 0
|
||||
preCondition: ctx.currentTask.isRootTask()
|
||||
@ -880,6 +892,46 @@ proc syncAll*(tp: Threadpool) {.raises: [].} =
|
||||
|
||||
profileStart(run_task)
|
||||
|
||||
proc wait(scopedBarrier: ptr ScopedBarrier) {.raises:[], gcsafe.} =
|
||||
## Wait at barrier until all descendant tasks are completed
|
||||
template ctx: untyped = workerContext
|
||||
|
||||
debugTermination:
|
||||
log(">>> Worker %3d enters scoped barrier 0x%.08x <<<\n", ctx.id, scopedBarrier)
|
||||
|
||||
while scopedBarrier.hasDescendantTasks():
|
||||
# 1. Empty local tasks, the initial loop only has tasks from that scope or a child scope.
|
||||
debug: log("Worker %3d: syncScope 1 - searching task from local queue\n", ctx.id)
|
||||
while (let task = ctx.taskqueue[].pop(); not task.isNil):
|
||||
debug: log("Worker %3d: syncScope 1 - running task 0x%.08x (parent 0x%.08x, current 0x%.08x, scope 0x%.08x)\n", ctx.id, task, task.parent, ctx.currentTask, task.scopedBarrier)
|
||||
ctx.run(task)
|
||||
if not scopedBarrier.hasDescendantTasks():
|
||||
debugTermination:
|
||||
log(">>> Worker %3d exits scoped barrier 0x%.08x <<<\n", ctx.id, scopedBarrier)
|
||||
return
|
||||
|
||||
# TODO: consider leapfrogging
|
||||
|
||||
if (var stolenTask = ctx.tryStealOne(); not stolenTask.isNil):
|
||||
# 2.a We stole some task
|
||||
debug: log("Worker %3d: syncScope 2 - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x, scope 0x%.08x)\n", ctx.id, stolenTask, stolenTask.parent, ctx.currentTask, stolenTask.scopedBarrier)
|
||||
|
||||
# Theft successful, there might be more work for idle threads, wake one
|
||||
ctx.threadpool.globalBackoff.wake()
|
||||
ctx.incCounter(backoffGlobalSignalSent)
|
||||
|
||||
ctx.incCounter(theftsIdle)
|
||||
ctx.incCounter(itersStolen):
|
||||
if stolenTask.loopStepsLeft == NotALoop: 0
|
||||
else: stolenTask.loopStepsLeft
|
||||
ctx.run(stolenTask)
|
||||
else:
|
||||
# TODO: backoff
|
||||
cpuRelax()
|
||||
|
||||
debugTermination:
|
||||
log(">>> Worker %3d exits scoped barrier 0x%.08x <<<\n", ctx.id, scopedBarrier)
|
||||
|
||||
# ############################################################
|
||||
# #
|
||||
# Runtime API #
|
||||
@ -973,6 +1025,20 @@ proc shutdown*(tp: var Threadpool) {.raises:[].} =
|
||||
# #
|
||||
# ############################################################
|
||||
|
||||
# Structured parallelism API
|
||||
# ---------------------------------------------
|
||||
|
||||
template syncScope*(body: untyped): untyped =
|
||||
block:
|
||||
let suspendedScope = workerContext.currentScope
|
||||
var scopedBarrier {.noInit.}: ScopedBarrier
|
||||
initialize(scopedBarrier)
|
||||
workerContext.currentScope = addr scopedBarrier
|
||||
block:
|
||||
body
|
||||
wait(scopedBarrier.addr)
|
||||
workerContext.currentScope = suspendedScope
|
||||
|
||||
# Task parallel API
|
||||
# ---------------------------------------------
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user