chore(rln): refactor resource initialization (#260)

* chore(rln): optimize into Lazy OnceCells

* fix

* fix: dont change duration

* fix: increase duration?

* chore: add backtrace

* fix: remove plotter to avoid f64 range failure

* fix: remove ci alteration

* fix: use arc over witness calc

* fix: remove more lifetimes

* fix: benchmark correct fn call, not the getter

* fix: bench config
This commit is contained in:
Aaryamann Challani 2024-06-17 13:43:09 +05:30 committed by GitHub
parent c6493bd10f
commit d8f813bc2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 87 additions and 136 deletions

20
Cargo.lock generated
View File

@ -1460,25 +1460,6 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "include_dir"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18762faeff7122e89e0857b02f7ce6fcc0d101d5e9ad2ad7846cc01d61b7f19e"
dependencies = [
"include_dir_macros",
]
[[package]]
name = "include_dir_macros"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b139284b5cf57ecfa712bcc66950bb635b31aff41c188e8a4cfc758eca374a3f"
dependencies = [
"proc-macro2",
"quote",
]
[[package]]
name = "indenter"
version = "0.3.3"
@ -2303,7 +2284,6 @@ dependencies = [
"cfg-if",
"color-eyre",
"criterion 0.4.0",
"include_dir",
"num-bigint",
"num-traits",
"once_cell",

View File

@ -5,12 +5,12 @@ use std::fs::File;
use crate::config::{Config, InnerConfig};
#[derive(Default)]
pub(crate) struct State<'a> {
pub rln: Option<RLN<'a>>,
pub(crate) struct State {
pub rln: Option<RLN>,
}
impl<'a> State<'a> {
pub(crate) fn load_state() -> Result<State<'a>> {
impl State {
pub(crate) fn load_state() -> Result<State> {
let config = Config::load_config()?;
let rln = if let Some(InnerConfig { file, tree_height }) = config.inner {
let resources = File::open(&file)?;

View File

@ -19,7 +19,7 @@ pub fn init_panic_hook() {
pub struct RLNWrapper {
// The purpose of this wrapper is to hold a RLN instance with the 'static lifetime
// because wasm_bindgen does not allow returning elements with lifetimes
instance: RLN<'static>,
instance: RLN,
}
// Macro to call methods with arbitrary amount of arguments,
@ -150,8 +150,8 @@ impl<T> ProcessArg for Vec<T> {
}
}
impl<'a> ProcessArg for *const RLN<'a> {
type ReturnType = &'a RLN<'a>;
impl ProcessArg for *const RLN {
type ReturnType = &'static RLN;
fn process(self) -> Self::ReturnType {
unsafe { &*self }
}

View File

@ -58,8 +58,6 @@ utils = { package = "zerokit_utils", version = "=0.5.0", path = "../utils/", def
serde_json = "=1.0.96"
serde = { version = "=1.0.163", features = ["derive"] }
include_dir = "=0.7.3"
[dev-dependencies]
sled = "=0.34.7"
criterion = { version = "=0.4.0", features = ["html_reports"] }

View File

@ -1,17 +1,15 @@
use criterion::{criterion_group, criterion_main, Criterion};
use rln::circuit::{vk_from_ark_serialized, RESOURCES_DIR, VK_FILENAME};
use std::path::Path;
use rln::circuit::{vk_from_ark_serialized, VK_BYTES};
// Here we benchmark how long the deserialization of the
// verifying_key takes, only testing the json => verifying_key conversion,
// and skipping conversion from bytes => string => serde_json::Value
pub fn vk_deserialize_benchmark(c: &mut Criterion) {
let vk = RESOURCES_DIR.get_file(Path::new(VK_FILENAME)).unwrap();
let vk = vk.contents();
let vk = VK_BYTES;
c.bench_function("circuit::to_verifying_key", |b| {
c.bench_function("vk::vk_from_ark_serialized", |b| {
b.iter(|| {
let _ = vk_from_ark_serialized(&vk);
let _ = vk_from_ark_serialized(vk);
})
});
}

View File

@ -1,11 +1,14 @@
use criterion::{criterion_group, criterion_main, Criterion};
use rln::circuit::{zkey_from_raw, ZKEY_BYTES};
// Depending on the key type (enabled by the `--features arkzkey` flag)
// the upload speed from the `rln_final.zkey` or `rln_final.arkzkey` file is calculated
pub fn key_load_benchmark(c: &mut Criterion) {
c.bench_function("zkey::upload_from_folder", |b| {
let zkey = ZKEY_BYTES.to_vec();
c.bench_function("zkey::zkey_from_raw", |b| {
b.iter(|| {
let _ = rln::circuit::zkey_from_folder();
let _ = zkey_from_raw(&zkey);
})
});
}

View File

@ -13,33 +13,49 @@ use color_eyre::{Report, Result};
cfg_if! {
if #[cfg(not(target_arch = "wasm32"))] {
use ark_circom::{WitnessCalculator};
use once_cell::sync::OnceCell;
use once_cell::sync::{Lazy};
use std::sync::Mutex;
use wasmer::{Module, Store};
use include_dir::{include_dir, Dir};
use std::path::Path;
use std::sync::Arc;
}
}
cfg_if! {
if #[cfg(feature = "arkzkey")] {
use ark_zkey::read_arkzkey_from_bytes;
const ARKZKEY_FILENAME: &str = "tree_height_20/rln_final.arkzkey";
const ARKZKEY_BYTES: &[u8] = include_bytes!("tree_height_20/rln_final.arkzkey");
} else {
use std::io::Cursor;
use ark_circom::read_zkey;
}
}
const ZKEY_FILENAME: &str = "tree_height_20/rln_final.zkey";
pub const VK_FILENAME: &str = "tree_height_20/verification_key.arkvkey";
const WASM_FILENAME: &str = "tree_height_20/rln.wasm";
pub const TEST_TREE_HEIGHT: usize = 20;
pub const ZKEY_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/rln_final.zkey");
pub const VK_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/verification_key.arkvkey");
const WASM_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/rln.wasm");
#[cfg(not(target_arch = "wasm32"))]
pub static RESOURCES_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/resources");
static ZKEY: Lazy<(ProvingKey<Curve>, ConstraintMatrices<Fr>)> = Lazy::new(|| {
cfg_if! {
if #[cfg(feature = "arkzkey")] {
read_arkzkey_from_bytes(ARKZKEY_BYTES).expect("Failed to read arkzkey")
} else {
let mut reader = Cursor::new(ZKEY_BYTES);
read_zkey(&mut reader).expect("Failed to read zkey")
}
}
});
#[cfg(not(target_arch = "wasm32"))]
static VK: Lazy<VerifyingKey<Curve>> =
Lazy::new(|| vk_from_ark_serialized(VK_BYTES).expect("Failed to read vk"));
#[cfg(not(target_arch = "wasm32"))]
static WITNESS_CALCULATOR: Lazy<Arc<Mutex<WitnessCalculator>>> = Lazy::new(|| {
circom_from_raw(WASM_BYTES.to_vec()).expect("Failed to create witness calculator")
});
pub const TEST_TREE_HEIGHT: usize = 20;
// The following types define the pairing friendly elliptic curve, the underlying finite fields and groups default to this module
// Note that proofs are serialized assuming Fr to be 4x8 = 32 bytes in size. Hence, changing to a curve with different encoding will make proof verification to fail
@ -72,26 +88,8 @@ pub fn zkey_from_raw(zkey_data: &Vec<u8>) -> Result<(ProvingKey<Curve>, Constrai
// Loads the proving key
#[cfg(not(target_arch = "wasm32"))]
pub fn zkey_from_folder() -> Result<(ProvingKey<Curve>, ConstraintMatrices<Fr>)> {
#[cfg(feature = "arkzkey")]
let zkey = RESOURCES_DIR.get_file(Path::new(ARKZKEY_FILENAME));
#[cfg(not(feature = "arkzkey"))]
let zkey = RESOURCES_DIR.get_file(Path::new(ZKEY_FILENAME));
if let Some(zkey) = zkey {
let proving_key_and_matrices = match () {
#[cfg(feature = "arkzkey")]
() => read_arkzkey_from_bytes(zkey.contents())?,
#[cfg(not(feature = "arkzkey"))]
() => {
let mut c = Cursor::new(zkey.contents());
read_zkey(&mut c)?
}
};
Ok(proving_key_and_matrices)
} else {
Err(Report::msg("No proving key found!"))
}
pub fn zkey_from_folder() -> &'static (ProvingKey<Curve>, ConstraintMatrices<Fr>) {
&ZKEY
}
// Loads the verification key from a bytes vector
@ -112,49 +110,24 @@ pub fn vk_from_raw(vk_data: &[u8], zkey_data: &Vec<u8>) -> Result<VerifyingKey<C
// Loads the verification key
#[cfg(not(target_arch = "wasm32"))]
pub fn vk_from_folder() -> Result<VerifyingKey<Curve>> {
let vk = RESOURCES_DIR.get_file(Path::new(VK_FILENAME));
let zkey = RESOURCES_DIR.get_file(Path::new(ZKEY_FILENAME));
let verifying_key: VerifyingKey<Curve>;
if let Some(vk) = vk {
verifying_key = vk_from_ark_serialized(vk.contents())?;
Ok(verifying_key)
} else if let Some(_zkey) = zkey {
let (proving_key, _matrices) = zkey_from_folder()?;
verifying_key = proving_key.vk;
Ok(verifying_key)
} else {
Err(Report::msg("No proving/verification key found!"))
}
pub fn vk_from_folder() -> &'static VerifyingKey<Curve> {
&VK
}
#[cfg(not(target_arch = "wasm32"))]
static WITNESS_CALCULATOR: OnceCell<Mutex<WitnessCalculator>> = OnceCell::new();
// Initializes the witness calculator using a bytes vector
#[cfg(not(target_arch = "wasm32"))]
pub fn circom_from_raw(wasm_buffer: Vec<u8>) -> Result<&'static Mutex<WitnessCalculator>> {
WITNESS_CALCULATOR.get_or_try_init(|| {
let store = Store::default();
let module = Module::new(&store, wasm_buffer)?;
let result = WitnessCalculator::from_module(module)?;
Ok::<Mutex<WitnessCalculator>, Report>(Mutex::new(result))
})
pub fn circom_from_raw(wasm_buffer: Vec<u8>) -> Result<Arc<Mutex<WitnessCalculator>>> {
let store = Store::default();
let module = Module::new(&store, wasm_buffer)?;
let result = WitnessCalculator::from_module(module)?;
let wrapped = Mutex::new(result);
Ok(Arc::new(wrapped))
}
// Initializes the witness calculator
#[cfg(not(target_arch = "wasm32"))]
pub fn circom_from_folder() -> Result<&'static Mutex<WitnessCalculator>> {
// We read the wasm file
let wasm = RESOURCES_DIR.get_file(Path::new(WASM_FILENAME));
if let Some(wasm) = wasm {
let wasm_buffer = wasm.contents();
circom_from_raw(wasm_buffer.to_vec())
} else {
Err(Report::msg("No wasm file found!"))
}
pub fn circom_from_folder() -> &'static Arc<Mutex<WitnessCalculator>> {
&WITNESS_CALCULATOR
}
// Computes the verification key from a bytes vector containing pre-processed ark-serialized verification key
@ -167,7 +140,7 @@ pub fn vk_from_ark_serialized(data: &[u8]) -> Result<VerifyingKey<Curve>> {
// Checks verification key to be correct with respect to proving key
#[cfg(not(target_arch = "wasm32"))]
pub fn check_vk_from_zkey(verifying_key: VerifyingKey<Curve>) -> Result<()> {
let (proving_key, _matrices) = zkey_from_folder()?;
let (proving_key, _matrices) = zkey_from_folder();
if proving_key.vk == verifying_key {
Ok(())
} else {

View File

@ -143,15 +143,15 @@ impl ProcessArg for *const Buffer {
}
}
impl<'a> ProcessArg for *const RLN<'a> {
type ReturnType = &'a RLN<'a>;
impl ProcessArg for *const RLN {
type ReturnType = &'static RLN;
fn process(self) -> Self::ReturnType {
unsafe { &*self }
}
}
impl<'a> ProcessArg for *mut RLN<'a> {
type ReturnType = &'a mut RLN<'a>;
impl ProcessArg for *mut RLN {
type ReturnType = &'static mut RLN;
fn process(self) -> Self::ReturnType {
unsafe { &mut *self }
}

View File

@ -22,6 +22,7 @@ cfg_if! {
use ark_circom::WitnessCalculator;
use serde_json::{json, Value};
use utils::{Hasher};
use std::sync::Arc;
use std::str::FromStr;
} else {
use std::marker::*;
@ -39,7 +40,7 @@ pub const RLN_IDENTIFIER: &[u8] = b"zerokit/rln/010203040506070809";
/// It implements the methods required to update the internal Merkle Tree, generate and verify RLN ZK proofs.
///
/// I/O is mostly done using writers and readers implementing `std::io::Write` and `std::io::Read`, respectively.
pub struct RLN<'a> {
pub struct RLN {
proving_key: (ProvingKey<Curve>, ConstraintMatrices<Fr>),
pub(crate) verification_key: VerifyingKey<Curve>,
pub(crate) tree: PoseidonTree,
@ -48,12 +49,12 @@ pub struct RLN<'a> {
// contains a lifetime, a PhantomData is necessary to avoid a compiler
// error since the lifetime is not being used
#[cfg(not(target_arch = "wasm32"))]
pub(crate) witness_calculator: &'a Mutex<WitnessCalculator>,
pub(crate) witness_calculator: Arc<Mutex<WitnessCalculator>>,
#[cfg(target_arch = "wasm32")]
_marker: PhantomData<&'a ()>,
_marker: PhantomData<()>,
}
impl RLN<'_> {
impl RLN {
/// Creates a new RLN object by loading circuit resources from a folder.
///
/// Input parameters are
@ -70,7 +71,7 @@ impl RLN<'_> {
/// let mut rln = RLN::new(tree_height, input);
/// ```
#[cfg(not(target_arch = "wasm32"))]
pub fn new<R: Read>(tree_height: usize, mut input_data: R) -> Result<RLN<'static>> {
pub fn new<R: Read>(tree_height: usize, mut input_data: R) -> Result<RLN> {
// We read input
let mut input: Vec<u8> = Vec::new();
input_data.read_to_end(&mut input)?;
@ -78,10 +79,10 @@ impl RLN<'_> {
let rln_config: Value = serde_json::from_str(&String::from_utf8(input)?)?;
let tree_config = rln_config["tree_config"].to_string();
let witness_calculator = circom_from_folder()?;
let proving_key = zkey_from_folder()?;
let witness_calculator = circom_from_folder();
let proving_key = zkey_from_folder();
let verification_key = vk_from_folder()?;
let verification_key = vk_from_folder();
let tree_config: <PoseidonTree as ZerokitMerkleTree>::Config = if tree_config.is_empty() {
<PoseidonTree as ZerokitMerkleTree>::Config::default()
@ -97,9 +98,9 @@ impl RLN<'_> {
)?;
Ok(RLN {
witness_calculator,
proving_key,
verification_key,
witness_calculator: witness_calculator.to_owned(),
proving_key: proving_key.to_owned(),
verification_key: verification_key.to_owned(),
tree,
#[cfg(target_arch = "wasm32")]
_marker: PhantomData,
@ -150,7 +151,7 @@ impl RLN<'_> {
zkey_vec: Vec<u8>,
vk_vec: Vec<u8>,
mut tree_config_input: R,
) -> Result<RLN<'static>> {
) -> Result<RLN> {
#[cfg(not(target_arch = "wasm32"))]
let witness_calculator = circom_from_raw(circom_vec)?;
@ -179,15 +180,13 @@ impl RLN<'_> {
proving_key,
verification_key,
tree,
#[cfg(target_arch = "wasm32")]
_marker: PhantomData,
})
}
#[cfg(target_arch = "wasm32")]
pub fn new_with_params(
tree_height: usize,
zkey_vec: Vec<u8>,
vk_vec: Vec<u8>,
) -> Result<RLN<'static>> {
pub fn new_with_params(tree_height: usize, zkey_vec: Vec<u8>, vk_vec: Vec<u8>) -> Result<RLN> {
#[cfg(not(target_arch = "wasm32"))]
let witness_calculator = circom_from_raw(circom_vec)?;
@ -682,7 +681,7 @@ impl RLN<'_> {
}
*/
let proof = generate_proof(self.witness_calculator, &self.proving_key, &rln_witness)?;
let proof = generate_proof(&self.witness_calculator, &self.proving_key, &rln_witness)?;
// Note: we export a serialization of ark-groth16::Proof not semaphore::Proof
proof.serialize_compressed(&mut output_data)?;
@ -805,7 +804,7 @@ impl RLN<'_> {
let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte)?;
let proof_values = proof_values_from_witness(&rln_witness)?;
let proof = generate_proof(self.witness_calculator, &self.proving_key, &rln_witness)?;
let proof = generate_proof(&self.witness_calculator, &self.proving_key, &rln_witness)?;
// Note: we export a serialization of ark-groth16::Proof not semaphore::Proof
// This proof is compressed, i.e. 128 bytes long
@ -853,7 +852,7 @@ impl RLN<'_> {
let (rln_witness, _) = deserialize_witness(&witness_byte)?;
let proof_values = proof_values_from_witness(&rln_witness)?;
let proof = generate_proof(self.witness_calculator, &self.proving_key, &rln_witness)?;
let proof = generate_proof(&self.witness_calculator, &self.proving_key, &rln_witness)?;
// Note: we export a serialization of ark-groth16::Proof not semaphore::Proof
// This proof is compressed, i.e. 128 bytes long
@ -1284,7 +1283,7 @@ impl RLN<'_> {
}
#[cfg(not(target_arch = "wasm32"))]
impl Default for RLN<'_> {
impl Default for RLN {
fn default() -> Self {
let tree_height = TEST_TREE_HEIGHT;
let buffer = Cursor::new(json!({}).to_string());

View File

@ -16,7 +16,7 @@ mod test {
const NO_OF_LEAVES: usize = 256;
fn create_rln_instance() -> &'static mut RLN<'static> {
fn create_rln_instance() -> &'static mut RLN {
let mut rln_pointer = MaybeUninit::<*mut RLN>::uninit();
let input_config = json!({}).to_string();
let input_buffer = &Buffer::from(input_config.as_bytes());

View File

@ -127,9 +127,9 @@ mod test {
// We test a RLN proof generation and verification
fn test_witness_from_json() {
// We generate all relevant keys
let proving_key = zkey_from_folder().unwrap();
let verification_key = vk_from_folder().unwrap();
let builder = circom_from_folder().unwrap();
let proving_key = zkey_from_folder();
let verification_key = vk_from_folder();
let builder = circom_from_folder();
// We compute witness from the json input
let rln_witness = get_test_witness();
@ -156,9 +156,9 @@ mod test {
assert_eq!(rln_witness_deser, rln_witness);
// We generate all relevant keys
let proving_key = zkey_from_folder().unwrap();
let verification_key = vk_from_folder().unwrap();
let builder = circom_from_folder().unwrap();
let proving_key = zkey_from_folder();
let verification_key = vk_from_folder();
let builder = circom_from_folder();
// Let's generate a zkSNARK proof
let proof = generate_proof(builder, &proving_key, &rln_witness_deser).unwrap();