From 8033bfd2daccc1f8d22348e2bbca861c90833135 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Sat, 18 Mar 2023 17:59:02 -0600 Subject: [PATCH] adding rust circtuit tests --- ...n-hasher.circom => poseidon-digest.circom} | 2 +- circuits/storer.circom | 4 +- src/circuit_tests/mod.rs | 84 +++++++++++++++++++ src/circuit_tests/poseidon-digest-test.circom | 5 ++ .../circuit_tests/poseidon-hash-test.circom | 6 +- src/circuit_tests/storer-test.circom | 5 ++ src/lib.rs | 3 +- src/utils.rs | 42 ++++++++-- 8 files changed, 135 insertions(+), 16 deletions(-) rename circuits/{poseidon-hasher.circom => poseidon-digest.circom} (96%) create mode 100644 src/circuit_tests/mod.rs create mode 100644 src/circuit_tests/poseidon-digest-test.circom rename circuits/simple-hasher.circom => src/circuit_tests/poseidon-hash-test.circom (60%) create mode 100644 src/circuit_tests/storer-test.circom diff --git a/circuits/poseidon-hasher.circom b/circuits/poseidon-digest.circom similarity index 96% rename from circuits/poseidon-hasher.circom rename to circuits/poseidon-digest.circom index eba1b0f..3a7a767 100644 --- a/circuits/poseidon-hasher.circom +++ b/circuits/poseidon-digest.circom @@ -13,7 +13,7 @@ function roundUpDiv(x, n) { return div; } -template parallel PoseidonHasher(BLOCK_SIZE, CHUNK_SIZE) { +template parallel PoseidonDigest(BLOCK_SIZE, CHUNK_SIZE) { // BLOCK_SIZE - size of the input block array // CHUNK_SIZE - number of elements to hash at once signal input block[BLOCK_SIZE]; // Input block array diff --git a/circuits/storer.circom b/circuits/storer.circom index bde062a..86b1407 100644 --- a/circuits/storer.circom +++ b/circuits/storer.circom @@ -3,7 +3,7 @@ pragma circom 2.1.0; include "../node_modules/circomlib/circuits/poseidon.circom"; include "../node_modules/circomlib/circuits/switcher.circom"; include "../node_modules/circomlib/circuits/bitify.circom"; -include "./poseidon-hasher.circom"; +include "./poseidon-digest.circom"; template parallel MerkleProof(LEVELS) { signal input leaf; @@ -49,7 +49,7 @@ template StorageProver(BLOCK_SIZE, QUERY_LEN, LEVELS, CHUNK_SIZE) { component hashers[QUERY_LEN]; for (var i = 0; i < QUERY_LEN; i++) { - hashers[i] = PoseidonHasher(BLOCK_SIZE, CHUNK_SIZE); + hashers[i] = PoseidonDigest(BLOCK_SIZE, CHUNK_SIZE); hashers[i].block <== chunks[i]; hashers[i].hash === hashes[i]; } diff --git a/src/circuit_tests/mod.rs b/src/circuit_tests/mod.rs new file mode 100644 index 0000000..3b01074 --- /dev/null +++ b/src/circuit_tests/mod.rs @@ -0,0 +1,84 @@ +use ark_bn254::Bn254; +use ark_circom::{CircomBuilder, CircomConfig}; +use ark_groth16::{ + create_random_proof as prove, generate_random_parameters, prepare_inputs, + prepare_verifying_key, verify_proof_with_prepared_inputs, ProvingKey, +}; +use ark_std::rand::rngs::ThreadRng; +use ruint::aliases::U256; + +pub struct CircuitsTests { + builder: CircomBuilder, + params: ProvingKey, + rng: ThreadRng, +} + +impl CircuitsTests { + pub fn new(wtns: String, r1cs: String) -> CircuitsTests { + let mut rng = ThreadRng::default(); + let builder = CircomBuilder::new(CircomConfig::::new(wtns, r1cs).unwrap()); + let params = generate_random_parameters::(builder.setup(), &mut rng).unwrap(); + + CircuitsTests { + builder, + params, + rng, + } + } + + pub fn poseidon_hash(&mut self, elements: &[U256], hash: U256) { + let mut builder = self.builder.clone(); + + elements.iter().for_each(|c| builder.push_input("in", *c)); + builder.push_input("hash", hash); + + let circuit = builder.build().unwrap(); + let inputs = circuit.get_public_inputs().unwrap(); + let proof = prove(circuit, &self.params, &mut self.rng).unwrap(); + let vk = prepare_verifying_key(&self.params.vk); + let public_inputs = prepare_inputs(&vk, &inputs).unwrap(); + verify_proof_with_prepared_inputs(&vk, &proof, &public_inputs).unwrap(); + } + + pub fn poseidon_digest(&mut self, elements: &[U256], hash: U256) { + let mut builder = self.builder.clone(); + + elements + .iter() + .for_each(|c| builder.push_input("block", *c)); + builder.push_input("hash", hash); + + let circuit = builder.build().unwrap(); + let inputs = circuit.get_public_inputs().unwrap(); + let proof = prove(circuit, &self.params, &mut self.rng).unwrap(); + let vk = prepare_verifying_key(&self.params.vk); + let public_inputs = prepare_inputs(&vk, &inputs).unwrap(); + verify_proof_with_prepared_inputs(&vk, &proof, &public_inputs).unwrap(); + } +} + +#[cfg(test)] +mod test { + use super::CircuitsTests; + use crate::{poseidon::hash, utils::digest}; + use ruint::aliases::U256; + + #[test] + fn test_poseidon_hash_circuit() { + let r1cs = "src/circuit_tests/artifacts/poseidon-hash-test.r1cs"; + let wasm = "src/circuit_tests/artifacts/poseidon-hash-test_js/poseidon-hash-test.wasm"; + + let mut hasher = CircuitsTests::new(wasm.to_string(), r1cs.to_string()); + hasher.poseidon_hash(&[U256::from(1)], hash(&[U256::from(1)])); + } + + #[test] + fn test_digest_digest_circuit() { + let r1cs = "src/circuit_tests/artifacts/poseidon-digest-test.r1cs"; + let wasm = "src/circuit_tests/artifacts/poseidon-digest-test_js/poseidon-digest-test.wasm"; + + let mut hasher = CircuitsTests::new(wasm.to_string(), r1cs.to_string()); + let input: Vec = (0..256).map(|_| U256::from(1)).collect(); + hasher.poseidon_digest(&input, digest(&input, Some(16))); + } +} diff --git a/src/circuit_tests/poseidon-digest-test.circom b/src/circuit_tests/poseidon-digest-test.circom new file mode 100644 index 0000000..f8d244f --- /dev/null +++ b/src/circuit_tests/poseidon-digest-test.circom @@ -0,0 +1,5 @@ +pragma circom 2.1.0; + +include "../../circuits/poseidon-digest.circom"; + +component main = PoseidonDigest(256, 16); diff --git a/circuits/simple-hasher.circom b/src/circuit_tests/poseidon-hash-test.circom similarity index 60% rename from circuits/simple-hasher.circom rename to src/circuit_tests/poseidon-hash-test.circom index a8f8676..5450e49 100644 --- a/circuits/simple-hasher.circom +++ b/src/circuit_tests/poseidon-hash-test.circom @@ -1,6 +1,6 @@ -include "../node_modules/circomlib/circuits/poseidon.circom"; +include "../../node_modules/circomlib/circuits/poseidon.circom"; -template SimpleHasher(SIZE) { +template PoseidonHash(SIZE) { signal input in[SIZE]; signal input hash; @@ -12,4 +12,4 @@ template SimpleHasher(SIZE) { hasher.out === hash; } -component main = SimpleHasher(1); +component main = PoseidonHash(1); diff --git a/src/circuit_tests/storer-test.circom b/src/circuit_tests/storer-test.circom new file mode 100644 index 0000000..2d7a8e7 --- /dev/null +++ b/src/circuit_tests/storer-test.circom @@ -0,0 +1,5 @@ +pragma circom 2.1.0; + +include "../../circuits/storer.circom"; + +component main { public [root, salt] } = StorageProver(32, 4, 2, 5); diff --git a/src/lib.rs b/src/lib.rs index bf6b6c5..4a7177b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,4 +2,5 @@ pub mod ffi; pub mod hash; pub mod poseidon; // pub mod storageproofs; -mod simple_hasher; +mod circuit_tests; +mod utils; diff --git a/src/utils.rs b/src/utils.rs index efaf0ea..8c372c7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,26 +1,50 @@ -use crate::poseidon::hash1; -use ruint::aliases::U256; +use crate::poseidon::hash; +use ruint::{aliases::U256, uint}; -fn digest(input: &[U256], chunk_size: Option) -> U256 { +pub fn digest(input: &[U256], chunk_size: Option) -> U256 { let chunk_size = chunk_size.unwrap_or(4); let chunks = ((input.len() as f32) / (chunk_size as f32)).ceil() as usize; - let mut concat = vec![]; + let mut concat: Vec = vec![]; let mut i: usize = 0; while i < chunks { let range = (i * chunk_size)..std::cmp::min((i + 1) * chunk_size, input.len()); - let mut chunk: Vec = input[range].to_vec(); + let mut chunk: Vec = input[range].to_vec(); if chunk.len() < chunk_size { - chunk.resize(chunk_size as usize, Fq::zero()); + chunk.resize(chunk_size as usize, uint!(0_U256)); } - concat.push(hash1(chunk)?); + concat.push(hash(chunk.as_slice())); i += chunk_size; } if concat.len() > 1 { - return hasher(concat); + return hash(concat.as_slice()); } - return Ok(concat[0]); + return concat[0]; +} + +pub fn merkelize(leafs: &[U256]) -> U256 { + // simple merkle root (treehash) generator + // unbalanced trees will have the last leaf duplicated + let mut merkle: Vec = leafs.to_vec(); + + while merkle.len() > 1 { + let mut new_merkle = Vec::new(); + let mut i = 0; + while i < merkle.len() { + new_merkle.push(hash(&[merkle[i], merkle[i + 1]])); + i += 2; + } + + if merkle.len() % 2 == 1 { + new_merkle + .push(hash(&[merkle[merkle.len() - 2], merkle[merkle.len() - 2]])); + } + + merkle = new_merkle; + } + + return merkle[0]; }