mirror of
https://github.com/logos-storage/nim-groth16.git
synced 2026-05-18 16:49:30 +00:00
use SharedBuf to facilitate safe cross-GC boundary task spawning
This commit is contained in:
parent
f876bd6dba
commit
afd479e822
@ -12,3 +12,13 @@ requires "nim >= 2.2.0"
|
||||
requires "https://github.com/status-im/nim-taskpools >= 0.0.5"
|
||||
requires "https://github.com/mratsim/constantine"
|
||||
# requires "https://github.com/mratsim/constantine#bc3845aa492b52f7fef047503b1592e830d1a774"
|
||||
|
||||
task test_arc, "run the test suite under --mm:arc":
|
||||
exec "nim c -r --threads:on --mm:arc tests/test.nim"
|
||||
|
||||
task test_refc, "run the test suite under --mm:refc":
|
||||
exec "nim c -r --threads:on --mm:refc tests/test.nim"
|
||||
|
||||
task test, "run the test suite under both --mm:arc and --mm:refc":
|
||||
exec "nim c -r --threads:on --mm:arc tests/test.nim"
|
||||
exec "nim c -r --threads:on --mm:refc tests/test.nim"
|
||||
|
||||
@ -22,6 +22,7 @@ import constantine/math/elliptic/ec_multi_scalar_mul as msm except Su
|
||||
|
||||
#import groth16/bn128/fields
|
||||
import groth16/bn128/curves as mycurves
|
||||
import groth16/sharedbuf
|
||||
|
||||
#import groth16/misc # TEMP DEBUGGING
|
||||
#import std/cpuinfo
|
||||
@ -79,6 +80,33 @@ func msmConstantineG2*( coeffs: openArray[Fr[BN254_Snarks]] , points: openArray[
|
||||
|
||||
return rAff
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# spawnable wrappers: take SharedBuf views, delegate to the core.
|
||||
# These are what `pool.spawn` calls, so they carry {.gcsafe, raises: [].}.
|
||||
#
|
||||
# Local aliases `AffG1`/`AffG2` are required because taskpools' `spawn` macro
|
||||
# does `getImpl().replaceSymsByIdents()`, which strips qualifications. With
|
||||
# `SharedBuf[mycurves.G1]` the bare ident `G1` then re-resolves to the enum
|
||||
# value `aff.G1` (of type `Subgroup`), not the type alias. Renaming dodges
|
||||
# the collision.
|
||||
|
||||
type
|
||||
AffG1 = mycurves.G1
|
||||
AffG2 = mycurves.G2
|
||||
FrBN = Fr[BN254_Snarks]
|
||||
|
||||
func msmConstantineG1Range( coeffs: SharedBuf[FrBN] ,
|
||||
points: SharedBuf[AffG1] ): AffG1
|
||||
{.gcsafe, raises: [].} =
|
||||
msmConstantineG1( toOpenArray(coeffs.payload, 0, coeffs.len - 1),
|
||||
toOpenArray(points.payload, 0, points.len - 1) )
|
||||
|
||||
func msmConstantineG2Range( coeffs: SharedBuf[FrBN] ,
|
||||
points: SharedBuf[AffG2] ): AffG2
|
||||
{.gcsafe, raises: [].} =
|
||||
msmConstantineG2( toOpenArray(coeffs.payload, 0, coeffs.len - 1),
|
||||
toOpenArray(points.payload, 0, points.len - 1) )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
const task_multiplier : int = 1
|
||||
@ -105,9 +133,9 @@ proc msmMultiThreadedG1*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1], pool:
|
||||
b = (N*(k+1)) div ntasks
|
||||
else:
|
||||
b = N
|
||||
let cs = coeffs[a..<b]
|
||||
let ps = points[a..<b]
|
||||
pending[k] = pool.spawn msmConstantineG1( cs, ps );
|
||||
let cs = SharedBuf.view(toOpenArray(coeffs, a, b - 1))
|
||||
let ps = SharedBuf.view(toOpenArray(points, a, b - 1))
|
||||
pending[k] = pool.spawn msmConstantineG1Range(cs, ps)
|
||||
a = b
|
||||
|
||||
var res : G1 = infG1
|
||||
@ -135,9 +163,9 @@ proc msmMultiThreadedG2*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G2], pool:
|
||||
b = (N*(k+1)) div ntasks
|
||||
else:
|
||||
b = N
|
||||
let cs = coeffs[a..<b]
|
||||
let ps = points[a..<b]
|
||||
pending[k] = pool.spawn msmConstantineG2( cs, ps );
|
||||
let cs = SharedBuf.view(toOpenArray(coeffs, a, b - 1))
|
||||
let ps = SharedBuf.view(toOpenArray(points, a, b - 1))
|
||||
pending[k] = pool.spawn msmConstantineG2Range(cs, ps)
|
||||
a = b
|
||||
|
||||
var res : G2 = infG2
|
||||
|
||||
@ -33,9 +33,6 @@ import groth16/partial/types
|
||||
|
||||
proc generatePartialProof*( zkey: ZKey, pwtns: PartialWitness, pool: Taskpool, printTimings: bool): PartialProof =
|
||||
|
||||
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
|
||||
{.fatal: "Compile with arc/orc!".}
|
||||
|
||||
# assert( zkey.header.curve == wtns.curve )
|
||||
|
||||
let partial_witness = pwtns.values
|
||||
|
||||
@ -33,9 +33,6 @@ import groth16/prover/shared
|
||||
|
||||
proc generateProofWithMask*( zkey: ZKey, wtns: Witness, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
|
||||
|
||||
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
|
||||
{.fatal: "Compile with arc/orc!".}
|
||||
|
||||
# if (zkey.header.curve != wtns.curve):
|
||||
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
|
||||
# echo( "wtns.curve = " & ($wtns.curve ) )
|
||||
|
||||
@ -18,10 +18,19 @@ import groth16/bn128
|
||||
import groth16/math/domain
|
||||
import groth16/math/poly
|
||||
import groth16/zkey_types
|
||||
import groth16/sharedbuf
|
||||
#import groth16/misc
|
||||
|
||||
import groth16/prover/types
|
||||
|
||||
# `FrBN` shadows `Fr[BN254_Snarks]` to dodge a taskpools `spawn` macro issue:
|
||||
# its `getImpl().replaceSymsByIdents()` doesn't roundtrip qualified generic
|
||||
# instantiations cleanly, so the bare ident `Fr[BN254_Snarks]` re-resolves
|
||||
# wrong at the wrapper-proc reconstruction step. Using a plain alias gives
|
||||
# the macro an unambiguous symbol to work with. Same workaround as
|
||||
# `AffG1`/`AffG2` in `groth16/bn128/msm.nim`.
|
||||
type FrBN = Fr[BN254_Snarks]
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc randomMask*(): Mask =
|
||||
@ -96,16 +105,40 @@ func shiftEvalDomain*(
|
||||
var ds : seq[Fr[BN254_Snarks]] = multiplyByPowers( cs, eta )
|
||||
return polyForwardNTT( Poly(coeffs:ds), D )
|
||||
|
||||
# Wraps shiftEvalDomain such that it can be called by Taskpool.spawn. The result
|
||||
# is written to the output parameter. Has an unused return type because
|
||||
# Taskpool.spawn cannot handle a void return type.
|
||||
func shiftEvalDomainTask*(
|
||||
values: seq[Fr[BN254_Snarks]],
|
||||
D: Domain,
|
||||
eta: Fr[BN254_Snarks],
|
||||
output: ptr Isolated[seq[Fr[BN254_Snarks]]]): bool =
|
||||
# Spawnable wrapper for shiftEvalDomain. Crosses the spawn boundary using
|
||||
# SharedBuf views — no GC type travels across, so this works under any
|
||||
# GC mode (refc / arc / orc).
|
||||
#
|
||||
# Contract: caller owns both `values` (read) and `output` (written),
|
||||
# both of length n. Caller must keep them alive until the FlowVar is sync'd.
|
||||
proc shiftEvalDomainTask(
|
||||
values: SharedBuf[FrBN],
|
||||
D: Domain,
|
||||
eta: FrBN,
|
||||
output: SharedBuf[FrBN]): bool
|
||||
{.gcsafe, raises: [].} =
|
||||
|
||||
output[] = isolate shiftEvalDomain(values, D, eta)
|
||||
let n = values.len
|
||||
|
||||
# TODO(perf): two N-sized boundary copies (input → `local`, `result` → output)
|
||||
# are an artifact of `shiftEvalDomain` / `polyInverseNTT` / `polyForwardNTT`
|
||||
# taking and returning `seq[Fr]`. To eliminate both:
|
||||
# 1. Add `inverseNTTInto`/`forwardNTTInto` variants in `groth16/math/ntt.nim`
|
||||
# that take `openArray[Fr]` for input and `var openArray[Fr]` for output
|
||||
# (the inner workers already operate on stride/offset triples — change is
|
||||
# mechanical).
|
||||
# 2. Rewrite this proc to call the `…Into` variants directly on
|
||||
# `toOpenArray(values)` and `toOpenArray(output)`.
|
||||
|
||||
# Copy input view → worker-local seq for the existing seq-flavoured core.
|
||||
var local = newSeq[FrBN](n)
|
||||
for i in 0 ..< n: local[i] = values.payload[i]
|
||||
|
||||
let result = shiftEvalDomain(local, D, eta)
|
||||
|
||||
# Copy result → caller's output buffer through the SharedBuf payload pointer.
|
||||
for i in 0 ..< n: output.payload[i] = result[i]
|
||||
return true
|
||||
|
||||
# computes the quotient polynomial Q = (A*B - C) / Z
|
||||
# by computing the values on a shifted domain, and interpolating the result
|
||||
@ -116,28 +149,37 @@ proc computeQuotientPointwise*( abc: ABC, pool: TaskPool ): Poly =
|
||||
assert( abc.valuesCz.len == n )
|
||||
|
||||
let D = createDomain(n)
|
||||
|
||||
# (eta*omega^j)^n - 1 = eta^n - 1
|
||||
|
||||
# (eta*omega^j)^n - 1 = eta^n - 1
|
||||
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
|
||||
|
||||
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
|
||||
# Pre-allocate caller-owned output buffers; workers write into these
|
||||
# through bare pointers, so nothing GC-managed crosses the spawn.
|
||||
var outA = newSeq[Fr[BN254_Snarks]](n)
|
||||
var outB = newSeq[Fr[BN254_Snarks]](n)
|
||||
var outC = newSeq[Fr[BN254_Snarks]](n)
|
||||
|
||||
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
|
||||
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
|
||||
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
|
||||
let taskA1 = pool.spawn shiftEvalDomainTask(
|
||||
SharedBuf.view(toOpenArray(abc.valuesAz, 0, n - 1)),
|
||||
D, eta,
|
||||
SharedBuf.view(toOpenArray(outA, 0, n - 1)))
|
||||
let taskB1 = pool.spawn shiftEvalDomainTask(
|
||||
SharedBuf.view(toOpenArray(abc.valuesBz, 0, n - 1)),
|
||||
D, eta,
|
||||
SharedBuf.view(toOpenArray(outB, 0, n - 1)))
|
||||
let taskC1 = pool.spawn shiftEvalDomainTask(
|
||||
SharedBuf.view(toOpenArray(abc.valuesCz, 0, n - 1)),
|
||||
D, eta,
|
||||
SharedBuf.view(toOpenArray(outC, 0, n - 1)))
|
||||
|
||||
discard sync taskA1
|
||||
discard sync taskB1
|
||||
discard sync taskC1
|
||||
|
||||
let A1 = outputA1.extract()
|
||||
let B1 = outputB1.extract()
|
||||
let C1 = outputC1.extract()
|
||||
|
||||
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
|
||||
for j in 0..<n: ys[j] = ( outA[j]*outB[j] - outC[j] ) * invZ1
|
||||
let Q1 = polyInverseNTT( ys, D )
|
||||
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
|
||||
|
||||
@ -147,7 +189,7 @@ proc computeQuotientPointwise*( abc: ABC, pool: TaskPool ): Poly =
|
||||
|
||||
# Snarkjs does something different, not actually computing the quotient poly
|
||||
# they can get away with this, because during the trusted setup, they
|
||||
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
|
||||
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
|
||||
# (shifted) Lagrange bases.
|
||||
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
#
|
||||
@ -158,22 +200,29 @@ proc computeSnarkjsScalarCoeffs*( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snark
|
||||
let D = createDomain(n)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
|
||||
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
|
||||
var outA = newSeq[Fr[BN254_Snarks]](n)
|
||||
var outB = newSeq[Fr[BN254_Snarks]](n)
|
||||
var outC = newSeq[Fr[BN254_Snarks]](n)
|
||||
|
||||
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
|
||||
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
|
||||
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
|
||||
let taskA1 = pool.spawn shiftEvalDomainTask(
|
||||
SharedBuf.view(toOpenArray(abc.valuesAz, 0, n - 1)),
|
||||
D, eta,
|
||||
SharedBuf.view(toOpenArray(outA, 0, n - 1)))
|
||||
let taskB1 = pool.spawn shiftEvalDomainTask(
|
||||
SharedBuf.view(toOpenArray(abc.valuesBz, 0, n - 1)),
|
||||
D, eta,
|
||||
SharedBuf.view(toOpenArray(outB, 0, n - 1)))
|
||||
let taskC1 = pool.spawn shiftEvalDomainTask(
|
||||
SharedBuf.view(toOpenArray(abc.valuesCz, 0, n - 1)),
|
||||
D, eta,
|
||||
SharedBuf.view(toOpenArray(outC, 0, n - 1)))
|
||||
|
||||
discard sync taskA1
|
||||
discard sync taskB1
|
||||
discard sync taskC1
|
||||
|
||||
let A1 = outputA1.extract()
|
||||
let B1 = outputB1.extract()
|
||||
let C1 = outputC1.extract()
|
||||
|
||||
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
|
||||
for j in 0..<n: ys[j] = ( outA[j] * outB[j] - outC[j] )
|
||||
|
||||
return ys
|
||||
|
||||
|
||||
43
groth16/sharedbuf.nim
Normal file
43
groth16/sharedbuf.nim
Normal file
@ -0,0 +1,43 @@
|
||||
#
|
||||
# SharedBuf[T] — non-owning (ptr, len) view over a contiguous buffer.
|
||||
#
|
||||
# Purpose: cross a `Taskpool.spawn` boundary safely under any GC mode
|
||||
# (refc / arc / orc) without sending a `seq[T]` through the spawn closure.
|
||||
#
|
||||
# Contract:
|
||||
# - The pointed-to memory is owned by the caller. The caller MUST keep
|
||||
# the underlying storage alive until every worker holding the view
|
||||
# has finished (i.e. all FlowVars from `pool.spawn` have been `sync`ed).
|
||||
# - Element type `T` must have no GC fields (workers read/write payload
|
||||
# bytes through a raw pointer; the seq's GC machinery is never touched).
|
||||
#
|
||||
|
||||
{.push raises: [], gcsafe.}
|
||||
|
||||
template makeUncheckedArray[T](p: ptr T): ptr UncheckedArray[T] =
|
||||
cast[ptr UncheckedArray[T]](p)
|
||||
|
||||
type SharedBuf*[T] = object
|
||||
payload*: ptr UncheckedArray[T]
|
||||
len*: int
|
||||
|
||||
func view*[T](_: type SharedBuf, v: openArray[T]): SharedBuf[T] =
|
||||
if v.len > 0:
|
||||
SharedBuf[T](payload: makeUncheckedArray(addr v[0]), len: v.len)
|
||||
else:
|
||||
default(SharedBuf[T])
|
||||
|
||||
template checkIdx(v: SharedBuf, i: int) =
|
||||
doAssert i >= 0 and i < v.len
|
||||
|
||||
func `[]`*[T](v: SharedBuf[T], i: int): var T =
|
||||
v.checkIdx(i)
|
||||
v.payload[i]
|
||||
|
||||
template toOpenArray*[T](v: SharedBuf[T]): var openArray[T] =
|
||||
v.payload.toOpenArray(0, v.len - 1)
|
||||
|
||||
template toOpenArray*[T](v: SharedBuf[T], s, e: int): var openArray[T] =
|
||||
v.toOpenArray().toOpenArray(s, e)
|
||||
|
||||
{.pop.}
|
||||
Loading…
x
Reference in New Issue
Block a user