Parallel batchadd (#215)

* [Threadpool] Fix syncAll releasing while a thread was attempting to steal + force no exception in tasks

* fix unguarded access on MacOS barriers

* parallel batchadd

* moved import
This commit is contained in:
Mamy Ratsimbazafy 2023-01-29 01:06:37 +01:00 committed by GitHub
parent a385acf2b8
commit 495ef4497b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 945 additions and 491 deletions

View File

@ -21,7 +21,7 @@ import
# Standard library # Standard library
std/[monotimes, times, strformat, strutils, macros] std/[monotimes, times, strformat, strutils, macros]
export strformat, platforms, times, monotimes, macros export strutils, strformat, platforms, times, monotimes, macros
var rng*: RngState var rng*: RngState
let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32 let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32

View File

@ -13,10 +13,34 @@ import
../constantine/math/elliptic/[ ../constantine/math/elliptic/[
ec_shortweierstrass_affine, ec_shortweierstrass_affine,
ec_shortweierstrass_projective, ec_shortweierstrass_projective,
ec_shortweierstrass_jacobian], ec_shortweierstrass_jacobian,
ec_shortweierstrass_batch_ops_parallel],
../constantine/platforms/threadpool/threadpool,
# Helpers # Helpers
../helpers/static_for, ../helpers/[static_for, prng_unsafe],
./bench_elliptic_template ./bench_elliptic_template,
./bench_blueprint
# ############################################################
#
# Parallel Benchmark definitions
#
# ############################################################
proc multiAddParallelBench*(EC: typedesc, numPoints: int, iters: int) =
var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](numPoints)
for i in 0 ..< numPoints:
points[i] = rng.random_unsafe(ECP_ShortW_Aff[EC.F, EC.G])
var r{.noInit.}: EC
var tp = Threadpool.new()
bench("EC parallel batch add (" & align($tp.numThreads, 2) & " threads) " & $EC.G & " (" & $numPoints & " points)", EC, iters):
tp.sum_batch_vartime_parallel(r, points)
tp.shutdown()
# ############################################################ # ############################################################
# #
@ -52,6 +76,10 @@ proc main() =
let batchIters = max(1, Iters div numPoints) let batchIters = max(1, Iters div numPoints)
multiAddBench(ECP_ShortW_Prj[Fp[curve], G1], numPoints, useBatching = true, batchIters) multiAddBench(ECP_ShortW_Prj[Fp[curve], G1], numPoints, useBatching = true, batchIters)
separator() separator()
for numPoints in [10, 100, 1000, 10000, 100000, 1000000]:
let batchIters = max(1, Iters div numPoints)
multiAddParallelBench(ECP_ShortW_Prj[Fp[curve], G1], numPoints, batchIters)
separator()
for numPoints in [10, 100, 1000, 10000, 100000, 1000000]: for numPoints in [10, 100, 1000, 10000, 100000, 1000000]:
let batchIters = max(1, Iters div numPoints) let batchIters = max(1, Iters div numPoints)
multiAddBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, useBatching = false, batchIters) multiAddBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, useBatching = false, batchIters)
@ -60,6 +88,10 @@ proc main() =
let batchIters = max(1, Iters div numPoints) let batchIters = max(1, Iters div numPoints)
multiAddBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, useBatching = true, batchIters) multiAddBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, useBatching = true, batchIters)
separator() separator()
for numPoints in [10, 100, 1000, 10000, 100000, 1000000]:
let batchIters = max(1, Iters div numPoints)
multiAddParallelBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, batchIters)
separator()
separator() separator()
main() main()

View File

@ -0,0 +1 @@
--threads:on

View File

@ -25,7 +25,7 @@ import
ec_shortweierstrass_batch_ops, ec_shortweierstrass_batch_ops,
ec_scalar_mul, ec_endomorphism_accel], ec_scalar_mul, ec_endomorphism_accel],
# Helpers # Helpers
../helpers/[prng_unsafe, static_for], ../helpers/prng_unsafe,
./platforms, ./platforms,
./bench_blueprint, ./bench_blueprint,
# Reference unsafe scalar multiplication # Reference unsafe scalar multiplication
@ -54,7 +54,7 @@ proc report(op, elliptic: string, start, stop: MonoTime, startClk, stopClk: int6
else: else:
echo &"{op:<60} {elliptic:<40} {throughput:>15.3f} ops/s {ns:>9} ns/op" echo &"{op:<60} {elliptic:<40} {throughput:>15.3f} ops/s {ns:>9} ns/op"
template bench(op: string, EC: typedesc, iters: int, body: untyped): untyped = template bench*(op: string, EC: typedesc, iters: int, body: untyped): untyped =
measure(iters, startTime, stopTime, startClk, stopClk, body) measure(iters, startTime, stopTime, startClk, stopClk, body)
report(op, fixEllipticDisplay(EC), startTime, stopTime, startClk, stopClk, iters) report(op, fixEllipticDisplay(EC), startTime, stopTime, startClk, stopClk, iters)
@ -143,7 +143,7 @@ proc scalarMulUnsafeDoubleAddBench*(EC: typedesc, iters: int) =
r = P r = P
r.unsafe_ECmul_double_add(exponent) r.unsafe_ECmul_double_add(exponent)
proc multiAddBench*(EC: typedesc, numPoints: int, useBatching: bool, iters: int) = proc multiAddBench*(EC: typedesc, numPoints: int, useBatching: bool, iters: int) =
var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](numPoints) var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](numPoints)
for i in 0 ..< numPoints: for i in 0 ..< numPoints:
@ -152,10 +152,10 @@ proc multiAddBench*(EC: typedesc, numPoints: int, useBatching: bool, iters: int)
var r{.noInit.}: EC var r{.noInit.}: EC
if useBatching: if useBatching:
bench("EC Multi-Addition batched " & $EC.G & " (" & $numPoints & " points)", EC, iters): bench("EC Multi Add batched " & $EC.G & " (" & $numPoints & " points)", EC, iters):
r.sum_batch_vartime(points) r.sum_batch_vartime(points)
else: else:
bench("EC Multi-Addition unbatched mixed add " & $EC.G & " (" & $numPoints & " points)", EC, iters): bench("EC Multi Mixed-Add unbatched " & $EC.G & " (" & $numPoints & " points)", EC, iters):
r.setInf() r.setInf()
for i in 0 ..< numPoints: for i in 0 ..< numPoints:
r += points[i] r += points[i]

View File

@ -245,6 +245,11 @@ const testDescThreadpool: seq[string] = @[
# "constantine/platforms/threadpool/benchmarks/single_task_producer/threadpool_spc.nim", # Need timing not implemented on Windows # "constantine/platforms/threadpool/benchmarks/single_task_producer/threadpool_spc.nim", # Need timing not implemented on Windows
] ]
const testDescMultithreadedCrypto: seq[string] = @[
"tests/parallel/t_ec_shortw_jac_g1_batch_add_parallel.nim",
"tests/parallel/t_ec_shortw_prj_g1_batch_add_parallel.nim"
]
const benchDesc = [ const benchDesc = [
"bench_fp", "bench_fp",
"bench_fp_double_precision", "bench_fp_double_precision",
@ -408,7 +413,25 @@ proc addTestSetThreadpool(cmdFile: var string) =
echo "Found " & $testDescThreadpool.len & " tests to run." echo "Found " & $testDescThreadpool.len & " tests to run."
for path in testDescThreadpool: for path in testDescThreadpool:
cmdFile.testBatch(flags = "--threads:on --linetrace:on", path) cmdFile.testBatch(flags = "--threads:on --linetrace:on --debugger:native", path)
proc addTestSetMultithreadedCrypto(cmdFile: var string, test32bit = false, testASM = true) =
if not dirExists "build":
mkDir "build"
echo "Found " & $testDescMultithreadedCrypto.len & " tests to run."
for td in testDescMultithreadedCrypto:
var flags = " --threads:on --debugger:native"
if not testASM:
flags &= " -d:CttASM=false"
if test32bit:
flags &= " -d:Constantine32"
if td in useDebug:
flags &= " -d:debugConstantine"
if td notin skipSanitizers:
flags &= sanitizers
cmdFile.testBatch(flags, td)
proc addBenchSet(cmdFile: var string, useAsm = true) = proc addBenchSet(cmdFile: var string, useAsm = true) =
if not dirExists "build": if not dirExists "build":
@ -643,6 +666,13 @@ task test_threadpool, "Run all tests for the builtin threadpool":
if cmd != "": # Windows doesn't like empty commands if cmd != "": # Windows doesn't like empty commands
exec cmd exec cmd
task test_multithreaded_crypto, "Run all tests for multithreaded cryptography":
var cmdFile: string
cmdFile.addTestSetMultithreadedCrypto()
for cmd in cmdFile.splitLines():
if cmd != "": # Windows doesn't like empty commands
exec cmd
task test_nvidia, "Run all tests for Nvidia GPUs": task test_nvidia, "Run all tests for Nvidia GPUs":
var cmdFile: string var cmdFile: string
cmdFile.addTestSetNvidia() cmdFile.addTestSetNvidia()

View File

@ -18,10 +18,12 @@ import
ec_shortweierstrass_affine, ec_shortweierstrass_affine,
ec_shortweierstrass_jacobian, ec_shortweierstrass_jacobian,
ec_shortweierstrass_projective, ec_shortweierstrass_projective,
ec_shortweierstrass_batch_ops,
ec_scalar_mul ec_scalar_mul
] ]
export ec_shortweierstrass_affine, ec_shortweierstrass_jacobian, ec_shortweierstrass_projective, ec_scalar_mul export ec_shortweierstrass_affine, ec_shortweierstrass_jacobian, ec_shortweierstrass_projective,
ec_shortweierstrass_batch_ops, ec_scalar_mul
type ECP_ShortW*[F; G: static Subgroup] = ECP_ShortW_Aff[F, G] | ECP_ShortW_Jac[F, G] | ECP_ShortW_Prj[F, G] type ECP_ShortW*[F; G: static Subgroup] = ECP_ShortW_Aff[F, G] | ECP_ShortW_Jac[F, G] | ECP_ShortW_Prj[F, G]

View File

@ -16,7 +16,8 @@ import
../arithmetic, ../arithmetic,
../extension_fields, ../extension_fields,
../isogenies/frobenius, ../isogenies/frobenius,
./ec_shortweierstrass_affine ./ec_shortweierstrass_affine,
./ec_shortweierstrass_batch_ops
# ############################################################ # ############################################################
# #

View File

@ -14,8 +14,141 @@ import
./ec_shortweierstrass_jacobian, ./ec_shortweierstrass_jacobian,
./ec_shortweierstrass_projective ./ec_shortweierstrass_projective
# No exceptions allowed # No exceptions allowed, or array bound checks or integer overflow
{.push raises: [].} {.push raises: [], checks:off.}
# ############################################################
#
# Elliptic Curve in Short Weierstrass form
# Batch conversion
#
# ############################################################
func batchAffine*[F, G](
affs: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
projs: ptr UncheckedArray[ECP_ShortW_Prj[F, G]],
N: int) =
# Algorithm: Montgomery's batch inversion
# - Speeding the Pollard and Elliptic Curve Methods of Factorization
# Section 10.3.1
# Peter L. Montgomery
# https://www.ams.org/journals/mcom/1987-48-177/S0025-5718-1987-0866113-7/S0025-5718-1987-0866113-7.pdf
# - Modern Computer Arithmetic
# Section 2.5.1 Several inversions at once
# Richard P. Brent and Paul Zimmermann
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
# To avoid temporaries, we store partial accumulations
# in affs[i].x
let zeroes = allocStackArray(SecretBool, N)
affs[0].x = projs[0].z
zeroes[0] = affs[0].x.isZero()
affs[0].x.csetOne(zeroes[0])
for i in 1 ..< N:
# Skip zero z-coordinates (infinity points)
var z = projs[i].z
zeroes[i] = z.isZero()
z.csetOne(zeroes[i])
if i != N-1:
affs[i].x.prod(affs[i-1].x, z, skipFinalSub = true)
else:
affs[i].x.prod(affs[i-1].x, z, skipFinalSub = false)
var accInv {.noInit.}: F
accInv.inv(affs[N-1].x)
for i in countdown(N-1, 1):
# Extract 1/Pᵢ
var invi {.noInit.}: F
invi.prod(accInv, affs[i-1].x, skipFinalSub = true)
invi.csetZero(zeroes[i])
# Now convert Pᵢ to affine
affs[i].x.prod(projs[i].x, invi)
affs[i].y.prod(projs[i].y, invi)
# next iteration
invi = projs[i].z
invi.csetOne(zeroes[i])
accInv.prod(accInv, invi, skipFinalSub = true)
block: # tail
accInv.csetZero(zeroes[0])
affs[0].x.prod(projs[0].x, accInv)
affs[0].y.prod(projs[0].y, accInv)
func batchAffine*[N: static int, F, G](
affs: var array[N, ECP_ShortW_Aff[F, G]],
projs: array[N, ECP_ShortW_Prj[F, G]]) {.inline.} =
batchAffine(affs.asUnchecked(), projs.asUnchecked(), N)
func batchAffine*[F, G](
affs: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
jacs: ptr UncheckedArray[ECP_ShortW_Jac[F, G]],
N: int) =
# Algorithm: Montgomery's batch inversion
# - Speeding the Pollard and Elliptic Curve Methods of Factorization
# Section 10.3.1
# Peter L. Montgomery
# https://www.ams.org/journals/mcom/1987-48-177/S0025-5718-1987-0866113-7/S0025-5718-1987-0866113-7.pdf
# - Modern Computer Arithmetic
# Section 2.5.1 Several inversions at once
# Richard P. Brent and Paul Zimmermann
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
# To avoid temporaries, we store partial accumulations
# in affs[i].x and whether z == 0 in affs[i].y
var zeroes = allocStackArray(SecretBool, N)
affs[0].x = jacs[0].z
zeroes[0] = affs[0].x.isZero()
affs[0].x.csetOne(zeroes[0])
for i in 1 ..< N:
# Skip zero z-coordinates (infinity points)
var z = jacs[i].z
zeroes[i] = z.isZero()
z.csetOne(zeroes[i])
if i != N-1:
affs[i].x.prod(affs[i-1].x, z, skipFinalSub = true)
else:
affs[i].x.prod(affs[i-1].x, z, skipFinalSub = false)
var accInv {.noInit.}: F
accInv.inv(affs[N-1].x)
for i in countdown(N-1, 1):
# Extract 1/Pᵢ
var invi {.noInit.}: F
invi.prod(accInv, affs[i-1].x, skipFinalSub = true)
invi.csetZero(zeroes[i])
# Now convert Pᵢ to affine
var invi2 {.noinit.}: F
invi2.square(invi, skipFinalSub = true)
affs[i].x.prod(jacs[i].x, invi2)
invi.prod(invi, invi2, skipFinalSub = true)
affs[i].y.prod(jacs[i].y, invi)
# next iteration
invi = jacs[i].z
invi.csetOne(zeroes[i])
accInv.prod(accInv, invi, skipFinalSub = true)
block: # tail
var invi2 {.noinit.}: F
accInv.csetZero(zeroes[0])
invi2.square(accInv, skipFinalSub = true)
affs[0].x.prod(jacs[0].x, invi2)
accInv.prod(accInv, invi2, skipFinalSub = true)
affs[0].y.prod(jacs[0].y, accInv)
func batchAffine*[N: static int, F, G](
affs: var array[N, ECP_ShortW_Aff[F, G]],
jacs: array[N, ECP_ShortW_Jac[F, G]]) {.inline.} =
batchAffine(affs.asUnchecked(), jacs.asUnchecked(), N)
# ############################################################ # ############################################################
# #
@ -28,19 +161,19 @@ import
# ------------------------------------------------------------ # ------------------------------------------------------------
# #
# The equation for elliptic curve addition is in affine (x, y) coordinates: # The equation for elliptic curve addition is in affine (x, y) coordinates:
# #
# P + Q = R # P + Q = R
# (Px, Py) + (Qx, Qy) = (Rx, Ry) # (Px, Py) + (Qx, Qy) = (Rx, Ry)
# #
# with # with
# Rx = λ² - Px - Qx # Rx = λ² - Px - Qx
# Ry = λ(Px - Rx) - Py # Ry = λ(Px - Rx) - Py
# #
# in the case of addition # in the case of addition
# λ = (Qy - Py) / (Qx - Px) # λ = (Qy - Py) / (Qx - Px)
# #
# which is undefined for P == Q or P == -Q as -(x, y) = (x, -y) # which is undefined for P == Q or P == -Q as -(x, y) = (x, -y)
# #
# if P = Q, the doubling formula uses the slope of the tangent at the limit # if P = Q, the doubling formula uses the slope of the tangent at the limit
# λ = (3 Px² + a) / (2 Px) # λ = (3 Px² + a) / (2 Px)
# #
@ -85,7 +218,7 @@ func affineAdd[F; G: static Subgroup](
r: var ECP_ShortW_Aff[F, G], r: var ECP_ShortW_Aff[F, G],
lambda: var F, lambda: var F,
P, Q: ECP_ShortW_Aff[F, G]) = P, Q: ECP_ShortW_Aff[F, G]) =
r.x.square(lambda) r.x.square(lambda)
r.x -= P.x r.x -= P.x
r.x -= Q.x r.x -= Q.x
@ -94,31 +227,30 @@ func affineAdd[F; G: static Subgroup](
r.y *= lambda r.y *= lambda
r.y -= P.y r.y -= P.y
{.push checks:off.}
func accum_half_vartime[F; G: static Subgroup]( func accum_half_vartime[F; G: static Subgroup](
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
lambdas: ptr UncheckedArray[tuple[num, den: F]], lambdas: ptr UncheckedArray[tuple[num, den: F]],
len: uint) {.noinline.} = len: uint) {.noinline.} =
## Affine accumulation of half the points into the other half ## Affine accumulation of half the points into the other half
## Warning ⚠️ : variable-time ## Warning ⚠️ : variable-time
## ##
## Accumulate `len` points pairwise into `len/2` ## Accumulate `len` points pairwise into `len/2`
## ##
## Input/output: ## Input/output:
## - points: `len/2` affine points to add (must be even) ## - points: `len/2` affine points to add (must be even)
## Partial sums are stored in [0, len/2) ## Partial sums are stored in [0, len/2)
## [len/2, len) data has been destroyed ## [len/2, len) data has been destroyed
## ##
## Scratchspace: ## Scratchspace:
## - Lambdas ## - Lambdas
## ##
## Output: ## Output:
## - r ## - r
## ##
## Warning ⚠️ : cannot be inlined if used in loop due to the use of alloca ## Warning ⚠️ : cannot be inlined if used in loop due to the use of alloca
debug: doAssert len and 1 == 0, "There must be an even number of points" debug: doAssert len and 1 == 0, "There must be an even number of points"
let N = len div 2 let N = len div 2
# Step 1: Compute numerators and denominators of λᵢ = λᵢ_num / λᵢ_den # Step 1: Compute numerators and denominators of λᵢ = λᵢ_num / λᵢ_den
@ -157,12 +289,12 @@ func accum_half_vartime[F; G: static Subgroup](
continue continue
else: else:
lambdaAdd(lambdas[i].num, lambdas[i].den, points[p], points[q]) lambdaAdd(lambdas[i].num, lambdas[i].den, points[p], points[q])
# Step 2: Accumulate denominators in Qy, which is not used anymore. # Step 2: Accumulate denominators in Qy, which is not used anymore.
if i == 0: if i == 0:
points[q].y = lambdas[i].den points[q].y = lambdas[i].den
else: else:
points[q].y.prod(points[q_prev].y, lambdas[i].den, skipFinalSub = true) points[q].y.prod(points[q_prev].y, lambdas[i].den, skipFinalSub = true)
# Step 3: batch invert # Step 3: batch invert
var accInv {.noInit.}: F var accInv {.noInit.}: F
@ -195,10 +327,10 @@ func accum_half_vartime[F; G: static Subgroup](
# Compute lambda # Compute lambda
points[q].y.prod(accInv, points[q_prev].y, skipFinalSub = true) points[q].y.prod(accInv, points[q_prev].y, skipFinalSub = true)
points[q].y.prod(points[q].y, lambdas[i].num, skipFinalSub = true) points[q].y.prod(points[q].y, lambdas[i].num, skipFinalSub = true)
# Compute EC addition # Compute EC addition
var r{.noInit.}: ECP_ShortW_Aff[F, G] var r{.noInit.}: ECP_ShortW_Aff[F, G]
r.affineAdd(lambda = points[q].y, points[p], points[q]) r.affineAdd(lambda = points[q].y, points[p], points[q])
# Store result # Store result
points[i] = r points[i] = r
@ -216,20 +348,17 @@ func accum_half_vartime[F; G: static Subgroup](
else: else:
# Compute lambda # Compute lambda
points[q].y.prod(lambdas[0].num, accInv, skipFinalSub = true) points[q].y.prod(lambdas[0].num, accInv, skipFinalSub = true)
# Compute EC addition # Compute EC addition
var r{.noInit.}: ECP_ShortW_Aff[F, G] var r{.noInit.}: ECP_ShortW_Aff[F, G]
r.affineAdd(lambda = points[q].y, points[p], points[q]) r.affineAdd(lambda = points[q].y, points[p], points[q])
# Store result # Store result
points[0] = r points[0] = r
{.pop.}
# Batch addition: jacobian # Batch addition: jacobian
# ------------------------------------------------------------ # ------------------------------------------------------------
{.push checks:off.}
func accumSum_chunk_vartime[F; G: static Subgroup]( func accumSum_chunk_vartime[F; G: static Subgroup](
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]), r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]],
@ -238,16 +367,16 @@ func accumSum_chunk_vartime[F; G: static Subgroup](
## Accumulate `points` into r. ## Accumulate `points` into r.
## `r` is NOT overwritten ## `r` is NOT overwritten
## r += ∑ points ## r += ∑ points
const ChunkThreshold = 16 const minNumPointsSerial = 16
var n = len var n = len
while n >= ChunkThreshold: while n >= minNumPointsSerial:
if (n and 1) == 1: # odd number of points if (n and 1) == 1: # odd number of points
## Accumulate the last ## Accumulate the last
r += points[n-1] r += points[n-1]
n -= 1 n -= 1
# Compute [0, n/2) += [n/2, n) # Compute [0, n/2) += [n/2, n)
accum_half_vartime(points, lambdas, n) accum_half_vartime(points, lambdas, n)
@ -257,15 +386,13 @@ func accumSum_chunk_vartime[F; G: static Subgroup](
# Tail # Tail
for i in 0'u ..< n: for i in 0'u ..< n:
r += points[i] r += points[i]
{.pop.}
{.push checks:off.}
func sum_batch_vartime*[F; G: static Subgroup]( func sum_batch_vartime*[F; G: static Subgroup](
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]), r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
points: openArray[ECP_ShortW_Aff[F, G]]) = points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], pointsLen: int) =
## Batch addition of `points` into `r` ## Batch addition of `points` into `r`
## `r` is overwritten ## `r` is overwritten
# We chunk the addition to limit memory usage # We chunk the addition to limit memory usage
# especially as we allocate on the stack. # especially as we allocate on the stack.
@ -287,17 +414,25 @@ func sum_batch_vartime*[F; G: static Subgroup](
r.setInf() r.setInf()
const maxChunkSize = 262144 # 2¹⁸ = 262144 const maxTempMem = 262144 # 2¹⁸ = 262144
const maxStride = maxChunkSize div sizeof(ECP_ShortW_Aff[F, G]) const maxStride = maxTempMem div sizeof(ECP_ShortW_Aff[F, G])
let n = min(maxStride, points.len) let n = min(maxStride, pointsLen)
let accumulators = allocStackArray(ECP_ShortW_Aff[F, G], n) let accumulators = allocStackArray(ECP_ShortW_Aff[F, G], n)
let lambdas = allocStackArray(tuple[num, den: F], n) let lambdas = allocStackArray(tuple[num, den: F], n)
for i in countup(0, points.len-1, maxStride): for i in countup(0, pointsLen-1, maxStride):
let n = min(maxStride, points.len - i) let n = min(maxStride, pointsLen - i)
let size = n * sizeof(ECP_ShortW_Aff[F, G]) let size = n * sizeof(ECP_ShortW_Aff[F, G])
copyMem(accumulators[0].addr, points[i].unsafeAddr, size) copyMem(accumulators[0].addr, points[i].unsafeAddr, size)
r.accumSum_chunk_vartime(accumulators, lambdas, uint n) r.accumSum_chunk_vartime(accumulators, lambdas, uint n)
{.pop.} func sum_batch_vartime*[F; G: static Subgroup](
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
points: openArray[ECP_ShortW_Aff[F, G]]) {.inline.} =
## Batch addition of `points` into `r`
## `r` is overwritten
if points.len == 0:
r.setInf()
return
r.sum_batch_vartime(points.asUnchecked(), points.len)

View File

@ -0,0 +1,139 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
../../platforms/[abstractions, allocs],
../../platforms/threadpool/threadpool,
./ec_shortweierstrass_affine,
./ec_shortweierstrass_jacobian,
./ec_shortweierstrass_projective,
./ec_shortweierstrass_batch_ops
# No exceptions allowed
{.push raises:[], checks: off.}
# ############################################################
#
# Elliptic Curve in Short Weierstrass form
# Parallel Batch addition
#
# ############################################################
type ChunkDescriptor = object
start, totalIters: int
numChunks, baseChunkSize, cutoff: int
func computeBalancedChunks(start, stopEx, minChunkSize, maxChunkSize, targetNumChunks: int): ChunkDescriptor =
## Balanced chunking algorithm for a range [start, stopEx)
## This ideally splits a range into min(stopEx-start, targetNumChunks) balanced regions
## unless the chunk size isn't in the range [minChunkSize, maxChunkSize]
#
# see constantine/platforms/threadpool/docs/partitioner.md
let totalIters = stopEx - start
var numChunks = max(targetNumChunks, 1)
var baseChunkSize = totalIters div numChunks
var cutoff = totalIters mod numChunks # Should be computed in a single instruction with baseChunkSize
if baseChunkSize < minChunkSize:
numChunks = max(totalIters div minChunkSize, 1)
baseChunkSize = totalIters div numChunks
cutoff = totalIters mod numChunks
elif baseChunkSize > maxChunkSize or (baseChunkSize == maxChunkSize and cutoff != 0):
# After cutoff, we do baseChunkSize+1, and would run afoul of the maxChunkSize constraint (unless no remainder), hence ceildiv
numChunks = (totalIters + maxChunkSize - 1) div maxChunkSize # ceildiv
baseChunkSize = totalIters div numChunks
cutoff = totalIters mod numChunks
return ChunkDescriptor(
start: start, totaliters: totalIters,
numChunks: numChunks, baseChunkSize: baseChunkSize, cutoff: cutoff
)
iterator items(c: ChunkDescriptor): tuple[chunkID, start, stopEx: int] =
for chunkID in 0 ..< min(c.numChunks, c.totalIters):
if chunkID < c.cutoff:
let offset = c.start + ((c.baseChunkSize + 1) * chunkID)
let chunkSize = c.baseChunkSize + 1
yield (chunkID, offset, min(offset+chunkSize, c.totalIters))
else:
let offset = c.start + (c.baseChunkSize * chunkID) + c.cutoff
let chunkSize = c.baseChunkSize
yield (chunkID, offset, min(offset+chunkSize, c.totalIters))
proc sum_batch_vartime_parallel*[F; G: static Subgroup](
tp: Threadpool,
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
points: openArray[ECP_ShortW_Aff[F, G]]) =
## Batch addition of `points` into `r`
## `r` is overwritten
## Compute is parallelized, if beneficial.
## This function cannot be nested in another parallel function
##
## Side-effects due to thread-local threadpool variable accesses.
# TODO:
# This function is needed in Multi-Scalar Multiplication (MSM)
# The main bottleneck (~80% time) of zero-ledge proof systems.
# MSM is difficult to scale above 16 cores,
# allowing nested parallelism will expose more parallelism opportunities.
# Chunking constants in ec_shortweierstrass_batch_ops.nim
const minNumPointsParallel = 1024 # For 256-bit curves that's 1024*(32+32) = 65536 temp mem
const maxTempMem = 262144 # 2¹⁸ = 262144
const maxNumPoints = maxTempMem div sizeof(ECP_ShortW_Aff[F, G])
# 262144 / (2*1024) = 128 bytes allowed per coordinates. Largest curve BW6-761 requires 96 bytes per coordinate. And G2 is on Fp, not Fp2.
static: doAssert minNumPointsParallel <= maxNumPoints, "The curve " & $r.typeof & " requires large size and needs to be tuned."
if points.len < minNumPointsParallel:
r.sum_batch_vartime(points)
return
let chunkDesc = computeBalancedChunks(
start = 0, stopEx = points.len,
minNumPointsParallel, maxNumPoints,
targetNumChunks = tp.numThreads.int)
let partialResults = allocStackArray(r.typeof(), chunkDesc.numChunks)
for iter in items(chunkDesc):
proc sum_batch_vartime_wrapper(res: ptr, p: ptr, pLen: int) {.nimcall.} =
# The borrow checker prevents capturing `var` and `openArray`
# so we capture pointers instead.
res[].sum_batch_vartime(p, pLen)
tp.spawn partialResults[iter.chunkID].addr.sum_batch_vartime_wrapper(
points.asUnchecked() +% iter.start,
iter.stopEx - iter.start)
tp.syncAll() # TODO: this prevents nesting in another parallel region
const minNumPointsSerial = 16
if chunkDesc.numChunks < minNumPointsSerial:
r.setInf()
for i in 0 ..< chunkDesc.numChunks:
r += partialResults[i]
else:
let partialResultsAffine = allocStackArray(ECP_ShortW_Aff[F, G], chunkDesc.numChunks)
partialResultsAffine.batchAffine(partialResults, chunkDesc.numChunks)
r.sum_batch_vartime(partialResultsAffine, chunkDesc.numChunks)
# Sanity checks
# ---------------------------------------
when isMainModule:
block:
let chunkDesc = computeBalancedChunks(start = 0, stopEx = 40, minChunkSize = 16, maxChunkSize = 128, targetNumChunks = 12)
for chunk in chunkDesc:
echo chunk
block:
let chunkDesc = computeBalancedChunks(start = 0, stopEx = 10000, minChunkSize = 16, maxChunkSize = 128, targetNumChunks = 12)
for chunk in chunkDesc:
echo chunk

View File

@ -0,0 +1,224 @@
Deriving efficient and complete Jacobian formulae
=================================================
We are looking for a complete addition formula,
that minimize overhead over classic addition formulae
from the litterature
and can handle all inputs.
We recall the basic affine addition and doubling formulae
```
P + Q = R
(Px, Py) + (Qx, Qy) = (Rx, Ry)
with
Rx = λ² - Px - Qx
Ry = λ(Px - Rx) - Py
and
λadd = (Qy - Py) / (Px - Qx)
λdbl = (3 Px² + a) / (2 Px)
```
Which is also called the "chord-and-tangent" group law.
Notice that if Px == Qx, addition is undefined, this can happen in 2 cases
- P == Q, in that case we need to double
- P == -Q, since -(x,y) = (x,-y) for elliptic curves. In that case we need infinity.
Concretely, that means that it is non-trivial to make the code constant-time
whichever case we are.
Furthermore, Renes et al 2015 which introduced complete addition formulae for projective coordinates
demonstrated that such a law cannot be as efficient for the Jacobian coordinates we are interested in.
Since we can't rely on math, we will need to rely on implementation "details" to achieve our goals.
First we look back in history at Brier and Joye 2002 unified formulae which uses the same code for addition and doubling:
```
λ = ((x₁+x₂)² - x₁x₂ + a)/(y₁+y₂)
x₃ = λ² - (x₁+x₂)
2y₃= λ(x₁+x₂-2x₃) - (y₁+y₂)
```
Alas we traded exceptions depending on the same coordinate x
for exceptions on negated coordinate y.
This can still happen for P=-Q but also for "unrelated" numbers.
> We recall that curves with equation `y² = x³ + b` are chosen so that there exist a cubic root of unity modulo r the curve order.
> Hence x³ ≡ 1 (mod r), we call that root ω
> And so we have y² = (ωx)³ + b describing a valid point with coordinate (ωx, y)
> Hence the unified formula cannot handle (x, y) + (ωx, -y)
> All pairings curves and secp256k1 have that equation form.
Now, all hope is not lost, we recall that unlike in math,
in actual implementation we havean excellent tool called conditional copy
so that we can ninja-swap our terms
provided addition and doubling are resembling each other.
So let's look at the current state of the art formulae.
I have added the spots where we can detect special conditions like infinity points, doubling and negation,
and reorganized doubling operations so that they match multiplication/squarings in the addition law
Let's look first at Cohen et al, 1998 formulae
```
| Addition - Cohen et al | Doubling any a - Cohen et al | Doubling = -3 | Doubling a = 0 |
| 12M + 4S + 6add + 1*2 | 3M + 6S + 1*a + 4add + 2*2 + 1*3 + 1*4 + 1*8 | | |
|------------------------------|----------------------------------------------|-----------------|----------------|
| Z₁Z₁ = Z₁² | Z₁Z₁ = Z₁² | | |
| Z₂Z₂ = Z₂² | | | |
| | | | |
| U₁ = X₁*Z₂Z₂ | | | |
| U₂ = X₂*Z₁Z₁ | | | |
| S₁ = Y₁*Z₂*Z₂Z₂ | | | |
| S₂ = Y₂*Z₁*Z₁Z₁ | | | |
| H = U₂-U₁ # P=-Q, P=Inf, P=Q | | | |
| F = S₂-S₁ # Q=Inf | | | |
| | | | |
| HH = H² | YY = Y₁² | | |
| HHH = H*HH | M = 3*X₁²+a*ZZ² | 3(X₁-Z₁)(X₁+Z₁) | 3*X₁² |
| V = U₁*HH | S = 4*X₁*YY | | |
| | | | |
| X₃ = R²-HHH-2*V | X₃ = M²-2*S | | |
| Y₃ = R*(V-X₃)-S₁*HHH | Y₃ = M*(S-X₃)-8*YY² | | |
| Z₃ = Z₁*Z₂*H | Z₃ = 2*Y₁*Z₁ | | |
```
This is very promising, as the expensive multiplies and squares n doubling all have a corresponding sister operation.
Now for Bernstein et al 2007 formulae.
```
| Addition - Bernstein et al | Doubling any a - Bernstein et al | Doubling = -3 | Doubling a = 0 |
| 11M + 5S + 9add + 4*2 | 1M + 8S + 1*a + 10add + 2*2 + 1*3 + 1*8 | | |
|----------------------------------|-----------------------------------------------------|-----------------|----------------|
| Z₁Z₁ = Z₁² | Z₁Z₁ = Z₁² | | |
| Z₂Z₂ = Z₂² | | | |
| | | | |
| U₁ = X₁*Z₂Z₂ | | | |
| U₂ = X₂*Z₁Z₁ | | | |
| S₁ = Y₁*Z₂*Z₂Z₂ | | | |
| S₂ = Y₂*Z₁*Z₁Z₁ | | | |
| H = U₂-U₁ # P=-Q, P=Inf, P=Q | | | |
| R = 2*(S₂-S₁) # Q=Inf | | | |
| | | | |
| | XX = X₁² (no matching op in addition, extra square) | | |
| | YYYY (no matching op in addition, extra 2 squares) | | |
| | | | |
| I = (2*H)² | YY = Y₁² | | |
| J = H*I | M = 3*X₁²+a*ZZ² | 3(X₁-Z₁)(X₁+Z₁) | 3*X₁² |
| V = U₁*I | S = 2*((X₁+YY)²-XX-YYYY) = 4*X₁*YY | | |
| | | | |
| X₃ = R²-J-2*V | X₃ = M²-2*S | | |
| Y₃ = R*(V-X₃)-2*S₁*J | Y₃ = M*(S-X₃)-8*YYYY | | |
| Z₃ = ((Z₁+Z₂)²-Z₁Z₁-Z₂Z₂)*H | Z₃ = (Y₁+Z₁)² - YY - ZZ = 2*Y₁*Z₁ | | |
```
Bernstein et al rewrites multiplication into squaring and 2 substraction.
The first thing to note is that we can't use that trick to compute S in doubling
and keep doubling resembling addition as we have not computed XX or YYYY yet
and have no auspicious place to do so before.
The second thing to note is that in the addition, they had to scale Z₃ by 2
which scaled X₃ by 4 and Y₃ by 8, leading to the doubling in I, r coefficients
Ultimately, it saves 1 mul but it costs 1S 3A 3*2. Here are some benchmarks for reference
| Operation | Fp[BLS12_381] (cycles) | Fp2[BLS12_381] (cycles) | Fp4[BLS12_381] (cycles) |
|-----------|------------------------|-------------------------|-------------------------|
| Add | 14 | 24 | 47 |
| Sub | 12 | 24 | 46 |
| Ccopy | 5 | 10 | 20 |
| Div2 | 14 | 23 | 42 |
| Mul | 81 | 337 | 1229 |
| Sqr | 81 | 231 | 939 |
On G1 this is not good enough
On G2 it is still not good enough
On G4 (BLS24) or G8 (BLS48) we can revisit the decision.
Let's focus back to Cohen formulae
```
| Addition - Cohen et al | Doubling any a - Cohen et al | Doubling = -3 | Doubling a = 0 |
| 12M + 4S + 6add + 1*2 | 3M + 6S + 1*a + 4add + 2*2 + 1*3 + 1*4 + 1*8 | | |
|------------------------------|----------------------------------------------|-----------------|----------------|
| Z₁Z₁ = Z₁² | Z₁Z₁ = Z₁² | | |
| Z₂Z₂ = Z₂² | | | |
| | | | |
| U₁ = X₁*Z₂Z₂ | | | |
| U₂ = X₂*Z₁Z₁ | | | |
| S₁ = Y₁*Z₂*Z₂Z₂ | | | |
| S₂ = Y₂*Z₁*Z₁Z₁ | | | |
| H = U₂-U₁ # P=-Q, P=Inf, P=Q | | | |
| R = S₂-S₁ # Q=Inf | | | |
| | | | |
| HH = H² | YY = Y₁² | | |
| HHH = H*HH | M = 3*X₁²+a*ZZ² | 3(X₁-Z₁)(X₁+Z₁) | 3*X₁² |
| V = U₁*HH | S = 4*X₁*YY | | |
| | | | |
| X₃ = R²-HHH-2*V | X₃ = M²-2*S | | |
| Y₃ = R*(V-X₃)-S₁*HHH | Y₃ = M*(S-X₃)-8*YY² | | |
| Z₃ = Z₁*Z₂*H | Z₃ = 2*Y₁*Z₁ | | |
```
> Reminder: Jacobian coordinates are related to affine coordinate
> the following way (X, Y) <-> (X Z², Y Z³, Z)
The 2, 4, 8 coefficients in respectively `Z₃=2Y₁Z₁`, `S=4X₁YY` and `Y₃=M(S-X₃)-8YY²`
are not in line with the addition.
2 solutions:
- either we scale the addition Z₃ by 2, which will scale X₃ by 4 and Y₃ by 8 just like Bernstein et al.
- or we scale the doubling Z₃ by ½, which will scale X₃ by ¼ and Y₃ by ⅛. This is what Bos et al 2014 does for a=-3 curves.
We generalize their approach to all curves and obtain
```
| Addition (Cohen et al) | Doubling any a (adapted Bos et al, Cohen et al) | Doubling = -3 | Doubling a = 0 |
| 12M + 4S + 6add + 1*2 | 3M + 6S + 1*a + 4add + 1*2 + 1*3 + 1half | | |
| ----------------------------- | ----------------------------------------------- | ----------------- | -------------- |
| Z₁Z₁ = Z₁² | Z₁Z₁ = Z₁² | | |
| Z₂Z₂ = Z₂² | | | |
| | | | |
| U₁ = X₁*Z₂Z₂ | | | |
| U₂ = X₂*Z₁Z₁ | | | |
| S₁ = Y₁*Z₂*Z₂Z₂ | | | |
| S₂ = Y₂*Z₁*Z₁Z₁ | | | |
| H = U₂-U₁ # P=-Q, P=Inf, P=Q | | | |
| R = S₂-S₁ # Q=Inf | | | |
| | | | |
| HH = H² | YY = Y₁² | | |
| HHH = H*HH | M = (3*X₁²+a*ZZ²)/2 | 3(X₁-Z₁)(X₁+Z₁)/2 | 3X₁²/2 |
| V = U₁*HH | S = X₁*YY | | |
| | | | |
| X₃ = R²-HHH-2*V | X₃ = M²-2*S | | |
| Y₃ = R*(V-X₃)-S₁*HHH | Y₃ = M*(S-X₃)-YY² | | |
| Z₃ = Z₁*Z₂*H | Z₃ = Y₁*Z₁ | | |
```
So we actually replaced 1 doubling, 1 quadrupling, 1 octupling by 1 halving, which has the same cost as doubling/addition.
We could use that for elliptic curve over Fp and Fp2.
For elliptic curve over Fp4 and Fp8 (BLS24 and BLS48) the gap between multiplication and square is large enough
that replacing a multiplication by squaring + 2 substractions and extra bookkeeping is worth it,
we could use this formula instead:
```
| Addition (adapted Bernstein et al) | Doubling any a (adapted Bernstein) | Doubling = -3 | Doubling a = 0 |
| 11M + 5S + 9add + 4*2 | 2M + 7S + 1*a + 7add + 2*2+1*3+1*4+1*8 | | |
| ---------------------------------- | ---------------------------------------- | --------------- | -------------- |
| Z₁Z₁ = Z₁² | Z₁Z₁ = Z₁² | | |
| Z₂Z₂ = Z₂² | | | |
| | | | |
| U₁ = X₁*Z₂Z₂ | | | |
| U₂ = X₂*Z₁Z₁ | | | |
| S₁ = Y₁*Z₂*Z₂Z₂ | | | |
| S₂ = Y₂*Z₁*Z₁Z₁ | | | |
| H = U₂-U₁ # P=-Q, P=Inf, P=Q | | | |
| R = 2*(S₂-S₁) # Q=Inf | | | |
| | | | |
| I = (2*H)² | YY = Y₁² | | |
| J = H*I | M = 3*X₁²+a*ZZ² | 3(X₁-Z₁)(X₁+Z₁) | 3*X₁² |
| V = U₁*I | S = 4*X₁*YY | | |
| | | | |
| X₃ = R²-J-2*V | X₃ = M²-2*S | | |
| Y₃ = R*(V-X₃)-2*S₁*J | Y₃ = M*(S-X₃)-8*YY² | | |
| Z₃ = ((Z₁+Z₂)²-Z₁Z₁-Z₂Z₂)*H | Z₃ = (Y₁+Z₁)² - YY - ZZ | | |
```

View File

@ -98,7 +98,7 @@ func trySetFromCoordsXandZ*[F; G](
## ##
## Note: Dedicated robust procedures for hashing-to-curve ## Note: Dedicated robust procedures for hashing-to-curve
## will be provided, this is intended for testing purposes. ## will be provided, this is intended for testing purposes.
## ##
## For **test case generation only**, ## For **test case generation only**,
## this is preferred to generating random point ## this is preferred to generating random point
## via random scalar multiplication of the curve generator ## via random scalar multiplication of the curve generator
@ -130,7 +130,7 @@ func trySetFromCoordX*[F; G](
## ##
## Note: Dedicated robust procedures for hashing-to-curve ## Note: Dedicated robust procedures for hashing-to-curve
## will be provided, this is intended for testing purposes. ## will be provided, this is intended for testing purposes.
## ##
## For **test case generation only**, ## For **test case generation only**,
## this is preferred to generating random point ## this is preferred to generating random point
## via random scalar multiplication of the curve generator ## via random scalar multiplication of the curve generator
@ -516,7 +516,7 @@ func madd*[F; G: static Subgroup](
b.square(Z1Z1) b.square(Z1Z1)
# b.mulCheckSparse(CoefA) # TODO: broken static compile-time type inference # b.mulCheckSparse(CoefA) # TODO: broken static compile-time type inference
b *= CoefA # b = αZZ, with α the "a" coefficient of the curve b *= CoefA # b = αZZ, with α the "a" coefficient of the curve
a += b a += b
a.div2() a.div2()
R_or_M.ccopy(a, isDbl) # (3X₁² - αZZ)/2 R_or_M.ccopy(a, isDbl) # (3X₁² - αZZ)/2
@ -550,7 +550,7 @@ func madd*[F; G: static Subgroup](
o.x.ccopy(Q.x, P.isInf()) o.x.ccopy(Q.x, P.isInf())
o.y.ccopy(Q.y, P.isInf()) o.y.ccopy(Q.y, P.isInf())
o.z.csetOne(P.isInf()) o.z.csetOne(P.isInf())
o.ccopy(P, Q.isInf()) o.ccopy(P, Q.isInf())
r = o r = o
@ -655,291 +655,3 @@ func fromAffine*[F; G](
jac.x = aff.x jac.x = aff.x
jac.y = aff.y jac.y = aff.y
jac.z.setOne() jac.z.setOne()
func batchAffine*[N: static int, F, G](
affs: var array[N, ECP_ShortW_Aff[F, G]],
jacs: array[N, ECP_ShortW_Jac[F, G]]) =
# Algorithm: Montgomery's batch inversion
# - Speeding the Pollard and Elliptic Curve Methods of Factorization
# Section 10.3.1
# Peter L. Montgomery
# https://www.ams.org/journals/mcom/1987-48-177/S0025-5718-1987-0866113-7/S0025-5718-1987-0866113-7.pdf
# - Modern Computer Arithmetic
# Section 2.5.1 Several inversions at once
# Richard P. Brent and Paul Zimmermann
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
# To avoid temporaries, we store partial accumulations
# in affs[i].x and whether z == 0 in affs[i].y
var zeroes: array[N, SecretBool]
affs[0].x = jacs[0].z
zeroes[0] = affs[0].x.isZero()
affs[0].x.csetOne(zeroes[0])
for i in 1 ..< N:
# Skip zero z-coordinates (infinity points)
var z = jacs[i].z
zeroes[i] = z.isZero()
z.csetOne(zeroes[i])
if i != N-1:
affs[i].x.prod(affs[i-1].x, z, skipFinalSub = true)
else:
affs[i].x.prod(affs[i-1].x, z, skipFinalSub = false)
var accInv {.noInit.}: F
accInv.inv(affs[N-1].x)
for i in countdown(N-1, 1):
# Extract 1/Pᵢ
var invi {.noInit.}: F
invi.prod(accInv, affs[i-1].x, skipFinalSub = true)
invi.csetZero(zeroes[i])
# Now convert Pᵢ to affine
var invi2 {.noinit.}: F
invi2.square(invi, skipFinalSub = true)
affs[i].x.prod(jacs[i].x, invi2)
invi.prod(invi, invi2, skipFinalSub = true)
affs[i].y.prod(jacs[i].y, invi)
# next iteration
invi = jacs[i].z
invi.csetOne(zeroes[i])
accInv.prod(accInv, invi, skipFinalSub = true)
block: # tail
var invi2 {.noinit.}: F
accInv.csetZero(zeroes[0])
invi2.square(accInv, skipFinalSub = true)
affs[0].x.prod(jacs[0].x, invi2)
accInv.prod(accInv, invi2, skipFinalSub = true)
affs[0].y.prod(jacs[0].y, accInv)
# ############################################################
# #
# Deriving efficient and complete Jacobian formulae #
# #
# ############################################################
#
# We are looking for a complete addition formula,
# that minimize overhead over classic addition formulae
# from the litterature
# and can handle all inputs.
#
# We recall the basic affine addition and doubling formulae
#
# ```
# P + Q = R
# (Px, Py) + (Qx, Qy) = (Rx, Ry)
#
# with
# Rx = λ² - Px - Qx
# Ry = λ(Px - Rx) - Py
# and
# λadd = (Qy - Py) / (Px - Qx)
# λdbl = (3 Px² + a) / (2 Px)
# ```
#
# Which is also called the "chord-and-tangent" group law.
# Notice that if Px == Qx, addition is undefined, this can happen in 2 cases
# - P == Q, in that case we need to double
# - P == -Q, since -(x,y) = (x,-y) for elliptic curves. In that case we need infinity.
#
# Concretely, that means that it is non-trivial to make the code constant-time
# whichever case we are.
# Furthermore, Renes et al 2015 which introduced complete addition formulae for projective coordinates
# demonstrated that such a law cannot be as efficient for the Jacobian coordinates we are interested in.
#
# Since we can't rely on math, we will need to rely on implementation "details" to achieve our goals.
# First we look back in history at Brier and Joye 2002 unified formulae which uses the same code for addition and doubling:
#
# ```
# λ = ((x₁+x₂)² - x₁x₂ + a)/(y₁+y₂)
# x₃ = λ² - (x₁+x₂)
# 2y₃= λ(x₁+x₂-2x₃) - (y₁+y₂)
# ```
#
# Alas we traded exceptions depending on the same coordinate x
# for exceptions on negated coordinate y.
# This can still happen for P=-Q but also for "unrelated" numbers.
# > We recall that curves with equation `y² = x³ + b` are chosen so that there exist a cubic root of unity modulo r the curve order.
# > Hence x³ ≡ 1 (mod r), we call that root ω
# > And so we have y² = (ωx)³ + b describing a valid point with coordinate (ωx, y)
# > Hence the unified formula cannot handle (x, y) + (ωx, -y)
# > All pairings curves and secp256k1 have that equation form.
#
# Now, all hope is not lost, we recall that unlike in math,
# in actual implementation we havean excellent tool called conditional copy
# so that we can ninja-swap our terms
# provided addition and doubling are resembling each other.
#
# So let's look at the current state of the art formulae.
# I have added the spots where we can detect special conditions like infinity points, doubling and negation,
# and reorganized doubling operations so that they match multiplication/squarings in the addition law
#
# Let's look first at Cohen et al, 1998 formulae
#
# ```
# | Addition - Cohen et al | Doubling any a - Cohen et al | Doubling = -3 | Doubling a = 0 |
# | 12M + 4S + 6add + 1*2 | 3M + 6S + 1*a + 4add + 2*2 + 1*3 + 1*4 + 1*8 | | |
# |------------------------------|----------------------------------------------|-----------------|-------|
# | Z₁Z₁ = Z₁² | Z₁Z₁ = Z₁² | | |
# | Z₂Z₂ = Z₂² | | | |
# | | | | |
# | U₁ = X₁*Z₂Z₂ | | | |
# | U₂ = X₂*Z₁Z₁ | | | |
# | S₁ = Y₁*Z₂*Z₂Z₂ | | | |
# | S₂ = Y₂*Z₁*Z₁Z₁ | | | |
# | H = U₂-U₁ # P=-Q, P=Inf, P=Q | | | |
# | F = S₂-S₁ # Q=Inf | | | |
# | | | | |
# | HH = H² | YY = Y₁² | | |
# | HHH = H*HH | M = 3*X₁²+a*ZZ² | 3(X₁-Z₁)(X₁+Z₁) | 3*X₁² |
# | V = U₁*HH | S = 4*X₁*YY | | |
# | | | | |
# | X₃ = R²-HHH-2*V | X₃ = M²-2*S | | |
# | Y₃ = R*(V-X₃)-S₁*HHH | Y₃ = M*(S-X₃)-8*YY² | | |
# | Z₃ = Z₁*Z₂*H | Z₃ = 2*Y₁*Z₁ | | |
# ```
#
# This is very promising, as the expensive multiplies and squares n doubling all have a corresponding sister operation.
# Now for Bernstein et al 2007 formulae.
#
# ```
# | Addition - Bernstein et al | Doubling any a - Bernstein et al | Doubling = -3 | Doubling a = 0 |
# | 11M + 5S + 9add + 4*2 | 1M + 8S + 1*a + 10add + 2*2 + 1*3 + 1*8 | | |
# |----------------------------------|-----------------------------------------------------|-----------------|-------|
# | Z₁Z₁ = Z₁² | Z₁Z₁ = Z₁² | | |
# | Z₂Z₂ = Z₂² | | | |
# | | | | |
# | U₁ = X₁*Z₂Z₂ | | | |
# | U₂ = X₂*Z₁Z₁ | | | |
# | S₁ = Y₁*Z₂*Z₂Z₂ | | | |
# | S₂ = Y₂*Z₁*Z₁Z₁ | | | |
# | H = U₂-U₁ # P=-Q, P=Inf, P=Q | | | |
# | R = 2*(S₂-S₁) # Q=Inf | | | |
# | | | | |
# | | XX = X₁² (no matching op in addition, extra square) | | |
# | | YYYY (no matching op in addition, extra 2 squares) | | |
# | | | | |
# | I = (2*H)² | YY = Y₁² | | |
# | J = H*I | M = 3*X₁²+a*ZZ² | 3(X₁-Z₁)(X₁+Z₁) | 3*X₁² |
# | V = U₁*I | S = 2*((X₁+YY)²-XX-YYYY) = 4*X₁*YY | | |
# | | | | |
# | X₃ = R²-J-2*V | X₃ = M²-2*S | | |
# | Y₃ = R*(V-X₃)-2*S₁*J | Y₃ = M*(S-X₃)-8*YYYY | | |
# | Z₃ = ((Z₁+Z₂)²-Z₁Z₁-Z₂Z₂)*H | Z₃ = (Y₁+Z₁)² - YY - ZZ = 2*Y₁*Z₁ | | |
# ```
#
# Bernstein et al rewrites multiplication into squaring and 2 substraction.
#
# The first thing to note is that we can't use that trick to compute S in doubling
# and keep doubling resembling addition as we have not computed XX or YYYY yet
# and have no auspicious place to do so before.
#
# The second thing to note is that in the addition, they had to scale Z₃ by 2
# which scaled X₃ by 4 and Y₃ by 8, leading to the doubling in I, r coefficients
#
# Ultimately, it saves 1 mul but it costs 1S 3A 3*2. Here are some benchmarks for reference
#
# | Operation | Fp[BLS12_381] (cycles) | Fp2[BLS12_381] (cycles) | Fp4[BLS12_381] (cycles) |
# |-----------|------------------------|-------------------------|-------------------------|
# | Add | 14 | 24 | 47 |
# | Sub | 12 | 24 | 46 |
# | Ccopy | 5 | 10 | 20 |
# | Div2 | 14 | 23 | 42 |
# | Mul | 81 | 337 | 1229 |
# | Sqr | 81 | 231 | 939 |
#
# On G1 this is not good enough
# On G2 it is still not good enough
# On G4 (BLS24) or G8 (BLS48) we can revisit the decision.
#
# Let's focus back to Cohen formulae
#
# ```
# | Addition - Cohen et al | Doubling any a - Cohen et al | Doubling = -3 | Doubling a = 0 |
# | 12M + 4S + 6add + 1*2 | 3M + 6S + 1*a + 4add + 2*2 + 1*3 + 1*4 + 1*8 | | |
# |------------------------------|----------------------------------------------|-----------------|-------|
# | Z₁Z₁ = Z₁² | Z₁Z₁ = Z₁² | | |
# | Z₂Z₂ = Z₂² | | | |
# | | | | |
# | U₁ = X₁*Z₂Z₂ | | | |
# | U₂ = X₂*Z₁Z₁ | | | |
# | S₁ = Y₁*Z₂*Z₂Z₂ | | | |
# | S₂ = Y₂*Z₁*Z₁Z₁ | | | |
# | H = U₂-U₁ # P=-Q, P=Inf, P=Q | | | |
# | R = S₂-S₁ # Q=Inf | | | |
# | | | | |
# | HH = H² | YY = Y₁² | | |
# | HHH = H*HH | M = 3*X₁²+a*ZZ² | 3(X₁-Z₁)(X₁+Z₁) | 3*X₁² |
# | V = U₁*HH | S = 4*X₁*YY | | |
# | | | | |
# | X₃ = R²-HHH-2*V | X₃ = M²-2*S | | |
# | Y₃ = R*(V-X₃)-S₁*HHH | Y₃ = M*(S-X₃)-8*YY² | | |
# | Z₃ = Z₁*Z₂*H | Z₃ = 2*Y₁*Z₁ | | |
# ```
#
# > Reminder: Jacobian coordinates are related to affine coordinate
# > the following way (X, Y) <-> (X Z², Y Z³, Z)
#
# The 2, 4, 8 coefficients in respectively `Z₃=2Y₁Z₁`, `S=4X₁YY` and `Y₃=M(S-X₃)-8YY²`
# are not in line with the addition.
# 2 solutions:
# - either we scale the addition Z₃ by 2, which will scale X₃ by 4 and Y₃ by 8 just like Bernstein et al.
# - or we scale the doubling Z₃ by ½, which will scale X₃ by ¼ and Y₃ by ⅛. This is what Bos et al 2014 does for a=-3 curves.
#
# We generalize their approach to all curves and obtain
#
# ```
# | Addition (Cohen et al) | Doubling any a (adapted Bos et al, Cohen et al) | Doubling = -3 | Doubling a = 0 |
# | 12M + 4S + 6add + 1*2 | 3M + 6S + 1*a + 4add + 1*2 + 1*3 + 1half | | |
# | ----------------------------- | ----------------------------------------------- | ----------------- | -------------- |
# | Z₁Z₁ = Z₁² | Z₁Z₁ = Z₁² | | |
# | Z₂Z₂ = Z₂² | | | |
# | | | | |
# | U₁ = X₁*Z₂Z₂ | | | |
# | U₂ = X₂*Z₁Z₁ | | | |
# | S₁ = Y₁*Z₂*Z₂Z₂ | | | |
# | S₂ = Y₂*Z₁*Z₁Z₁ | | | |
# | H = U₂-U₁ # P=-Q, P=Inf, P=Q | | | |
# | R = S₂-S₁ # Q=Inf | | | |
# | | | | |
# | HH = H² | YY = Y₁² | | |
# | HHH = H*HH | M = (3*X₁²+a*ZZ²)/2 | 3(X₁-Z₁)(X₁+Z₁)/2 | 3X₁²/2 |
# | V = U₁*HH | S = X₁*YY | | |
# | | | | |
# | X₃ = R²-HHH-2*V | X₃ = M²-2*S | | |
# | Y₃ = R*(V-X₃)-S₁*HHH | Y₃ = M*(S-X₃)-YY² | | |
# | Z₃ = Z₁*Z₂*H | Z₃ = Y₁*Z₁ | | |
# ```
#
# So we actually replaced 1 doubling, 1 quadrupling, 1 octupling by 1 halving, which has the same cost as doubling/addition.
# We could use that for elliptic curve over Fp and Fp2.
# For elliptic curve over Fp4 and Fp8 (BLS24 and BLS48) the gap between multiplication and square is large enough
# that replacing a multiplication by squaring + 2 substractions and extra bookkeeping is worth it,
# we could use this formula instead:
#
# ```
# | Addition (adapted Bernstein et al) | Doubling any a (adapted Bernstein) | Doubling = -3 | Doubling a = 0 |
# | 11M + 5S + 9add + 4*2 | 2M + 7S + 1*a + 7add + 2*2+1*3+1*4+1*8 | | |
# | ---------------------------------- | ---------------------------------------- | --------------- | -------------- |
# | Z₁Z₁ = Z₁² | Z₁Z₁ = Z₁² | | |
# | Z₂Z₂ = Z₂² | | | |
# | | | | |
# | U₁ = X₁*Z₂Z₂ | | | |
# | U₂ = X₂*Z₁Z₁ | | | |
# | S₁ = Y₁*Z₂*Z₂Z₂ | | | |
# | S₂ = Y₂*Z₁*Z₁Z₁ | | | |
# | H = U₂-U₁ # P=-Q, P=Inf, P=Q | | | |
# | R = 2*(S₂-S₁) # Q=Inf | | | |
# | | | | |
# | I = (2*H)² | YY = Y₁² | | |
# | J = H*I | M = 3*X₁²+a*ZZ² | 3(X₁-Z₁)(X₁+Z₁) | 3*X₁² |
# | V = U₁*I | S = 4*X₁*YY | | |
# | | | | |
# | X₃ = R²-J-2*V | X₃ = M²-2*S | | |
# | Y₃ = R*(V-X₃)-2*S₁*J | Y₃ = M*(S-X₃)-8*YY² | | |
# | Z₃ = ((Z₁+Z₂)²-Z₁Z₁-Z₂Z₂)*H | Z₃ = (Y₁+Z₁)² - YY - ZZ | | |
# ```

View File

@ -92,7 +92,7 @@ func trySetFromCoordsXandZ*[F; G](
## ##
## Note: Dedicated robust procedures for hashing-to-curve ## Note: Dedicated robust procedures for hashing-to-curve
## will be provided, this is intended for testing purposes. ## will be provided, this is intended for testing purposes.
## ##
## For **test case generation only**, ## For **test case generation only**,
## this is preferred to generating random point ## this is preferred to generating random point
## via random scalar multiplication of the curve generator ## via random scalar multiplication of the curve generator
@ -121,7 +121,7 @@ func trySetFromCoordX*[F; G](
## ##
## Note: Dedicated robust procedures for hashing-to-curve ## Note: Dedicated robust procedures for hashing-to-curve
## will be provided, this is intended for testing purposes. ## will be provided, this is intended for testing purposes.
## ##
## For **test case generation only**, ## For **test case generation only**,
## this is preferred to generating random point ## this is preferred to generating random point
## via random scalar multiplication of the curve generator ## via random scalar multiplication of the curve generator
@ -258,7 +258,7 @@ func madd*[F; G: static Subgroup](
## with p in Projective coordinates and Q in affine coordinates ## with p in Projective coordinates and Q in affine coordinates
## ##
## R = P + Q ## R = P + Q
## ##
## ``r`` may alias P ## ``r`` may alias P
when F.C.getCoefA() == 0: when F.C.getCoefA() == 0:
@ -275,7 +275,7 @@ func madd*[F; G: static Subgroup](
# #
# Note¹⁰ mentions that due to Qz = 1, cannot be # Note¹⁰ mentions that due to Qz = 1, cannot be
# the point at infinity. # the point at infinity.
# We solve that by conditional copies. # We solve that by conditional copies.
t0.prod(P.x, Q.x) # 1. t₀ <- X₁ X₂ t0.prod(P.x, Q.x) # 1. t₀ <- X₁ X₂
t1.prod(P.y, Q.y) # 2. t₁ <- Y₁ Y₂ t1.prod(P.y, Q.y) # 2. t₁ <- Y₁ Y₂
t3.sum(P.x, P.y) # 3. t₃ <- X₁ + Y₁ ! error in paper t3.sum(P.x, P.y) # 3. t₃ <- X₁ + Y₁ ! error in paper
@ -314,7 +314,7 @@ func madd*[F; G: static Subgroup](
t0 *= t3 # 31. t₀ <- t₀ t₃, t₀ = 3X₁X₂ (X₁Y₂ + X₂Y₁) t0 *= t3 # 31. t₀ <- t₀ t₃, t₀ = 3X₁X₂ (X₁Y₂ + X₂Y₁)
z3 *= t4 # 32. Z₃ <- Z₃ t₄, Z₃ = (Y₁Y₂ + 3bZ₁)(Y₁ + Y₂Z₁) z3 *= t4 # 32. Z₃ <- Z₃ t₄, Z₃ = (Y₁Y₂ + 3bZ₁)(Y₁ + Y₂Z₁)
z3 += t0 # 33. Z₃ <- Z₃ + t₀, Z₃ = (Y₁ + Y₂Z₁)(Y₁Y₂ + 3bZ₁) + 3X₁X₂ (X₁Y₂ + X₂Y₁) z3 += t0 # 33. Z₃ <- Z₃ + t₀, Z₃ = (Y₁ + Y₂Z₁)(Y₁Y₂ + 3bZ₁) + 3X₁X₂ (X₁Y₂ + X₂Y₁)
# Deal with infinity point. r and P might alias. # Deal with infinity point. r and P might alias.
let inf = Q.isInf() let inf = Q.isInf()
x3.ccopy(P.x, inf) x3.ccopy(P.x, inf)
@ -441,57 +441,3 @@ func fromAffine*[F, G](
proj.x = aff.x proj.x = aff.x
proj.y = aff.y proj.y = aff.y
proj.z.setOne() proj.z.setOne()
func batchAffine*[N: static int, F, G](
affs: var array[N, ECP_ShortW_Aff[F, G]],
projs: array[N, ECP_ShortW_Prj[F, G]]) =
# Algorithm: Montgomery's batch inversion
# - Speeding the Pollard and Elliptic Curve Methods of Factorization
# Section 10.3.1
# Peter L. Montgomery
# https://www.ams.org/journals/mcom/1987-48-177/S0025-5718-1987-0866113-7/S0025-5718-1987-0866113-7.pdf
# - Modern Computer Arithmetic
# Section 2.5.1 Several inversions at once
# Richard P. Brent and Paul Zimmermann
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
# To avoid temporaries, we store partial accumulations
# in affs[i].x
var zeroes: array[N, SecretBool]
affs[0].x = projs[0].z
zeroes[0] = affs[0].x.isZero()
affs[0].x.csetOne(zeroes[0])
for i in 1 ..< N:
# Skip zero z-coordinates (infinity points)
var z = projs[i].z
zeroes[i] = z.isZero()
z.csetOne(zeroes[i])
if i != N-1:
affs[i].x.prod(affs[i-1].x, z, skipFinalSub = true)
else:
affs[i].x.prod(affs[i-1].x, z, skipFinalSub = false)
var accInv {.noInit.}: F
accInv.inv(affs[N-1].x)
for i in countdown(N-1, 1):
# Extract 1/Pᵢ
var invi {.noInit.}: F
invi.prod(accInv, affs[i-1].x, skipFinalSub = true)
invi.csetZero(zeroes[i])
# Now convert Pᵢ to affine
affs[i].x.prod(projs[i].x, invi)
affs[i].y.prod(projs[i].y, invi)
# next iteration
invi = projs[i].z
invi.csetOne(zeroes[i])
accInv.prod(accInv, invi, skipFinalSub = true)
block: # tail
accInv.csetZero(zeroes[0])
affs[0].x.prod(projs[0].x, accInv)
affs[0].y.prod(projs[0].y, accInv)

View File

@ -74,7 +74,7 @@ template allocStackUnchecked*(T: typedesc, size: int): ptr T =
## Stack allocation for types containing a variable-sized UncheckedArray field ## Stack allocation for types containing a variable-sized UncheckedArray field
cast[ptr T](alloca(size)) cast[ptr T](alloca(size))
template allocStackArray*(T: typedesc, len: Natural): ptr UncheckedArray[T] = template allocStackArray*(T: typedesc, len: SomeInteger): ptr UncheckedArray[T] =
cast[ptr UncheckedArray[T]](alloca(sizeof(T) * len)) cast[ptr UncheckedArray[T]](alloca(sizeof(T) * len))
# Heap allocation # Heap allocation
@ -88,7 +88,7 @@ proc allocHeapUnchecked*(T: typedesc, size: int): ptr T {.inline.} =
cast[type result](malloc(size)) cast[type result](malloc(size))
proc allocHeapArray*(T: typedesc, len: SomeInteger): ptr UncheckedArray[T] {.inline.} = proc allocHeapArray*(T: typedesc, len: SomeInteger): ptr UncheckedArray[T] {.inline.} =
cast[type result](malloc(len*sizeof(T))) cast[type result](malloc(sizeof(T) * len))
proc freeHeap*(p: pointer) {.inline.} = proc freeHeap*(p: pointer) {.inline.} =
free(p) free(p)

View File

@ -85,6 +85,8 @@ proc notify*(en: var EventNotifier) {.inline.} =
type type
Eventcount* = object Eventcount* = object
## The lock-free equivalent of a condition variable. ## The lock-free equivalent of a condition variable.
## Supports up to 256 threads on 32-bit.
## Supports up to 65536 threads on 64-bit.
## ##
## Usage, if a thread needs to be parked until a condition is true ## Usage, if a thread needs to be parked until a condition is true
## and signaled by another thread: ## and signaled by another thread:
@ -101,18 +103,20 @@ type
## ec.sleep() ## ec.sleep()
## ``` ## ```
state: Atomic[uint32] state: Atomic[uint]
# State is actually the equivalent of a bitfield # State is actually the equivalent of a bitfield
# type State = object # type State = object
# waiters {.bitsize: 16.}: uint16 # when sizeof(uint) == 8:
# when sizeof(pointer) == 4: # waiters {.bitsize: 16.}: uint16
# epoch {.bitsize: 16.}: uint16 # preWaiters {.bitsize: 16.}: uint16
# epoch {.bitsize: 32.}: uint32
# else: # else:
# epoch {.bitsize: 48.}: uint48 # waiters {.bitsize: 8.}: uint8
# preWaiters {.bitsize: 8.}: uint8
# epoch {.bitsize: 16.}: uint16
# #
# of size, the native integer size # but there is no native fetchAdd for bitfields.
# and so can be used for atomic operations on 32-bit or 64-bit platforms.
# but there is no native fetchAdd for bitfield
futex: Futex futex: Futex
# Technically we could use the futex as the state. # Technically we could use the futex as the state.
# When you wait on a Futex, it waits only if the value of the futex # When you wait on a Futex, it waits only if the value of the futex
@ -128,18 +132,42 @@ type
epoch: uint32 epoch: uint32
const # bitfield const # bitfield
# Low 16 bits are waiters, up to 2¹⁶ = 65536 threads are supported # On 32-bit
# High 16 or 48 bits are epochs. # Low 8 bits are waiters, up to 2⁸ = 256 threads are supported
# We can deal with the ABA problem o: # Next 8 bits are pre-waiters, planning to wait but not committed.
# - up to 65536 wake requests on 32-bit # Next 16 bits is the epoch.
# - up to 281 474 976 710 656 wake requests on 64-bit # The epoch deals with the ABA problem
# Epoch rolling over to 0 are not a problem, they won't change the low 16 bits # - up to 65536 wake requests on 32-bit
kEpochShift = 16 # Epoch rolling over to 0 are not a problem, they won't change the low 16 bits.
kAddEpoch = 1 shl kEpochShift # On 64-bit
kWaiterMask = kAddEpoch - 1 # Low 16 bits are waiters, up to 2¹⁶ = 65536 threads are supported
kEpochMask {.used.} = not kWaiterMask # Next 16 bits are pre-waiters, planning to wait but not committed.
kAddWaiter = 1 # Next 32 bits is the epoch.
kSubWaiter = 1 # The epoch deals with the ABA problem
# - up to 4 294 967 296 wake requests on 64-bit
# Epoch rolling over to 0 are not a problem, they won't change the low 16 bits.
#
# OS limitations:
# - Windows 10 supports up to 256 cores (https://www.microsoft.com/en-us/microsoft-365/blog/2017/12/15/windows-10-pro-workstations-power-advanced-workloads/)
# - Linux CPUSET supports up to 1024 threads (https://man7.org/linux/man-pages/man3/CPU_SET.3.html)
#
# Hardware limitations:
# - Xeon Platinum 9282, 56 cores - 112 threads
# - 8 sockets: 896 threads
scale = sizeof(uint) div 4 # 2 for 64-bit, 1 for 32-bit.
kEpochShift = 16'u * scale
kPreWaitShift = 8'u * scale
kEpoch = 1'u shl kEpochShift
kPreWait = 1'u shl kPreWaitShift
kWait = 1'u
kTransitionToWait = kWait - kPreWait
kWaitMask = kPreWait-1
kAnyWaiterMask = kEpoch-1
kPreWaitMask = kAnyWaiterMask xor kWaitMask # 0x0000FF00 on 32-bit
func initialize*(ec: var EventCount) {.inline.} = func initialize*(ec: var EventCount) {.inline.} =
ec.state.store(0, moRelaxed) ec.state.store(0, moRelaxed)
@ -151,36 +179,41 @@ func `=destroy`*(ec: var EventCount) {.inline.} =
proc sleepy*(ec: var Eventcount): ParkingTicket {.inline.} = proc sleepy*(ec: var Eventcount): ParkingTicket {.inline.} =
## To be called before checking if the condition to not sleep is met. ## To be called before checking if the condition to not sleep is met.
## Returns a ticket to be used when committing to sleep ## Returns a ticket to be used when committing to sleep
let prevState = ec.state.fetchAdd(kAddWaiter, moAcquireRelease) let prevState = ec.state.fetchAdd(kPreWait, moAcquireRelease)
result.epoch = prevState shr kEpochShift result.epoch = uint32(prevState shr kEpochShift)
proc sleep*(ec: var Eventcount, ticket: ParkingTicket) {.inline.} = proc sleep*(ec: var Eventcount, ticket: ParkingTicket) {.inline.} =
## Put a thread to sleep until notified. ## Put a thread to sleep until notified.
## If the ticket becomes invalid (a notfication has been received) ## If the ticket becomes invalid (a notfication has been received)
## by the time sleep is called, the thread won't enter sleep ## by the time sleep is called, the thread won't enter sleep
discard ec.state.fetchAdd(kTransitionToWait, moAcquireRelease)
while ec.state.load(moAcquire) shr kEpochShift == ticket.epoch: while ec.state.load(moAcquire) shr kEpochShift == ticket.epoch:
ec.futex.wait(ticket.epoch) # We don't use the futex internal value ec.futex.wait(ticket.epoch) # We don't use the futex internal value
let prev {.used.} = ec.state.fetchSub(kSubWaiter, moRelaxed) let prev {.used.} = ec.state.fetchSub(kWait, moRelaxed)
proc cancelSleep*(ec: var Eventcount) {.inline.} = proc cancelSleep*(ec: var Eventcount) {.inline.} =
## Cancel a sleep that was scheduled. ## Cancel a sleep that was scheduled.
let prev {.used.} = ec.state.fetchSub(kSubWaiter, moRelaxed) let prev {.used.} = ec.state.fetchSub(kPreWait, moRelaxed)
proc wake*(ec: var EventCount) {.inline.} = proc wake*(ec: var EventCount) {.inline.} =
## Wake a thread if at least 1 is parked ## Wake a thread if at least 1 is parked
let prev = ec.state.fetchAdd(kAddEpoch, moAcquireRelease) let prev = ec.state.fetchAdd(kEpoch, moAcquireRelease)
if (prev and kWaiterMask) != 0: if (prev and kAnyWaiterMask) != 0:
ec.futex.wake() ec.futex.wake()
proc wakeAll*(ec: var EventCount) {.inline.} = proc wakeAll*(ec: var EventCount) {.inline.} =
## Wake all threads if at least 1 is parked ## Wake all threads if at least 1 is parked
let prev = ec.state.fetchAdd(kAddEpoch, moAcquireRelease) let prev = ec.state.fetchAdd(kEpoch, moAcquireRelease)
if (prev and kWaiterMask) != 0: if (prev and kAnyWaiterMask) != 0:
ec.futex.wakeAll() ec.futex.wakeAll()
proc getNumWaiters*(ec: var EventCount): uint32 {.inline.} = proc getNumWaiters*(ec: var EventCount): tuple[preSleep, committedSleep: uint32] {.inline.} =
## Get the number of parked threads ## Get the number of idle threads:
ec.state.load(moRelaxed) and kWaiterMask ## (planningToSleep, committedToSleep)
let waiters = ec.state.load(moAcquire)
result.preSleep = uint32((waiters and kPreWaitMask) shr kPreWaitShift)
result.committedSleep = uint32(waiters and kWaitMask)
{.pop.} # {.push raises:[], checks:off.} {.pop.} # {.push raises:[], checks:off.}

View File

@ -76,8 +76,8 @@ proc peek*(tq: var Taskqueue): int =
## ##
## This is a non-locking operation. ## This is a non-locking operation.
let # Handle race conditions let # Handle race conditions
b = tq.back.load(moRelaxed) b = tq.back.load(moAcquire)
f = tq.front.load(moRelaxed) f = tq.front.load(moAcquire)
if b >= f: if b >= f:
return b-f return b-f

View File

@ -41,7 +41,7 @@ type
# Execution # Execution
# ------------------ # ------------------
fn*: proc (param: pointer) {.nimcall, gcsafe.} fn*: proc (param: pointer) {.nimcall, gcsafe, raises: [].}
# destroy*: proc (param: pointer) {.nimcall, gcsafe.} # Constantine only deals with plain old data # destroy*: proc (param: pointer) {.nimcall, gcsafe.} # Constantine only deals with plain old data
data*{.align:sizeof(int).}: UncheckedArray[byte] data*{.align:sizeof(int).}: UncheckedArray[byte]
@ -60,6 +60,7 @@ proc new*(
result = allocHeapUnchecked(T, size) result = allocHeapUnchecked(T, size)
result.parent = parent result.parent = parent
result.thiefID.store(SentinelThief, moRelaxed) result.thiefID.store(SentinelThief, moRelaxed)
result.hasFuture = false
result.completed.store(false, moRelaxed) result.completed.store(false, moRelaxed)
result.waiter.store(nil, moRelaxed) result.waiter.store(nil, moRelaxed)
result.fn = fn result.fn = fn
@ -67,7 +68,7 @@ proc new*(
proc new*( proc new*(
T: typedesc[Task], T: typedesc[Task],
parent: ptr Task, parent: ptr Task,
fn: proc (param: pointer) {.nimcall, gcsafe.}, fn: proc (param: pointer) {.nimcall, gcsafe, raises: [].},
params: auto): ptr Task {.inline.} = params: auto): ptr Task {.inline.} =
const size = sizeof(T) + # size without Unchecked const size = sizeof(T) + # size without Unchecked
@ -76,6 +77,7 @@ proc new*(
result = allocHeapUnchecked(T, size) result = allocHeapUnchecked(T, size)
result.parent = parent result.parent = parent
result.thiefID.store(SentinelThief, moRelaxed) result.thiefID.store(SentinelThief, moRelaxed)
result.hasFuture = false
result.completed.store(false, moRelaxed) result.completed.store(false, moRelaxed)
result.waiter.store(nil, moRelaxed) result.waiter.store(nil, moRelaxed)
result.fn = fn result.fn = fn

View File

@ -2,9 +2,12 @@ import ../threadpool
block: # Async without result block: # Async without result
proc displayInt(x: int) = proc displayInt(x: int) {.raises: [].} =
stdout.write(x) try:
stdout.write(" - SUCCESS\n") stdout.write(x)
stdout.write(" - SUCCESS\n")
except:
quit 1
proc main() = proc main() =
echo "\n==============================================================================================" echo "\n=============================================================================================="

View File

@ -8,7 +8,6 @@
import import
std/macros, std/macros,
./instrumentation,
./crossthread/tasks_flowvars ./crossthread/tasks_flowvars
# Task parallelism - spawn # Task parallelism - spawn
@ -43,12 +42,10 @@ proc spawnVoid(funcCall: NimNode, args, argsTy: NimNode, workerContext, schedule
# Create the async call # Create the async call
result.add quote do: result.add quote do:
proc `async_fn`(param: pointer) {.nimcall.} = proc `async_fn`(param: pointer) {.nimcall.} =
# preCondition: not isRootTask(`workerContext`.currentTask)
when bool(`withArgs`): when bool(`withArgs`):
let `data` = cast[ptr `argsTy`](param) let `data` = cast[ptr `argsTy`](param)
`fnCall` `fnCall`
# Create the task # Create the task
result.add quote do: result.add quote do:
block enq_deq_task: block enq_deq_task:
@ -110,8 +107,6 @@ proc spawnRet(funcCall: NimNode, retTy, args, argsTy: NimNode, workerContext, sc
result.add quote do: result.add quote do:
proc `async_fn`(param: pointer) {.nimcall.} = proc `async_fn`(param: pointer) {.nimcall.} =
# preCondition: not isRootTask(`workerContext`.currentTask)
let `data` = cast[ptr `futArgsTy`](param) let `data` = cast[ptr `futArgsTy`](param)
let res = `fnCall` let res = `fnCall`
readyWith(`data`[0], res) readyWith(`data`[0], res)
@ -136,7 +131,7 @@ proc spawnRet(funcCall: NimNode, retTy, args, argsTy: NimNode, workerContext, sc
proc spawnImpl*(tp: NimNode{nkSym}, funcCall: NimNode, workerContext, schedule: NimNode): NimNode = proc spawnImpl*(tp: NimNode{nkSym}, funcCall: NimNode, workerContext, schedule: NimNode): NimNode =
funcCall.expectKind(nnkCall) funcCall.expectKind(nnkCall)
# Get the return type if any # Get the return type if any
let retType = funcCall[0].getImpl[3][0] let retType = funcCall[0].getImpl[3][0]
let needFuture = retType.kind != nnkEmpty let needFuture = retType.kind != nnkEmpty
@ -157,4 +152,3 @@ proc spawnImpl*(tp: NimNode{nkSym}, funcCall: NimNode, workerContext, schedule:
# Wrap in a block for namespacing # Wrap in a block for namespacing
result = nnkBlockStmt.newTree(newEmptyNode(), result) result = nnkBlockStmt.newTree(newEmptyNode(), result)
# echo result.toStrLit

View File

@ -45,9 +45,9 @@ func pthread_barrier_init*(
barrier.lock.initLock() barrier.lock.initLock()
{.locks: [barrier.lock].}: {.locks: [barrier.lock].}:
barrier.cond.initCond() barrier.cond.initCond()
barrier.sense = false
barrier.left = count barrier.left = count
barrier.count = count barrier.count = count
# barrier.sense = false
proc pthread_barrier_wait*(barrier: var PthreadBarrier): Errno = proc pthread_barrier_wait*(barrier: var PthreadBarrier): Errno =
## Wait on `barrier` ## Wait on `barrier`

View File

@ -109,9 +109,9 @@ proc teardownWorker() =
workerContext.localBackoff.`=destroy`() workerContext.localBackoff.`=destroy`()
workerContext.taskqueue[].teardown() workerContext.taskqueue[].teardown()
proc eventLoop(ctx: var WorkerContext) {.raises:[Exception].} proc eventLoop(ctx: var WorkerContext) {.raises:[].}
proc workerEntryFn(params: tuple[threadpool: Threadpool, id: WorkerID]) {.raises: [Exception].} = proc workerEntryFn(params: tuple[threadpool: Threadpool, id: WorkerID]) {.raises: [].} =
## On the start of the threadpool workers will execute this ## On the start of the threadpool workers will execute this
## until they receive a termination signal ## until they receive a termination signal
# We assume that thread_local variables start all at their binary zero value # We assume that thread_local variables start all at their binary zero value
@ -146,7 +146,7 @@ proc workerEntryFn(params: tuple[threadpool: Threadpool, id: WorkerID]) {.raises
const ReadyFuture = cast[ptr EventNotifier](0xCA11AB1E) const ReadyFuture = cast[ptr EventNotifier](0xCA11AB1E)
const RootTask = cast[ptr Task](0xEFFACED0) const RootTask = cast[ptr Task](0xEFFACED0)
proc run*(ctx: var WorkerContext, task: ptr Task) {.raises:[Exception].} = proc run*(ctx: var WorkerContext, task: ptr Task) {.raises:[].} =
## Run a task, frees it if it is not owned by a Flowvar ## Run a task, frees it if it is not owned by a Flowvar
let suspendedTask = workerContext.currentTask let suspendedTask = workerContext.currentTask
ctx.currentTask = task ctx.currentTask = task
@ -178,7 +178,6 @@ proc schedule(ctx: var WorkerContext, tn: ptr Task, forceWake = false) {.inline.
# Lazy binary-splitting: a run-time adaptive work-stealing scheduler. # Lazy binary-splitting: a run-time adaptive work-stealing scheduler.
# In PPoPP 10, Bangalore, India, January 2010. ACM, pp. 179190. # In PPoPP 10, Bangalore, India, January 2010. ACM, pp. 179190.
# https://user.eng.umd.edu/~barua/ppopp164.pdf # https://user.eng.umd.edu/~barua/ppopp164.pdf
let wasEmpty = ctx.taskqueue[].peek() == 0 let wasEmpty = ctx.taskqueue[].peek() == 0
ctx.taskqueue[].push(tn) ctx.taskqueue[].push(tn)
if forceWake or wasEmpty: if forceWake or wasEmpty:
@ -326,9 +325,9 @@ proc tryLeapfrog(ctx: var WorkerContext, awaitedTask: ptr Task): ptr Task =
return leapTask return leapTask
return nil return nil
proc eventLoop(ctx: var WorkerContext) {.raises:[Exception].} = proc eventLoop(ctx: var WorkerContext) {.raises:[].} =
## Each worker thread executes this loop over and over. ## Each worker thread executes this loop over and over.
while not ctx.signal.terminate.load(moRelaxed): while true:
# 1. Pick from local queue # 1. Pick from local queue
debug: log("Worker %2d: eventLoop 1 - searching task from local queue\n", ctx.id) debug: log("Worker %2d: eventLoop 1 - searching task from local queue\n", ctx.id)
while (var task = ctx.taskqueue[].pop(); not task.isNil): while (var task = ctx.taskqueue[].pop(); not task.isNil):
@ -338,13 +337,17 @@ proc eventLoop(ctx: var WorkerContext) {.raises:[Exception].} =
# 2. Run out of tasks, become a thief # 2. Run out of tasks, become a thief
debug: log("Worker %2d: eventLoop 2 - becoming a thief\n", ctx.id) debug: log("Worker %2d: eventLoop 2 - becoming a thief\n", ctx.id)
let ticket = ctx.threadpool.globalBackoff.sleepy() let ticket = ctx.threadpool.globalBackoff.sleepy()
var stolenTask = ctx.tryStealAdaptative() if (var stolenTask = ctx.tryStealAdaptative(); not stolenTask.isNil):
if not stolenTask.isNil:
# We manage to steal a task, cancel sleep # We manage to steal a task, cancel sleep
ctx.threadpool.globalBackoff.cancelSleep() ctx.threadpool.globalBackoff.cancelSleep()
# 2.a Run task # 2.a Run task
debug: log("Worker %2d: eventLoop 2.a - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, stolenTask, stolenTask.parent, ctx.currentTask) debug: log("Worker %2d: eventLoop 2.a - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, stolenTask, stolenTask.parent, ctx.currentTask)
ctx.run(stolenTask) ctx.run(stolenTask)
elif ctx.signal.terminate.load(moAcquire):
# 2.b Threadpool has no more tasks and we were signaled to terminate
ctx.threadpool.globalBackoff.cancelSleep()
debugTermination: log("Worker %2d: eventLoop 2.b - terminated\n", ctx.id)
break
else: else:
# 2.b Park the thread until a new task enters the threadpool # 2.b Park the thread until a new task enters the threadpool
debug: log("Worker %2d: eventLoop 2.b - sleeping\n", ctx.id) debug: log("Worker %2d: eventLoop 2.b - sleeping\n", ctx.id)
@ -357,9 +360,8 @@ proc eventLoop(ctx: var WorkerContext) {.raises:[Exception].} =
template isRootTask(task: ptr Task): bool = template isRootTask(task: ptr Task): bool =
task == RootTask task == RootTask
proc completeFuture*[T](fv: Flowvar[T], parentResult: var T) {.raises:[Exception].} = proc completeFuture*[T](fv: Flowvar[T], parentResult: var T) {.raises:[].} =
## Eagerly complete an awaited FlowVar ## Eagerly complete an awaited FlowVar
template ctx: untyped = workerContext template ctx: untyped = workerContext
template isFutReady(): untyped = template isFutReady(): untyped =
@ -422,7 +424,7 @@ proc completeFuture*[T](fv: Flowvar[T], parentResult: var T) {.raises:[Exception
if compareExchange(fv.task.waiter, expected, desired = ctx.localBackoff.addr, moAcquireRelease): if compareExchange(fv.task.waiter, expected, desired = ctx.localBackoff.addr, moAcquireRelease):
ctx.localBackoff.park() ctx.localBackoff.park()
proc syncAll*(tp: Threadpool) {.raises: [Exception].} = proc syncAll*(tp: Threadpool) {.raises: [].} =
## Blocks until all pending tasks are completed ## Blocks until all pending tasks are completed
## This MUST only be called from ## This MUST only be called from
## the root scope that created the threadpool ## the root scope that created the threadpool
@ -435,35 +437,31 @@ proc syncAll*(tp: Threadpool) {.raises: [Exception].} =
preCondition: ctx.currentTask.isRootTask() preCondition: ctx.currentTask.isRootTask()
# Empty all tasks # Empty all tasks
var foreignThreadsParked = false tp.globalBackoff.wakeAll()
while not foreignThreadsParked:
while true:
# 1. Empty local tasks # 1. Empty local tasks
debug: log("Worker %2d: syncAll 1 - searching task from local queue\n", ctx.id) debug: log("Worker %2d: syncAll 1 - searching task from local queue\n", ctx.id)
while (let task = ctx.taskqueue[].pop(); not task.isNil): while (let task = ctx.taskqueue[].pop(); not task.isNil):
debug: log("Worker %2d: syncAll 1 - running task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, task, task.parent, ctx.currentTask) debug: log("Worker %2d: syncAll 1 - running task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, task, task.parent, ctx.currentTask)
ctx.run(task) ctx.run(task)
if tp.numThreads == 1 or foreignThreadsParked: if tp.numThreads == 1:
break break
# 2. Help other threads # 2. Help other threads
debug: log("Worker %2d: syncAll 2 - becoming a thief\n", ctx.id) debug: log("Worker %2d: syncAll 2 - becoming a thief\n", ctx.id)
let stolenTask = ctx.tryStealAdaptative() if (var stolenTask = ctx.tryStealAdaptative(); not stolenTask.isNil):
# 2.a We stole some task
if not stolenTask.isNil: debug: log("Worker %2d: syncAll 2.a - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, stolenTask, stolenTask.parent, ctx.currentTask)
# 2.1 We stole some task
debug: log("Worker %2d: syncAll 2.1 - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, stolenTask, stolenTask.parent, ctx.currentTask)
ctx.run(stolenTask) ctx.run(stolenTask)
elif tp.globalBackoff.getNumWaiters() == (0'u32, tp.numThreads - 1):
# 2.b all threads besides the current are parked
debugTermination: log("Worker %2d: syncAll 2.b - termination, all other threads sleeping\n", ctx.id)
break
else: else:
# 2.2 No task to steal # 2.c We don't park as there is no notif for task completion
if tp.globalBackoff.getNumWaiters() == tp.numThreads - 1: cpuRelax()
# 2.2.1 all threads besides the current are parked
debugTermination:
log("Worker %2d: syncAll 2.2.1 - termination, all other threads sleeping\n", ctx.id)
foreignThreadsParked = true
else:
# 2.2.2 We don't park as there is no notif for task completion
cpuRelax()
debugTermination: debugTermination:
log(">>> Worker %2d leaves barrier <<<\n", ctx.id) log(">>> Worker %2d leaves barrier <<<\n", ctx.id)
@ -471,7 +469,7 @@ proc syncAll*(tp: Threadpool) {.raises: [Exception].} =
# Runtime # Runtime
# --------------------------------------------- # ---------------------------------------------
proc new*(T: type Threadpool, numThreads = countProcessors()): T {.raises: [Exception].} = proc new*(T: type Threadpool, numThreads = countProcessors()): T {.raises: [ResourceExhaustedError].} =
## Initialize a threadpool that manages `numThreads` threads. ## Initialize a threadpool that manages `numThreads` threads.
## Default to the number of logical processors available. ## Default to the number of logical processors available.
@ -503,7 +501,7 @@ proc new*(T: type Threadpool, numThreads = countProcessors()): T {.raises: [Exce
discard tp.barrier.wait() discard tp.barrier.wait()
return tp return tp
proc cleanup(tp: var Threadpool) {.raises: [OSError].} = proc cleanup(tp: var Threadpool) {.raises: [].} =
## Cleanup all resources allocated by the threadpool ## Cleanup all resources allocated by the threadpool
preCondition: workerContext.currentTask.isRootTask() preCondition: workerContext.currentTask.isRootTask()
@ -518,14 +516,14 @@ proc cleanup(tp: var Threadpool) {.raises: [OSError].} =
tp.freeHeapAligned() tp.freeHeapAligned()
proc shutdown*(tp: var Threadpool) {.raises:[Exception].} = proc shutdown*(tp: var Threadpool) {.raises:[].} =
## Wait until all tasks are processed and then shutdown the threadpool ## Wait until all tasks are processed and then shutdown the threadpool
preCondition: workerContext.currentTask.isRootTask() preCondition: workerContext.currentTask.isRootTask()
tp.syncAll() tp.syncAll()
# Signal termination to all threads # Signal termination to all threads
for i in 0 ..< tp.numThreads: for i in 0 ..< tp.numThreads:
tp.workerSignals[i].terminate.store(true, moRelaxed) tp.workerSignals[i].terminate.store(true, moRelease)
tp.globalBackoff.wakeAll() tp.globalBackoff.wakeAll()

View File

@ -14,8 +14,7 @@ import
# Test utilities # Test utilities
./t_ec_template ./t_ec_template
const const numPoints = [1, 2, 8, 16, 128, 1024, 2048, 16384, 32768] # 262144, 1048576]
numPoints = [1, 2, 8, 16, 128, 1024, 2048, 16384, 32768] # 262144, 1048576]
run_EC_batch_add_impl( run_EC_batch_add_impl(
ec = ECP_ShortW_Jac[Fp[BN254_Snarks], G1], ec = ECP_ShortW_Jac[Fp[BN254_Snarks], G1],

View File

@ -14,8 +14,7 @@ import
# Test utilities # Test utilities
./t_ec_template ./t_ec_template
const const numPoints = [1, 2, 8, 16, 128, 1024, 2048, 16384, 32768] # 262144, 1048576]
numPoints = [1, 2, 8, 16, 128, 1024, 2048, 16384, 32768] # 262144, 1048576]
run_EC_batch_add_impl( run_EC_batch_add_impl(
ec = ECP_ShortW_Prj[Fp[BN254_Snarks], G1], ec = ECP_ShortW_Prj[Fp[BN254_Snarks], G1],

View File

@ -0,0 +1,29 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
# Internals
../../constantine/math/config/curves,
../../constantine/math/elliptic/ec_shortweierstrass_jacobian,
../../constantine/math/arithmetic,
# Test utilities
./t_ec_template_parallel
const numPoints = [1, 2, 8, 16, 128, 1024, 2048, 16384, 32768] # 262144, 1048576]
run_EC_batch_add_parallel_impl(
ec = ECP_ShortW_Jac[Fp[BN254_Snarks], G1],
numPoints = numPoints,
moduleName = "test_ec_shortweierstrass_jacobian_batch_add_parallel_" & $BN254_Snarks
)
run_EC_batch_add_parallel_impl(
ec = ECP_ShortW_Jac[Fp[BLS12_381], G1],
numPoints = numPoints,
moduleName = "test_ec_shortweierstrass_jacobian_batch_add_parallel_" & $BLS12_381
)

View File

@ -0,0 +1,29 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
# Internals
../../constantine/math/config/curves,
../../constantine/math/elliptic/ec_shortweierstrass_projective,
../../constantine/math/arithmetic,
# Test utilities
./t_ec_template_parallel
const numPoints = [1, 2, 8, 16, 128, 1024, 2048, 16384, 32768] # 262144, 1048576]
run_EC_batch_add_parallel_impl(
ec = ECP_ShortW_Prj[Fp[BN254_Snarks], G1],
numPoints = numPoints,
moduleName = "test_ec_shortweierstrass_projective_batch_add_parallel_" & $BN254_Snarks
)
run_EC_batch_add_parallel_impl(
ec = ECP_ShortW_Prj[Fp[BLS12_381], G1],
numPoints = numPoints,
moduleName = "test_ec_shortweierstrass_projective_batch_add_parallel_" & $BLS12_381
)

View File

@ -0,0 +1,146 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
# ############################################################
#
# Template tests for elliptic curve operations
#
# ############################################################
import
# Standard library
std/[unittest, times],
# Internals
../../constantine/platforms/abstractions,
../../constantine/math/[arithmetic, extension_fields],
../../constantine/math/elliptic/[
ec_shortweierstrass_affine,
ec_shortweierstrass_jacobian,
ec_shortweierstrass_projective,
ec_shortweierstrass_batch_ops_parallel],
../../constantine/platforms/threadpool/threadpool,
# Test utilities
../../helpers/prng_unsafe
export unittest, abstractions, arithmetic # Generic sandwich
type
RandomGen* = enum
Uniform
HighHammingWeight
Long01Sequence
func random_point*(rng: var RngState, EC: typedesc, randZ: bool, gen: RandomGen): EC {.noInit.} =
when EC is ECP_ShortW_Aff:
if gen == Uniform:
result = rng.random_unsafe(EC)
elif gen == HighHammingWeight:
result = rng.random_highHammingWeight(EC)
else:
result = rng.random_long01Seq(EC)
else:
if not randZ:
if gen == Uniform:
result = rng.random_unsafe(EC)
elif gen == HighHammingWeight:
result = rng.random_highHammingWeight(EC)
else:
result = rng.random_long01Seq(EC)
else:
if gen == Uniform:
result = rng.random_unsafe_with_randZ(EC)
elif gen == HighHammingWeight:
result = rng.random_highHammingWeight_with_randZ(EC)
else:
result = rng.random_long01Seq_with_randZ(EC)
proc run_EC_batch_add_parallel_impl*[N: static int](
ec: typedesc,
numPoints: array[N, int],
moduleName: string
) =
# Random seed for reproducibility
var rng: RngState
let seed = 1674654772 # uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
rng.seed(seed)
echo "\n------------------------------------------------------\n"
echo moduleName, " xoshiro512** seed: ", seed
when ec.G == G1:
const G1_or_G2 = "G1"
else:
const G1_or_G2 = "G2"
const testSuiteDesc = "Elliptic curve parallel batch addition for Short Weierstrass form"
suite testSuiteDesc & " - " & $ec & " - [" & $WordBitWidth & "-bit mode]":
for n in numPoints:
test $ec & " batch addition (N=" & $n & ")":
proc test(EC: typedesc, gen: RandomGen) =
var tp = Threadpool.new()
defer: tp.shutdown()
var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](n)
for i in 0 ..< n:
points[i] = rng.random_point(ECP_ShortW_Aff[EC.F, EC.G], randZ = false, gen)
var r_batch{.noinit.}, r_ref{.noInit.}: EC
r_ref.setInf()
for i in 0 ..< n:
r_ref += points[i]
tp.sum_batch_vartime_parallel(r_batch, points)
check: bool(r_batch == r_ref)
test(ec, gen = Uniform)
test(ec, gen = HighHammingWeight)
test(ec, gen = Long01Sequence)
test "EC " & G1_or_G2 & " batch addition (N=" & $n & ") - special cases":
proc test(EC: typedesc, gen: RandomGen) =
var tp = Threadpool.new()
defer: tp.shutdown()
var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](n)
let halfN = n div 2
for i in 0 ..< halfN:
points[i] = rng.random_point(ECP_ShortW_Aff[EC.F, EC.G], randZ = false, gen)
for i in halfN ..< n:
# The special cases test relies on internal knowledge that we sum(points[i], points[i+n/2]
# It should be changed if scheduling change, for example if we sum(points[2*i], points[2*i+1])
let c = rng.random_unsafe(3)
if c == 0:
points[i] = rng.random_point(ECP_ShortW_Aff[EC.F, EC.G], randZ = false, gen)
elif c == 1:
points[i] = points[i-halfN]
else:
points[i].neg(points[i-halfN])
var r_batch{.noinit.}, r_ref{.noInit.}: EC
r_ref.setInf()
for i in 0 ..< n:
r_ref += points[i]
tp.sum_batch_vartime_parallel(r_batch, points)
check: bool(r_batch == r_ref)
test(ec, gen = Uniform)
test(ec, gen = HighHammingWeight)
test(ec, gen = Long01Sequence)