import circom helpers
This commit is contained in:
parent
e39b5ef694
commit
018be61ebc
|
@ -13,15 +13,15 @@ import pkg/constantine/math/[arithmetic, io/io_bigints, io/io_fields]
|
||||||
import ./utils
|
import ./utils
|
||||||
import ./create_circuits
|
import ./create_circuits
|
||||||
|
|
||||||
type CircuitFiles* = object
|
type CircomCircuit* = object
|
||||||
r1cs*: string
|
r1cs*: string
|
||||||
wasm*: string
|
wasm*: string
|
||||||
zkey*: string
|
zkey*: string
|
||||||
inputs*: string
|
inputs*: string
|
||||||
dir*: string
|
dir*: string
|
||||||
circName*: string
|
circName*: string
|
||||||
backendCfg : ptr CircomBn254Cfg
|
backendCfg: ptr CircomBn254Cfg
|
||||||
vkp* : ptr VerifyingKey
|
vkp*: ptr VerifyingKey
|
||||||
|
|
||||||
proc release*(self: CircomCompat) =
|
proc release*(self: CircomCompat) =
|
||||||
## Release the ctx
|
## Release the ctx
|
||||||
|
@ -33,9 +33,7 @@ proc release*(self: CircomCompat) =
|
||||||
if not isNil(self.vkp):
|
if not isNil(self.vkp):
|
||||||
self.vkp.unsafeAddr.release_key()
|
self.vkp.unsafeAddr.release_key()
|
||||||
|
|
||||||
proc prove*[H](
|
proc prove*[H](self: CircomCompat, input: ProofInputs[H]): ?!CircomProof =
|
||||||
self: CircomCompat,
|
|
||||||
input: ProofInputs[H]): ?!CircomProof =
|
|
||||||
## Encode buffers using a ctx
|
## Encode buffers using a ctx
|
||||||
##
|
##
|
||||||
|
|
||||||
|
@ -44,28 +42,27 @@ proc prove*[H](
|
||||||
# to the circom ffi - `setLen` is used to adjust the
|
# to the circom ffi - `setLen` is used to adjust the
|
||||||
# sequence length to the correct size which also 0 pads
|
# sequence length to the correct size which also 0 pads
|
||||||
# to the correct length
|
# to the correct length
|
||||||
doAssert input.samples.len == self.numSamples,
|
doAssert input.samples.len == self.numSamples, "Number of samples does not match"
|
||||||
"Number of samples does not match"
|
|
||||||
|
|
||||||
doAssert input.slotProof.len <= self.datasetDepth,
|
doAssert input.slotProof.len <= self.datasetDepth,
|
||||||
"Number of slot proofs does not match"
|
"Number of slot proofs does not match"
|
||||||
|
|
||||||
doAssert input.samples.allIt(
|
doAssert input.samples.allIt(
|
||||||
block:
|
block:
|
||||||
(it.merklePaths.len <= self.slotDepth + self.blkDepth and
|
(
|
||||||
it.cellData.len <= self.cellElms * 32)), "Merkle paths length does not match"
|
it.merklePaths.len <= self.slotDepth + self.blkDepth and
|
||||||
|
it.cellData.len <= self.cellElms * 32
|
||||||
|
)
|
||||||
|
), "Merkle paths length does not match"
|
||||||
|
|
||||||
# TODO: All parameters should match circom's static parametter
|
# TODO: All parameters should match circom's static parametter
|
||||||
var
|
var ctx: ptr CircomCompatCtx
|
||||||
ctx: ptr CircomCompatCtx
|
|
||||||
|
|
||||||
defer:
|
defer:
|
||||||
if ctx != nil:
|
if ctx != nil:
|
||||||
ctx.addr.releaseCircomCompat()
|
ctx.addr.releaseCircomCompat()
|
||||||
|
|
||||||
if initCircomCompat(
|
if initCircomCompat(self.backendCfg, addr ctx) != ERR_OK or ctx == nil:
|
||||||
self.backendCfg,
|
|
||||||
addr ctx) != ERR_OK or ctx == nil:
|
|
||||||
raiseAssert("failed to initialize CircomCompat ctx")
|
raiseAssert("failed to initialize CircomCompat ctx")
|
||||||
|
|
||||||
var
|
var
|
||||||
|
@ -73,69 +70,61 @@ proc prove*[H](
|
||||||
dataSetRoot = input.datasetRoot.toBytes
|
dataSetRoot = input.datasetRoot.toBytes
|
||||||
slotRoot = input.slotRoot.toBytes
|
slotRoot = input.slotRoot.toBytes
|
||||||
|
|
||||||
if ctx.pushInputU256Array(
|
if ctx.pushInputU256Array("entropy".cstring, entropy[0].addr, entropy.len.uint32) !=
|
||||||
"entropy".cstring, entropy[0].addr, entropy.len.uint32) != ERR_OK:
|
ERR_OK:
|
||||||
return failure("Failed to push entropy")
|
return failure("Failed to push entropy")
|
||||||
|
|
||||||
if ctx.pushInputU256Array(
|
if ctx.pushInputU256Array(
|
||||||
"dataSetRoot".cstring, dataSetRoot[0].addr, dataSetRoot.len.uint32) != ERR_OK:
|
"dataSetRoot".cstring, dataSetRoot[0].addr, dataSetRoot.len.uint32
|
||||||
|
) != ERR_OK:
|
||||||
return failure("Failed to push data set root")
|
return failure("Failed to push data set root")
|
||||||
|
|
||||||
if ctx.pushInputU256Array(
|
if ctx.pushInputU256Array("slotRoot".cstring, slotRoot[0].addr, slotRoot.len.uint32) !=
|
||||||
"slotRoot".cstring, slotRoot[0].addr, slotRoot.len.uint32) != ERR_OK:
|
ERR_OK:
|
||||||
return failure("Failed to push data set root")
|
return failure("Failed to push data set root")
|
||||||
|
|
||||||
if ctx.pushInputU32(
|
if ctx.pushInputU32("nCellsPerSlot".cstring, input.nCellsPerSlot.uint32) != ERR_OK:
|
||||||
"nCellsPerSlot".cstring, input.nCellsPerSlot.uint32) != ERR_OK:
|
|
||||||
return failure("Failed to push nCellsPerSlot")
|
return failure("Failed to push nCellsPerSlot")
|
||||||
|
|
||||||
if ctx.pushInputU32(
|
if ctx.pushInputU32("nSlotsPerDataSet".cstring, input.nSlotsPerDataSet.uint32) !=
|
||||||
"nSlotsPerDataSet".cstring, input.nSlotsPerDataSet.uint32) != ERR_OK:
|
ERR_OK:
|
||||||
return failure("Failed to push nSlotsPerDataSet")
|
return failure("Failed to push nSlotsPerDataSet")
|
||||||
|
|
||||||
if ctx.pushInputU32(
|
if ctx.pushInputU32("slotIndex".cstring, input.slotIndex.uint32) != ERR_OK:
|
||||||
"slotIndex".cstring, input.slotIndex.uint32) != ERR_OK:
|
|
||||||
return failure("Failed to push slotIndex")
|
return failure("Failed to push slotIndex")
|
||||||
|
|
||||||
var
|
var slotProof = input.slotProof.mapIt(it.toBytes).concat
|
||||||
slotProof = input.slotProof.mapIt( it.toBytes ).concat
|
|
||||||
|
|
||||||
slotProof.setLen(self.datasetDepth) # zero pad inputs to correct size
|
slotProof.setLen(self.datasetDepth) # zero pad inputs to correct size
|
||||||
|
|
||||||
# arrays are always flattened
|
# arrays are always flattened
|
||||||
if ctx.pushInputU256Array(
|
if ctx.pushInputU256Array(
|
||||||
"slotProof".cstring,
|
"slotProof".cstring, slotProof[0].addr, uint (slotProof[0].len * slotProof.len)
|
||||||
slotProof[0].addr,
|
) != ERR_OK:
|
||||||
uint (slotProof[0].len * slotProof.len)) != ERR_OK:
|
|
||||||
return failure("Failed to push slot proof")
|
return failure("Failed to push slot proof")
|
||||||
|
|
||||||
for s in input.samples:
|
for s in input.samples:
|
||||||
var
|
var
|
||||||
merklePaths = s.merklePaths.mapIt( it.toBytes )
|
merklePaths = s.merklePaths.mapIt(it.toBytes)
|
||||||
data = s.cellData
|
data = s.cellData
|
||||||
|
|
||||||
merklePaths.setLen(self.slotDepth) # zero pad inputs to correct size
|
merklePaths.setLen(self.slotDepth) # zero pad inputs to correct size
|
||||||
if ctx.pushInputU256Array(
|
if ctx.pushInputU256Array(
|
||||||
"merklePaths".cstring,
|
"merklePaths".cstring,
|
||||||
merklePaths[0].addr,
|
merklePaths[0].addr,
|
||||||
uint (merklePaths[0].len * merklePaths.len)) != ERR_OK:
|
uint (merklePaths[0].len * merklePaths.len),
|
||||||
|
) != ERR_OK:
|
||||||
return failure("Failed to push merkle paths")
|
return failure("Failed to push merkle paths")
|
||||||
|
|
||||||
data.setLen(self.cellElms * 32) # zero pad inputs to correct size
|
data.setLen(self.cellElms * 32) # zero pad inputs to correct size
|
||||||
if ctx.pushInputU256Array(
|
if ctx.pushInputU256Array("cellData".cstring, data[0].addr, data.len.uint) != ERR_OK:
|
||||||
"cellData".cstring,
|
|
||||||
data[0].addr,
|
|
||||||
data.len.uint) != ERR_OK:
|
|
||||||
return failure("Failed to push cell data")
|
return failure("Failed to push cell data")
|
||||||
|
|
||||||
var
|
var proofPtr: ptr Proof = nil
|
||||||
proofPtr: ptr Proof = nil
|
|
||||||
|
|
||||||
let proof =
|
let proof =
|
||||||
try:
|
try:
|
||||||
if (
|
if (let res = self.backendCfg.proveCircuit(ctx, proofPtr.addr); res != ERR_OK) or
|
||||||
let res = self.backendCfg.proveCircuit(ctx, proofPtr.addr);
|
|
||||||
res != ERR_OK) or
|
|
||||||
proofPtr == nil:
|
proofPtr == nil:
|
||||||
return failure("Failed to prove - err code: " & $res)
|
return failure("Failed to prove - err code: " & $res)
|
||||||
|
|
||||||
|
@ -146,10 +135,20 @@ proc prove*[H](
|
||||||
|
|
||||||
success proof
|
success proof
|
||||||
|
|
||||||
proc verify*[H](
|
proc toCircomInputs*(inputs: ProofInputs[Poseidon2Hash]): Inputs =
|
||||||
self: CircomCompat,
|
var
|
||||||
proof: CircomProof,
|
slotIndex = inputs.slotIndex.toF.toBytes.toArray32
|
||||||
inputs: ProofInputs[H]): ?!bool =
|
datasetRoot = inputs.datasetRoot.toBytes.toArray32
|
||||||
|
entropy = inputs.entropy.toBytes.toArray32
|
||||||
|
|
||||||
|
elms = [entropy, datasetRoot, slotIndex]
|
||||||
|
|
||||||
|
let inputsPtr = allocShared0(32 * elms.len)
|
||||||
|
copyMem(inputsPtr, addr elms[0], elms.len * 32)
|
||||||
|
|
||||||
|
CircomInputs(elms: cast[ptr array[32, byte]](inputsPtr), len: elms.len.uint)
|
||||||
|
|
||||||
|
proc verify*(self: CircomCompat, proof: CircomProof, inputs: ProofInputs[H]): ?!bool =
|
||||||
## Verify a proof using a ctx
|
## Verify a proof using a ctx
|
||||||
##
|
##
|
||||||
|
|
||||||
|
@ -170,48 +169,49 @@ proc verify*[H](
|
||||||
|
|
||||||
proc init*(
|
proc init*(
|
||||||
_: type CircomCompat,
|
_: type CircomCompat,
|
||||||
r1csPath : string,
|
r1csPath: string,
|
||||||
wasmPath : string,
|
wasmPath: string,
|
||||||
zkeyPath : string = "",
|
zkeyPath: string = "",
|
||||||
slotDepth = DefaultMaxSlotDepth,
|
slotDepth = DefaultMaxSlotDepth,
|
||||||
datasetDepth = DefaultMaxDatasetDepth,
|
datasetDepth = DefaultMaxDatasetDepth,
|
||||||
blkDepth = DefaultBlockDepth,
|
blkDepth = DefaultBlockDepth,
|
||||||
cellElms = DefaultCellElms,
|
cellElms = DefaultCellElms,
|
||||||
numSamples = DefaultSamplesNum): CircomCompat =
|
numSamples = DefaultSamplesNum,
|
||||||
|
): CircomCompat =
|
||||||
## Create a new ctx
|
## Create a new ctx
|
||||||
##
|
##
|
||||||
|
|
||||||
var cfg: ptr CircomBn254Cfg
|
var cfg: ptr CircomBn254Cfg
|
||||||
var zkey = if zkeyPath.len > 0: zkeyPath.cstring else: nil
|
var zkey = if zkeyPath.len > 0: zkeyPath.cstring else: nil
|
||||||
|
|
||||||
if initCircomConfig(
|
if initCircomConfig(r1csPath.cstring, wasmPath.cstring, zkey, cfg.addr) != ERR_OK or
|
||||||
r1csPath.cstring,
|
cfg == nil:
|
||||||
wasmPath.cstring,
|
if cfg != nil:
|
||||||
zkey, cfg.addr) != ERR_OK or cfg == nil:
|
cfg.addr.releaseCfg()
|
||||||
if cfg != nil: cfg.addr.releaseCfg()
|
|
||||||
raiseAssert("failed to initialize circom compat config")
|
raiseAssert("failed to initialize circom compat config")
|
||||||
|
|
||||||
var
|
var vkpPtr: ptr VerifyingKey = nil
|
||||||
vkpPtr: ptr VerifyingKey = nil
|
|
||||||
|
|
||||||
if cfg.getVerifyingKey(vkpPtr.addr) != ERR_OK or vkpPtr == nil:
|
if cfg.getVerifyingKey(vkpPtr.addr) != ERR_OK or vkpPtr == nil:
|
||||||
if vkpPtr != nil: vkpPtr.addr.releaseKey()
|
if vkpPtr != nil:
|
||||||
|
vkpPtr.addr.releaseKey()
|
||||||
raiseAssert("Failed to get verifying key")
|
raiseAssert("Failed to get verifying key")
|
||||||
|
|
||||||
CircomCompat(
|
CircomCompat(
|
||||||
r1csPath : r1csPath,
|
r1csPath: r1csPath,
|
||||||
wasmPath : wasmPath,
|
wasmPath: wasmPath,
|
||||||
zkeyPath : zkeyPath,
|
zkeyPath: zkeyPath,
|
||||||
slotDepth : slotDepth,
|
slotDepth: slotDepth,
|
||||||
datasetDepth: datasetDepth,
|
datasetDepth: datasetDepth,
|
||||||
blkDepth : blkDepth,
|
blkDepth: blkDepth,
|
||||||
cellElms : cellElms,
|
cellElms: cellElms,
|
||||||
numSamples : numSamples,
|
numSamples: numSamples,
|
||||||
backendCfg : cfg,
|
backendCfg: cfg,
|
||||||
vkp : vkpPtr)
|
vkp: vkpPtr,
|
||||||
|
)
|
||||||
|
|
||||||
proc runArkCircom(
|
proc runArkCircom(
|
||||||
args: CircuitArgs, files: CircuitFiles, proofInputs: ProofInputs[Poseidon2Hash]
|
args: CircuitArgs, files: CircomCircuit, proofInputs: ProofInputs[Poseidon2Hash]
|
||||||
) =
|
) =
|
||||||
echo "Loading sample proof..."
|
echo "Loading sample proof..."
|
||||||
var circom = CircomCompat.init(
|
var circom = CircomCompat.init(
|
||||||
|
@ -250,7 +250,7 @@ proc printHelp() =
|
||||||
|
|
||||||
quit(1)
|
quit(1)
|
||||||
|
|
||||||
proc parseCliOptions(args: var CircuitArgs, files: var CircuitFiles) =
|
proc parseCliOptions(args: var CircuitArgs, files: var CircomCircuit) =
|
||||||
var argCtr: int = 0
|
var argCtr: int = 0
|
||||||
template expectPath(val: string): string =
|
template expectPath(val: string): string =
|
||||||
if val == "":
|
if val == "":
|
||||||
|
@ -299,7 +299,7 @@ proc run*() =
|
||||||
|
|
||||||
var
|
var
|
||||||
args = CircuitArgs()
|
args = CircuitArgs()
|
||||||
files = CircuitFiles()
|
files = CircomCircuit()
|
||||||
|
|
||||||
parseCliOptions(args, files)
|
parseCliOptions(args, files)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue