wip rust ffi

This commit is contained in:
Dmitriy Ryajov 2023-03-16 13:59:08 -06:00
parent e7b296ebbb
commit 8560db1ee5
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4
9 changed files with 365 additions and 26 deletions

29
Cargo.toml Normal file
View File

@ -0,0 +1,29 @@
[package]
name = "codex-storage-proofs"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
crate-type = [
"staticlib", # Ensure it gets compiled as a (static) C library
# "cdylib", # If you want a shared/dynamic C library (advanced)
"lib", # For downstream Rust dependents: `examples/`, `tests/` etc.
]
[dependencies]
ark-bn254 = { version = "0.3" }
ark-ec = { version = "0.4", default-features = false, features = ["parallel"] }
ark-groth16 = { version = "0.3", features = ["parallel"] }
ark-std = { version = "0.3", default-features = false, features = ["parallel"] }
ark-serialize = { version = "0.3", default-features = false }
num-bigint = { version = "0.4", default-features = false, features = ["rand"] }
ark-circom = { git = "https://github.com/gakonst/ark-circom.git#master", features = ["circom-2"] }
arkworks-native-gadgets = "1.2.0"
arkworks-utils = { version = "1.0.1", features = ["parallel", "poseidon_bn254_x5_3", "poseidon_bn254_x5_5"] }
ark-ff = { version = "0.4.1", features = ["std"] }
[dev-dependencies]
ff = { package="ff_ce", version="0.11", features = ["derive"] }

View File

@ -2,4 +2,4 @@ pragma circom 2.1.0;
include "./storer.circom";
component main { public [root, salt] } = StorageProver(32, 4, 2);
component main { public [root, salt] } = StorageProver(32, 4, 2, 4);

View File

@ -0,0 +1,12 @@
include "../node_modules/circomlib/circuits/poseidon.circom";
template SimpleHasher(SIZE) {
signal input in[SIZE];
signal input hash;
component hasher = Poseidon(SIZE);
hasher.inputs[0] <== in;
hasher.out === hash;
}
component main = SimpleHasher(2);

View File

@ -32,29 +32,39 @@ template parallel MerkleProof(LEVELS) {
root <== hasher[LEVELS - 1].out;
}
function min(arg1, arg2) {
return arg1 < arg2 ? arg1 : arg2;
function roundUpDiv(x, n) {
var last = x % n; // get the last digit
var div = x \ n; // get the division
if (last > 0) {
return div + 1;
} else {
return div;
}
}
template parallel HashCheck(BLOCK_SIZE) {
template parallel HashCheck(BLOCK_SIZE, CHUNK_SIZE) {
signal input block[BLOCK_SIZE];
signal input blockHash;
signal output hash;
// TODO: make CHUNK_SIZE a parameter
// Split array into chunks of size 16
var CHUNK_SIZE = 16;
var NUM_CHUNKS = BLOCK_SIZE / CHUNK_SIZE;
// Split array into chunks of size CHUNK_SIZE
var NUM_CHUNKS = roundUpDiv(BLOCK_SIZE, CHUNK_SIZE);
// Initialize an array to store hashes of each block
component hashes[NUM_CHUNKS];
// Loop over chunks and hash them using Poseidon()
for (var i = 0; i < NUM_CHUNKS; i++) {
var start = i * CHUNK_SIZE;
var end = min(start + CHUNK_SIZE, BLOCK_SIZE);
hashes[i] = Poseidon(CHUNK_SIZE);
var start = i * CHUNK_SIZE;
var end = start + CHUNK_SIZE;
for (var j = start; j < end; j++) {
hashes[i].inputs[j - start] <== block[j];
if (j >= BLOCK_SIZE) {
hashes[i].inputs[j - start] <== 0;
} else {
hashes[i].inputs[j - start] <== block[j];
}
}
}
@ -69,13 +79,14 @@ template parallel HashCheck(BLOCK_SIZE) {
h.inputs <== concat;
// Assign output to hash signal
h.out === blockHash;
hash <== h.out;
}
template StorageProver(BLOCK_SIZE, QUERY_LEN, LEVELS) {
template StorageProver(BLOCK_SIZE, QUERY_LEN, LEVELS, CHUNK_SIZE) {
// BLOCK_SIZE: size of block in symbols
// QUERY_LEN: query length, i.e. number if indices to be proven
// LEVELS: size of Merkle Tree in the manifest
// CHUNK_SIZE: number of symbols to hash in one go
signal input chunks[QUERY_LEN][BLOCK_SIZE]; // chunks to be proven
signal input siblings[QUERY_LEN][LEVELS]; // siblings hashes of chunks to be proven
signal input path[QUERY_LEN]; // path of chunks to be proven
@ -87,9 +98,9 @@ template StorageProver(BLOCK_SIZE, QUERY_LEN, LEVELS) {
component hashers[QUERY_LEN];
for (var i = 0; i < QUERY_LEN; i++) {
hashers[i] = HashCheck(BLOCK_SIZE);
hashers[i] = HashCheck(BLOCK_SIZE, CHUNK_SIZE);
hashers[i].block <== chunks[i];
hashers[i].blockHash <== hashes[i];
hashers[i].hash === hashes[i];
}
component merkelizer[QUERY_LEN];

40
src/ffi.rs Normal file
View File

@ -0,0 +1,40 @@
use crate::storageproofs::StorageProofs;
use std::str;
#[no_mangle]
pub extern "C" fn init(
r1cs: *const u8,
r1cs_len: usize,
wasm: *const u8,
wasm_len: usize,
) -> *mut StorageProofs {
let r1cs = unsafe {
let slice = std::slice::from_raw_parts(r1cs, r1cs_len);
str::from_utf8(slice).unwrap()
};
let wasm = unsafe {
let slice = std::slice::from_raw_parts(wasm, wasm_len);
str::from_utf8(slice).unwrap()
};
let storage_proofs = Box::into_raw(Box::new(StorageProofs::new(
wasm.to_string(),
r1cs.to_string(),
)));
return storage_proofs;
}
#[cfg(test)]
mod tests {
use super::init;
#[test]
fn should_prove() {
let r1cs = "/Users/dryajov/personal/projects/status/codex-zk/test/circuits/artifacts/storer_test.r1cs";
let wasm = "/Users/dryajov/personal/projects/status/codex-zk/test/circuits/artifacts/storer_test_js/storer_test.wasm";
let prover = init(r1cs.as_ptr(), r1cs.len(), wasm.as_ptr(), wasm.len());
}
}

3
src/lib.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod ffi;
pub mod storageproofs;
mod simple_hasher;

239
src/storageproofs.rs Normal file
View File

@ -0,0 +1,239 @@
use ark_std::rand::rngs::ThreadRng;
use arkworks_native_gadgets::prelude::ark_ff::PrimeField;
use num_bigint::{BigInt, Sign};
use ark_bn254::{Bn254, Fq};
use ark_circom::{CircomBuilder, CircomConfig};
use ark_groth16::{
create_random_proof as prove, generate_random_parameters, prepare_verifying_key, verify_proof,
Proof, ProvingKey,
};
#[derive(Debug, Clone)]
#[repr(C)]
pub struct StorageProofs {
builder: CircomBuilder<Bn254>,
pvk: ProvingKey<Bn254>,
rng: ThreadRng,
}
impl StorageProofs {
pub fn new(wtns: String, r1cs: String) -> Self {
let mut rng = ThreadRng::default();
let builder = CircomBuilder::new(CircomConfig::<Bn254>::new(wtns, r1cs).unwrap());
let pvk = generate_random_parameters::<Bn254, _, _>(builder.setup(), &mut rng).unwrap();
Self { builder, pvk, rng }
}
pub fn prove(
&mut self,
chunks: Vec<Vec<Fq>>,
siblings: Vec<Vec<Fq>>,
hashes: Vec<Fq>,
path: Vec<u32>,
root: Fq,
salt: Fq,
proof_bytes: Vec<u8>,
public_inputs_bytes: Vec<u8>,
) -> Result<(), String> {
let mut builder = self.builder.clone();
chunks.iter().flat_map(|c| c.into_iter()).for_each(|c| {
builder.push_input(
"chunks",
BigInt::from_biguint(Sign::Plus, c.into_repr().into()),
)
});
siblings.iter().flat_map(|c| c.into_iter()).for_each(|c| {
builder.push_input(
"siblings",
BigInt::from_biguint(Sign::Plus, c.into_repr().into()),
)
});
hashes.iter().for_each(|c| {
builder.push_input(
"hashes",
BigInt::from_biguint(Sign::Plus, c.into_repr().into()),
)
});
path.iter()
.for_each(|c| builder.push_input("path", BigInt::new(Sign::Plus, vec![*c])));
builder.push_input(
"root",
BigInt::from_biguint(Sign::Plus, root.into_repr().into()),
);
builder.push_input(
"salt",
BigInt::from_biguint(Sign::Plus, salt.into_repr().into()),
);
let circuit = builder.build().unwrap();
let inputs = circuit.get_public_inputs().unwrap();
let proof = prove(circuit, &self.pvk, &mut self.rng).unwrap();
let vk = prepare_verifying_key(&self.pvk.vk);
// proof.serialize(proof_bytes).unwrap();
// inputs.serialize(public_inputs_bytes).unwrap();
Ok(())
}
// fn verify<R: Read>(self, hashes: Vec<i32>, root: i32, salt: i32,vk_bytes: R, proof_bytes: R) -> Result<(), String> {
// let vk = ProvingKey::<Bn254>::deserialize(vk_bytes).unwrap();
// let proof = Proof::<Bn254>::deserialize(proof_bytes).unwrap();
// let vk = prepare_verifying_key(&self.pvk.vk);
// verify_proof(&vk, &proof, &public_inputs).unwrap();
// Ok(())
// }
}
#[cfg(test)]
mod test {
use super::StorageProofs;
use ark_bn254::Fq;
use ark_ff::{UniformRand, Zero};
use ark_std::rand::{rngs::ThreadRng, Rng};
use arkworks_native_gadgets::{
poseidon::{sbox::PoseidonSbox, *},
prelude::ark_ff::PrimeField,
};
use arkworks_utils::{
bytes_matrix_to_f, bytes_vec_to_f, poseidon_params::setup_poseidon_params, Curve,
};
type PoseidonHasher = Poseidon<Fq>;
type Hasher = Box<dyn Fn(Vec<Fq>) -> Result<Fq, PoseidonError>>;
pub fn setup_params<F: PrimeField>(curve: Curve, exp: i8, width: u8) -> PoseidonParameters<F> {
let pos_data = setup_poseidon_params(curve, exp, width).unwrap();
let mds_f = bytes_matrix_to_f(&pos_data.mds);
let rounds_f = bytes_vec_to_f(&pos_data.rounds);
PoseidonParameters {
mds_matrix: mds_f,
round_keys: rounds_f,
full_rounds: pos_data.full_rounds,
partial_rounds: pos_data.partial_rounds,
sbox: PoseidonSbox(pos_data.exp),
width: pos_data.width,
}
}
fn hasher(curve: Curve, exp: i8, width: u8) -> Hasher {
let params = setup_params(curve, exp, width);
let poseidon = PoseidonHasher::new(params);
return Box::new(move |inputs| poseidon.hash(&inputs));
}
fn digest(input: Vec<Fq>, chunk_size: Option<usize>) -> Result<Fq, PoseidonError> {
let chunk_size = chunk_size.unwrap_or(4);
let chunks = ((input.len() as f32) / (chunk_size as f32)).ceil() as usize;
let mut concat = vec![];
let hasher = hasher(Curve::Bn254, 5, (chunk_size + 1) as u8);
let mut i: usize = 0;
while i < chunks {
let range = (i * chunk_size)..std::cmp::min((i + 1) * chunk_size, input.len());
let mut chunk: Vec<Fq> = input[range].to_vec();
if chunk.len() < chunk_size {
chunk.resize(chunk_size as usize, Fq::zero());
}
concat.push(hasher(chunk)?);
i += chunk_size;
}
if concat.len() > 1 {
return hasher(concat);
}
return Ok(concat[0]);
}
fn merkelize(leafs: Vec<Fq>) -> Fq {
// simple merkle root (treehash) generator
// unbalanced trees will have the last leaf duplicated
let mut merkle: Vec<Fq> = leafs;
let hasher = hasher(Curve::Bn254, 5, 3);
while merkle.len() > 1 {
let mut new_merkle = Vec::new();
let mut i = 0;
while i < merkle.len() {
new_merkle.push(hasher(vec![merkle[i], merkle[i + 1]]).unwrap());
i += 2;
}
if merkle.len() % 2 == 1 {
new_merkle.push(
hasher(vec![merkle[merkle.len() - 2], merkle[merkle.len() - 2]]).unwrap(),
);
}
merkle = new_merkle;
}
return merkle[0];
}
#[test]
fn should_proove() {
let mut rng = ThreadRng::default();
let data: Vec<(Vec<Fq>, Fq)> = (0..4)
.map(|_| {
let preimages = vec![Fq::rand(&mut rng); 32];
let hash = digest(preimages.clone(), None).unwrap();
return (preimages, hash);
})
.collect();
let chunks: Vec<Vec<Fq>> = data.iter().map(|c| c.0.to_vec()).collect();
let hashes: Vec<Fq> = data.iter().map(|c| c.1).collect();
let path = [0, 1, 2, 3].to_vec();
let hash2 = hasher(Curve::Bn254, 5, 3);
let parent_hash_l = hash2(vec![hashes[0], hashes[1]]).unwrap();
let parent_hash_r = hash2(vec![hashes[2], hashes[3]]).unwrap();
let siblings = [
[hashes[1], parent_hash_r].to_vec(),
[hashes[1], parent_hash_r].to_vec(),
[hashes[3], parent_hash_l].to_vec(),
[hashes[2], parent_hash_l].to_vec(),
]
.to_vec();
let root = merkelize(hashes.clone());
let mut proof_bytes: Vec<u8> = Vec::new();
let mut public_inputs_bytes: Vec<u8> = Vec::new();
let r1cs = "/Users/dryajov/personal/projects/status/codex-zk/test/circuits/artifacts/storer_test.r1cs";
let wasm = "/Users/dryajov/personal/projects/status/codex-zk/test/circuits/artifacts/storer_test_js/storer_test.wasm";
let mut prover = StorageProofs::new(wasm.to_string(), r1cs.to_string());
prover
.prove(
chunks,
siblings,
hashes,
path,
root,
root, // random salt
proof_bytes,
public_inputs_bytes,
)
.unwrap();
}
}

View File

@ -2,4 +2,5 @@ pragma circom 2.1.0;
include "../../circuits/storer.circom";
component main { public [root, salt] } = StorageProver(32, 4, 2);
// component main { public [root, salt] } = StorageProver(32, 4, 2, 4);
component main { public [root, salt] } = StorageProver(32, 4, 2, 2);

View File

@ -16,12 +16,16 @@ const Fr = new F1Field(p);
const assert = chai.assert;
const expect = chai.expect;
function digest(input, chunkSize = 16) {
function digest(input, chunkSize = 5) {
let chunks = Math.ceil(input.length / chunkSize);
let concat = [];
for (let i = 0; i < chunks; i++) {
concat.push(poseidon(input.slice(i * chunkSize, Math.min((i + 1) * chunkSize, input.length))));
let chunk = input.slice(i * chunkSize, (i + 1) * chunkSize);
if (chunk.length < chunkSize) {
chunk = chunk.concat(Array(chunkSize - chunk.length).fill(0));
}
concat.push(poseidon(chunk));
}
if (concat.length > 1) {
@ -41,12 +45,12 @@ function merkelize(leafs) {
var i = 0;
while (i < merkle.length) {
newMerkle.push(digest([merkle[i], merkle[i + 1]]));
newMerkle.push(digest([merkle[i], merkle[i + 1]], 2));
i += 2;
}
if (merkle.length % 2 == 1) {
newMerkle.add(digest([merkle[merkle.length - 2], merkle[merkle.length - 2]]));
newMerkle.add(digest([merkle[merkle.length - 2], merkle[merkle.length - 2]], 2));
}
merkle = newMerkle;
@ -71,7 +75,7 @@ describe("Storer test", function () {
it("Should merkelize", async () => {
let root = merkelize([aHash, bHash]);
let hash = digest([aHash, bHash]);
let hash = digest([aHash, bHash], 2);
assert.equal(hash, root);
});
@ -81,8 +85,8 @@ describe("Storer test", function () {
const root = merkelize([aHash, bHash, cHash, dHash]);
const parentHashL = digest([aHash, bHash]);
const parentHashR = digest([cHash, dHash]);
const parentHashL = digest([aHash, bHash], 2);
const parentHashR = digest([cHash, dHash], 2);
await cir.calculateWitness({
"chunks": [[a], [b], [c], [d]],
@ -103,8 +107,8 @@ describe("Storer test", function () {
const root = merkelize([aHash, bHash, cHash, dHash]);
const parentHashL = digest([aHash, bHash]);
const parentHashR = digest([cHash, dHash]);
const parentHashL = digest([aHash, bHash], 2);
const parentHashR = digest([cHash, dHash], 2);
const fn = async () => {
return await cir.calculateWitness({