proof input generation reference impl. in Nim (WIP, untested!)

This commit is contained in:
Balazs Komuves 2023-11-17 18:17:26 +01:00
parent fad92b6c75
commit 9e61c14e5d
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
11 changed files with 631 additions and 0 deletions

3
reference/nim/proof_input/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
.DS_Store
testmain
*.json

View File

@ -0,0 +1,11 @@
version = "0.0.1"
author = "Balazs Komuves"
description = "reference implementation for generating the proof inputs"
license = "MIT or Apache-2.0"
srcDir = "src"
bin = @["testmain"]
requires "nim >= 1.6.0"
requires "https://github.com/mratsim/constantine"
requires "https://github.com/codex-storage/nim-poseidon2#596f7b18070b44ca0bf305bf9bdf1dc4f6011181"

View File

@ -0,0 +1,56 @@
import sugar
import std/sequtils
#import poseidon2/types
import poseidon2/io
import poseidon2/sponge
import poseidon2/merkle
import types
import merkle
#-------------------------------------------------------------------------------
func hashCellOpen( cellData: openArray[byte] ): Hash =
assert( cellData.len == cellSize , "cells are expected to be exactly 2048 bytes" )
return Sponge.digest( cellData, rate=2 )
func hashCell*(cellData: Cell): Hash = hashCellOpen(cellData)
#-------------------------------------------------------------------------------
func splitBlockIntoCells( blockData: openArray[byte] ): seq[Cell] =
assert( blockData.len == blockSize , "network blocks are expected to be exactly 65536 bytes" )
var cells : seq[seq[byte]] = newSeq[seq[byte]]( cellsPerBlock )
let start = low(blockData)
var leaves : seq[Hash] = newSeq[Hash]( cellsPerBlock )
for i in 0..<cellsPerBlock:
let a = start + i * cellSize
let b = start + (i+1) * cellSize
cells[i] = blockData[a..<b].toSeq()
return cells
# returns the special hash of a network block (this is a Merkle root built on the
# top of the hashes of the 32 cells inside the block)
func hashNetworkBlockOpen( blockData: openArray[byte] ): Hash =
let cells = splitBlockIntoCells(blockData)
let leaves = collect( newSeq , (for i in 0..<cellsPerBlock: hashCell(cells[i]) ))
return merkleRoot(leaves)
func hashNetworkBlock*(blockData: Block): Hash = hashNetworkBlockOpen(blockData)
#-------------------------------------------------------------------------------
# returns the mini Merkle tree built on the 32 cells inside a network block
func networkBlockTreeOpen( blockData: openArray[byte] ): MerkleTree =
let cells = splitBlockIntoCells(blockData)
let leaves = collect( newSeq , (for i in 0..<cellsPerBlock: hashCell(cells[i]) ))
return merkleTree(leaves)
func networkBlockTree*(blockData: Block): MerkleTree = networkBlockTreeOpen(blockData)
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,50 @@
#
# generate the input data for the proof
# see `json.nim` to export it in Snarkjs-compatible format
#
import sugar
import std/sequtils
import blocks
import slot
import sample
import merkle
import types
#-------------------------------------------------------------------------------
proc generateProofInput*( cfg: SlotConfig, entropy: Entropy ): SlotProofInput =
let ncells = cfg.nCells
let nblocks = ncells div cellsPerBlock
assert( nblocks * cellsPerBlock == ncells )
let blocks : seq[Block] = collect( newSeq, (for i in 0..<nblocks: loadBlockData(cfg, i) ))
let miniTrees : seq[MerkleTree] = map( blocks , networkBlockTree )
let blockHashes : seq[Root] = map( miniTrees , treeRoot )
let bigTree = merkleTree( blockHashes )
let slotRoot = treeRoot( bigTree )
let indices = cellIndices(entropy, slotRoot, ncells, cfg.nSamples)
var inputs : seq[CellProofInput]
for cellIdx in indices:
let blockIdx = cellIdx div cellsPerBlock
let blockTree = miniTrees[ blockIdx ]
let cellData = loadCellData( cfg, cellIdx )
let botProof = merkleProof( blockTree , cellIdx mod cellsPerBlock )
let topProof = merkleProof( bigTree , blockIdx )
let prf = mergeMerkleProofs( botProof, topProof )
inputs.add( CellProofInput(cellData: cellData, merkleProof: prf) )
return SlotProofInput( slotRoot: slotRoot
, entropy: entropy
, nCells: ncells
, proofInputs: inputs
)
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,99 @@
#
# export the proof inputs as a JSON file suitable for `snarkjs`
#
import sugar
import std/strutils
import std/sequtils
import std/streams
from poseidon2/io import elements
import types
#-------------------------------------------------------------------------------
func toQuotedDecimalF(x: F): string =
let s : string = toDecimalF(x)
return ("\"" & s & "\"")
func mkIndent(foo: string): string =
return spaces(foo.len)
proc writeF(h: Stream, prefix: string, x: F) =
h.writeLine(prefix & toQuotedDecimalF(x))
#[
proc writeSeq(h: Stream, prefix: string, xs: seq[F])
let n = xs.len
let indent = mkIndent(prefix)
for i in 0..<n:
let str : string = toQuotedF( xs[i] )
if i==0:
h.writeLine(prefix & "[ " & str)
else:
h.writeLine(indent & ", " & str)
h.writeLine(indent & "] ")
]#
#-------------------------------------------------------------------------------
type
WriteFun[T] = proc (stream: Stream, prefix: string, what: T) {.closure.}
proc writeList[T](h: Stream, prefix: string, xs: seq[T], writeFun: WriteFun[T]) =
let n = xs.len
let indent = mkIndent(prefix)
for i in 0..<n:
if i==0:
writeFun(h, prefix & "[ ", xs[i])
else:
writeFun(h, indent & ", ", xs[i])
h.writeLine( indent & "]" )
proc writeFieldElems(h: Stream, prefix: string, xs: seq[F]) =
writeList[F]( h, prefix, xs, writeF )
#-------------------------------------------------------------------------------
proc writeSingleCellData(h: Stream, prefix:string , cell: Cell) =
let flds : seq[F] = cell.elements(F).toSeq()
writeFieldElems(h, prefix, flds)
proc writeAllCellData(h: Stream, cells: seq[Cell]) =
writeList(h, " ", cells, writeSingleCellData )
#-------------------------------------------------------------------------------
proc writeSingleMerklePath(h: Stream, prefix: string, path: MerkleProof) =
let flds = path.merklePath
writeFieldElems(h, prefix, flds)
proc writeAllMerklePaths(h: Stream, cells: seq[MerkleProof]) =
writeList(h, " ", cells, writeSingleMerklePath )
#-------------------------------------------------------------------------------
#[
signal input entropy; // public input
signal input slotRoot; // public input
signal input nCells; // public input
signal input cellData[nSamples][nFieldElemsPerCell]; // private input
signal input merklePaths[nSamples][depth]; // private input
]#
proc exportProofInput*(fname: string, prfInput: SlotProofInput) =
let h = openFileStream(fname, fmWrite)
defer: h.close()
h.writeLine("{")
h.writeLine(" \"slotRoot\": " & toQuotedDecimalF(prfInput.slotRoot) )
h.writeLine(" \"entropy\": " & toQuotedDecimalF(prfInput.entropy ) )
h.writeLine(" \"nCells\": " & $(prfInput.nCells) )
h.writeLine(" \"cellData\": ")
writeAllCellData(h, collect( newSeq , (for p in prfInput.proofInputs: p.cellData) ))
h.writeLine(" \"merklePaths\":")
writeAllMerklePaths(h, collect( newSeq , (for p in prfInput.proofInputs: p.merkleProof) ))
h.writeLine("}")

View File

@ -0,0 +1,149 @@
import std/bitops
import std/sequtils
import constantine/math/arithmetic
import constantine/math/io/io_fields
import poseidon2/types
import poseidon2/merkle
import poseidon2/compress
import types
#-------------------------------------------------------------------------------
func treeDepth*(tree: MerkleTree): int =
return tree.layers.len - 1
func treeNumberOfLeaves*(tree: MerkleTree): int =
return tree.layers[0].len
func treeRoot*(tree: MerkleTree): Hash =
let last = tree.layers[tree.layers.len-1]
assert( last.len == 1 )
return last[0]
#-------------------------------------------------------------------------------
func merkleProof*(tree: MerkleTree, index: int): MerkleProof =
let depth = treeDepth(tree)
let nleaves = treeNumberOfLeaves(tree)
assert( index >= 0 and index < nleaves )
var path : seq[Hash] = newSeq[Hash](depth)
var k = index
var m = nleaves
for i in 0..<depth:
let j = k xor 1
path[i] = if (j < m): tree.layers[i][j] else: zero
k = k shr 1
m = (m+1) shr 1
return MerkleProof( leafIndex: index
, leafValue: tree.layers[0][index]
, merklePath: path
, numberOfLeaves: nleaves
)
#-------------------------------------------------------------------------------
func compressWithKey(key: int, x: F, y: F): F =
compress(x,y, key=toF(key))
func reconstructRoot*(proof: MerkleProof): Hash =
var m : int = proof.numberOfLeaves
var j : int = proof.leafIndex
var h : Hash = proof.leafValue
var bottomFlag : int = 1
for p in proof.merklePath:
let oddIndex : bool = (bitand(j,1) != 0)
if oddIndex:
# the index of the child is odd, so the node tiself can't be odd (a bit counterintuitive, yeah :)
let key = bottomFlag
h = compressWithKey( key , p , h )
else:
if j==m-1:
# single child => odd node
let key = bottomFlag + 2
h = compressWithKey( key , h , p )
else:
# even node
let key = bottomFlag
h = compressWithKey( key , h , p )
bottomFlag = 0
j = j shr 1
m = (m+1) shr 1
return h
func checkMerkleProof*(root: Root, proof: MerkleProof): bool =
return bool(root == reconstructRoot(proof))
#-------------------------------------------------------------------------------
# TODO: maybe move this (and the rest?) into poseidon2-nim
const KeyNone = F.fromHex("0x0")
const KeyBottomLayer = F.fromHex("0x1")
const KeyOdd = F.fromHex("0x2")
const KeyOddAndBottomLayer = F.fromhex("0x3")
func merkleTreeWorker(xs: openArray[F], isBottomLayer: static bool) : seq[seq[F]] =
let a = low(xs)
let b = high(xs)
let m = b-a+1
when not isBottomLayer:
if m==1:
return @[ xs.toSeq() ]
let halfn : int = m div 2
let n : int = 2*halfn
let isOdd : bool = (n != m)
var ys : seq[F]
if not isOdd:
ys = newSeq[F](halfn)
else:
ys = newSeq[F](halfn+1)
for i in 0..<halfn:
const key = when isBottomLayer: KeyBottomLayer else: KeyNone
ys[i] = compress( xs[a+2*i], xs[a+2*i+1], key = key )
if isOdd:
const key = when isBottomLayer: KeyOddAndBottomLayer else: KeyOdd
ys[halfn] = compress( xs[n], zero, key = key )
var ls : seq[seq[F]]
ls = @[ xs.toSeq() ]
ls = ls & merkleTreeWorker(ys, isBottomLayer = false)
return ls
func merkleTree*(xs: openArray[F]) : MerkleTree =
return MerkleTree(layers: merkleTreeWorker(xs, isBottomLayer = true))
#-------------------------------------------------------------------------------
#
# given a top tree, with small, fixed size (!) trees grafted to the leaves,
# we can compose proofs (though for checking these proofs we need to remember
# this and use a specific custom convention, because we mark the bottom layers)
#
func mergeMerkleProofs*(bottomProof, topProof: MerkleProof): MerkleProof =
let botRoot = reconstructRoot( bottomProof )
assert( bool(botRoot == topProof.leafValue) )
let idx = topProof.leafIndex * bottomProof.numberOfLeaves + bottomProof.leafIndex
let val = bottomProof.leafValue
let nlvs = bottomProof.numberOfLeaves * topProof.numberOfLeaves
let path = bottomProof.merklePath & topProof.merklePath
return MerkleProof( leafIndex: idx
, leafValue: val
, merklePath: path
, numberOfLeaves: nlvs
)
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,22 @@
#
# helper functions
#
#-------------------------------------------------------------------------------
func floorLog2* (x : int) : int =
var k = -1
var y = x
while (y > 0):
k += 1
y = y shr 1
return k
func ceilingLog2* (x : int) : int =
if (x==0):
return -1
else:
return (floorLog2(x-1) + 1)
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,43 @@
import sugar
import std/bitops
import constantine/math/arithmetic
import poseidon2/types
import poseidon2/io
import poseidon2/sponge
import types
import misc
#-------------------------------------------------------------------------------
func extractLowBits[n: static int]( A: BigInt[n], k: int): uint64 =
assert( k>0 and k<=64 )
var r : uint64 = 0
for i in 0..<k:
let b = bit[n](A, n-1-i) # it's BIG-ENDIAN (wtf)
let y = uint64(b)
if (y != 0):
r = bitor( r, 1'u64 shl y )
return r
func extractLowBits(fld: F, k: int): uint64 =
let A : BigInt[254] = fld.toBig()
return extractLowBits(A, k);
#-------------------------------------------------------------------------------
func cellIndex*(entropy: Entropy, slotRoot: Root, numberOfCells: int, counter: int): int =
let log2 = ceilingLog2(numberOfCells)
assert( 1 shl log2 == numberOfCells , "for this version, `numberOfCells` is assumed to be a power of two")
let input : seq[F] = @[ entropy, slotRoot, toF(counter) ]
let H : Hash = Sponge.digest( input, rate = 2 )
return int(extractLowBits(H,log2))
func cellIndices*(entropy: Entropy, slotRoot: Root, numberOfCells: int, nSamples: int): seq[int] =
return collect( newSeq, (for i in 1..nSamples: cellIndex(entropy, slotRoot, numberOfCells, i) ))
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,56 @@
import sugar
import std/streams
import std/sequtils
import types
import blocks
#-------------------------------------------------------------------------------
# Example slot configuration
#
const exSlotCfg =
SlotConfig( nCells: 1024
, nSamples: 20
, dataSrc: DataSource(kind: FakeData, seed: 12345)
)
#-------------------------------------------------------------------------------
{.overflowChecks: off.}
func genFakeCell(cfg: SlotConfig, seed1: Seed, seed2: CellIdx): Cell =
var cell : seq[byte] = newSeq[byte](cellSize)
var state : int64 = 0
for i in 0..<cellSize:
state = state*state + seed1*state + (seed2 + 17)
cell[i] = byte(state)
return cell
#-------------------------------------------------------------------------------
proc loadCellData*(cfg: SlotConfig, idx: CellIdx): Cell =
case cfg.dataSrc.kind
of FakeData:
return genFakeCell(cfg, cfg.dataSrc.seed, idx)
of SlotFile:
let stream = newFileStream(cfg.dataSrc.filename, mode = fmRead)
defer: stream.close()
stream.setPosition( cellSize * idx )
var arr : array[cellSize, byte]
discard stream.readData( addr(arr), cellSize )
return arr.toSeq()
proc loadBlockData*(cfg: SlotConfig, idx: BlockIdx): Block =
let cells : seq[seq[byte]] =
collect( newSeq , (for i in 0..<cellsPerBlock: loadCellData(cfg, idx*cellsPerBlock+i) ))
return concat(cells)
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,55 @@
import sugar
import std/sequtils
import constantine/math/arithmetic
import poseidon2/types
import poseidon2/merkle
import types
import blocks
import slot
import sample
import merkle
import gen_input
import json
#-------------------------------------------------------------------------------
proc testMerkleProofs*( input: seq[F] ) =
let tree = merkleTree(input)
let root = merkleRoot(input)
assert( bool(root == treeRoot(tree)) )
let n = input.len
var ok = true
var oks : seq[bool] = newSeq[bool]( n )
for i in 0..<n:
let proof = merkleProof(tree, i)
let b = checkMerkleProof(root, proof)
oks[i] = b
ok = ok and b
let prefix = ("testing Merkle proofs for an input of size " & $n & " ... ")
if ok:
echo (prefix & "OK.")
else:
echo (prefix & "FAILED!")
echo oks
proc testAllMerkleProofs*( N: int ) =
for k in 1..N:
let input = collect( newSeq , (for i in 1..k: toF(100+i) ))
testMerkleProofs( input )
#-------------------------------------------------------------------------------
when isMainModule:
# testAllMerkleProofs(20)
let fakedata = DataSource(kind: FakeData, seed: 12345)
let slotcfg = SlotConfig( nCells: 128, nSamples: 3, dataSrc: fakedata)
let entropy = toF( 1234567 )
let prfInput = generateProofInput(slotcfg, entropy)
exportProofInput( "foo.json" , prfInput )

View File

@ -0,0 +1,87 @@
import std/strutils
from constantine/math/io/io_fields import toDecimal
import poseidon2/types
export types
#-------------------------------------------------------------------------------
const cellSize* : int = 2048 # size of the cells we prove
const blockSize* : int = 65536 # size of the network block
const cellsPerBlock* : int = blockSize div cellSize
#-------------------------------------------------------------------------------
type Entropy* = F
type Hash* = F
type Root* = Hash
#-------------------------------------------------------------------------------
func toDecimalF*(a : F): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
return s
#-------------------------------------------------------------------------------
type Cell* = seq[byte]
type Block* = seq[byte]
#-------------------------------------------------------------------------------
type
MerkleProof* = object
leafIndex* : int # linear index of the leaf, starting from 0
leafValue* : Hash # value of the leaf
merklePath* : seq[Hash] # order: from the bottom to the top
numberOfLeaves* : int # number of leaves in the tree (=size of input)
MerkleTree* = object
layers*: seq[seq[Hash]]
# ^^^ note: the first layer is the bottom layer, and the last layer is the root
#-------------------------------------------------------------------------------
type
CellProofInput* = object
cellData*: Cell
merkleProof*: MerkleProof
SlotProofInput* = object
slotRoot*: Root
entropy*: Entropy
nCells*: int
proofInputs*: seq[CellProofInput]
#-------------------------------------------------------------------------------
type
Seed* = int
CellIdx* = int
BlockIdx* = int
DataSourceKind* = enum
SlotFile,
FakeData
DataSource* = object
case kind*: DataSourceKind
of SlotFile:
filename*: string
of FakeData:
seed*: Seed
SlotConfig* = object
nCells* : int # number of cells per slot (should be power of two)
nSamples* : int # how many cells we sample
dataSrc* : DataSource # slot data source
#-------------------------------------------------------------------------------