From c401c0b21d56aaaf696cb8495146360e9558e5a1 Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Tue, 20 Sep 2022 08:22:46 -0400 Subject: [PATCH] feat: wasm (#38) --- .github/workflows/ci.yml | 7 +- Cargo.toml | 1 + rln-wasm/Cargo.toml | 27 ++ rln-wasm/README.md | 20 ++ rln-wasm/resources/witness_calculator.js | 331 +++++++++++++++++++++++ rln-wasm/src/lib.rs | 222 +++++++++++++++ rln-wasm/src/utils.js | 18 ++ rln/Cargo.toml | 17 +- rln/src/circuit.rs | 18 +- rln/src/ffi.rs | 7 +- rln/src/lib.rs | 4 +- rln/src/protocol.rs | 145 ++++++++-- rln/src/public.rs | 217 ++++++++++++--- rln/src/utils.rs | 12 +- 14 files changed, 973 insertions(+), 73 deletions(-) create mode 100644 rln-wasm/Cargo.toml create mode 100644 rln-wasm/README.md create mode 100644 rln-wasm/resources/witness_calculator.js create mode 100644 rln-wasm/src/lib.rs create mode 100644 rln-wasm/src/utils.js diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2d17530..8c46ff6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,7 +22,7 @@ jobs: run: git submodule update --init --recursive - name: cargo test run: | - cargo test + cargo test --release --workspace --exclude rln-wasm lint: runs-on: ubuntu-latest steps: @@ -40,6 +40,9 @@ jobs: - name: cargo fmt run: cargo fmt --all -- --check - name: cargo clippy - run: cargo clippy + run: | + (cd multiplier && cargo clippy) + (cd rln && cargo clippy) + (cd semaphore && cargo clippy) # Currently not treating warnings as error, too noisy # -- -D warnings diff --git a/Cargo.toml b/Cargo.toml index 4eaf11c..5b96cb0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,4 +3,5 @@ members = [ "multiplier", "semaphore", "rln", + "rln-wasm", ] diff --git a/rln-wasm/Cargo.toml b/rln-wasm/Cargo.toml new file mode 100644 index 0000000..be7069f --- /dev/null +++ b/rln-wasm/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "rln-wasm" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib", "rlib"] + +[dependencies] +rln = { path = "../rln", default-features = false, features = ["wasm"] } +num-bigint = { version = "0.4", default-features = false, features = ["rand", "serde"] } +wasmer = { version = "2.3", default-features = false, features = ["js", "std"] } +web-sys = {version = "0.3", features=["console"]} +getrandom = { version = "0.2.7", default-features = false, features = ["js"] } +wasm-bindgen = "0.2.63" +serde-wasm-bindgen = "0.4" +js-sys = "0.3.59" +console_error_panic_hook = "0.1.7" +serde_json = "1.0.85" + +[dev-dependencies] +wasm-bindgen-test = "0.3.0" +wasm-bindgen-futures = "0.4.33" + +[profile.release] +debug = true + diff --git a/rln-wasm/README.md b/rln-wasm/README.md new file mode 100644 index 0000000..fa14cad --- /dev/null +++ b/rln-wasm/README.md @@ -0,0 +1,20 @@ +# RLN for WASM +This library is used in [waku-org/js-rln](https://github.com/waku-org/js-rln/) + +## Building the library +1. Make sure you have nodejs installed and the `build-essential` package if using ubuntu. +2. Install wasm-pack +``` +curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh +``` +3. Compile zerokit for `wasm32-unknown-unknown`: +``` +cd rln-wasm +wasm-pack build --release +``` + +## Running tests +``` +cd rln-wasm +wasm-pack test --node --release +``` \ No newline at end of file diff --git a/rln-wasm/resources/witness_calculator.js b/rln-wasm/resources/witness_calculator.js new file mode 100644 index 0000000..1570779 --- /dev/null +++ b/rln-wasm/resources/witness_calculator.js @@ -0,0 +1,331 @@ +module.exports = async function builder(code, options) { + + options = options || {}; + + let wasmModule; + try { + wasmModule = await WebAssembly.compile(code); + } catch (err) { + console.log(err); + console.log("\nTry to run circom --c in order to generate c++ code instead\n"); + throw new Error(err); + } + + let wc; + + let errStr = ""; + let msgStr = ""; + + const instance = await WebAssembly.instantiate(wasmModule, { + runtime: { + exceptionHandler : function(code) { + let err; + if (code == 1) { + err = "Signal not found.\n"; + } else if (code == 2) { + err = "Too many signals set.\n"; + } else if (code == 3) { + err = "Signal already set.\n"; + } else if (code == 4) { + err = "Assert Failed.\n"; + } else if (code == 5) { + err = "Not enough memory.\n"; + } else if (code == 6) { + err = "Input signal array access exceeds the size.\n"; + } else { + err = "Unknown error.\n"; + } + throw new Error(err + errStr); + }, + printErrorMessage : function() { + errStr += getMessage() + "\n"; + // console.error(getMessage()); + }, + writeBufferMessage : function() { + const msg = getMessage(); + // Any calls to `log()` will always end with a `\n`, so that's when we print and reset + if (msg === "\n") { + console.log(msgStr); + msgStr = ""; + } else { + // If we've buffered other content, put a space in between the items + if (msgStr !== "") { + msgStr += " " + } + // Then append the message to the message we are creating + msgStr += msg; + } + }, + showSharedRWMemory : function() { + printSharedRWMemory (); + } + + } + }); + + const sanityCheck = + options +// options && +// ( +// options.sanityCheck || +// options.logGetSignal || +// options.logSetSignal || +// options.logStartComponent || +// options.logFinishComponent +// ); + + + wc = new WitnessCalculator(instance, sanityCheck); + return wc; + + function getMessage() { + var message = ""; + var c = instance.exports.getMessageChar(); + while ( c != 0 ) { + message += String.fromCharCode(c); + c = instance.exports.getMessageChar(); + } + return message; + } + + function printSharedRWMemory () { + const shared_rw_memory_size = instance.exports.getFieldNumLen32(); + const arr = new Uint32Array(shared_rw_memory_size); + for (let j=0; j { + const h = fnvHash(k); + const hMSB = parseInt(h.slice(0,8), 16); + const hLSB = parseInt(h.slice(8,16), 16); + const fArr = flatArray(input[k]); + let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB); + if (signalSize < 0){ + throw new Error(`Signal ${k} not found\n`); + } + if (fArr.length < signalSize) { + throw new Error(`Not enough values for input signal ${k}\n`); + } + if (fArr.length > signalSize) { + throw new Error(`Too many values for input signal ${k}\n`); + } + for (let i=0; i0) { + res.unshift(0); + i--; + } + } + return res; +} + +function fromArray32(arr) { //returns a BigInt + var res = BigInt(0); + const radix = BigInt(0x100000000); + for (let i = 0; i, +} + +#[allow(clippy::not_unsafe_ptr_arg_deref)] +#[wasm_bindgen(js_name = newRLN)] +pub fn wasm_new(tree_height: usize, zkey: Uint8Array, vk: Uint8Array) -> *mut RLNWrapper { + let instance = RLN::new_with_params(tree_height, zkey.to_vec(), vk.to_vec()); + let wrapper = RLNWrapper { instance }; + Box::into_raw(Box::new(wrapper)) +} + +#[allow(clippy::not_unsafe_ptr_arg_deref)] +#[wasm_bindgen(js_name = getSerializedRLNWitness)] +pub fn wasm_get_serialized_rln_witness(ctx: *mut RLNWrapper, input: Uint8Array) -> Uint8Array { + let wrapper = unsafe { &mut *ctx }; + let rln_witness = wrapper + .instance + .get_serialized_rln_witness(&input.to_vec()[..]); + + Uint8Array::from(&rln_witness[..]) +} + +#[allow(clippy::not_unsafe_ptr_arg_deref)] +#[wasm_bindgen(js_name = insertMember)] +pub fn wasm_set_next_leaf(ctx: *mut RLNWrapper, input: Uint8Array) -> Result<(), String> { + let wrapper = unsafe { &mut *ctx }; + if wrapper.instance.set_next_leaf(&input.to_vec()[..]).is_ok() { + Ok(()) + } else { + Err("could not insert member into merkle tree".into()) + } +} + +#[allow(clippy::not_unsafe_ptr_arg_deref)] +#[wasm_bindgen(js_name = RLNWitnessToJson)] +pub fn rln_witness_to_json(ctx: *mut RLNWrapper, serialized_witness: Uint8Array) -> Object { + let wrapper = unsafe { &mut *ctx }; + let inputs = wrapper + .instance + .get_rln_witness_json(&serialized_witness.to_vec()[..]) + .unwrap(); + + let js_value = serde_wasm_bindgen::to_value(&inputs).unwrap(); + let obj = Object::from_entries(&js_value); + obj.unwrap() +} + +#[allow(clippy::not_unsafe_ptr_arg_deref)] +#[wasm_bindgen] +pub fn generate_rln_proof_with_witness( + ctx: *mut RLNWrapper, + calculated_witness: Vec, + serialized_witness: Uint8Array, +) -> Result { + let wrapper = unsafe { &mut *ctx }; + + let witness_vec: Vec = calculated_witness + .iter() + .map(|v| { + v.to_string(10) + .unwrap() + .as_string() + .unwrap() + .parse::() + .unwrap() + }) + .collect(); + + let mut output_data: Vec = Vec::new(); + + if wrapper + .instance + .generate_rln_proof_with_witness(witness_vec, serialized_witness.to_vec(), &mut output_data) + .is_ok() + { + let result = Uint8Array::from(&output_data[..]); + std::mem::forget(output_data); + Ok(result) + } else { + std::mem::forget(output_data); + Err("could not generate proof".into()) + } +} + +#[allow(clippy::not_unsafe_ptr_arg_deref)] +#[wasm_bindgen(js_name = generateMembershipKey)] +pub fn wasm_key_gen(ctx: *const RLNWrapper) -> Result { + let wrapper = unsafe { &*ctx }; + let mut output_data: Vec = Vec::new(); + if wrapper.instance.key_gen(&mut output_data).is_ok() { + let result = Uint8Array::from(&output_data[..]); + std::mem::forget(output_data); + Ok(result) + } else { + std::mem::forget(output_data); + Err("could not generate membership keys".into()) + } +} + +#[allow(clippy::not_unsafe_ptr_arg_deref)] +#[wasm_bindgen(js_name = verifyProof)] +pub fn wasm_verify(ctx: *const RLNWrapper, proof: Uint8Array) -> bool { + let wrapper = unsafe { &*ctx }; + if match wrapper.instance.verify(&proof.to_vec()[..]) { + Ok(verified) => verified, + Err(_) => return false, + } { + return true; + } + + false +} + +#[cfg(test)] +mod tests { + use super::*; + use rln::circuit::TEST_TREE_HEIGHT; + use wasm_bindgen_test::wasm_bindgen_test; + + #[wasm_bindgen(module = "/src/utils.js")] + extern "C" { + #[wasm_bindgen(catch)] + fn read_file(path: &str) -> Result; + + #[wasm_bindgen(catch)] + async fn calculateWitness(circom_path: &str, input: Object) -> Result; + } + + #[wasm_bindgen_test] + pub async fn test_basic_flow() { + let tree_height = TEST_TREE_HEIGHT; + let circom_path = format!("../rln/resources/tree_height_{TEST_TREE_HEIGHT}/rln.wasm"); + let zkey_path = format!("../rln/resources/tree_height_{TEST_TREE_HEIGHT}/rln_final.zkey"); + let vk_path = + format!("../rln/resources/tree_height_{TEST_TREE_HEIGHT}/verification_key.json"); + let zkey = read_file(&zkey_path).unwrap(); + let vk = read_file(&vk_path).unwrap(); + + // Creating an instance of RLN + let rln_instance = wasm_new(tree_height, zkey, vk); + + // Creating membership key + let mem_keys = wasm_key_gen(rln_instance).unwrap(); + let idkey = mem_keys.subarray(0, 32); + let idcommitment = mem_keys.subarray(32, 64); + + // Insert PK + wasm_set_next_leaf(rln_instance, idcommitment).unwrap(); + + // Prepare the message + let mut signal = "Hello World".as_bytes().to_vec(); + let signal_len: u64 = signal.len() as u64; + + // Setting up the epoch (With 0s for the test) + let epoch = Uint8Array::new_with_length(32); + epoch.fill(0, 0, 32); + + let identity_index: u64 = 0; + + // Serializing the message + let mut serialized_vec: Vec = Vec::new(); + serialized_vec.append(&mut idkey.to_vec()); + serialized_vec.append(&mut identity_index.to_le_bytes().to_vec()); + serialized_vec.append(&mut epoch.to_vec()); + serialized_vec.append(&mut signal_len.to_le_bytes().to_vec()); + serialized_vec.append(&mut signal); + let serialized_message = Uint8Array::from(&serialized_vec[..]); + + let serialized_rln_witness = + wasm_get_serialized_rln_witness(rln_instance, serialized_message); + + // Obtaining inputs that should be sent to circom witness calculator + let json_inputs = rln_witness_to_json(rln_instance, serialized_rln_witness.clone()); + + // Calculating witness with JS + // (Using a JSON since wasm_bindgen does not like Result,JsValue>) + let calculated_witness_json = calculateWitness(&circom_path, json_inputs) + .await + .unwrap() + .as_string() + .unwrap(); + let calculated_witness_vec_str: Vec = + serde_json::from_str(&calculated_witness_json).unwrap(); + let calculated_witness: Vec = calculated_witness_vec_str + .iter() + .map(|x| JsBigInt::new(&x.into()).unwrap()) + .collect(); + + // Generating proof + let proof = generate_rln_proof_with_witness( + rln_instance, + calculated_witness.into(), + serialized_rln_witness, + ) + .unwrap(); + + // Validate Proof + let is_proof_valid = wasm_verify(rln_instance, proof); + + assert!( + is_proof_valid, + "validating proof generated with wasm failed" + ); + } +} diff --git a/rln-wasm/src/utils.js b/rln-wasm/src/utils.js new file mode 100644 index 0000000..e89807a --- /dev/null +++ b/rln-wasm/src/utils.js @@ -0,0 +1,18 @@ +const fs = require("fs"); + +// Utils functions for loading circom witness calculator and reading files from test + +module.exports = { + read_file: function (path) { + return fs.readFileSync(path); + }, + + calculateWitness: async function(circom_path, inputs){ + const wc = require("resources/witness_calculator.js"); + const wasmFile = fs.readFileSync(circom_path); + const wasmFileBuffer = wasmFile.slice(wasmFile.byteOffset, wasmFile.byteOffset + wasmFile.byteLength); + const witnessCalculator = await wc(wasmFileBuffer); + const calculatedWitness = await witnessCalculator.calculateWitness(inputs, false); + return JSON.stringify(calculatedWitness, (key, value) => typeof value === "bigint" ? value.toString() : value); + } +} diff --git a/rln/Cargo.toml b/rln/Cargo.toml index 2c26f50..632dab5 100644 --- a/rln/Cargo.toml +++ b/rln/Cargo.toml @@ -6,18 +6,22 @@ edition = "2021" [lib] crate-type = ["cdylib", "rlib", "staticlib"] + [dependencies] # ZKP Generation -ark-ff = { version = "0.3.0", default-features = false, features = ["parallel", "asm"] } -ark-std = { version = "0.3.0", default-features = false, features = ["parallel"] } +ark-ec = { version = "0.3.0", default-features = false } +ark-ff = { version = "0.3.0", default-features = false, features = [ "asm"] } +ark-std = { version = "0.3.0", default-features = false } ark-bn254 = { version = "0.3.0" } -ark-groth16 = { git = "https://github.com/arkworks-rs/groth16", rev = "765817f", features = ["parallel"] } +ark-groth16 = { git = "https://github.com/arkworks-rs/groth16", rev = "765817f", default-features = false } ark-relations = { version = "0.3.0", default-features = false, features = [ "std" ] } ark-serialize = { version = "0.3.0", default-features = false } -ark-circom = { git = "https://github.com/gakonst/ark-circom", rev = "06eb075", features = ["circom-2"] } +ark-circom = { git = "https://github.com/vacp2p/ark-circom", branch = "wasm", default-features = false, features = ["circom-2"] } #ark-circom = { git = "https://github.com/vacp2p/ark-circom", branch = "no-ethers-core", features = ["circom-2"] } -wasmer = "2.3.0" + +# WASM +wasmer = { version = "2.3.0", default-features = false } # error handling color-eyre = "0.5.11" @@ -39,4 +43,7 @@ serde_json = "1.0.48" hex-literal = "0.3.4" [features] +default = ["parallel", "wasmer/sys-default"] fullmerkletree = [] +parallel = ["ark-ec/parallel", "ark-ff/parallel", "ark-std/parallel", "ark-groth16/parallel"] +wasm = ["wasmer/js", "wasmer/std"] diff --git a/rln/src/circuit.rs b/rln/src/circuit.rs index 45e85b8..14f23a1 100644 --- a/rln/src/circuit.rs +++ b/rln/src/circuit.rs @@ -4,18 +4,25 @@ use ark_bn254::{ Bn254, Fq as ArkFq, Fq2 as ArkFq2, Fr as ArkFr, G1Affine as ArkG1Affine, G1Projective as ArkG1Projective, G2Affine as ArkG2Affine, G2Projective as ArkG2Projective, }; -use ark_circom::{read_zkey, WitnessCalculator}; +use ark_circom::read_zkey; use ark_groth16::{ProvingKey, VerifyingKey}; use ark_relations::r1cs::ConstraintMatrices; +use cfg_if::cfg_if; use num_bigint::BigUint; -use once_cell::sync::OnceCell; use serde_json::Value; use std::fs::File; use std::io::{Cursor, Error, ErrorKind, Result}; use std::path::Path; use std::str::FromStr; -use std::sync::Mutex; -use wasmer::{Module, Store}; + +cfg_if! { + if #[cfg(not(target_arch = "wasm32"))] { + use ark_circom::{WitnessCalculator}; + use once_cell::sync::OnceCell; + use std::sync::Mutex; + use wasmer::{Module, Store}; + } +} const ZKEY_FILENAME: &str = "rln_final.zkey"; const VK_FILENAME: &str = "verifying_key.json"; @@ -109,9 +116,11 @@ pub fn vk_from_folder(resources_folder: &str) -> Result> { } } +#[cfg(not(target_arch = "wasm32"))] static WITNESS_CALCULATOR: OnceCell> = OnceCell::new(); // Initializes the witness calculator using a bytes vector +#[cfg(not(target_arch = "wasm32"))] pub fn circom_from_raw(wasm_buffer: Vec) -> &'static Mutex { WITNESS_CALCULATOR.get_or_init(|| { let store = Store::default(); @@ -123,6 +132,7 @@ pub fn circom_from_raw(wasm_buffer: Vec) -> &'static Mutex &'static Mutex { // We read the wasm file let wasm_path = format!("{resources_folder}{WASM_FILENAME}"); diff --git a/rln/src/ffi.rs b/rln/src/ffi.rs index 5d99661..ea18f4f 100644 --- a/rln/src/ffi.rs +++ b/rln/src/ffi.rs @@ -59,7 +59,12 @@ pub extern "C" fn new_with_params( let circom_data = <&[u8]>::from(unsafe { &*circom_buffer }); let zkey_data = <&[u8]>::from(unsafe { &*zkey_buffer }); let vk_data = <&[u8]>::from(unsafe { &*vk_buffer }); - let rln = RLN::new_with_params(tree_height, circom_data, zkey_data, vk_data); + let rln = RLN::new_with_params( + tree_height, + circom_data.to_vec(), + zkey_data.to_vec(), + vk_data.to_vec(), + ); unsafe { *ctx = Box::into_raw(Box::new(rln)) }; true } diff --git a/rln/src/lib.rs b/rln/src/lib.rs index d2e0d58..7598b07 100644 --- a/rln/src/lib.rs +++ b/rln/src/lib.rs @@ -1,7 +1,6 @@ #![allow(dead_code)] pub mod circuit; -pub mod ffi; pub mod merkle_tree; pub mod poseidon_constants; pub mod poseidon_hash; @@ -10,6 +9,9 @@ pub mod protocol; pub mod public; pub mod utils; +#[cfg(not(target_arch = "wasm32"))] +pub mod ffi; + #[cfg(test)] mod test { diff --git a/rln/src/protocol.rs b/rln/src/protocol.rs index 5bd4fc8..0a8ff49 100644 --- a/rln/src/protocol.rs +++ b/rln/src/protocol.rs @@ -11,6 +11,7 @@ use ark_std::{rand::thread_rng, UniformRand}; use color_eyre::Result; use num_bigint::BigInt; use rand::Rng; +#[cfg(not(target_arch = "wasm32"))] use std::sync::Mutex; #[cfg(debug_assertions)] use std::time::Instant; @@ -22,6 +23,7 @@ use crate::poseidon_hash::poseidon_hash; use crate::poseidon_tree::*; use crate::public::RLN_IDENTIFIER; use crate::utils::*; +use cfg_if::cfg_if; /////////////////////////////////////////////////////// // RLN Witness data structure and utility functions @@ -121,12 +123,9 @@ pub fn proof_inputs_to_rln_witness( let signal_len = u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap()); all_read += 8; - let signal: Vec = - serialized[all_read..all_read + usize::try_from(signal_len).unwrap()].to_vec(); + let signal: Vec = serialized[all_read..all_read + (signal_len as usize)].to_vec(); - let merkle_proof = tree - .proof(usize::try_from(id_index).unwrap()) - .expect("proof should exist"); + let merkle_proof = tree.proof(id_index as usize).expect("proof should exist"); let path_elements = merkle_proof.get_path_elements(); let identity_path_index = merkle_proof.get_path_index(); @@ -374,16 +373,70 @@ pub enum ProofError { SynthesisError(#[from] SynthesisError), } -/// Generates a RLN proof -/// -/// # Errors -/// -/// Returns a [`ProofError`] if proving fails. -pub fn generate_proof( - witness_calculator: &Mutex, +fn calculate_witness_element(witness: Vec) -> Result> { + use ark_ff::{FpParameters, PrimeField}; + let modulus = <::Params as FpParameters>::MODULUS; + + // convert it to field elements + use num_traits::Signed; + let witness = witness + .into_iter() + .map(|w| { + let w = if w.sign() == num_bigint::Sign::Minus { + // Need to negate the witness element if negative + modulus.into() - w.abs().to_biguint().unwrap() + } else { + w.to_biguint().unwrap() + }; + E::Fr::from(w) + }) + .collect::>(); + + Ok(witness) +} + +pub fn generate_proof_with_witness( + witness: Vec, proving_key: &(ProvingKey, ConstraintMatrices), - rln_witness: &RLNWitnessInput, ) -> Result, ProofError> { + // If in debug mode, we measure and later print time take to compute witness + #[cfg(debug_assertions)] + let now = Instant::now(); + + let full_assignment = calculate_witness_element::(witness) + .map_err(ProofError::WitnessError) + .unwrap(); + + #[cfg(debug_assertions)] + println!("witness generation took: {:.2?}", now.elapsed()); + + // Random Values + let mut rng = thread_rng(); + let r = Fr::rand(&mut rng); + let s = Fr::rand(&mut rng); + + // If in debug mode, we measure and later print time take to compute proof + #[cfg(debug_assertions)] + let now = Instant::now(); + + let proof = create_proof_with_reduction_and_matrices::<_, CircomReduction>( + &proving_key.0, + r, + s, + &proving_key.1, + proving_key.1.num_instance_variables, + proving_key.1.num_constraints, + full_assignment.as_slice(), + ) + .unwrap(); + + #[cfg(debug_assertions)] + println!("proof generation took: {:.2?}", now.elapsed()); + + Ok(proof) +} + +pub fn inputs_for_witness_calculation(rln_witness: &RLNWitnessInput) -> [(&str, Vec); 6] { // We confert the path indexes to field elements // TODO: check if necessary let mut path_elements = Vec::new(); @@ -398,7 +451,7 @@ pub fn generate_proof( .iter() .for_each(|v| identity_path_index.push(BigInt::from(*v))); - let inputs = [ + [ ( "identity_secret", vec![to_bigint(&rln_witness.identity_secret)], @@ -411,8 +464,21 @@ pub fn generate_proof( "rln_identifier", vec![to_bigint(&rln_witness.rln_identifier)], ), - ]; - let inputs = inputs + ] +} + +/// Generates a RLN proof +/// +/// # Errors +/// +/// Returns a [`ProofError`] if proving fails. +pub fn generate_proof( + #[cfg(not(target_arch = "wasm32"))] witness_calculator: &Mutex, + #[cfg(target_arch = "wasm32")] witness_calculator: &mut WitnessCalculator, + proving_key: &(ProvingKey, ConstraintMatrices), + rln_witness: &RLNWitnessInput, +) -> Result, ProofError> { + let inputs = inputs_for_witness_calculation(rln_witness) .into_iter() .map(|(name, values)| (name.to_string(), values)); @@ -420,11 +486,19 @@ pub fn generate_proof( #[cfg(debug_assertions)] let now = Instant::now(); - let full_assignment = witness_calculator - .lock() - .expect("witness_calculator mutex should not get poisoned") - .calculate_witness_element::(inputs, false) - .map_err(ProofError::WitnessError)?; + cfg_if! { + if #[cfg(target_arch = "wasm32")] { + let full_assignment = witness_calculator + .calculate_witness_element::(inputs, false) + .map_err(ProofError::WitnessError)?; + } else { + let full_assignment = witness_calculator + .lock() + .expect("witness_calculator mutex should not get poisoned") + .calculate_witness_element::(inputs, false) + .map_err(ProofError::WitnessError)?; + } + } #[cfg(debug_assertions)] println!("witness generation took: {:.2?}", now.elapsed()); @@ -490,3 +564,32 @@ pub fn verify_proof( Ok(verified) } + +/// Get CIRCOM JSON inputs +/// +/// Returns a JSON object containing the inputs necessary to calculate +/// the witness with CIRCOM on javascript +pub fn get_json_inputs(rln_witness: &RLNWitnessInput) -> serde_json::Value { + let mut path_elements = Vec::new(); + rln_witness + .path_elements + .iter() + .for_each(|v| path_elements.push(to_bigint(v).to_str_radix(10))); + + let mut identity_path_index = Vec::new(); + rln_witness + .identity_path_index + .iter() + .for_each(|v| identity_path_index.push(BigInt::from(*v).to_str_radix(10))); + + let inputs = serde_json::json!({ + "identity_secret": to_bigint(&rln_witness.identity_secret).to_str_radix(10), + "path_elements": path_elements, + "identity_path_index": identity_path_index, + "x": to_bigint(&rln_witness.x).to_str_radix(10), + "epoch": format!("0x{:064x}", to_bigint(&rln_witness.epoch)), + "rln_identifier": to_bigint(&rln_witness.rln_identifier).to_str_radix(10), + }); + + inputs +} diff --git a/rln/src/public.rs b/rln/src/public.rs index c972350..fcf5bac 100644 --- a/rln/src/public.rs +++ b/rln/src/public.rs @@ -1,21 +1,28 @@ -/// This is the main public API for RLN module. It is used by the FFI, and should be -/// used by tests etc as well -use ark_circom::WitnessCalculator; -use ark_groth16::Proof as ArkProof; -use ark_groth16::{ProvingKey, VerifyingKey}; -use ark_relations::r1cs::ConstraintMatrices; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use std::default::Default; -use std::io::{self, Cursor, Read, Result, Write}; -use std::sync::Mutex; - -use crate::circuit::{ - circom_from_folder, circom_from_raw, vk_from_folder, vk_from_raw, zkey_from_folder, - zkey_from_raw, Curve, Fr, TEST_RESOURCES_FOLDER, TEST_TREE_HEIGHT, -}; +use crate::circuit::{vk_from_raw, zkey_from_raw, Curve, Fr}; use crate::poseidon_tree::PoseidonTree; use crate::protocol::*; use crate::utils::*; +/// This is the main public API for RLN module. It is used by the FFI, and should be +/// used by tests etc as well +use ark_groth16::Proof as ArkProof; +use ark_groth16::{ProvingKey, VerifyingKey}; +use ark_relations::r1cs::ConstraintMatrices; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, Write}; +use cfg_if::cfg_if; +use num_bigint::BigInt; +use std::io::Cursor; +use std::io::{self, Result}; + +cfg_if! { + if #[cfg(not(target_arch = "wasm32"))] { + use std::default::Default; + use std::sync::Mutex; + use crate::circuit::{circom_from_folder, vk_from_folder, circom_from_raw, zkey_from_folder, TEST_RESOURCES_FOLDER, TEST_TREE_HEIGHT}; + use ark_circom::WitnessCalculator; + } else { + use std::marker::*; + } +} // Application specific RLN identifier pub const RLN_IDENTIFIER: &[u8] = b"zerokit/rln/010203040506070809"; @@ -23,13 +30,21 @@ pub const RLN_IDENTIFIER: &[u8] = b"zerokit/rln/010203040506070809"; // TODO Add Engine here? i.e. not // TODO Assuming we want to use IncrementalMerkleTree, figure out type/trait conversions pub struct RLN<'a> { - witness_calculator: &'a Mutex, proving_key: Result<(ProvingKey, ConstraintMatrices)>, verification_key: Result>, tree: PoseidonTree, + + // The witness calculator can't be loaded in zerokit. Since this struct + // contains a lifetime, a PhantomData is necessary to avoid a compiler + // error since the lifetime is not being used + #[cfg(not(target_arch = "wasm32"))] + witness_calculator: &'a Mutex, + #[cfg(target_arch = "wasm32")] + _marker: PhantomData<&'a ()>, } impl RLN<'_> { + #[cfg(not(target_arch = "wasm32"))] pub fn new(tree_height: usize, mut input_data: R) -> RLN<'static> { // We read input let mut input: Vec = Vec::new(); @@ -50,23 +65,18 @@ impl RLN<'_> { proving_key, verification_key, tree, + #[cfg(target_arch = "wasm32")] + _marker: PhantomData, } } - pub fn new_with_params( + pub fn new_with_params( tree_height: usize, - mut circom_data: R, - mut zkey_data: R, - mut vk_data: R, + #[cfg(not(target_arch = "wasm32"))] circom_vec: Vec, + zkey_vec: Vec, + vk_vec: Vec, ) -> RLN<'static> { - // We read input - let mut circom_vec: Vec = Vec::new(); - circom_data.read_to_end(&mut circom_vec).unwrap(); - let mut zkey_vec: Vec = Vec::new(); - zkey_data.read_to_end(&mut zkey_vec).unwrap(); - let mut vk_vec: Vec = Vec::new(); - vk_data.read_to_end(&mut vk_vec).unwrap(); - + #[cfg(not(target_arch = "wasm32"))] let witness_calculator = circom_from_raw(circom_vec); let proving_key = zkey_from_raw(&zkey_vec); @@ -76,10 +86,13 @@ impl RLN<'_> { let tree = PoseidonTree::default(tree_height); RLN { + #[cfg(not(target_arch = "wasm32"))] witness_calculator, proving_key, verification_key, tree, + #[cfg(target_arch = "wasm32")] + _marker: PhantomData, } } @@ -165,6 +178,7 @@ impl RLN<'_> { //////////////////////////////////////////////////////// // zkSNARK APIs //////////////////////////////////////////////////////// + #[cfg(not(target_arch = "wasm32"))] pub fn prove( &mut self, mut input_data: R, @@ -182,7 +196,7 @@ impl RLN<'_> { */ let proof = generate_proof( - self.witness_calculator, + &mut self.witness_calculator, self.proving_key.as_ref().unwrap(), &rln_witness, ) @@ -213,9 +227,29 @@ impl RLN<'_> { Ok(verified) } + /// Get the serialized rln_witness for some input + pub fn get_serialized_rln_witness(&mut self, mut input_data: R) -> Vec { + // We read input RLN witness and we deserialize it + let mut witness_byte: Vec = Vec::new(); + input_data.read_to_end(&mut witness_byte).unwrap(); + let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte); + + serialize_witness(&rln_witness) + } + + /// Get JSON inputs for serialized RLN witness + pub fn get_rln_witness_json( + &mut self, + serialized_witness: &[u8], + ) -> io::Result { + let (rln_witness, _) = deserialize_witness(serialized_witness); + Ok(get_json_inputs(&rln_witness)) + } + // This API keeps partial compatibility with kilic's rln public API https://github.com/kilic/rln/blob/7ac74183f8b69b399e3bc96c1ae8ab61c026dc43/src/public.rs#L148 // input_data is [ id_key<32> | id_index<8> | epoch<32> | signal_len<8> | signal ] // output_data is [ proof<128> | share_y<32> | nullifier<32> | root<32> | epoch<32> | share_x<32> | rln_identifier<32> ] + #[cfg(not(target_arch = "wasm32"))] pub fn generate_rln_proof( &mut self, mut input_data: R, @@ -242,6 +276,29 @@ impl RLN<'_> { Ok(()) } + /// Generate RLN Proof using a witness calculated from outside zerokit + /// + /// output_data is [ proof<128> | share_y<32> | nullifier<32> | root<32> | epoch<32> | share_x<32> | rln_identifier<32> ] + pub fn generate_rln_proof_with_witness( + &mut self, + calculated_witness: Vec, + rln_witness_vec: Vec, + mut output_data: W, + ) -> io::Result<()> { + let (rln_witness, _) = deserialize_witness(&rln_witness_vec[..]); + let proof_values = proof_values_from_witness(&rln_witness); + + let proof = + generate_proof_with_witness(calculated_witness, self.proving_key.as_ref().unwrap()) + .unwrap(); + + // Note: we export a serialization of ark-groth16::Proof not semaphore::Proof + // This proof is compressed, i.e. 128 bytes long + proof.serialize(&mut output_data).unwrap(); + output_data.write_all(&serialize_proof_values(&proof_values))?; + Ok(()) + } + // Input data is serialized for Curve as: // [ proof<128> | share_y<32> | nullifier<32> | root<32> | epoch<32> | share_x<32> | rln_identifier<32> | signal_len<8> | signal ] pub fn verify_rln_proof(&self, mut input_data: R) -> io::Result { @@ -253,10 +310,8 @@ impl RLN<'_> { let (proof_values, read) = deserialize_proof_values(&serialized[all_read..].to_vec()); all_read += read; - let signal_len = usize::try_from(u64::from_le_bytes( - serialized[all_read..all_read + 8].try_into().unwrap(), - )) - .unwrap(); + let signal_len = + u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap()) as usize; all_read += 8; let signal: Vec = serialized[all_read..all_read + signal_len].to_vec(); @@ -299,6 +354,7 @@ impl RLN<'_> { } } +#[cfg(not(target_arch = "wasm32"))] impl Default for RLN<'_> { fn default() -> Self { let tree_height = TEST_TREE_HEIGHT; @@ -673,6 +729,101 @@ mod test { assert!(verified); } + #[test] + fn test_rln_with_witness() { + let tree_height = TEST_TREE_HEIGHT; + let no_of_leaves = 256; + + // We generate a vector of random leaves + let mut leaves: Vec = Vec::new(); + let mut rng = thread_rng(); + for _ in 0..no_of_leaves { + leaves.push(Fr::rand(&mut rng)); + } + + // We create a new RLN instance + let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER); + let mut rln = RLN::new(tree_height, input_buffer); + + // We add leaves in a batch into the tree + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves)); + rln.set_leaves(&mut buffer).unwrap(); + + // Generate identity pair + let (identity_secret, id_commitment) = keygen(); + + // We set as leaf id_commitment after storing its index + let identity_index = u64::try_from(rln.tree.leaves_set()).unwrap(); + let mut buffer = Cursor::new(fr_to_bytes_le(&id_commitment)); + rln.set_next_leaf(&mut buffer).unwrap(); + + // We generate a random signal + let mut rng = rand::thread_rng(); + let signal: [u8; 32] = rng.gen(); + let signal_len = u64::try_from(signal.len()).unwrap(); + + // We generate a random epoch + let epoch = hash_to_field(b"test-epoch"); + + // We prepare input for generate_rln_proof API + // input_data is [ id_key<32> | id_index<8> | epoch<32> | signal_len<8> | signal ] + let mut serialized: Vec = Vec::new(); + serialized.append(&mut fr_to_bytes_le(&identity_secret)); + serialized.append(&mut identity_index.to_le_bytes().to_vec()); + serialized.append(&mut fr_to_bytes_le(&epoch)); + serialized.append(&mut signal_len.to_le_bytes().to_vec()); + serialized.append(&mut signal.to_vec()); + + let mut input_buffer = Cursor::new(serialized); + + // We read input RLN witness and we deserialize it + let mut witness_byte: Vec = Vec::new(); + input_buffer.read_to_end(&mut witness_byte).unwrap(); + let (rln_witness, _) = proof_inputs_to_rln_witness(&mut rln.tree, &witness_byte); + + let serialized_witness = serialize_witness(&rln_witness); + + // Calculate witness outside zerokit (simulating what JS is doing) + let inputs = inputs_for_witness_calculation(&rln_witness) + .into_iter() + .map(|(name, values)| (name.to_string(), values)); + let calculated_witness = rln + .witness_calculator + .lock() + .expect("witness_calculator mutex should not get poisoned") + .calculate_witness_element::(inputs, false) + .map_err(ProofError::WitnessError) + .unwrap(); + + let calculated_witness_vec: Vec = calculated_witness + .into_iter() + .map(|v| to_bigint(&v)) + .collect(); + + // Generating the proof + let mut output_buffer = Cursor::new(Vec::::new()); + rln.generate_rln_proof_with_witness( + calculated_witness_vec, + serialized_witness, + &mut output_buffer, + ) + .unwrap(); + + // output_data is [ proof<128> | share_y<32> | nullifier<32> | root<32> | epoch<32> | share_x<32> | rln_identifier<32> ] + let mut proof_data = output_buffer.into_inner(); + + // We prepare input for verify_rln_proof API + // input_data is [ proof<128> | share_y<32> | nullifier<32> | root<32> | epoch<32> | share_x<32> | rln_identifier<32> | signal_len<8> | signal ] + // that is [ proof_data || signal_len<8> | signal ] + proof_data.append(&mut signal_len.to_le_bytes().to_vec()); + proof_data.append(&mut signal.to_vec()); + + let mut input_buffer = Cursor::new(proof_data); + let verified = rln.verify_rln_proof(&mut input_buffer).unwrap(); + + assert!(verified); + } + #[test] fn test_hash_to_field() { let rln = RLN::default(); diff --git a/rln/src/utils.rs b/rln/src/utils.rs index f9ab7a6..f424911 100644 --- a/rln/src/utils.rs +++ b/rln/src/utils.rs @@ -85,7 +85,7 @@ pub fn fr_to_bytes_be(input: &Fr) -> Vec { pub fn vec_fr_to_bytes_le(input: &[Fr]) -> Vec { let mut bytes: Vec = Vec::new(); //We store the vector length - bytes.extend(input.len().to_le_bytes().to_vec()); + bytes.extend(u64::try_from(input.len()).unwrap().to_le_bytes().to_vec()); // We store each element input.iter().for_each(|el| bytes.extend(fr_to_bytes_le(el))); @@ -95,7 +95,7 @@ pub fn vec_fr_to_bytes_le(input: &[Fr]) -> Vec { pub fn vec_fr_to_bytes_be(input: &[Fr]) -> Vec { let mut bytes: Vec = Vec::new(); //We store the vector length - bytes.extend(input.len().to_be_bytes().to_vec()); + bytes.extend(u64::try_from(input.len()).unwrap().to_be_bytes().to_vec()); // We store each element input.iter().for_each(|el| bytes.extend(fr_to_bytes_be(el))); @@ -121,7 +121,7 @@ pub fn vec_u8_to_bytes_be(input: Vec) -> Vec { pub fn bytes_le_to_vec_u8(input: &[u8]) -> (Vec, usize) { let mut read: usize = 0; - let len = usize::try_from(u64::from_le_bytes(input[0..8].try_into().unwrap())).unwrap(); + let len = u64::from_le_bytes(input[0..8].try_into().unwrap()) as usize; read += 8; let res = input[8..8 + len].to_vec(); @@ -133,7 +133,7 @@ pub fn bytes_le_to_vec_u8(input: &[u8]) -> (Vec, usize) { pub fn bytes_be_to_vec_u8(input: &[u8]) -> (Vec, usize) { let mut read: usize = 0; - let len = usize::try_from(u64::from_be_bytes(input[0..8].try_into().unwrap())).unwrap(); + let len = u64::from_be_bytes(input[0..8].try_into().unwrap()) as usize; read += 8; let res = input[8..8 + len].to_vec(); @@ -147,7 +147,7 @@ pub fn bytes_le_to_vec_fr(input: &[u8]) -> (Vec, usize) { let mut read: usize = 0; let mut res: Vec = Vec::new(); - let len = usize::try_from(u64::from_le_bytes(input[0..8].try_into().unwrap())).unwrap(); + let len = u64::from_le_bytes(input[0..8].try_into().unwrap()) as usize; read += 8; let el_size = fr_byte_size(); @@ -164,7 +164,7 @@ pub fn bytes_be_to_vec_fr(input: &[u8]) -> (Vec, usize) { let mut read: usize = 0; let mut res: Vec = Vec::new(); - let len = usize::try_from(u64::from_be_bytes(input[0..8].try_into().unwrap())).unwrap(); + let len = u64::from_be_bytes(input[0..8].try_into().unwrap()) as usize; read += 8; let el_size = fr_byte_size();