do the three FFT/IFFT pairs in parallel

This commit is contained in:
Balazs Komuves 2024-02-29 20:35:16 +01:00
parent cfd30a045e
commit d790dc3162
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
3 changed files with 46 additions and 17 deletions

View File

@ -3,6 +3,7 @@ import sugar
import std/strutils
import std/sequtils
import std/os
import std/cpuinfo
import std/parseopt
import std/times
import std/options
@ -128,6 +129,9 @@ proc parseCliOptions(): Config =
printHelp()
quit()
if cfg.nthreads <= 0:
cfg.nthreads = countProcessors()
return cfg
#-------------------------------------------------------------------------------

View File

@ -99,8 +99,6 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1]
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
# echo("msm with #threads = " & $nthreads)
var pool = Taskpool.new(num_threads = nthreads)
var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks)
@ -142,12 +140,9 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2]
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
# echo("G2 msm with #threads = " & $nthreads)
var pool = Taskpool.new(num_threads = nthreads)
var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks)
# nim is just batshit crazy...
GC_ref(coeffs)
GC_ref(points)

View File

@ -17,6 +17,9 @@ import ./zkey
import std/os
import std/times
import std/cpuinfo
import system
import taskpools
import constantine/math/arithmetic except Fp, Fr
#import constantine/math/io/io_extfields except Fp12
@ -112,7 +115,7 @@ func shiftEvalDomain( values: seq[Fr], D: Domain, eta: Fr ): seq[Fr] =
# computes the quotient polynomial Q = (A*B - C) / Z
# by computing the values on a shifted domain, and interpolating the result
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
func computeQuotientPointwise( abc: ABC ): Poly =
proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
@ -124,15 +127,27 @@ func computeQuotientPointwise( abc: ABC ): Poly =
let eta = createDomain(2*n).domainGen
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
let A1 = shiftEvalDomain( abc.valuesAz, D, eta )
let B1 = shiftEvalDomain( abc.valuesBz, D, eta )
let C1 = shiftEvalDomain( abc.valuesCz, D, eta )
var pool = Taskpool.new(num_threads = nthreads)
GCref(abc.valuesAz)
GCref(abc.valuesBz)
GCref(abc.valuesCz)
var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta )
var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta )
var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta )
pool.syncAll()
GCunref(abc.valuesAz)
GCunref(abc.valuesBz)
GCunref(abc.valuesCz)
let A1 = sync A1fv
let B1 = sync B1fv
let C1 = sync C1fv
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
pool.shutdown()
return Poly(coeffs: cs)
#---------------------------------------
@ -143,17 +158,32 @@ func computeQuotientPointwise( abc: ABC ): Poly =
# (shifted) Lagrange bases.
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] =
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
let A1 = shiftEvalDomain( abc.valuesAz, D, eta )
let B1 = shiftEvalDomain( abc.valuesBz, D, eta )
let C1 = shiftEvalDomain( abc.valuesCz, D, eta )
var pool = Taskpool.new(num_threads = nthreads)
GCref(abc.valuesAz)
GCref(abc.valuesBz)
GCref(abc.valuesCz)
var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta )
var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta )
var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta )
pool.syncAll()
GCunref(abc.valuesAz)
GCunref(abc.valuesBz)
GCunref(abc.valuesCz)
let A1 = sync A1fv
let B1 = sync B1fv
let C1 = sync C1fv
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
pool.shutdown()
return ys
#-------------------------------------------------------------------------------
@ -201,13 +231,13 @@ proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns
# the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( abc )
let polyQ = computeQuotientPointwise( nthreads, abc )
qs = polyQ.coeffs
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
# where L_i are Lagrange basis polynomials on the double-sized domain
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( abc )
qs = computeSnarkjsScalarCoeffs( nthreads, abc )
var zs : seq[Fr] = newSeq[Fr]( nvars - npubs - 1 )
for j in npubs+1..<nvars: