Fixes prover behavior with singleton proof trees (#859)

* add logs and test

* add Merkle proof checks

* factor out Circom input normalization, fix proof input serialization

* add test and update existing ones

* update circuit assets

* add back trace message

* switch contracts to fix branch

* update codex-contracts-eth to latest

* do not expose prove with prenormalized inputs
This commit is contained in:
Giuliano Mega 2024-07-18 10:25:06 -03:00 committed by GitHub
parent 8f740b42e6
commit fbce240e3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 111 additions and 87 deletions

View File

@ -9,17 +9,14 @@
{.push raises: [].} {.push raises: [].}
import std/sequtils import std/sugar
import pkg/chronos import pkg/chronos
import pkg/questionable/results import pkg/questionable/results
import pkg/circomcompat import pkg/circomcompat
import pkg/poseidon2/io
import ../../types import ../../types
import ../../../stores import ../../../stores
import ../../../merkletree
import ../../../codextypes
import ../../../contracts import ../../../contracts
import ./converters import ./converters
@ -39,6 +36,41 @@ type
backendCfg : ptr CircomBn254Cfg backendCfg : ptr CircomBn254Cfg
vkp* : ptr CircomKey vkp* : ptr CircomKey
NormalizedProofInputs*[H] {.borrow: `.`.} = distinct ProofInputs[H]
func normalizeInput*[H](self: CircomCompat, input: ProofInputs[H]):
NormalizedProofInputs[H] =
## Parameters in CIRCOM circuits are statically sized and must be properly
## padded before they can be passed onto the circuit. This function takes
## variable length parameters and performs that padding.
##
## The output from this function can be JSON-serialized and used as direct
## inputs to the CIRCOM circuit for testing and debugging when one wishes
## to bypass the Rust FFI.
let normSamples = collect:
for sample in input.samples:
var merklePaths = sample.merklePaths
merklePaths.setLen(self.slotDepth)
Sample[H](
cellData: sample.cellData,
merklePaths: merklePaths
)
var normSlotProof = input.slotProof
normSlotProof.setLen(self.datasetDepth)
NormalizedProofInputs[H] ProofInputs[H](
entropy: input.entropy,
datasetRoot: input.datasetRoot,
slotIndex: input.slotIndex,
slotRoot: input.slotRoot,
nCellsPerSlot: input.nCellsPerSlot,
nSlotsPerDataSet: input.nSlotsPerDataSet,
slotProof: normSlotProof,
samples: normSamples
)
proc release*(self: CircomCompat) = proc release*(self: CircomCompat) =
## Release the ctx ## Release the ctx
## ##
@ -49,27 +81,20 @@ proc release*(self: CircomCompat) =
if not isNil(self.vkp): if not isNil(self.vkp):
self.vkp.unsafeAddr.release_key() self.vkp.unsafeAddr.release_key()
proc prove*[H]( proc prove[H](
self: CircomCompat, self: CircomCompat,
input: ProofInputs[H]): ?!CircomProof = input: NormalizedProofInputs[H]): ?!CircomProof =
## Encode buffers using a ctx
##
# NOTE: All inputs are statically sized per circuit
# and adjusted accordingly right before being passed
# to the circom ffi - `setLen` is used to adjust the
# sequence length to the correct size which also 0 pads
# to the correct length
doAssert input.samples.len == self.numSamples, doAssert input.samples.len == self.numSamples,
"Number of samples does not match" "Number of samples does not match"
doAssert input.slotProof.len <= self.datasetDepth, doAssert input.slotProof.len <= self.datasetDepth,
"Number of slot proofs does not match" "Slot proof is too deep - dataset has more slots than what we can handle?"
doAssert input.samples.allIt( doAssert input.samples.allIt(
block: block:
(it.merklePaths.len <= self.slotDepth + self.blkDepth and (it.merklePaths.len <= self.slotDepth + self.blkDepth and
it.cellData.len <= self.cellElms * 32)), "Merkle paths length does not match" it.cellData.len == self.cellElms)), "Merkle paths too deep or cells too big for circuit"
# TODO: All parameters should match circom's static parametter # TODO: All parameters should match circom's static parametter
var var
@ -116,8 +141,7 @@ proc prove*[H](
var var
slotProof = input.slotProof.mapIt( it.toBytes ).concat slotProof = input.slotProof.mapIt( it.toBytes ).concat
slotProof.setLen(self.datasetDepth) # zero pad inputs to correct size doAssert(slotProof.len == self.datasetDepth)
# arrays are always flattened # arrays are always flattened
if ctx.pushInputU256Array( if ctx.pushInputU256Array(
"slotProof".cstring, "slotProof".cstring,
@ -128,16 +152,14 @@ proc prove*[H](
for s in input.samples: for s in input.samples:
var var
merklePaths = s.merklePaths.mapIt( it.toBytes ) merklePaths = s.merklePaths.mapIt( it.toBytes )
data = s.cellData data = s.cellData.mapIt( @(it.toBytes) ).concat
merklePaths.setLen(self.slotDepth) # zero pad inputs to correct size
if ctx.pushInputU256Array( if ctx.pushInputU256Array(
"merklePaths".cstring, "merklePaths".cstring,
merklePaths[0].addr, merklePaths[0].addr,
uint (merklePaths[0].len * merklePaths.len)) != ERR_OK: uint (merklePaths[0].len * merklePaths.len)) != ERR_OK:
return failure("Failed to push merkle paths") return failure("Failed to push merkle paths")
data.setLen(self.cellElms * 32) # zero pad inputs to correct size
if ctx.pushInputU256Array( if ctx.pushInputU256Array(
"cellData".cstring, "cellData".cstring,
data[0].addr, data[0].addr,
@ -162,6 +184,12 @@ proc prove*[H](
success proof success proof
proc prove*[H](
self: CircomCompat,
input: ProofInputs[H]): ?!CircomProof =
self.prove(self.normalizeInput(input))
proc verify*[H]( proc verify*[H](
self: CircomCompat, self: CircomCompat,
proof: CircomProof, proof: CircomProof,

View File

@ -38,7 +38,7 @@ type
func getCell*[T, H]( func getCell*[T, H](
self: DataSampler[T, H], self: DataSampler[T, H],
blkBytes: seq[byte], blkBytes: seq[byte],
blkCellIdx: Natural): seq[byte] = blkCellIdx: Natural): seq[H] =
let let
cellSize = self.builder.cellSize.uint64 cellSize = self.builder.cellSize.uint64
@ -47,7 +47,7 @@ func getCell*[T, H](
doAssert (dataEnd - dataStart) == cellSize, "Invalid cell size" doAssert (dataEnd - dataStart) == cellSize, "Invalid cell size"
toInputData[H](blkBytes[dataStart ..< dataEnd]) blkBytes[dataStart ..< dataEnd].elements(H).toSeq()
proc getSample*[T, H]( proc getSample*[T, H](
self: DataSampler[T, H], self: DataSampler[T, H],

View File

@ -7,23 +7,13 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import std/sugar
import std/bitops import std/bitops
import std/sequtils
import pkg/questionable/results import pkg/questionable/results
import pkg/poseidon2
import pkg/poseidon2/io
import pkg/constantine/math/arithmetic import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
import ../../merkletree import ../../merkletree
func toInputData*[H](data: seq[byte]): seq[byte] =
return toSeq(data.elements(H)).mapIt( @(it.toBytes) ).concat
func extractLowBits*[n: static int](elm: BigInt[n], k: int): uint64 = func extractLowBits*[n: static int](elm: BigInt[n], k: int): uint64 =
doAssert( k > 0 and k <= 64 ) doAssert( k > 0 and k <= 64 )
var r = 0'u64 var r = 0'u64
@ -39,6 +29,7 @@ func extractLowBits(fld: Poseidon2Hash, k: int): uint64 =
return extractLowBits(elm, k); return extractLowBits(elm, k);
func floorLog2*(x : int) : int = func floorLog2*(x : int) : int =
doAssert ( x > 0 )
var k = -1 var k = -1
var y = x var y = x
while (y > 0): while (y > 0):
@ -47,10 +38,8 @@ func floorLog2*(x : int) : int =
return k return k
func ceilingLog2*(x : int) : int = func ceilingLog2*(x : int) : int =
if (x == 0): doAssert ( x > 0 )
return -1 return (floorLog2(x - 1) + 1)
else:
return (floorLog2(x-1) + 1)
func toBlkInSlot*(cell: Natural, numCells: Natural): Natural = func toBlkInSlot*(cell: Natural, numCells: Natural): Natural =
let log2 = ceilingLog2(numCells) let log2 = ceilingLog2(numCells)
@ -80,7 +69,7 @@ func cellIndices*(
numCells: Natural, nSamples: Natural): seq[Natural] = numCells: Natural, nSamples: Natural): seq[Natural] =
var indices: seq[Natural] var indices: seq[Natural]
while (indices.len < nSamples): for i in 1..nSamples:
let idx = cellIndex(entropy, slotRoot, numCells, indices.len + 1) indices.add(cellIndex(entropy, slotRoot, numCells, i))
indices.add(idx.Natural)
indices indices

View File

@ -9,7 +9,7 @@
type type
Sample*[H] = object Sample*[H] = object
cellData*: seq[byte] cellData*: seq[H]
merklePaths*: seq[H] merklePaths*: seq[H]
PublicInputs*[H] = object PublicInputs*[H] = object
@ -24,5 +24,5 @@ type
slotRoot*: H slotRoot*: H
nCellsPerSlot*: Natural nCellsPerSlot*: Natural
nSlotsPerDataSet*: Natural nSlotsPerDataSet*: Natural
slotProof*: seq[H] slotProof*: seq[H] # inclusion proof that shows that the slot root (leaf) is part of the dataset (root)
samples*: seq[Sample[H]] samples*: seq[Sample[H]] # inclusion proofs which show that the selected cells (leafs) are part of the slot (roots)

View File

@ -17,21 +17,6 @@ import pkg/codex/utils/json
export types export types
func fromCircomData*(_: type Poseidon2Hash, cellData: seq[byte]): seq[Poseidon2Hash] =
var
pos = 0
cellElms: seq[Bn254Fr]
while pos < cellData.len:
var
step = 32
offset = min(pos + step, cellData.len)
data = cellData[pos..<offset]
let ff = Bn254Fr.fromBytes(data.toArray32).get
cellElms.add(ff)
pos += data.len
cellElms
func toJsonDecimal*(big: BigInt[254]): string = func toJsonDecimal*(big: BigInt[254]): string =
let s = big.toDecimal.strip( leading = true, trailing = false, chars = {'0'} ) let s = big.toDecimal.strip( leading = true, trailing = false, chars = {'0'} )
if s.len == 0: "0" else: s if s.len == 0: "0" else: s
@ -78,13 +63,16 @@ func toJson*(input: ProofInputs[Poseidon2Hash]): JsonNode =
"slotRoot": input.slotRoot.toDecimal, "slotRoot": input.slotRoot.toDecimal,
"slotProof": input.slotProof.mapIt( it.toBig.toJsonDecimal ), "slotProof": input.slotProof.mapIt( it.toBig.toJsonDecimal ),
"cellData": input.samples.mapIt( "cellData": input.samples.mapIt(
toSeq( it.cellData.elements(Poseidon2Hash) ).mapIt( it.toBig.toJsonDecimal ) it.cellData.mapIt( it.toBig.toJsonDecimal )
), ),
"merklePaths": input.samples.mapIt( "merklePaths": input.samples.mapIt(
it.merklePaths.mapIt( it.toBig.toJsonDecimal ) it.merklePaths.mapIt( it.toBig.toJsonDecimal )
) )
} }
func toJson*(input: NormalizedProofInputs[Poseidon2Hash]): JsonNode =
toJson(ProofInputs[Poseidon2Hash](input))
func jsonToProofInput*(_: type Poseidon2Hash, inputJson: JsonNode): ProofInputs[Poseidon2Hash] = func jsonToProofInput*(_: type Poseidon2Hash, inputJson: JsonNode): ProofInputs[Poseidon2Hash] =
let let
cellData = cellData =
@ -93,10 +81,12 @@ func jsonToProofInput*(_: type Poseidon2Hash, inputJson: JsonNode): ProofInputs[
block: block:
var var
big: BigInt[256] big: BigInt[256]
data = newSeq[byte](big.bits div 8) hash: Poseidon2Hash
data: array[32, byte]
assert bool(big.fromDecimal( it.str )) assert bool(big.fromDecimal( it.str ))
data.marshal(big, littleEndian) assert data.marshal(big, littleEndian)
data
Poseidon2Hash.fromBytes(data).get
).concat # flatten out elements ).concat # flatten out elements
) )

View File

@ -58,7 +58,7 @@ suite "Test Sampler - control samples":
proofInput.nCellsPerSlot, proofInput.nCellsPerSlot,
sample.merklePaths[5..<9]).tryGet sample.merklePaths[5..<9]).tryGet
cellData = Poseidon2Hash.fromCircomData(sample.cellData) cellData = sample.cellData
cellLeaf = Poseidon2Hash.spongeDigest(cellData, rate = 2).tryGet cellLeaf = Poseidon2Hash.spongeDigest(cellData, rate = 2).tryGet
slotLeaf = cellProof.reconstructRoot(cellLeaf).tryGet slotLeaf = cellProof.reconstructRoot(cellLeaf).tryGet
@ -158,7 +158,7 @@ suite "Test Sampler":
nSlotCells, nSlotCells,
sample.merklePaths[5..<sample.merklePaths.len]).tryGet sample.merklePaths[5..<sample.merklePaths.len]).tryGet
cellData = Poseidon2Hash.fromCircomData(sample.cellData) cellData = sample.cellData
cellLeaf = Poseidon2Hash.spongeDigest(cellData, rate = 2).tryGet cellLeaf = Poseidon2Hash.spongeDigest(cellData, rate = 2).tryGet
slotLeaf = cellProof.reconstructRoot(cellLeaf).tryGet slotLeaf = cellProof.reconstructRoot(cellLeaf).tryGet

View File

@ -24,23 +24,19 @@ import ./backends/helpers
suite "Test Prover": suite "Test Prover":
let let
slotId = 1
samples = 5 samples = 5
ecK = 3
ecM = 2
numDatasetBlocks = 8
blockSize = DefaultBlockSize blockSize = DefaultBlockSize
cellSize = DefaultCellSize cellSize = DefaultCellSize
repoTmp = TempLevelDb.new() repoTmp = TempLevelDb.new()
metaTmp = TempLevelDb.new() metaTmp = TempLevelDb.new()
r1cs = "tests/circuits/fixtures/proof_main.r1cs"
wasm = "tests/circuits/fixtures/proof_main.wasm"
circomBackend = CircomCompat.init(r1cs, wasm)
challenge = 1234567.toF.toBytes.toArray32
var var
datasetBlocks: seq[bt.Block]
store: BlockStore store: BlockStore
manifest: Manifest prover: Prover
protected: Manifest
verifiable: Manifest
sampler: Poseidon2Sampler
setup: setup:
let let
@ -48,14 +44,7 @@ suite "Test Prover":
metaDs = metaTmp.newDb() metaDs = metaTmp.newDb()
store = RepoStore.new(repoDs, metaDs) store = RepoStore.new(repoDs, metaDs)
prover = Prover.new(store, circomBackend, samples)
(manifest, protected, verifiable) =
await createVerifiableManifest(
store,
numDatasetBlocks,
ecK, ecM,
blockSize,
cellSize)
teardown: teardown:
await repoTmp.destroyDb() await repoTmp.destroyDb()
@ -63,13 +52,41 @@ suite "Test Prover":
test "Should sample and prove a slot": test "Should sample and prove a slot":
let let
r1cs = "tests/circuits/fixtures/proof_main.r1cs" (_, _, verifiable) =
wasm = "tests/circuits/fixtures/proof_main.wasm" await createVerifiableManifest(
store,
8, # number of blocks in the original dataset (before EC)
5, # ecK
3, # ecM
blockSize,
cellSize)
circomBackend = CircomCompat.init(r1cs, wasm) let
prover = Prover.new(store, circomBackend, samples) (inputs, proof) = (
challenge = 1234567.toF.toBytes.toArray32 await prover.prove(1, verifiable, challenge)).tryGet
(inputs, proof) = (await prover.prove(1, verifiable, challenge)).tryGet
check:
(await prover.verify(proof, inputs)).tryGet == true
test "Should generate valid proofs when slots consist of single blocks":
# To get single-block slots, we just need to set the number of blocks in
# the original dataset to be the same as ecK. The total number of blocks
# after generating random data for parity will be ecK + ecM, which will
# match the number of slots.
let
(_, _, verifiable) =
await createVerifiableManifest(
store,
2, # number of blocks in the original dataset (before EC)
2, # ecK
1, # ecM
blockSize,
cellSize)
let
(inputs, proof) = (
await prover.prove(1, verifiable, challenge)).tryGet
check: check:
(await prover.verify(proof, inputs)).tryGet == true (await prover.verify(proof, inputs)).tryGet == true

@ -1 +1 @@
Subproject commit 57e8cd5013325f05e16833a5320b575d32a403f3 Subproject commit 7ad26688a3b75b914d626e2623174a36f4425f51