use SharedBuf to facilitate safe cross-GC boundary task spawning

This commit is contained in:
munna0908 2026-05-03 20:38:07 +05:30
parent f876bd6dba
commit afd479e822
No known key found for this signature in database
GPG Key ID: 2FFCD637E937D3E6
6 changed files with 166 additions and 42 deletions

View File

@ -12,3 +12,13 @@ requires "nim >= 2.2.0"
requires "https://github.com/status-im/nim-taskpools >= 0.0.5"
requires "https://github.com/mratsim/constantine"
# requires "https://github.com/mratsim/constantine#bc3845aa492b52f7fef047503b1592e830d1a774"
task test_arc, "run the test suite under --mm:arc":
exec "nim c -r --threads:on --mm:arc tests/test.nim"
task test_refc, "run the test suite under --mm:refc":
exec "nim c -r --threads:on --mm:refc tests/test.nim"
task test, "run the test suite under both --mm:arc and --mm:refc":
exec "nim c -r --threads:on --mm:arc tests/test.nim"
exec "nim c -r --threads:on --mm:refc tests/test.nim"

View File

@ -22,6 +22,7 @@ import constantine/math/elliptic/ec_multi_scalar_mul as msm except Su
#import groth16/bn128/fields
import groth16/bn128/curves as mycurves
import groth16/sharedbuf
#import groth16/misc # TEMP DEBUGGING
#import std/cpuinfo
@ -79,6 +80,33 @@ func msmConstantineG2*( coeffs: openArray[Fr[BN254_Snarks]] , points: openArray[
return rAff
#-------------------------------------------------------------------------------
# spawnable wrappers: take SharedBuf views, delegate to the core.
# These are what `pool.spawn` calls, so they carry {.gcsafe, raises: [].}.
#
# Local aliases `AffG1`/`AffG2` are required because taskpools' `spawn` macro
# does `getImpl().replaceSymsByIdents()`, which strips qualifications. With
# `SharedBuf[mycurves.G1]` the bare ident `G1` then re-resolves to the enum
# value `aff.G1` (of type `Subgroup`), not the type alias. Renaming dodges
# the collision.
type
AffG1 = mycurves.G1
AffG2 = mycurves.G2
FrBN = Fr[BN254_Snarks]
func msmConstantineG1Range( coeffs: SharedBuf[FrBN] ,
points: SharedBuf[AffG1] ): AffG1
{.gcsafe, raises: [].} =
msmConstantineG1( toOpenArray(coeffs.payload, 0, coeffs.len - 1),
toOpenArray(points.payload, 0, points.len - 1) )
func msmConstantineG2Range( coeffs: SharedBuf[FrBN] ,
points: SharedBuf[AffG2] ): AffG2
{.gcsafe, raises: [].} =
msmConstantineG2( toOpenArray(coeffs.payload, 0, coeffs.len - 1),
toOpenArray(points.payload, 0, points.len - 1) )
#-------------------------------------------------------------------------------
const task_multiplier : int = 1
@ -105,9 +133,9 @@ proc msmMultiThreadedG1*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1], pool:
b = (N*(k+1)) div ntasks
else:
b = N
let cs = coeffs[a..<b]
let ps = points[a..<b]
pending[k] = pool.spawn msmConstantineG1( cs, ps );
let cs = SharedBuf.view(toOpenArray(coeffs, a, b - 1))
let ps = SharedBuf.view(toOpenArray(points, a, b - 1))
pending[k] = pool.spawn msmConstantineG1Range(cs, ps)
a = b
var res : G1 = infG1
@ -135,9 +163,9 @@ proc msmMultiThreadedG2*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G2], pool:
b = (N*(k+1)) div ntasks
else:
b = N
let cs = coeffs[a..<b]
let ps = points[a..<b]
pending[k] = pool.spawn msmConstantineG2( cs, ps );
let cs = SharedBuf.view(toOpenArray(coeffs, a, b - 1))
let ps = SharedBuf.view(toOpenArray(points, a, b - 1))
pending[k] = pool.spawn msmConstantineG2Range(cs, ps)
a = b
var res : G2 = infG2

View File

@ -33,9 +33,6 @@ import groth16/partial/types
proc generatePartialProof*( zkey: ZKey, pwtns: PartialWitness, pool: Taskpool, printTimings: bool): PartialProof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
# assert( zkey.header.curve == wtns.curve )
let partial_witness = pwtns.values

View File

@ -33,9 +33,6 @@ import groth16/prover/shared
proc generateProofWithMask*( zkey: ZKey, wtns: Witness, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
# if (zkey.header.curve != wtns.curve):
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
# echo( "wtns.curve = " & ($wtns.curve ) )

View File

@ -18,10 +18,19 @@ import groth16/bn128
import groth16/math/domain
import groth16/math/poly
import groth16/zkey_types
import groth16/sharedbuf
#import groth16/misc
import groth16/prover/types
# `FrBN` shadows `Fr[BN254_Snarks]` to dodge a taskpools `spawn` macro issue:
# its `getImpl().replaceSymsByIdents()` doesn't roundtrip qualified generic
# instantiations cleanly, so the bare ident `Fr[BN254_Snarks]` re-resolves
# wrong at the wrapper-proc reconstruction step. Using a plain alias gives
# the macro an unambiguous symbol to work with. Same workaround as
# `AffG1`/`AffG2` in `groth16/bn128/msm.nim`.
type FrBN = Fr[BN254_Snarks]
#-------------------------------------------------------------------------------
proc randomMask*(): Mask =
@ -96,16 +105,40 @@ func shiftEvalDomain*(
var ds : seq[Fr[BN254_Snarks]] = multiplyByPowers( cs, eta )
return polyForwardNTT( Poly(coeffs:ds), D )
# Wraps shiftEvalDomain such that it can be called by Taskpool.spawn. The result
# is written to the output parameter. Has an unused return type because
# Taskpool.spawn cannot handle a void return type.
func shiftEvalDomainTask*(
values: seq[Fr[BN254_Snarks]],
D: Domain,
eta: Fr[BN254_Snarks],
output: ptr Isolated[seq[Fr[BN254_Snarks]]]): bool =
# Spawnable wrapper for shiftEvalDomain. Crosses the spawn boundary using
# SharedBuf views — no GC type travels across, so this works under any
# GC mode (refc / arc / orc).
#
# Contract: caller owns both `values` (read) and `output` (written),
# both of length n. Caller must keep them alive until the FlowVar is sync'd.
proc shiftEvalDomainTask(
values: SharedBuf[FrBN],
D: Domain,
eta: FrBN,
output: SharedBuf[FrBN]): bool
{.gcsafe, raises: [].} =
output[] = isolate shiftEvalDomain(values, D, eta)
let n = values.len
# TODO(perf): two N-sized boundary copies (input → `local`, `result` → output)
# are an artifact of `shiftEvalDomain` / `polyInverseNTT` / `polyForwardNTT`
# taking and returning `seq[Fr]`. To eliminate both:
# 1. Add `inverseNTTInto`/`forwardNTTInto` variants in `groth16/math/ntt.nim`
# that take `openArray[Fr]` for input and `var openArray[Fr]` for output
# (the inner workers already operate on stride/offset triples — change is
# mechanical).
# 2. Rewrite this proc to call the `…Into` variants directly on
# `toOpenArray(values)` and `toOpenArray(output)`.
# Copy input view → worker-local seq for the existing seq-flavoured core.
var local = newSeq[FrBN](n)
for i in 0 ..< n: local[i] = values.payload[i]
let result = shiftEvalDomain(local, D, eta)
# Copy result → caller's output buffer through the SharedBuf payload pointer.
for i in 0 ..< n: output.payload[i] = result[i]
return true
# computes the quotient polynomial Q = (A*B - C) / Z
# by computing the values on a shifted domain, and interpolating the result
@ -116,28 +149,37 @@ proc computeQuotientPointwise*( abc: ABC, pool: TaskPool ): Poly =
assert( abc.valuesCz.len == n )
let D = createDomain(n)
# (eta*omega^j)^n - 1 = eta^n - 1
# (eta*omega^j)^n - 1 = eta^n - 1
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
let eta = createDomain(2*n).domainGen
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
# Pre-allocate caller-owned output buffers; workers write into these
# through bare pointers, so nothing GC-managed crosses the spawn.
var outA = newSeq[Fr[BN254_Snarks]](n)
var outB = newSeq[Fr[BN254_Snarks]](n)
var outC = newSeq[Fr[BN254_Snarks]](n)
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
let taskA1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesAz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outA, 0, n - 1)))
let taskB1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesBz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outB, 0, n - 1)))
let taskC1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesCz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outC, 0, n - 1)))
discard sync taskA1
discard sync taskB1
discard sync taskC1
let A1 = outputA1.extract()
let B1 = outputB1.extract()
let C1 = outputC1.extract()
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
for j in 0..<n: ys[j] = ( outA[j]*outB[j] - outC[j] ) * invZ1
let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
@ -147,7 +189,7 @@ proc computeQuotientPointwise*( abc: ABC, pool: TaskPool ): Poly =
# Snarkjs does something different, not actually computing the quotient poly
# they can get away with this, because during the trusted setup, they
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
# (shifted) Lagrange bases.
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
@ -158,22 +200,29 @@ proc computeSnarkjsScalarCoeffs*( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snark
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
var outA = newSeq[Fr[BN254_Snarks]](n)
var outB = newSeq[Fr[BN254_Snarks]](n)
var outC = newSeq[Fr[BN254_Snarks]](n)
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
let taskA1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesAz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outA, 0, n - 1)))
let taskB1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesBz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outB, 0, n - 1)))
let taskC1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesCz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outC, 0, n - 1)))
discard sync taskA1
discard sync taskB1
discard sync taskC1
let A1 = outputA1.extract()
let B1 = outputB1.extract()
let C1 = outputC1.extract()
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
for j in 0..<n: ys[j] = ( outA[j] * outB[j] - outC[j] )
return ys

43
groth16/sharedbuf.nim Normal file
View File

@ -0,0 +1,43 @@
#
# SharedBuf[T] — non-owning (ptr, len) view over a contiguous buffer.
#
# Purpose: cross a `Taskpool.spawn` boundary safely under any GC mode
# (refc / arc / orc) without sending a `seq[T]` through the spawn closure.
#
# Contract:
# - The pointed-to memory is owned by the caller. The caller MUST keep
# the underlying storage alive until every worker holding the view
# has finished (i.e. all FlowVars from `pool.spawn` have been `sync`ed).
# - Element type `T` must have no GC fields (workers read/write payload
# bytes through a raw pointer; the seq's GC machinery is never touched).
#
{.push raises: [], gcsafe.}
template makeUncheckedArray[T](p: ptr T): ptr UncheckedArray[T] =
cast[ptr UncheckedArray[T]](p)
type SharedBuf*[T] = object
payload*: ptr UncheckedArray[T]
len*: int
func view*[T](_: type SharedBuf, v: openArray[T]): SharedBuf[T] =
if v.len > 0:
SharedBuf[T](payload: makeUncheckedArray(addr v[0]), len: v.len)
else:
default(SharedBuf[T])
template checkIdx(v: SharedBuf, i: int) =
doAssert i >= 0 and i < v.len
func `[]`*[T](v: SharedBuf[T], i: int): var T =
v.checkIdx(i)
v.payload[i]
template toOpenArray*[T](v: SharedBuf[T]): var openArray[T] =
v.payload.toOpenArray(0, v.len - 1)
template toOpenArray*[T](v: SharedBuf[T], s, e: int): var openArray[T] =
v.toOpenArray().toOpenArray(s, e)
{.pop.}