MSM tuning for high core count (#227)

* tune for high core count

* reentrancy: allow nesting of parallel functions by introducing precise scoped barriers

* increase collision queue depth
This commit is contained in:
Mamy Ratsimbazafy 2023-04-14 20:02:59 +02:00 committed by GitHub
parent 6c48975aee
commit 93dac2503c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 432 additions and 138 deletions

View File

@ -65,35 +65,35 @@ proc msmParallelBench*(EC: typedesc, numPoints: int, iters: int) =
var startNaive, stopNaive, startMSMbaseline, stopMSMbaseline, startMSMopt, stopMSMopt, startMSMpara, stopMSMpara: MonoTime
if numPoints <= 100000:
startNaive = getMonotime()
bench("EC scalar muls " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
startNaive = getMonotime()
var tmp: EC
r.setInf()
for i in 0 ..< points.len:
tmp.fromAffine(points[i])
tmp.scalarMul(scalars[i])
r += tmp
stopNaive = getMonotime()
stopNaive = getMonotime()
block:
if numPoints <= 100000:
startMSMbaseline = getMonotime()
bench("EC multi-scalar-mul baseline " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
startMSMbaseline = getMonotime()
r.multiScalarMul_reference_vartime(scalars, points)
stopMSMbaseline = getMonotime()
stopMSMbaseline = getMonotime()
block:
startMSMopt = getMonotime()
bench("EC multi-scalar-mul optimized " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
startMSMopt = getMonotime()
r.multiScalarMul_vartime(scalars, points)
stopMSMopt = getMonotime()
stopMSMopt = getMonotime()
block:
var tp = Threadpool.new()
startMSMpara = getMonotime()
bench("EC multi-scalar-mul" & align($tp.numThreads & " threads", 11) & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
startMSMpara = getMonotime()
tp.multiScalarMul_vartime_parallel(r, scalars, points)
stopMSMpara = getMonotime()
stopMSMpara = getMonotime()
tp.shutdown()
@ -109,8 +109,8 @@ proc msmParallelBench*(EC: typedesc, numPoints: int, iters: int) =
let speedupOpt = float(perfNaive) / float(perfMSMopt)
echo &"Speedup ratio optimized over naive linear combination: {speedupOpt:>6.3f}x"
let speedupOptBaseline = float(perfMSMbaseline) / float(perfMSMopt)
echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x"
let speedupOptBaseline = float(perfMSMbaseline) / float(perfMSMopt)
echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x"
let speedupParaOpt = float(perfMSMopt) / float(perfMSMpara)
echo &"Speedup ratio parallel over optimized linear combination: {speedupParaOpt:>6.3f}x"

View File

@ -118,12 +118,8 @@ func multiScalarMul_reference_vartime*[EC](r: var EC, coefs: openArray[BigInt],
of 13: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 13)
of 14: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 14)
of 15: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 15)
of 16: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 16)
of 17: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 17)
of 18: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 18)
of 19: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 19)
of 20: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 20)
of 21: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 21)
of 16..20: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 16)
else:
unreachable()
@ -271,6 +267,8 @@ func schedAccumulate*[NumBuckets, QueueLen, F, G; bits: static int](
const top = bits - excess
static: doAssert miniMsmKind != kTopWindow, "The top window is smaller in bits which increases collisions in scheduler."
sched.bucketInit()
var curSP, nextSP: ScheduledPoint
template getSignedWindow(j : int): tuple[val: SecretWord, neg: SecretBool] =
@ -295,14 +293,12 @@ func miniMSM_affine[NumBuckets, QueueLen, F, G; bits: static int](
## Apply a mini-Multi-Scalar-Multiplication on [bitIndex, bitIndex+window)
## slice of all (coef, point) pairs
sched.buckets[].init()
# 1. Bucket Accumulation
sched.schedAccumulate(bitIndex, miniMsmKind, c, coefs, N)
# 2. Bucket Reduction
var windowSum_jacext{.noInit.}: ECP_ShortW_JacExt[F, G]
windowSum_jacext.bucketReduce(sched.buckets[])
windowSum_jacext.bucketReduce(sched.buckets)
# 3. Mini-MSM on the slice [bitIndex, bitIndex+window)
var windowSum{.noInit.}: typeof(r)
@ -324,7 +320,6 @@ func multiScalarMulAffine_vartime[F, G; bits: static int](
# -----
const (numBuckets, queueLen) = c.deriveSchedulerConstants()
let buckets = allocHeap(Buckets[numBuckets, F, G])
buckets[].init()
let sched = allocHeap(Scheduler[numBuckets, queueLen, F, G])
sched.init(points, buckets, 0, numBuckets.int32)
@ -440,11 +435,10 @@ func multiScalarMul_dispatch_vartime[bits: static int, F, G](
of 11: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 11)
of 12: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 12)
of 13: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 13)
of 14: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 14)
of 15: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 15)
of 16: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 16)
of 17: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 17)
of 18: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 18)
of 14: multiScalarMulAffine_vartime(r, coefs, points, N, c = 14)
of 15: multiScalarMulAffine_vartime(r, coefs, points, N, c = 15)
of 16..17: multiScalarMulAffine_vartime(r, coefs, points, N, c = 16)
else:
unreachable()
@ -458,4 +452,4 @@ func multiScalarMul_vartime*[bits: static int, F, G](
debug: doAssert coefs.len == points.len
let N = points.len
multiScalarMul_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N)
multiScalarMul_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N)

View File

@ -11,7 +11,7 @@ import ./ec_multi_scalar_mul_scheduler,
./ec_endomorphism_accel,
../extension_fields,
../constants/zoo_endomorphisms,
../../platforms/threadpool/threadpool
../../platforms/threadpool/[threadpool, partitioners]
export bestBucketBitSize
# No exceptions allowed in core cryptographic operations
@ -133,8 +133,8 @@ export bestBucketBitSize
# Parallel MSM Jacobian Extended
# ------------------------------
proc bucketAccumReduce_jacext_zeroMem[F, G; bits: static int](
windowSum: ptr ECP_ShortW[F, G],
proc bucketAccumReduce_jacext_zeroMem[EC, F, G; bits: static int](
windowSum: ptr EC,
buckets: ptr ECP_ShortW_JacExt[F, G] or ptr UncheckedArray[ECP_ShortW_JacExt[F, G]],
bitIndex: int, miniMsmKind: static MiniMsmKind, c: static int,
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], N: int) =
@ -143,9 +143,9 @@ proc bucketAccumReduce_jacext_zeroMem[F, G; bits: static int](
zeroMem(buckets, sizeof(ECP_ShortW_JacExt[F, G]) * numBuckets)
bucketAccumReduce_jacext(windowSum[], buckets, bitIndex, miniMsmKind, c, coefs, points, N)
proc msmJacExt_vartime_parallel*[bits: static int, F, G](
proc msmJacExt_vartime_parallel*[bits: static int, EC, F, G](
tp: Threadpool,
r: var ECP_ShortW[F, G],
r: ptr EC,
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
N: int, c: static int) =
@ -158,7 +158,6 @@ proc msmJacExt_vartime_parallel*[bits: static int, F, G](
# Instead of storing the result in futures, risking them being scattered in memory
# we store them in a contiguous array, and the synchronizing future just returns a bool.
# top window is done on this thread
type EC = typeof(r)
let miniMSMsResults = allocHeapArray(EC, numFullWindows)
let miniMSMsReady = allocStackArray(FlowVar[bool], numFullWindows)
@ -188,41 +187,38 @@ proc msmJacExt_vartime_parallel*[bits: static int, F, G](
when top != 0:
when excess != 0:
bucketAccumReduce_jacext_zeroMem(
r.addr,
r,
bucketsMatrix[numFullWindows*numBuckets].addr,
bitIndex = top, kTopWindow, c,
coefs, points, N)
else:
r.setInf()
r[].setInf()
# 3. Final reduction, r initialized to what would be miniMSMsReady[numWindows-1]
when excess != 0:
for w in countdown(numWindows-2, 0):
for _ in 0 ..< c:
r.double()
r[].double()
discard sync miniMSMsReady[w]
r += miniMSMsResults[w]
r[] += miniMSMsResults[w]
elif numWindows >= 2:
discard sync miniMSMsReady[numWindows-2]
r = miniMSMsResults[numWindows-2]
r[] = miniMSMsResults[numWindows-2]
for w in countdown(numWindows-3, 0):
for _ in 0 ..< c:
r.double()
r[].double()
discard sync miniMSMsReady[w]
r += miniMSMsResults[w]
r[] += miniMSMsResults[w]
# Cleanup
# -------
miniMSMsResults.freeHeap()
bucketsMatrix.freeHeap()
# Parallel MSM Affine
# ------------------------------
proc bucketAccumReduce_parallel[bits: static int, F, G](
tp: Threadpool,
r: ptr ECP_ShortW[F, G],
# Parallel MSM Affine - bucket accumulation
# -----------------------------------------
proc bucketAccumReduce_serial[bits: static int, EC, F, G](
r: ptr EC,
bitIndex: int,
miniMsmKind: static MiniMsmKind, c: static int,
coefs: ptr UncheckedArray[BigInt[bits]],
@ -230,21 +226,46 @@ proc bucketAccumReduce_parallel[bits: static int, F, G](
N: int) =
const (numBuckets, queueLen) = c.deriveSchedulerConstants()
const outerParallelism = bits div c # It's actually ceilDiv instead of floorDiv, but the last iteration might be too small
let buckets = allocHeap(Buckets[numBuckets, F, G])
let sched = allocHeap(Scheduler[numBuckets, queueLen, F, G])
sched.init(points, buckets, 0, numBuckets.int32)
var innerParallelism = 1'i32
while outerParallelism*innerParallelism < tp.numThreads:
innerParallelism = innerParallelism shl 1
# 1. Bucket Accumulation
sched.schedAccumulate(bitIndex, miniMsmKind, c, coefs, N)
let numChunks = 1'i32 # innerParallelism # TODO: unfortunately trying to expose more parallelism slows down the performance
# 2. Bucket Reduction
var windowSum{.noInit.}: ECP_ShortW_JacExt[F, G]
windowSum.bucketReduce(sched.buckets)
r[].fromJacobianExtended_vartime(windowSum)
# Cleanup
# ----------------
sched.freeHeap()
buckets.freeHeap()
proc bucketAccumReduce_parallel[bits: static int, EC, F, G](
tp: Threadpool,
r: ptr EC,
bitIndex: int,
miniMsmKind: static MiniMsmKind, c: static int,
coefs: ptr UncheckedArray[BigInt[bits]],
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
N: int) =
const (numBuckets, queueLen) = c.deriveSchedulerConstants()
const windowParallelism = bits div c # It's actually ceilDiv instead of floorDiv, but the last iteration might be too small
var bucketParallelism = 1'i32
while windowParallelism*bucketParallelism < tp.numThreads:
bucketParallelism = bucketParallelism shl 1
let numChunks = bucketParallelism
let chunkSize = int32(numBuckets) shr log2_vartime(cast[uint32](numChunks)) # Both are power of 2 so exact division
let chunksReadiness = allocStackArray(FlowVar[bool], numChunks-1) # Last chunk is done on this thread
let buckets = allocHeap(Buckets[numBuckets, F, G])
let scheds = allocHeapArray(Scheduler[numBuckets, queueLen, F, G], numChunks)
buckets[].init()
block: # 1. Bucket Accumulation
for chunkID in 0'i32 ..< numChunks-1:
let idx = chunkID*chunkSize
@ -254,7 +275,7 @@ proc bucketAccumReduce_parallel[bits: static int, F, G](
scheds[numChunks-1].addr.init(points, buckets, (numChunks-1)*chunkSize, int32 numBuckets)
scheds[numChunks-1].addr.schedAccumulate(bitIndex, miniMsmKind, c, coefs, N)
block: # 2. Bucket reduction
block: # 2. Bucket reduction with latency hiding
var windowSum{.noInit.}: ECP_ShortW_JacExt[F, G]
var accumBuckets{.noinit.}: ECP_ShortW_JacExt[F, G]
@ -268,7 +289,7 @@ proc bucketAccumReduce_parallel[bits: static int, F, G](
else:
accumBuckets.setInf()
windowSum = accumBuckets
buckets[].reset(numBuckets-1)
buckets.reset(numBuckets-1)
var nextBatch = numBuckets-1-chunkSize
var nextFutureIdx = numChunks-2
@ -289,7 +310,7 @@ proc bucketAccumReduce_parallel[bits: static int, F, G](
elif kJacExt in buckets.status[k]:
accumBuckets += buckets.ptJacExt[k]
buckets[].reset(k)
buckets.reset(k)
windowSum += accumBuckets
r[].fromJacobianExtended_vartime(windowSum)
@ -299,11 +320,14 @@ proc bucketAccumReduce_parallel[bits: static int, F, G](
scheds.freeHeap()
buckets.freeHeap()
proc msmAffine_vartime_parallel*[bits: static int, F, G](
# Parallel MSM Affine - window-level only
# ---------------------------------------
proc msmAffine_vartime_parallel*[bits: static int, EC, F, G](
tp: Threadpool,
r: var ECP_ShortW[F, G],
r: ptr EC,
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
N: int, c: static int) =
N: int, c: static int, useParallelBuckets: static bool) =
# Prologue
# --------
@ -314,24 +338,36 @@ proc msmAffine_vartime_parallel*[bits: static int, F, G](
# Instead of storing the result in futures, risking them being scattered in memory
# we store them in a contiguous array, and the synchronizing future just returns a bool.
# top window is done on this thread
type EC = typeof(r)
type EC = typeof(r[])
let miniMSMsResults = allocHeapArray(EC, numFullWindows)
let miniMSMsReady = allocStackArray(Flowvar[bool], numFullWindows)
# Algorithm
# ---------
block: # 1. Bucket accumulation and reduction
# 1. mini-MSMs: Bucket accumulation and reduction
when useParallelBuckets:
miniMSMsReady[0] = tp.spawnAwaitable bucketAccumReduce_parallel(
tp, miniMSMsResults[0].addr,
bitIndex = 0, kBottomWIndow, c,
coefs, points, N)
tp, miniMSMsResults[0].addr,
bitIndex = 0, kBottomWindow, c,
coefs, points, N)
for w in 1 ..< numFullWindows:
miniMSMsReady[w] = tp.spawnAwaitable bucketAccumReduce_parallel(
tp, miniMSMsResults[w].addr,
bitIndex = w*c, kFullWIndow, c,
coefs, points, N)
for w in 1 ..< numFullWindows:
miniMSMsReady[w] = tp.spawnAwaitable bucketAccumReduce_parallel(
tp, miniMSMsResults[w].addr,
bitIndex = w*c, kFullWindow, c,
coefs, points, N)
else:
miniMSMsReady[0] = tp.spawnAwaitable bucketAccumReduce_serial(
miniMSMsResults[0].addr,
bitIndex = 0, kBottomWindow, c,
coefs, points, N)
for w in 1 ..< numFullWindows:
miniMSMsReady[w] = tp.spawnAwaitable bucketAccumReduce_serial(
miniMSMsResults[w].addr,
bitIndex = w*c, kFullWindow, c,
coefs, points, N)
# Last window is done sync on this thread, directly initializing r
const excess = bits mod c
@ -341,32 +377,79 @@ proc msmAffine_vartime_parallel*[bits: static int, F, G](
when excess != 0:
let buckets = allocHeapArray(ECP_ShortW_JacExt[F, G], numBuckets)
zeroMem(buckets[0].addr, sizeof(ECP_ShortW_JacExt[F, G]) * numBuckets)
r.bucketAccumReduce_jacext(buckets, bitIndex = top, kTopWindow, c,
r[].bucketAccumReduce_jacext(buckets, bitIndex = top, kTopWindow, c,
coefs, points, N)
buckets.freeHeap()
else:
r.setInf()
r[].setInf()
# 3. Final reduction, r initialized to what would be miniMSMsReady[numWindows-1]
# 2. Final reduction with latency hiding, r initialized to what would be miniMSMsReady[numWindows-1]
when excess != 0:
for w in countdown(numWindows-2, 0):
for _ in 0 ..< c:
r.double()
r[].double()
discard sync miniMSMsReady[w]
r += miniMSMsResults[w]
r[] += miniMSMsResults[w]
elif numWindows >= 2:
discard sync miniMSMsReady[numWindows-2]
r = miniMSMsResults[numWindows-2]
r[] = miniMSMsResults[numWindows-2]
for w in countdown(numWindows-3, 0):
for _ in 0 ..< c:
r.double()
r[].double()
discard sync miniMSMsReady[w]
r += miniMSMsResults[w]
r[] += miniMSMsResults[w]
# Cleanup
# -------
miniMSMsResults.freeHeap()
proc msmAffine_vartime_parallel_split[bits: static int, EC, F, G](
tp: Threadpool,
r: ptr EC,
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
N: int, c: static int, useParallelBuckets: static bool) =
# Parallelism levels:
# - MSM parallelism: compute independent MSMs, this increases the number of total ops
# - window parallelism: compute a MSM outer loop on different threads, this has no tradeoffs
# - bucket parallelism: handle range of buckets on different threads, threads do superfluous overlapping memory reads
#
# It seems to be beneficial to have both MSM and bucket level parallelism.
# Probably by guaranteeing 2x more tasks than threads, we avoid starvation.
var windowParallelism = bits div c # It's actually ceilDiv instead of floorDiv, but the last iteration might be too small
var msmParallelism = 1'i32
while windowParallelism*msmParallelism < tp.numThreads:
windowParallelism = bits div c # This is an approximation
msmParallelism = msmParallelism shl 1
if msmParallelism == 1:
msmAffine_vartime_parallel(tp, r, coefs, points, N, c, useParallelBuckets)
return
let chunkingDescriptor = balancedChunksPrioNumber(0, N, msmParallelism)
let splitMSMsResults = allocHeapArray(typeof(r[]), msmParallelism-1)
let splitMSMsReady = allocStackArray(Flowvar[bool], msmParallelism-1)
for (i, start, len) in items(chunkingDescriptor):
if i != msmParallelism-1:
splitMSMsReady[i] = tp.spawnAwaitable msmAffine_vartime_parallel(
tp, splitMSMsResults[i].addr,
coefs +% start, points +% start, len,
c, useParallelBuckets)
else: # Run last on this thread
msmAffine_vartime_parallel(
tp, r,
coefs +% start, points +% start, len,
c, useParallelBuckets)
for i in countdown(msmParallelism-2, 0):
discard sync splitMSMsReady[i]
r[] += splitMSMsResults[i]
freeHeap(splitMSMsResults)
proc applyEndomorphism_parallel[bits: static int, F, G](
tp: Threadpool,
coefs: ptr UncheckedArray[BigInt[bits]],
@ -384,39 +467,38 @@ proc applyEndomorphism_parallel[bits: static int, F, G](
let splitCoefs = allocHeapArray(array[M, BigInt[L]], N)
let endoBasis = allocHeapArray(array[M, ECP_ShortW_Aff[F, G]], N)
tp.parallelFor i in 0 ..< N:
captures: {coefs, points, splitCoefs, endoBasis}
syncScope:
tp.parallelFor i in 0 ..< N:
captures: {coefs, points, splitCoefs, endoBasis}
var negatePoints {.noinit.}: array[M, SecretBool]
splitCoefs[i].decomposeEndo(negatePoints, coefs[i], F)
if negatePoints[0].bool:
endoBasis[i][0].neg(points[i])
else:
endoBasis[i][0] = points[i]
when F is Fp:
endoBasis[i][1].x.prod(points[i].x, F.C.getCubicRootOfUnity_mod_p())
if negatePoints[1].bool:
endoBasis[i][1].y.neg(points[i].y)
var negatePoints {.noinit.}: array[M, SecretBool]
splitCoefs[i].decomposeEndo(negatePoints, coefs[i], F)
if negatePoints[0].bool:
endoBasis[i][0].neg(points[i])
else:
endoBasis[i][1].y = points[i].y
else:
staticFor m, 1, M:
endoBasis[i][m].frobenius_psi(points[i], m)
if negatePoints[m].bool:
endoBasis[i][m].neg()
endoBasis[i][0] = points[i]
tp.syncAll()
when F is Fp:
endoBasis[i][1].x.prod(points[i].x, F.C.getCubicRootOfUnity_mod_p())
if negatePoints[1].bool:
endoBasis[i][1].y.neg(points[i].y)
else:
endoBasis[i][1].y = points[i].y
else:
staticFor m, 1, M:
endoBasis[i][m].frobenius_psi(points[i], m)
if negatePoints[m].bool:
endoBasis[i][m].neg()
let endoCoefs = cast[ptr UncheckedArray[BigInt[L]]](splitCoefs)
let endoPoints = cast[ptr UncheckedArray[ECP_ShortW_Aff[F, G]]](endoBasis)
return (endoCoefs, endoPoints, M*N)
template withEndo[bits: static int, F, G](
template withEndo[bits: static int, EC, F, G](
msmProc: untyped,
tp: Threadpool,
r: var ECP_ShortW[F, G],
r: ptr EC,
coefs: ptr UncheckedArray[BigInt[bits]],
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
N: int, c: static int) =
@ -430,9 +512,26 @@ template withEndo[bits: static int, F, G](
else:
msmProc(tp, r, coefs, points, N, c)
proc multiScalarMul_dispatch_vartime_parallel[bits: static int, F, G](
template withEndo[bits: static int, EC, F, G](
msmProc: untyped,
tp: Threadpool,
r: ptr EC,
coefs: ptr UncheckedArray[BigInt[bits]],
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
N: int, c: static int, useParallelBuckets: static bool) =
when bits <= F.C.getCurveOrderBitwidth() and hasEndomorphismAcceleration(F.C):
let (endoCoefs, endoPoints, endoN) = applyEndomorphism_parallel(tp, coefs, points, N)
# Given that bits and N changed, we are able to use a bigger `c`
# but it has no significant impact on performance
msmProc(tp, r, endoCoefs, endoPoints, endoN, c, useParallelBuckets)
freeHeap(endoCoefs)
freeHeap(endoPoints)
else:
msmProc(tp, r, coefs, points, N, c, useParallelBuckets)
proc multiScalarMul_dispatch_vartime_parallel[bits: static int, EC, F, G](
tp: Threadpool,
r: var ECP_ShortW[F, G], coefs: ptr UncheckedArray[BigInt[bits]],
r: ptr EC, coefs: ptr UncheckedArray[BigInt[bits]],
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], N: int) =
## Multiscalar multiplication:
## r <- [a₀]P₀ + [a₁]P₁ + ... + [aₙ]Pₙ
@ -448,30 +547,32 @@ proc multiScalarMul_dispatch_vartime_parallel[bits: static int, F, G](
of 4: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 4)
of 5: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 5)
of 6: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 6)
of 7: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 7)
of 8: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 8)
of 9: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 9)
of 10: withEndo(msmJacExt_vartime_parallel, tp, r, coefs, points, N, c = 10)
of 11: withEndo(msmAffine_vartime_parallel, tp, r, coefs, points, N, c = 11)
of 7: msmJacExt_vartime_parallel(tp, r, coefs, points, N, c = 7)
of 8: msmJacExt_vartime_parallel(tp, r, coefs, points, N, c = 8)
of 12: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 12)
of 13: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 13)
of 14: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 14)
of 15: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 15)
of 16: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 16)
of 17: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 17)
of 18: msmAffine_vartime_parallel(tp, r, coefs, points, N, c = 18)
of 9: withEndo(msmAffine_vartime_parallel_split, tp, r, coefs, points, N, c = 9, useParallelBuckets = true)
of 10: withEndo(msmAffine_vartime_parallel_split, tp, r, coefs, points, N, c = 10, useParallelBuckets = true)
of 11: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 10, useParallelBuckets = true)
of 12: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 11, useParallelBuckets = true)
of 13: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 12, useParallelBuckets = true)
of 14: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 13, useParallelBuckets = true)
of 15: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 14, useParallelBuckets = true)
of 16: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 15, useParallelBuckets = true)
of 17: msmAffine_vartime_parallel_split(tp, r, coefs, points, N, c = 16, useParallelBuckets = true)
else:
unreachable()
proc multiScalarMul_vartime_parallel*[bits: static int, F, G](
proc multiScalarMul_vartime_parallel*[bits: static int, EC, F, G](
tp: Threadpool,
r: var ECP_ShortW[F, G],
r: var EC,
coefs: openArray[BigInt[bits]],
points: openArray[ECP_ShortW_Aff[F, G]]) {.meter, inline.} =
## Multiscalar multiplication:
## r <- [a₀]P₀ + [a₁]P₁ + ... + [aₙ]Pₙ
## This function can be nested in another parallel function
debug: doAssert coefs.len == points.len
let N = points.len
tp.multiScalarMul_dispatch_vartime_parallel(r, coefs.asUnchecked(), points.asUnchecked(), N)
tp.multiScalarMul_dispatch_vartime_parallel(r.addr, coefs.asUnchecked(), points.asUnchecked(), N)

View File

@ -190,7 +190,7 @@ func bestBucketBitSize*(inputSize: int, scalarBitwidth: static int, useSignedBuc
let n = inputSize
let b = float32(scalarBitwidth)
var minCost = float32(Inf)
for c in 2 .. 21:
for c in 2 .. 20: # cap return value at 17 after manual tuning
let b_over_c = b/c.float32
let bucket_accumulate_reduce = b_over_c * float32(n + (1 shl (c-s)) - 2) * A
@ -255,14 +255,14 @@ type
numScheduled, numCollisions: int32
collisionsMap: BigInt[NumNZBuckets] # We use a BigInt as a bitmap, when all you have is an axe ...
queue: array[QueueLen, ScheduledPoint]
collisions: array[32, ScheduledPoint]
collisions: array[QueueLen, ScheduledPoint]
const MinVectorAddThreshold = 32
func init*(buckets: var Buckets) {.inline.} =
func init*(buckets: ptr Buckets) {.inline.} =
zeroMem(buckets.status.addr, buckets.status.sizeof())
func reset*(buckets: var Buckets, index: int) {.inline.} =
func reset*(buckets: ptr Buckets, index: int) {.inline.} =
buckets.status[index] = {}
func deriveSchedulerConstants*(c: int): tuple[numNZBuckets, queueLen: int] {.compileTime.} =
@ -282,6 +282,9 @@ func init*[NumNZBuckets, QueueLen: static int, F; G: static Subgroup](
sched.numScheduled = 0
sched.numCollisions = 0
func bucketInit*(sched: ptr Scheduler) {.inline.} =
zeroMem(sched.buckets.status.addr +% sched.start, (sched.stopEx-sched.start)*sizeof(set[BucketStatus]))
func scheduledPointDescriptor*(pointIndex: int, pointDesc: tuple[val: SecretWord, neg: SecretBool]): ScheduledPoint {.inline.} =
ScheduledPoint(
bucket: cast[int64](pointDesc.val)-1, # shift bucket by 1 as bucket 0 is skipped
@ -548,7 +551,7 @@ func sparseVectorAddition[F, G](
func bucketReduce*[N, F, G](
r: var ECP_ShortW_JacExt[F, G],
buckets: var Buckets[N, F, G]) =
buckets: ptr Buckets[N, F, G]) =
var accumBuckets{.noinit.}: ECP_ShortW_JacExt[F, G]

View File

@ -31,7 +31,7 @@ proc sum_reduce_vartime_parallelChunks[F; G: static Subgroup](
## Batch addition of `points` into `r`
## `r` is overwritten
## Compute is parallelized, if beneficial.
## This function cannot be nested in another parallel function
## This function can be nested in another parallel function
# Chunking constants in ec_shortweierstrass_batch_ops.nim
const maxTempMem = 262144 # 2¹⁸ = 262144
@ -50,18 +50,17 @@ proc sum_reduce_vartime_parallelChunks[F; G: static Subgroup](
let partialResults = allocStackArray(r.typeof(), chunkDesc.numChunks)
for iter in items(chunkDesc):
proc sum_reduce_chunk_vartime_wrapper(res: ptr, p: ptr, pLen: int) {.nimcall.} =
# The borrow checker prevents capturing `var` and `openArray`
# so we capture pointers instead.
res[].setInf()
res[].accumSum_chunk_vartime(p, pLen)
syncScope:
for iter in items(chunkDesc):
proc sum_reduce_chunk_vartime_wrapper(res: ptr, p: ptr, pLen: int) {.nimcall.} =
# The borrow checker prevents capturing `var` and `openArray`
# so we capture pointers instead.
res[].setInf()
res[].accumSum_chunk_vartime(p, pLen)
tp.spawn partialResults[iter.chunkID].addr.sum_reduce_chunk_vartime_wrapper(
points.asUnchecked() +% iter.start,
iter.size)
tp.syncAll() # TODO: this prevents nesting in another parallel region
tp.spawn partialResults[iter.chunkID].addr.sum_reduce_chunk_vartime_wrapper(
points.asUnchecked() +% iter.start,
iter.size)
const minChunkSizeSerial = 32
if chunkDesc.numChunks < minChunkSizeSerial:

View File

@ -0,0 +1,105 @@
# Weave
# Copyright (c) 2019 Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
std/atomics,
../instrumentation
# Scoped barrier
# ----------------------------------------------------------------------------------
# A scoped barrier allows detecting if there are child, grandchild or descendant tasks
# still running so that a thread can wait for nested spawns.
#
# This solves the following problem:
#
# parallelFor i in 0 ..< M:
# awaitable: loopI
# parallelFor j in 0 ..< M:
# awaitable: loopJ
# discard
#
# Here, the loop i will add all the parallel tasks for loop j to the task queue
# and continue on. Even if we await loopI, the task was really to just create the loopJ,
# but the loop itself might still be pending in the task queue
#
# We could instead await loopJ before going to the next i iteration but that
# restrict the parallelism exposed.
#
# Alternatively we could use `syncRoot` after those statements but that means
# that it cannot be called from parallel code.
#
# In short:
# - the scoped barrier should prevent a thread from continuing while any descendant task
# is still running.
# - it should be composable and parallel constructs need not to worry about its presence
# in spawn functions.
# - the enclosed scope should be able to expose all parallelism opportunities
# in particular nested parallel for regions.
# TODO: transform into a Futex/eventcount and allow sleeping
type
ScopedBarrier* = object
## A scoped barrier allows detecting if there are child, grandchild or descendant tasks
## still running so that a thread can wait for nested spawns.
##
## ScopedBarriers can be nested and work like a stack.
## Only one can be active for a given thread for a given code section.
##
## They can be allocated on the stack given that a scoped barrier can not be exited
## before all the descendant tasks exited and so the descendants cannot escape,
## i.e. they have a pointer to their scope which is always valid.
##
## This means that in case of nested scopes, only the inner scope needs to track its descendants.
descendants: Atomic[int]
# Note: If one is defined, the other destructors proc are not implictly create inline
# even if trivial
proc `=`*(dst: var ScopedBarrier, src: ScopedBarrier) {.error: "A scoped barrier cannot be copied.".}
proc `=sink`*(dst: var ScopedBarrier, src: ScopedBarrier) {.inline.} =
# Nim doesn't respect noinit and tries to zeroMem then move the type
{.warning: "Moving a shared resource (an atomic type).".}
system.`=sink`(dst.descendants, src.descendants)
proc `=destroy`*(sb: var ScopedBarrier) {.inline.}=
preCondition: sb.descendants.load(moRelaxed) == 0
# system.`=destroy`(sb.descendants)
proc initialize*(scopedBarrier: var ScopedBarrier) {.inline.} =
## Initialize a scoped barrier
scopedBarrier.descendants.store(0, moRelaxed)
proc registerDescendant*(scopedBarrier: ptr ScopedBarrier) {.inline.} =
## Register a descendant task to the scoped barrier
## Important: the thread creating the task must register the descendant
## before handing them over to the runtime.
## This way, if the end of scope is reached and we have 0 descendant it means that
## - either no task was created in the scope
## - tasks were created, but descendants increment/decrement cannot reach 0 before all descendants actually exited
if not scopedBarrier.isNil:
preCondition: scopedBarrier.descendants.load(moAcquire) >= 0
discard scopedBarrier.descendants.fetchAdd(1, moRelease)
postCondition: scopedBarrier.descendants.load(moAcquire) >= 1
proc unlistDescendant*(scopedBarrier: ptr ScopedBarrier) {.inline.} =
## Unlist a descendant task from the scoped barrier.
## Important: if that task spawned new tasks, it is fine even if those grandchild tasks
## are still running, however they must have been registered to the scoped barrier to avoid race conditions.
if not scopedBarrier.isNil:
preCondition: scopedBarrier.descendants.load(moAcquire) >= 1
fence(moRelease)
discard scopedBarrier.descendants.fetchSub(1, moRelease)
preCondition: scopedBarrier.descendants.load(moAcquire) >= 0
proc hasDescendantTasks*(scopedBarrier: ptr ScopedBarrier): bool {.inline.} =
## Returns true if a scoped barrier has at least a descendant task.
## This should only be called from the thread that created the scoped barrier.
preCondition: scopedBarrier.descendants.load(moAcquire) >= 0
return scopedBarrier.descendants.load(moAcquire) != 0

View File

@ -8,6 +8,7 @@
import
std/atomics,
./scoped_barriers,
../instrumentation,
../../allocs,
../primitives/futexes
@ -49,6 +50,7 @@ type
# ------------------
state: TaskState
parent*: ptr Task # Latency: When a task is awaited, a thread can quickly prioritize its direct children.
scopedBarrier*: ptr ScopedBarrier
hasFuture*: bool # Ownership: if a task has a future, the future deallocates it. Otherwise the worker thread does.
# Data parallelism
@ -170,14 +172,18 @@ proc setThief*(task: ptr Task, thiefID: int32) {.inline.} =
proc newSpawn*(
T: typedesc[Task],
parent: ptr Task,
scopedBarrier: ptr ScopedBarrier,
fn: proc (env: pointer) {.nimcall, gcsafe, raises: [].}
): ptr Task {.inline.} =
const size = sizeof(T)
scopedBarrier.registerDescendant()
result = allocHeapUnchecked(T, size)
result.initSynchroState()
result.parent = parent
result.scopedBarrier = scopedBarrier
result.hasFuture = false
result.fn = fn
@ -187,15 +193,19 @@ proc newSpawn*(
proc newSpawn*(
T: typedesc[Task],
parent: ptr Task,
scopedBarrier: ptr ScopedBarrier,
fn: proc (env: pointer) {.nimcall, gcsafe, raises: [].},
env: auto): ptr Task {.inline.} =
const size = sizeof(T) + # size without Unchecked
sizeof(env)
scopedBarrier.registerDescendant()
result = allocHeapUnchecked(T, size)
result.initSynchroState()
result.parent = parent
result.scopedBarrier = scopedBarrier
result.hasFuture = false
result.fn = fn
cast[ptr[type env]](result.env)[] = env
@ -209,6 +219,7 @@ func ceilDiv_vartime(a, b: auto): auto {.inline.} =
proc newLoop*(
T: typedesc[Task],
parent: ptr Task,
scopedBarrier: ptr ScopedBarrier,
start, stop, stride: int,
isFirstIter: bool,
fn: proc (env: pointer) {.nimcall, gcsafe, raises: [].}
@ -216,9 +227,12 @@ proc newLoop*(
const size = sizeof(T)
preCondition: start < stop
scopedBarrier.registerDescendant()
result = allocHeapUnchecked(T, size)
result.initSynchroState()
result.parent = parent
result.scopedBarrier = scopedBarrier
result.hasFuture = false
result.fn = fn
result.envSize = 0
@ -233,6 +247,7 @@ proc newLoop*(
proc newLoop*(
T: typedesc[Task],
parent: ptr Task,
scopedBarrier: ptr ScopedBarrier,
start, stop, stride: int,
isFirstIter: bool,
fn: proc (env: pointer) {.nimcall, gcsafe, raises: [].},
@ -242,9 +257,12 @@ proc newLoop*(
sizeof(env)
preCondition: start < stop
scopedBarrier.registerDescendant()
result = allocHeapUnchecked(T, size)
result.initSynchroState()
result.parent = parent
result.scopedBarrier = scopedBarrier
result.hasFuture = false
result.fn = fn
result.envSize = int32(sizeof(env))
@ -311,4 +329,4 @@ func getTask*[T](fv: FlowVar[T]): ptr Task {.inline.} =
proc newReductionDagNode*(task: ptr Task, next: ptr ReductionDagNode): ptr ReductionDagNode {.inline.} =
result = allocHeap(ReductionDagNode)
result.next = next
result.task = task
result.task = task

View File

@ -94,11 +94,13 @@ proc spawnVoid(funcCall: NimNode, args, argsTy: NimNode, workerContext, schedule
when bool(`withArgs`):
let `task` = Task.newSpawn(
parent = `workerContext`.currentTask,
scopedBarrier = `workerContext`.currentScope,
fn = `tpSpawn_closure`,
env = `envParams`)
else:
let `task` = Task.newSpawn(
parent = `workerContext`.currentTask,
scopedBarrier = `workerContext`.currentScope,
fn = `tpSpawn_closure`)
`scheduleBlock`
@ -161,6 +163,7 @@ proc spawnVoidAwaitable(funcCall: NimNode, args, argsTy: NimNode, workerContext,
let `task` = Task.newSpawn(
parent = `workerContext`.currentTask,
scopedBarrier = `workerContext`.currentScope,
fn = `tpSpawn_closure`,
env = `envParams`)
let `fut` = newFlowVar(bool, `task`)
@ -230,6 +233,7 @@ proc spawnRet(funcCall: NimNode, retTy, args, argsTy: NimNode, workerContext, sc
let `task` = Task.newSpawn(
parent = `workerContext`.currentTask,
scopedBarrier = `workerContext`.currentScope,
fn = `tpSpawn_closure`,
env = `envParams`)
let `fut` = newFlowVar(`retTy`, `task`)
@ -589,6 +593,7 @@ proc generateAndScheduleLoopTask(ld: LoopDescriptor): NimNode =
when bool(`withCaptures`):
let `task` = Task.newLoop(
parent = `workerContext`.currentTask,
scopedBarrier = `workerContext`.currentScope,
start, stopEx, `stride`,
isFirstIter = true,
fn = `closureName`,
@ -596,6 +601,7 @@ proc generateAndScheduleLoopTask(ld: LoopDescriptor): NimNode =
else:
let `task` = Task.newLoop(
parent = `workerContext`.currentTask,
scopedBarrier = `workerContext`.currentScope,
start, stopEx, `stride`,
isFirstIter = true,
fn = `closureName`)
@ -613,6 +619,7 @@ proc generateAndScheduleLoopTask(ld: LoopDescriptor): NimNode =
when bool(`withCaptures`):
let `task` = Task.newLoop(
parent = `workerContext`.currentTask,
scopedBarrier = `workerContext`.currentScope,
start, stopEx, `stride`,
isFirstIter = true,
fn = `closureName`,
@ -620,6 +627,7 @@ proc generateAndScheduleLoopTask(ld: LoopDescriptor): NimNode =
else:
let `task` = Task.newLoop(
parent = `workerContext`.currentTask,
scopedBarrier = `workerContext`.currentScope,
start, stopEx, `stride`,
isFirstIter = true,
fn = `closureName`,

View File

@ -16,6 +16,7 @@ import
./crossthread/[
taskqueues,
backoff,
scoped_barriers,
tasks_flowvars],
./instrumentation,
./primitives/barriers,
@ -173,6 +174,7 @@ type
currentTask: ptr Task
# Synchronization
currentScope*: ptr ScopedBarrier # need to be exported for syncScope template
signal: ptr Signal # owned signal
# Thefts
@ -273,6 +275,7 @@ proc setupWorker(ctx: var WorkerContext) =
ctx.rng.seed(0xEFFACED + ctx.id)
# Synchronization
ctx.currentScope = nil
ctx.signal = addr ctx.threadpool.workerSignals[ctx.id]
ctx.signal.terminate.store(false, moRelaxed)
@ -359,13 +362,21 @@ const RootTask = cast[ptr Task](0xEFFACED0)
proc run(ctx: var WorkerContext, task: ptr Task) {.raises:[].} =
## Run a task, frees it if it is not owned by a Flowvar
let suspendedTask = ctx.currentTask
let suspendedScope = ctx.currentScope
ctx.currentTask = task
ctx.currentScope = task.scopedBarrier
debug: log("Worker %3d: running task 0x%.08x (previous: 0x%.08x, %d pending, thiefID %d)\n", ctx.id, task, suspendedTask, ctx.taskqueue[].peek(), task.getThief())
profile(run_task):
task.fn(task.env.addr)
task.scopedBarrier.unlistDescendant()
debug: log("Worker %3d: completed task 0x%.08x (%d pending)\n", ctx.id, task, ctx.taskqueue[].peek())
ctx.currentTask = suspendedTask
ctx.currentScope = suspendedScope
ctx.incCounter(tasksExecuted)
ctx.incCounter(itersExecuted):
@ -387,7 +398,7 @@ proc schedule(ctx: var WorkerContext, task: ptr Task, forceWake = false) {.inlin
## Schedule a task in the threadpool
## This wakes another worker if our local queue is empty
## or forceWake is true.
debug: log("Worker %3d: schedule task 0x%.08x (parent/current task 0x%.08x)\n", ctx.id, task, task.parent)
debug: log("Worker %3d: schedule task 0x%.08x (parent/current task 0x%.08x, scope 0x%.08x)\n", ctx.id, task, task.parent, task.scopedBarrier)
# Instead of notifying every time a task is scheduled, we notify
# only when the worker queue is empty. This is a good approximation
@ -536,6 +547,7 @@ proc splitAndDispatchLoop(ctx: var WorkerContext, task: ptr Task, curLoopIndex:
upperSplit.initSynchroState()
upperSplit.parent = task
upperSplit.scopedBarrier.registerDescendant()
upperSplit.isFirstIter = false
upperSplit.loopStart = offset
@ -839,7 +851,7 @@ proc syncAll*(tp: Threadpool) {.raises: [].} =
template ctx: untyped = workerContext
debugTermination:
log(">>> Worker %3d enters barrier <<<\n", ctx.id)
log(">>> Worker %3d enters global barrier <<<\n", ctx.id)
preCondition: ctx.id == 0
preCondition: ctx.currentTask.isRootTask()
@ -880,6 +892,46 @@ proc syncAll*(tp: Threadpool) {.raises: [].} =
profileStart(run_task)
proc wait(scopedBarrier: ptr ScopedBarrier) {.raises:[], gcsafe.} =
## Wait at barrier until all descendant tasks are completed
template ctx: untyped = workerContext
debugTermination:
log(">>> Worker %3d enters scoped barrier 0x%.08x <<<\n", ctx.id, scopedBarrier)
while scopedBarrier.hasDescendantTasks():
# 1. Empty local tasks, the initial loop only has tasks from that scope or a child scope.
debug: log("Worker %3d: syncScope 1 - searching task from local queue\n", ctx.id)
while (let task = ctx.taskqueue[].pop(); not task.isNil):
debug: log("Worker %3d: syncScope 1 - running task 0x%.08x (parent 0x%.08x, current 0x%.08x, scope 0x%.08x)\n", ctx.id, task, task.parent, ctx.currentTask, task.scopedBarrier)
ctx.run(task)
if not scopedBarrier.hasDescendantTasks():
debugTermination:
log(">>> Worker %3d exits scoped barrier 0x%.08x <<<\n", ctx.id, scopedBarrier)
return
# TODO: consider leapfrogging
if (var stolenTask = ctx.tryStealOne(); not stolenTask.isNil):
# 2.a We stole some task
debug: log("Worker %3d: syncScope 2 - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x, scope 0x%.08x)\n", ctx.id, stolenTask, stolenTask.parent, ctx.currentTask, stolenTask.scopedBarrier)
# Theft successful, there might be more work for idle threads, wake one
ctx.threadpool.globalBackoff.wake()
ctx.incCounter(backoffGlobalSignalSent)
ctx.incCounter(theftsIdle)
ctx.incCounter(itersStolen):
if stolenTask.loopStepsLeft == NotALoop: 0
else: stolenTask.loopStepsLeft
ctx.run(stolenTask)
else:
# TODO: backoff
cpuRelax()
debugTermination:
log(">>> Worker %3d exits scoped barrier 0x%.08x <<<\n", ctx.id, scopedBarrier)
# ############################################################
# #
# Runtime API #
@ -973,6 +1025,20 @@ proc shutdown*(tp: var Threadpool) {.raises:[].} =
# #
# ############################################################
# Structured parallelism API
# ---------------------------------------------
template syncScope*(body: untyped): untyped =
block:
let suspendedScope = workerContext.currentScope
var scopedBarrier {.noInit.}: ScopedBarrier
initialize(scopedBarrier)
workerContext.currentScope = addr scopedBarrier
block:
body
wait(scopedBarrier.addr)
workerContext.currentScope = suspendedScope
# Task parallel API
# ---------------------------------------------