diff --git a/Cargo.toml b/Cargo.toml index ec070c1..95a835a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ ark-bn254 = { version = "0.3.0" } ark-ec = { version = "0.3.0", default-features = false, features = [ "parallel", ] } -ark-groth16 = { version = "0.3.0", features = ["parallel"] } +ark-groth16 = { git = "https://github.com/arkworks-rs/groth16", rev = "765817f", features = ["parallel"] } ark-std = { version = "0.3.0", default-features = false, features = [ "parallel", ] } diff --git a/src/circuit_tests/mod.rs b/src/circuit_tests/mod.rs index ffa89d1..a246aa2 100644 --- a/src/circuit_tests/mod.rs +++ b/src/circuit_tests/mod.rs @@ -119,7 +119,7 @@ mod test { let r1cs = "./src/circuit_tests/artifacts/storer-test.r1cs"; let wasm = "./src/circuit_tests/artifacts/storer-test_js/storer-test.wasm"; - let mut prover = StorageProofs::new(wasm.to_string(), r1cs.to_string()); + let mut prover = StorageProofs::new(wasm.to_string(), r1cs.to_string(), None, None); let root = merkelize(hashes.as_slice()); let proof_bytes = &mut Vec::new(); diff --git a/src/ffi.rs b/src/ffi.rs index ab9f605..903dc2e 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -1,3 +1,5 @@ +use ark_std::rand::rngs::ThreadRng; + use crate::storageproofs::StorageProofs; use std::str; @@ -7,6 +9,9 @@ pub extern "C" fn init( r1cs_len: usize, wasm: *const u8, wasm_len: usize, + zkey: *const u8, + zkey_len: usize, + rng: &ThreadRng, ) -> *mut StorageProofs { let r1cs = unsafe { let slice = std::slice::from_raw_parts(r1cs, r1cs_len); @@ -18,9 +23,18 @@ pub extern "C" fn init( str::from_utf8(slice).unwrap() }; + let zkey = unsafe { + let slice = std::slice::from_raw_parts(zkey, zkey_len); + str::from_utf8(slice).unwrap() + }; + + let rng = unsafe { &mut rng }; + let storage_proofs = Box::into_raw(Box::new(StorageProofs::new( wasm.to_string(), r1cs.to_string(), + Some(zkey.to_string()), + Some(rng), ))); return storage_proofs; diff --git a/src/poseidon/constants.rs b/src/poseidon/constants.rs index 46fdc79..4ea9465 100644 --- a/src/poseidon/constants.rs +++ b/src/poseidon/constants.rs @@ -99,7 +99,6 @@ pub static M_CONST: Lazy>>> = Lazy::new(|| { }) .collect() }) - // .flatten() .collect::>>>() .try_into() .unwrap() @@ -134,7 +133,6 @@ pub static P_CONST: Lazy>>> = Lazy::new(|| { }) .collect() }) - // .flatten() .collect::>>>() .try_into() .unwrap() diff --git a/src/poseidon/mod.rs b/src/poseidon/mod.rs index 2701290..1cb741b 100644 --- a/src/poseidon/mod.rs +++ b/src/poseidon/mod.rs @@ -130,7 +130,7 @@ pub fn hash(inputs: &[U256]) -> U256 { (0, acc.1 + m[item.0][i] * item.1) }) .1 - }, // reduce((acc, a, j) => F.add(acc, F.mul(M[j][i], a)), F.zero) + }, ) .collect(); diff --git a/src/storageproofs.rs b/src/storageproofs.rs index 2682ca5..ce0a80a 100644 --- a/src/storageproofs.rs +++ b/src/storageproofs.rs @@ -1,5 +1,7 @@ +use std::fs::File; + use ark_bn254::{Bn254, Fr}; -use ark_circom::{CircomBuilder, CircomConfig}; +use ark_circom::{read_zkey, CircomBuilder, CircomConfig}; use ark_groth16::{ create_random_proof as prove, generate_random_parameters, prepare_verifying_key, verify_proof, Proof, ProvingKey, @@ -10,19 +12,29 @@ use ruint::aliases::U256; #[derive(Debug, Clone)] #[repr(C)] -pub struct StorageProofs { +pub struct StorageProofs<'a> { builder: CircomBuilder, - pvk: ProvingKey, - rng: ThreadRng, + params: ProvingKey, + rng: &'a ThreadRng, } -impl StorageProofs { - pub fn new(wtns: String, r1cs: String) -> Self { - let mut rng = ThreadRng::default(); +impl StorageProofs<'_> { + pub fn new(wtns: String, r1cs: String, zkey: Option, rng: Option<&ThreadRng>) -> Self { + let mut rng = rng.unwrap_or(&ThreadRng::default()); let builder = CircomBuilder::new(CircomConfig::::new(wtns, r1cs).unwrap()); - let pvk = generate_random_parameters::(builder.setup(), &mut rng).unwrap(); + let params: ProvingKey = match zkey { + Some(zkey) => { + let mut file = File::open(zkey).unwrap(); + read_zkey(&mut file).unwrap().0 + } + None => generate_random_parameters::(builder.setup(), &mut *rng).unwrap(), + }; - Self { builder, pvk, rng } + Self { + builder, + params, + rng, + } } pub fn prove( @@ -61,7 +73,7 @@ impl StorageProofs { let inputs = circuit .get_public_inputs() .ok_or("Unable to get public inputs!")?; - let proof = prove(circuit, &self.pvk, &mut self.rng).map_err(|e| e.to_string())?; + let proof = prove(circuit, &self.params, &*mut self.rng).map_err(|e| e.to_string())?; proof.serialize(proof_bytes).map_err(|e| e.to_string())?; inputs @@ -75,7 +87,7 @@ impl StorageProofs { let inputs: Vec = CanonicalDeserialize::deserialize(&mut public_inputs).map_err(|e| e.to_string())?; let proof = Proof::::deserialize(proof_bytes).map_err(|e| e.to_string())?; - let vk = prepare_verifying_key(&self.pvk.vk); + let vk = prepare_verifying_key(&self.params.vk); verify_proof(&vk, &proof, &inputs.as_slice()).map_err(|e| e.to_string())?; diff --git a/test/circuits/storer-test.circom b/test/circuits/storer-test.circom new file mode 100644 index 0000000..2d7a8e7 --- /dev/null +++ b/test/circuits/storer-test.circom @@ -0,0 +1,5 @@ +pragma circom 2.1.0; + +include "../../circuits/storer.circom"; + +component main { public [root, salt] } = StorageProver(32, 4, 2, 5); diff --git a/test/storer.js b/test/storer.js new file mode 100644 index 0000000..ff561fe --- /dev/null +++ b/test/storer.js @@ -0,0 +1,141 @@ +const chai = require("chai"); +const path = require("path"); +const crypto = require("crypto"); +const F1Field = require("ffjavascript").F1Field; +const Scalar = require("ffjavascript").Scalar; +const {c} = require("circom_tester"); +const chaiAsPromised = require('chai-as-promised'); +const poseidon = require("circomlibjs/src/poseidon"); +const wasm_tester = require("circom_tester").wasm; + +chai.use(chaiAsPromised); + +const p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617"); +const Fr = new F1Field(p); + +const assert = chai.assert; +const expect = chai.expect; + +function digest(input, chunkSize = 5) { + let chunks = Math.ceil(input.length / chunkSize); + let concat = []; + + for (let i = 0; i < chunks; i++) { + let chunk = input.slice(i * chunkSize, (i + 1) * chunkSize); + if (chunk.length < chunkSize) { + chunk = chunk.concat(Array(chunkSize - chunk.length).fill(0)); + } + concat.push(poseidon(chunk)); + } + + if (concat.length > 1) { + return poseidon(concat); + } + + return concat[0] +} + +function merkelize(leafs) { + // simple merkle root (treehash) generator + // unbalanced trees will have the last leaf duplicated + var merkle = leafs; + + while (merkle.length > 1) { + var newMerkle = []; + + var i = 0; + while (i < merkle.length) { + newMerkle.push(digest([merkle[i], merkle[i + 1]], 2)); + i += 2; + } + + if (merkle.length % 2 == 1) { + newMerkle.add(digest([merkle[merkle.length - 2], merkle[merkle.length - 2]], 2)); + } + + merkle = newMerkle; + } + + return merkle[0]; +} + +describe("Storer test", function () { + this.timeout(100000); + + const a = Array.from(crypto.randomBytes(32).values()).map((v) => BigInt(v)); + const aHash = digest(a); + const b = Array.from(crypto.randomBytes(32).values()).map((v) => BigInt(v)); + const bHash = digest(b); + const c = Array.from(crypto.randomBytes(32).values()).map((v) => BigInt(v)); + const cHash = digest(c); + const d = Array.from(crypto.randomBytes(32).values()).map((v) => BigInt(v)); + const dHash = digest(d); + const salt = Array.from(crypto.randomBytes(32).values()).map((v) => BigInt(v)); + const saltHash = digest(salt); + + it("Should merkelize", async () => { + let root = merkelize([aHash, bHash]); + let hash = digest([aHash, bHash], 2); + + assert.equal(hash, root); + }); + + it("Should verify chunk is correct and part of dataset", async () => { + const cir = await wasm_tester(path.join(__dirname, "./circuits", "storer_test.circom")); + + const root = merkelize([aHash, bHash, cHash, dHash]); + + const parentHashL = digest([aHash, bHash], 2); + const parentHashR = digest([cHash, dHash], 2); + + await cir.calculateWitness({ + "chunks": [[a], [b], [c], [d]], + "siblings": [ + [bHash, parentHashR], + [aHash, parentHashR], + [dHash, parentHashL], + [cHash, parentHashL]], + "hashes": [aHash, bHash, cHash, dHash], + "path": [0, 1, 2, 3], + "root": root, + "salt": saltHash, + }, true); + }); + + it("Should verify chunk is not correct and part of dataset", async () => { + const cir = await wasm_tester(path.join(__dirname, "./circuits", "storer_test.circom")); + + const root = merkelize([aHash, bHash, cHash, dHash]); + + const parentHashL = digest([aHash, bHash], 2); + const parentHashR = digest([cHash, dHash], 2); + + const fn = async () => { + return await cir.calculateWitness({ + "chunks": [ + [salt], // wrong chunk + [b], + [c], + [d]], + "siblings": [ + [bHash, parentHashR], + [aHash, parentHashR], + [dHash, parentHashL], + [cHash, parentHashL]], + "hashes": [saltHash, bHash, cHash, dHash], + "path": [0, 1, 2, 3], + "root": root, + "salt": saltHash, + }, true); + } + + assert.isRejected( + fn(), Error, + /Error: Error: Assert Failed.\nError in template StorageProver_7 line: 75/); + }); + + + it("Should should hash item", async () => { + console.log(digest([0, 0, 0]).toString(16)); + }); +});