adding rust circtuit tests

This commit is contained in:
Dmitriy Ryajov 2023-03-18 17:59:02 -06:00
parent 65caca5f78
commit 8033bfd2da
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4
8 changed files with 135 additions and 16 deletions

View File

@ -13,7 +13,7 @@ function roundUpDiv(x, n) {
return div;
}
template parallel PoseidonHasher(BLOCK_SIZE, CHUNK_SIZE) {
template parallel PoseidonDigest(BLOCK_SIZE, CHUNK_SIZE) {
// BLOCK_SIZE - size of the input block array
// CHUNK_SIZE - number of elements to hash at once
signal input block[BLOCK_SIZE]; // Input block array

View File

@ -3,7 +3,7 @@ pragma circom 2.1.0;
include "../node_modules/circomlib/circuits/poseidon.circom";
include "../node_modules/circomlib/circuits/switcher.circom";
include "../node_modules/circomlib/circuits/bitify.circom";
include "./poseidon-hasher.circom";
include "./poseidon-digest.circom";
template parallel MerkleProof(LEVELS) {
signal input leaf;
@ -49,7 +49,7 @@ template StorageProver(BLOCK_SIZE, QUERY_LEN, LEVELS, CHUNK_SIZE) {
component hashers[QUERY_LEN];
for (var i = 0; i < QUERY_LEN; i++) {
hashers[i] = PoseidonHasher(BLOCK_SIZE, CHUNK_SIZE);
hashers[i] = PoseidonDigest(BLOCK_SIZE, CHUNK_SIZE);
hashers[i].block <== chunks[i];
hashers[i].hash === hashes[i];
}

84
src/circuit_tests/mod.rs Normal file
View File

@ -0,0 +1,84 @@
use ark_bn254::Bn254;
use ark_circom::{CircomBuilder, CircomConfig};
use ark_groth16::{
create_random_proof as prove, generate_random_parameters, prepare_inputs,
prepare_verifying_key, verify_proof_with_prepared_inputs, ProvingKey,
};
use ark_std::rand::rngs::ThreadRng;
use ruint::aliases::U256;
pub struct CircuitsTests {
builder: CircomBuilder<Bn254>,
params: ProvingKey<Bn254>,
rng: ThreadRng,
}
impl CircuitsTests {
pub fn new(wtns: String, r1cs: String) -> CircuitsTests {
let mut rng = ThreadRng::default();
let builder = CircomBuilder::new(CircomConfig::<Bn254>::new(wtns, r1cs).unwrap());
let params = generate_random_parameters::<Bn254, _, _>(builder.setup(), &mut rng).unwrap();
CircuitsTests {
builder,
params,
rng,
}
}
pub fn poseidon_hash(&mut self, elements: &[U256], hash: U256) {
let mut builder = self.builder.clone();
elements.iter().for_each(|c| builder.push_input("in", *c));
builder.push_input("hash", hash);
let circuit = builder.build().unwrap();
let inputs = circuit.get_public_inputs().unwrap();
let proof = prove(circuit, &self.params, &mut self.rng).unwrap();
let vk = prepare_verifying_key(&self.params.vk);
let public_inputs = prepare_inputs(&vk, &inputs).unwrap();
verify_proof_with_prepared_inputs(&vk, &proof, &public_inputs).unwrap();
}
pub fn poseidon_digest(&mut self, elements: &[U256], hash: U256) {
let mut builder = self.builder.clone();
elements
.iter()
.for_each(|c| builder.push_input("block", *c));
builder.push_input("hash", hash);
let circuit = builder.build().unwrap();
let inputs = circuit.get_public_inputs().unwrap();
let proof = prove(circuit, &self.params, &mut self.rng).unwrap();
let vk = prepare_verifying_key(&self.params.vk);
let public_inputs = prepare_inputs(&vk, &inputs).unwrap();
verify_proof_with_prepared_inputs(&vk, &proof, &public_inputs).unwrap();
}
}
#[cfg(test)]
mod test {
use super::CircuitsTests;
use crate::{poseidon::hash, utils::digest};
use ruint::aliases::U256;
#[test]
fn test_poseidon_hash_circuit() {
let r1cs = "src/circuit_tests/artifacts/poseidon-hash-test.r1cs";
let wasm = "src/circuit_tests/artifacts/poseidon-hash-test_js/poseidon-hash-test.wasm";
let mut hasher = CircuitsTests::new(wasm.to_string(), r1cs.to_string());
hasher.poseidon_hash(&[U256::from(1)], hash(&[U256::from(1)]));
}
#[test]
fn test_digest_digest_circuit() {
let r1cs = "src/circuit_tests/artifacts/poseidon-digest-test.r1cs";
let wasm = "src/circuit_tests/artifacts/poseidon-digest-test_js/poseidon-digest-test.wasm";
let mut hasher = CircuitsTests::new(wasm.to_string(), r1cs.to_string());
let input: Vec<U256> = (0..256).map(|_| U256::from(1)).collect();
hasher.poseidon_digest(&input, digest(&input, Some(16)));
}
}

View File

@ -0,0 +1,5 @@
pragma circom 2.1.0;
include "../../circuits/poseidon-digest.circom";
component main = PoseidonDigest(256, 16);

View File

@ -1,6 +1,6 @@
include "../node_modules/circomlib/circuits/poseidon.circom";
include "../../node_modules/circomlib/circuits/poseidon.circom";
template SimpleHasher(SIZE) {
template PoseidonHash(SIZE) {
signal input in[SIZE];
signal input hash;
@ -12,4 +12,4 @@ template SimpleHasher(SIZE) {
hasher.out === hash;
}
component main = SimpleHasher(1);
component main = PoseidonHash(1);

View File

@ -0,0 +1,5 @@
pragma circom 2.1.0;
include "../../circuits/storer.circom";
component main { public [root, salt] } = StorageProver(32, 4, 2, 5);

View File

@ -2,4 +2,5 @@ pub mod ffi;
pub mod hash;
pub mod poseidon;
// pub mod storageproofs;
mod simple_hasher;
mod circuit_tests;
mod utils;

View File

@ -1,26 +1,50 @@
use crate::poseidon::hash1;
use ruint::aliases::U256;
use crate::poseidon::hash;
use ruint::{aliases::U256, uint};
fn digest(input: &[U256], chunk_size: Option<usize>) -> U256 {
pub fn digest(input: &[U256], chunk_size: Option<usize>) -> U256 {
let chunk_size = chunk_size.unwrap_or(4);
let chunks = ((input.len() as f32) / (chunk_size as f32)).ceil() as usize;
let mut concat = vec![];
let mut concat: Vec<U256> = vec![];
let mut i: usize = 0;
while i < chunks {
let range = (i * chunk_size)..std::cmp::min((i + 1) * chunk_size, input.len());
let mut chunk: Vec<Fq> = input[range].to_vec();
let mut chunk: Vec<U256> = input[range].to_vec();
if chunk.len() < chunk_size {
chunk.resize(chunk_size as usize, Fq::zero());
chunk.resize(chunk_size as usize, uint!(0_U256));
}
concat.push(hash1(chunk)?);
concat.push(hash(chunk.as_slice()));
i += chunk_size;
}
if concat.len() > 1 {
return hasher(concat);
return hash(concat.as_slice());
}
return Ok(concat[0]);
return concat[0];
}
pub fn merkelize(leafs: &[U256]) -> U256 {
// simple merkle root (treehash) generator
// unbalanced trees will have the last leaf duplicated
let mut merkle: Vec<U256> = leafs.to_vec();
while merkle.len() > 1 {
let mut new_merkle = Vec::new();
let mut i = 0;
while i < merkle.len() {
new_merkle.push(hash(&[merkle[i], merkle[i + 1]]));
i += 2;
}
if merkle.len() % 2 == 1 {
new_merkle
.push(hash(&[merkle[merkle.len() - 2], merkle[merkle.len() - 2]]));
}
merkle = new_merkle;
}
return merkle[0];
}