preliminary multithreading support (WIP; only for MSM right now)

This commit is contained in:
Balazs Komuves 2024-02-29 18:59:35 +01:00
parent 1ef1b040d6
commit 9d743247e9
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
6 changed files with 172 additions and 113 deletions

View File

@ -28,6 +28,7 @@ proc printHelp() =
echo " -h, --help : print this help"
echo " -v, --verbose : verbose output"
echo " -d, --debug : debug output"
echo " -j, --nthreads : number of CPU threads"
echo " -t, --time : print time measurements"
echo " -p, --prove : create a proof"
echo " -y, --verify : verify a proof"
@ -54,6 +55,7 @@ type Config = object
do_verify: bool
do_setup: bool
no_masking: bool
nthreads: int
const dummyConfig =
Config( zkey_file: ""
@ -67,6 +69,7 @@ const dummyConfig =
, do_verify: false
, do_setup: false
, no_masking: false
, nthreads: 0
)
proc printConfig(cfg: Config) =
@ -102,6 +105,7 @@ proc parseCliOptions(): Config =
of "h", "help" : printHelp()
of "v", "verbose" : cfg.verbose = true
of "d", "debug" : cfg.debug = true
of "j", "nthreads" : cfg.nthreads = parseInt(value)
of "t", "time" : cfg.measure_time = true
of "p", "prove" : cfg.do_prove = true
of "y", "verify" : cfg.do_verify = true
@ -160,25 +164,19 @@ proc cliMain(cfg: Config) =
if not (cfg.wtns_file == ""):
echo("\nparsing witness file " & quoted(cfg.wtns_file))
let start = cpuTime()
wtns = parseWitness(cfg.wtns_file)
let elapsed = cpuTime() - start
if cfg.measure_time: echo("parsing the witness took ",seconds(elapsed))
withMeasureTime(cfg.measure_time,"parsing the witness"):
wtns = parseWitness(cfg.wtns_file)
if not (cfg.zkey_file == ""):
echo("\nparsing zkey file " & quoted(cfg.zkey_file))
let start = cpuTime()
zkey = parseZKey(cfg.zkey_file)
let elapsed = cpuTime() - start
if cfg.measure_time: echo("parsing the zkey took ",seconds(elapsed))
withMeasureTime(cfg.measure_time,"parsing the zkey"):
zkey = parseZKey(cfg.zkey_file)
if not (cfg.r1cs_file == ""):
echo("\nparsing r1cs file " & quoted(cfg.r1cs_file))
let start = cpuTime()
r1cs = parseR1CS(cfg.r1cs_file)
let elapsed = cpuTime() - start
if cfg.measure_time: echo("parsing the r1cs took ",seconds(elapsed))
withMeasureTime(cfg.measure_time,"parsing the r1cs"):
r1cs = parseR1CS(cfg.r1cs_file)
if cfg.do_setup:
if not (cfg.zkey_file == ""):
echo("\nwe are doing a fake trusted setup, don't specify the zkey file!")
@ -187,11 +185,9 @@ proc cliMain(cfg: Config) =
echo("\nerror: r1cs file is required for the fake setup!")
quit()
echo("\nperforming fake trusted setup...")
let start = cpuTime()
zkey = createFakeCircuitSetup( r1cs, flavour=Snarkjs )
let elapsed = cpuTime() - start
if cfg.measure_time: echo("fake setup took ",seconds(elapsed))
withMeasureTime(cfg.measure_time,"fake setup"):
zkey = createFakeCircuitSetup( r1cs, flavour=Snarkjs )
if cfg.debug:
printGrothHeader(zkey.header)
# debugPrintCoeffs(zkey.coeffs)
@ -202,14 +198,13 @@ proc cliMain(cfg: Config) =
quit()
else:
echo("generating proof...")
let start = cpuTime()
let print_timings = cfg.measure_time and cfg.verbose
if cfg.no_masking:
proof = generateProofWithTrivialMask(print_timings, zkey, wtns)
else:
proof = generateProof(print_timings, zkey, wtns)
let elapsed = cpuTime() - start
if cfg.measure_time: echo("proving took ",seconds(elapsed))
withMeasureTime(cfg.measure_time,"proving"):
if cfg.no_masking:
proof = generateProofWithTrivialMask(cfg.nthreads, print_timings, zkey, wtns)
else:
proof = generateProof(cfg.nthreads, print_timings, zkey, wtns)
if not (cfg.output_file == ""):
echo("exporting the proof to " & quoted(cfg.output_file))
exportProof( cfg.output_file, proof )
@ -224,12 +219,11 @@ proc cliMain(cfg: Config) =
else:
let vkey = extractVKey( zkey)
echo("\nverifying the proof...")
let start = cpuTime()
let ok = verifyProof( vkey, proof )
let elapsed = cpuTime() - start
echo("verification succeeded = ",ok)
if cfg.measure_time: echo("verifying took ",seconds(elapsed))
var ok : bool
withMeasureTime(cfg.measure_time,"verifying"):
ok = verifyProof( vkey, proof )
echo("verification succeeded = ",ok)
echo("")
#-------------------------------------------------------------------------------

View File

@ -1 +1,2 @@
--path:".."
--threads:on

View File

@ -7,4 +7,5 @@ skipDirs = @["groth16/example"]
binDir = "build"
namedBin = {"cli/cli_main": "nim-groth16"}.toTable()
requires "https://github.com/status-im/nim-taskpools"
requires "https://github.com/mratsim/constantine#5f7ba18f2ed351260015397c9eae079a6decaee1"

View File

@ -4,6 +4,8 @@
#
import system
import std/cpuinfo
import taskpools
# import constantine/curves_primitives except Fp, Fp2, Fr
@ -23,11 +25,16 @@ import constantine/math/elliptic/ec_scalar_mul_vartime as scl except Su
import constantine/math/elliptic/ec_multi_scalar_mul as msm except Subgroup
import groth16/bn128/fields
import groth16/bn128/curves
import groth16/bn128/curves as mycurves
import groth16/misc # TEMP DEBUGGING
import std/times
#-------------------------------------------------------------------------------
func msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
proc msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
# let start = cpuTime()
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
@ -46,6 +53,9 @@ func msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
var rAff: G1
prj.affine(rAff, r)
# let elapsed = cpuTime() - start
# echo("computing an MSM of size " & ($N) & " took " & seconds(elapsed))
return rAff
#---------------------------------------
@ -74,41 +84,96 @@ func msmConstantineG2*( coeffs: openArray[Fr] , points: openArray[G2] ): G2 =
#-------------------------------------------------------------------------------
#[
type InputTuple = tuple[idx:int, coeffs: openArray[Fr] , points: openArray[G1]]
const task_multiplier : int = 1
func msmMultiThreadedG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
let nthreadsTarget = 8
proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1] ): G1 =
# for N <= 255 , we use 1 thread
# for N == 256 , we use 2 threads
# for N == 512 , we use 4 threads
# for N >= 1024, we use 8 threads
let nthreads = max( 1 , min( N div 128 , nthreadsTarget ) )
# for N >= 1024, we use 8+ threads
let m = N div nthreads
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 )
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
var threads : seq[Thread[InputTuple]] = newSeq[Thread[InputTuple]]( nthreads )
var results : seq[G1] = newSeq[G1]( nthreads )
# echo("msm with #threads = " & $nthreads)
proc myThreadFunc( inp: InputTuple ) {.thread.} =
results[inp.idx] = msmConstantineG1( inp.coeffs, inp.points )
var pool = Taskpool.new(num_threads = nthreads)
var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks)
for i in 0..<nthreads:
let a = i*m
let b = if (i == nthreads-1): N else: (i+1)*m
createThread(threads[i], myThreadFunc, (i, coeffs[a..<b], points[a..<b]))
# nim is just batshit crazy...
GC_ref(coeffs)
GC_ref(points)
joinThreads(threads)
var a : int = 0
var b : int
for k in 0..<ntasks:
if k < ntasks-1:
b = (N*(k+1)) div ntasks
else:
b = N
let cs = coeffs[a..<b]
let ps = points[a..<b]
pending[k] = pool.spawn msmConstantineG1( cs, ps );
a = b
var r : G1 = infG1
for i in 0..<nthreads: r += results[i]
var res : G1 = infG1
for k in 0..<ntasks:
res += sync pending[k]
return r
]#
pool.syncAll()
pool.shutdown()
GC_unref(coeffs)
GC_unref(points)
return res
#---------------------------------------
proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2] ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 )
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
# echo("G2 msm with #threads = " & $nthreads)
var pool = Taskpool.new(num_threads = nthreads)
var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks)
# nim is just batshit crazy...
GC_ref(coeffs)
GC_ref(points)
var a : int = 0
var b : int
for k in 0..<ntasks:
if k < ntasks-1:
b = (N*(k+1)) div ntasks
else:
b = N
let cs = coeffs[a..<b]
let ps = points[a..<b]
pending[k] = pool.spawn msmConstantineG2( cs, ps );
a = b
var res : G2 = infG2
for k in 0..<ntasks:
res += sync pending[k]
pool.syncAll()
pool.shutdown()
GC_unref(coeffs)
GC_unref(points)
return res
#-------------------------------------------------------------------------------
@ -152,8 +217,8 @@ func msmNaiveG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 =
#-------------------------------------------------------------------------------
func msmG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 = msmConstantineG1(coeffs, points)
func msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 = msmConstantineG2(coeffs, points)
proc msmG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 = msmConstantineG1(coeffs, points)
proc msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 = msmConstantineG2(coeffs, points)
#-------------------------------------------------------------------------------

View File

@ -4,6 +4,7 @@
#
import strformat
import times, os, strutils
#-------------------------------------------------------------------------------
@ -13,6 +14,19 @@ func quoted*(s: string): string = fmt"`{s:s}`"
#-------------------------------------------------------------------------------
template withMeasureTime*(doPrint: bool, text: string, code: untyped) =
block:
if doPrint:
let t0 = epochTime()
code
let elapsed = epochTime() - t0
let elapsedStr = elapsed.formatFloat(format = ffDecimal, precision = 4)
echo ( text & " took " & elapsedStr & " seconds" )
else:
code
#-------------------------------------------------------------------------------
func delta*(i, j: int) : int =
return (if (i == j): 1 else: 0)

View File

@ -165,7 +165,7 @@ type
r*: Fr # masking coefficients
s*: Fr # for zero knowledge
proc generateProofWithMask*( printTimings: bool, zkey: ZKey, wtns: Witness, mask: Mask ): Proof =
proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness, mask: Mask ): Proof =
# if (zkey.header.curve != wtns.curve):
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
@ -190,28 +190,24 @@ proc generateProofWithMask*( printTimings: bool, zkey: ZKey, wtns: Witness, mask
for i in 0..npubs: pubIO[i] = witness[i]
start = cpuTime()
var abc : ABC = buildABC( zkey, witness )
if printTimings:
let elapsed = cpuTime() - start
echo("building 'ABC' took ",seconds(elapsed))
var abc : ABC
withMeasureTime(printTimings,"building 'ABC'"):
abc = buildABC( zkey, witness )
start = cpuTime()
var qs : seq[Fr]
case zkey.header.flavour
withMeasureTime(printTimings,"computing the quotient (FFTs)"):
case zkey.header.flavour
# the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( abc )
qs = polyQ.coeffs
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
# where L_i are Lagrange basis polynomials on the double-sized domain
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( abc )
if printTimings:
let elapsed = cpuTime() - start
echo("computing the quotient took ",seconds(elapsed))
# the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( abc )
qs = polyQ.coeffs
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
# where L_i are Lagrange basis polynomials on the double-sized domain
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( abc )
var zs : seq[Fr] = newSeq[Fr]( nvars - npubs - 1 )
for j in npubs+1..<nvars:
@ -229,57 +225,45 @@ proc generateProofWithMask*( printTimings: bool, zkey: ZKey, wtns: Witness, mask
assert( nvars - npubs - 1 == zs.len )
assert( nvars - npubs - 1 == pts.pointsC1.len )
start = cpuTime()
var pi_a : G1
pi_a = spec.alpha1
pi_a += r ** spec.delta1
pi_a += msmG1( witness , pts.pointsA1 )
if printTimings:
let elapsed = cpuTime() - start
echo("computing pi_A (G1 MSM) took ",seconds(elapsed))
withMeasureTime(printTimings,"computing pi_A (G1 MSM)"):
pi_a = spec.alpha1
pi_a += r ** spec.delta1
pi_a += msmMultiThreadedG1( nthreads , witness , pts.pointsA1 )
start = cpuTime()
var rho : G1
rho = spec.beta1
rho += s ** spec.delta1
rho += msmG1( witness , pts.pointsB1 )
if printTimings:
let elapsed = cpuTime() - start
echo("computing rho (G1 MSM) took ",seconds(elapsed))
withMeasureTime(printTimings,"computing rho (G1 MSM)"):
rho = spec.beta1
rho += s ** spec.delta1
rho += msmMultiThreadedG1( nthreads , witness , pts.pointsB1 )
start = cpuTime()
var pi_b : G2
pi_b = spec.beta2
pi_b += s ** spec.delta2
pi_b += msmG2( witness , pts.pointsB2 )
if printTimings:
let elapsed = cpuTime() - start
echo("computing pi_B (G2 MSM) took ",seconds(elapsed))
withMeasureTime(printTimings,"computing pi_B (G2 MSM)"):
pi_b = spec.beta2
pi_b += s ** spec.delta2
pi_b += msmMultiThreadedG2( nthreads , witness , pts.pointsB2 )
start = cpuTime()
var pi_c : G1
pi_c = s ** pi_a
pi_c += r ** rho
pi_c += negFr(r*s) ** spec.delta1
pi_c += msmG1( qs , pts.pointsH1 )
pi_c += msmG1( zs , pts.pointsC1 )
if printTimings:
let elapsed = cpuTime() - start
echo("computing pi_C (2x G1 MSM) took ",seconds(elapsed))
withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"):
pi_c = s ** pi_a
pi_c += r ** rho
pi_c += negFr(r*s) ** spec.delta1
pi_c += msmMultiThreadedG1( nthreads, qs , pts.pointsH1 )
pi_c += msmMultiThreadedG1( nthreads, zs , pts.pointsC1 )
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
#-------------------------------------------------------------------------------
proc generateProofWithTrivialMask*( printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
proc generateProofWithTrivialMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
let mask = Mask(r: intToFr(0), s: intToFr(0))
return generateProofWithMask( printTimings, zkey, wtns, mask )
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )
proc generateProof*( printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
proc generateProof*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
# masking coeffs
let r : Fr = randFr()
let s : Fr = randFr()
let mask = Mask(r: r, s: s)
return generateProofWithMask( printTimings, zkey, wtns, mask )
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )