fix the multithreading failures by switching to arc

This commit is contained in:
Balazs Komuves 2024-03-04 09:27:12 +01:00
parent d790dc3162
commit 5616a1c52f
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
4 changed files with 40 additions and 32 deletions

View File

@ -19,14 +19,14 @@ at your choice.
### TODO
- [ ] find and fix the _second_ totally surreal bug
- [x] find and fix the _second_ totally surreal bug
- [ ] clean up the code
- [ ] make it compatible with the latest constantine and also Nim 2.0.x
- [x] make it a nimble package
- [ ] compare `.r1cs` to the "coeffs" section of `.zkey`
- [x] generate fake circuit-specific setup ourselves
- [x] make a CLI interface
- [ ] multithreading support (MSM, and possibly also FFT)
- [x] multithreading support (MSM, and possibly also FFT)
- [ ] add Groth16 notes
- [ ] document the `snarkjs` circuit-specific setup `H` points convention
- [ ] make it work for different curves

View File

@ -1,2 +1,3 @@
--path:".."
--threads:on
--threads:on
--mm:arc

View File

@ -102,10 +102,6 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1]
var pool = Taskpool.new(num_threads = nthreads)
var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks)
# nim is just batshit crazy...
GC_ref(coeffs)
GC_ref(points)
var a : int = 0
var b : int
for k in 0..<ntasks:
@ -125,9 +121,6 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1]
pool.syncAll()
pool.shutdown()
GC_unref(coeffs)
GC_unref(points)
return res
#---------------------------------------
@ -143,9 +136,6 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2]
var pool = Taskpool.new(num_threads = nthreads)
var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks)
GC_ref(coeffs)
GC_ref(points)
var a : int = 0
var b : int
for k in 0..<ntasks:
@ -165,9 +155,6 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2]
pool.syncAll()
pool.shutdown()
GC_unref(coeffs)
GC_unref(points)
return res
#-------------------------------------------------------------------------------

View File

@ -128,16 +128,11 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
var pool = Taskpool.new(num_threads = nthreads)
GCref(abc.valuesAz)
GCref(abc.valuesBz)
GCref(abc.valuesCz)
var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta )
var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta )
var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta )
pool.syncAll()
GCunref(abc.valuesAz)
GCunref(abc.valuesBz)
GCunref(abc.valuesCz)
let A1 = sync A1fv
let B1 = sync B1fv
let C1 = sync C1fv
@ -147,7 +142,9 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
pool.syncAll()
pool.shutdown()
return Poly(coeffs: cs)
#---------------------------------------
@ -158,7 +155,7 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
# (shifted) Lagrange bases.
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr] =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
@ -166,16 +163,11 @@ proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
let eta = createDomain(2*n).domainGen
var pool = Taskpool.new(num_threads = nthreads)
GCref(abc.valuesAz)
GCref(abc.valuesBz)
GCref(abc.valuesCz)
var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta )
var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta )
var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta )
pool.syncAll()
GCunref(abc.valuesAz)
GCunref(abc.valuesBz)
GCunref(abc.valuesCz)
let A1 = sync A1fv
let B1 = sync B1fv
let C1 = sync C1fv
@ -183,9 +175,34 @@ proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
pool.syncAll()
pool.shutdown()
return ys
#[
proc computeSnarkjsScalarCoeffs_st( abc: ABC ): seq[Fr] =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
let A1 : seq[Fr] = shiftEvalDomain( abc.valuesAz, D, eta )
let B1 : seq[Fr] = shiftEvalDomain( abc.valuesBz, D, eta )
let C1 : seq[Fr] = shiftEvalDomain( abc.valuesCz, D, eta )
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
return ys
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
if nthreads <= 1:
computeSnarkjsScalarCoeffs_st( abc )
else:
computeSnarkjsScalarCoeffs_mt( nthreads, abc )
]#
#-------------------------------------------------------------------------------
# the prover
#
@ -197,6 +214,9 @@ type
proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness, mask: Mask ): Proof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
# if (zkey.header.curve != wtns.curve):
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
# echo( "wtns.curve = " & ($wtns.curve ) )
@ -286,7 +306,7 @@ proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns
#-------------------------------------------------------------------------------
proc generateProofWithTrivialMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
let mask = Mask(r: intToFr(0), s: intToFr(0))
let mask = Mask( r: zeroFr , s: zeroFr )
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )
proc generateProof*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =