mirror of
https://github.com/logos-storage/logos-storage-nim.git
synced 2026-01-05 15:03:07 +00:00
323 lines
9.7 KiB
Nim
323 lines
9.7 KiB
Nim
## Nim-Codex
|
|
## Copyright (c) 2024 Status Research & Development GmbH
|
|
## Licensed under either of
|
|
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
|
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
|
## at your option.
|
|
## This file may not be copied, modified, or distributed except according to
|
|
## those terms.
|
|
|
|
{.push raises: [].}
|
|
|
|
import std/[sugar, atomics, locks]
|
|
|
|
import pkg/chronos
|
|
import pkg/taskpools
|
|
import pkg/chronos/threadsync
|
|
import pkg/questionable/results
|
|
import pkg/circomcompat
|
|
|
|
import ../../types
|
|
import ../../../stores
|
|
import ../../../contracts
|
|
|
|
import ./converters
|
|
|
|
export circomcompat, converters
|
|
export taskpools
|
|
|
|
type
|
|
CircomCompat* = object
|
|
slotDepth: int # max depth of the slot tree
|
|
datasetDepth: int # max depth of dataset tree
|
|
blkDepth: int # depth of the block merkle tree (pow2 for now)
|
|
cellElms: int # number of field elements per cell
|
|
numSamples: int # number of samples per slot
|
|
r1csPath: string # path to the r1cs file
|
|
wasmPath: string # path to the wasm file
|
|
zkeyPath: string # path to the zkey file
|
|
backendCfg: ptr CircomBn254Cfg
|
|
vkp*: ptr CircomKey
|
|
taskpool: Taskpool
|
|
|
|
NormalizedProofInputs*[H] {.borrow: `.`.} = distinct ProofInputs[H]
|
|
|
|
ProveTask = object
|
|
circom: ptr CircomCompat
|
|
ctx: ptr CircomCompatCtx
|
|
proofPtr: ProofPtr
|
|
res: Atomic[int32]
|
|
signal: ThreadSignalPtr
|
|
|
|
VerifyTask = object
|
|
proof: ptr CircomProof
|
|
vkp: ptr CircomKey
|
|
inputs: ptr CircomInputs
|
|
res: Atomic[int32]
|
|
signal: ThreadSignalPtr
|
|
|
|
func normalizeInput*[H](
|
|
self: CircomCompat, input: ProofInputs[H]
|
|
): NormalizedProofInputs[H] =
|
|
## Parameters in CIRCOM circuits are statically sized and must be properly
|
|
## padded before they can be passed onto the circuit. This function takes
|
|
## variable length parameters and performs that padding.
|
|
##
|
|
## The output from this function can be JSON-serialized and used as direct
|
|
## inputs to the CIRCOM circuit for testing and debugging when one wishes
|
|
## to bypass the Rust FFI.
|
|
|
|
let normSamples = collect:
|
|
for sample in input.samples:
|
|
var merklePaths = sample.merklePaths
|
|
merklePaths.setLen(self.slotDepth)
|
|
Sample[H](cellData: sample.cellData, merklePaths: merklePaths)
|
|
|
|
var normSlotProof = input.slotProof
|
|
normSlotProof.setLen(self.datasetDepth)
|
|
|
|
NormalizedProofInputs[H] ProofInputs[H](
|
|
entropy: input.entropy,
|
|
datasetRoot: input.datasetRoot,
|
|
slotIndex: input.slotIndex,
|
|
slotRoot: input.slotRoot,
|
|
nCellsPerSlot: input.nCellsPerSlot,
|
|
nSlotsPerDataSet: input.nSlotsPerDataSet,
|
|
slotProof: normSlotProof,
|
|
samples: normSamples,
|
|
)
|
|
|
|
proc release*(self: CircomCompat) =
|
|
## Release the ctx
|
|
##
|
|
|
|
if not isNil(self.backendCfg):
|
|
self.backendCfg.unsafeAddr.release_cfg()
|
|
|
|
if not isNil(self.vkp):
|
|
self.vkp.unsafeAddr.release_key()
|
|
|
|
proc circomProveTask(task: ptr ProveTask) {.gcsafe.} =
|
|
defer:
|
|
discard task[].signal.fireSync()
|
|
|
|
let res = task.circom.backendCfg.prove_circuit(task[].ctx, task[].proofPtr.addr)
|
|
task.res.store(res)
|
|
|
|
proc prove*[H](
|
|
self: CircomCompat, input: ProofInputs[H]
|
|
): Future[?!CircomProof] {.async, raises: [CancelledError].} =
|
|
## Prove a circuit using a ctx
|
|
##
|
|
|
|
var input = self.normalizeInput(input)
|
|
doAssert input.samples.len == self.numSamples, "Number of samples does not match"
|
|
|
|
doAssert input.slotProof.len <= self.datasetDepth,
|
|
"Slot proof is too deep - dataset has more slots than what we can handle?"
|
|
|
|
doAssert input.samples.allIt(
|
|
block:
|
|
(
|
|
it.merklePaths.len <= self.slotDepth + self.blkDepth and
|
|
it.cellData.len == self.cellElms
|
|
)
|
|
), "Merkle paths too deep or cells too big for circuit"
|
|
|
|
# TODO: All parameters should match circom's static parametter
|
|
var ctx: ptr CircomCompatCtx
|
|
|
|
defer:
|
|
if ctx != nil:
|
|
ctx.addr.release_circom_compat()
|
|
|
|
if init_circom_compat(self.backendCfg, addr ctx) != ERR_OK or ctx == nil:
|
|
raiseAssert("failed to initialize CircomCompat ctx")
|
|
|
|
var
|
|
entropy = input.entropy.toBytes
|
|
dataSetRoot = input.datasetRoot.toBytes
|
|
slotRoot = input.slotRoot.toBytes
|
|
|
|
if ctx.push_input_u256_array("entropy".cstring, entropy[0].addr, entropy.len.uint32) !=
|
|
ERR_OK:
|
|
return failure("Failed to push entropy")
|
|
|
|
if ctx.push_input_u256_array(
|
|
"dataSetRoot".cstring, dataSetRoot[0].addr, dataSetRoot.len.uint32
|
|
) != ERR_OK:
|
|
return failure("Failed to push data set root")
|
|
|
|
if ctx.push_input_u256_array(
|
|
"slotRoot".cstring, slotRoot[0].addr, slotRoot.len.uint32
|
|
) != ERR_OK:
|
|
return failure("Failed to push data set root")
|
|
|
|
if ctx.push_input_u32("nCellsPerSlot".cstring, input.nCellsPerSlot.uint32) != ERR_OK:
|
|
return failure("Failed to push nCellsPerSlot")
|
|
|
|
if ctx.push_input_u32("nSlotsPerDataSet".cstring, input.nSlotsPerDataSet.uint32) !=
|
|
ERR_OK:
|
|
return failure("Failed to push nSlotsPerDataSet")
|
|
|
|
if ctx.push_input_u32("slotIndex".cstring, input.slotIndex.uint32) != ERR_OK:
|
|
return failure("Failed to push slotIndex")
|
|
|
|
var slotProof = input.slotProof.mapIt(it.toBytes).concat
|
|
|
|
doAssert(slotProof.len == self.datasetDepth)
|
|
# arrays are always flattened
|
|
if ctx.push_input_u256_array(
|
|
"slotProof".cstring, slotProof[0].addr, uint (slotProof[0].len * slotProof.len)
|
|
) != ERR_OK:
|
|
return failure("Failed to push slot proof")
|
|
|
|
for s in input.samples:
|
|
var
|
|
merklePaths = s.merklePaths.mapIt(@(it.toBytes)).concat
|
|
data = s.cellData.mapIt(@(it.toBytes)).concat
|
|
|
|
if ctx.push_input_u256_array(
|
|
"merklePaths".cstring, merklePaths[0].addr, uint (merklePaths.len)
|
|
) != ERR_OK:
|
|
return failure("Failed to push merkle paths")
|
|
|
|
if ctx.push_input_u256_array("cellData".cstring, data[0].addr, data.len.uint) !=
|
|
ERR_OK:
|
|
return failure("Failed to push cell data")
|
|
|
|
without threadPtr =? ThreadSignalPtr.new().mapFailure, err:
|
|
trace "Failed to create thread signal", err = err.msg
|
|
return failure("Unable to create thread signal")
|
|
|
|
var task = ProveTask(circom: addr self, ctx: ctx, signal: threadPtr)
|
|
|
|
defer:
|
|
threadPtr.close().expect("closing once works")
|
|
if task.proofPtr != nil:
|
|
task.proofPtr.addr.release_proof()
|
|
|
|
doAssert task.circom.taskpool.numThreads > 1,
|
|
"Must have at least one separate thread or signal will never be fired"
|
|
task.circom.taskpool.spawn circomProveTask(addr task)
|
|
let threadFut = threadPtr.wait()
|
|
|
|
if joinErr =? catch(await threadFut.join()).errorOption:
|
|
if err =? catch(await noCancel threadFut).errorOption:
|
|
trace "Failed to prove circuit", err = err.msg
|
|
return failure(err)
|
|
if joinErr of CancelledError:
|
|
trace "Cancellation requested for proving circuit, re-raising"
|
|
raise joinErr
|
|
else:
|
|
return failure(joinErr)
|
|
|
|
if task.res.load() != ERR_OK:
|
|
return failure("Failed to prove circuit")
|
|
|
|
success(task.proofPtr[])
|
|
|
|
proc circomVerifyTask(task: ptr VerifyTask) {.gcsafe.} =
|
|
defer:
|
|
task[].inputs[].releaseCircomInputs()
|
|
discard task[].signal.fireSync()
|
|
|
|
let res = verify_circuit(task[].proof, task[].inputs, task[].vkp)
|
|
task.res.store(res)
|
|
|
|
proc verify*[H](
|
|
self: CircomCompat, proof: CircomProof, inputs: ProofInputs[H]
|
|
): Future[?!bool] {.async, raises: [CancelledError].} =
|
|
## Verify a proof using a ctx
|
|
##
|
|
try:
|
|
var
|
|
proofPtr = unsafeAddr proof
|
|
inputs = inputs.toCircomInputs()
|
|
|
|
without threadPtr =? ThreadSignalPtr.new().mapFailure, err:
|
|
trace "Failed to create thread signal", err = err.msg
|
|
return failure("Unable to create thread signal")
|
|
|
|
defer:
|
|
threadPtr.close().expect("closing once works")
|
|
inputs.releaseCircomInputs()
|
|
|
|
var task =
|
|
VerifyTask(proof: proofPtr, vkp: self.vkp, inputs: addr inputs, signal: threadPtr)
|
|
|
|
doAssert self.taskpool.numThreads > 1,
|
|
"Must have at least one separate thread or signal will never be fired"
|
|
|
|
self.taskpool.spawn circomVerifyTask(addr task)
|
|
|
|
let threadFut = threadPtr.wait()
|
|
|
|
if joinErr =? catch(await threadFut.join()).errorOption:
|
|
if err =? catch(await noCancel threadFut).errorOption:
|
|
trace "Error verifying proof", err = err.msg
|
|
return failure(err)
|
|
if joinErr of CancelledError:
|
|
trace "Cancellation requested for verifying proof, re-raising"
|
|
raise joinErr
|
|
else:
|
|
return failure(joinErr)
|
|
|
|
let res = task.res.load()
|
|
case res
|
|
of ERR_FAILED_TO_VERIFY_PROOF:
|
|
trace "Failed to verify proof", res
|
|
return success(false)
|
|
of ERR_OK:
|
|
return success(true)
|
|
else:
|
|
trace "Unknown error verifying proof", res
|
|
return failure("Unknown error")
|
|
except CancelledError as exc:
|
|
raise exc
|
|
|
|
proc init*(
|
|
_: type CircomCompat,
|
|
r1csPath: string,
|
|
wasmPath: string,
|
|
zkeyPath: string = "",
|
|
slotDepth = DefaultMaxSlotDepth,
|
|
datasetDepth = DefaultMaxDatasetDepth,
|
|
blkDepth = DefaultBlockDepth,
|
|
cellElms = DefaultCellElms,
|
|
numSamples = DefaultSamplesNum,
|
|
taskpool: Taskpool,
|
|
): CircomCompat =
|
|
## Create a new ctx
|
|
##
|
|
|
|
var cfg: ptr CircomBn254Cfg
|
|
var zkey = if zkeyPath.len > 0: zkeyPath.cstring else: nil
|
|
|
|
if init_circom_config(r1csPath.cstring, wasmPath.cstring, zkey, cfg.addr) != ERR_OK or
|
|
cfg == nil:
|
|
if cfg != nil:
|
|
cfg.addr.release_cfg()
|
|
raiseAssert("failed to initialize circom compat config")
|
|
|
|
var vkpPtr: ptr VerifyingKey = nil
|
|
|
|
if cfg.get_verifying_key(vkpPtr.addr) != ERR_OK or vkpPtr == nil:
|
|
if vkpPtr != nil:
|
|
vkpPtr.addr.release_key()
|
|
raiseAssert("Failed to get verifying key")
|
|
|
|
CircomCompat(
|
|
r1csPath: r1csPath,
|
|
wasmPath: wasmPath,
|
|
zkeyPath: zkeyPath,
|
|
slotDepth: slotDepth,
|
|
datasetDepth: datasetDepth,
|
|
blkDepth: blkDepth,
|
|
cellElms: cellElms,
|
|
numSamples: numSamples,
|
|
backendCfg: cfg,
|
|
vkp: vkpPtr,
|
|
taskpool: taskpool,
|
|
)
|