From 8560db1ee52153b9afd12a688a955de71431b240 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Thu, 16 Mar 2023 13:59:08 -0600 Subject: [PATCH] wip rust ffi --- Cargo.toml | 29 ++++ circuits/main.circom | 2 +- circuits/simple-hasher.circom | 12 ++ circuits/storer.circom | 41 ++++-- src/ffi.rs | 40 ++++++ src/lib.rs | 3 + src/storageproofs.rs | 239 +++++++++++++++++++++++++++++++ test/circuits/storer_test.circom | 3 +- test/storer.js | 22 +-- 9 files changed, 365 insertions(+), 26 deletions(-) create mode 100644 Cargo.toml create mode 100644 circuits/simple-hasher.circom create mode 100644 src/ffi.rs create mode 100644 src/lib.rs create mode 100644 src/storageproofs.rs diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..2f7be91 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "codex-storage-proofs" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +crate-type = [ + "staticlib", # Ensure it gets compiled as a (static) C library + # "cdylib", # If you want a shared/dynamic C library (advanced) + "lib", # For downstream Rust dependents: `examples/`, `tests/` etc. +] + +[dependencies] +ark-bn254 = { version = "0.3" } +ark-ec = { version = "0.4", default-features = false, features = ["parallel"] } +ark-groth16 = { version = "0.3", features = ["parallel"] } +ark-std = { version = "0.3", default-features = false, features = ["parallel"] } +ark-serialize = { version = "0.3", default-features = false } + +num-bigint = { version = "0.4", default-features = false, features = ["rand"] } +ark-circom = { git = "https://github.com/gakonst/ark-circom.git#master", features = ["circom-2"] } +arkworks-native-gadgets = "1.2.0" +arkworks-utils = { version = "1.0.1", features = ["parallel", "poseidon_bn254_x5_3", "poseidon_bn254_x5_5"] } +ark-ff = { version = "0.4.1", features = ["std"] } + +[dev-dependencies] +ff = { package="ff_ce", version="0.11", features = ["derive"] } diff --git a/circuits/main.circom b/circuits/main.circom index ab6410e..d69f0c0 100644 --- a/circuits/main.circom +++ b/circuits/main.circom @@ -2,4 +2,4 @@ pragma circom 2.1.0; include "./storer.circom"; -component main { public [root, salt] } = StorageProver(32, 4, 2); +component main { public [root, salt] } = StorageProver(32, 4, 2, 4); diff --git a/circuits/simple-hasher.circom b/circuits/simple-hasher.circom new file mode 100644 index 0000000..b0a39ac --- /dev/null +++ b/circuits/simple-hasher.circom @@ -0,0 +1,12 @@ +include "../node_modules/circomlib/circuits/poseidon.circom"; + +template SimpleHasher(SIZE) { + signal input in[SIZE]; + signal input hash; + + component hasher = Poseidon(SIZE); + hasher.inputs[0] <== in; + hasher.out === hash; +} + +component main = SimpleHasher(2); diff --git a/circuits/storer.circom b/circuits/storer.circom index e8c241e..a6aabd4 100644 --- a/circuits/storer.circom +++ b/circuits/storer.circom @@ -32,29 +32,39 @@ template parallel MerkleProof(LEVELS) { root <== hasher[LEVELS - 1].out; } -function min(arg1, arg2) { - return arg1 < arg2 ? arg1 : arg2; +function roundUpDiv(x, n) { + var last = x % n; // get the last digit + var div = x \ n; // get the division + + if (last > 0) { + return div + 1; + } else { + return div; + } } -template parallel HashCheck(BLOCK_SIZE) { +template parallel HashCheck(BLOCK_SIZE, CHUNK_SIZE) { signal input block[BLOCK_SIZE]; - signal input blockHash; + signal output hash; - // TODO: make CHUNK_SIZE a parameter - // Split array into chunks of size 16 - var CHUNK_SIZE = 16; - var NUM_CHUNKS = BLOCK_SIZE / CHUNK_SIZE; + // Split array into chunks of size CHUNK_SIZE + var NUM_CHUNKS = roundUpDiv(BLOCK_SIZE, CHUNK_SIZE); // Initialize an array to store hashes of each block component hashes[NUM_CHUNKS]; // Loop over chunks and hash them using Poseidon() for (var i = 0; i < NUM_CHUNKS; i++) { - var start = i * CHUNK_SIZE; - var end = min(start + CHUNK_SIZE, BLOCK_SIZE); hashes[i] = Poseidon(CHUNK_SIZE); + + var start = i * CHUNK_SIZE; + var end = start + CHUNK_SIZE; for (var j = start; j < end; j++) { - hashes[i].inputs[j - start] <== block[j]; + if (j >= BLOCK_SIZE) { + hashes[i].inputs[j - start] <== 0; + } else { + hashes[i].inputs[j - start] <== block[j]; + } } } @@ -69,13 +79,14 @@ template parallel HashCheck(BLOCK_SIZE) { h.inputs <== concat; // Assign output to hash signal - h.out === blockHash; + hash <== h.out; } -template StorageProver(BLOCK_SIZE, QUERY_LEN, LEVELS) { +template StorageProver(BLOCK_SIZE, QUERY_LEN, LEVELS, CHUNK_SIZE) { // BLOCK_SIZE: size of block in symbols // QUERY_LEN: query length, i.e. number if indices to be proven // LEVELS: size of Merkle Tree in the manifest + // CHUNK_SIZE: number of symbols to hash in one go signal input chunks[QUERY_LEN][BLOCK_SIZE]; // chunks to be proven signal input siblings[QUERY_LEN][LEVELS]; // siblings hashes of chunks to be proven signal input path[QUERY_LEN]; // path of chunks to be proven @@ -87,9 +98,9 @@ template StorageProver(BLOCK_SIZE, QUERY_LEN, LEVELS) { component hashers[QUERY_LEN]; for (var i = 0; i < QUERY_LEN; i++) { - hashers[i] = HashCheck(BLOCK_SIZE); + hashers[i] = HashCheck(BLOCK_SIZE, CHUNK_SIZE); hashers[i].block <== chunks[i]; - hashers[i].blockHash <== hashes[i]; + hashers[i].hash === hashes[i]; } component merkelizer[QUERY_LEN]; diff --git a/src/ffi.rs b/src/ffi.rs new file mode 100644 index 0000000..a76d575 --- /dev/null +++ b/src/ffi.rs @@ -0,0 +1,40 @@ +use crate::storageproofs::StorageProofs; +use std::str; + +#[no_mangle] +pub extern "C" fn init( + r1cs: *const u8, + r1cs_len: usize, + wasm: *const u8, + wasm_len: usize, +) -> *mut StorageProofs { + let r1cs = unsafe { + let slice = std::slice::from_raw_parts(r1cs, r1cs_len); + str::from_utf8(slice).unwrap() + }; + + let wasm = unsafe { + let slice = std::slice::from_raw_parts(wasm, wasm_len); + str::from_utf8(slice).unwrap() + }; + + let storage_proofs = Box::into_raw(Box::new(StorageProofs::new( + wasm.to_string(), + r1cs.to_string(), + ))); + + return storage_proofs; +} + +#[cfg(test)] +mod tests { + use super::init; + + #[test] + fn should_prove() { + let r1cs = "/Users/dryajov/personal/projects/status/codex-zk/test/circuits/artifacts/storer_test.r1cs"; + let wasm = "/Users/dryajov/personal/projects/status/codex-zk/test/circuits/artifacts/storer_test_js/storer_test.wasm"; + + let prover = init(r1cs.as_ptr(), r1cs.len(), wasm.as_ptr(), wasm.len()); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..298d93b --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,3 @@ +pub mod ffi; +pub mod storageproofs; +mod simple_hasher; diff --git a/src/storageproofs.rs b/src/storageproofs.rs new file mode 100644 index 0000000..269b4de --- /dev/null +++ b/src/storageproofs.rs @@ -0,0 +1,239 @@ +use ark_std::rand::rngs::ThreadRng; +use arkworks_native_gadgets::prelude::ark_ff::PrimeField; +use num_bigint::{BigInt, Sign}; + +use ark_bn254::{Bn254, Fq}; +use ark_circom::{CircomBuilder, CircomConfig}; +use ark_groth16::{ + create_random_proof as prove, generate_random_parameters, prepare_verifying_key, verify_proof, + Proof, ProvingKey, +}; + +#[derive(Debug, Clone)] +#[repr(C)] +pub struct StorageProofs { + builder: CircomBuilder, + pvk: ProvingKey, + rng: ThreadRng, +} + +impl StorageProofs { + pub fn new(wtns: String, r1cs: String) -> Self { + let mut rng = ThreadRng::default(); + let builder = CircomBuilder::new(CircomConfig::::new(wtns, r1cs).unwrap()); + let pvk = generate_random_parameters::(builder.setup(), &mut rng).unwrap(); + + Self { builder, pvk, rng } + } + + pub fn prove( + &mut self, + chunks: Vec>, + siblings: Vec>, + hashes: Vec, + path: Vec, + root: Fq, + salt: Fq, + proof_bytes: Vec, + public_inputs_bytes: Vec, + ) -> Result<(), String> { + let mut builder = self.builder.clone(); + + chunks.iter().flat_map(|c| c.into_iter()).for_each(|c| { + builder.push_input( + "chunks", + BigInt::from_biguint(Sign::Plus, c.into_repr().into()), + ) + }); + + siblings.iter().flat_map(|c| c.into_iter()).for_each(|c| { + builder.push_input( + "siblings", + BigInt::from_biguint(Sign::Plus, c.into_repr().into()), + ) + }); + + hashes.iter().for_each(|c| { + builder.push_input( + "hashes", + BigInt::from_biguint(Sign::Plus, c.into_repr().into()), + ) + }); + + path.iter() + .for_each(|c| builder.push_input("path", BigInt::new(Sign::Plus, vec![*c]))); + + builder.push_input( + "root", + BigInt::from_biguint(Sign::Plus, root.into_repr().into()), + ); + + builder.push_input( + "salt", + BigInt::from_biguint(Sign::Plus, salt.into_repr().into()), + ); + + let circuit = builder.build().unwrap(); + let inputs = circuit.get_public_inputs().unwrap(); + let proof = prove(circuit, &self.pvk, &mut self.rng).unwrap(); + let vk = prepare_verifying_key(&self.pvk.vk); + + // proof.serialize(proof_bytes).unwrap(); + // inputs.serialize(public_inputs_bytes).unwrap(); + + Ok(()) + } + + // fn verify(self, hashes: Vec, root: i32, salt: i32,vk_bytes: R, proof_bytes: R) -> Result<(), String> { + // let vk = ProvingKey::::deserialize(vk_bytes).unwrap(); + // let proof = Proof::::deserialize(proof_bytes).unwrap(); + + // let vk = prepare_verifying_key(&self.pvk.vk); + // verify_proof(&vk, &proof, &public_inputs).unwrap(); + + // Ok(()) + // } +} + +#[cfg(test)] +mod test { + use super::StorageProofs; + use ark_bn254::Fq; + use ark_ff::{UniformRand, Zero}; + use ark_std::rand::{rngs::ThreadRng, Rng}; + use arkworks_native_gadgets::{ + poseidon::{sbox::PoseidonSbox, *}, + prelude::ark_ff::PrimeField, + }; + + use arkworks_utils::{ + bytes_matrix_to_f, bytes_vec_to_f, poseidon_params::setup_poseidon_params, Curve, + }; + + type PoseidonHasher = Poseidon; + type Hasher = Box) -> Result>; + pub fn setup_params(curve: Curve, exp: i8, width: u8) -> PoseidonParameters { + let pos_data = setup_poseidon_params(curve, exp, width).unwrap(); + + let mds_f = bytes_matrix_to_f(&pos_data.mds); + let rounds_f = bytes_vec_to_f(&pos_data.rounds); + + PoseidonParameters { + mds_matrix: mds_f, + round_keys: rounds_f, + full_rounds: pos_data.full_rounds, + partial_rounds: pos_data.partial_rounds, + sbox: PoseidonSbox(pos_data.exp), + width: pos_data.width, + } + } + + fn hasher(curve: Curve, exp: i8, width: u8) -> Hasher { + let params = setup_params(curve, exp, width); + let poseidon = PoseidonHasher::new(params); + + return Box::new(move |inputs| poseidon.hash(&inputs)); + } + + fn digest(input: Vec, chunk_size: Option) -> Result { + let chunk_size = chunk_size.unwrap_or(4); + let chunks = ((input.len() as f32) / (chunk_size as f32)).ceil() as usize; + let mut concat = vec![]; + let hasher = hasher(Curve::Bn254, 5, (chunk_size + 1) as u8); + + let mut i: usize = 0; + while i < chunks { + let range = (i * chunk_size)..std::cmp::min((i + 1) * chunk_size, input.len()); + + let mut chunk: Vec = input[range].to_vec(); + + if chunk.len() < chunk_size { + chunk.resize(chunk_size as usize, Fq::zero()); + } + + concat.push(hasher(chunk)?); + i += chunk_size; + } + + if concat.len() > 1 { + return hasher(concat); + } + + return Ok(concat[0]); + } + + fn merkelize(leafs: Vec) -> Fq { + // simple merkle root (treehash) generator + // unbalanced trees will have the last leaf duplicated + let mut merkle: Vec = leafs; + let hasher = hasher(Curve::Bn254, 5, 3); + + while merkle.len() > 1 { + let mut new_merkle = Vec::new(); + let mut i = 0; + while i < merkle.len() { + new_merkle.push(hasher(vec![merkle[i], merkle[i + 1]]).unwrap()); + i += 2; + } + + if merkle.len() % 2 == 1 { + new_merkle.push( + hasher(vec![merkle[merkle.len() - 2], merkle[merkle.len() - 2]]).unwrap(), + ); + } + + merkle = new_merkle; + } + + return merkle[0]; + } + + #[test] + fn should_proove() { + let mut rng = ThreadRng::default(); + let data: Vec<(Vec, Fq)> = (0..4) + .map(|_| { + let preimages = vec![Fq::rand(&mut rng); 32]; + let hash = digest(preimages.clone(), None).unwrap(); + return (preimages, hash); + }) + .collect(); + + let chunks: Vec> = data.iter().map(|c| c.0.to_vec()).collect(); + let hashes: Vec = data.iter().map(|c| c.1).collect(); + let path = [0, 1, 2, 3].to_vec(); + + let hash2 = hasher(Curve::Bn254, 5, 3); + let parent_hash_l = hash2(vec![hashes[0], hashes[1]]).unwrap(); + let parent_hash_r = hash2(vec![hashes[2], hashes[3]]).unwrap(); + + let siblings = [ + [hashes[1], parent_hash_r].to_vec(), + [hashes[1], parent_hash_r].to_vec(), + [hashes[3], parent_hash_l].to_vec(), + [hashes[2], parent_hash_l].to_vec(), + ] + .to_vec(); + + let root = merkelize(hashes.clone()); + let mut proof_bytes: Vec = Vec::new(); + let mut public_inputs_bytes: Vec = Vec::new(); + + let r1cs = "/Users/dryajov/personal/projects/status/codex-zk/test/circuits/artifacts/storer_test.r1cs"; + let wasm = "/Users/dryajov/personal/projects/status/codex-zk/test/circuits/artifacts/storer_test_js/storer_test.wasm"; + + let mut prover = StorageProofs::new(wasm.to_string(), r1cs.to_string()); + prover + .prove( + chunks, + siblings, + hashes, + path, + root, + root, // random salt + proof_bytes, + public_inputs_bytes, + ) + .unwrap(); + } +} diff --git a/test/circuits/storer_test.circom b/test/circuits/storer_test.circom index 21edfee..06908ce 100644 --- a/test/circuits/storer_test.circom +++ b/test/circuits/storer_test.circom @@ -2,4 +2,5 @@ pragma circom 2.1.0; include "../../circuits/storer.circom"; -component main { public [root, salt] } = StorageProver(32, 4, 2); +// component main { public [root, salt] } = StorageProver(32, 4, 2, 4); +component main { public [root, salt] } = StorageProver(32, 4, 2, 2); diff --git a/test/storer.js b/test/storer.js index 1db8b85..d61d19b 100644 --- a/test/storer.js +++ b/test/storer.js @@ -16,12 +16,16 @@ const Fr = new F1Field(p); const assert = chai.assert; const expect = chai.expect; -function digest(input, chunkSize = 16) { +function digest(input, chunkSize = 5) { let chunks = Math.ceil(input.length / chunkSize); let concat = []; for (let i = 0; i < chunks; i++) { - concat.push(poseidon(input.slice(i * chunkSize, Math.min((i + 1) * chunkSize, input.length)))); + let chunk = input.slice(i * chunkSize, (i + 1) * chunkSize); + if (chunk.length < chunkSize) { + chunk = chunk.concat(Array(chunkSize - chunk.length).fill(0)); + } + concat.push(poseidon(chunk)); } if (concat.length > 1) { @@ -41,12 +45,12 @@ function merkelize(leafs) { var i = 0; while (i < merkle.length) { - newMerkle.push(digest([merkle[i], merkle[i + 1]])); + newMerkle.push(digest([merkle[i], merkle[i + 1]], 2)); i += 2; } if (merkle.length % 2 == 1) { - newMerkle.add(digest([merkle[merkle.length - 2], merkle[merkle.length - 2]])); + newMerkle.add(digest([merkle[merkle.length - 2], merkle[merkle.length - 2]], 2)); } merkle = newMerkle; @@ -71,7 +75,7 @@ describe("Storer test", function () { it("Should merkelize", async () => { let root = merkelize([aHash, bHash]); - let hash = digest([aHash, bHash]); + let hash = digest([aHash, bHash], 2); assert.equal(hash, root); }); @@ -81,8 +85,8 @@ describe("Storer test", function () { const root = merkelize([aHash, bHash, cHash, dHash]); - const parentHashL = digest([aHash, bHash]); - const parentHashR = digest([cHash, dHash]); + const parentHashL = digest([aHash, bHash], 2); + const parentHashR = digest([cHash, dHash], 2); await cir.calculateWitness({ "chunks": [[a], [b], [c], [d]], @@ -103,8 +107,8 @@ describe("Storer test", function () { const root = merkelize([aHash, bHash, cHash, dHash]); - const parentHashL = digest([aHash, bHash]); - const parentHashR = digest([cHash, dHash]); + const parentHashL = digest([aHash, bHash], 2); + const parentHashR = digest([cHash, dHash], 2); const fn = async () => { return await cir.calculateWitness({