feat: wasm (#38)

This commit is contained in:
Richard Ramos 2022-09-20 08:22:46 -04:00 committed by GitHub
parent 4f08818d7a
commit c401c0b21d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 973 additions and 73 deletions

View File

@ -22,7 +22,7 @@ jobs:
run: git submodule update --init --recursive
- name: cargo test
run: |
cargo test
cargo test --release --workspace --exclude rln-wasm
lint:
runs-on: ubuntu-latest
steps:
@ -40,6 +40,9 @@ jobs:
- name: cargo fmt
run: cargo fmt --all -- --check
- name: cargo clippy
run: cargo clippy
run: |
(cd multiplier && cargo clippy)
(cd rln && cargo clippy)
(cd semaphore && cargo clippy)
# Currently not treating warnings as error, too noisy
# -- -D warnings

View File

@ -3,4 +3,5 @@ members = [
"multiplier",
"semaphore",
"rln",
"rln-wasm",
]

27
rln-wasm/Cargo.toml Normal file
View File

@ -0,0 +1,27 @@
[package]
name = "rln-wasm"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib", "rlib"]
[dependencies]
rln = { path = "../rln", default-features = false, features = ["wasm"] }
num-bigint = { version = "0.4", default-features = false, features = ["rand", "serde"] }
wasmer = { version = "2.3", default-features = false, features = ["js", "std"] }
web-sys = {version = "0.3", features=["console"]}
getrandom = { version = "0.2.7", default-features = false, features = ["js"] }
wasm-bindgen = "0.2.63"
serde-wasm-bindgen = "0.4"
js-sys = "0.3.59"
console_error_panic_hook = "0.1.7"
serde_json = "1.0.85"
[dev-dependencies]
wasm-bindgen-test = "0.3.0"
wasm-bindgen-futures = "0.4.33"
[profile.release]
debug = true

20
rln-wasm/README.md Normal file
View File

@ -0,0 +1,20 @@
# RLN for WASM
This library is used in [waku-org/js-rln](https://github.com/waku-org/js-rln/)
## Building the library
1. Make sure you have nodejs installed and the `build-essential` package if using ubuntu.
2. Install wasm-pack
```
curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
```
3. Compile zerokit for `wasm32-unknown-unknown`:
```
cd rln-wasm
wasm-pack build --release
```
## Running tests
```
cd rln-wasm
wasm-pack test --node --release
```

View File

@ -0,0 +1,331 @@
module.exports = async function builder(code, options) {
options = options || {};
let wasmModule;
try {
wasmModule = await WebAssembly.compile(code);
} catch (err) {
console.log(err);
console.log("\nTry to run circom --c in order to generate c++ code instead\n");
throw new Error(err);
}
let wc;
let errStr = "";
let msgStr = "";
const instance = await WebAssembly.instantiate(wasmModule, {
runtime: {
exceptionHandler : function(code) {
let err;
if (code == 1) {
err = "Signal not found.\n";
} else if (code == 2) {
err = "Too many signals set.\n";
} else if (code == 3) {
err = "Signal already set.\n";
} else if (code == 4) {
err = "Assert Failed.\n";
} else if (code == 5) {
err = "Not enough memory.\n";
} else if (code == 6) {
err = "Input signal array access exceeds the size.\n";
} else {
err = "Unknown error.\n";
}
throw new Error(err + errStr);
},
printErrorMessage : function() {
errStr += getMessage() + "\n";
// console.error(getMessage());
},
writeBufferMessage : function() {
const msg = getMessage();
// Any calls to `log()` will always end with a `\n`, so that's when we print and reset
if (msg === "\n") {
console.log(msgStr);
msgStr = "";
} else {
// If we've buffered other content, put a space in between the items
if (msgStr !== "") {
msgStr += " "
}
// Then append the message to the message we are creating
msgStr += msg;
}
},
showSharedRWMemory : function() {
printSharedRWMemory ();
}
}
});
const sanityCheck =
options
// options &&
// (
// options.sanityCheck ||
// options.logGetSignal ||
// options.logSetSignal ||
// options.logStartComponent ||
// options.logFinishComponent
// );
wc = new WitnessCalculator(instance, sanityCheck);
return wc;
function getMessage() {
var message = "";
var c = instance.exports.getMessageChar();
while ( c != 0 ) {
message += String.fromCharCode(c);
c = instance.exports.getMessageChar();
}
return message;
}
function printSharedRWMemory () {
const shared_rw_memory_size = instance.exports.getFieldNumLen32();
const arr = new Uint32Array(shared_rw_memory_size);
for (let j=0; j<shared_rw_memory_size; j++) {
arr[shared_rw_memory_size-1-j] = instance.exports.readSharedRWMemory(j);
}
// If we've buffered other content, put a space in between the items
if (msgStr !== "") {
msgStr += " "
}
// Then append the value to the message we are creating
msgStr += (fromArray32(arr).toString());
}
};
class WitnessCalculator {
constructor(instance, sanityCheck) {
this.instance = instance;
this.version = this.instance.exports.getVersion();
this.n32 = this.instance.exports.getFieldNumLen32();
this.instance.exports.getRawPrime();
const arr = new Uint32Array(this.n32);
for (let i=0; i<this.n32; i++) {
arr[this.n32-1-i] = this.instance.exports.readSharedRWMemory(i);
}
this.prime = fromArray32(arr);
this.witnessSize = this.instance.exports.getWitnessSize();
this.sanityCheck = sanityCheck;
}
circom_version() {
return this.instance.exports.getVersion();
}
async _doCalculateWitness(input, sanityCheck) {
//input is assumed to be a map from signals to arrays of bigints
this.instance.exports.init((this.sanityCheck || sanityCheck) ? 1 : 0);
const keys = Object.keys(input);
var input_counter = 0;
keys.forEach( (k) => {
const h = fnvHash(k);
const hMSB = parseInt(h.slice(0,8), 16);
const hLSB = parseInt(h.slice(8,16), 16);
const fArr = flatArray(input[k]);
let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB);
if (signalSize < 0){
throw new Error(`Signal ${k} not found\n`);
}
if (fArr.length < signalSize) {
throw new Error(`Not enough values for input signal ${k}\n`);
}
if (fArr.length > signalSize) {
throw new Error(`Too many values for input signal ${k}\n`);
}
for (let i=0; i<fArr.length; i++) {
const arrFr = toArray32(BigInt(fArr[i])%this.prime,this.n32)
for (let j=0; j<this.n32; j++) {
this.instance.exports.writeSharedRWMemory(j,arrFr[this.n32-1-j]);
}
try {
this.instance.exports.setInputSignal(hMSB, hLSB,i);
input_counter++;
} catch (err) {
// console.log(`After adding signal ${i} of ${k}`)
throw new Error(err);
}
}
});
if (input_counter < this.instance.exports.getInputSize()) {
throw new Error(`Not all inputs have been set. Only ${input_counter} out of ${this.instance.exports.getInputSize()}`);
}
}
async calculateWitness(input, sanityCheck) {
const w = [];
await this._doCalculateWitness(input, sanityCheck);
for (let i=0; i<this.witnessSize; i++) {
this.instance.exports.getWitness(i);
const arr = new Uint32Array(this.n32);
for (let j=0; j<this.n32; j++) {
arr[this.n32-1-j] = this.instance.exports.readSharedRWMemory(j);
}
w.push(fromArray32(arr));
}
return w;
}
async calculateBinWitness(input, sanityCheck) {
const buff32 = new Uint32Array(this.witnessSize*this.n32);
const buff = new Uint8Array( buff32.buffer);
await this._doCalculateWitness(input, sanityCheck);
for (let i=0; i<this.witnessSize; i++) {
this.instance.exports.getWitness(i);
const pos = i*this.n32;
for (let j=0; j<this.n32; j++) {
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
}
}
return buff;
}
async calculateWTNSBin(input, sanityCheck) {
const buff32 = new Uint32Array(this.witnessSize*this.n32+this.n32+11);
const buff = new Uint8Array( buff32.buffer);
await this._doCalculateWitness(input, sanityCheck);
//"wtns"
buff[0] = "w".charCodeAt(0)
buff[1] = "t".charCodeAt(0)
buff[2] = "n".charCodeAt(0)
buff[3] = "s".charCodeAt(0)
//version 2
buff32[1] = 2;
//number of sections: 2
buff32[2] = 2;
//id section 1
buff32[3] = 1;
const n8 = this.n32*4;
//id section 1 length in 64bytes
const idSection1length = 8 + n8;
const idSection1lengthHex = idSection1length.toString(16);
buff32[4] = parseInt(idSection1lengthHex.slice(0,8), 16);
buff32[5] = parseInt(idSection1lengthHex.slice(8,16), 16);
//this.n32
buff32[6] = n8;
//prime number
this.instance.exports.getRawPrime();
var pos = 7;
for (let j=0; j<this.n32; j++) {
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
}
pos += this.n32;
// witness size
buff32[pos] = this.witnessSize;
pos++;
//id section 2
buff32[pos] = 2;
pos++;
// section 2 length
const idSection2length = n8*this.witnessSize;
const idSection2lengthHex = idSection2length.toString(16);
buff32[pos] = parseInt(idSection2lengthHex.slice(0,8), 16);
buff32[pos+1] = parseInt(idSection2lengthHex.slice(8,16), 16);
pos += 2;
for (let i=0; i<this.witnessSize; i++) {
this.instance.exports.getWitness(i);
for (let j=0; j<this.n32; j++) {
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
}
pos += this.n32;
}
return buff;
}
}
function toArray32(rem,size) {
const res = []; //new Uint32Array(size); //has no unshift
const radix = BigInt(0x100000000);
while (rem) {
res.unshift( Number(rem % radix));
rem = rem / radix;
}
if (size) {
var i = size - res.length;
while (i>0) {
res.unshift(0);
i--;
}
}
return res;
}
function fromArray32(arr) { //returns a BigInt
var res = BigInt(0);
const radix = BigInt(0x100000000);
for (let i = 0; i<arr.length; i++) {
res = res*radix + BigInt(arr[i]);
}
return res;
}
function flatArray(a) {
var res = [];
fillArray(res, a);
return res;
function fillArray(res, a) {
if (Array.isArray(a)) {
for (let i=0; i<a.length; i++) {
fillArray(res, a[i]);
}
} else {
res.push(a);
}
}
}
function fnvHash(str) {
const uint64_max = BigInt(2) ** BigInt(64);
let hash = BigInt("0xCBF29CE484222325");
for (var i = 0; i < str.length; i++) {
hash ^= BigInt(str[i].charCodeAt());
hash *= BigInt(0x100000001B3);
hash %= uint64_max;
}
let shash = hash.toString(16);
let n = 16 - shash.length;
shash = '0'.repeat(n).concat(shash);
return shash;
}

222
rln-wasm/src/lib.rs Normal file
View File

@ -0,0 +1,222 @@
extern crate wasm_bindgen;
extern crate web_sys;
use js_sys::{BigInt as JsBigInt, Object, Uint8Array};
use num_bigint::BigInt;
use rln::public::RLN;
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub fn init_panic_hook() {
console_error_panic_hook::set_once();
}
#[wasm_bindgen(js_name = RLN)]
pub struct RLNWrapper {
// The purpose of this wrapper is to hold a RLN instance with the 'static lifetime
// because wasm_bindgen does not allow returning elements with lifetimes
instance: RLN<'static>,
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = newRLN)]
pub fn wasm_new(tree_height: usize, zkey: Uint8Array, vk: Uint8Array) -> *mut RLNWrapper {
let instance = RLN::new_with_params(tree_height, zkey.to_vec(), vk.to_vec());
let wrapper = RLNWrapper { instance };
Box::into_raw(Box::new(wrapper))
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = getSerializedRLNWitness)]
pub fn wasm_get_serialized_rln_witness(ctx: *mut RLNWrapper, input: Uint8Array) -> Uint8Array {
let wrapper = unsafe { &mut *ctx };
let rln_witness = wrapper
.instance
.get_serialized_rln_witness(&input.to_vec()[..]);
Uint8Array::from(&rln_witness[..])
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = insertMember)]
pub fn wasm_set_next_leaf(ctx: *mut RLNWrapper, input: Uint8Array) -> Result<(), String> {
let wrapper = unsafe { &mut *ctx };
if wrapper.instance.set_next_leaf(&input.to_vec()[..]).is_ok() {
Ok(())
} else {
Err("could not insert member into merkle tree".into())
}
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = RLNWitnessToJson)]
pub fn rln_witness_to_json(ctx: *mut RLNWrapper, serialized_witness: Uint8Array) -> Object {
let wrapper = unsafe { &mut *ctx };
let inputs = wrapper
.instance
.get_rln_witness_json(&serialized_witness.to_vec()[..])
.unwrap();
let js_value = serde_wasm_bindgen::to_value(&inputs).unwrap();
let obj = Object::from_entries(&js_value);
obj.unwrap()
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen]
pub fn generate_rln_proof_with_witness(
ctx: *mut RLNWrapper,
calculated_witness: Vec<JsBigInt>,
serialized_witness: Uint8Array,
) -> Result<Uint8Array, String> {
let wrapper = unsafe { &mut *ctx };
let witness_vec: Vec<BigInt> = calculated_witness
.iter()
.map(|v| {
v.to_string(10)
.unwrap()
.as_string()
.unwrap()
.parse::<BigInt>()
.unwrap()
})
.collect();
let mut output_data: Vec<u8> = Vec::new();
if wrapper
.instance
.generate_rln_proof_with_witness(witness_vec, serialized_witness.to_vec(), &mut output_data)
.is_ok()
{
let result = Uint8Array::from(&output_data[..]);
std::mem::forget(output_data);
Ok(result)
} else {
std::mem::forget(output_data);
Err("could not generate proof".into())
}
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = generateMembershipKey)]
pub fn wasm_key_gen(ctx: *const RLNWrapper) -> Result<Uint8Array, String> {
let wrapper = unsafe { &*ctx };
let mut output_data: Vec<u8> = Vec::new();
if wrapper.instance.key_gen(&mut output_data).is_ok() {
let result = Uint8Array::from(&output_data[..]);
std::mem::forget(output_data);
Ok(result)
} else {
std::mem::forget(output_data);
Err("could not generate membership keys".into())
}
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = verifyProof)]
pub fn wasm_verify(ctx: *const RLNWrapper, proof: Uint8Array) -> bool {
let wrapper = unsafe { &*ctx };
if match wrapper.instance.verify(&proof.to_vec()[..]) {
Ok(verified) => verified,
Err(_) => return false,
} {
return true;
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use rln::circuit::TEST_TREE_HEIGHT;
use wasm_bindgen_test::wasm_bindgen_test;
#[wasm_bindgen(module = "/src/utils.js")]
extern "C" {
#[wasm_bindgen(catch)]
fn read_file(path: &str) -> Result<Uint8Array, JsValue>;
#[wasm_bindgen(catch)]
async fn calculateWitness(circom_path: &str, input: Object) -> Result<JsValue, JsValue>;
}
#[wasm_bindgen_test]
pub async fn test_basic_flow() {
let tree_height = TEST_TREE_HEIGHT;
let circom_path = format!("../rln/resources/tree_height_{TEST_TREE_HEIGHT}/rln.wasm");
let zkey_path = format!("../rln/resources/tree_height_{TEST_TREE_HEIGHT}/rln_final.zkey");
let vk_path =
format!("../rln/resources/tree_height_{TEST_TREE_HEIGHT}/verification_key.json");
let zkey = read_file(&zkey_path).unwrap();
let vk = read_file(&vk_path).unwrap();
// Creating an instance of RLN
let rln_instance = wasm_new(tree_height, zkey, vk);
// Creating membership key
let mem_keys = wasm_key_gen(rln_instance).unwrap();
let idkey = mem_keys.subarray(0, 32);
let idcommitment = mem_keys.subarray(32, 64);
// Insert PK
wasm_set_next_leaf(rln_instance, idcommitment).unwrap();
// Prepare the message
let mut signal = "Hello World".as_bytes().to_vec();
let signal_len: u64 = signal.len() as u64;
// Setting up the epoch (With 0s for the test)
let epoch = Uint8Array::new_with_length(32);
epoch.fill(0, 0, 32);
let identity_index: u64 = 0;
// Serializing the message
let mut serialized_vec: Vec<u8> = Vec::new();
serialized_vec.append(&mut idkey.to_vec());
serialized_vec.append(&mut identity_index.to_le_bytes().to_vec());
serialized_vec.append(&mut epoch.to_vec());
serialized_vec.append(&mut signal_len.to_le_bytes().to_vec());
serialized_vec.append(&mut signal);
let serialized_message = Uint8Array::from(&serialized_vec[..]);
let serialized_rln_witness =
wasm_get_serialized_rln_witness(rln_instance, serialized_message);
// Obtaining inputs that should be sent to circom witness calculator
let json_inputs = rln_witness_to_json(rln_instance, serialized_rln_witness.clone());
// Calculating witness with JS
// (Using a JSON since wasm_bindgen does not like Result<Vec<JsBigInt>,JsValue>)
let calculated_witness_json = calculateWitness(&circom_path, json_inputs)
.await
.unwrap()
.as_string()
.unwrap();
let calculated_witness_vec_str: Vec<String> =
serde_json::from_str(&calculated_witness_json).unwrap();
let calculated_witness: Vec<JsBigInt> = calculated_witness_vec_str
.iter()
.map(|x| JsBigInt::new(&x.into()).unwrap())
.collect();
// Generating proof
let proof = generate_rln_proof_with_witness(
rln_instance,
calculated_witness.into(),
serialized_rln_witness,
)
.unwrap();
// Validate Proof
let is_proof_valid = wasm_verify(rln_instance, proof);
assert!(
is_proof_valid,
"validating proof generated with wasm failed"
);
}
}

18
rln-wasm/src/utils.js Normal file
View File

@ -0,0 +1,18 @@
const fs = require("fs");
// Utils functions for loading circom witness calculator and reading files from test
module.exports = {
read_file: function (path) {
return fs.readFileSync(path);
},
calculateWitness: async function(circom_path, inputs){
const wc = require("resources/witness_calculator.js");
const wasmFile = fs.readFileSync(circom_path);
const wasmFileBuffer = wasmFile.slice(wasmFile.byteOffset, wasmFile.byteOffset + wasmFile.byteLength);
const witnessCalculator = await wc(wasmFileBuffer);
const calculatedWitness = await witnessCalculator.calculateWitness(inputs, false);
return JSON.stringify(calculatedWitness, (key, value) => typeof value === "bigint" ? value.toString() : value);
}
}

View File

@ -6,18 +6,22 @@ edition = "2021"
[lib]
crate-type = ["cdylib", "rlib", "staticlib"]
[dependencies]
# ZKP Generation
ark-ff = { version = "0.3.0", default-features = false, features = ["parallel", "asm"] }
ark-std = { version = "0.3.0", default-features = false, features = ["parallel"] }
ark-ec = { version = "0.3.0", default-features = false }
ark-ff = { version = "0.3.0", default-features = false, features = [ "asm"] }
ark-std = { version = "0.3.0", default-features = false }
ark-bn254 = { version = "0.3.0" }
ark-groth16 = { git = "https://github.com/arkworks-rs/groth16", rev = "765817f", features = ["parallel"] }
ark-groth16 = { git = "https://github.com/arkworks-rs/groth16", rev = "765817f", default-features = false }
ark-relations = { version = "0.3.0", default-features = false, features = [ "std" ] }
ark-serialize = { version = "0.3.0", default-features = false }
ark-circom = { git = "https://github.com/gakonst/ark-circom", rev = "06eb075", features = ["circom-2"] }
ark-circom = { git = "https://github.com/vacp2p/ark-circom", branch = "wasm", default-features = false, features = ["circom-2"] }
#ark-circom = { git = "https://github.com/vacp2p/ark-circom", branch = "no-ethers-core", features = ["circom-2"] }
wasmer = "2.3.0"
# WASM
wasmer = { version = "2.3.0", default-features = false }
# error handling
color-eyre = "0.5.11"
@ -39,4 +43,7 @@ serde_json = "1.0.48"
hex-literal = "0.3.4"
[features]
default = ["parallel", "wasmer/sys-default"]
fullmerkletree = []
parallel = ["ark-ec/parallel", "ark-ff/parallel", "ark-std/parallel", "ark-groth16/parallel"]
wasm = ["wasmer/js", "wasmer/std"]

View File

@ -4,18 +4,25 @@ use ark_bn254::{
Bn254, Fq as ArkFq, Fq2 as ArkFq2, Fr as ArkFr, G1Affine as ArkG1Affine,
G1Projective as ArkG1Projective, G2Affine as ArkG2Affine, G2Projective as ArkG2Projective,
};
use ark_circom::{read_zkey, WitnessCalculator};
use ark_circom::read_zkey;
use ark_groth16::{ProvingKey, VerifyingKey};
use ark_relations::r1cs::ConstraintMatrices;
use cfg_if::cfg_if;
use num_bigint::BigUint;
use once_cell::sync::OnceCell;
use serde_json::Value;
use std::fs::File;
use std::io::{Cursor, Error, ErrorKind, Result};
use std::path::Path;
use std::str::FromStr;
use std::sync::Mutex;
use wasmer::{Module, Store};
cfg_if! {
if #[cfg(not(target_arch = "wasm32"))] {
use ark_circom::{WitnessCalculator};
use once_cell::sync::OnceCell;
use std::sync::Mutex;
use wasmer::{Module, Store};
}
}
const ZKEY_FILENAME: &str = "rln_final.zkey";
const VK_FILENAME: &str = "verifying_key.json";
@ -109,9 +116,11 @@ pub fn vk_from_folder(resources_folder: &str) -> Result<VerifyingKey<Curve>> {
}
}
#[cfg(not(target_arch = "wasm32"))]
static WITNESS_CALCULATOR: OnceCell<Mutex<WitnessCalculator>> = OnceCell::new();
// Initializes the witness calculator using a bytes vector
#[cfg(not(target_arch = "wasm32"))]
pub fn circom_from_raw(wasm_buffer: Vec<u8>) -> &'static Mutex<WitnessCalculator> {
WITNESS_CALCULATOR.get_or_init(|| {
let store = Store::default();
@ -123,6 +132,7 @@ pub fn circom_from_raw(wasm_buffer: Vec<u8>) -> &'static Mutex<WitnessCalculator
}
// Initializes the witness calculator
#[cfg(not(target_arch = "wasm32"))]
pub fn circom_from_folder(resources_folder: &str) -> &'static Mutex<WitnessCalculator> {
// We read the wasm file
let wasm_path = format!("{resources_folder}{WASM_FILENAME}");

View File

@ -59,7 +59,12 @@ pub extern "C" fn new_with_params(
let circom_data = <&[u8]>::from(unsafe { &*circom_buffer });
let zkey_data = <&[u8]>::from(unsafe { &*zkey_buffer });
let vk_data = <&[u8]>::from(unsafe { &*vk_buffer });
let rln = RLN::new_with_params(tree_height, circom_data, zkey_data, vk_data);
let rln = RLN::new_with_params(
tree_height,
circom_data.to_vec(),
zkey_data.to_vec(),
vk_data.to_vec(),
);
unsafe { *ctx = Box::into_raw(Box::new(rln)) };
true
}

View File

@ -1,7 +1,6 @@
#![allow(dead_code)]
pub mod circuit;
pub mod ffi;
pub mod merkle_tree;
pub mod poseidon_constants;
pub mod poseidon_hash;
@ -10,6 +9,9 @@ pub mod protocol;
pub mod public;
pub mod utils;
#[cfg(not(target_arch = "wasm32"))]
pub mod ffi;
#[cfg(test)]
mod test {

View File

@ -11,6 +11,7 @@ use ark_std::{rand::thread_rng, UniformRand};
use color_eyre::Result;
use num_bigint::BigInt;
use rand::Rng;
#[cfg(not(target_arch = "wasm32"))]
use std::sync::Mutex;
#[cfg(debug_assertions)]
use std::time::Instant;
@ -22,6 +23,7 @@ use crate::poseidon_hash::poseidon_hash;
use crate::poseidon_tree::*;
use crate::public::RLN_IDENTIFIER;
use crate::utils::*;
use cfg_if::cfg_if;
///////////////////////////////////////////////////////
// RLN Witness data structure and utility functions
@ -121,12 +123,9 @@ pub fn proof_inputs_to_rln_witness(
let signal_len = u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap());
all_read += 8;
let signal: Vec<u8> =
serialized[all_read..all_read + usize::try_from(signal_len).unwrap()].to_vec();
let signal: Vec<u8> = serialized[all_read..all_read + (signal_len as usize)].to_vec();
let merkle_proof = tree
.proof(usize::try_from(id_index).unwrap())
.expect("proof should exist");
let merkle_proof = tree.proof(id_index as usize).expect("proof should exist");
let path_elements = merkle_proof.get_path_elements();
let identity_path_index = merkle_proof.get_path_index();
@ -374,16 +373,70 @@ pub enum ProofError {
SynthesisError(#[from] SynthesisError),
}
/// Generates a RLN proof
///
/// # Errors
///
/// Returns a [`ProofError`] if proving fails.
pub fn generate_proof(
witness_calculator: &Mutex<WitnessCalculator>,
fn calculate_witness_element<E: ark_ec::PairingEngine>(witness: Vec<BigInt>) -> Result<Vec<E::Fr>> {
use ark_ff::{FpParameters, PrimeField};
let modulus = <<E::Fr as PrimeField>::Params as FpParameters>::MODULUS;
// convert it to field elements
use num_traits::Signed;
let witness = witness
.into_iter()
.map(|w| {
let w = if w.sign() == num_bigint::Sign::Minus {
// Need to negate the witness element if negative
modulus.into() - w.abs().to_biguint().unwrap()
} else {
w.to_biguint().unwrap()
};
E::Fr::from(w)
})
.collect::<Vec<_>>();
Ok(witness)
}
pub fn generate_proof_with_witness(
witness: Vec<BigInt>,
proving_key: &(ProvingKey<Curve>, ConstraintMatrices<Fr>),
rln_witness: &RLNWitnessInput,
) -> Result<ArkProof<Curve>, ProofError> {
// If in debug mode, we measure and later print time take to compute witness
#[cfg(debug_assertions)]
let now = Instant::now();
let full_assignment = calculate_witness_element::<Curve>(witness)
.map_err(ProofError::WitnessError)
.unwrap();
#[cfg(debug_assertions)]
println!("witness generation took: {:.2?}", now.elapsed());
// Random Values
let mut rng = thread_rng();
let r = Fr::rand(&mut rng);
let s = Fr::rand(&mut rng);
// If in debug mode, we measure and later print time take to compute proof
#[cfg(debug_assertions)]
let now = Instant::now();
let proof = create_proof_with_reduction_and_matrices::<_, CircomReduction>(
&proving_key.0,
r,
s,
&proving_key.1,
proving_key.1.num_instance_variables,
proving_key.1.num_constraints,
full_assignment.as_slice(),
)
.unwrap();
#[cfg(debug_assertions)]
println!("proof generation took: {:.2?}", now.elapsed());
Ok(proof)
}
pub fn inputs_for_witness_calculation(rln_witness: &RLNWitnessInput) -> [(&str, Vec<BigInt>); 6] {
// We confert the path indexes to field elements
// TODO: check if necessary
let mut path_elements = Vec::new();
@ -398,7 +451,7 @@ pub fn generate_proof(
.iter()
.for_each(|v| identity_path_index.push(BigInt::from(*v)));
let inputs = [
[
(
"identity_secret",
vec![to_bigint(&rln_witness.identity_secret)],
@ -411,8 +464,21 @@ pub fn generate_proof(
"rln_identifier",
vec![to_bigint(&rln_witness.rln_identifier)],
),
];
let inputs = inputs
]
}
/// Generates a RLN proof
///
/// # Errors
///
/// Returns a [`ProofError`] if proving fails.
pub fn generate_proof(
#[cfg(not(target_arch = "wasm32"))] witness_calculator: &Mutex<WitnessCalculator>,
#[cfg(target_arch = "wasm32")] witness_calculator: &mut WitnessCalculator,
proving_key: &(ProvingKey<Curve>, ConstraintMatrices<Fr>),
rln_witness: &RLNWitnessInput,
) -> Result<ArkProof<Curve>, ProofError> {
let inputs = inputs_for_witness_calculation(rln_witness)
.into_iter()
.map(|(name, values)| (name.to_string(), values));
@ -420,11 +486,19 @@ pub fn generate_proof(
#[cfg(debug_assertions)]
let now = Instant::now();
let full_assignment = witness_calculator
.lock()
.expect("witness_calculator mutex should not get poisoned")
.calculate_witness_element::<Curve, _>(inputs, false)
.map_err(ProofError::WitnessError)?;
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
let full_assignment = witness_calculator
.calculate_witness_element::<Curve, _>(inputs, false)
.map_err(ProofError::WitnessError)?;
} else {
let full_assignment = witness_calculator
.lock()
.expect("witness_calculator mutex should not get poisoned")
.calculate_witness_element::<Curve, _>(inputs, false)
.map_err(ProofError::WitnessError)?;
}
}
#[cfg(debug_assertions)]
println!("witness generation took: {:.2?}", now.elapsed());
@ -490,3 +564,32 @@ pub fn verify_proof(
Ok(verified)
}
/// Get CIRCOM JSON inputs
///
/// Returns a JSON object containing the inputs necessary to calculate
/// the witness with CIRCOM on javascript
pub fn get_json_inputs(rln_witness: &RLNWitnessInput) -> serde_json::Value {
let mut path_elements = Vec::new();
rln_witness
.path_elements
.iter()
.for_each(|v| path_elements.push(to_bigint(v).to_str_radix(10)));
let mut identity_path_index = Vec::new();
rln_witness
.identity_path_index
.iter()
.for_each(|v| identity_path_index.push(BigInt::from(*v).to_str_radix(10)));
let inputs = serde_json::json!({
"identity_secret": to_bigint(&rln_witness.identity_secret).to_str_radix(10),
"path_elements": path_elements,
"identity_path_index": identity_path_index,
"x": to_bigint(&rln_witness.x).to_str_radix(10),
"epoch": format!("0x{:064x}", to_bigint(&rln_witness.epoch)),
"rln_identifier": to_bigint(&rln_witness.rln_identifier).to_str_radix(10),
});
inputs
}

View File

@ -1,21 +1,28 @@
/// This is the main public API for RLN module. It is used by the FFI, and should be
/// used by tests etc as well
use ark_circom::WitnessCalculator;
use ark_groth16::Proof as ArkProof;
use ark_groth16::{ProvingKey, VerifyingKey};
use ark_relations::r1cs::ConstraintMatrices;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use std::default::Default;
use std::io::{self, Cursor, Read, Result, Write};
use std::sync::Mutex;
use crate::circuit::{
circom_from_folder, circom_from_raw, vk_from_folder, vk_from_raw, zkey_from_folder,
zkey_from_raw, Curve, Fr, TEST_RESOURCES_FOLDER, TEST_TREE_HEIGHT,
};
use crate::circuit::{vk_from_raw, zkey_from_raw, Curve, Fr};
use crate::poseidon_tree::PoseidonTree;
use crate::protocol::*;
use crate::utils::*;
/// This is the main public API for RLN module. It is used by the FFI, and should be
/// used by tests etc as well
use ark_groth16::Proof as ArkProof;
use ark_groth16::{ProvingKey, VerifyingKey};
use ark_relations::r1cs::ConstraintMatrices;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, Write};
use cfg_if::cfg_if;
use num_bigint::BigInt;
use std::io::Cursor;
use std::io::{self, Result};
cfg_if! {
if #[cfg(not(target_arch = "wasm32"))] {
use std::default::Default;
use std::sync::Mutex;
use crate::circuit::{circom_from_folder, vk_from_folder, circom_from_raw, zkey_from_folder, TEST_RESOURCES_FOLDER, TEST_TREE_HEIGHT};
use ark_circom::WitnessCalculator;
} else {
use std::marker::*;
}
}
// Application specific RLN identifier
pub const RLN_IDENTIFIER: &[u8] = b"zerokit/rln/010203040506070809";
@ -23,13 +30,21 @@ pub const RLN_IDENTIFIER: &[u8] = b"zerokit/rln/010203040506070809";
// TODO Add Engine here? i.e. <E: Engine> not <Curve>
// TODO Assuming we want to use IncrementalMerkleTree, figure out type/trait conversions
pub struct RLN<'a> {
witness_calculator: &'a Mutex<WitnessCalculator>,
proving_key: Result<(ProvingKey<Curve>, ConstraintMatrices<Fr>)>,
verification_key: Result<VerifyingKey<Curve>>,
tree: PoseidonTree,
// The witness calculator can't be loaded in zerokit. Since this struct
// contains a lifetime, a PhantomData is necessary to avoid a compiler
// error since the lifetime is not being used
#[cfg(not(target_arch = "wasm32"))]
witness_calculator: &'a Mutex<WitnessCalculator>,
#[cfg(target_arch = "wasm32")]
_marker: PhantomData<&'a ()>,
}
impl RLN<'_> {
#[cfg(not(target_arch = "wasm32"))]
pub fn new<R: Read>(tree_height: usize, mut input_data: R) -> RLN<'static> {
// We read input
let mut input: Vec<u8> = Vec::new();
@ -50,23 +65,18 @@ impl RLN<'_> {
proving_key,
verification_key,
tree,
#[cfg(target_arch = "wasm32")]
_marker: PhantomData,
}
}
pub fn new_with_params<R: Read>(
pub fn new_with_params(
tree_height: usize,
mut circom_data: R,
mut zkey_data: R,
mut vk_data: R,
#[cfg(not(target_arch = "wasm32"))] circom_vec: Vec<u8>,
zkey_vec: Vec<u8>,
vk_vec: Vec<u8>,
) -> RLN<'static> {
// We read input
let mut circom_vec: Vec<u8> = Vec::new();
circom_data.read_to_end(&mut circom_vec).unwrap();
let mut zkey_vec: Vec<u8> = Vec::new();
zkey_data.read_to_end(&mut zkey_vec).unwrap();
let mut vk_vec: Vec<u8> = Vec::new();
vk_data.read_to_end(&mut vk_vec).unwrap();
#[cfg(not(target_arch = "wasm32"))]
let witness_calculator = circom_from_raw(circom_vec);
let proving_key = zkey_from_raw(&zkey_vec);
@ -76,10 +86,13 @@ impl RLN<'_> {
let tree = PoseidonTree::default(tree_height);
RLN {
#[cfg(not(target_arch = "wasm32"))]
witness_calculator,
proving_key,
verification_key,
tree,
#[cfg(target_arch = "wasm32")]
_marker: PhantomData,
}
}
@ -165,6 +178,7 @@ impl RLN<'_> {
////////////////////////////////////////////////////////
// zkSNARK APIs
////////////////////////////////////////////////////////
#[cfg(not(target_arch = "wasm32"))]
pub fn prove<R: Read, W: Write>(
&mut self,
mut input_data: R,
@ -182,7 +196,7 @@ impl RLN<'_> {
*/
let proof = generate_proof(
self.witness_calculator,
&mut self.witness_calculator,
self.proving_key.as_ref().unwrap(),
&rln_witness,
)
@ -213,9 +227,29 @@ impl RLN<'_> {
Ok(verified)
}
/// Get the serialized rln_witness for some input
pub fn get_serialized_rln_witness<R: Read>(&mut self, mut input_data: R) -> Vec<u8> {
// We read input RLN witness and we deserialize it
let mut witness_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut witness_byte).unwrap();
let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte);
serialize_witness(&rln_witness)
}
/// Get JSON inputs for serialized RLN witness
pub fn get_rln_witness_json(
&mut self,
serialized_witness: &[u8],
) -> io::Result<serde_json::Value> {
let (rln_witness, _) = deserialize_witness(serialized_witness);
Ok(get_json_inputs(&rln_witness))
}
// This API keeps partial compatibility with kilic's rln public API https://github.com/kilic/rln/blob/7ac74183f8b69b399e3bc96c1ae8ab61c026dc43/src/public.rs#L148
// input_data is [ id_key<32> | id_index<8> | epoch<32> | signal_len<8> | signal<var> ]
// output_data is [ proof<128> | share_y<32> | nullifier<32> | root<32> | epoch<32> | share_x<32> | rln_identifier<32> ]
#[cfg(not(target_arch = "wasm32"))]
pub fn generate_rln_proof<R: Read, W: Write>(
&mut self,
mut input_data: R,
@ -242,6 +276,29 @@ impl RLN<'_> {
Ok(())
}
/// Generate RLN Proof using a witness calculated from outside zerokit
///
/// output_data is [ proof<128> | share_y<32> | nullifier<32> | root<32> | epoch<32> | share_x<32> | rln_identifier<32> ]
pub fn generate_rln_proof_with_witness<W: Write>(
&mut self,
calculated_witness: Vec<BigInt>,
rln_witness_vec: Vec<u8>,
mut output_data: W,
) -> io::Result<()> {
let (rln_witness, _) = deserialize_witness(&rln_witness_vec[..]);
let proof_values = proof_values_from_witness(&rln_witness);
let proof =
generate_proof_with_witness(calculated_witness, self.proving_key.as_ref().unwrap())
.unwrap();
// Note: we export a serialization of ark-groth16::Proof not semaphore::Proof
// This proof is compressed, i.e. 128 bytes long
proof.serialize(&mut output_data).unwrap();
output_data.write_all(&serialize_proof_values(&proof_values))?;
Ok(())
}
// Input data is serialized for Curve as:
// [ proof<128> | share_y<32> | nullifier<32> | root<32> | epoch<32> | share_x<32> | rln_identifier<32> | signal_len<8> | signal<var> ]
pub fn verify_rln_proof<R: Read>(&self, mut input_data: R) -> io::Result<bool> {
@ -253,10 +310,8 @@ impl RLN<'_> {
let (proof_values, read) = deserialize_proof_values(&serialized[all_read..].to_vec());
all_read += read;
let signal_len = usize::try_from(u64::from_le_bytes(
serialized[all_read..all_read + 8].try_into().unwrap(),
))
.unwrap();
let signal_len =
u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap()) as usize;
all_read += 8;
let signal: Vec<u8> = serialized[all_read..all_read + signal_len].to_vec();
@ -299,6 +354,7 @@ impl RLN<'_> {
}
}
#[cfg(not(target_arch = "wasm32"))]
impl Default for RLN<'_> {
fn default() -> Self {
let tree_height = TEST_TREE_HEIGHT;
@ -673,6 +729,101 @@ mod test {
assert!(verified);
}
#[test]
fn test_rln_with_witness() {
let tree_height = TEST_TREE_HEIGHT;
let no_of_leaves = 256;
// We generate a vector of random leaves
let mut leaves: Vec<Fr> = Vec::new();
let mut rng = thread_rng();
for _ in 0..no_of_leaves {
leaves.push(Fr::rand(&mut rng));
}
// We create a new RLN instance
let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER);
let mut rln = RLN::new(tree_height, input_buffer);
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
rln.set_leaves(&mut buffer).unwrap();
// Generate identity pair
let (identity_secret, id_commitment) = keygen();
// We set as leaf id_commitment after storing its index
let identity_index = u64::try_from(rln.tree.leaves_set()).unwrap();
let mut buffer = Cursor::new(fr_to_bytes_le(&id_commitment));
rln.set_next_leaf(&mut buffer).unwrap();
// We generate a random signal
let mut rng = rand::thread_rng();
let signal: [u8; 32] = rng.gen();
let signal_len = u64::try_from(signal.len()).unwrap();
// We generate a random epoch
let epoch = hash_to_field(b"test-epoch");
// We prepare input for generate_rln_proof API
// input_data is [ id_key<32> | id_index<8> | epoch<32> | signal_len<8> | signal<var> ]
let mut serialized: Vec<u8> = Vec::new();
serialized.append(&mut fr_to_bytes_le(&identity_secret));
serialized.append(&mut identity_index.to_le_bytes().to_vec());
serialized.append(&mut fr_to_bytes_le(&epoch));
serialized.append(&mut signal_len.to_le_bytes().to_vec());
serialized.append(&mut signal.to_vec());
let mut input_buffer = Cursor::new(serialized);
// We read input RLN witness and we deserialize it
let mut witness_byte: Vec<u8> = Vec::new();
input_buffer.read_to_end(&mut witness_byte).unwrap();
let (rln_witness, _) = proof_inputs_to_rln_witness(&mut rln.tree, &witness_byte);
let serialized_witness = serialize_witness(&rln_witness);
// Calculate witness outside zerokit (simulating what JS is doing)
let inputs = inputs_for_witness_calculation(&rln_witness)
.into_iter()
.map(|(name, values)| (name.to_string(), values));
let calculated_witness = rln
.witness_calculator
.lock()
.expect("witness_calculator mutex should not get poisoned")
.calculate_witness_element::<Curve, _>(inputs, false)
.map_err(ProofError::WitnessError)
.unwrap();
let calculated_witness_vec: Vec<BigInt> = calculated_witness
.into_iter()
.map(|v| to_bigint(&v))
.collect();
// Generating the proof
let mut output_buffer = Cursor::new(Vec::<u8>::new());
rln.generate_rln_proof_with_witness(
calculated_witness_vec,
serialized_witness,
&mut output_buffer,
)
.unwrap();
// output_data is [ proof<128> | share_y<32> | nullifier<32> | root<32> | epoch<32> | share_x<32> | rln_identifier<32> ]
let mut proof_data = output_buffer.into_inner();
// We prepare input for verify_rln_proof API
// input_data is [ proof<128> | share_y<32> | nullifier<32> | root<32> | epoch<32> | share_x<32> | rln_identifier<32> | signal_len<8> | signal<var> ]
// that is [ proof_data || signal_len<8> | signal<var> ]
proof_data.append(&mut signal_len.to_le_bytes().to_vec());
proof_data.append(&mut signal.to_vec());
let mut input_buffer = Cursor::new(proof_data);
let verified = rln.verify_rln_proof(&mut input_buffer).unwrap();
assert!(verified);
}
#[test]
fn test_hash_to_field() {
let rln = RLN::default();

View File

@ -85,7 +85,7 @@ pub fn fr_to_bytes_be(input: &Fr) -> Vec<u8> {
pub fn vec_fr_to_bytes_le(input: &[Fr]) -> Vec<u8> {
let mut bytes: Vec<u8> = Vec::new();
//We store the vector length
bytes.extend(input.len().to_le_bytes().to_vec());
bytes.extend(u64::try_from(input.len()).unwrap().to_le_bytes().to_vec());
// We store each element
input.iter().for_each(|el| bytes.extend(fr_to_bytes_le(el)));
@ -95,7 +95,7 @@ pub fn vec_fr_to_bytes_le(input: &[Fr]) -> Vec<u8> {
pub fn vec_fr_to_bytes_be(input: &[Fr]) -> Vec<u8> {
let mut bytes: Vec<u8> = Vec::new();
//We store the vector length
bytes.extend(input.len().to_be_bytes().to_vec());
bytes.extend(u64::try_from(input.len()).unwrap().to_be_bytes().to_vec());
// We store each element
input.iter().for_each(|el| bytes.extend(fr_to_bytes_be(el)));
@ -121,7 +121,7 @@ pub fn vec_u8_to_bytes_be(input: Vec<u8>) -> Vec<u8> {
pub fn bytes_le_to_vec_u8(input: &[u8]) -> (Vec<u8>, usize) {
let mut read: usize = 0;
let len = usize::try_from(u64::from_le_bytes(input[0..8].try_into().unwrap())).unwrap();
let len = u64::from_le_bytes(input[0..8].try_into().unwrap()) as usize;
read += 8;
let res = input[8..8 + len].to_vec();
@ -133,7 +133,7 @@ pub fn bytes_le_to_vec_u8(input: &[u8]) -> (Vec<u8>, usize) {
pub fn bytes_be_to_vec_u8(input: &[u8]) -> (Vec<u8>, usize) {
let mut read: usize = 0;
let len = usize::try_from(u64::from_be_bytes(input[0..8].try_into().unwrap())).unwrap();
let len = u64::from_be_bytes(input[0..8].try_into().unwrap()) as usize;
read += 8;
let res = input[8..8 + len].to_vec();
@ -147,7 +147,7 @@ pub fn bytes_le_to_vec_fr(input: &[u8]) -> (Vec<Fr>, usize) {
let mut read: usize = 0;
let mut res: Vec<Fr> = Vec::new();
let len = usize::try_from(u64::from_le_bytes(input[0..8].try_into().unwrap())).unwrap();
let len = u64::from_le_bytes(input[0..8].try_into().unwrap()) as usize;
read += 8;
let el_size = fr_byte_size();
@ -164,7 +164,7 @@ pub fn bytes_be_to_vec_fr(input: &[u8]) -> (Vec<Fr>, usize) {
let mut read: usize = 0;
let mut res: Vec<Fr> = Vec::new();
let len = usize::try_from(u64::from_be_bytes(input[0..8].try_into().unwrap())).unwrap();
let len = u64::from_be_bytes(input[0..8].try_into().unwrap()) as usize;
read += 8;
let el_size = fr_byte_size();