mirror of
https://github.com/logos-storage/nim-groth16.git
synced 2026-01-09 17:13:08 +00:00
fix the multithreading failures by switching to arc
This commit is contained in:
parent
d790dc3162
commit
5616a1c52f
@ -19,14 +19,14 @@ at your choice.
|
||||
|
||||
### TODO
|
||||
|
||||
- [ ] find and fix the _second_ totally surreal bug
|
||||
- [x] find and fix the _second_ totally surreal bug
|
||||
- [ ] clean up the code
|
||||
- [ ] make it compatible with the latest constantine and also Nim 2.0.x
|
||||
- [x] make it a nimble package
|
||||
- [ ] compare `.r1cs` to the "coeffs" section of `.zkey`
|
||||
- [x] generate fake circuit-specific setup ourselves
|
||||
- [x] make a CLI interface
|
||||
- [ ] multithreading support (MSM, and possibly also FFT)
|
||||
- [x] multithreading support (MSM, and possibly also FFT)
|
||||
- [ ] add Groth16 notes
|
||||
- [ ] document the `snarkjs` circuit-specific setup `H` points convention
|
||||
- [ ] make it work for different curves
|
||||
|
||||
@ -1,2 +1,3 @@
|
||||
--path:".."
|
||||
--threads:on
|
||||
--threads:on
|
||||
--mm:arc
|
||||
@ -102,10 +102,6 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1]
|
||||
var pool = Taskpool.new(num_threads = nthreads)
|
||||
var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks)
|
||||
|
||||
# nim is just batshit crazy...
|
||||
GC_ref(coeffs)
|
||||
GC_ref(points)
|
||||
|
||||
var a : int = 0
|
||||
var b : int
|
||||
for k in 0..<ntasks:
|
||||
@ -125,9 +121,6 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1]
|
||||
pool.syncAll()
|
||||
pool.shutdown()
|
||||
|
||||
GC_unref(coeffs)
|
||||
GC_unref(points)
|
||||
|
||||
return res
|
||||
|
||||
#---------------------------------------
|
||||
@ -143,9 +136,6 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2]
|
||||
var pool = Taskpool.new(num_threads = nthreads)
|
||||
var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks)
|
||||
|
||||
GC_ref(coeffs)
|
||||
GC_ref(points)
|
||||
|
||||
var a : int = 0
|
||||
var b : int
|
||||
for k in 0..<ntasks:
|
||||
@ -165,9 +155,6 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2]
|
||||
pool.syncAll()
|
||||
pool.shutdown()
|
||||
|
||||
GC_unref(coeffs)
|
||||
GC_unref(points)
|
||||
|
||||
return res
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -128,16 +128,11 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
|
||||
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
|
||||
|
||||
var pool = Taskpool.new(num_threads = nthreads)
|
||||
GCref(abc.valuesAz)
|
||||
GCref(abc.valuesBz)
|
||||
GCref(abc.valuesCz)
|
||||
|
||||
var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta )
|
||||
var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta )
|
||||
var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta )
|
||||
pool.syncAll()
|
||||
GCunref(abc.valuesAz)
|
||||
GCunref(abc.valuesBz)
|
||||
GCunref(abc.valuesCz)
|
||||
|
||||
let A1 = sync A1fv
|
||||
let B1 = sync B1fv
|
||||
let C1 = sync C1fv
|
||||
@ -147,7 +142,9 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
|
||||
let Q1 = polyInverseNTT( ys, D )
|
||||
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
|
||||
|
||||
pool.syncAll()
|
||||
pool.shutdown()
|
||||
|
||||
return Poly(coeffs: cs)
|
||||
|
||||
#---------------------------------------
|
||||
@ -158,7 +155,7 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
|
||||
# (shifted) Lagrange bases.
|
||||
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
#
|
||||
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
|
||||
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr] =
|
||||
let n = abc.valuesAz.len
|
||||
assert( abc.valuesBz.len == n )
|
||||
assert( abc.valuesCz.len == n )
|
||||
@ -166,16 +163,11 @@ proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
|
||||
let eta = createDomain(2*n).domainGen
|
||||
|
||||
var pool = Taskpool.new(num_threads = nthreads)
|
||||
GCref(abc.valuesAz)
|
||||
GCref(abc.valuesBz)
|
||||
GCref(abc.valuesCz)
|
||||
|
||||
var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta )
|
||||
var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta )
|
||||
var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta )
|
||||
pool.syncAll()
|
||||
GCunref(abc.valuesAz)
|
||||
GCunref(abc.valuesBz)
|
||||
GCunref(abc.valuesCz)
|
||||
|
||||
let A1 = sync A1fv
|
||||
let B1 = sync B1fv
|
||||
let C1 = sync C1fv
|
||||
@ -183,9 +175,34 @@ proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
|
||||
var ys : seq[Fr] = newSeq[Fr]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
|
||||
|
||||
pool.syncAll()
|
||||
pool.shutdown()
|
||||
|
||||
return ys
|
||||
|
||||
#[
|
||||
|
||||
proc computeSnarkjsScalarCoeffs_st( abc: ABC ): seq[Fr] =
|
||||
let n = abc.valuesAz.len
|
||||
assert( abc.valuesBz.len == n )
|
||||
assert( abc.valuesCz.len == n )
|
||||
let D = createDomain(n)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
let A1 : seq[Fr] = shiftEvalDomain( abc.valuesAz, D, eta )
|
||||
let B1 : seq[Fr] = shiftEvalDomain( abc.valuesBz, D, eta )
|
||||
let C1 : seq[Fr] = shiftEvalDomain( abc.valuesCz, D, eta )
|
||||
var ys : seq[Fr] = newSeq[Fr]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
|
||||
return ys
|
||||
|
||||
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
|
||||
if nthreads <= 1:
|
||||
computeSnarkjsScalarCoeffs_st( abc )
|
||||
else:
|
||||
computeSnarkjsScalarCoeffs_mt( nthreads, abc )
|
||||
|
||||
]#
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# the prover
|
||||
#
|
||||
@ -197,6 +214,9 @@ type
|
||||
|
||||
proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness, mask: Mask ): Proof =
|
||||
|
||||
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
|
||||
{.fatal: "Compile with arc/orc!".}
|
||||
|
||||
# if (zkey.header.curve != wtns.curve):
|
||||
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
|
||||
# echo( "wtns.curve = " & ($wtns.curve ) )
|
||||
@ -286,7 +306,7 @@ proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc generateProofWithTrivialMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
|
||||
let mask = Mask(r: intToFr(0), s: intToFr(0))
|
||||
let mask = Mask( r: zeroFr , s: zeroFr )
|
||||
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )
|
||||
|
||||
proc generateProof*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user