Ark circom and rust ffi (#5)

* wip rust ffi

* proper test component instantiation

* adding quick&dirty poseidon implementation

* update gitignode

* gitignore

* adding rust circuit tests

* gitignore

* rename

* add storer tests

* move utils under circuit_tests

* fix storage proofs

* wip: ffi

* instantiate storer

* enable ark-serialize

* delete js tests

* update CI to run cargo tests

* keep the artifacts dir

* update .gitignore

* build circuits

* remove package json

* place built circuits in correct dirs

* update gitignore

* remove node

* fix ci

* updating readme

* storageproofs.rs to storage_proofs.rs

* flatten tests chunks by default

* add ffi

* fix digest

* minor fixes for ffi

* fix storer test

* use random data for chunks

* debug optimizations to speed witness generation

* clippy & other lint stuff

* add back missing unsafe blocks

* release mode disables constraint checks

* fix ffi

* fix hashes serialization

* make naming more consistent

* add missing pragma

* use correct circuits

* add todo

* add clarification to readme

* silence unused warning

* include constants file into exec

* remove unused imports
This commit is contained in:
Dmitriy Ryajov 2023-04-12 23:17:00 +01:00 committed by GitHub
parent e7b296ebbb
commit ebef300064
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 25915 additions and 112 deletions

View File

@ -28,5 +28,7 @@ jobs:
- name: Install circom if not cached - name: Install circom if not cached
run: sh ./scripts/install-circom.sh run: sh ./scripts/install-circom.sh
- run: npm ci - run: npm ci
- name: Build circuits
run: sh ./scripts/circuit-prep.sh
- name: Run the tests - name: Run the tests
run: npm test run: RUST_BACKTRACE=full cargo test

5
.gitignore vendored
View File

@ -11,3 +11,8 @@ node_modules/
#/target #/target
/Cargo.lock /Cargo.lock
.vscode
test/circuits/artifacts
out.log
src/circuit_tests/artifacts/*
!src/circuit_tests/artifacts/.keep

38
Cargo.toml Normal file
View File

@ -0,0 +1,38 @@
[package]
name = "codex-storage-proofs"
version = "0.1.0"
edition = "2021"
[profile.dev]
opt-level = 3
[lib]
crate-type = [
"staticlib", # Ensure it gets compiled as a (static) C library
# "cdylib", # If you want a shared/dynamic C library (advanced)
"lib", # For downstream Rust dependents: `examples/`, `tests/` etc.
]
[dependencies]
ark-bn254 = { version = "0.3.0" }
ark-ec = { version = "0.3.0", default-features = false, features = [
"parallel",
] }
ark-groth16 = { git = "https://github.com/arkworks-rs/groth16", rev = "765817f", features = [
"parallel",
] }
ark-std = { version = "0.3.0", default-features = false, features = [
"parallel",
] }
ark-serialize = { version = "0.3.0", default-features = false }
num-bigint = { version = "0.4", default-features = false, features = ["rand"] }
ark-circom = { git = "https://github.com/gakonst/ark-circom.git", rev = "35ce5a9", features = [
"circom-2",
] }
ark-ff = { version = "0.3.0", features = ["std"] }
ruint = { version = "1.7.0", features = ["serde", "num-bigint", "ark-ff"] }
once_cell = "1.17.1"
serde = "1.0.156"
serde_json = "1.0.94"
num-traits = "0.2.15"
ark-relations = { version = "0.4.0", features = ["std", "tracing-subscriber"] }

View File

@ -18,22 +18,24 @@ or
at your option. These files may not be copied, modified, or distributed except according to those terms. at your option. These files may not be copied, modified, or distributed except according to those terms.
## Usage ## Usage
First
``` First, clone the repo and install the circom components:
```sh
git clone git@github.com:status-im/codex-storage-proofs.git git clone git@github.com:status-im/codex-storage-proofs.git
cd codex-storage-proofs cd codex-storage-proofs
npm i npm i
cd circuits cd circuits
``` ```
Preparing test key material (only suitable for testing) Nex, compile circuits:
```
../scripts/circuit_prep.sh storer 13 ```sh
../scripts/circuit_prep.sh
``` ```
Running the tests: Running the tests:
`npm test` ```sh
``` cargo test # don't run in release more as it dissables circuit assets
npm test test/merkletree.js
``` ```

View File

@ -1,5 +0,0 @@
pragma circom 2.1.0;
include "./storer.circom";
component main { public [root, salt] } = StorageProver(32, 4, 2);

View File

@ -0,0 +1,55 @@
pragma circom 2.1.0;
include "../node_modules/circomlib/circuits/poseidon.circom";
function roundUpDiv(x, n) {
var last = x % n; // get the last digit
var div = x \ n; // get the division
if (last > 0) {
return div + 1;
}
return div;
}
template parallel PoseidonDigest(BLOCK_SIZE, DIGEST_CHUNK) {
// BLOCK_SIZE - size of the input block array
// DIGEST_CHUNK - number of elements to hash at once
signal input block[BLOCK_SIZE]; // Input block array
signal output hash; // Output hash
// Split array into chunks of size DIGEST_CHUNK, usually 2
var NUM_CHUNKS = roundUpDiv(BLOCK_SIZE, DIGEST_CHUNK);
// Initialize an array to store hashes of each block
component hashes[NUM_CHUNKS];
// Loop over chunks and hash them using Poseidon()
for (var i = 0; i < NUM_CHUNKS; i++) {
hashes[i] = Poseidon(DIGEST_CHUNK);
var start = i * DIGEST_CHUNK;
var end = start + DIGEST_CHUNK;
for (var j = start; j < end; j++) {
if (j >= BLOCK_SIZE) {
hashes[i].inputs[j - start] <== 0;
} else {
hashes[i].inputs[j - start] <== block[j];
}
}
}
// Concatenate hashes into a single block
var concat[NUM_CHUNKS];
for (var i = 0; i < NUM_CHUNKS; i++) {
concat[i] = hashes[i].out;
}
// Hash concatenated array using Poseidon() again
component h = Poseidon(NUM_CHUNKS);
h.inputs <== concat;
// Assign output to hash signal
hash <== h.out;
}

View File

@ -4,6 +4,8 @@ include "../node_modules/circomlib/circuits/poseidon.circom";
include "../node_modules/circomlib/circuits/switcher.circom"; include "../node_modules/circomlib/circuits/switcher.circom";
include "../node_modules/circomlib/circuits/bitify.circom"; include "../node_modules/circomlib/circuits/bitify.circom";
include "./poseidon-digest.circom";
template parallel MerkleProof(LEVELS) { template parallel MerkleProof(LEVELS) {
signal input leaf; signal input leaf;
signal input pathElements[LEVELS]; signal input pathElements[LEVELS];
@ -32,50 +34,11 @@ template parallel MerkleProof(LEVELS) {
root <== hasher[LEVELS - 1].out; root <== hasher[LEVELS - 1].out;
} }
function min(arg1, arg2) { template StorageProver(BLOCK_SIZE, QUERY_LEN, LEVELS, DIGEST_CHUNK) {
return arg1 < arg2 ? arg1 : arg2;
}
template parallel HashCheck(BLOCK_SIZE) {
signal input block[BLOCK_SIZE];
signal input blockHash;
// TODO: make CHUNK_SIZE a parameter
// Split array into chunks of size 16
var CHUNK_SIZE = 16;
var NUM_CHUNKS = BLOCK_SIZE / CHUNK_SIZE;
// Initialize an array to store hashes of each block
component hashes[NUM_CHUNKS];
// Loop over chunks and hash them using Poseidon()
for (var i = 0; i < NUM_CHUNKS; i++) {
var start = i * CHUNK_SIZE;
var end = min(start + CHUNK_SIZE, BLOCK_SIZE);
hashes[i] = Poseidon(CHUNK_SIZE);
for (var j = start; j < end; j++) {
hashes[i].inputs[j - start] <== block[j];
}
}
// Concatenate hashes into a single block
var concat[NUM_CHUNKS];
for (var i = 0; i < NUM_CHUNKS; i++) {
concat[i] = hashes[i].out;
}
// Hash concatenated array using Poseidon() again
component h = Poseidon(NUM_CHUNKS);
h.inputs <== concat;
// Assign output to hash signal
h.out === blockHash;
}
template StorageProver(BLOCK_SIZE, QUERY_LEN, LEVELS) {
// BLOCK_SIZE: size of block in symbols // BLOCK_SIZE: size of block in symbols
// QUERY_LEN: query length, i.e. number if indices to be proven // QUERY_LEN: query length, i.e. number if indices to be proven
// LEVELS: size of Merkle Tree in the manifest // LEVELS: size of Merkle Tree in the manifest
// DIGEST_CHUNK: number of symbols to hash in one go
signal input chunks[QUERY_LEN][BLOCK_SIZE]; // chunks to be proven signal input chunks[QUERY_LEN][BLOCK_SIZE]; // chunks to be proven
signal input siblings[QUERY_LEN][LEVELS]; // siblings hashes of chunks to be proven signal input siblings[QUERY_LEN][LEVELS]; // siblings hashes of chunks to be proven
signal input path[QUERY_LEN]; // path of chunks to be proven signal input path[QUERY_LEN]; // path of chunks to be proven
@ -87,9 +50,9 @@ template StorageProver(BLOCK_SIZE, QUERY_LEN, LEVELS) {
component hashers[QUERY_LEN]; component hashers[QUERY_LEN];
for (var i = 0; i < QUERY_LEN; i++) { for (var i = 0; i < QUERY_LEN; i++) {
hashers[i] = HashCheck(BLOCK_SIZE); hashers[i] = PoseidonDigest(BLOCK_SIZE, DIGEST_CHUNK);
hashers[i].block <== chunks[i]; hashers[i].block <== chunks[i];
hashers[i].blockHash <== hashes[i]; hashers[i].hash === hashes[i];
} }
component merkelizer[QUERY_LEN]; component merkelizer[QUERY_LEN];

View File

@ -0,0 +1,5 @@
pragma circom 2.1.0;
include "./storer.circom";
component main { public [root, salt] } = StorageProver(256, 80, 32, 16);

View File

@ -1,28 +1,5 @@
#!/bin/bash #!/bin/bash
set -e
set -x
CIRCUIT=`basename $1`
POWER="${2:-12}"
CURVE="${3:-bn128}"
POTPREFIX=pot${POWER}_${CURVE}
if [ ! -f ${POTPREFIX}_final.ptau ]
then
snarkjs powersoftau new $CURVE $POWER ${POTPREFIX}_0000.ptau -v
snarkjs powersoftau contribute ${POTPREFIX}_0000.ptau ${POTPREFIX}_0001.ptau --name="First contribution" -v -e="random text"
snarkjs powersoftau verify ${POTPREFIX}_0001.ptau
snarkjs powersoftau beacon ${POTPREFIX}_0001.ptau ${POTPREFIX}_beacon.ptau 0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f 10 -n="Final Beacon"
snarkjs powersoftau prepare phase2 ${POTPREFIX}_beacon.ptau ${POTPREFIX}_final.ptau -v
snarkjs powersoftau verify ${POTPREFIX}_final.ptau
fi
# phase 2
circom $1.circom --r1cs --wasm
snarkjs groth16 setup ${CIRCUIT}.r1cs ${POTPREFIX}_final.ptau ${CIRCUIT}_0000.zkey
snarkjs zkey contribute ${CIRCUIT}_0000.zkey ${CIRCUIT}_0001.zkey --name="1st Contributor Name" -v -e="another random text"
snarkjs zkey verify ${CIRCUIT}.r1cs ${POTPREFIX}_final.ptau ${CIRCUIT}_0001.zkey
snarkjs zkey beacon ${CIRCUIT}_0001.zkey ${CIRCUIT}_final.zkey 0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f 10 -n="Final Beacon phase2"
circom src/circuit_tests/poseidon-digest-test.circom --r1cs --wasm -o src/circuit_tests/artifacts
circom src/circuit_tests/poseidon-hash-test.circom --r1cs --wasm -o src/circuit_tests/artifacts
circom src/circuit_tests/storer-test.circom --r1cs --wasm -o src/circuit_tests/artifacts

View File

153
src/circuit_tests/mod.rs Normal file
View File

@ -0,0 +1,153 @@
pub mod utils;
#[cfg(test)]
mod test {
use ark_bn254::Bn254;
use ark_circom::{CircomBuilder, CircomConfig};
use ark_groth16::{
create_random_proof as prove, generate_random_parameters, prepare_inputs,
prepare_verifying_key, verify_proof_with_prepared_inputs, ProvingKey,
};
use ark_std::rand::{distributions::Alphanumeric, rngs::ThreadRng, Rng};
use ruint::aliases::U256;
use crate::{
circuit_tests::utils::{digest, merkelize},
poseidon::hash,
storage_proofs::StorageProofs,
};
pub struct CircuitsTests {
builder: CircomBuilder<Bn254>,
params: ProvingKey<Bn254>,
rng: ThreadRng,
}
impl CircuitsTests {
pub fn new(wtns: String, r1cs: String) -> CircuitsTests {
let mut rng = ThreadRng::default();
let builder = CircomBuilder::new(CircomConfig::<Bn254>::new(wtns, r1cs).unwrap());
let params =
generate_random_parameters::<Bn254, _, _>(builder.setup(), &mut rng).unwrap();
CircuitsTests {
builder,
params,
rng,
}
}
pub fn poseidon_hash(&mut self, elements: &[U256], hash: U256) -> bool {
let mut builder = self.builder.clone();
elements.iter().for_each(|c| builder.push_input("in", *c));
builder.push_input("hash", hash);
let circuit = builder.build().unwrap();
let inputs = circuit.get_public_inputs().unwrap();
let proof = prove(circuit, &self.params, &mut self.rng).unwrap();
let vk = prepare_verifying_key(&self.params.vk);
let public_inputs = prepare_inputs(&vk, &inputs).unwrap();
verify_proof_with_prepared_inputs(&vk, &proof, &public_inputs).is_ok()
}
pub fn poseidon_digest(&mut self, elements: &[U256], hash: U256) -> bool {
let mut builder = self.builder.clone();
elements
.iter()
.for_each(|c| builder.push_input("block", *c));
builder.push_input("hash", hash);
let circuit = builder.build().unwrap();
let inputs = circuit.get_public_inputs().unwrap();
let proof = prove(circuit, &self.params, &mut self.rng).unwrap();
let vk = prepare_verifying_key(&self.params.vk);
let public_inputs = prepare_inputs(&vk, &inputs).unwrap();
verify_proof_with_prepared_inputs(&vk, &proof, &public_inputs).is_ok()
}
}
#[test]
fn test_poseidon_hash() {
let r1cs = "./src/circuit_tests/artifacts/poseidon-hash-test.r1cs";
let wasm = "./src/circuit_tests/artifacts/poseidon-hash-test_js/poseidon-hash-test.wasm";
let mut hasher = CircuitsTests::new(wasm.to_string(), r1cs.to_string());
assert!(hasher.poseidon_hash(&[U256::from(1)], hash(&[U256::from(1)])));
}
#[test]
fn test_poseidon_digest() {
let r1cs = "./src/circuit_tests/artifacts/poseidon-digest-test.r1cs";
let wasm =
"./src/circuit_tests/artifacts/poseidon-digest-test_js/poseidon-digest-test.wasm";
let mut hasher = CircuitsTests::new(wasm.to_string(), r1cs.to_string());
let input: Vec<U256> = (0..256).map(|c| U256::from(c)).collect();
assert!(hasher.poseidon_digest(&input, digest(&input, Some(16))));
}
#[test]
fn test_storer() {
let r1cs = "./src/circuit_tests/artifacts/storer-test.r1cs";
let wasm = "./src/circuit_tests/artifacts/storer-test_js/storer-test.wasm";
let mut prover = StorageProofs::new(wasm.to_string(), r1cs.to_string(), None);
// generate a tuple of (preimages, hash), where preimages is a vector of 256 U256s
// and hash is the hash of each vector generated using the digest function
let data = (0..4)
.map(|_| {
let rng = ThreadRng::default();
let preimages: Vec<U256> = rng
.sample_iter(Alphanumeric)
.take(256)
.map(|c| U256::from(c))
.collect();
let hash = digest(&preimages, Some(16));
(preimages, hash)
})
.collect::<Vec<(Vec<U256>, U256)>>();
let chunks: Vec<U256> = data.iter().flat_map(|c| c.0.to_vec()).collect();
let hashes: Vec<U256> = data.iter().map(|c| c.1).collect();
let path = [0, 1, 2, 3].to_vec();
let parent_hash_l = hash(&[hashes[0], hashes[1]]);
let parent_hash_r = hash(&[hashes[2], hashes[3]]);
let siblings = &[
hashes[1],
parent_hash_r,
hashes[0],
parent_hash_r,
hashes[3],
parent_hash_l,
hashes[2],
parent_hash_l,
];
let root = merkelize(hashes.as_slice());
let proof_bytes = &mut Vec::new();
let public_inputs_bytes = &mut Vec::new();
prover
.prove(
chunks.as_slice(),
siblings,
hashes.as_slice(),
path.as_slice(),
root,
root, // random salt - block hash
proof_bytes,
public_inputs_bytes,
)
.unwrap();
assert!(prover
.verify(proof_bytes.as_slice(), public_inputs_bytes.as_slice())
.is_ok());
}
}

View File

@ -0,0 +1,20 @@
pragma circom 2.1.0;
include "../../circuits/poseidon-digest.circom";
template PoseidonDigestTest(BLOCK_SIZE, CHUNK_SIZE) {
signal input block[BLOCK_SIZE];
signal input hash;
signal output hash2;
component digest = PoseidonDigest(BLOCK_SIZE, CHUNK_SIZE);
for (var i = 0; i < BLOCK_SIZE; i++) {
digest.block[i] <== block[i];
}
digest.hash === hash; // verify that the hash is correct
hash2 <== digest.hash;
}
component main { public [hash] } = PoseidonDigestTest(256, 16);

View File

@ -0,0 +1,17 @@
pragma circom 2.1.0;
include "../../node_modules/circomlib/circuits/poseidon.circom";
template PoseidonHash(SIZE) {
signal input in[SIZE];
signal input hash;
component hasher = Poseidon(SIZE);
for(var i = 0; i < SIZE; i++) {
hasher.inputs[i] <== in[i];
}
hasher.out === hash;
}
component main { public [hash] } = PoseidonHash(1);

View File

@ -0,0 +1,5 @@
pragma circom 2.1.0;
include "../../circuits/storer.circom";
component main { public [root, salt] } = StorageProver(256, 4, 2, 16);

View File

@ -0,0 +1,49 @@
#![allow(dead_code)]
use crate::poseidon::hash;
use ruint::{aliases::U256, uint};
pub fn digest(input: &[U256], chunk_size: Option<usize>) -> U256 {
let chunk_size = chunk_size.unwrap_or(4);
let chunks = ((input.len() as f32) / (chunk_size as f32)).ceil() as usize;
let mut concat: Vec<U256> = vec![];
for i in 0..chunks {
let range = (i * chunk_size)..std::cmp::min((i + 1) * chunk_size, input.len());
let mut chunk = input[range].to_vec();
if chunk.len() < chunk_size {
chunk.resize(chunk_size, uint!(0_U256));
}
concat.push(hash(chunk.as_slice()));
}
if concat.len() > 1 {
return hash(concat.as_slice());
}
concat[0]
}
pub fn merkelize(leafs: &[U256]) -> U256 {
// simple merkle root (treehash) generator
// unbalanced trees will have the last leaf duplicated
let mut merkle: Vec<U256> = leafs.to_vec();
while merkle.len() > 1 {
let mut new_merkle = Vec::new();
let mut i = 0;
while i < merkle.len() {
new_merkle.push(hash(&[merkle[i], merkle[i + 1]]));
i += 2;
}
if merkle.len() % 2 == 1 {
new_merkle.push(hash(&[merkle[merkle.len() - 2], merkle[merkle.len() - 2]]));
}
merkle = new_merkle;
}
merkle[0]
}

302
src/ffi.rs Normal file
View File

@ -0,0 +1,302 @@
use ruint::aliases::U256;
use crate::storage_proofs::StorageProofs;
use std::str;
#[derive(Debug, Clone)]
#[repr(C)]
pub struct Buffer {
pub data: *const u8,
pub len: usize,
}
#[derive(Debug, Clone)]
#[repr(C)]
pub struct ProofCtx {
pub proof: Buffer,
pub public_inputs: Buffer,
}
impl ProofCtx {
pub fn new(proof: &[u8], public_inputs: &[u8]) -> Self {
Self {
proof: Buffer {
data: proof.as_ptr(),
len: proof.len(),
},
public_inputs: Buffer {
data: public_inputs.as_ptr(),
len: public_inputs.len(),
},
}
}
}
/// # Safety
///
/// Construct a StorageProofs object
#[no_mangle]
pub unsafe extern "C" fn init(
r1cs: *const &Buffer,
wasm: *const &Buffer,
zkey: *const &Buffer,
) -> *mut StorageProofs {
let r1cs = {
if r1cs.is_null() {
return std::ptr::null_mut();
}
let slice = std::slice::from_raw_parts((*r1cs).data, (*r1cs).len);
str::from_utf8(slice).unwrap().to_string()
};
let wasm = {
if wasm.is_null() {
return std::ptr::null_mut();
}
let slice = std::slice::from_raw_parts((*wasm).data, (*wasm).len);
str::from_utf8(slice).unwrap().to_string()
};
let zkey = {
if !zkey.is_null() {
let slice = std::slice::from_raw_parts((*zkey).data, (*zkey).len);
Some(str::from_utf8(slice).unwrap().to_string())
} else {
None
}
};
Box::into_raw(Box::new(StorageProofs::new(wasm, r1cs, zkey)))
}
/// # Safety
///
/// Use after constructing a StorageProofs object with init
#[no_mangle]
pub unsafe extern "C" fn prove(
prover_ptr: *mut StorageProofs,
chunks: *const Buffer,
siblings: *const Buffer,
hashes: *const Buffer,
path: *const i32,
path_len: usize,
pubkey: *const Buffer,
root: *const Buffer,
salt: *const Buffer,
) -> *mut ProofCtx {
let chunks = {
let slice = std::slice::from_raw_parts((*chunks).data, (*chunks).len);
slice
.chunks(U256::BYTES)
.map(|c| U256::try_from_le_slice(c).unwrap())
.collect::<Vec<U256>>()
};
let siblings = {
let slice = std::slice::from_raw_parts((*siblings).data, (*siblings).len);
slice
.chunks(U256::BYTES)
.map(|c| U256::try_from_le_slice(c).unwrap())
.collect::<Vec<U256>>()
};
let hashes = {
let slice = std::slice::from_raw_parts((*hashes).data, (*hashes).len);
slice
.chunks(U256::BYTES)
.map(|c| U256::try_from_le_slice(c).unwrap())
.collect::<Vec<U256>>()
};
let path = {
let slice = std::slice::from_raw_parts(path, path_len);
slice.to_vec()
};
let pubkey =
U256::try_from_le_slice(std::slice::from_raw_parts((*pubkey).data, (*pubkey).len)).unwrap();
let root =
U256::try_from_le_slice(std::slice::from_raw_parts((*root).data, (*root).len)).unwrap();
let salt =
U256::try_from_le_slice(std::slice::from_raw_parts((*salt).data, (*salt).len)).unwrap();
let proof_bytes = &mut Vec::new();
let public_inputs_bytes = &mut Vec::new();
let mut _prover = &mut *prover_ptr;
_prover
.prove(
chunks.as_slice(),
siblings.as_slice(),
hashes.as_slice(),
path.as_slice(),
root,
salt,
proof_bytes,
public_inputs_bytes,
)
.unwrap();
Box::into_raw(Box::new(ProofCtx::new(proof_bytes, public_inputs_bytes)))
}
#[no_mangle]
/// # Safety
///
/// Should be called on a valid proof and public inputs previously generated by prove
pub unsafe extern "C" fn verify(
prover_ptr: *mut StorageProofs,
proof: *const Buffer,
public_inputs: *const Buffer,
) -> bool {
let proof = std::slice::from_raw_parts((*proof).data, (*proof).len);
let public_inputs = std::slice::from_raw_parts((*public_inputs).data, (*public_inputs).len);
let mut _prover = &mut *prover_ptr;
_prover.verify(proof, public_inputs).is_ok()
}
/// # Safety
///
/// Use on a valid pointer to StorageProofs or panics
#[no_mangle]
pub unsafe extern "C" fn free_prover(prover: *mut StorageProofs) {
if prover.is_null() {
return;
}
unsafe { drop(Box::from_raw(prover)) }
}
/// # Safety
///
/// Use on a valid pointer to ProofCtx or panics
#[no_mangle]
pub unsafe extern "C" fn free_proof_ctx(ctx: *mut ProofCtx) {
if ctx.is_null() {
return;
}
drop(Box::from_raw(ctx))
}
#[cfg(test)]
mod tests {
use ark_std::rand::{distributions::Alphanumeric, rngs::ThreadRng, Rng};
use ruint::aliases::U256;
use crate::{
circuit_tests::utils::{digest, merkelize},
poseidon::hash,
};
use super::{init, prove, Buffer};
#[test]
fn test_storer_ffi() {
// generate a tuple of (preimages, hash), where preimages is a vector of 256 U256s
// and hash is the hash of each vector generated using the digest function
let data = (0..4)
.map(|_| {
let rng = ThreadRng::default();
let preimages: Vec<U256> = rng
.sample_iter(Alphanumeric)
.take(256)
.map(|c| U256::from(c))
.collect();
let hash = digest(&preimages, Some(16));
(preimages, hash)
})
.collect::<Vec<(Vec<U256>, U256)>>();
let chunks: Vec<u8> = data
.iter()
.map(|c| {
c.0.iter()
.map(|c| c.to_le_bytes_vec())
.flatten()
.collect::<Vec<u8>>()
})
.flatten()
.collect();
let hashes: Vec<U256> = data.iter().map(|c| c.1).collect();
let hashes_slice: Vec<u8> = hashes.iter().map(|c| c.to_le_bytes_vec()).flatten().collect();
let path = [0, 1, 2, 3];
let parent_hash_l = hash(&[hashes[0], hashes[1]]);
let parent_hash_r = hash(&[hashes[2], hashes[3]]);
let sibling_hashes = &[
hashes[1],
parent_hash_r,
hashes[0],
parent_hash_r,
hashes[3],
parent_hash_l,
hashes[2],
parent_hash_l,
];
let siblings: Vec<u8> = sibling_hashes
.iter()
.map(|c| c.to_le_bytes_vec())
.flatten()
.collect();
let root = merkelize(hashes.as_slice());
let chunks_buff = Buffer {
data: chunks.as_ptr() as *const u8,
len: chunks.len(),
};
let siblings_buff = Buffer {
data: siblings.as_ptr() as *const u8,
len: siblings.len(),
};
let hashes_buff = Buffer {
data: hashes_slice.as_ptr() as *const u8,
len: hashes_slice.len(),
};
let root_bytes: [u8; U256::BYTES] = root.to_le_bytes();
let root_buff = Buffer {
data: root_bytes.as_ptr() as *const u8,
len: root_bytes.len(),
};
let r1cs_path = "src/circuit_tests/artifacts/storer-test.r1cs";
let wasm_path = "src/circuit_tests/artifacts/storer-test_js/storer-test.wasm";
let r1cs = &Buffer {
data: r1cs_path.as_ptr(),
len: r1cs_path.len(),
};
let wasm = &Buffer {
data: wasm_path.as_ptr(),
len: wasm_path.len(),
};
let prover_ptr = unsafe { init(&r1cs, &wasm, std::ptr::null()) };
let prove_ctx = unsafe {
prove(
prover_ptr,
&chunks_buff as *const Buffer,
&siblings_buff as *const Buffer,
&hashes_buff as *const Buffer,
&path as *const i32,
path.len(),
&root_buff as *const Buffer, // root
&root_buff as *const Buffer, // pubkey
&root_buff as *const Buffer, // salt/block hash
)
};
assert!(prove_ctx.is_null() == false);
}
}

4
src/lib.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod ffi;
pub mod poseidon;
pub mod storage_proofs;
mod circuit_tests;

118
src/poseidon/constants.rs Normal file
View File

@ -0,0 +1,118 @@
use ark_bn254::Fr;
use num_bigint::BigUint;
use once_cell::sync::Lazy;
use num_traits::Num;
const CONSTANTS_STR: &str = include_str!("poseidon_constants_opt.json");
pub static CONSTANTS: Lazy<serde_json::Value> = Lazy::new(|| {
serde_json::from_str(CONSTANTS_STR).unwrap()
});
pub static C_CONST: Lazy<Vec<Vec<Fr>>> = Lazy::new(|| {
CONSTANTS["C"]
.as_array()
.unwrap()
.iter()
.map(|row| {
row.as_array()
.unwrap()
.iter()
.map(|c| {
Fr::try_from(
BigUint::from_str_radix(
c.as_str().unwrap().strip_prefix("0x").unwrap(),
16,
)
.unwrap(),
)
})
.collect::<Result<Vec<Fr>, _>>()
.unwrap()
})
.collect::<Vec<Vec<Fr>>>()
});
pub static S_CONST: Lazy<Vec<Vec<Fr>>> = Lazy::new(|| {
CONSTANTS["S"]
.as_array()
.unwrap()
.iter()
.map(|row| {
row.as_array()
.unwrap()
.iter()
.map(|c| {
Fr::try_from(
BigUint::from_str_radix(
c.as_str().unwrap().strip_prefix("0x").unwrap(),
16,
)
.unwrap(),
)
})
.collect::<Result<Vec<Fr>, _>>()
.unwrap()
})
.collect::<Vec<Vec<Fr>>>()
});
pub static M_CONST: Lazy<Vec<Vec<Vec<Fr>>>> = Lazy::new(|| {
CONSTANTS["M"]
.as_array()
.unwrap()
.iter()
.map(|row| {
row.as_array()
.unwrap()
.iter()
.map(|c| {
c.as_array()
.unwrap()
.iter()
.map(|c| {
Fr::try_from(
BigUint::from_str_radix(
c.as_str().unwrap().strip_prefix("0x").unwrap(),
16,
)
.unwrap(),
)
})
.collect::<Result<Vec<Fr>, _>>()
.unwrap()
})
.collect()
})
.collect::<Vec<Vec<Vec<Fr>>>>()
});
pub static P_CONST: Lazy<Vec<Vec<Vec<Fr>>>> = Lazy::new(|| {
CONSTANTS["P"]
.as_array()
.unwrap()
.iter()
.map(|row| {
row.as_array()
.unwrap()
.iter()
.map(|c| {
c.as_array()
.unwrap()
.iter()
.map(|c| {
Fr::try_from(
BigUint::from_str_radix(
c.as_str().unwrap().strip_prefix("0x").unwrap(),
16,
)
.unwrap(),
)
})
.collect::<Result<Vec<Fr>, _>>()
.unwrap()
})
.collect()
})
.collect::<Vec<Vec<Vec<Fr>>>>()
});

154
src/poseidon/mod.rs Normal file
View File

@ -0,0 +1,154 @@
mod constants;
use ark_bn254::Fr;
use ark_ff::{Field, Zero};
use ruint::aliases::U256;
const N_ROUNDS_F: u8 = 8;
const N_ROUNDS_P: [i32; 16] = [
56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68,
];
// Compute a Poseidon hash function of the input vector.
//
// # Panics
//
// Panics if `input` is not a valid field element.
#[must_use]
pub fn hash(inputs: &[U256]) -> U256 {
assert!(!inputs.is_empty());
assert!(inputs.len() <= N_ROUNDS_P.len());
let t = inputs.len() + 1;
let n_rounds_f = N_ROUNDS_F as usize;
let n_rounds_p = N_ROUNDS_P[t - 2] as usize;
let c = constants::C_CONST[t - 2].clone();
let s = constants::S_CONST[t - 2].clone();
let m = constants::M_CONST[t - 2].clone();
let p = constants::P_CONST[t - 2].clone();
let mut state: Vec<Fr> = inputs.iter().map(|f| f.try_into().unwrap()).collect();
state.insert(0, Fr::zero());
state = state.iter().enumerate().map(|(j, a)| *a + c[j]).collect();
for r in 0..(n_rounds_f / 2 - 1) {
state = state
.iter()
.map(|a| a.pow([5]))
.enumerate()
.map(|(i, a)| a + c[(r + 1) * t + i])
.collect();
state = state
.iter()
.enumerate()
.map(|(i, _)| {
state
.iter()
.enumerate()
.fold((0, Fr::zero()), |acc, item| {
(0, (acc.1 + m[item.0][i] * item.1))
})
.1
})
.collect();
}
state = state
.iter()
.map(|a| a.pow([5]))
.enumerate()
.map(|(i, a)| a + c[(n_rounds_f / 2 - 1 + 1) * t + i])
.collect();
state = state
.iter()
.enumerate()
.map(|(i, _)| {
state
.iter()
.enumerate()
.fold((0, Fr::zero()), |acc, item| {
(0, (acc.1 + p[item.0][i] * item.1))
})
.1
})
.collect();
for r in 0..n_rounds_p {
state[0] = state[0].pow([5]);
state[0] += c[(n_rounds_f / 2 + 1) * t + r];
let s0 = state
.iter()
.enumerate()
.fold((0, Fr::zero()), |acc, item| {
(0, acc.1 + s[(t * 2 - 1) * r + item.0] * item.1)
})
.1;
for k in 1..t {
state[k] = state[k] + state[0] * s[(t * 2 - 1) * r + t + k - 1];
}
state[0] = s0;
}
for r in 0..(n_rounds_f / 2 - 1) {
state = state
.iter()
.map(|a| a.pow([5]))
.enumerate()
.map(|(i, a)| a + c[(n_rounds_f / 2 + 1) * t + n_rounds_p + r * t + i])
.collect();
state = state
.iter()
.enumerate()
.map(|(i, _)| {
state
.iter()
.enumerate()
.fold((0, Fr::zero()), |acc, item| {
(0, acc.1 + m[item.0][i] * item.1)
})
.1
})
.collect();
}
state = state.iter().map(|a| a.pow([5])).collect();
state = state
.iter()
.enumerate()
.map(
|(i, _)| {
state
.iter()
.enumerate()
.fold((0, Fr::zero()), |acc, item| {
(0, acc.1 + m[item.0][i] * item.1)
})
.1
},
)
.collect();
state[0].into()
}
#[cfg(test)]
mod tests {
use super::*;
use ruint::uint;
#[test]
fn test_hash_inputs() {
uint! {
assert_eq!(hash(&[0_U256]), 0x2a09a9fd93c590c26b91effbb2499f07e8f7aa12e2b4940a3aed2411cb65e11c_U256);
assert_eq!(hash(&[0_U256, 0_U256]), 0x2098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864_U256);
assert_eq!(hash(&[0_U256, 0_U256, 0_U256]), 0xbc188d27dcceadc1dcfb6af0a7af08fe2864eecec96c5ae7cee6db31ba599aa_U256);
assert_eq!(hash(&[31213_U256, 132_U256]), 0x303f59cd0831b5633bcda50514521b33776b5d4280eb5868ba1dbbe2e4d76ab5_U256);
}
}
}

File diff suppressed because it is too large Load Diff

98
src/storage_proofs.rs Normal file
View File

@ -0,0 +1,98 @@
use std::fs::File;
use ark_bn254::{Bn254, Fr};
use ark_circom::{read_zkey, CircomBuilder, CircomConfig};
use ark_groth16::{
create_random_proof as prove, generate_random_parameters, prepare_verifying_key, verify_proof,
Proof, ProvingKey,
};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read};
use ark_std::rand::rngs::ThreadRng;
use ruint::aliases::U256;
#[derive(Debug, Clone)]
pub struct StorageProofs {
builder: CircomBuilder<Bn254>,
params: ProvingKey<Bn254>,
rng: ThreadRng,
}
impl StorageProofs {
// TODO: add rng
pub fn new(
wtns: String,
r1cs: String,
zkey: Option<String>, /* , rng: Option<ThreadRng> */
) -> Self {
let mut rng = ThreadRng::default();
let builder = CircomBuilder::new(CircomConfig::<Bn254>::new(wtns, r1cs).unwrap());
let params: ProvingKey<Bn254> = match zkey {
Some(zkey) => {
let mut file = File::open(zkey).unwrap();
read_zkey(&mut file).unwrap().0
}
None => generate_random_parameters::<Bn254, _, _>(builder.setup(), &mut rng).unwrap(),
};
Self {
builder,
params,
rng,
}
}
pub fn prove(
&mut self,
chunks: &[U256],
siblings: &[U256],
hashes: &[U256],
path: &[i32],
root: U256,
salt: U256,
proof_bytes: &mut Vec<u8>,
public_inputs_bytes: &mut Vec<u8>,
) -> Result<(), String> {
let mut builder = self.builder.clone();
// vec of vecs is flattened, since wasm expects a contiguous array in memory
chunks.iter().for_each(|c| builder.push_input("chunks", *c));
siblings
.iter()
.for_each(|c| builder.push_input("siblings", *c));
hashes.iter().for_each(|c| builder.push_input("hashes", *c));
path.iter().for_each(|c| builder.push_input("path", *c));
builder.push_input("root", root);
builder.push_input("salt", salt);
let circuit = builder.build().map_err(|e| e.to_string())?;
let inputs = circuit
.get_public_inputs()
.ok_or("Unable to get public inputs!")?;
let proof = prove(circuit, &self.params, &mut self.rng).map_err(|e| e.to_string())?;
proof.serialize(proof_bytes).map_err(|e| e.to_string())?;
inputs
.serialize(public_inputs_bytes)
.map_err(|e| e.to_string())?;
Ok(())
}
pub fn verify<RR: Read>(
&mut self,
proof_bytes: RR,
mut public_inputs: RR,
) -> Result<(), String> {
let inputs: Vec<Fr> =
CanonicalDeserialize::deserialize(&mut public_inputs).map_err(|e| e.to_string())?;
let proof = Proof::<Bn254>::deserialize(proof_bytes).map_err(|e| e.to_string())?;
let vk = prepare_verifying_key(&self.params.vk);
verify_proof(&vk, &proof, inputs.as_slice()).map_err(|e| e.to_string())?;
Ok(())
}
}

View File

@ -0,0 +1,5 @@
pragma circom 2.1.0;
include "../../circuits/storer.circom";
component main { public [root, salt] } = StorageProver(32, 4, 2, 5);

View File

@ -1,5 +0,0 @@
pragma circom 2.1.0;
include "../../circuits/storer.circom";
component main { public [root, salt] } = StorageProver(32, 4, 2);

View File

@ -7,6 +7,8 @@ const {c} = require("circom_tester");
const chaiAsPromised = require('chai-as-promised'); const chaiAsPromised = require('chai-as-promised');
const poseidon = require("circomlibjs/src/poseidon"); const poseidon = require("circomlibjs/src/poseidon");
const wasm_tester = require("circom_tester").wasm; const wasm_tester = require("circom_tester").wasm;
// const snarkjs = require("snarkjs");
// const fs = require("fs");
chai.use(chaiAsPromised); chai.use(chaiAsPromised);
@ -16,12 +18,16 @@ const Fr = new F1Field(p);
const assert = chai.assert; const assert = chai.assert;
const expect = chai.expect; const expect = chai.expect;
function digest(input, chunkSize = 16) { function digest(input, chunkSize = 5) {
let chunks = Math.ceil(input.length / chunkSize); let chunks = Math.ceil(input.length / chunkSize);
let concat = []; let concat = [];
for (let i = 0; i < chunks; i++) { for (let i = 0; i < chunks; i++) {
concat.push(poseidon(input.slice(i * chunkSize, Math.min((i + 1) * chunkSize, input.length)))); let chunk = input.slice(i * chunkSize, (i + 1) * chunkSize);
if (chunk.length < chunkSize) {
chunk = chunk.concat(Array(chunkSize - chunk.length).fill(0));
}
concat.push(poseidon(chunk));
} }
if (concat.length > 1) { if (concat.length > 1) {
@ -41,12 +47,12 @@ function merkelize(leafs) {
var i = 0; var i = 0;
while (i < merkle.length) { while (i < merkle.length) {
newMerkle.push(digest([merkle[i], merkle[i + 1]])); newMerkle.push(digest([merkle[i], merkle[i + 1]], 2));
i += 2; i += 2;
} }
if (merkle.length % 2 == 1) { if (merkle.length % 2 == 1) {
newMerkle.add(digest([merkle[merkle.length - 2], merkle[merkle.length - 2]])); newMerkle.add(digest([merkle[merkle.length - 2], merkle[merkle.length - 2]], 2));
} }
merkle = newMerkle; merkle = newMerkle;
@ -55,34 +61,36 @@ function merkelize(leafs) {
return merkle[0]; return merkle[0];
} }
// TODO: should be removed at some point, as the rust test should be sufficient, but left here for now to aid debugging
describe("Storer test", function () { describe("Storer test", function () {
this.timeout(100000); this.timeout(100000);
const a = Array.from(crypto.randomBytes(32).values()).map((v) => BigInt(v)); const a = Array.from(crypto.randomBytes(256).values()).map((v) => BigInt(v));
const aHash = digest(a); const aHash = digest(a, 16);
const b = Array.from(crypto.randomBytes(32).values()).map((v) => BigInt(v)); const b = Array.from(crypto.randomBytes(256).values()).map((v) => BigInt(v));
const bHash = digest(b); const bHash = digest(b, 16);
const c = Array.from(crypto.randomBytes(32).values()).map((v) => BigInt(v)); const c = Array.from(crypto.randomBytes(256).values()).map((v) => BigInt(v));
const cHash = digest(c); const cHash = digest(c, 16);
const d = Array.from(crypto.randomBytes(32).values()).map((v) => BigInt(v)); const d = Array.from(crypto.randomBytes(256).values()).map((v) => BigInt(v));
const dHash = digest(d); const dHash = digest(d, 16);
const salt = Array.from(crypto.randomBytes(32).values()).map((v) => BigInt(v)); const salt = Array.from(crypto.randomBytes(256).values()).map((v) => BigInt(v));
const saltHash = digest(salt); const saltHash = digest(salt, 16);
it("Should merkelize", async () => { it("Should merkelize", async () => {
let root = merkelize([aHash, bHash]); let root = merkelize([aHash, bHash]);
let hash = digest([aHash, bHash]); let hash = digest([aHash, bHash], 2);
assert.equal(hash, root); assert.equal(hash, root);
}); });
it("Should verify chunk is correct and part of dataset", async () => { it("Should verify chunk is correct and part of dataset", async () => {
const cir = await wasm_tester(path.join(__dirname, "./circuits", "storer_test.circom")); const cir = await wasm_tester("src/circuit_tests/storer-test.circom");
const root = merkelize([aHash, bHash, cHash, dHash]); const root = merkelize([aHash, bHash, cHash, dHash]);
const parentHashL = digest([aHash, bHash]); const parentHashL = digest([aHash, bHash], 2);
const parentHashR = digest([cHash, dHash]); const parentHashR = digest([cHash, dHash], 2);
await cir.calculateWitness({ await cir.calculateWitness({
"chunks": [[a], [b], [c], [d]], "chunks": [[a], [b], [c], [d]],
@ -96,15 +104,15 @@ describe("Storer test", function () {
"root": root, "root": root,
"salt": saltHash, "salt": saltHash,
}, true); }, true);
}).timeout(100000); });
it("Should verify chunk is correct and part of dataset", async () => { it("Should verify chunk is not correct and part of dataset", async () => {
const cir = await wasm_tester(path.join(__dirname, "./circuits", "storer_test.circom")); const cir = await wasm_tester("src/circuit_tests/storer-test.circom");
const root = merkelize([aHash, bHash, cHash, dHash]); const root = merkelize([aHash, bHash, cHash, dHash]);
const parentHashL = digest([aHash, bHash]); const parentHashL = digest([aHash, bHash], 2);
const parentHashR = digest([cHash, dHash]); const parentHashR = digest([cHash, dHash], 2);
const fn = async () => { const fn = async () => {
return await cir.calculateWitness({ return await cir.calculateWitness({
@ -128,6 +136,33 @@ describe("Storer test", function () {
assert.isRejected( assert.isRejected(
fn(), Error, fn(), Error,
/Error: Error: Assert Failed.\nError in template StorageProver_7 line: 75/); /Error: Error: Assert Failed.\nError in template StorageProver_7 line: 75/);
});
}).timeout(100000);
function range(start, end) {
return Array(end - start + 1).fill().map((_, idx) => start + idx)
}
it("Should test poseidon digest", async () => {
const cir = await wasm_tester("src/circuit_tests/poseidon-digest-test.circom");
let input = range(0, 255).map((c) => BigInt(c));
await cir.calculateWitness({
"block": input,
"hash": digest(input, 16),
});
});
// it("Should prove digest with zkey file", async () => {
// let input = range(0, 255).map((c) => BigInt(c));
// const {proof, publicSignals} = await snarkjs.groth16.fullProve(
// {
// "block": input,
// "hash": digest(input, 16),
// },
// "src/circuit_tests/artifacts/poseidon-digest-test_js/poseidon-digest-test.wasm",
// "circuit_0000.zkey");
// const vKey = JSON.parse(fs.readFileSync("verification_key.json"));
// const res = await snarkjs.groth16.verify(vKey, publicSignals, proof);
// assert(res);
// });
}); });