mirror of
https://github.com/logos-storage/nim-groth16.git
synced 2026-01-02 13:43:09 +00:00
preliminary multithreading support (WIP; only for MSM right now)
This commit is contained in:
parent
1ef1b040d6
commit
9d743247e9
@ -28,6 +28,7 @@ proc printHelp() =
|
||||
echo " -h, --help : print this help"
|
||||
echo " -v, --verbose : verbose output"
|
||||
echo " -d, --debug : debug output"
|
||||
echo " -j, --nthreads : number of CPU threads"
|
||||
echo " -t, --time : print time measurements"
|
||||
echo " -p, --prove : create a proof"
|
||||
echo " -y, --verify : verify a proof"
|
||||
@ -54,6 +55,7 @@ type Config = object
|
||||
do_verify: bool
|
||||
do_setup: bool
|
||||
no_masking: bool
|
||||
nthreads: int
|
||||
|
||||
const dummyConfig =
|
||||
Config( zkey_file: ""
|
||||
@ -67,6 +69,7 @@ const dummyConfig =
|
||||
, do_verify: false
|
||||
, do_setup: false
|
||||
, no_masking: false
|
||||
, nthreads: 0
|
||||
)
|
||||
|
||||
proc printConfig(cfg: Config) =
|
||||
@ -102,6 +105,7 @@ proc parseCliOptions(): Config =
|
||||
of "h", "help" : printHelp()
|
||||
of "v", "verbose" : cfg.verbose = true
|
||||
of "d", "debug" : cfg.debug = true
|
||||
of "j", "nthreads" : cfg.nthreads = parseInt(value)
|
||||
of "t", "time" : cfg.measure_time = true
|
||||
of "p", "prove" : cfg.do_prove = true
|
||||
of "y", "verify" : cfg.do_verify = true
|
||||
@ -160,25 +164,19 @@ proc cliMain(cfg: Config) =
|
||||
|
||||
if not (cfg.wtns_file == ""):
|
||||
echo("\nparsing witness file " & quoted(cfg.wtns_file))
|
||||
let start = cpuTime()
|
||||
wtns = parseWitness(cfg.wtns_file)
|
||||
let elapsed = cpuTime() - start
|
||||
if cfg.measure_time: echo("parsing the witness took ",seconds(elapsed))
|
||||
|
||||
withMeasureTime(cfg.measure_time,"parsing the witness"):
|
||||
wtns = parseWitness(cfg.wtns_file)
|
||||
|
||||
if not (cfg.zkey_file == ""):
|
||||
echo("\nparsing zkey file " & quoted(cfg.zkey_file))
|
||||
let start = cpuTime()
|
||||
zkey = parseZKey(cfg.zkey_file)
|
||||
let elapsed = cpuTime() - start
|
||||
if cfg.measure_time: echo("parsing the zkey took ",seconds(elapsed))
|
||||
|
||||
withMeasureTime(cfg.measure_time,"parsing the zkey"):
|
||||
zkey = parseZKey(cfg.zkey_file)
|
||||
|
||||
if not (cfg.r1cs_file == ""):
|
||||
echo("\nparsing r1cs file " & quoted(cfg.r1cs_file))
|
||||
let start = cpuTime()
|
||||
r1cs = parseR1CS(cfg.r1cs_file)
|
||||
let elapsed = cpuTime() - start
|
||||
if cfg.measure_time: echo("parsing the r1cs took ",seconds(elapsed))
|
||||
|
||||
withMeasureTime(cfg.measure_time,"parsing the r1cs"):
|
||||
r1cs = parseR1CS(cfg.r1cs_file)
|
||||
|
||||
if cfg.do_setup:
|
||||
if not (cfg.zkey_file == ""):
|
||||
echo("\nwe are doing a fake trusted setup, don't specify the zkey file!")
|
||||
@ -187,11 +185,9 @@ proc cliMain(cfg: Config) =
|
||||
echo("\nerror: r1cs file is required for the fake setup!")
|
||||
quit()
|
||||
echo("\nperforming fake trusted setup...")
|
||||
let start = cpuTime()
|
||||
zkey = createFakeCircuitSetup( r1cs, flavour=Snarkjs )
|
||||
let elapsed = cpuTime() - start
|
||||
if cfg.measure_time: echo("fake setup took ",seconds(elapsed))
|
||||
|
||||
withMeasureTime(cfg.measure_time,"fake setup"):
|
||||
zkey = createFakeCircuitSetup( r1cs, flavour=Snarkjs )
|
||||
|
||||
if cfg.debug:
|
||||
printGrothHeader(zkey.header)
|
||||
# debugPrintCoeffs(zkey.coeffs)
|
||||
@ -202,14 +198,13 @@ proc cliMain(cfg: Config) =
|
||||
quit()
|
||||
else:
|
||||
echo("generating proof...")
|
||||
let start = cpuTime()
|
||||
let print_timings = cfg.measure_time and cfg.verbose
|
||||
if cfg.no_masking:
|
||||
proof = generateProofWithTrivialMask(print_timings, zkey, wtns)
|
||||
else:
|
||||
proof = generateProof(print_timings, zkey, wtns)
|
||||
let elapsed = cpuTime() - start
|
||||
if cfg.measure_time: echo("proving took ",seconds(elapsed))
|
||||
withMeasureTime(cfg.measure_time,"proving"):
|
||||
if cfg.no_masking:
|
||||
proof = generateProofWithTrivialMask(cfg.nthreads, print_timings, zkey, wtns)
|
||||
else:
|
||||
proof = generateProof(cfg.nthreads, print_timings, zkey, wtns)
|
||||
|
||||
if not (cfg.output_file == ""):
|
||||
echo("exporting the proof to " & quoted(cfg.output_file))
|
||||
exportProof( cfg.output_file, proof )
|
||||
@ -224,12 +219,11 @@ proc cliMain(cfg: Config) =
|
||||
else:
|
||||
let vkey = extractVKey( zkey)
|
||||
echo("\nverifying the proof...")
|
||||
let start = cpuTime()
|
||||
let ok = verifyProof( vkey, proof )
|
||||
let elapsed = cpuTime() - start
|
||||
echo("verification succeeded = ",ok)
|
||||
if cfg.measure_time: echo("verifying took ",seconds(elapsed))
|
||||
|
||||
var ok : bool
|
||||
withMeasureTime(cfg.measure_time,"verifying"):
|
||||
ok = verifyProof( vkey, proof )
|
||||
echo("verification succeeded = ",ok)
|
||||
|
||||
echo("")
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -1 +1,2 @@
|
||||
--path:".."
|
||||
--threads:on
|
||||
@ -7,4 +7,5 @@ skipDirs = @["groth16/example"]
|
||||
binDir = "build"
|
||||
namedBin = {"cli/cli_main": "nim-groth16"}.toTable()
|
||||
|
||||
requires "https://github.com/status-im/nim-taskpools"
|
||||
requires "https://github.com/mratsim/constantine#5f7ba18f2ed351260015397c9eae079a6decaee1"
|
||||
@ -4,6 +4,8 @@
|
||||
#
|
||||
|
||||
import system
|
||||
import std/cpuinfo
|
||||
import taskpools
|
||||
|
||||
# import constantine/curves_primitives except Fp, Fp2, Fr
|
||||
|
||||
@ -23,11 +25,16 @@ import constantine/math/elliptic/ec_scalar_mul_vartime as scl except Su
|
||||
import constantine/math/elliptic/ec_multi_scalar_mul as msm except Subgroup
|
||||
|
||||
import groth16/bn128/fields
|
||||
import groth16/bn128/curves
|
||||
import groth16/bn128/curves as mycurves
|
||||
|
||||
import groth16/misc # TEMP DEBUGGING
|
||||
import std/times
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
|
||||
proc msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
|
||||
|
||||
# let start = cpuTime()
|
||||
|
||||
let N = coeffs.len
|
||||
assert( N == points.len, "incompatible sequence lengths" )
|
||||
@ -46,6 +53,9 @@ func msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
|
||||
var rAff: G1
|
||||
prj.affine(rAff, r)
|
||||
|
||||
# let elapsed = cpuTime() - start
|
||||
# echo("computing an MSM of size " & ($N) & " took " & seconds(elapsed))
|
||||
|
||||
return rAff
|
||||
|
||||
#---------------------------------------
|
||||
@ -74,41 +84,96 @@ func msmConstantineG2*( coeffs: openArray[Fr] , points: openArray[G2] ): G2 =
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
#[
|
||||
type InputTuple = tuple[idx:int, coeffs: openArray[Fr] , points: openArray[G1]]
|
||||
const task_multiplier : int = 1
|
||||
|
||||
func msmMultiThreadedG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
|
||||
let N = coeffs.len
|
||||
assert( N == points.len, "incompatible sequence lengths" )
|
||||
|
||||
let nthreadsTarget = 8
|
||||
proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1] ): G1 =
|
||||
|
||||
# for N <= 255 , we use 1 thread
|
||||
# for N == 256 , we use 2 threads
|
||||
# for N == 512 , we use 4 threads
|
||||
# for N >= 1024, we use 8 threads
|
||||
let nthreads = max( 1 , min( N div 128 , nthreadsTarget ) )
|
||||
# for N >= 1024, we use 8+ threads
|
||||
|
||||
let m = N div nthreads
|
||||
let N = coeffs.len
|
||||
assert( N == points.len, "incompatible sequence lengths" )
|
||||
let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 )
|
||||
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
|
||||
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
|
||||
|
||||
var threads : seq[Thread[InputTuple]] = newSeq[Thread[InputTuple]]( nthreads )
|
||||
var results : seq[G1] = newSeq[G1]( nthreads )
|
||||
# echo("msm with #threads = " & $nthreads)
|
||||
|
||||
proc myThreadFunc( inp: InputTuple ) {.thread.} =
|
||||
results[inp.idx] = msmConstantineG1( inp.coeffs, inp.points )
|
||||
var pool = Taskpool.new(num_threads = nthreads)
|
||||
var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks)
|
||||
|
||||
for i in 0..<nthreads:
|
||||
let a = i*m
|
||||
let b = if (i == nthreads-1): N else: (i+1)*m
|
||||
createThread(threads[i], myThreadFunc, (i, coeffs[a..<b], points[a..<b]))
|
||||
# nim is just batshit crazy...
|
||||
GC_ref(coeffs)
|
||||
GC_ref(points)
|
||||
|
||||
joinThreads(threads)
|
||||
var a : int = 0
|
||||
var b : int
|
||||
for k in 0..<ntasks:
|
||||
if k < ntasks-1:
|
||||
b = (N*(k+1)) div ntasks
|
||||
else:
|
||||
b = N
|
||||
let cs = coeffs[a..<b]
|
||||
let ps = points[a..<b]
|
||||
pending[k] = pool.spawn msmConstantineG1( cs, ps );
|
||||
a = b
|
||||
|
||||
var r : G1 = infG1
|
||||
for i in 0..<nthreads: r += results[i]
|
||||
var res : G1 = infG1
|
||||
for k in 0..<ntasks:
|
||||
res += sync pending[k]
|
||||
|
||||
return r
|
||||
]#
|
||||
pool.syncAll()
|
||||
pool.shutdown()
|
||||
|
||||
GC_unref(coeffs)
|
||||
GC_unref(points)
|
||||
|
||||
return res
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2] ): G2 =
|
||||
|
||||
let N = coeffs.len
|
||||
assert( N == points.len, "incompatible sequence lengths" )
|
||||
let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 )
|
||||
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
|
||||
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
|
||||
|
||||
# echo("G2 msm with #threads = " & $nthreads)
|
||||
|
||||
var pool = Taskpool.new(num_threads = nthreads)
|
||||
var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks)
|
||||
|
||||
# nim is just batshit crazy...
|
||||
GC_ref(coeffs)
|
||||
GC_ref(points)
|
||||
|
||||
var a : int = 0
|
||||
var b : int
|
||||
for k in 0..<ntasks:
|
||||
if k < ntasks-1:
|
||||
b = (N*(k+1)) div ntasks
|
||||
else:
|
||||
b = N
|
||||
let cs = coeffs[a..<b]
|
||||
let ps = points[a..<b]
|
||||
pending[k] = pool.spawn msmConstantineG2( cs, ps );
|
||||
a = b
|
||||
|
||||
var res : G2 = infG2
|
||||
for k in 0..<ntasks:
|
||||
res += sync pending[k]
|
||||
|
||||
pool.syncAll()
|
||||
pool.shutdown()
|
||||
|
||||
GC_unref(coeffs)
|
||||
GC_unref(points)
|
||||
|
||||
return res
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -152,8 +217,8 @@ func msmNaiveG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 =
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func msmG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 = msmConstantineG1(coeffs, points)
|
||||
func msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 = msmConstantineG2(coeffs, points)
|
||||
proc msmG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 = msmConstantineG1(coeffs, points)
|
||||
proc msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 = msmConstantineG2(coeffs, points)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#
|
||||
|
||||
import strformat
|
||||
import times, os, strutils
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -13,6 +14,19 @@ func quoted*(s: string): string = fmt"`{s:s}`"
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
template withMeasureTime*(doPrint: bool, text: string, code: untyped) =
|
||||
block:
|
||||
if doPrint:
|
||||
let t0 = epochTime()
|
||||
code
|
||||
let elapsed = epochTime() - t0
|
||||
let elapsedStr = elapsed.formatFloat(format = ffDecimal, precision = 4)
|
||||
echo ( text & " took " & elapsedStr & " seconds" )
|
||||
else:
|
||||
code
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func delta*(i, j: int) : int =
|
||||
return (if (i == j): 1 else: 0)
|
||||
|
||||
|
||||
@ -165,7 +165,7 @@ type
|
||||
r*: Fr # masking coefficients
|
||||
s*: Fr # for zero knowledge
|
||||
|
||||
proc generateProofWithMask*( printTimings: bool, zkey: ZKey, wtns: Witness, mask: Mask ): Proof =
|
||||
proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness, mask: Mask ): Proof =
|
||||
|
||||
# if (zkey.header.curve != wtns.curve):
|
||||
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
|
||||
@ -190,28 +190,24 @@ proc generateProofWithMask*( printTimings: bool, zkey: ZKey, wtns: Witness, mask
|
||||
for i in 0..npubs: pubIO[i] = witness[i]
|
||||
|
||||
start = cpuTime()
|
||||
var abc : ABC = buildABC( zkey, witness )
|
||||
if printTimings:
|
||||
let elapsed = cpuTime() - start
|
||||
echo("building 'ABC' took ",seconds(elapsed))
|
||||
var abc : ABC
|
||||
withMeasureTime(printTimings,"building 'ABC'"):
|
||||
abc = buildABC( zkey, witness )
|
||||
|
||||
start = cpuTime()
|
||||
var qs : seq[Fr]
|
||||
case zkey.header.flavour
|
||||
withMeasureTime(printTimings,"computing the quotient (FFTs)"):
|
||||
case zkey.header.flavour
|
||||
|
||||
# the points H are [delta^-1 * tau^i * Z(tau)]
|
||||
of JensGroth:
|
||||
let polyQ = computeQuotientPointwise( abc )
|
||||
qs = polyQ.coeffs
|
||||
|
||||
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
|
||||
# where L_i are Lagrange basis polynomials on the double-sized domain
|
||||
of Snarkjs:
|
||||
qs = computeSnarkjsScalarCoeffs( abc )
|
||||
|
||||
if printTimings:
|
||||
let elapsed = cpuTime() - start
|
||||
echo("computing the quotient took ",seconds(elapsed))
|
||||
# the points H are [delta^-1 * tau^i * Z(tau)]
|
||||
of JensGroth:
|
||||
let polyQ = computeQuotientPointwise( abc )
|
||||
qs = polyQ.coeffs
|
||||
|
||||
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
|
||||
# where L_i are Lagrange basis polynomials on the double-sized domain
|
||||
of Snarkjs:
|
||||
qs = computeSnarkjsScalarCoeffs( abc )
|
||||
|
||||
var zs : seq[Fr] = newSeq[Fr]( nvars - npubs - 1 )
|
||||
for j in npubs+1..<nvars:
|
||||
@ -229,57 +225,45 @@ proc generateProofWithMask*( printTimings: bool, zkey: ZKey, wtns: Witness, mask
|
||||
assert( nvars - npubs - 1 == zs.len )
|
||||
assert( nvars - npubs - 1 == pts.pointsC1.len )
|
||||
|
||||
start = cpuTime()
|
||||
var pi_a : G1
|
||||
pi_a = spec.alpha1
|
||||
pi_a += r ** spec.delta1
|
||||
pi_a += msmG1( witness , pts.pointsA1 )
|
||||
if printTimings:
|
||||
let elapsed = cpuTime() - start
|
||||
echo("computing pi_A (G1 MSM) took ",seconds(elapsed))
|
||||
withMeasureTime(printTimings,"computing pi_A (G1 MSM)"):
|
||||
pi_a = spec.alpha1
|
||||
pi_a += r ** spec.delta1
|
||||
pi_a += msmMultiThreadedG1( nthreads , witness , pts.pointsA1 )
|
||||
|
||||
start = cpuTime()
|
||||
var rho : G1
|
||||
rho = spec.beta1
|
||||
rho += s ** spec.delta1
|
||||
rho += msmG1( witness , pts.pointsB1 )
|
||||
if printTimings:
|
||||
let elapsed = cpuTime() - start
|
||||
echo("computing rho (G1 MSM) took ",seconds(elapsed))
|
||||
withMeasureTime(printTimings,"computing rho (G1 MSM)"):
|
||||
rho = spec.beta1
|
||||
rho += s ** spec.delta1
|
||||
rho += msmMultiThreadedG1( nthreads , witness , pts.pointsB1 )
|
||||
|
||||
start = cpuTime()
|
||||
var pi_b : G2
|
||||
pi_b = spec.beta2
|
||||
pi_b += s ** spec.delta2
|
||||
pi_b += msmG2( witness , pts.pointsB2 )
|
||||
if printTimings:
|
||||
let elapsed = cpuTime() - start
|
||||
echo("computing pi_B (G2 MSM) took ",seconds(elapsed))
|
||||
withMeasureTime(printTimings,"computing pi_B (G2 MSM)"):
|
||||
pi_b = spec.beta2
|
||||
pi_b += s ** spec.delta2
|
||||
pi_b += msmMultiThreadedG2( nthreads , witness , pts.pointsB2 )
|
||||
|
||||
start = cpuTime()
|
||||
var pi_c : G1
|
||||
pi_c = s ** pi_a
|
||||
pi_c += r ** rho
|
||||
pi_c += negFr(r*s) ** spec.delta1
|
||||
pi_c += msmG1( qs , pts.pointsH1 )
|
||||
pi_c += msmG1( zs , pts.pointsC1 )
|
||||
if printTimings:
|
||||
let elapsed = cpuTime() - start
|
||||
echo("computing pi_C (2x G1 MSM) took ",seconds(elapsed))
|
||||
withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"):
|
||||
pi_c = s ** pi_a
|
||||
pi_c += r ** rho
|
||||
pi_c += negFr(r*s) ** spec.delta1
|
||||
pi_c += msmMultiThreadedG1( nthreads, qs , pts.pointsH1 )
|
||||
pi_c += msmMultiThreadedG1( nthreads, zs , pts.pointsC1 )
|
||||
|
||||
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc generateProofWithTrivialMask*( printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
|
||||
proc generateProofWithTrivialMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
|
||||
let mask = Mask(r: intToFr(0), s: intToFr(0))
|
||||
return generateProofWithMask( printTimings, zkey, wtns, mask )
|
||||
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )
|
||||
|
||||
proc generateProof*( printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
|
||||
proc generateProof*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
|
||||
|
||||
# masking coeffs
|
||||
let r : Fr = randFr()
|
||||
let s : Fr = randFr()
|
||||
let mask = Mask(r: r, s: s)
|
||||
|
||||
return generateProofWithMask( printTimings, zkey, wtns, mask )
|
||||
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user